mirror of
https://github.com/xtekky/gpt4free.git
synced 2026-04-22 15:47:11 +08:00
Refactor provider methods to unify async and sync handling, enhance clarity, and improve error management
This commit is contained in:
@@ -174,7 +174,12 @@ class GithubCopilot(OpenaiTemplate):
|
||||
"Please run 'g4f auth github-copilot' to authenticate."
|
||||
) from e
|
||||
raise
|
||||
return super().get_models(api_key, base_url, timeout)
|
||||
response = super().get_models(api_key, base_url, timeout)
|
||||
if isinstance(response, dict):
|
||||
for key in list(response.keys()):
|
||||
if key.startswith("accounts/") or key.startswith("text-embedding-") or key in ("minimax-m2.5", "goldeneye-free-auto"):
|
||||
del response[key]
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def get_headers(cls, stream: bool, api_key: str | None = None, headers: dict[str, str] | None = None) -> dict[str, str]:
|
||||
|
||||
+7
-2
@@ -11,7 +11,9 @@ from .client import Client, AsyncClient, ClientFactory, create_custom_provider
|
||||
from .typing import Messages, CreateResult, AsyncResult, ImageType
|
||||
from .cookies import get_cookies, set_cookies
|
||||
from .providers.types import ProviderType
|
||||
from .providers.base_provider import get_async_provider_method, get_provider_method
|
||||
from .providers.helper import concat_chunks, async_concat_chunks
|
||||
from .providers.asyncio import to_sync_generator
|
||||
from .client.service import get_model_and_provider
|
||||
|
||||
# Configure logger
|
||||
@@ -69,7 +71,9 @@ class ChatCompletion:
|
||||
model, messages, provider, stream, image, image_name,
|
||||
ignore_working, ignore_stream, **kwargs
|
||||
)
|
||||
result = provider.create_function(model, messages, stream=stream, **kwargs)
|
||||
method = get_provider_method(provider)
|
||||
result = method(model, messages, stream=stream, **kwargs)
|
||||
result = to_sync_generator(result)
|
||||
return result if stream or ignore_stream else concat_chunks(result)
|
||||
|
||||
@staticmethod
|
||||
@@ -86,7 +90,8 @@ class ChatCompletion:
|
||||
model, messages, provider, stream, image, image_name,
|
||||
ignore_working, ignore_stream, **kwargs
|
||||
)
|
||||
result = provider.async_create_function(model, messages, stream=stream, **kwargs)
|
||||
method = get_async_provider_method(provider)
|
||||
result = method(model, messages, stream=stream, **kwargs)
|
||||
if not stream and not ignore_stream and hasattr(result, "__aiter__"):
|
||||
result = async_concat_chunks(result)
|
||||
return result
|
||||
@@ -539,8 +539,6 @@ class AnyProvider(AsyncGeneratorProvider, AnyModelProviderMixin):
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async_create_function = create_async_generator
|
||||
|
||||
|
||||
# Clean model names function
|
||||
def clean_name(name: str) -> str:
|
||||
|
||||
@@ -45,6 +45,23 @@ async def async_generator_to_list(generator: AsyncIterator) -> list:
|
||||
|
||||
def to_sync_generator(generator: AsyncIterator, stream: bool = True, timeout: int = None) -> Iterator:
|
||||
loop = get_running_loop(check_nested=False)
|
||||
if asyncio.iscoroutine(generator):
|
||||
if loop is not None:
|
||||
try:
|
||||
result = loop.run_until_complete(generator)
|
||||
except RuntimeError as e:
|
||||
if asyncio.iscoroutine(generator):
|
||||
try:
|
||||
generator.close()
|
||||
except Exception:
|
||||
pass
|
||||
raise NestAsyncioError(
|
||||
'Install "nest-asyncio2" package | pip install -U nest-asyncio2'
|
||||
) from e
|
||||
else:
|
||||
result = asyncio.run(generator)
|
||||
yield result
|
||||
return
|
||||
if not stream:
|
||||
yield from asyncio.run(async_generator_to_list(generator))
|
||||
return
|
||||
@@ -72,6 +89,9 @@ def to_sync_generator(generator: AsyncIterator, stream: bool = True, timeout: in
|
||||
|
||||
# Helper function to convert a synchronous iterator to an async iterator
|
||||
async def to_async_iterator(iterator) -> AsyncIterator:
|
||||
if isinstance(iterator, (str, bytes)):
|
||||
yield iterator
|
||||
return
|
||||
if hasattr(iterator, '__aiter__'):
|
||||
async for item in iterator:
|
||||
yield item
|
||||
|
||||
+40
-199
@@ -2,12 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from asyncio import AbstractEventLoop
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from abc import abstractmethod
|
||||
import json
|
||||
from inspect import signature, Parameter
|
||||
from typing import Optional, _GenericAlias
|
||||
from typing import Optional, _GenericAlias, AsyncIterator
|
||||
from pathlib import Path
|
||||
from aiohttp import ClientSession
|
||||
try:
|
||||
@@ -22,7 +20,7 @@ from .response import BaseConversation, AuthResult
|
||||
from .helper import concat_chunks
|
||||
from ..cookies import get_cookies_dir
|
||||
from ..requests import raise_for_status
|
||||
from ..errors import ModelNotFoundError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError, CloudflareError
|
||||
from ..errors import ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError, CloudflareError
|
||||
from .. import debug
|
||||
|
||||
SAFE_PARAMETERS = [
|
||||
@@ -71,92 +69,45 @@ PARAMETER_EXAMPLES = {
|
||||
"aspect_ratio": "1:1",
|
||||
}
|
||||
|
||||
async def wait_for(response: AsyncIterator, timeout: int = None) -> AsyncIterator:
|
||||
if timeout is not None:
|
||||
while True:
|
||||
try:
|
||||
yield await asyncio.wait_for(
|
||||
response.__anext__(),
|
||||
timeout=timeout
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError("The operation timed out after {} seconds".format(timeout)) from e
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
else:
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
def get_async_provider_method(provider: type) -> Optional[callable]:
|
||||
if hasattr(provider, "create_async_generator"):
|
||||
return provider.create_async_generator
|
||||
if hasattr(provider, "create_async"):
|
||||
return provider.create_async
|
||||
if hasattr(provider, "create_completion"):
|
||||
async def wrapper(*args, **kwargs):
|
||||
for chunk in provider.create_completion(*args, **kwargs):
|
||||
yield chunk
|
||||
return wrapper
|
||||
raise NotImplementedError(f"{provider.__name__} does not implement an async method")
|
||||
|
||||
|
||||
def get_provider_method(provider: type) -> Optional[callable]:
|
||||
if hasattr(provider, "create_completion"):
|
||||
return provider.create_completion
|
||||
if hasattr(provider, "create_async_generator"):
|
||||
return provider.create_async_generator
|
||||
if hasattr(provider, "create_async"):
|
||||
return provider.create_async
|
||||
raise NotImplementedError(f"{provider.__name__} does not implement a create method")
|
||||
|
||||
class AbstractProvider(BaseProvider):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def create_completion(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Create a completion with the given parameters.
|
||||
|
||||
Args:
|
||||
model (str): The model to use.
|
||||
messages (Messages): The messages to process.
|
||||
stream (bool): Whether to use streaming.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the creation process.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
*,
|
||||
timeout: int = None,
|
||||
loop: AbstractEventLoop = None,
|
||||
executor: ThreadPoolExecutor = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously creates a result based on the given model and messages.
|
||||
|
||||
Args:
|
||||
cls (type): The class on which this method is called.
|
||||
model (str): The model to use for creation.
|
||||
messages (Messages): The messages to process.
|
||||
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
|
||||
executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The created result as a string.
|
||||
"""
|
||||
loop = asyncio.get_running_loop() if loop is None else loop
|
||||
|
||||
def create_func() -> str:
|
||||
return concat_chunks(cls.create_completion(model=model, messages=messages, **kwargs))
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, create_func), timeout=timeout
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError("The operation timed out after {} seconds in {}".format(timeout, cls.__name__)) from e
|
||||
|
||||
@classmethod
|
||||
def create_function(cls, *args, **kwargs) -> CreateResult:
|
||||
"""
|
||||
Creates a completion using the synchronous method.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the completion creation.
|
||||
"""
|
||||
return cls.create_completion(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def async_create_function(cls, *args, **kwargs) -> AsyncResult:
|
||||
"""
|
||||
Creates a completion using the synchronous method.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the completion creation.
|
||||
"""
|
||||
return cls.create_async(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
|
||||
params = {name: parameter for name, parameter in signature(
|
||||
@@ -240,29 +191,6 @@ class AsyncProvider(AbstractProvider):
|
||||
Provides asynchronous functionality for creating completions.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create_completion(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Creates a completion result synchronously.
|
||||
|
||||
Args:
|
||||
cls (type): The class on which this method is called.
|
||||
model (str): The model to use for creation.
|
||||
messages (Messages): The messages to process.
|
||||
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the completion creation.
|
||||
"""
|
||||
get_running_loop(check_nested=False)
|
||||
yield asyncio.run(cls.create_async(model, messages, **kwargs))
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
async def create_async(
|
||||
@@ -309,33 +237,6 @@ class AsyncGeneratorProvider(AbstractProvider):
|
||||
await raise_for_status(response)
|
||||
return await response.json()
|
||||
|
||||
@classmethod
|
||||
def create_completion(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
timeout: int = None,
|
||||
stream_timeout: int = None,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Creates a streaming completion result synchronously.
|
||||
|
||||
Args:
|
||||
cls (type): The class on which this method is called.
|
||||
model (str): The model to use for creation.
|
||||
messages (Messages): The messages to process.
|
||||
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the streaming completion creation.
|
||||
"""
|
||||
return to_sync_generator(
|
||||
cls.create_async_generator(model, messages, **kwargs),
|
||||
timeout=stream_timeout if cls.use_stream_timeout is None else timeout,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
async def create_async_generator(
|
||||
@@ -359,34 +260,6 @@ class AsyncGeneratorProvider(AbstractProvider):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
async def async_create_function(cls, *args, **kwargs) -> AsyncResult:
|
||||
"""
|
||||
Creates a completion using the synchronous method.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the completion creation.
|
||||
"""
|
||||
response = cls.create_async_generator(*args, **kwargs)
|
||||
if "stream_timeout" in kwargs or "timeout" in kwargs:
|
||||
timeout = kwargs.get("stream_timeout") if cls.use_stream_timeout else kwargs.get("timeout")
|
||||
while True:
|
||||
try:
|
||||
yield await asyncio.wait_for(
|
||||
response.__anext__(),
|
||||
timeout=timeout
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError("The operation timed out after {} seconds in {}".format(timeout, cls.__name__)) from e
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
else:
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
class ProviderModelMixin:
|
||||
default_model: str = None
|
||||
models: list[str] = []
|
||||
@@ -466,13 +339,6 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
|
||||
raise MissingAuthError(f"API key is required for {cls.__name__}")
|
||||
return AuthResult()
|
||||
|
||||
@classmethod
|
||||
def on_auth(cls, **kwargs) -> AuthResult:
|
||||
auth_result = cls.on_auth_async(**kwargs)
|
||||
if hasattr(auth_result, "__aiter__"):
|
||||
return to_sync_generator(auth_result)
|
||||
return asyncio.run(auth_result)
|
||||
|
||||
@classmethod
|
||||
def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None):
|
||||
if auth_result is not None:
|
||||
@@ -505,31 +371,6 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
|
||||
else:
|
||||
raise MissingAuthError
|
||||
|
||||
@classmethod
|
||||
def create_completion(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
auth_result: AuthResult = None
|
||||
cache_file = cls.get_cache_file()
|
||||
try:
|
||||
auth_result = cls.get_auth_result()
|
||||
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
|
||||
except (MissingAuthError, NoValidHarFileError, CloudflareError):
|
||||
response = cls.on_auth(**kwargs)
|
||||
for chunk in response:
|
||||
if isinstance(chunk, AuthResult):
|
||||
auth_result = chunk
|
||||
else:
|
||||
yield chunk
|
||||
for chunk in to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs), kwargs.get("stream_timeout", kwargs.get("timeout"))):
|
||||
if cache_file is not None:
|
||||
cls.write_cache_file(cache_file, auth_result)
|
||||
cache_file = None
|
||||
yield chunk
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
|
||||
+53
-226
@@ -2,14 +2,42 @@ from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
from ..typing import Dict, Type, List, CreateResult, Messages, AsyncResult
|
||||
from ..typing import Dict, Type, List, Messages, AsyncResult
|
||||
from .types import BaseProvider, BaseRetryProvider, ProviderType
|
||||
from .response import ProviderInfo, JsonConversation, is_content
|
||||
from .base_provider import get_async_provider_method, to_async_iterator
|
||||
from .. import debug
|
||||
from ..tools.run_tools import AuthManager
|
||||
from ..config import AppConfig
|
||||
from ..errors import RetryProviderError, RetryNoProviderError, MissingAuthError, NoValidHarFileError
|
||||
|
||||
|
||||
def _resolve_model(provider: Type[BaseProvider], model: str) -> str:
|
||||
alias = model or getattr(provider, "default_model", None)
|
||||
if hasattr(provider, "model_aliases"):
|
||||
alias = provider.model_aliases.get(model, model)
|
||||
if isinstance(alias, list):
|
||||
alias = random.choice(alias)
|
||||
return alias
|
||||
|
||||
|
||||
def _prepare_provider_kwargs(
|
||||
provider: Type[BaseProvider],
|
||||
api_key,
|
||||
conversation: JsonConversation,
|
||||
kwargs: dict,
|
||||
) -> dict:
|
||||
extra_body = kwargs.copy()
|
||||
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
|
||||
if not current_api_key or AppConfig.disable_custom_api_key:
|
||||
current_api_key = AuthManager.load_api_key(provider)
|
||||
if current_api_key:
|
||||
extra_body["api_key"] = current_api_key
|
||||
if conversation is not None and hasattr(conversation, provider.__name__):
|
||||
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
|
||||
return extra_body
|
||||
|
||||
|
||||
class RotatedProvider(BaseRetryProvider):
|
||||
"""
|
||||
A provider that rotates through a list of providers, attempting one provider per
|
||||
@@ -48,69 +76,6 @@ class RotatedProvider(BaseRetryProvider):
|
||||
#new_provider_name = self.providers[self.current_index].__name__
|
||||
#debug.log(f"Rotated to next provider: {new_provider_name}")
|
||||
|
||||
def create_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
ignored: list[str] = [], # 'ignored' is less relevant now but kept for compatibility
|
||||
api_key: str = None,
|
||||
**kwargs,
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Create a completion using the current provider and rotating on failure.
|
||||
|
||||
It will try each provider in the list once per call, rotating after each
|
||||
failed attempt, until one succeeds or all have failed.
|
||||
"""
|
||||
exceptions: Dict[str, Exception] = {}
|
||||
|
||||
# Loop over the number of providers, giving each one a chance
|
||||
for _ in range(len(self.providers)):
|
||||
provider = self._get_current_provider()
|
||||
self.last_provider = provider
|
||||
self._rotate_provider()
|
||||
|
||||
# Skip if provider is in the ignored list
|
||||
if provider.get_parent() in ignored:
|
||||
continue
|
||||
|
||||
alias = model or getattr(provider, "default_model", None)
|
||||
if hasattr(provider, "model_aliases"):
|
||||
alias = provider.model_aliases.get(model, model)
|
||||
if isinstance(alias, list):
|
||||
alias = random.choice(alias)
|
||||
|
||||
debug.log(f"Attempting provider: {provider.__name__} with model: {alias}")
|
||||
yield ProviderInfo(**provider.get_dict(), model=alias, alias=model)
|
||||
|
||||
extra_body = kwargs.copy()
|
||||
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
|
||||
if not current_api_key or AppConfig.disable_custom_api_key:
|
||||
current_api_key = AuthManager.load_api_key(provider)
|
||||
if current_api_key:
|
||||
extra_body["api_key"] = current_api_key
|
||||
|
||||
try:
|
||||
# Attempt to get a response from the current provider
|
||||
response = provider.create_function(alias, messages, **extra_body)
|
||||
started = False
|
||||
for chunk in response:
|
||||
if chunk:
|
||||
yield chunk
|
||||
if is_content(chunk):
|
||||
started = True
|
||||
if started:
|
||||
provider.live += 1
|
||||
# Success, so we return and do not rotate
|
||||
return
|
||||
except Exception as e:
|
||||
provider.live -= 1
|
||||
exceptions[provider.__name__] = e
|
||||
debug.error(f"{provider.__name__} failed: {e}")
|
||||
|
||||
# If the loop completes, all providers have failed
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
async def create_async_generator(
|
||||
self,
|
||||
model: str,
|
||||
@@ -133,28 +98,18 @@ class RotatedProvider(BaseRetryProvider):
|
||||
if provider.get_parent() in ignored:
|
||||
continue
|
||||
|
||||
alias = model or getattr(provider, "default_model", None)
|
||||
if hasattr(provider, "model_aliases"):
|
||||
alias = provider.model_aliases.get(model, model)
|
||||
if isinstance(alias, list):
|
||||
alias = random.choice(alias)
|
||||
alias = _resolve_model(provider, model)
|
||||
|
||||
debug.log(f"Attempting provider: {provider.__name__} with model: {alias}")
|
||||
yield ProviderInfo(**provider.get_dict(), model=alias)
|
||||
|
||||
extra_body = kwargs.copy()
|
||||
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
|
||||
if not current_api_key or AppConfig.disable_custom_api_key:
|
||||
current_api_key = AuthManager.load_api_key(provider)
|
||||
if current_api_key:
|
||||
extra_body["api_key"] = current_api_key
|
||||
if conversation and hasattr(conversation, provider.__name__):
|
||||
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
|
||||
extra_body = _prepare_provider_kwargs(provider, api_key, conversation, kwargs)
|
||||
|
||||
try:
|
||||
response = provider.async_create_function(alias, messages, **extra_body)
|
||||
method = get_async_provider_method(provider)
|
||||
response = method(model=alias, messages=messages, **extra_body)
|
||||
started = False
|
||||
async for chunk in response:
|
||||
async for chunk in to_async_iterator(response):
|
||||
if isinstance(chunk, JsonConversation):
|
||||
if conversation is None: conversation = JsonConversation()
|
||||
setattr(conversation, provider.__name__, chunk.get_dict())
|
||||
@@ -173,10 +128,6 @@ class RotatedProvider(BaseRetryProvider):
|
||||
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
# Maintain API compatibility
|
||||
create_function = create_completion
|
||||
async_create_function = create_async_generator
|
||||
|
||||
class IterListProvider(BaseRetryProvider):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -196,61 +147,6 @@ class IterListProvider(BaseRetryProvider):
|
||||
self.working = True
|
||||
self.last_provider: Type[BaseProvider] = None
|
||||
|
||||
def create_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
ignored: list[str] = [],
|
||||
api_key: str = None,
|
||||
**kwargs,
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Create a completion using available providers.
|
||||
Args:
|
||||
model (str): The model to be used for completion.
|
||||
messages (Messages): The messages to be used for generating completion.
|
||||
Yields:
|
||||
CreateResult: Tokens or results from the completion.
|
||||
Raises:
|
||||
Exception: Any exception encountered during the completion process.
|
||||
"""
|
||||
exceptions = {}
|
||||
started: bool = False
|
||||
for provider in self.get_providers(ignored):
|
||||
self.last_provider = provider
|
||||
alias = model
|
||||
if not model:
|
||||
alias = getattr(provider, "default_model", None)
|
||||
if hasattr(provider, "model_aliases"):
|
||||
alias = provider.model_aliases.get(model, model)
|
||||
if isinstance(alias, list):
|
||||
alias = random.choice(alias)
|
||||
debug.log(f"Using provider: {provider.__name__} with model: {alias}")
|
||||
yield ProviderInfo(**provider.get_dict(), model=alias)
|
||||
extra_body = kwargs.copy()
|
||||
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
|
||||
if not current_api_key or AppConfig.disable_custom_api_key:
|
||||
current_api_key = AuthManager.load_api_key(provider)
|
||||
if current_api_key:
|
||||
extra_body["api_key"] = current_api_key
|
||||
try:
|
||||
response = provider.create_function(alias, messages, **extra_body)
|
||||
for chunk in response:
|
||||
if chunk:
|
||||
yield chunk
|
||||
if is_content(chunk):
|
||||
started = True
|
||||
if started:
|
||||
return
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
debug.error(f"{provider.__name__}:", e)
|
||||
if started:
|
||||
raise e
|
||||
yield e
|
||||
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
async def create_async_generator(
|
||||
self,
|
||||
model: str,
|
||||
@@ -265,41 +161,23 @@ class IterListProvider(BaseRetryProvider):
|
||||
|
||||
for provider in self.get_providers(ignored):
|
||||
self.last_provider = provider
|
||||
alias = model
|
||||
if not model:
|
||||
alias = getattr(provider, "default_model", None)
|
||||
if hasattr(provider, "model_aliases"):
|
||||
alias = provider.model_aliases.get(model, model)
|
||||
if isinstance(alias, list):
|
||||
alias = random.choice(alias)
|
||||
alias = _resolve_model(provider, model)
|
||||
debug.log(f"Using {provider.__name__} provider with model {alias}")
|
||||
yield ProviderInfo(**provider.get_dict(), model=alias)
|
||||
extra_body = kwargs.copy()
|
||||
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
|
||||
if not current_api_key or AppConfig.disable_custom_api_key:
|
||||
current_api_key = AuthManager.load_api_key(provider)
|
||||
if current_api_key:
|
||||
extra_body["api_key"] = current_api_key
|
||||
if conversation is not None and hasattr(conversation, provider.__name__):
|
||||
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
|
||||
extra_body = _prepare_provider_kwargs(provider, api_key, conversation, kwargs)
|
||||
try:
|
||||
response = provider.async_create_function(model, messages, **extra_body)
|
||||
if hasattr(response, "__aiter__"):
|
||||
async for chunk in response:
|
||||
if isinstance(chunk, JsonConversation):
|
||||
if conversation is None:
|
||||
conversation = JsonConversation()
|
||||
setattr(conversation, provider.__name__, chunk.get_dict())
|
||||
yield conversation
|
||||
elif chunk:
|
||||
yield chunk
|
||||
if is_content(chunk):
|
||||
started = True
|
||||
elif response:
|
||||
response = await response
|
||||
if response:
|
||||
yield response
|
||||
started = True
|
||||
method = get_async_provider_method(provider)
|
||||
response = method(model=alias, messages=messages, **extra_body)
|
||||
async for chunk in to_async_iterator(response):
|
||||
if isinstance(chunk, JsonConversation):
|
||||
if conversation is None:
|
||||
conversation = JsonConversation()
|
||||
setattr(conversation, provider.__name__, chunk.get_dict())
|
||||
yield conversation
|
||||
elif chunk:
|
||||
yield chunk
|
||||
if is_content(chunk):
|
||||
started = True
|
||||
if started:
|
||||
return
|
||||
except Exception as e:
|
||||
@@ -307,13 +185,9 @@ class IterListProvider(BaseRetryProvider):
|
||||
debug.error(f"{provider.__name__}:", e)
|
||||
if started:
|
||||
raise e
|
||||
yield e
|
||||
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
create_function = create_completion
|
||||
async_create_function = create_async_generator
|
||||
|
||||
def get_providers(self, ignored: list[str]) -> list[ProviderType]:
|
||||
providers = [p for p in self.providers if p.__name__ not in ignored]
|
||||
if self.shuffle:
|
||||
@@ -340,48 +214,6 @@ class RetryProvider(IterListProvider):
|
||||
self.single_provider_retry = single_provider_retry
|
||||
self.max_retries = max_retries
|
||||
|
||||
def create_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs,
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Create a completion using available providers.
|
||||
Args:
|
||||
model (str): The model to be used for completion.
|
||||
messages (Messages): The messages to be used for generating completion.
|
||||
Yields:
|
||||
CreateResult: Tokens or results from the completion.
|
||||
Raises:
|
||||
Exception: Any exception encountered during the completion process.
|
||||
"""
|
||||
if self.single_provider_retry:
|
||||
exceptions = {}
|
||||
started: bool = False
|
||||
provider = self.providers[0]
|
||||
self.last_provider = provider
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
if debug.logging:
|
||||
print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
|
||||
response = provider.create_function(model, messages, **kwargs)
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
if is_content(chunk):
|
||||
started = True
|
||||
if started:
|
||||
return
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
if started:
|
||||
raise e
|
||||
raise_exceptions(exceptions)
|
||||
else:
|
||||
yield from super().create_completion(model, messages, **kwargs)
|
||||
|
||||
async def create_async_generator(
|
||||
self,
|
||||
model: str,
|
||||
@@ -397,16 +229,11 @@ class RetryProvider(IterListProvider):
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})")
|
||||
response = provider.async_create_function(model, messages, **kwargs)
|
||||
if hasattr(response, "__aiter__"):
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
if is_content(chunk):
|
||||
started = True
|
||||
else:
|
||||
response = await response
|
||||
if response:
|
||||
yield response
|
||||
method = get_async_provider_method(provider)
|
||||
response = method(model=model, messages=messages, **kwargs)
|
||||
async for chunk in to_async_iterator(response):
|
||||
yield chunk
|
||||
if is_content(chunk):
|
||||
started = True
|
||||
if started:
|
||||
return
|
||||
|
||||
@@ -7,8 +7,8 @@ from typing import Optional, Union
|
||||
from ..typing import AsyncResult, Messages, MediaListType
|
||||
from ..client.service import get_model_and_provider
|
||||
from ..client.helper import filter_json
|
||||
from ..providers.types import ProviderType
|
||||
from .base_provider import AsyncGeneratorProvider
|
||||
from .types import ProviderType
|
||||
from .base_provider import AsyncGeneratorProvider, get_async_provider_method, to_async_iterator
|
||||
from .response import ToolCalls, FinishReason, Usage
|
||||
|
||||
|
||||
@@ -74,14 +74,15 @@ class ToolSupportProvider(AsyncGeneratorProvider):
|
||||
finish = None
|
||||
chunks = []
|
||||
has_usage = False
|
||||
async for chunk in provider.async_create_function(
|
||||
method = get_async_provider_method(provider)
|
||||
async for chunk in to_async_iterator(method(
|
||||
model,
|
||||
messages,
|
||||
stream=stream,
|
||||
media=media,
|
||||
response_format=response_format,
|
||||
**kwargs,
|
||||
):
|
||||
)):
|
||||
if isinstance(chunk, str):
|
||||
chunks.append(chunk)
|
||||
elif isinstance(chunk, Usage):
|
||||
|
||||
+1
-38
@@ -26,8 +26,6 @@ class BaseProvider(ABC):
|
||||
supports_message_history: bool = False
|
||||
supports_system_message: bool = False
|
||||
params: str
|
||||
create_function: callable
|
||||
async_create_function: callable
|
||||
live: int = 0
|
||||
|
||||
@classmethod
|
||||
@@ -44,42 +42,6 @@ class BaseProvider(ABC):
|
||||
def get_parent(cls) -> str:
|
||||
return getattr(cls, "parent", cls.__name__)
|
||||
|
||||
@abstractmethod
|
||||
def create_function(
|
||||
*args,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Create a function to generate a response based on the model and messages.
|
||||
|
||||
Args:
|
||||
model (str): The model to use.
|
||||
messages (Messages): The messages to process.
|
||||
stream (bool): Whether to stream the response.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the creation.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def async_create_function(
|
||||
*args,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Asynchronously create a function to generate a response based on the model and messages.
|
||||
|
||||
Args:
|
||||
model (str): The model to use.
|
||||
messages (Messages): The messages to process.
|
||||
stream (bool): Whether to stream the response.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the creation.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
class BaseRetryProvider(BaseProvider):
|
||||
"""
|
||||
Base class for a provider that implements retry logic.
|
||||
@@ -93,6 +55,7 @@ class BaseRetryProvider(BaseProvider):
|
||||
|
||||
__name__: str = "RetryProvider"
|
||||
supports_stream: bool = True
|
||||
use_stream_timeout: bool = True
|
||||
last_provider: Type[BaseProvider] = None
|
||||
|
||||
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
|
||||
|
||||
@@ -152,7 +152,7 @@ async def get_args_from_nodriver(
|
||||
await page.wait_for(wait_for, timeout=timeout)
|
||||
if callback is not None:
|
||||
await callback(page)
|
||||
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
|
||||
for c in await asyncio.wait_for(page.send(nodriver.cdp.network.get_cookies([url])), timeout=timeout):
|
||||
cookies[c.name] = c.value
|
||||
await stop_browser()
|
||||
return {
|
||||
|
||||
@@ -22,6 +22,7 @@ from ..providers.helper import filter_none
|
||||
from ..providers.asyncio import to_async_iterator, to_sync_generator
|
||||
from ..providers.response import Reasoning, FinishReason, Sources, Usage, ProviderInfo
|
||||
from ..providers.types import ProviderType
|
||||
from ..providers.base_provider import get_async_provider_method, get_provider_method, wait_for
|
||||
from ..cookies import get_cookies_dir
|
||||
from ..config import AppConfig
|
||||
from .web_search import do_search, get_search_message
|
||||
@@ -296,9 +297,12 @@ async def async_iter_run_tools(
|
||||
kwargs.update(extra_kwargs)
|
||||
|
||||
# Generate response
|
||||
method = get_async_provider_method(provider)
|
||||
response = to_async_iterator(
|
||||
provider.async_create_function(model=model, messages=messages, **kwargs)
|
||||
method(model=model, messages=messages, **kwargs)
|
||||
)
|
||||
timeout = kwargs.get("stream_timeout") if provider.use_stream_timeout else kwargs.get("timeout")
|
||||
response = wait_for(response, timeout=timeout) if stream else response
|
||||
|
||||
try:
|
||||
usage_model = model
|
||||
@@ -471,9 +475,10 @@ def iter_run_tools(
|
||||
usage_provider = provider.__name__
|
||||
completion_tokens = 0
|
||||
usage = None
|
||||
for chunk in provider.create_function(
|
||||
method = get_provider_method(provider)
|
||||
for chunk in to_sync_generator(method(
|
||||
model=model, messages=messages, provider=provider, **kwargs
|
||||
):
|
||||
)):
|
||||
if isinstance(chunk, FinishReason):
|
||||
if sources is not None:
|
||||
yield sources
|
||||
|
||||
Reference in New Issue
Block a user