Refactor provider methods to unify async and sync handling, enhance clarity, and improve error management

This commit is contained in:
hlohaus
2026-04-16 02:38:03 +02:00
parent 65ee4cb088
commit 06df47f279
10 changed files with 141 additions and 476 deletions
+6 -1
View File
@@ -174,7 +174,12 @@ class GithubCopilot(OpenaiTemplate):
"Please run 'g4f auth github-copilot' to authenticate."
) from e
raise
return super().get_models(api_key, base_url, timeout)
response = super().get_models(api_key, base_url, timeout)
if isinstance(response, dict):
for key in list(response.keys()):
if key.startswith("accounts/") or key.startswith("text-embedding-") or key in ("minimax-m2.5", "goldeneye-free-auto"):
del response[key]
return response
@classmethod
def get_headers(cls, stream: bool, api_key: str | None = None, headers: dict[str, str] | None = None) -> dict[str, str]:
+7 -2
View File
@@ -11,7 +11,9 @@ from .client import Client, AsyncClient, ClientFactory, create_custom_provider
from .typing import Messages, CreateResult, AsyncResult, ImageType
from .cookies import get_cookies, set_cookies
from .providers.types import ProviderType
from .providers.base_provider import get_async_provider_method, get_provider_method
from .providers.helper import concat_chunks, async_concat_chunks
from .providers.asyncio import to_sync_generator
from .client.service import get_model_and_provider
# Configure logger
@@ -69,7 +71,9 @@ class ChatCompletion:
model, messages, provider, stream, image, image_name,
ignore_working, ignore_stream, **kwargs
)
result = provider.create_function(model, messages, stream=stream, **kwargs)
method = get_provider_method(provider)
result = method(model, messages, stream=stream, **kwargs)
result = to_sync_generator(result)
return result if stream or ignore_stream else concat_chunks(result)
@staticmethod
@@ -86,7 +90,8 @@ class ChatCompletion:
model, messages, provider, stream, image, image_name,
ignore_working, ignore_stream, **kwargs
)
result = provider.async_create_function(model, messages, stream=stream, **kwargs)
method = get_async_provider_method(provider)
result = method(model, messages, stream=stream, **kwargs)
if not stream and not ignore_stream and hasattr(result, "__aiter__"):
result = async_concat_chunks(result)
return result
-2
View File
@@ -539,8 +539,6 @@ class AnyProvider(AsyncGeneratorProvider, AnyModelProviderMixin):
):
yield chunk
async_create_function = create_async_generator
# Clean model names function
def clean_name(name: str) -> str:
+20
View File
@@ -45,6 +45,23 @@ async def async_generator_to_list(generator: AsyncIterator) -> list:
def to_sync_generator(generator: AsyncIterator, stream: bool = True, timeout: int = None) -> Iterator:
loop = get_running_loop(check_nested=False)
if asyncio.iscoroutine(generator):
if loop is not None:
try:
result = loop.run_until_complete(generator)
except RuntimeError as e:
if asyncio.iscoroutine(generator):
try:
generator.close()
except Exception:
pass
raise NestAsyncioError(
'Install "nest-asyncio2" package | pip install -U nest-asyncio2'
) from e
else:
result = asyncio.run(generator)
yield result
return
if not stream:
yield from asyncio.run(async_generator_to_list(generator))
return
@@ -72,6 +89,9 @@ def to_sync_generator(generator: AsyncIterator, stream: bool = True, timeout: in
# Helper function to convert a synchronous iterator to an async iterator
async def to_async_iterator(iterator) -> AsyncIterator:
if isinstance(iterator, (str, bytes)):
yield iterator
return
if hasattr(iterator, '__aiter__'):
async for item in iterator:
yield item
+40 -199
View File
@@ -2,12 +2,10 @@ from __future__ import annotations
import asyncio
import random
from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod
import json
from inspect import signature, Parameter
from typing import Optional, _GenericAlias
from typing import Optional, _GenericAlias, AsyncIterator
from pathlib import Path
from aiohttp import ClientSession
try:
@@ -22,7 +20,7 @@ from .response import BaseConversation, AuthResult
from .helper import concat_chunks
from ..cookies import get_cookies_dir
from ..requests import raise_for_status
from ..errors import ModelNotFoundError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError, CloudflareError
from ..errors import ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError, CloudflareError
from .. import debug
SAFE_PARAMETERS = [
@@ -71,92 +69,45 @@ PARAMETER_EXAMPLES = {
"aspect_ratio": "1:1",
}
async def wait_for(response: AsyncIterator, timeout: int = None) -> AsyncIterator:
if timeout is not None:
while True:
try:
yield await asyncio.wait_for(
response.__anext__(),
timeout=timeout
)
except TimeoutError as e:
raise TimeoutError("The operation timed out after {} seconds".format(timeout)) from e
except StopAsyncIteration:
break
else:
async for chunk in response:
yield chunk
def get_async_provider_method(provider: type) -> Optional[callable]:
if hasattr(provider, "create_async_generator"):
return provider.create_async_generator
if hasattr(provider, "create_async"):
return provider.create_async
if hasattr(provider, "create_completion"):
async def wrapper(*args, **kwargs):
for chunk in provider.create_completion(*args, **kwargs):
yield chunk
return wrapper
raise NotImplementedError(f"{provider.__name__} does not implement an async method")
def get_provider_method(provider: type) -> Optional[callable]:
if hasattr(provider, "create_completion"):
return provider.create_completion
if hasattr(provider, "create_async_generator"):
return provider.create_async_generator
if hasattr(provider, "create_async"):
return provider.create_async
raise NotImplementedError(f"{provider.__name__} does not implement a create method")
class AbstractProvider(BaseProvider):
@classmethod
@abstractmethod
def create_completion(
cls,
model: str,
messages: Messages,
**kwargs
) -> CreateResult:
"""
Create a completion with the given parameters.
Args:
model (str): The model to use.
messages (Messages): The messages to process.
stream (bool): Whether to use streaming.
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the creation process.
"""
raise NotImplementedError()
@classmethod
async def create_async(
cls,
model: str,
messages: Messages,
*,
timeout: int = None,
loop: AbstractEventLoop = None,
executor: ThreadPoolExecutor = None,
**kwargs
) -> str:
"""
Asynchronously creates a result based on the given model and messages.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
str: The created result as a string.
"""
loop = asyncio.get_running_loop() if loop is None else loop
def create_func() -> str:
return concat_chunks(cls.create_completion(model=model, messages=messages, **kwargs))
try:
return await asyncio.wait_for(
loop.run_in_executor(executor, create_func), timeout=timeout
)
except TimeoutError as e:
raise TimeoutError("The operation timed out after {} seconds in {}".format(timeout, cls.__name__)) from e
@classmethod
def create_function(cls, *args, **kwargs) -> CreateResult:
"""
Creates a completion using the synchronous method.
Args:
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the completion creation.
"""
return cls.create_completion(*args, **kwargs)
@classmethod
def async_create_function(cls, *args, **kwargs) -> AsyncResult:
"""
Creates a completion using the synchronous method.
Args:
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the completion creation.
"""
return cls.create_async(*args, **kwargs)
@classmethod
def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
params = {name: parameter for name, parameter in signature(
@@ -240,29 +191,6 @@ class AsyncProvider(AbstractProvider):
Provides asynchronous functionality for creating completions.
"""
@classmethod
def create_completion(
cls,
model: str,
messages: Messages,
**kwargs
) -> CreateResult:
"""
Creates a completion result synchronously.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the completion creation.
"""
get_running_loop(check_nested=False)
yield asyncio.run(cls.create_async(model, messages, **kwargs))
@staticmethod
@abstractmethod
async def create_async(
@@ -309,33 +237,6 @@ class AsyncGeneratorProvider(AbstractProvider):
await raise_for_status(response)
return await response.json()
@classmethod
def create_completion(
cls,
model: str,
messages: Messages,
timeout: int = None,
stream_timeout: int = None,
**kwargs
) -> CreateResult:
"""
Creates a streaming completion result synchronously.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the streaming completion creation.
"""
return to_sync_generator(
cls.create_async_generator(model, messages, **kwargs),
timeout=stream_timeout if cls.use_stream_timeout is None else timeout,
)
@staticmethod
@abstractmethod
async def create_async_generator(
@@ -359,34 +260,6 @@ class AsyncGeneratorProvider(AbstractProvider):
"""
raise NotImplementedError()
@classmethod
async def async_create_function(cls, *args, **kwargs) -> AsyncResult:
"""
Creates a completion using the synchronous method.
Args:
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the completion creation.
"""
response = cls.create_async_generator(*args, **kwargs)
if "stream_timeout" in kwargs or "timeout" in kwargs:
timeout = kwargs.get("stream_timeout") if cls.use_stream_timeout else kwargs.get("timeout")
while True:
try:
yield await asyncio.wait_for(
response.__anext__(),
timeout=timeout
)
except TimeoutError as e:
raise TimeoutError("The operation timed out after {} seconds in {}".format(timeout, cls.__name__)) from e
except StopAsyncIteration:
break
else:
async for chunk in response:
yield chunk
class ProviderModelMixin:
default_model: str = None
models: list[str] = []
@@ -466,13 +339,6 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
raise MissingAuthError(f"API key is required for {cls.__name__}")
return AuthResult()
@classmethod
def on_auth(cls, **kwargs) -> AuthResult:
auth_result = cls.on_auth_async(**kwargs)
if hasattr(auth_result, "__aiter__"):
return to_sync_generator(auth_result)
return asyncio.run(auth_result)
@classmethod
def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None):
if auth_result is not None:
@@ -505,31 +371,6 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
else:
raise MissingAuthError
@classmethod
def create_completion(
cls,
model: str,
messages: Messages,
**kwargs
) -> CreateResult:
auth_result: AuthResult = None
cache_file = cls.get_cache_file()
try:
auth_result = cls.get_auth_result()
yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
except (MissingAuthError, NoValidHarFileError, CloudflareError):
response = cls.on_auth(**kwargs)
for chunk in response:
if isinstance(chunk, AuthResult):
auth_result = chunk
else:
yield chunk
for chunk in to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs), kwargs.get("stream_timeout", kwargs.get("timeout"))):
if cache_file is not None:
cls.write_cache_file(cache_file, auth_result)
cache_file = None
yield chunk
@classmethod
async def create_async_generator(
cls,
+53 -226
View File
@@ -2,14 +2,42 @@ from __future__ import annotations
import random
from ..typing import Dict, Type, List, CreateResult, Messages, AsyncResult
from ..typing import Dict, Type, List, Messages, AsyncResult
from .types import BaseProvider, BaseRetryProvider, ProviderType
from .response import ProviderInfo, JsonConversation, is_content
from .base_provider import get_async_provider_method, to_async_iterator
from .. import debug
from ..tools.run_tools import AuthManager
from ..config import AppConfig
from ..errors import RetryProviderError, RetryNoProviderError, MissingAuthError, NoValidHarFileError
def _resolve_model(provider: Type[BaseProvider], model: str) -> str:
alias = model or getattr(provider, "default_model", None)
if hasattr(provider, "model_aliases"):
alias = provider.model_aliases.get(model, model)
if isinstance(alias, list):
alias = random.choice(alias)
return alias
def _prepare_provider_kwargs(
provider: Type[BaseProvider],
api_key,
conversation: JsonConversation,
kwargs: dict,
) -> dict:
extra_body = kwargs.copy()
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
if not current_api_key or AppConfig.disable_custom_api_key:
current_api_key = AuthManager.load_api_key(provider)
if current_api_key:
extra_body["api_key"] = current_api_key
if conversation is not None and hasattr(conversation, provider.__name__):
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
return extra_body
class RotatedProvider(BaseRetryProvider):
"""
A provider that rotates through a list of providers, attempting one provider per
@@ -48,69 +76,6 @@ class RotatedProvider(BaseRetryProvider):
#new_provider_name = self.providers[self.current_index].__name__
#debug.log(f"Rotated to next provider: {new_provider_name}")
def create_completion(
self,
model: str,
messages: Messages,
ignored: list[str] = [], # 'ignored' is less relevant now but kept for compatibility
api_key: str = None,
**kwargs,
) -> CreateResult:
"""
Create a completion using the current provider and rotating on failure.
It will try each provider in the list once per call, rotating after each
failed attempt, until one succeeds or all have failed.
"""
exceptions: Dict[str, Exception] = {}
# Loop over the number of providers, giving each one a chance
for _ in range(len(self.providers)):
provider = self._get_current_provider()
self.last_provider = provider
self._rotate_provider()
# Skip if provider is in the ignored list
if provider.get_parent() in ignored:
continue
alias = model or getattr(provider, "default_model", None)
if hasattr(provider, "model_aliases"):
alias = provider.model_aliases.get(model, model)
if isinstance(alias, list):
alias = random.choice(alias)
debug.log(f"Attempting provider: {provider.__name__} with model: {alias}")
yield ProviderInfo(**provider.get_dict(), model=alias, alias=model)
extra_body = kwargs.copy()
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
if not current_api_key or AppConfig.disable_custom_api_key:
current_api_key = AuthManager.load_api_key(provider)
if current_api_key:
extra_body["api_key"] = current_api_key
try:
# Attempt to get a response from the current provider
response = provider.create_function(alias, messages, **extra_body)
started = False
for chunk in response:
if chunk:
yield chunk
if is_content(chunk):
started = True
if started:
provider.live += 1
# Success, so we return and do not rotate
return
except Exception as e:
provider.live -= 1
exceptions[provider.__name__] = e
debug.error(f"{provider.__name__} failed: {e}")
# If the loop completes, all providers have failed
raise_exceptions(exceptions)
async def create_async_generator(
self,
model: str,
@@ -133,28 +98,18 @@ class RotatedProvider(BaseRetryProvider):
if provider.get_parent() in ignored:
continue
alias = model or getattr(provider, "default_model", None)
if hasattr(provider, "model_aliases"):
alias = provider.model_aliases.get(model, model)
if isinstance(alias, list):
alias = random.choice(alias)
alias = _resolve_model(provider, model)
debug.log(f"Attempting provider: {provider.__name__} with model: {alias}")
yield ProviderInfo(**provider.get_dict(), model=alias)
extra_body = kwargs.copy()
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
if not current_api_key or AppConfig.disable_custom_api_key:
current_api_key = AuthManager.load_api_key(provider)
if current_api_key:
extra_body["api_key"] = current_api_key
if conversation and hasattr(conversation, provider.__name__):
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
extra_body = _prepare_provider_kwargs(provider, api_key, conversation, kwargs)
try:
response = provider.async_create_function(alias, messages, **extra_body)
method = get_async_provider_method(provider)
response = method(model=alias, messages=messages, **extra_body)
started = False
async for chunk in response:
async for chunk in to_async_iterator(response):
if isinstance(chunk, JsonConversation):
if conversation is None: conversation = JsonConversation()
setattr(conversation, provider.__name__, chunk.get_dict())
@@ -173,10 +128,6 @@ class RotatedProvider(BaseRetryProvider):
raise_exceptions(exceptions)
# Maintain API compatibility
create_function = create_completion
async_create_function = create_async_generator
class IterListProvider(BaseRetryProvider):
def __init__(
self,
@@ -196,61 +147,6 @@ class IterListProvider(BaseRetryProvider):
self.working = True
self.last_provider: Type[BaseProvider] = None
def create_completion(
self,
model: str,
messages: Messages,
ignored: list[str] = [],
api_key: str = None,
**kwargs,
) -> CreateResult:
"""
Create a completion using available providers.
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
Yields:
CreateResult: Tokens or results from the completion.
Raises:
Exception: Any exception encountered during the completion process.
"""
exceptions = {}
started: bool = False
for provider in self.get_providers(ignored):
self.last_provider = provider
alias = model
if not model:
alias = getattr(provider, "default_model", None)
if hasattr(provider, "model_aliases"):
alias = provider.model_aliases.get(model, model)
if isinstance(alias, list):
alias = random.choice(alias)
debug.log(f"Using provider: {provider.__name__} with model: {alias}")
yield ProviderInfo(**provider.get_dict(), model=alias)
extra_body = kwargs.copy()
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
if not current_api_key or AppConfig.disable_custom_api_key:
current_api_key = AuthManager.load_api_key(provider)
if current_api_key:
extra_body["api_key"] = current_api_key
try:
response = provider.create_function(alias, messages, **extra_body)
for chunk in response:
if chunk:
yield chunk
if is_content(chunk):
started = True
if started:
return
except Exception as e:
exceptions[provider.__name__] = e
debug.error(f"{provider.__name__}:", e)
if started:
raise e
yield e
raise_exceptions(exceptions)
async def create_async_generator(
self,
model: str,
@@ -265,41 +161,23 @@ class IterListProvider(BaseRetryProvider):
for provider in self.get_providers(ignored):
self.last_provider = provider
alias = model
if not model:
alias = getattr(provider, "default_model", None)
if hasattr(provider, "model_aliases"):
alias = provider.model_aliases.get(model, model)
if isinstance(alias, list):
alias = random.choice(alias)
alias = _resolve_model(provider, model)
debug.log(f"Using {provider.__name__} provider with model {alias}")
yield ProviderInfo(**provider.get_dict(), model=alias)
extra_body = kwargs.copy()
current_api_key = api_key.get(provider.get_parent()) if isinstance(api_key, dict) else api_key
if not current_api_key or AppConfig.disable_custom_api_key:
current_api_key = AuthManager.load_api_key(provider)
if current_api_key:
extra_body["api_key"] = current_api_key
if conversation is not None and hasattr(conversation, provider.__name__):
extra_body["conversation"] = JsonConversation(**getattr(conversation, provider.__name__))
extra_body = _prepare_provider_kwargs(provider, api_key, conversation, kwargs)
try:
response = provider.async_create_function(model, messages, **extra_body)
if hasattr(response, "__aiter__"):
async for chunk in response:
if isinstance(chunk, JsonConversation):
if conversation is None:
conversation = JsonConversation()
setattr(conversation, provider.__name__, chunk.get_dict())
yield conversation
elif chunk:
yield chunk
if is_content(chunk):
started = True
elif response:
response = await response
if response:
yield response
started = True
method = get_async_provider_method(provider)
response = method(model=alias, messages=messages, **extra_body)
async for chunk in to_async_iterator(response):
if isinstance(chunk, JsonConversation):
if conversation is None:
conversation = JsonConversation()
setattr(conversation, provider.__name__, chunk.get_dict())
yield conversation
elif chunk:
yield chunk
if is_content(chunk):
started = True
if started:
return
except Exception as e:
@@ -307,13 +185,9 @@ class IterListProvider(BaseRetryProvider):
debug.error(f"{provider.__name__}:", e)
if started:
raise e
yield e
raise_exceptions(exceptions)
create_function = create_completion
async_create_function = create_async_generator
def get_providers(self, ignored: list[str]) -> list[ProviderType]:
providers = [p for p in self.providers if p.__name__ not in ignored]
if self.shuffle:
@@ -340,48 +214,6 @@ class RetryProvider(IterListProvider):
self.single_provider_retry = single_provider_retry
self.max_retries = max_retries
def create_completion(
self,
model: str,
messages: Messages,
**kwargs,
) -> CreateResult:
"""
Create a completion using available providers.
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
Yields:
CreateResult: Tokens or results from the completion.
Raises:
Exception: Any exception encountered during the completion process.
"""
if self.single_provider_retry:
exceptions = {}
started: bool = False
provider = self.providers[0]
self.last_provider = provider
for attempt in range(self.max_retries):
try:
if debug.logging:
print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
response = provider.create_function(model, messages, **kwargs)
for chunk in response:
yield chunk
if is_content(chunk):
started = True
if started:
return
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started:
raise e
raise_exceptions(exceptions)
else:
yield from super().create_completion(model, messages, **kwargs)
async def create_async_generator(
self,
model: str,
@@ -397,16 +229,11 @@ class RetryProvider(IterListProvider):
for attempt in range(self.max_retries):
try:
debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})")
response = provider.async_create_function(model, messages, **kwargs)
if hasattr(response, "__aiter__"):
async for chunk in response:
yield chunk
if is_content(chunk):
started = True
else:
response = await response
if response:
yield response
method = get_async_provider_method(provider)
response = method(model=model, messages=messages, **kwargs)
async for chunk in to_async_iterator(response):
yield chunk
if is_content(chunk):
started = True
if started:
return
+5 -4
View File
@@ -7,8 +7,8 @@ from typing import Optional, Union
from ..typing import AsyncResult, Messages, MediaListType
from ..client.service import get_model_and_provider
from ..client.helper import filter_json
from ..providers.types import ProviderType
from .base_provider import AsyncGeneratorProvider
from .types import ProviderType
from .base_provider import AsyncGeneratorProvider, get_async_provider_method, to_async_iterator
from .response import ToolCalls, FinishReason, Usage
@@ -74,14 +74,15 @@ class ToolSupportProvider(AsyncGeneratorProvider):
finish = None
chunks = []
has_usage = False
async for chunk in provider.async_create_function(
method = get_async_provider_method(provider)
async for chunk in to_async_iterator(method(
model,
messages,
stream=stream,
media=media,
response_format=response_format,
**kwargs,
):
)):
if isinstance(chunk, str):
chunks.append(chunk)
elif isinstance(chunk, Usage):
+1 -38
View File
@@ -26,8 +26,6 @@ class BaseProvider(ABC):
supports_message_history: bool = False
supports_system_message: bool = False
params: str
create_function: callable
async_create_function: callable
live: int = 0
@classmethod
@@ -44,42 +42,6 @@ class BaseProvider(ABC):
def get_parent(cls) -> str:
return getattr(cls, "parent", cls.__name__)
@abstractmethod
def create_function(
*args,
**kwargs
) -> CreateResult:
"""
Create a function to generate a response based on the model and messages.
Args:
model (str): The model to use.
messages (Messages): The messages to process.
stream (bool): Whether to stream the response.
Returns:
CreateResult: The result of the creation.
"""
raise NotImplementedError()
@staticmethod
def async_create_function(
*args,
**kwargs
) -> CreateResult:
"""
Asynchronously create a function to generate a response based on the model and messages.
Args:
model (str): The model to use.
messages (Messages): The messages to process.
stream (bool): Whether to stream the response.
Returns:
CreateResult: The result of the creation.
"""
raise NotImplementedError()
class BaseRetryProvider(BaseProvider):
"""
Base class for a provider that implements retry logic.
@@ -93,6 +55,7 @@ class BaseRetryProvider(BaseProvider):
__name__: str = "RetryProvider"
supports_stream: bool = True
use_stream_timeout: bool = True
last_provider: Type[BaseProvider] = None
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
+1 -1
View File
@@ -152,7 +152,7 @@ async def get_args_from_nodriver(
await page.wait_for(wait_for, timeout=timeout)
if callback is not None:
await callback(page)
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
for c in await asyncio.wait_for(page.send(nodriver.cdp.network.get_cookies([url])), timeout=timeout):
cookies[c.name] = c.value
await stop_browser()
return {
+8 -3
View File
@@ -22,6 +22,7 @@ from ..providers.helper import filter_none
from ..providers.asyncio import to_async_iterator, to_sync_generator
from ..providers.response import Reasoning, FinishReason, Sources, Usage, ProviderInfo
from ..providers.types import ProviderType
from ..providers.base_provider import get_async_provider_method, get_provider_method, wait_for
from ..cookies import get_cookies_dir
from ..config import AppConfig
from .web_search import do_search, get_search_message
@@ -296,9 +297,12 @@ async def async_iter_run_tools(
kwargs.update(extra_kwargs)
# Generate response
method = get_async_provider_method(provider)
response = to_async_iterator(
provider.async_create_function(model=model, messages=messages, **kwargs)
method(model=model, messages=messages, **kwargs)
)
timeout = kwargs.get("stream_timeout") if provider.use_stream_timeout else kwargs.get("timeout")
response = wait_for(response, timeout=timeout) if stream else response
try:
usage_model = model
@@ -471,9 +475,10 @@ def iter_run_tools(
usage_provider = provider.__name__
completion_tokens = 0
usage = None
for chunk in provider.create_function(
method = get_provider_method(provider)
for chunk in to_sync_generator(method(
model=model, messages=messages, provider=provider, **kwargs
):
)):
if isinstance(chunk, FinishReason):
if sources is not None:
yield sources