feat: add gTTS provider and update EdgeTTS & media docs

- In **docs/media.md**:
  - Updated the import to include `gTTS` alongside `EdgeTTS`, `Gemini`, and `PollinationsAI`.
  - Changed the audio parameter for EdgeTTS from `"locale": "en-US"` to `"language": "en"`.
  - Added a new code example demonstrating how to use the gTTS provider and save the output as "google-tts.mp3".

- In **g4f/Provider/__init__.py**:
  - Replaced the import of `EdgeTTS` with a wildcard import (`from .audio import *`) to include all audio providers.

- In **g4f/Provider/audio/EdgeTTS.py**:
  - Added a new class attribute `model_id = "edge-tts"`.
  - Changed the voice selection logic to use `cls.model_id` instead of the hardcoded string "edge-tts".
  - Updated the filename generation to use `[cls.model_id]` instead of `[cls.default_model]`.

- In **g4f/Provider/audio/__init__.py**:
  - Added an import for the new `gTTS` provider.

- Added new file **g4f/Provider/audio/gTTS.py**:
  - Implements the gTTS provider using the `gtts` library.
  - Defines provider attributes (`label`, `working`, `model_id`, etc.) and generates audio using a similar structure to EdgeTTS.

- In **g4f/image/__init__.py**:
  - Modified the `get_extension` function to extract the extension using `.lower().lstrip('.')` instead of slicing with `[1:]`.
This commit is contained in:
hlohaus
2025-04-19 14:40:53 +02:00
parent b1f01d464b
commit 099d7283ed
6 changed files with 90 additions and 7 deletions
+6 -2
View File
@@ -33,7 +33,7 @@ asyncio.run(main())
```python
from g4f.client import Client
from g4f.Provider import EdgeTTS, Gemini, PollinationsAI
from g4f.Provider import gTTS, EdgeTTS, Gemini, PollinationsAI
client = Client(provider=PollinationsAI)
response = client.media.generate("Hello", audio={"voice": "alloy", "format": "mp3"})
@@ -48,8 +48,12 @@ response = client.media.generate("Hello", model="gemini-audio")
response.data[0].save("gemini.ogx")
client = Client(provider=EdgeTTS)
response = client.media.generate("Hello", audio={"locale": "en-US"})
response = client.media.generate("Hello", audio={"language": "en"})
response.data[0].save("edge-tts.mp3")
client = Client(provider=gTTS)
response = client.media.generate("Hello", audio={"language": "en"})
response.data[0].save("google-tts.mp3")
```
#### **Transcribe an Audio File:**
+1 -1
View File
@@ -29,7 +29,7 @@ try:
except ImportError as e:
debug.error("MiniMax providers not loaded:", e)
try:
from .audio import EdgeTTS
from .audio import *
except ImportError as e:
debug.error("Audio providers not loaded:", e)
+3 -2
View File
@@ -20,6 +20,7 @@ from ..helper import format_image_prompt
class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin):
label = "Edge TTS"
working = has_edge_tts
model_id = "edge-tts"
default_language = "en"
default_locale = "en-US"
default_format = "mp3"
@@ -45,7 +46,7 @@ class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin):
prompt = format_image_prompt(messages, prompt)
if not prompt:
raise ValueError("Prompt is empty.")
voice = audio.get("voice", model if model and model != "edge-tts" else None)
voice = audio.get("voice", model if model and model != cls.model_id else None)
if not voice:
voices = await VoicesManager.create()
if "locale" in audio:
@@ -62,7 +63,7 @@ class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin):
voice = random.choice(voices)["Name"]
format = audio.get("format", cls.default_format)
filename = get_filename([cls.default_model], prompt, f".{format}", prompt)
filename = get_filename([cls.model_id], prompt, f".{format}", prompt)
target_path = os.path.join(get_media_dir(), filename)
ensure_media_dir()
+2 -1
View File
@@ -1 +1,2 @@
from .EdgeTTS import EdgeTTS
from .EdgeTTS import EdgeTTS
from .gTTS import gTTS
+77
View File
@@ -0,0 +1,77 @@
from __future__ import annotations
import os
import random
import asyncio
try:
from gtts import gTTS as gTTS_Service
has_gtts = True
except ImportError:
has_gtts = False
from ...typing import AsyncResult, Messages
from ...providers.response import AudioResponse
from ...image.copy_images import get_filename, get_media_dir, ensure_media_dir
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt
locals = {
"en-AU": ["English (Australia)", "en", "com.au"],
"en-GB": ["English (United Kingdom)", "en", "co.uk"],
"en-US": ["English (United States)", "en", "us"],
"en-CA": ["English (Canada)", "en", "ca"],
"en-IN": ["English (India)", "en", "co.in"],
"en-IE": ["English (Ireland)", "en", "ie"],
"en-ZA": ["English (South Africa)", "en", "co.za"],
"en-NG": ["English (Nigeria)", "en", "com.ng"],
"fr-CA": ["French (Canada)", "fr", "ca"],
"fr-FR": ["French (France)", "fr", "fr"],
"de-DE": ["German (Germany)", "de", "de"],
"zh-CN": ["Mandarin (China Mainland)", "zh-CN", "com"],
"zh-TW": ["Mandarin (Taiwan)", "zh-TW", "com"],
"pt-BR": ["Portuguese (Brazil)", "pt", "com.br"],
"pt-PT": ["Portuguese (Portugal)", "pt", "pt"],
"es-MX": ["Spanish (Mexico)", "es", "com.mx"],
"es-ES": ["Spanish (Spain)", "es", "es"],
"es-US": ["Spanish (United States)", "es", "us"],
}
models = {locale[0]: {"lang": locale[1], "tld": locale[2]} for locale in locals.values()}
class gTTS(AsyncGeneratorProvider, ProviderModelMixin):
label = "gTTS (Google Text-to-Speech)"
working = has_gtts
model_id = "google-tts"
default_language = "en"
default_tld = "com"
default_format = "mp3"
models = list(models.keys())
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
prompt: str = None,
audio: dict = {},
**kwargs
) -> AsyncResult:
prompt = format_image_prompt(messages, prompt)
if not prompt:
raise ValueError("Prompt is empty.")
format = audio.get("format", cls.default_format)
filename = get_filename([cls.model_id], prompt, f".{format}", prompt)
target_path = os.path.join(get_media_dir(), filename)
ensure_media_dir()
gTTS_Service(
prompt,
**{
"lang": audio.get("language", cls.default_language),
"tld": audio.get("tld", cls.default_tld),
"slow": audio.get("slow", False),
**models.get(model, {})
}
).save(target_path)
yield AudioResponse(f"/media/{filename}", audio=audio, text=prompt)
+1 -1
View File
@@ -79,7 +79,7 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image:
def get_extension(filename: str) -> Optional[str]:
if '.' in filename:
ext = os.path.splitext(filename)[1][1:].lower()
ext = os.path.splitext(filename)[1].lower().lstrip('.')
return ext if ext in EXTENSIONS_MAP else None
return None