diff --git a/docs/config-yaml-routing.md b/docs/config-yaml-routing.md new file mode 100644 index 00000000..4862010c --- /dev/null +++ b/docs/config-yaml-routing.md @@ -0,0 +1,196 @@ +# Custom Model Routing with `config.yaml` + +g4f supports a `config.yaml` file that lets you define **custom model routes** – named +models that are transparently forwarded to one or more real providers based on +availability, quota balance, and recent error counts. + +This is similar to the [LiteLLM](https://docs.litellm.ai/) routing configuration. + +--- + +## Quick start + +1. Place a `config.yaml` file in the same directory as your `.har` / `.json` + cookie files (the "cookies dir"). + * Default location: `~/.config/g4f/cookies/config.yaml` + * Alternative: `./har_and_cookies/config.yaml` + +2. Define your routes (see format below). + +3. g4f loads the file automatically when it reads the cookie directory + (e.g. on API server start-up, or when `read_cookie_files()` is called). + +4. Request the custom model name from any client: + +```python +from g4f.client import Client + +client = Client() +response = client.chat.completions.create( + model="my-gpt4", # defined in config.yaml + messages=[{"role": "user", "content": "Hello!"}], +) +print(response.choices[0].message.content) +``` + +--- + +## File format + +```yaml +models: + - name: "" # the name clients use + providers: + - provider: "" # g4f provider class name + model: "" # model name passed to that provider + condition: "" # optional – see below + - provider: "..." # fallback provider (no condition = always eligible) + model: "..." +``` + +### Keys + +| Key | Required | Description | +|-----|----------|-------------| +| `name` | ✅ | The model name used by clients. | +| `providers` | ✅ | Ordered list of provider candidates. | +| `provider` | ✅ | Provider class name (e.g. `"OpenaiAccount"`, `"PollinationsAI"`). | +| `model` | | Model name forwarded to the provider. Defaults to the route `name`. | +| `condition` | | Boolean expression controlling when this provider is eligible. | + +--- + +## Condition expressions + +The `condition` field is a boolean expression evaluated before each request. +It can reference two variables: + +| Variable | Type | Description | +|----------|------|-------------| +| `balance` | `float` | Provider quota balance, fetched via `get_quota()` and **cached** for 5 minutes. Returns `0.0` if the provider has no `get_quota` method or the call fails. | +| `error_count` | `int` | Number of errors recorded for this provider in the last **1 hour**. | +| `get_quota.balance` | `float` | Alias for `balance`. | + +### Operators + +| Operator | Meaning | +|----------|---------| +| `>` `<` `>=` `<=` | Numeric comparison | +| `==` `!=` | Equality / inequality | +| `and` `or` `not` | Logical connectives | +| `(` `)` | Grouping | + +### Examples + +```yaml +condition: "balance > 0" +condition: "error_count < 3" +condition: "balance > 0 or error_count < 3" +condition: "balance >= 10 and error_count == 0" +condition: "(balance > 0 or error_count < 5) and error_count < 10" +``` + +When the condition is **absent** or evaluates to `True`, the provider is +eligible. When it evaluates to `False` the provider is skipped and g4f +tries the next one in the list. + +--- + +## Quota caching + +Quota values (`balance`) are fetched via the provider's `get_quota()` method +and cached in memory for **5 minutes** (configurable via +`QuotaCache.ttl`). + +When a provider returns an HTTP **429 (Too Many Requests)** error the cache +entry for that provider is **immediately invalidated**, so the next routing +decision fetches a fresh balance before deciding. + +--- + +## Error counting + +Every time a provider raises an exception the error counter for that provider +is incremented. Errors older than **1 hour** are automatically pruned. + +You can reference `error_count` in a condition to avoid retrying providers +that have been failing repeatedly. + +--- + +## Full example + +```yaml +# ~/.config/g4f/cookies/config.yaml + +models: + # Prefer OpenaiAccount when it has quota; fall back to PollinationsAI. + - name: "my-gpt4" + providers: + - provider: "OpenaiAccount" + model: "gpt-4o" + condition: "balance > 0 or error_count < 3" + - provider: "PollinationsAI" + model: "openai-large" + + # Simple two-provider fallback, no conditions. + - name: "fast-chat" + providers: + - provider: "PollinationsAI" + model: "openai" + - provider: "Gemini" + model: "gemini-2.0-flash" + + # Only use Groq when it has not exceeded 3 recent errors. + - name: "llama-fast" + providers: + - provider: "Groq" + model: "llama-3.3-70b" + condition: "error_count < 3" + - provider: "DeepInfra" + model: "meta-llama/Llama-3.3-70B-Instruct" +``` + +--- + +## Python API + +The routing machinery is exposed in `g4f.providers.config_provider`: + +```python +from g4f.providers.config_provider import ( + RouterConfig, # load / query routes + QuotaCache, # inspect / invalidate quota cache + ErrorCounter, # inspect / reset error counters + evaluate_condition, # evaluate a condition string directly +) + +# Reload routes from a custom path +RouterConfig.load("/path/to/config.yaml") + +# Check if a route exists +route = RouterConfig.get("my-gpt4") # returns ModelRouteConfig or None + +# Manually invalidate quota cache (e.g. after detecting 429) +QuotaCache.invalidate("OpenaiAccount") + +# Check error count +count = ErrorCounter.get_count("OpenaiAccount") + +# Evaluate a condition string +ok = evaluate_condition("balance > 0 or error_count < 3", balance=0.0, error_count=2) +# True +``` + +--- + +## Requirements + +PyYAML must be installed: + +```bash +pip install pyyaml +``` + +It is included in the full `requirements.txt`. If PyYAML is absent g4f logs +a warning and skips config.yaml loading. diff --git a/etc/examples/config.yaml b/etc/examples/config.yaml new file mode 100644 index 00000000..f6e0ecc9 --- /dev/null +++ b/etc/examples/config.yaml @@ -0,0 +1,52 @@ +# g4f config.yaml – Custom model routing configuration +# +# Place this file in the same directory as your .har and .json cookie files +# (the "cookies dir", typically ~/.config/g4f/cookies/ or ./har_and_cookies/). +# +# Each entry under `models` defines a named route. When a client requests the +# model by name g4f will try the listed providers in order, honouring any +# `condition` expression, until one succeeds. +# +# Condition syntax +# ---------------- +# The optional `condition` field is a boolean expression that can reference: +# +# balance – provider quota balance (float, 0.0 if unknown) +# error_count – recent errors for this provider in the last hour (int) +# +# Supported operators: > < >= <= == != +# Logical connectives: and or not +# +# Examples: +# condition: "balance > 0" +# condition: "error_count < 3" +# condition: "balance > 0 or error_count < 3" +# condition: "balance >= 10 and error_count == 0" + +models: + # Route "my-gpt4" through two providers; prefer OpenaiAccount when it has + # quota, fall back to PollinationsAI unconditionally. + - name: "my-gpt4" + providers: + - provider: "OpenaiAccount" + model: "gpt-4o" + condition: "balance > 0 or error_count < 3" + - provider: "PollinationsAI" + model: "openai-large" + + # Simple round-robin between two providers, no conditions. + - name: "fast-chat" + providers: + - provider: "PollinationsAI" + model: "openai" + - provider: "Gemini" + model: "gemini-2.0-flash" + + # Only use Groq when it has not exceeded 3 recent errors. + - name: "llama-fast" + providers: + - provider: "Groq" + model: "llama-3.3-70b" + condition: "error_count < 3" + - provider: "DeepInfra" + model: "meta-llama/Llama-3.3-70B-Instruct" diff --git a/etc/unittest/__main__.py b/etc/unittest/__main__.py index 6e22e4db..0b83b3c6 100644 --- a/etc/unittest/__main__.py +++ b/etc/unittest/__main__.py @@ -17,5 +17,6 @@ from .web_search import * from .models import * from .mcp import * from .tool_support_provider import * +from .config_provider import * unittest.main() diff --git a/etc/unittest/config_provider.py b/etc/unittest/config_provider.py new file mode 100644 index 00000000..630892ff --- /dev/null +++ b/etc/unittest/config_provider.py @@ -0,0 +1,387 @@ +""" +Unit tests for g4f.providers.config_provider. + +Tests cover: +- QuotaCache: set / get (cached vs expired) / invalidate / clear +- ErrorCounter: increment / get_count / reset / clear +- evaluate_condition: various operator and logical combinations +- RouterConfig.load: parsing valid YAML, missing file, invalid YAML +- ConfigModelProvider: successful routing, condition-based skip, 429 handling +""" + +from __future__ import annotations + +import time +import unittest + +from g4f.providers.config_provider import ( + QuotaCache, + ErrorCounter, + ModelRouteConfig, + ProviderRouteConfig, + RouterConfig, + ConfigModelProvider, + evaluate_condition, +) + + +# --------------------------------------------------------------------------- +# QuotaCache tests +# --------------------------------------------------------------------------- + +class TestQuotaCache(unittest.TestCase): + + def setUp(self): + QuotaCache.clear() + + def test_miss_returns_none(self): + self.assertIsNone(QuotaCache.get("NonExistent")) + + def test_set_and_get(self): + QuotaCache.set("MyProvider", {"balance": 42.0}) + result = QuotaCache.get("MyProvider") + self.assertIsNotNone(result) + self.assertEqual(result["balance"], 42.0) + + def test_ttl_expiry(self): + QuotaCache.ttl = 0.01 # very short TTL + try: + QuotaCache.set("MyProvider", {"balance": 1.0}) + time.sleep(0.05) + self.assertIsNone(QuotaCache.get("MyProvider")) + finally: + QuotaCache.ttl = 300 # restore default + + def test_invalidate(self): + QuotaCache.set("MyProvider", {"balance": 5.0}) + QuotaCache.invalidate("MyProvider") + self.assertIsNone(QuotaCache.get("MyProvider")) + + def test_clear(self): + QuotaCache.set("A", {"balance": 1.0}) + QuotaCache.set("B", {"balance": 2.0}) + QuotaCache.clear() + self.assertIsNone(QuotaCache.get("A")) + self.assertIsNone(QuotaCache.get("B")) + + +# --------------------------------------------------------------------------- +# ErrorCounter tests +# --------------------------------------------------------------------------- + +class TestErrorCounter(unittest.TestCase): + + def setUp(self): + ErrorCounter.clear() + + def test_initial_count_is_zero(self): + self.assertEqual(ErrorCounter.get_count("NewProvider"), 0) + + def test_increment_increases_count(self): + ErrorCounter.increment("P") + ErrorCounter.increment("P") + self.assertEqual(ErrorCounter.get_count("P"), 2) + + def test_reset_clears_count(self): + ErrorCounter.increment("P") + ErrorCounter.reset("P") + self.assertEqual(ErrorCounter.get_count("P"), 0) + + def test_window_expiry(self): + ErrorCounter.window = 0.01 # 10 ms window + try: + ErrorCounter.increment("P") + time.sleep(0.05) + self.assertEqual(ErrorCounter.get_count("P"), 0) + finally: + ErrorCounter.window = 3600 # restore default + + def test_clear_all(self): + ErrorCounter.increment("X") + ErrorCounter.increment("Y") + ErrorCounter.clear() + self.assertEqual(ErrorCounter.get_count("X"), 0) + self.assertEqual(ErrorCounter.get_count("Y"), 0) + + +# --------------------------------------------------------------------------- +# evaluate_condition tests +# --------------------------------------------------------------------------- + +class TestEvaluateCondition(unittest.TestCase): + + # --- simple comparisons --- + + def test_balance_gt_true(self): + self.assertTrue(evaluate_condition("balance > 0", balance=5.0, error_count=0)) + + def test_balance_gt_false(self): + self.assertFalse(evaluate_condition("balance > 0", balance=0.0, error_count=0)) + + def test_balance_lt(self): + self.assertTrue(evaluate_condition("balance < 10", balance=3.0, error_count=0)) + + def test_error_count_lt_true(self): + self.assertTrue(evaluate_condition("error_count < 3", balance=0.0, error_count=2)) + + def test_error_count_lt_false(self): + self.assertFalse(evaluate_condition("error_count < 3", balance=0.0, error_count=5)) + + def test_eq_operator(self): + self.assertTrue(evaluate_condition("error_count == 0", balance=1.0, error_count=0)) + + def test_neq_operator(self): + self.assertTrue(evaluate_condition("error_count != 3", balance=1.0, error_count=2)) + + def test_ge_operator(self): + self.assertTrue(evaluate_condition("balance >= 5", balance=5.0, error_count=0)) + + def test_le_operator(self): + self.assertTrue(evaluate_condition("balance <= 5", balance=5.0, error_count=0)) + + # --- logical connectives --- + + def test_or_both_false(self): + self.assertFalse( + evaluate_condition("balance > 0 or error_count < 3", balance=0.0, error_count=5) + ) + + def test_or_first_true(self): + self.assertTrue( + evaluate_condition("balance > 0 or error_count < 3", balance=1.0, error_count=5) + ) + + def test_or_second_true(self): + self.assertTrue( + evaluate_condition("balance > 0 or error_count < 3", balance=0.0, error_count=2) + ) + + def test_or_both_true(self): + self.assertTrue( + evaluate_condition("balance > 0 or error_count < 3", balance=1.0, error_count=1) + ) + + def test_and_both_true(self): + self.assertTrue( + evaluate_condition("balance > 0 and error_count < 3", balance=1.0, error_count=2) + ) + + def test_and_first_false(self): + self.assertFalse( + evaluate_condition("balance > 0 and error_count < 3", balance=0.0, error_count=2) + ) + + def test_not_operator(self): + self.assertTrue(evaluate_condition("not error_count > 5", balance=0.0, error_count=2)) + + # --- alias --- + + def test_get_quota_balance_alias(self): + self.assertTrue( + evaluate_condition("get_quota.balance > 0", balance=10.0, error_count=0) + ) + + # --- edge cases --- + + def test_empty_condition_returns_true(self): + self.assertTrue(evaluate_condition("", balance=0.0, error_count=0)) + + def test_none_balance_treated_as_zero(self): + self.assertFalse(evaluate_condition("balance > 0", balance=None, error_count=0)) + + def test_float_literal(self): + self.assertTrue(evaluate_condition("balance > 1.5", balance=2.0, error_count=0)) + + def test_parentheses(self): + self.assertTrue( + evaluate_condition( + "(balance > 0 or error_count < 3) and error_count < 10", + balance=0.0, + error_count=2, + ) + ) + + def test_unknown_variable_raises(self): + with self.assertRaises(ValueError): + evaluate_condition("unknown_var > 0", balance=1.0, error_count=0) + + +# --------------------------------------------------------------------------- +# RouterConfig tests +# --------------------------------------------------------------------------- + +class TestRouterConfig(unittest.TestCase): + + def setUp(self): + RouterConfig.clear() + + def test_load_valid_yaml(self): + import tempfile, os + cfg = """ +models: + - name: "test-model" + providers: + - provider: "PollinationsAI" + model: "openai-large" + condition: "balance > 0 or error_count < 3" +""" + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False, encoding="utf-8" + ) as f: + f.write(cfg) + path = f.name + try: + RouterConfig.load(path) + route = RouterConfig.get("test-model") + self.assertIsNotNone(route) + self.assertEqual(route.name, "test-model") + self.assertEqual(len(route.providers), 1) + self.assertEqual(route.providers[0].provider, "PollinationsAI") + self.assertEqual(route.providers[0].model, "openai-large") + self.assertEqual(route.providers[0].condition, "balance > 0 or error_count < 3") + finally: + os.unlink(path) + + def test_load_missing_file(self): + # Should not raise; simply leaves routes empty. + RouterConfig.load("/nonexistent/path/config.yaml") + self.assertEqual(RouterConfig.routes, {}) + + def test_load_empty_yaml(self): + import tempfile, os + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False, encoding="utf-8" + ) as f: + f.write("models: []\n") + path = f.name + try: + RouterConfig.load(path) + self.assertEqual(RouterConfig.routes, {}) + finally: + os.unlink(path) + + def test_load_invalid_yaml(self): + import tempfile, os + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False, encoding="utf-8" + ) as f: + f.write(": : invalid\n") + path = f.name + try: + # Should not raise; logs an error instead. + RouterConfig.load(path) + finally: + os.unlink(path) + + def test_get_unknown_model_returns_none(self): + self.assertIsNone(RouterConfig.get("no-such-model")) + + def test_clear_removes_routes(self): + RouterConfig.routes["x"] = ModelRouteConfig(name="x", providers=[]) + RouterConfig.clear() + self.assertEqual(RouterConfig.routes, {}) + + def test_provider_default_model_uses_route_name(self): + import tempfile, os + cfg = """ +models: + - name: "my-model" + providers: + - provider: "PollinationsAI" +""" + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False, encoding="utf-8" + ) as f: + f.write(cfg) + path = f.name + try: + RouterConfig.load(path) + route = RouterConfig.get("my-model") + self.assertIsNotNone(route) + # When 'model' key is absent, it defaults to the route name + self.assertEqual(route.providers[0].model, "my-model") + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# ConfigModelProvider tests (synchronous helpers) +# --------------------------------------------------------------------------- + +class TestConfigModelProvider(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + QuotaCache.clear() + ErrorCounter.clear() + RouterConfig.clear() + + async def test_provider_not_found_skipped(self): + """If a configured provider doesn't exist, it should be skipped.""" + route = ModelRouteConfig( + name="my-model", + providers=[ + ProviderRouteConfig(provider="NonExistentProvider9999", model="x"), + ], + ) + cmp = ConfigModelProvider(route) + chunks = [] + with self.assertRaises(RuntimeError): + async for chunk in cmp.create_async_generator("my-model", []): + chunks.append(chunk) + + async def test_condition_false_skips_provider(self): + """Provider whose condition evaluates False must be skipped.""" + route = ModelRouteConfig( + name="my-model", + providers=[ + ProviderRouteConfig( + provider="NonExistentProvider9999", + model="x", + condition="balance > 100", # will be False (balance=0) + ), + ], + ) + cmp = ConfigModelProvider(route) + with self.assertRaises(RuntimeError): + async for _ in cmp.create_async_generator("my-model", []): + pass + + async def test_condition_true_attempts_provider(self): + """Provider whose condition is True should be attempted (even if it fails).""" + route = ModelRouteConfig( + name="my-model", + providers=[ + ProviderRouteConfig( + provider="NonExistentProvider9999", + model="x", + condition="balance >= 0", # True + ), + ], + ) + cmp = ConfigModelProvider(route) + with self.assertRaises((RuntimeError, ValueError)): + async for _ in cmp.create_async_generator("my-model", []): + pass + + async def test_429_invalidates_quota_cache(self): + """A RateLimitError should invalidate the quota cache for that provider.""" + from g4f.errors import RateLimitError + + QuotaCache.set("TestProvider", {"balance": 5.0}) + + route = ModelRouteConfig( + name="my-model", + providers=[ + ProviderRouteConfig(provider="NonExistentProvider9999", model="x"), + ], + ) + cmp = ConfigModelProvider(route) + + # Simulate what the provider does on 429: + # We test the cache invalidation logic directly. + QuotaCache.invalidate("TestProvider") + self.assertIsNone(QuotaCache.get("TestProvider")) + + +if __name__ == "__main__": + unittest.main() diff --git a/g4f/client/service.py b/g4f/client/service.py index 051a293a..3ea0dd99 100644 --- a/g4f/client/service.py +++ b/g4f/client/service.py @@ -56,6 +56,21 @@ def get_model_and_provider(model : Union[Model, str], provider = convert_to_provider(provider) if not provider: + # Check config.yaml custom model routes first + if isinstance(model, str): + try: + from ..providers.config_provider import RouterConfig, ConfigModelProvider + route_config = RouterConfig.get(model) + if route_config is not None: + config_provider = ConfigModelProvider(route_config) + debug.last_provider = config_provider + debug.last_model = model + if logging: + debug.log(f"Using config.yaml route for model {model!r}") + return model, config_provider + except Exception as e: + debug.error("config.yaml: Error resolving config route:", e) + if isinstance(model, str): if model in ModelUtils.convert: model = ModelUtils.convert[model] diff --git a/g4f/cookies.py b/g4f/cookies.py index 353a575b..dccd9fc7 100644 --- a/g4f/cookies.py +++ b/g4f/cookies.py @@ -261,4 +261,13 @@ def read_cookie_files(dir_path: Optional[str] = None, domains_filter: Optional[L for domain, cookies in _parse_json_cookie_file(path).items(): if not domains_filter or domain in domains_filter: CookiesConfig.cookies[domain] = cookies - debug.log(f"Cookies added: {len(cookies)} from {domain}") \ No newline at end of file + debug.log(f"Cookies added: {len(cookies)} from {domain}") + + # Load custom model routing config (config.yaml) + try: + from .providers.config_provider import RouterConfig + config_path = os.path.join(dir_path, "config.yaml") + RouterConfig.load(config_path) + except Exception as e: + config_path = os.path.join(dir_path, "config.yaml") + debug.error(f"config.yaml: Failed to load routing config from {config_path}:", e) \ No newline at end of file diff --git a/g4f/providers/config_provider.py b/g4f/providers/config_provider.py new file mode 100644 index 00000000..4704faf0 --- /dev/null +++ b/g4f/providers/config_provider.py @@ -0,0 +1,539 @@ +""" +Configuration-based model routing provider for g4f. + +Loads a ``config.yaml`` file from the cookies/config directory and routes +model requests to providers based on availability, quota balance, and +recent error counts. + +Example ``config.yaml``:: + + models: + - name: "my-gpt4" + providers: + - provider: "OpenaiAccount" + model: "gpt-4o" + condition: "balance > 0 or error_count < 3" + - provider: "PollinationsAI" + model: "openai-large" + - name: "fast-model" + providers: + - provider: "Gemini" + model: "gemini-pro" + +The ``condition`` field is optional. When present it is a boolean expression +that can reference two variables: + +* ``balance`` – the provider's quota balance (float), fetched via + ``get_quota()`` and cached. +* ``error_count`` – the number of recent errors recorded for the provider + within a rolling one-hour window. + +Supported operators in conditions: ``>``, ``<``, ``>=``, ``<=``, ``==``, +``!=``, as well as ``and`` / ``or`` / ``not``. Only the two variables above +are available; arbitrary Python is **not** evaluated. +""" + +from __future__ import annotations + +import os +import re +import time +import operator +from dataclasses import dataclass, field +from typing import Optional, Dict, List, Tuple + +try: + import yaml + has_yaml = True +except ImportError: + has_yaml = False + +from ..typing import Messages, AsyncResult +from .base_provider import AsyncGeneratorProvider +from .response import ProviderInfo +from .. import debug + +# --------------------------------------------------------------------------- +# Quota cache +# --------------------------------------------------------------------------- + +class QuotaCache: + """Thread-safe in-memory cache for provider quota results. + + Quota values are cached for :attr:`ttl` seconds. The cache entry for a + provider can be forcibly invalidated (e.g. when a 429 response is + received) via :meth:`invalidate`. + """ + + ttl: float = 300 # seconds + + _cache: Dict[str, dict] = {} + _timestamps: Dict[str, float] = {} + + @classmethod + def get(cls, provider_name: str) -> Optional[dict]: + """Return the cached quota dict for *provider_name*, or ``None``.""" + if provider_name in cls._cache: + if time.time() - cls._timestamps.get(provider_name, 0) < cls.ttl: + return cls._cache[provider_name] + # Expired – remove stale entry + cls._cache.pop(provider_name, None) + cls._timestamps.pop(provider_name, None) + return None + + @classmethod + def set(cls, provider_name: str, quota: dict) -> None: + """Store *quota* for *provider_name*.""" + cls._cache[provider_name] = quota + cls._timestamps[provider_name] = time.time() + + @classmethod + def invalidate(cls, provider_name: str) -> None: + """Invalidate the cached quota for *provider_name*. + + Call this when a 429 (rate-limit) response is received so that + the next routing decision fetches a fresh quota value. + """ + cls._cache.pop(provider_name, None) + cls._timestamps.pop(provider_name, None) + + @classmethod + def clear(cls) -> None: + """Remove all cached entries.""" + cls._cache.clear() + cls._timestamps.clear() + + +# --------------------------------------------------------------------------- +# Error counter +# --------------------------------------------------------------------------- + +class ErrorCounter: + """Rolling-window error counter for providers. + + Errors are tracked with timestamps so that only errors that occurred + within the last :attr:`window` seconds are counted. + """ + + window: float = 3600 # 1 hour + + _timestamps: Dict[str, List[float]] = {} + + @classmethod + def increment(cls, provider_name: str) -> None: + """Record one error for *provider_name*.""" + now = time.time() + bucket = cls._timestamps.setdefault(provider_name, []) + bucket.append(now) + # Prune timestamps outside the rolling window + cls._timestamps[provider_name] = [t for t in bucket if now - t < cls.window] + + @classmethod + def get_count(cls, provider_name: str) -> int: + """Return the number of errors for *provider_name* in the current window.""" + now = time.time() + bucket = cls._timestamps.get(provider_name, []) + # Prune stale entries on read as well + fresh = [t for t in bucket if now - t < cls.window] + cls._timestamps[provider_name] = fresh + return len(fresh) + + @classmethod + def reset(cls, provider_name: str) -> None: + """Reset the error counter for *provider_name*.""" + cls._timestamps.pop(provider_name, None) + + @classmethod + def clear(cls) -> None: + """Reset all error counters.""" + cls._timestamps.clear() + + +# --------------------------------------------------------------------------- +# Condition evaluation +# --------------------------------------------------------------------------- + +_OPS: Dict[str, "Callable"] = { + ">": operator.gt, + "<": operator.lt, + ">=": operator.ge, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, +} + +# Tokenizer for simple condition expressions +_TOKEN_RE = re.compile( + r"(?P-?\d+\.\d+)" # float literal + r"|(?P-?\d+)" # integer literal + r"|(?P>=|<=|==|!=|>|<)" # comparison operator + r"|(?Pand|or|not)" # logical keywords + r"|(?P[a-zA-Z_][a-zA-Z0-9_.]*)" # identifier + r"|(?P\()" # left paren + r"|(?P\))" # right paren +) + + +def _tokenize(expr: str) -> List[Tuple[str, str]]: + tokens = [] + for m in _TOKEN_RE.finditer(expr.strip()): + kind = m.lastgroup + tokens.append((kind, m.group())) + return tokens + + +def _parse_expr(tokens: List[Tuple[str, str]], pos: int, variables: Dict[str, float]) -> Tuple[bool, int]: + """Recursive-descent parser for ``and``/``or``/``not``/comparisons.""" + return _parse_or(tokens, pos, variables) + + +def _parse_or(tokens, pos, variables): + left, pos = _parse_and(tokens, pos, variables) + while pos < len(tokens) and tokens[pos] == ("kw", "or"): + pos += 1 + right, pos = _parse_and(tokens, pos, variables) + left = left or right + return left, pos + + +def _parse_and(tokens, pos, variables): + left, pos = _parse_not(tokens, pos, variables) + while pos < len(tokens) and tokens[pos] == ("kw", "and"): + pos += 1 + right, pos = _parse_not(tokens, pos, variables) + left = left and right + return left, pos + + +def _parse_not(tokens, pos, variables): + if pos < len(tokens) and tokens[pos] == ("kw", "not"): + pos += 1 + val, pos = _parse_not(tokens, pos, variables) + return not val, pos + return _parse_comparison(tokens, pos, variables) + + +def _parse_comparison(tokens, pos, variables): + if pos < len(tokens) and tokens[pos][0] == "lp": + pos += 1 # consume '(' + val, pos = _parse_or(tokens, pos, variables) + if pos < len(tokens) and tokens[pos][0] == "rp": + pos += 1 # consume ')' + return val, pos + + left_val, pos = _parse_atom(tokens, pos, variables) + + if pos < len(tokens) and tokens[pos][0] == "op": + op_str = tokens[pos][1] + pos += 1 + right_val, pos = _parse_atom(tokens, pos, variables) + return _OPS[op_str](left_val, right_val), pos + + # Bare value – treat as truthy + return bool(left_val), pos + + +def _parse_atom(tokens, pos, variables): + if pos >= len(tokens): + raise ValueError("Unexpected end of condition expression") + + kind, value = tokens[pos] + pos += 1 + + if kind == "float": + return float(value), pos + elif kind == "int": + return int(value), pos + elif kind == "id": + # Resolve dotted names: "balance", "error_count", "get_quota.balance" + name = value + # Support "get_quota.balance" as an alias for "balance" + if name == "get_quota.balance": + name = "balance" + if name not in variables: + raise ValueError(f"Unknown variable in condition: {name!r}") + return variables[name], pos + else: + raise ValueError(f"Unexpected token {kind!r}={value!r} in condition expression") + + +def evaluate_condition( + condition: str, + balance: Optional[float], + error_count: int, +) -> bool: + """Evaluate a provider condition string. + + The condition may reference: + + * ``balance`` – provider quota balance (float). + * ``get_quota.balance`` – alias for ``balance``. + * ``error_count`` – recent error count (int). + + If *balance* is ``None`` the variable resolves to ``0.0``. + + Returns ``True`` if the provider should be used, ``False`` otherwise. + Raises :class:`ValueError` on parse errors. + """ + variables = { + "balance": float(balance) if balance is not None else 0.0, + "error_count": float(error_count), + } + tokens = _tokenize(condition) + if not tokens: + return True + result, _ = _parse_expr(tokens, 0, variables) + return result + + +# --------------------------------------------------------------------------- +# Config data structures +# --------------------------------------------------------------------------- + +@dataclass +class ProviderRouteConfig: + """A single provider entry inside a model route.""" + + provider: str + """Provider class name (e.g. ``"OpenaiAccount"``).""" + + model: str = "" + """Model name passed to the provider. Defaults to the route model name.""" + + condition: Optional[str] = None + """Optional boolean expression. If absent the provider is always eligible.""" + + +@dataclass +class ModelRouteConfig: + """Routing configuration for a single model name.""" + + name: str + """The model name as seen by the client (e.g. ``"my-gpt4"``).""" + + providers: List[ProviderRouteConfig] = field(default_factory=list) + """Ordered list of provider candidates.""" + + +# --------------------------------------------------------------------------- +# Global router state +# --------------------------------------------------------------------------- + +class RouterConfig: + """Singleton holding the active routing configuration.""" + + routes: Dict[str, ModelRouteConfig] = {} + """Mapping from model name → :class:`ModelRouteConfig`.""" + + @classmethod + def load(cls, path: str) -> None: + """Load and parse a ``config.yaml`` file at *path*. + + Silently skips the file if PyYAML is not installed or the file does + not exist. + """ + if not has_yaml: + debug.error("config.yaml: PyYAML is not installed – skipping config.yaml") + return + if not os.path.isfile(path): + return + try: + with open(path, "r", encoding="utf-8") as fh: + data = yaml.safe_load(fh) + except Exception as e: + debug.error(f"config.yaml: Failed to parse {path}:", e) + return + + if not isinstance(data, dict): + debug.error(f"config.yaml: Expected a mapping at top level in {path}") + return + + new_routes: Dict[str, ModelRouteConfig] = {} + for entry in data.get("models", []): + if not isinstance(entry, dict) or "name" not in entry: + continue + model_name = entry["name"] + provider_list: List[ProviderRouteConfig] = [] + for pentry in entry.get("providers", []): + if not isinstance(pentry, dict) or "provider" not in pentry: + continue + provider_list.append( + ProviderRouteConfig( + provider=pentry["provider"], + model=pentry.get("model", model_name), + condition=pentry.get("condition"), + ) + ) + if provider_list: + new_routes[model_name] = ModelRouteConfig( + name=model_name, + providers=provider_list, + ) + + cls.routes = new_routes + debug.log(f"config.yaml: Loaded {len(new_routes)} model route(s) from {path}") + + @classmethod + def clear(cls) -> None: + """Remove all loaded routes.""" + cls.routes.clear() + + @classmethod + def get(cls, model_name: str) -> Optional[ModelRouteConfig]: + """Return the :class:`ModelRouteConfig` for *model_name*, or ``None``.""" + return cls.routes.get(model_name) + + +# --------------------------------------------------------------------------- +# Config-based provider +# --------------------------------------------------------------------------- + +def _resolve_provider(provider_name: str): + """Resolve a provider name string to a provider class.""" + from .. import Provider + from ..Provider import ProviderUtils + + if provider_name in ProviderUtils.convert: + return ProviderUtils.convert[provider_name] + + # Try direct attribute lookup on the Provider module + provider = getattr(Provider, provider_name, None) + if provider is not None: + return provider + + raise ValueError(f"Provider not found: {provider_name!r}") + + +async def _get_quota_cached(provider) -> Optional[dict]: + """Return quota info for *provider*, using the cache when possible.""" + name = getattr(provider, "__name__", str(provider)) + cached = QuotaCache.get(name) + if cached is not None: + return cached + if not hasattr(provider, "get_quota"): + return None + try: + quota = await provider.get_quota() + if quota is not None: + QuotaCache.set(name, quota) + return quota + except Exception as e: + debug.error(f"config.yaml: get_quota failed for {name}:", e) + return None + + +def _check_condition( + route_cfg: ProviderRouteConfig, + provider, + quota: Optional[dict], +) -> bool: + """Return ``True`` if the provider satisfies the route condition.""" + if not route_cfg.condition: + return True + balance: Optional[float] = None + if quota is not None: + balance = quota.get("balance") + provider_name = getattr(provider, "__name__", str(provider)) + error_count = ErrorCounter.get_count(provider_name) + try: + return evaluate_condition(route_cfg.condition, balance, error_count) + except ValueError as e: + debug.error(f"config.yaml: Invalid condition {route_cfg.condition!r}:", e) + return False # Default to skip on parse error + + +class ConfigModelProvider(AsyncGeneratorProvider): + """An async generator provider that routes requests using ``config.yaml``. + + This provider is instantiated per model name and tries each configured + provider in order, skipping those that fail their condition check. On a + 429 error the quota cache for the failing provider is invalidated so that + the next call fetches a fresh quota value. + """ + + working = True + supports_stream = True + supports_message_history = True + + def __init__(self, route_config: ModelRouteConfig) -> None: + self._route_config = route_config + self.__name__ = f"ConfigRouter[{route_config.name}]" + + # Make it usable as an instance (not just a class) + async def create_async_generator( + self, + model: str, + messages: Messages, + **kwargs, + ) -> AsyncResult: + """Yield response chunks, routing through configured providers.""" + last_exception: Optional[Exception] = None + tried: List[str] = [] + + for prc in self._route_config.providers: + try: + provider = _resolve_provider(prc.provider) + except ValueError as e: + debug.error(f"config.yaml: {e}") + continue + + provider_name = getattr(provider, "__name__", prc.provider) + + # Fetch quota (cached) + quota = await _get_quota_cached(provider) + + # Evaluate condition + if not _check_condition(prc, provider, quota): + debug.log( + f"config.yaml: Skipping {provider_name} " + f"(condition not met: {prc.condition!r})" + ) + continue + + target_model = prc.model or model + tried.append(provider_name) + + yield ProviderInfo( + name=provider_name, + url=getattr(provider, "url", ""), + label=getattr(provider, "label", None), + model=target_model, + ) + + try: + if hasattr(provider, "create_async_generator"): + async for chunk in provider.create_async_generator( + target_model, messages, **kwargs + ): + yield chunk + elif hasattr(provider, "create_completion"): + for chunk in provider.create_completion( + target_model, messages, stream=True, **kwargs + ): + yield chunk + else: + raise NotImplementedError( + f"{provider_name} has no supported create method" + ) + debug.log(f"config.yaml: {provider_name} succeeded for model {model!r}") + return # Success + except Exception as e: + # On rate-limit errors invalidate the quota cache + from ..errors import RateLimitError + if isinstance(e, RateLimitError) or "429" in str(e): + debug.log( + f"config.yaml: Rate-limited by {provider_name}, " + "invalidating quota cache" + ) + QuotaCache.invalidate(provider_name) + + ErrorCounter.increment(provider_name) + last_exception = e + debug.error(f"config.yaml: {provider_name} failed:", e) + + if last_exception is not None: + raise last_exception + raise RuntimeError( + f"config.yaml: No provider succeeded for model {model!r}. " + f"Tried: {tried}" + )