mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
feat: MiniMax provider for image/video/podcast skills + new music-generation skill (#3437)
* docs(spec): MiniMax integration for generation skills + new music skill Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * docs(plan): MiniMax generation providers implementation plan Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * test(skills): add importlib loader + FakeResp for skill tests * test(skills): register loaded module in sys.modules; raise requests.HTTPError in FakeResp * feat(image-generation): add MiniMax provider with env auto-detect Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * refactor(image-generation): guard unknown provider, derive ref MIME, strengthen tests Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * feat(video-generation): add MiniMax provider with async poll/download Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * refactor(video-generation): surface base_resp errors while polling; add timeout test * feat(podcast-generation): add MiniMax t2a_v2 provider with env auto-detect Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * refactor(podcast-generation): restore TTS credential guard; add volcengine + voice tests Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * feat(music-generation): new MiniMax music skill via skill-creator Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * refactor(music-generation): treat empty lyrics as absent; test no-audio-data path * refactor(skills): add request timeouts to MiniMax network calls Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Potential fix for pull request finding 'Explicit returns mixed with implicit (fall through) returns' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> * fix(models): strip inconsistent user-message names for MiniMax chat DeerFlow middlewares tag user messages with provenance names (user-input, summary, loop_warning); langchain serializes them into the OpenAI-compatible payload and MiniMax rejects mismatched user-message names with "user name must be consistent (2013)". PatchedChatMiniMax now drops the per-message name from user-role messages. Point the config.example MiniMax models at PatchedChatMiniMax so they also get reasoning_content mapping. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * feat(image-generation): MiniMax sends JSON prompt field, guard 1500-char limit MiniMax image-01 takes one text string capped at 1500 chars, but the skill was sending the whole structured JSON. The MiniMax provider now extracts the JSON `prompt` field (relying on prompt_optimizer to expand it) and fails fast with a clear error before calling the API when that field exceeds 1500 chars. Authoring stays provider-agnostic; Gemini still receives the full JSON. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * feat(podcast-generation): per-provider TTS concurrency and retry/backoff Each TTS provider owns its concurrency internally — MiniMax runs single-threaded to reduce rate-limit failures, Volcengine keeps 4 workers — with automatic retry and backoff on transient HTTP and base_resp errors. No caller-facing concurrency knob. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * fix(skills): address Copilot review comments on generation skills - video: add raise_for_status + timeout to the Gemini download/POST/poll calls so non-2xx responses surface as clear HTTP errors instead of JSON/KeyError or hangs - video: check the task Fail status before the generic base_resp check so the failure keeps its task_id context - video/image: create the output file parent directory before writing (matching music-generation) so nested output paths do not raise FileNotFoundError - music: require a non-empty prompt and fail fast with ValueError instead of sending an empty prompt to the API Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * fix(scripts): reclaim dev ports across worktrees in make stop/dev All deer-flow worktrees (main checkout + linked worktrees) hardcode the same dev ports (8001/3000/2026), so a service started from any worktree must be reclaimable from another. stop_all now resolves the set of worktree roots (DEERFLOW_ROOTS) and treats a process as deer-flow-owned when its open files live under any of them. It also force-kills survivors on 2026 alongside 8001/3000, fixing `make dev` aborting on the nginx port preflight when a prior nginx lingered on 2026. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * fix(view-image): hide the injected image-context message from the UI ViewImageMiddleware injects a HumanMessage (text + base64 images) so the vision model can see viewed images, but it was the only internal injector that set neither hide_from_ui nor a hidden name, so it leaked into the chat UI (and IM channels) as a user bubble reading "Here are the images you've viewed:". Mark it with additional_kwargs={"hide_from_ui": True}, matching todo/dynamic_context injections, which the frontend isHiddenFromUIMessage and the channel sender already honor. The model still receives the full content. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * fix(minimax): mark M2.7 models as text-only (no vision) MiniMax M2.7 / M2.7-highspeed do not support vision; only M3 does. The provider config asserted vision support for M2.7 in four places. - config.example.yaml: 4 M2.7 entries -> supports_vision: false - backend/docs/CONFIGURATION.md: M2.7 + highspeed -> supports_vision: false - wizard: add LLMProvider.model_vision_overrides + extra_config_for() so selecting an M2.7 model writes supports_vision: false while M3 (default) keeps vision; wire it through setup_wizard.py - tests: M2.7-highspeed fixture -> supports_vision=False; add test_minimax_vision_is_per_model Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
This commit is contained in:
@@ -178,6 +178,27 @@ For scenarios where visual accuracy is critical, **use the `image_search` tool f
|
||||
|
||||
This approach significantly improves generation quality by providing the model with concrete visual guidance rather than relying solely on text descriptions.
|
||||
|
||||
## Providers (Gemini / MiniMax)
|
||||
|
||||
This skill auto-selects the provider by environment variables (no CLI change):
|
||||
|
||||
- `GEMINI_API_KEY` set → use Gemini (default, unchanged).
|
||||
- Only `MINIMAX_API_KEY` set → use MiniMax (`/v1/image_generation`, model `image-01`).
|
||||
- Force one explicitly with `IMAGE_GENERATION_PROVIDER=gemini|minimax`.
|
||||
|
||||
MiniMax optional overrides: `MINIMAX_API_HOST` (default `https://api.minimaxi.com`),
|
||||
`MINIMAX_IMAGE_MODEL` (default `image-01`). Reference images are sent as the MiniMax
|
||||
`subject_reference` character image. The CLI and `--prompt-file` / `--reference-images`
|
||||
/ `--output-file` / `--aspect-ratio` arguments are identical for both providers.
|
||||
|
||||
**MiniMax prompt handling (provider-internal).** Authoring is provider-agnostic — write
|
||||
the same structured JSON regardless of which provider is active. MiniMax `image-01`
|
||||
consumes a single text string, so the MiniMax path itself sends only the JSON `prompt`
|
||||
field (the other fields such as `style` / `composition` / `negative_prompt` apply to the
|
||||
Gemini path) and enables `prompt_optimizer` so MiniMax expands it server-side. MiniMax
|
||||
caps that prompt at 1500 characters; if the `prompt` field is longer, the script returns
|
||||
an error instead of calling the API. The Gemini path receives the full structured JSON.
|
||||
|
||||
## Notes
|
||||
|
||||
- Always use English for prompts regardless of user's language
|
||||
|
||||
@@ -1,32 +1,196 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
MINIMAX_DEFAULT_HOST = "https://api.minimaxi.com"
|
||||
# MiniMax image-01 caps the prompt at 1500 characters and rejects longer requests
|
||||
# with a generic "invalid params" error, so validate before calling the API.
|
||||
MINIMAX_PROMPT_MAX_CHARS = 1500
|
||||
|
||||
|
||||
def validate_image(image_path: str) -> bool:
|
||||
"""
|
||||
Validate if an image file can be opened and is not corrupted.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
|
||||
Returns:
|
||||
True if the image is valid and can be opened, False otherwise
|
||||
"""
|
||||
"""Validate if an image file can be opened and is not corrupted."""
|
||||
from PIL import Image # lazy import: keeps module importable without Pillow
|
||||
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
img.verify() # Verify that it's a valid image
|
||||
# Re-open to check if it can be fully loaded (verify() may not catch all issues)
|
||||
with Image.open(image_path) as img:
|
||||
img.load() # Force load the image data
|
||||
with Image.open(image_path) as image:
|
||||
image.verify()
|
||||
with Image.open(image_path) as image:
|
||||
image.load()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Warning: Image '{image_path}' is invalid or corrupted: {e}")
|
||||
except Exception as exc:
|
||||
print(f"Warning: Image '{image_path}' is invalid or corrupted: {exc}")
|
||||
return False
|
||||
|
||||
|
||||
def _resolve_provider(override_env: str, existing_provider: str, has_existing_creds: bool) -> str:
|
||||
"""Pick the generation provider.
|
||||
|
||||
1. Explicit <SKILL>_PROVIDER override wins.
|
||||
2. Otherwise prefer the existing provider when its credentials are present.
|
||||
3. Otherwise fall back to MiniMax when MINIMAX_API_KEY is set.
|
||||
"""
|
||||
override = os.getenv(override_env)
|
||||
if override:
|
||||
return override.strip().lower()
|
||||
if has_existing_creds:
|
||||
return existing_provider
|
||||
if os.getenv("MINIMAX_API_KEY"):
|
||||
return "minimax"
|
||||
raise ValueError(
|
||||
f"No credentials found. Set GEMINI_API_KEY for {existing_provider}, "
|
||||
f"or MINIMAX_API_KEY for minimax (optionally force with {override_env})."
|
||||
)
|
||||
|
||||
|
||||
def _minimax_host() -> str:
|
||||
return os.getenv("MINIMAX_API_HOST", MINIMAX_DEFAULT_HOST).rstrip("/")
|
||||
|
||||
|
||||
def _check_base_resp(payload: dict) -> None:
|
||||
base = payload.get("base_resp") or {}
|
||||
if base.get("status_code", 0) != 0:
|
||||
raise Exception(
|
||||
f"MiniMax error {base.get('status_code')}: {base.get('status_msg')}"
|
||||
)
|
||||
|
||||
|
||||
def _guess_mime(image_path: str) -> str:
|
||||
ext = os.path.splitext(image_path)[1].lower()
|
||||
return {
|
||||
".png": "image/png",
|
||||
".webp": "image/webp",
|
||||
".gif": "image/gif",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
}.get(ext, "image/jpeg")
|
||||
|
||||
|
||||
def _to_data_url(image_path: str) -> str:
|
||||
with open(image_path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return f"data:{_guess_mime(image_path)};base64,{b64}"
|
||||
|
||||
|
||||
def _ensure_output_dir(output_file: str) -> None:
|
||||
"""Create the output file's parent directory so nested paths don't fail."""
|
||||
output_dir = os.path.dirname(output_file)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
def _minimax_prompt(raw: str) -> str:
|
||||
"""Extract the single text prompt MiniMax image-01 expects.
|
||||
|
||||
The shared prompt file is structured JSON (a consolidated ``prompt`` plus
|
||||
Gemini-oriented fields like ``style`` / ``composition`` / ``negative_prompt``),
|
||||
but MiniMax consumes one string and expands it via ``prompt_optimizer``. The
|
||||
provider adapts the input itself — the caller never needs to know MiniMax is
|
||||
active. Use the JSON ``prompt`` field; fall back to the raw text for plain-text
|
||||
prompt files or JSON without a ``prompt`` field.
|
||||
"""
|
||||
text = raw.strip()
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
return text
|
||||
if isinstance(data, dict):
|
||||
core = data.get("prompt")
|
||||
if isinstance(core, str) and core.strip():
|
||||
return core.strip()
|
||||
return text
|
||||
|
||||
|
||||
def _generate_image_minimax(
|
||||
prompt: str, reference_images: list[str], output_file: str, aspect_ratio: str
|
||||
) -> str:
|
||||
api_key = os.getenv("MINIMAX_API_KEY")
|
||||
if not api_key:
|
||||
return "MINIMAX_API_KEY is not set"
|
||||
prompt = _minimax_prompt(prompt)
|
||||
if len(prompt) > MINIMAX_PROMPT_MAX_CHARS:
|
||||
return (
|
||||
f"Prompt is {len(prompt)} characters but MiniMax image-01 accepts at most "
|
||||
f"{MINIMAX_PROMPT_MAX_CHARS}. Shorten the prompt to stay within the limit; "
|
||||
f"reference images plus a tighter description usually recover the detail."
|
||||
)
|
||||
body = {
|
||||
"model": os.getenv("MINIMAX_IMAGE_MODEL", "image-01"),
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"response_format": "base64",
|
||||
"n": 1,
|
||||
"prompt_optimizer": True,
|
||||
}
|
||||
if reference_images:
|
||||
# Reference images are passed as character subjects as-is; unlike the Gemini
|
||||
# path we do not pre-validate them — invalid files surface as a MiniMax API error.
|
||||
body["subject_reference"] = [
|
||||
{"type": "character", "image_file": _to_data_url(p)} for p in reference_images
|
||||
]
|
||||
response = requests.post(
|
||||
f"{_minimax_host()}/v1/image_generation",
|
||||
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
||||
json=body,
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
_check_base_resp(payload)
|
||||
images = (payload.get("data") or {}).get("image_base64") or []
|
||||
if not images:
|
||||
raise Exception("MiniMax returned no image data")
|
||||
_ensure_output_dir(output_file)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(base64.b64decode(images[0]))
|
||||
return f"Successfully generated image to {output_file}"
|
||||
|
||||
|
||||
def _generate_image_gemini(
|
||||
prompt: str, reference_images: list[str], output_file: str, aspect_ratio: str
|
||||
) -> str:
|
||||
parts = []
|
||||
valid_reference_images = []
|
||||
for ref_img in reference_images:
|
||||
if validate_image(ref_img):
|
||||
valid_reference_images.append(ref_img)
|
||||
else:
|
||||
print(f"Skipping invalid reference image: {ref_img}")
|
||||
if len(valid_reference_images) < len(reference_images):
|
||||
skipped = len(reference_images) - len(valid_reference_images)
|
||||
print(f"Note: {skipped} reference image(s) were skipped due to validation failure.")
|
||||
|
||||
for reference_image in valid_reference_images:
|
||||
with open(reference_image, "rb") as f:
|
||||
image_b64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
parts.append({"inlineData": {"mimeType": "image/jpeg", "data": image_b64}})
|
||||
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
return "GEMINI_API_KEY is not set"
|
||||
response = requests.post(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/gemini-3-pro-image-preview:generateContent",
|
||||
headers={"x-goog-api-key": api_key, "Content-Type": "application/json"},
|
||||
json={
|
||||
"generationConfig": {"imageConfig": {"aspectRatio": aspect_ratio}},
|
||||
"contents": [{"parts": [*parts, {"text": prompt}]}],
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
response_parts: list[dict] = data["candidates"][0]["content"]["parts"]
|
||||
image_parts = [part for part in response_parts if part.get("inlineData", False)]
|
||||
if len(image_parts) == 1:
|
||||
base64_image = image_parts[0]["inlineData"]["data"]
|
||||
_ensure_output_dir(output_file)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(base64.b64decode(base64_image))
|
||||
return f"Successfully generated image to {output_file}"
|
||||
raise Exception("Failed to generate image")
|
||||
|
||||
|
||||
def generate_image(
|
||||
prompt_file: str,
|
||||
reference_images: list[str],
|
||||
@@ -35,98 +199,30 @@ def generate_image(
|
||||
) -> str:
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
prompt = f.read()
|
||||
parts = []
|
||||
i = 0
|
||||
|
||||
# Filter out invalid reference images
|
||||
valid_reference_images = []
|
||||
for ref_img in reference_images:
|
||||
if validate_image(ref_img):
|
||||
valid_reference_images.append(ref_img)
|
||||
else:
|
||||
print(f"Skipping invalid reference image: {ref_img}")
|
||||
|
||||
if len(valid_reference_images) < len(reference_images):
|
||||
print(f"Note: {len(reference_images) - len(valid_reference_images)} reference image(s) were skipped due to validation failure.")
|
||||
|
||||
for reference_image in valid_reference_images:
|
||||
i += 1
|
||||
with open(reference_image, "rb") as f:
|
||||
image_b64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
parts.append(
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": "image/jpeg",
|
||||
"data": image_b64,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
return "GEMINI_API_KEY is not set"
|
||||
response = requests.post(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/gemini-3-pro-image-preview:generateContent",
|
||||
headers={
|
||||
"x-goog-api-key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"generationConfig": {"imageConfig": {"aspectRatio": aspect_ratio}},
|
||||
"contents": [{"parts": [*parts, {"text": prompt}]}],
|
||||
},
|
||||
provider = _resolve_provider(
|
||||
"IMAGE_GENERATION_PROVIDER", "gemini", bool(os.getenv("GEMINI_API_KEY"))
|
||||
)
|
||||
response.raise_for_status()
|
||||
json = response.json()
|
||||
parts: list[dict] = json["candidates"][0]["content"]["parts"]
|
||||
image_parts = [part for part in parts if part.get("inlineData", False)]
|
||||
if len(image_parts) == 1:
|
||||
base64_image = image_parts[0]["inlineData"]["data"]
|
||||
# Save the image to a file
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(base64.b64decode(base64_image))
|
||||
return f"Successfully generated image to {output_file}"
|
||||
else:
|
||||
raise Exception("Failed to generate image")
|
||||
if provider == "minimax":
|
||||
return _generate_image_minimax(prompt, reference_images, output_file, aspect_ratio)
|
||||
if provider in ("gemini", "google"):
|
||||
return _generate_image_gemini(prompt, reference_images, output_file, aspect_ratio)
|
||||
raise ValueError(f"Unknown image provider: {provider!r} (use 'gemini' or 'minimax')")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate images using Gemini API")
|
||||
parser.add_argument(
|
||||
"--prompt-file",
|
||||
required=True,
|
||||
help="Absolute path to JSON prompt file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference-images",
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="Absolute paths to reference images (space-separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
required=True,
|
||||
help="Output path for generated image",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aspect-ratio",
|
||||
required=False,
|
||||
default="16:9",
|
||||
help="Aspect ratio of the generated image",
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate images using Gemini or MiniMax API")
|
||||
parser.add_argument("--prompt-file", required=True, help="Absolute path to JSON prompt file")
|
||||
parser.add_argument("--reference-images", nargs="*", default=[],
|
||||
help="Absolute paths to reference images (space-separated)")
|
||||
parser.add_argument("--output-file", required=True, help="Output path for generated image")
|
||||
parser.add_argument("--aspect-ratio", required=False, default="16:9",
|
||||
help="Aspect ratio of the generated image")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
print(
|
||||
generate_image(
|
||||
args.prompt_file,
|
||||
args.reference_images,
|
||||
args.output_file,
|
||||
args.aspect_ratio,
|
||||
)
|
||||
)
|
||||
print(generate_image(args.prompt_file, args.reference_images,
|
||||
args.output_file, args.aspect_ratio))
|
||||
except Exception as e:
|
||||
print(f"Error while generating image: {e}")
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
---
|
||||
name: music-generation
|
||||
description: Use this skill when the user requests to generate, create, compose, or produce music or songs — background music, theme songs, jingles, or instrumental tracks. Generates a song from a style/mood prompt and optional lyrics via the MiniMax music API.
|
||||
---
|
||||
|
||||
# Music Generation Skill
|
||||
|
||||
## Overview
|
||||
|
||||
This skill generates songs (vocal or instrumental) from a structured JSON spec using the
|
||||
MiniMax music generation API (`/v1/music_generation`). You describe the style/mood/scene in
|
||||
`prompt`, optionally provide `lyrics`, and the script returns an MP3.
|
||||
|
||||
## Workflow
|
||||
|
||||
### Step 1: Understand Requirements
|
||||
|
||||
Identify the desired style, mood, scene, language, and whether the user wants vocals or a
|
||||
pure instrumental track. Decide whether to supply lyrics or let the model write them.
|
||||
|
||||
### Step 2: Create the Spec JSON
|
||||
|
||||
Write a JSON file in `/mnt/user-data/workspace/` named `{descriptive-name}.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"title": "Rainy Night Cafe",
|
||||
"prompt": "indie folk, melancholic, introspective, walking alone, cafe",
|
||||
"lyrics": "[verse]\nStreetlights glow the night wind sighs\n[chorus]\nPush the wooden door warm air inside"
|
||||
}
|
||||
```
|
||||
|
||||
Fields:
|
||||
- `title` (optional): a human-readable name.
|
||||
- `prompt` (required): style, mood, and scene. Drives the musical character.
|
||||
- `lyrics` (optional): song lyrics. Use `\n` between lines and structure tags such as
|
||||
`[Intro]`, `[Verse]`, `[Pre Chorus]`, `[Chorus]`, `[Bridge]`, `[Outro]`.
|
||||
- `is_instrumental` (optional, bool): set `true` for a pure instrumental track (no lyrics needed).
|
||||
|
||||
Behavior:
|
||||
- `lyrics` provided → those lyrics are sung.
|
||||
- `is_instrumental: true` → instrumental, no vocals.
|
||||
- neither → the model auto-writes lyrics from `prompt` (`lyrics_optimizer`).
|
||||
|
||||
### Step 3: Execute Generation
|
||||
|
||||
```bash
|
||||
python /mnt/skills/public/music-generation/scripts/generate.py \
|
||||
--prompt-file /mnt/user-data/workspace/rainy-night-cafe.json \
|
||||
--output-file /mnt/user-data/outputs/rainy-night-cafe.mp3
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `--prompt-file`: Absolute path to the JSON spec (required).
|
||||
- `--output-file`: Absolute path for the output MP3 (required).
|
||||
|
||||
[!NOTE]
|
||||
Do NOT read the python file, just call it with the parameters.
|
||||
|
||||
## Environment
|
||||
|
||||
- `MINIMAX_API_KEY` (required): your MiniMax interface key.
|
||||
- `MINIMAX_API_HOST` (optional): default `https://api.minimaxi.com`.
|
||||
- `MINIMAX_MUSIC_MODEL` (optional): default `music-2.6-free` (works for all API-key users);
|
||||
paid/Token-Plan users can set `music-2.6` for higher limits.
|
||||
|
||||
## Output Handling
|
||||
|
||||
- Music is saved as MP3 (typically in `/mnt/user-data/outputs/`).
|
||||
- Share the generated file with the user using the present_files tool.
|
||||
- Offer to iterate on style or lyrics if adjustments are needed.
|
||||
|
||||
## Notes
|
||||
|
||||
- Keep `prompt` focused on style/mood/scene; put the actual sung words in `lyrics`.
|
||||
- For non-English songs, write `lyrics` in the target language.
|
||||
@@ -0,0 +1,82 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
MINIMAX_DEFAULT_HOST = "https://api.minimaxi.com"
|
||||
|
||||
|
||||
def _check_base_resp(payload: dict) -> None:
|
||||
base = payload.get("base_resp") or {}
|
||||
if base.get("status_code", 0) != 0:
|
||||
raise Exception(f"MiniMax error {base.get('status_code')}: {base.get('status_msg')}")
|
||||
|
||||
|
||||
def generate_music(prompt_file: str, output_file: str) -> str:
|
||||
"""Generate a song from a JSON spec via MiniMax /v1/music_generation.
|
||||
|
||||
Spec JSON: {"title": str, "prompt": str, "lyrics"?: str, "is_instrumental"?: bool}
|
||||
- lyrics given -> use them (supports [Verse]/[Chorus] structure tags, \\n lines)
|
||||
- is_instrumental true -> pure music, no lyrics needed
|
||||
- otherwise -> lyrics_optimizer auto-writes lyrics from prompt
|
||||
"""
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
spec = json.load(f)
|
||||
|
||||
api_key = os.getenv("MINIMAX_API_KEY")
|
||||
if not api_key:
|
||||
return "MINIMAX_API_KEY is not set"
|
||||
|
||||
prompt = (spec.get("prompt") or "").strip()
|
||||
if not prompt:
|
||||
raise ValueError("`prompt` is required in the music spec")
|
||||
lyrics = spec.get("lyrics") or None # treat empty string the same as absent
|
||||
is_instrumental = bool(spec.get("is_instrumental", False))
|
||||
|
||||
body = {
|
||||
"model": os.getenv("MINIMAX_MUSIC_MODEL", "music-2.6-free"),
|
||||
"prompt": prompt,
|
||||
"output_format": "hex",
|
||||
"audio_setting": {"sample_rate": 44100, "bitrate": 256000, "format": "mp3"},
|
||||
}
|
||||
if lyrics:
|
||||
body["lyrics"] = lyrics
|
||||
elif is_instrumental:
|
||||
body["is_instrumental"] = True
|
||||
else:
|
||||
body["lyrics_optimizer"] = True
|
||||
|
||||
host = os.getenv("MINIMAX_API_HOST", MINIMAX_DEFAULT_HOST).rstrip("/")
|
||||
response = requests.post(
|
||||
f"{host}/v1/music_generation",
|
||||
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
||||
json=body,
|
||||
timeout=300,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
_check_base_resp(payload)
|
||||
audio_hex = (payload.get("data") or {}).get("audio")
|
||||
if not audio_hex:
|
||||
raise Exception("MiniMax returned no audio data")
|
||||
|
||||
output_dir = os.path.dirname(output_file)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(bytes.fromhex(audio_hex))
|
||||
return f"Successfully generated music to {output_file}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate music using MiniMax API")
|
||||
parser.add_argument("--prompt-file", required=True,
|
||||
help="Absolute path to JSON spec file {title, prompt, lyrics?, is_instrumental?}")
|
||||
parser.add_argument("--output-file", required=True, help="Output path for generated MP3")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
print(generate_music(args.prompt_file, args.output_file))
|
||||
except Exception as e:
|
||||
print(f"Error while generating music: {e}")
|
||||
@@ -64,6 +64,7 @@ Parameters:
|
||||
> - The script handles all TTS API calls and audio generation internally.
|
||||
> - Do NOT read the Python file, just call it with the parameters.
|
||||
> - Always include `--transcript-file` to generate a readable transcript for the user.
|
||||
> - The TTS provider and its concurrency are selected automatically from environment variables — you do not choose or tune them.
|
||||
|
||||
## Script JSON Format
|
||||
|
||||
@@ -172,8 +173,8 @@ After generation:
|
||||
## Requirements
|
||||
|
||||
The following environment variables must be set:
|
||||
- `VOLCENGINE_TTS_APPID`: Volcengine TTS application ID
|
||||
- `VOLCENGINE_TTS_ACCESS_TOKEN`: Volcengine TTS access token
|
||||
- For Volcengine: `VOLCENGINE_TTS_APPID` and `VOLCENGINE_TTS_ACCESS_TOKEN`
|
||||
- For MiniMax: `MINIMAX_API_KEY`
|
||||
- `VOLCENGINE_TTS_CLUSTER`: Volcengine TTS cluster (optional, defaults to "volcano_tts")
|
||||
|
||||
## Notes
|
||||
@@ -183,3 +184,20 @@ The following environment variables must be set:
|
||||
- Technical content should be simplified for audio accessibility in the script
|
||||
- Complex notations (formulas, code) should be translated to plain language in the script
|
||||
- Long content may result in longer podcasts
|
||||
|
||||
## Providers (Volcengine / MiniMax)
|
||||
|
||||
Auto-selected by environment variables:
|
||||
|
||||
- `VOLCENGINE_TTS_APPID` + `VOLCENGINE_TTS_ACCESS_TOKEN` set → Volcengine TTS (default).
|
||||
- Only `MINIMAX_API_KEY` set → MiniMax TTS (`/v1/t2a_v2`).
|
||||
- Force with `PODCAST_GENERATION_PROVIDER=volcengine|minimax`.
|
||||
|
||||
MiniMax overrides: `MINIMAX_API_HOST` (default `https://api.minimaxi.com`),
|
||||
`MINIMAX_TTS_MODEL` (default `speech-2.6-hd`), `MINIMAX_TTS_VOICE_MALE`
|
||||
(default `male-qn-qingse`), `MINIMAX_TTS_VOICE_FEMALE` (default `female-tianmei`).
|
||||
|
||||
Concurrency is owned by each provider internally — MiniMax runs single-threaded
|
||||
to reduce rate-limit failures, Volcengine uses 4 workers. There is no
|
||||
caller-facing concurrency knob; transient rate limits are handled by automatic
|
||||
retry with backoff.
|
||||
|
||||
@@ -3,6 +3,8 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Literal, Optional
|
||||
@@ -12,8 +14,14 @@ import requests
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MINIMAX_DEFAULT_HOST = "https://api.minimaxi.com"
|
||||
# MiniMax base_resp codes worth retrying: unknown, timeout, RPM limit, TPM limit.
|
||||
MINIMAX_RETRYABLE_CODES = {1000, 1001, 1002, 1039}
|
||||
DEFAULT_TTS_MAX_RETRIES = 4
|
||||
DEFAULT_MAX_WORKERS = 4
|
||||
DEFAULT_MINIMAX_MAX_WORKERS = 1
|
||||
|
||||
|
||||
# Types
|
||||
class ScriptLine:
|
||||
def __init__(self, speaker: Literal["male", "female"] = "male", paragraph: str = ""):
|
||||
self.speaker = speaker
|
||||
@@ -30,113 +38,243 @@ class Script:
|
||||
script = cls(locale=data.get("locale", "en"))
|
||||
for line in data.get("lines", []):
|
||||
script.lines.append(
|
||||
ScriptLine(
|
||||
speaker=line.get("speaker", "male"),
|
||||
paragraph=line.get("paragraph", ""),
|
||||
)
|
||||
ScriptLine(speaker=line.get("speaker", "male"),
|
||||
paragraph=line.get("paragraph", ""))
|
||||
)
|
||||
return script
|
||||
|
||||
|
||||
def text_to_speech(text: str, voice_type: str) -> Optional[bytes]:
|
||||
"""Convert text to speech using Volcengine TTS."""
|
||||
def _resolve_provider(override_env: str, existing_provider: str, has_existing_creds: bool) -> str:
|
||||
override = os.getenv(override_env)
|
||||
if override:
|
||||
return override.strip().lower()
|
||||
if has_existing_creds:
|
||||
return existing_provider
|
||||
if os.getenv("MINIMAX_API_KEY"):
|
||||
return "minimax"
|
||||
raise ValueError(
|
||||
f"No credentials found. Set VOLCENGINE_TTS_APPID + VOLCENGINE_TTS_ACCESS_TOKEN "
|
||||
f"for {existing_provider}, or MINIMAX_API_KEY for minimax "
|
||||
f"(optionally force with {override_env})."
|
||||
)
|
||||
|
||||
|
||||
def _resolve_tts_provider() -> str:
|
||||
has_volc = bool(
|
||||
os.getenv("VOLCENGINE_TTS_APPID") and os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN")
|
||||
)
|
||||
provider = _resolve_provider("PODCAST_GENERATION_PROVIDER", "volcengine", has_volc)
|
||||
if provider not in ("volcengine", "minimax"):
|
||||
raise ValueError(
|
||||
f"Unknown podcast provider: {provider!r} (use 'volcengine' or 'minimax')"
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
def _default_max_retries() -> int:
|
||||
try:
|
||||
return int(os.getenv("MINIMAX_TTS_MAX_RETRIES", str(DEFAULT_TTS_MAX_RETRIES)))
|
||||
except ValueError:
|
||||
return DEFAULT_TTS_MAX_RETRIES
|
||||
|
||||
|
||||
def _default_max_workers(provider: str) -> int:
|
||||
"""Each provider owns its own concurrency: MiniMax stays low to avoid rate
|
||||
limits, Volcengine keeps the historical default. Not user-tunable by design.
|
||||
"""
|
||||
if provider == "minimax":
|
||||
return DEFAULT_MINIMAX_MAX_WORKERS
|
||||
return DEFAULT_MAX_WORKERS
|
||||
|
||||
|
||||
def _parse_retry_after(response) -> Optional[float]:
|
||||
"""Return the server-provided Retry-After (seconds), if any."""
|
||||
headers = getattr(response, "headers", None) or {}
|
||||
value = headers.get("Retry-After")
|
||||
try:
|
||||
return float(value) if value else None
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _backoff_sleep(attempt: int, retry_after: Optional[float]) -> None:
|
||||
"""Sleep with exponential backoff + jitter, honoring Retry-After when present.
|
||||
|
||||
Jitter de-synchronizes concurrent workers that all got rate-limited at once,
|
||||
avoiding a thundering-herd retry storm.
|
||||
"""
|
||||
base = retry_after if retry_after else min(2 ** attempt, 30)
|
||||
time.sleep(base + random.uniform(0, 1))
|
||||
|
||||
|
||||
def text_to_speech_volcengine(
|
||||
text: str, voice_type: str, max_retries: Optional[int] = None
|
||||
) -> Optional[bytes]:
|
||||
"""Convert text to speech using Volcengine TTS (returns base64-decoded mp3 bytes).
|
||||
|
||||
Retries with exponential backoff on transient HTTP errors (429 / 5xx).
|
||||
"""
|
||||
app_id = os.getenv("VOLCENGINE_TTS_APPID")
|
||||
access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN")
|
||||
cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
|
||||
|
||||
if not app_id or not access_token:
|
||||
raise ValueError(
|
||||
"VOLCENGINE_TTS_APPID and VOLCENGINE_TTS_ACCESS_TOKEN environment variables must be set"
|
||||
)
|
||||
|
||||
if max_retries is None:
|
||||
max_retries = _default_max_retries()
|
||||
url = "https://openspeech.bytedance.com/api/v1/tts"
|
||||
|
||||
# Authentication: Bearer token with semicolon separator
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer;{access_token}",
|
||||
}
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer;{access_token}"}
|
||||
payload = {
|
||||
"app": {
|
||||
"appid": app_id,
|
||||
"token": "access_token", # literal string, not the actual token
|
||||
"cluster": cluster,
|
||||
},
|
||||
"app": {"appid": app_id, "token": "access_token", "cluster": cluster},
|
||||
"user": {"uid": "podcast-generator"},
|
||||
"audio": {
|
||||
"voice_type": voice_type,
|
||||
"encoding": "mp3",
|
||||
"speed_ratio": 1.2,
|
||||
},
|
||||
"request": {
|
||||
"reqid": str(uuid.uuid4()), # must be unique UUID
|
||||
"text": text,
|
||||
"text_type": "plain",
|
||||
"operation": "query",
|
||||
},
|
||||
"audio": {"voice_type": voice_type, "encoding": "mp3", "speed_ratio": 1.2},
|
||||
"request": {"reqid": str(uuid.uuid4()), "text": text,
|
||||
"text_type": "plain", "operation": "query"},
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=60)
|
||||
except Exception as e:
|
||||
logger.error(f"TTS error: {e}")
|
||||
if attempt < max_retries:
|
||||
_backoff_sleep(attempt, None)
|
||||
continue
|
||||
return None
|
||||
if response.status_code == 429 or response.status_code >= 500:
|
||||
logger.warning(
|
||||
f"Volcengine TTS transient HTTP {response.status_code} "
|
||||
f"(attempt {attempt + 1}/{max_retries + 1})"
|
||||
)
|
||||
if attempt < max_retries:
|
||||
_backoff_sleep(attempt, _parse_retry_after(response))
|
||||
continue
|
||||
return None
|
||||
if response.status_code != 200:
|
||||
logger.error(f"TTS API error: {response.status_code} - {response.text}")
|
||||
return None
|
||||
|
||||
result = response.json()
|
||||
if result.get("code") != 3000:
|
||||
logger.error(f"TTS error: {result.get('message')} (code: {result.get('code')})")
|
||||
return None
|
||||
|
||||
audio_data = result.get("data")
|
||||
if audio_data:
|
||||
return base64.b64decode(audio_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TTS error: {str(e)}")
|
||||
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _process_line(args: tuple[int, ScriptLine, int]) -> tuple[int, Optional[bytes]]:
|
||||
def text_to_speech_minimax(
|
||||
text: str, voice_id: str, max_retries: Optional[int] = None
|
||||
) -> Optional[bytes]:
|
||||
"""Convert text to speech using MiniMax t2a_v2 (returns hex-decoded mp3 bytes).
|
||||
|
||||
Retries with exponential backoff on HTTP 429/5xx and on retryable base_resp
|
||||
codes (rate/TPM limits, timeouts). Permanent errors (auth, balance, bad input)
|
||||
are not retried.
|
||||
"""
|
||||
api_key = os.getenv("MINIMAX_API_KEY")
|
||||
host = os.getenv("MINIMAX_API_HOST", MINIMAX_DEFAULT_HOST).rstrip("/")
|
||||
if max_retries is None:
|
||||
max_retries = _default_max_retries()
|
||||
payload = {
|
||||
"model": os.getenv("MINIMAX_TTS_MODEL", "speech-2.6-hd"),
|
||||
"text": text,
|
||||
"voice_setting": {"voice_id": voice_id, "speed": 1.0, "vol": 1.0, "pitch": 0},
|
||||
"audio_setting": {"sample_rate": 32000, "bitrate": 128000, "format": "mp3", "channel": 1},
|
||||
"output_format": "hex",
|
||||
}
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{host}/v1/t2a_v2",
|
||||
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
timeout=60,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"MiniMax TTS error: {e}")
|
||||
if attempt < max_retries:
|
||||
_backoff_sleep(attempt, None)
|
||||
continue
|
||||
return None
|
||||
if response.status_code == 429 or response.status_code >= 500:
|
||||
logger.warning(
|
||||
f"MiniMax TTS rate-limited HTTP {response.status_code} "
|
||||
f"(attempt {attempt + 1}/{max_retries + 1})"
|
||||
)
|
||||
if attempt < max_retries:
|
||||
_backoff_sleep(attempt, _parse_retry_after(response))
|
||||
continue
|
||||
return None
|
||||
if response.status_code != 200:
|
||||
logger.error(f"MiniMax TTS error: {response.status_code} - {response.text}")
|
||||
return None
|
||||
result = response.json()
|
||||
base = result.get("base_resp") or {}
|
||||
code = base.get("status_code", 0)
|
||||
if code in MINIMAX_RETRYABLE_CODES:
|
||||
logger.warning(
|
||||
f"MiniMax TTS retryable error {code}: {base.get('status_msg')} "
|
||||
f"(attempt {attempt + 1}/{max_retries + 1})"
|
||||
)
|
||||
if attempt < max_retries:
|
||||
_backoff_sleep(attempt, None)
|
||||
continue
|
||||
return None
|
||||
if code != 0:
|
||||
logger.error(f"MiniMax TTS error {code}: {base.get('status_msg')}")
|
||||
return None
|
||||
audio_hex = (result.get("data") or {}).get("audio")
|
||||
if audio_hex:
|
||||
return bytes.fromhex(audio_hex)
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _process_line(args: tuple[int, ScriptLine, int, str]) -> tuple[int, Optional[bytes]]:
|
||||
"""Process a single script line for TTS. Returns (index, audio_bytes)."""
|
||||
i, line, total = args
|
||||
|
||||
# Select voice based on speaker gender
|
||||
if line.speaker == "male":
|
||||
voice_type = "zh_male_yangguangqingnian_moon_bigtts" # Male voice
|
||||
i, line, total, provider = args
|
||||
logger.info(f"Processing line {i + 1}/{total} ({line.speaker}) via {provider}")
|
||||
if provider == "minimax":
|
||||
if line.speaker == "male":
|
||||
voice = os.getenv("MINIMAX_TTS_VOICE_MALE", "male-qn-qingse")
|
||||
else:
|
||||
voice = os.getenv("MINIMAX_TTS_VOICE_FEMALE", "female-tianmei")
|
||||
audio = text_to_speech_minimax(line.paragraph, voice)
|
||||
else:
|
||||
voice_type = "zh_female_sajiaonvyou_moon_bigtts" # Female voice
|
||||
|
||||
logger.info(f"Processing line {i + 1}/{total} ({line.speaker})")
|
||||
audio = text_to_speech(line.paragraph, voice_type)
|
||||
|
||||
if line.speaker == "male":
|
||||
voice = "zh_male_yangguangqingnian_moon_bigtts"
|
||||
else:
|
||||
voice = "zh_female_sajiaonvyou_moon_bigtts"
|
||||
audio = text_to_speech_volcengine(line.paragraph, voice)
|
||||
if not audio:
|
||||
logger.warning(f"Failed to generate audio for line {i + 1}")
|
||||
|
||||
return (i, audio)
|
||||
|
||||
|
||||
def tts_node(script: Script, max_workers: int = 4) -> list[bytes]:
|
||||
"""Convert script lines to audio chunks using TTS with multi-threading."""
|
||||
logger.info(f"Converting script to audio using {max_workers} workers...")
|
||||
def tts_node(script: Script) -> list[bytes]:
|
||||
"""Convert script lines to audio chunks using TTS with multi-threading.
|
||||
|
||||
Concurrency is owned by the resolved provider (see _default_max_workers);
|
||||
there is no caller-facing knob. Fails loudly: if any line cannot be
|
||||
synthesized (even after retries), raise rather than silently emitting an
|
||||
incomplete podcast.
|
||||
"""
|
||||
total = len(script.lines)
|
||||
|
||||
# Handle empty script case
|
||||
if total == 0:
|
||||
raise ValueError("Script contains no lines to process")
|
||||
|
||||
# Validate required environment variables before starting TTS
|
||||
if not os.getenv("VOLCENGINE_TTS_APPID") or not os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN"):
|
||||
provider = _resolve_tts_provider()
|
||||
max_workers = _default_max_workers(provider)
|
||||
if provider == "volcengine" and not (
|
||||
os.getenv("VOLCENGINE_TTS_APPID") and os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN")
|
||||
):
|
||||
raise ValueError(
|
||||
"Missing required environment variables: VOLCENGINE_TTS_APPID and VOLCENGINE_TTS_ACCESS_TOKEN must be set"
|
||||
"Volcengine TTS selected but VOLCENGINE_TTS_APPID / "
|
||||
"VOLCENGINE_TTS_ACCESS_TOKEN are not set"
|
||||
)
|
||||
if provider == "minimax" and not os.getenv("MINIMAX_API_KEY"):
|
||||
raise ValueError("MiniMax TTS selected but MINIMAX_API_KEY is not set")
|
||||
logger.info(f"Converting script to audio using {max_workers} workers (provider={provider})...")
|
||||
tasks = [(i, line, total, provider) for i, line in enumerate(script.lines)]
|
||||
|
||||
tasks = [(i, line, total) for i, line in enumerate(script.lines)]
|
||||
|
||||
# Use ThreadPoolExecutor for parallel TTS generation
|
||||
results: dict[int, Optional[bytes]] = {}
|
||||
failed_indices: list[int] = []
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
@@ -144,81 +282,52 @@ def tts_node(script: Script, max_workers: int = 4) -> list[bytes]:
|
||||
for future in as_completed(futures):
|
||||
idx, audio = future.result()
|
||||
results[idx] = audio
|
||||
# Use `not audio` to catch both None and empty bytes
|
||||
if not audio:
|
||||
failed_indices.append(idx)
|
||||
|
||||
# Log failed lines with 1-based indices for user-friendly output
|
||||
if failed_indices:
|
||||
logger.warning(
|
||||
f"Failed to generate audio for {len(failed_indices)}/{total} lines: "
|
||||
f"line numbers {sorted(i + 1 for i in failed_indices)}"
|
||||
)
|
||||
|
||||
# Collect results in order, skipping failed ones
|
||||
audio_chunks = []
|
||||
for i in range(total):
|
||||
audio = results.get(i)
|
||||
if audio:
|
||||
audio_chunks.append(audio)
|
||||
|
||||
logger.info(f"Generated {len(audio_chunks)}/{total} audio chunks successfully")
|
||||
|
||||
if not audio_chunks:
|
||||
raise ValueError(
|
||||
f"TTS generation failed for all {total} lines. "
|
||||
"Please check VOLCENGINE_TTS_APPID and VOLCENGINE_TTS_ACCESS_TOKEN environment variables."
|
||||
f"TTS failed for {len(failed_indices)}/{total} lines after retries: "
|
||||
f"line numbers {sorted(i + 1 for i in failed_indices)}. "
|
||||
f"This is usually transient API rate limiting — wait a moment and retry."
|
||||
)
|
||||
|
||||
|
||||
audio_chunks = [results[i] for i in range(total)]
|
||||
logger.info(f"Generated {len(audio_chunks)}/{total} audio chunks successfully")
|
||||
return audio_chunks
|
||||
|
||||
|
||||
def mix_audio(audio_chunks: list[bytes]) -> bytes:
|
||||
"""Combine audio chunks into a single audio file."""
|
||||
logger.info("Mixing audio chunks...")
|
||||
|
||||
if not audio_chunks:
|
||||
raise ValueError("No audio chunks to mix - TTS generation may have failed")
|
||||
|
||||
output = b"".join(audio_chunks)
|
||||
|
||||
if len(output) == 0:
|
||||
raise ValueError("Mixed audio is empty - TTS generation may have failed")
|
||||
|
||||
logger.info(f"Audio mixing complete: {len(output)} bytes")
|
||||
return output
|
||||
|
||||
|
||||
def generate_markdown(script: Script, title: str = "Podcast Script") -> str:
|
||||
"""Generate a markdown script from the podcast script."""
|
||||
lines = [f"# {title}", ""]
|
||||
|
||||
for line in script.lines:
|
||||
speaker_name = "**Host (Male)**" if line.speaker == "male" else "**Host (Female)**"
|
||||
lines.append(f"{speaker_name}: {line.paragraph}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_podcast(
|
||||
script_file: str,
|
||||
output_file: str,
|
||||
transcript_file: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Generate a podcast from a script JSON file."""
|
||||
|
||||
# Read script JSON
|
||||
def generate_podcast(script_file: str, output_file: str,
|
||||
transcript_file: Optional[str] = None) -> str:
|
||||
with open(script_file, "r", encoding="utf-8") as f:
|
||||
script_json = json.load(f)
|
||||
|
||||
if "lines" not in script_json:
|
||||
raise ValueError(f"Invalid script format: missing 'lines' key. Got keys: {list(script_json.keys())}")
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid script format: missing 'lines' key. Got keys: {list(script_json.keys())}"
|
||||
)
|
||||
script = Script.from_dict(script_json)
|
||||
logger.info(f"Loaded script with {len(script.lines)} lines")
|
||||
|
||||
# Generate transcript markdown if requested
|
||||
if transcript_file:
|
||||
title = script_json.get("title", "Podcast Script")
|
||||
markdown_content = generate_markdown(script, title)
|
||||
@@ -229,16 +338,11 @@ def generate_podcast(
|
||||
f.write(markdown_content)
|
||||
logger.info(f"Generated transcript to {transcript_file}")
|
||||
|
||||
# Convert to audio
|
||||
audio_chunks = tts_node(script)
|
||||
|
||||
if not audio_chunks:
|
||||
raise Exception("Failed to generate any audio")
|
||||
|
||||
# Mix audio
|
||||
output_audio = mix_audio(audio_chunks)
|
||||
|
||||
# Save output
|
||||
output_dir = os.path.dirname(output_file)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@@ -253,30 +357,15 @@ def generate_podcast(
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate podcast from script JSON file")
|
||||
parser.add_argument(
|
||||
"--script-file",
|
||||
required=True,
|
||||
help="Absolute path to script JSON file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
required=True,
|
||||
help="Output path for generated podcast MP3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transcript-file",
|
||||
required=False,
|
||||
help="Output path for transcript markdown file (optional)",
|
||||
)
|
||||
|
||||
parser.add_argument("--script-file", required=True, help="Absolute path to script JSON file")
|
||||
parser.add_argument("--output-file", required=True, help="Output path for generated podcast MP3")
|
||||
parser.add_argument("--transcript-file", required=False,
|
||||
help="Output path for transcript markdown file (optional)")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
result = generate_podcast(
|
||||
args.script_file,
|
||||
args.output_file,
|
||||
args.transcript_file,
|
||||
)
|
||||
result = generate_podcast(args.script_file, args.output_file,
|
||||
args.transcript_file)
|
||||
print(result)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
@@ -137,3 +137,15 @@ After generation:
|
||||
- JSON format ensures structured, parsable prompts
|
||||
- Reference image enhance generation quality significantly
|
||||
- Iterative refinement is normal for optimal results
|
||||
|
||||
## Providers (Gemini / MiniMax)
|
||||
|
||||
Auto-selected by environment variables (CLI unchanged):
|
||||
|
||||
- `GEMINI_API_KEY` set → Gemini Veo (default, unchanged).
|
||||
- Only `MINIMAX_API_KEY` set → MiniMax video (`/v1/video_generation`, async 3-step poll/download).
|
||||
- Force with `VIDEO_GENERATION_PROVIDER=gemini|minimax`.
|
||||
|
||||
MiniMax overrides: `MINIMAX_API_HOST` (default `https://api.minimaxi.com`),
|
||||
`MINIMAX_VIDEO_MODEL` (default `MiniMax-Hailuo-2.3`). The first reference image is used
|
||||
as MiniMax `first_frame_image`. MiniMax ignores `--aspect-ratio` (it uses resolution/duration).
|
||||
|
||||
@@ -4,6 +4,185 @@ import time
|
||||
|
||||
import requests
|
||||
|
||||
MINIMAX_DEFAULT_HOST = "https://api.minimaxi.com"
|
||||
|
||||
|
||||
def _resolve_provider(override_env: str, existing_provider: str, has_existing_creds: bool) -> str:
|
||||
"""Pick the provider: <SKILL>_PROVIDER override > existing creds > MiniMax fallback."""
|
||||
override = os.getenv(override_env)
|
||||
if override:
|
||||
return override.strip().lower()
|
||||
if has_existing_creds:
|
||||
return existing_provider
|
||||
if os.getenv("MINIMAX_API_KEY"):
|
||||
return "minimax"
|
||||
raise ValueError(
|
||||
f"No credentials found. Set GEMINI_API_KEY for {existing_provider}, "
|
||||
f"or MINIMAX_API_KEY for minimax (optionally force with {override_env})."
|
||||
)
|
||||
|
||||
|
||||
def _minimax_host() -> str:
|
||||
return os.getenv("MINIMAX_API_HOST", MINIMAX_DEFAULT_HOST).rstrip("/")
|
||||
|
||||
|
||||
def _ensure_output_dir(output_file: str) -> None:
|
||||
"""Create the output file's parent directory so nested paths don't fail."""
|
||||
output_dir = os.path.dirname(output_file)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
def _check_base_resp(payload: dict) -> None:
|
||||
base = payload.get("base_resp") or {}
|
||||
if base.get("status_code", 0) != 0:
|
||||
raise Exception(f"MiniMax error {base.get('status_code')}: {base.get('status_msg')}")
|
||||
|
||||
|
||||
def _guess_mime(image_path: str) -> str:
|
||||
ext = os.path.splitext(image_path)[1].lower()
|
||||
return {
|
||||
".png": "image/png",
|
||||
".webp": "image/webp",
|
||||
".gif": "image/gif",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
}.get(ext, "image/jpeg")
|
||||
|
||||
|
||||
def _to_data_url(image_path: str) -> str:
|
||||
with open(image_path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return f"data:{_guess_mime(image_path)};base64,{b64}"
|
||||
|
||||
|
||||
def _poll_video_task(host: str, auth: str, task_id: str,
|
||||
max_attempts: int = 120, interval: int = 3) -> str:
|
||||
for _ in range(max_attempts):
|
||||
response = requests.get(
|
||||
f"{host}/v1/query/video_generation",
|
||||
headers={"Authorization": auth},
|
||||
params={"task_id": task_id},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
status = payload.get("status")
|
||||
if status == "Success":
|
||||
return payload["file_id"]
|
||||
if status == "Fail":
|
||||
base = payload.get("base_resp") or {}
|
||||
raise Exception(
|
||||
f"MiniMax video task {task_id} failed: "
|
||||
f"{base.get('status_code')} {base.get('status_msg')}"
|
||||
)
|
||||
# Surface query-level errors (bad task_id, auth) that arrive as a non-zero
|
||||
# base_resp without a terminal status, then keep polling.
|
||||
_check_base_resp(payload)
|
||||
time.sleep(interval)
|
||||
raise Exception(f"MiniMax video task {task_id} timed out after {max_attempts} polls")
|
||||
|
||||
|
||||
def _retrieve_file_url(host: str, auth: str, file_id: str) -> str:
|
||||
response = requests.get(
|
||||
f"{host}/v1/files/retrieve",
|
||||
headers={"Authorization": auth},
|
||||
params={"file_id": file_id},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
_check_base_resp(payload)
|
||||
return payload["file"]["download_url"]
|
||||
|
||||
|
||||
def _download(url: str, output_file: str) -> None:
|
||||
response = requests.get(url, timeout=300)
|
||||
response.raise_for_status()
|
||||
_ensure_output_dir(output_file)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
|
||||
def _generate_video_minimax(
|
||||
prompt: str, reference_images: list[str], output_file: str
|
||||
) -> str:
|
||||
api_key = os.getenv("MINIMAX_API_KEY")
|
||||
if not api_key:
|
||||
return "MINIMAX_API_KEY is not set"
|
||||
host = _minimax_host()
|
||||
auth = f"Bearer {api_key}"
|
||||
body = {"model": os.getenv("MINIMAX_VIDEO_MODEL", "MiniMax-Hailuo-2.3"), "prompt": prompt}
|
||||
if reference_images:
|
||||
body["first_frame_image"] = _to_data_url(reference_images[0])
|
||||
response = requests.post(
|
||||
f"{host}/v1/video_generation",
|
||||
headers={"Authorization": auth, "Content-Type": "application/json"},
|
||||
json=body,
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
_check_base_resp(payload)
|
||||
task_id = payload["task_id"]
|
||||
file_id = _poll_video_task(host, auth, task_id)
|
||||
download_url = _retrieve_file_url(host, auth, file_id)
|
||||
_download(download_url, output_file)
|
||||
return f"The video has been generated successfully to {output_file}"
|
||||
|
||||
|
||||
def download(url: str, output_file: str) -> None:
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("GEMINI_API_KEY is not set")
|
||||
response = requests.get(url, headers={"x-goog-api-key": api_key}, timeout=300)
|
||||
response.raise_for_status()
|
||||
_ensure_output_dir(output_file)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
|
||||
def _generate_video_gemini(
|
||||
prompt: str, reference_images: list[str], output_file: str
|
||||
) -> str:
|
||||
reference_payload = []
|
||||
request_json = {"instances": [{"prompt": prompt}]}
|
||||
for reference_image in reference_images:
|
||||
with open(reference_image, "rb") as f:
|
||||
image_b64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
reference_payload.append(
|
||||
{"image": {"mimeType": "image/jpeg", "bytesBase64Encoded": image_b64},
|
||||
"referenceType": "asset"}
|
||||
)
|
||||
if reference_payload:
|
||||
request_json["instances"][0]["referenceImages"] = reference_payload
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
return "GEMINI_API_KEY is not set"
|
||||
response = requests.post(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/veo-3.1-generate-preview:predictLongRunning",
|
||||
headers={"x-goog-api-key": api_key, "Content-Type": "application/json"},
|
||||
json=request_json,
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
operation_name = data["name"]
|
||||
while True:
|
||||
response = requests.get(
|
||||
f"https://generativelanguage.googleapis.com/v1beta/{operation_name}",
|
||||
headers={"x-goog-api-key": api_key},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if data.get("done", False):
|
||||
sample = data["response"]["generateVideoResponse"]["generatedSamples"][0]
|
||||
download(sample["video"]["uri"], output_file)
|
||||
break
|
||||
time.sleep(3)
|
||||
return f"The video has been generated successfully to {output_file}"
|
||||
|
||||
|
||||
def generate_video(
|
||||
prompt_file: str,
|
||||
@@ -13,104 +192,31 @@ def generate_video(
|
||||
) -> str:
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
prompt = f.read()
|
||||
referenceImages = []
|
||||
i = 0
|
||||
json = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
}
|
||||
for reference_image in reference_images:
|
||||
i += 1
|
||||
with open(reference_image, "rb") as f:
|
||||
image_b64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
referenceImages.append(
|
||||
{
|
||||
"image": {"mimeType": "image/jpeg", "bytesBase64Encoded": image_b64},
|
||||
"referenceType": "asset",
|
||||
}
|
||||
)
|
||||
if i > 0:
|
||||
json["instances"][0]["referenceImages"] = referenceImages
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
return "GEMINI_API_KEY is not set"
|
||||
response = requests.post(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/veo-3.1-generate-preview:predictLongRunning",
|
||||
headers={
|
||||
"x-goog-api-key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=json,
|
||||
provider = _resolve_provider(
|
||||
"VIDEO_GENERATION_PROVIDER", "gemini", bool(os.getenv("GEMINI_API_KEY"))
|
||||
)
|
||||
json = response.json()
|
||||
operation_name = json["name"]
|
||||
while True:
|
||||
response = requests.get(
|
||||
f"https://generativelanguage.googleapis.com/v1beta/{operation_name}",
|
||||
headers={
|
||||
"x-goog-api-key": api_key,
|
||||
},
|
||||
)
|
||||
json = response.json()
|
||||
if json.get("done", False):
|
||||
sample = json["response"]["generateVideoResponse"]["generatedSamples"][0]
|
||||
url = sample["video"]["uri"]
|
||||
download(url, output_file)
|
||||
break
|
||||
time.sleep(3)
|
||||
return f"The video has been generated successfully to {output_file}"
|
||||
|
||||
|
||||
def download(url: str, output_file: str):
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
return "GEMINI_API_KEY is not set"
|
||||
response = requests.get(
|
||||
url,
|
||||
headers={
|
||||
"x-goog-api-key": api_key,
|
||||
},
|
||||
)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
if provider == "minimax":
|
||||
# MiniMax video uses resolution/duration, not aspect_ratio; aspect_ratio ignored.
|
||||
return _generate_video_minimax(prompt, reference_images, output_file)
|
||||
if provider in ("gemini", "google"):
|
||||
return _generate_video_gemini(prompt, reference_images, output_file)
|
||||
raise ValueError(f"Unknown video provider: {provider!r} (use 'gemini' or 'minimax')")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate videos using Gemini API")
|
||||
parser.add_argument(
|
||||
"--prompt-file",
|
||||
required=True,
|
||||
help="Absolute path to JSON prompt file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference-images",
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="Absolute paths to reference images (space-separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
required=True,
|
||||
help="Output path for generated image",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aspect-ratio",
|
||||
required=False,
|
||||
default="16:9",
|
||||
help="Aspect ratio of the generated image",
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate videos using Gemini or MiniMax API")
|
||||
parser.add_argument("--prompt-file", required=True, help="Absolute path to JSON prompt file")
|
||||
parser.add_argument("--reference-images", nargs="*", default=[],
|
||||
help="Absolute paths to reference images (space-separated)")
|
||||
parser.add_argument("--output-file", required=True, help="Output path for generated video")
|
||||
parser.add_argument("--aspect-ratio", required=False, default="16:9",
|
||||
help="Aspect ratio of the generated video (Gemini only)")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
print(
|
||||
generate_video(
|
||||
args.prompt_file,
|
||||
args.reference_images,
|
||||
args.output_file,
|
||||
args.aspect_ratio,
|
||||
)
|
||||
)
|
||||
print(generate_video(args.prompt_file, args.reference_images,
|
||||
args.output_file, args.aspect_ratio))
|
||||
except Exception as e:
|
||||
print(f"Error while generating video: {e}")
|
||||
|
||||
Reference in New Issue
Block a user