mirror of
https://github.com/xtekky/gpt4free.git
synced 2026-04-22 23:57:17 +08:00
feat: add gTTS provider and update EdgeTTS & media docs
- In **docs/media.md**:
- Updated the import to include `gTTS` alongside `EdgeTTS`, `Gemini`, and `PollinationsAI`.
- Changed the audio parameter for EdgeTTS from `"locale": "en-US"` to `"language": "en"`.
- Added a new code example demonstrating how to use the gTTS provider and save the output as "google-tts.mp3".
- In **g4f/Provider/__init__.py**:
- Replaced the import of `EdgeTTS` with a wildcard import (`from .audio import *`) to include all audio providers.
- In **g4f/Provider/audio/EdgeTTS.py**:
- Added a new class attribute `model_id = "edge-tts"`.
- Changed the voice selection logic to use `cls.model_id` instead of the hardcoded string "edge-tts".
- Updated the filename generation to use `[cls.model_id]` instead of `[cls.default_model]`.
- In **g4f/Provider/audio/__init__.py**:
- Added an import for the new `gTTS` provider.
- Added new file **g4f/Provider/audio/gTTS.py**:
- Implements the gTTS provider using the `gtts` library.
- Defines provider attributes (`label`, `working`, `model_id`, etc.) and generates audio using a similar structure to EdgeTTS.
- In **g4f/image/__init__.py**:
- Modified the `get_extension` function to extract the extension using `.lower().lstrip('.')` instead of slicing with `[1:]`.
This commit is contained in:
+6
-2
@@ -33,7 +33,7 @@ asyncio.run(main())
|
||||
```python
|
||||
from g4f.client import Client
|
||||
|
||||
from g4f.Provider import EdgeTTS, Gemini, PollinationsAI
|
||||
from g4f.Provider import gTTS, EdgeTTS, Gemini, PollinationsAI
|
||||
|
||||
client = Client(provider=PollinationsAI)
|
||||
response = client.media.generate("Hello", audio={"voice": "alloy", "format": "mp3"})
|
||||
@@ -48,8 +48,12 @@ response = client.media.generate("Hello", model="gemini-audio")
|
||||
response.data[0].save("gemini.ogx")
|
||||
|
||||
client = Client(provider=EdgeTTS)
|
||||
response = client.media.generate("Hello", audio={"locale": "en-US"})
|
||||
response = client.media.generate("Hello", audio={"language": "en"})
|
||||
response.data[0].save("edge-tts.mp3")
|
||||
|
||||
client = Client(provider=gTTS)
|
||||
response = client.media.generate("Hello", audio={"language": "en"})
|
||||
response.data[0].save("google-tts.mp3")
|
||||
```
|
||||
|
||||
#### **Transcribe an Audio File:**
|
||||
|
||||
@@ -29,7 +29,7 @@ try:
|
||||
except ImportError as e:
|
||||
debug.error("MiniMax providers not loaded:", e)
|
||||
try:
|
||||
from .audio import EdgeTTS
|
||||
from .audio import *
|
||||
except ImportError as e:
|
||||
debug.error("Audio providers not loaded:", e)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from ..helper import format_image_prompt
|
||||
class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Edge TTS"
|
||||
working = has_edge_tts
|
||||
model_id = "edge-tts"
|
||||
default_language = "en"
|
||||
default_locale = "en-US"
|
||||
default_format = "mp3"
|
||||
@@ -45,7 +46,7 @@ class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
if not prompt:
|
||||
raise ValueError("Prompt is empty.")
|
||||
voice = audio.get("voice", model if model and model != "edge-tts" else None)
|
||||
voice = audio.get("voice", model if model and model != cls.model_id else None)
|
||||
if not voice:
|
||||
voices = await VoicesManager.create()
|
||||
if "locale" in audio:
|
||||
@@ -62,7 +63,7 @@ class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
voice = random.choice(voices)["Name"]
|
||||
|
||||
format = audio.get("format", cls.default_format)
|
||||
filename = get_filename([cls.default_model], prompt, f".{format}", prompt)
|
||||
filename = get_filename([cls.model_id], prompt, f".{format}", prompt)
|
||||
target_path = os.path.join(get_media_dir(), filename)
|
||||
ensure_media_dir()
|
||||
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .EdgeTTS import EdgeTTS
|
||||
from .EdgeTTS import EdgeTTS
|
||||
from .gTTS import gTTS
|
||||
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
from gtts import gTTS as gTTS_Service
|
||||
has_gtts = True
|
||||
except ImportError:
|
||||
has_gtts = False
|
||||
|
||||
from ...typing import AsyncResult, Messages
|
||||
from ...providers.response import AudioResponse
|
||||
from ...image.copy_images import get_filename, get_media_dir, ensure_media_dir
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import format_image_prompt
|
||||
|
||||
locals = {
|
||||
"en-AU": ["English (Australia)", "en", "com.au"],
|
||||
"en-GB": ["English (United Kingdom)", "en", "co.uk"],
|
||||
"en-US": ["English (United States)", "en", "us"],
|
||||
"en-CA": ["English (Canada)", "en", "ca"],
|
||||
"en-IN": ["English (India)", "en", "co.in"],
|
||||
"en-IE": ["English (Ireland)", "en", "ie"],
|
||||
"en-ZA": ["English (South Africa)", "en", "co.za"],
|
||||
"en-NG": ["English (Nigeria)", "en", "com.ng"],
|
||||
"fr-CA": ["French (Canada)", "fr", "ca"],
|
||||
"fr-FR": ["French (France)", "fr", "fr"],
|
||||
"de-DE": ["German (Germany)", "de", "de"],
|
||||
"zh-CN": ["Mandarin (China Mainland)", "zh-CN", "com"],
|
||||
"zh-TW": ["Mandarin (Taiwan)", "zh-TW", "com"],
|
||||
"pt-BR": ["Portuguese (Brazil)", "pt", "com.br"],
|
||||
"pt-PT": ["Portuguese (Portugal)", "pt", "pt"],
|
||||
"es-MX": ["Spanish (Mexico)", "es", "com.mx"],
|
||||
"es-ES": ["Spanish (Spain)", "es", "es"],
|
||||
"es-US": ["Spanish (United States)", "es", "us"],
|
||||
}
|
||||
models = {locale[0]: {"lang": locale[1], "tld": locale[2]} for locale in locals.values()}
|
||||
|
||||
class gTTS(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "gTTS (Google Text-to-Speech)"
|
||||
working = has_gtts
|
||||
model_id = "google-tts"
|
||||
default_language = "en"
|
||||
default_tld = "com"
|
||||
default_format = "mp3"
|
||||
models = list(models.keys())
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
prompt: str = None,
|
||||
audio: dict = {},
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
if not prompt:
|
||||
raise ValueError("Prompt is empty.")
|
||||
format = audio.get("format", cls.default_format)
|
||||
filename = get_filename([cls.model_id], prompt, f".{format}", prompt)
|
||||
target_path = os.path.join(get_media_dir(), filename)
|
||||
ensure_media_dir()
|
||||
|
||||
gTTS_Service(
|
||||
prompt,
|
||||
**{
|
||||
"lang": audio.get("language", cls.default_language),
|
||||
"tld": audio.get("tld", cls.default_tld),
|
||||
"slow": audio.get("slow", False),
|
||||
**models.get(model, {})
|
||||
}
|
||||
).save(target_path)
|
||||
|
||||
yield AudioResponse(f"/media/{filename}", audio=audio, text=prompt)
|
||||
@@ -79,7 +79,7 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image:
|
||||
|
||||
def get_extension(filename: str) -> Optional[str]:
|
||||
if '.' in filename:
|
||||
ext = os.path.splitext(filename)[1][1:].lower()
|
||||
ext = os.path.splitext(filename)[1].lower().lstrip('.')
|
||||
return ext if ext in EXTENSIONS_MAP else None
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user