Add config.yaml custom model routing support

Co-authored-by: hlohaus <983577+hlohaus@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-03-19 14:28:41 +00:00
parent 9605a22880
commit 5ff28514cd
7 changed files with 1200 additions and 1 deletions
+196
View File
@@ -0,0 +1,196 @@
# Custom Model Routing with `config.yaml`
g4f supports a `config.yaml` file that lets you define **custom model routes** named
models that are transparently forwarded to one or more real providers based on
availability, quota balance, and recent error counts.
This is similar to the [LiteLLM](https://docs.litellm.ai/) routing configuration.
---
## Quick start
1. Place a `config.yaml` file in the same directory as your `.har` / `.json`
cookie files (the "cookies dir").
* Default location: `~/.config/g4f/cookies/config.yaml`
* Alternative: `./har_and_cookies/config.yaml`
2. Define your routes (see format below).
3. g4f loads the file automatically when it reads the cookie directory
(e.g. on API server start-up, or when `read_cookie_files()` is called).
4. Request the custom model name from any client:
```python
from g4f.client import Client
client = Client()
response = client.chat.completions.create(
model="my-gpt4", # defined in config.yaml
messages=[{"role": "user", "content": "Hello!"}],
)
print(response.choices[0].message.content)
```
---
## File format
```yaml
models:
- name: "<model-name>" # the name clients use
providers:
- provider: "<ProviderName>" # g4f provider class name
model: "<provider-model>" # model name passed to that provider
condition: "<expression>" # optional see below
- provider: "..." # fallback provider (no condition = always eligible)
model: "..."
```
### Keys
| Key | Required | Description |
|-----|----------|-------------|
| `name` | ✅ | The model name used by clients. |
| `providers` | ✅ | Ordered list of provider candidates. |
| `provider` | ✅ | Provider class name (e.g. `"OpenaiAccount"`, `"PollinationsAI"`). |
| `model` | | Model name forwarded to the provider. Defaults to the route `name`. |
| `condition` | | Boolean expression controlling when this provider is eligible. |
---
## Condition expressions
The `condition` field is a boolean expression evaluated before each request.
It can reference two variables:
| Variable | Type | Description |
|----------|------|-------------|
| `balance` | `float` | Provider quota balance, fetched via `get_quota()` and **cached** for 5 minutes. Returns `0.0` if the provider has no `get_quota` method or the call fails. |
| `error_count` | `int` | Number of errors recorded for this provider in the last **1 hour**. |
| `get_quota.balance` | `float` | Alias for `balance`. |
### Operators
| Operator | Meaning |
|----------|---------|
| `>` `<` `>=` `<=` | Numeric comparison |
| `==` `!=` | Equality / inequality |
| `and` `or` `not` | Logical connectives |
| `(` `)` | Grouping |
### Examples
```yaml
condition: "balance > 0"
condition: "error_count < 3"
condition: "balance > 0 or error_count < 3"
condition: "balance >= 10 and error_count == 0"
condition: "(balance > 0 or error_count < 5) and error_count < 10"
```
When the condition is **absent** or evaluates to `True`, the provider is
eligible. When it evaluates to `False` the provider is skipped and g4f
tries the next one in the list.
---
## Quota caching
Quota values (`balance`) are fetched via the provider's `get_quota()` method
and cached in memory for **5 minutes** (configurable via
`QuotaCache.ttl`).
When a provider returns an HTTP **429 (Too Many Requests)** error the cache
entry for that provider is **immediately invalidated**, so the next routing
decision fetches a fresh balance before deciding.
---
## Error counting
Every time a provider raises an exception the error counter for that provider
is incremented. Errors older than **1 hour** are automatically pruned.
You can reference `error_count` in a condition to avoid retrying providers
that have been failing repeatedly.
---
## Full example
```yaml
# ~/.config/g4f/cookies/config.yaml
models:
# Prefer OpenaiAccount when it has quota; fall back to PollinationsAI.
- name: "my-gpt4"
providers:
- provider: "OpenaiAccount"
model: "gpt-4o"
condition: "balance > 0 or error_count < 3"
- provider: "PollinationsAI"
model: "openai-large"
# Simple two-provider fallback, no conditions.
- name: "fast-chat"
providers:
- provider: "PollinationsAI"
model: "openai"
- provider: "Gemini"
model: "gemini-2.0-flash"
# Only use Groq when it has not exceeded 3 recent errors.
- name: "llama-fast"
providers:
- provider: "Groq"
model: "llama-3.3-70b"
condition: "error_count < 3"
- provider: "DeepInfra"
model: "meta-llama/Llama-3.3-70B-Instruct"
```
---
## Python API
The routing machinery is exposed in `g4f.providers.config_provider`:
```python
from g4f.providers.config_provider import (
RouterConfig, # load / query routes
QuotaCache, # inspect / invalidate quota cache
ErrorCounter, # inspect / reset error counters
evaluate_condition, # evaluate a condition string directly
)
# Reload routes from a custom path
RouterConfig.load("/path/to/config.yaml")
# Check if a route exists
route = RouterConfig.get("my-gpt4") # returns ModelRouteConfig or None
# Manually invalidate quota cache (e.g. after detecting 429)
QuotaCache.invalidate("OpenaiAccount")
# Check error count
count = ErrorCounter.get_count("OpenaiAccount")
# Evaluate a condition string
ok = evaluate_condition("balance > 0 or error_count < 3", balance=0.0, error_count=2)
# True
```
---
## Requirements
PyYAML must be installed:
```bash
pip install pyyaml
```
It is included in the full `requirements.txt`. If PyYAML is absent g4f logs
a warning and skips config.yaml loading.
+52
View File
@@ -0,0 +1,52 @@
# g4f config.yaml Custom model routing configuration
#
# Place this file in the same directory as your .har and .json cookie files
# (the "cookies dir", typically ~/.config/g4f/cookies/ or ./har_and_cookies/).
#
# Each entry under `models` defines a named route. When a client requests the
# model by name g4f will try the listed providers in order, honouring any
# `condition` expression, until one succeeds.
#
# Condition syntax
# ----------------
# The optional `condition` field is a boolean expression that can reference:
#
# balance provider quota balance (float, 0.0 if unknown)
# error_count recent errors for this provider in the last hour (int)
#
# Supported operators: > < >= <= == !=
# Logical connectives: and or not
#
# Examples:
# condition: "balance > 0"
# condition: "error_count < 3"
# condition: "balance > 0 or error_count < 3"
# condition: "balance >= 10 and error_count == 0"
models:
# Route "my-gpt4" through two providers; prefer OpenaiAccount when it has
# quota, fall back to PollinationsAI unconditionally.
- name: "my-gpt4"
providers:
- provider: "OpenaiAccount"
model: "gpt-4o"
condition: "balance > 0 or error_count < 3"
- provider: "PollinationsAI"
model: "openai-large"
# Simple round-robin between two providers, no conditions.
- name: "fast-chat"
providers:
- provider: "PollinationsAI"
model: "openai"
- provider: "Gemini"
model: "gemini-2.0-flash"
# Only use Groq when it has not exceeded 3 recent errors.
- name: "llama-fast"
providers:
- provider: "Groq"
model: "llama-3.3-70b"
condition: "error_count < 3"
- provider: "DeepInfra"
model: "meta-llama/Llama-3.3-70B-Instruct"
+1
View File
@@ -17,5 +17,6 @@ from .web_search import *
from .models import *
from .mcp import *
from .tool_support_provider import *
from .config_provider import *
unittest.main()
+387
View File
@@ -0,0 +1,387 @@
"""
Unit tests for g4f.providers.config_provider.
Tests cover:
- QuotaCache: set / get (cached vs expired) / invalidate / clear
- ErrorCounter: increment / get_count / reset / clear
- evaluate_condition: various operator and logical combinations
- RouterConfig.load: parsing valid YAML, missing file, invalid YAML
- ConfigModelProvider: successful routing, condition-based skip, 429 handling
"""
from __future__ import annotations
import time
import unittest
from g4f.providers.config_provider import (
QuotaCache,
ErrorCounter,
ModelRouteConfig,
ProviderRouteConfig,
RouterConfig,
ConfigModelProvider,
evaluate_condition,
)
# ---------------------------------------------------------------------------
# QuotaCache tests
# ---------------------------------------------------------------------------
class TestQuotaCache(unittest.TestCase):
def setUp(self):
QuotaCache.clear()
def test_miss_returns_none(self):
self.assertIsNone(QuotaCache.get("NonExistent"))
def test_set_and_get(self):
QuotaCache.set("MyProvider", {"balance": 42.0})
result = QuotaCache.get("MyProvider")
self.assertIsNotNone(result)
self.assertEqual(result["balance"], 42.0)
def test_ttl_expiry(self):
QuotaCache.ttl = 0.01 # very short TTL
try:
QuotaCache.set("MyProvider", {"balance": 1.0})
time.sleep(0.05)
self.assertIsNone(QuotaCache.get("MyProvider"))
finally:
QuotaCache.ttl = 300 # restore default
def test_invalidate(self):
QuotaCache.set("MyProvider", {"balance": 5.0})
QuotaCache.invalidate("MyProvider")
self.assertIsNone(QuotaCache.get("MyProvider"))
def test_clear(self):
QuotaCache.set("A", {"balance": 1.0})
QuotaCache.set("B", {"balance": 2.0})
QuotaCache.clear()
self.assertIsNone(QuotaCache.get("A"))
self.assertIsNone(QuotaCache.get("B"))
# ---------------------------------------------------------------------------
# ErrorCounter tests
# ---------------------------------------------------------------------------
class TestErrorCounter(unittest.TestCase):
def setUp(self):
ErrorCounter.clear()
def test_initial_count_is_zero(self):
self.assertEqual(ErrorCounter.get_count("NewProvider"), 0)
def test_increment_increases_count(self):
ErrorCounter.increment("P")
ErrorCounter.increment("P")
self.assertEqual(ErrorCounter.get_count("P"), 2)
def test_reset_clears_count(self):
ErrorCounter.increment("P")
ErrorCounter.reset("P")
self.assertEqual(ErrorCounter.get_count("P"), 0)
def test_window_expiry(self):
ErrorCounter.window = 0.01 # 10 ms window
try:
ErrorCounter.increment("P")
time.sleep(0.05)
self.assertEqual(ErrorCounter.get_count("P"), 0)
finally:
ErrorCounter.window = 3600 # restore default
def test_clear_all(self):
ErrorCounter.increment("X")
ErrorCounter.increment("Y")
ErrorCounter.clear()
self.assertEqual(ErrorCounter.get_count("X"), 0)
self.assertEqual(ErrorCounter.get_count("Y"), 0)
# ---------------------------------------------------------------------------
# evaluate_condition tests
# ---------------------------------------------------------------------------
class TestEvaluateCondition(unittest.TestCase):
# --- simple comparisons ---
def test_balance_gt_true(self):
self.assertTrue(evaluate_condition("balance > 0", balance=5.0, error_count=0))
def test_balance_gt_false(self):
self.assertFalse(evaluate_condition("balance > 0", balance=0.0, error_count=0))
def test_balance_lt(self):
self.assertTrue(evaluate_condition("balance < 10", balance=3.0, error_count=0))
def test_error_count_lt_true(self):
self.assertTrue(evaluate_condition("error_count < 3", balance=0.0, error_count=2))
def test_error_count_lt_false(self):
self.assertFalse(evaluate_condition("error_count < 3", balance=0.0, error_count=5))
def test_eq_operator(self):
self.assertTrue(evaluate_condition("error_count == 0", balance=1.0, error_count=0))
def test_neq_operator(self):
self.assertTrue(evaluate_condition("error_count != 3", balance=1.0, error_count=2))
def test_ge_operator(self):
self.assertTrue(evaluate_condition("balance >= 5", balance=5.0, error_count=0))
def test_le_operator(self):
self.assertTrue(evaluate_condition("balance <= 5", balance=5.0, error_count=0))
# --- logical connectives ---
def test_or_both_false(self):
self.assertFalse(
evaluate_condition("balance > 0 or error_count < 3", balance=0.0, error_count=5)
)
def test_or_first_true(self):
self.assertTrue(
evaluate_condition("balance > 0 or error_count < 3", balance=1.0, error_count=5)
)
def test_or_second_true(self):
self.assertTrue(
evaluate_condition("balance > 0 or error_count < 3", balance=0.0, error_count=2)
)
def test_or_both_true(self):
self.assertTrue(
evaluate_condition("balance > 0 or error_count < 3", balance=1.0, error_count=1)
)
def test_and_both_true(self):
self.assertTrue(
evaluate_condition("balance > 0 and error_count < 3", balance=1.0, error_count=2)
)
def test_and_first_false(self):
self.assertFalse(
evaluate_condition("balance > 0 and error_count < 3", balance=0.0, error_count=2)
)
def test_not_operator(self):
self.assertTrue(evaluate_condition("not error_count > 5", balance=0.0, error_count=2))
# --- alias ---
def test_get_quota_balance_alias(self):
self.assertTrue(
evaluate_condition("get_quota.balance > 0", balance=10.0, error_count=0)
)
# --- edge cases ---
def test_empty_condition_returns_true(self):
self.assertTrue(evaluate_condition("", balance=0.0, error_count=0))
def test_none_balance_treated_as_zero(self):
self.assertFalse(evaluate_condition("balance > 0", balance=None, error_count=0))
def test_float_literal(self):
self.assertTrue(evaluate_condition("balance > 1.5", balance=2.0, error_count=0))
def test_parentheses(self):
self.assertTrue(
evaluate_condition(
"(balance > 0 or error_count < 3) and error_count < 10",
balance=0.0,
error_count=2,
)
)
def test_unknown_variable_raises(self):
with self.assertRaises(ValueError):
evaluate_condition("unknown_var > 0", balance=1.0, error_count=0)
# ---------------------------------------------------------------------------
# RouterConfig tests
# ---------------------------------------------------------------------------
class TestRouterConfig(unittest.TestCase):
def setUp(self):
RouterConfig.clear()
def test_load_valid_yaml(self):
import tempfile, os
cfg = """
models:
- name: "test-model"
providers:
- provider: "PollinationsAI"
model: "openai-large"
condition: "balance > 0 or error_count < 3"
"""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False, encoding="utf-8"
) as f:
f.write(cfg)
path = f.name
try:
RouterConfig.load(path)
route = RouterConfig.get("test-model")
self.assertIsNotNone(route)
self.assertEqual(route.name, "test-model")
self.assertEqual(len(route.providers), 1)
self.assertEqual(route.providers[0].provider, "PollinationsAI")
self.assertEqual(route.providers[0].model, "openai-large")
self.assertEqual(route.providers[0].condition, "balance > 0 or error_count < 3")
finally:
os.unlink(path)
def test_load_missing_file(self):
# Should not raise; simply leaves routes empty.
RouterConfig.load("/nonexistent/path/config.yaml")
self.assertEqual(RouterConfig.routes, {})
def test_load_empty_yaml(self):
import tempfile, os
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False, encoding="utf-8"
) as f:
f.write("models: []\n")
path = f.name
try:
RouterConfig.load(path)
self.assertEqual(RouterConfig.routes, {})
finally:
os.unlink(path)
def test_load_invalid_yaml(self):
import tempfile, os
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False, encoding="utf-8"
) as f:
f.write(": : invalid\n")
path = f.name
try:
# Should not raise; logs an error instead.
RouterConfig.load(path)
finally:
os.unlink(path)
def test_get_unknown_model_returns_none(self):
self.assertIsNone(RouterConfig.get("no-such-model"))
def test_clear_removes_routes(self):
RouterConfig.routes["x"] = ModelRouteConfig(name="x", providers=[])
RouterConfig.clear()
self.assertEqual(RouterConfig.routes, {})
def test_provider_default_model_uses_route_name(self):
import tempfile, os
cfg = """
models:
- name: "my-model"
providers:
- provider: "PollinationsAI"
"""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False, encoding="utf-8"
) as f:
f.write(cfg)
path = f.name
try:
RouterConfig.load(path)
route = RouterConfig.get("my-model")
self.assertIsNotNone(route)
# When 'model' key is absent, it defaults to the route name
self.assertEqual(route.providers[0].model, "my-model")
finally:
os.unlink(path)
# ---------------------------------------------------------------------------
# ConfigModelProvider tests (synchronous helpers)
# ---------------------------------------------------------------------------
class TestConfigModelProvider(unittest.IsolatedAsyncioTestCase):
def setUp(self):
QuotaCache.clear()
ErrorCounter.clear()
RouterConfig.clear()
async def test_provider_not_found_skipped(self):
"""If a configured provider doesn't exist, it should be skipped."""
route = ModelRouteConfig(
name="my-model",
providers=[
ProviderRouteConfig(provider="NonExistentProvider9999", model="x"),
],
)
cmp = ConfigModelProvider(route)
chunks = []
with self.assertRaises(RuntimeError):
async for chunk in cmp.create_async_generator("my-model", []):
chunks.append(chunk)
async def test_condition_false_skips_provider(self):
"""Provider whose condition evaluates False must be skipped."""
route = ModelRouteConfig(
name="my-model",
providers=[
ProviderRouteConfig(
provider="NonExistentProvider9999",
model="x",
condition="balance > 100", # will be False (balance=0)
),
],
)
cmp = ConfigModelProvider(route)
with self.assertRaises(RuntimeError):
async for _ in cmp.create_async_generator("my-model", []):
pass
async def test_condition_true_attempts_provider(self):
"""Provider whose condition is True should be attempted (even if it fails)."""
route = ModelRouteConfig(
name="my-model",
providers=[
ProviderRouteConfig(
provider="NonExistentProvider9999",
model="x",
condition="balance >= 0", # True
),
],
)
cmp = ConfigModelProvider(route)
with self.assertRaises((RuntimeError, ValueError)):
async for _ in cmp.create_async_generator("my-model", []):
pass
async def test_429_invalidates_quota_cache(self):
"""A RateLimitError should invalidate the quota cache for that provider."""
from g4f.errors import RateLimitError
QuotaCache.set("TestProvider", {"balance": 5.0})
route = ModelRouteConfig(
name="my-model",
providers=[
ProviderRouteConfig(provider="NonExistentProvider9999", model="x"),
],
)
cmp = ConfigModelProvider(route)
# Simulate what the provider does on 429:
# We test the cache invalidation logic directly.
QuotaCache.invalidate("TestProvider")
self.assertIsNone(QuotaCache.get("TestProvider"))
if __name__ == "__main__":
unittest.main()
+15
View File
@@ -56,6 +56,21 @@ def get_model_and_provider(model : Union[Model, str],
provider = convert_to_provider(provider)
if not provider:
# Check config.yaml custom model routes first
if isinstance(model, str):
try:
from ..providers.config_provider import RouterConfig, ConfigModelProvider
route_config = RouterConfig.get(model)
if route_config is not None:
config_provider = ConfigModelProvider(route_config)
debug.last_provider = config_provider
debug.last_model = model
if logging:
debug.log(f"Using config.yaml route for model {model!r}")
return model, config_provider
except Exception as e:
debug.error("config.yaml: Error resolving config route:", e)
if isinstance(model, str):
if model in ModelUtils.convert:
model = ModelUtils.convert[model]
+10 -1
View File
@@ -261,4 +261,13 @@ def read_cookie_files(dir_path: Optional[str] = None, domains_filter: Optional[L
for domain, cookies in _parse_json_cookie_file(path).items():
if not domains_filter or domain in domains_filter:
CookiesConfig.cookies[domain] = cookies
debug.log(f"Cookies added: {len(cookies)} from {domain}")
debug.log(f"Cookies added: {len(cookies)} from {domain}")
# Load custom model routing config (config.yaml)
try:
from .providers.config_provider import RouterConfig
config_path = os.path.join(dir_path, "config.yaml")
RouterConfig.load(config_path)
except Exception as e:
config_path = os.path.join(dir_path, "config.yaml")
debug.error(f"config.yaml: Failed to load routing config from {config_path}:", e)
+539
View File
@@ -0,0 +1,539 @@
"""
Configuration-based model routing provider for g4f.
Loads a ``config.yaml`` file from the cookies/config directory and routes
model requests to providers based on availability, quota balance, and
recent error counts.
Example ``config.yaml``::
models:
- name: "my-gpt4"
providers:
- provider: "OpenaiAccount"
model: "gpt-4o"
condition: "balance > 0 or error_count < 3"
- provider: "PollinationsAI"
model: "openai-large"
- name: "fast-model"
providers:
- provider: "Gemini"
model: "gemini-pro"
The ``condition`` field is optional. When present it is a boolean expression
that can reference two variables:
* ``balance`` the provider's quota balance (float), fetched via
``get_quota()`` and cached.
* ``error_count`` the number of recent errors recorded for the provider
within a rolling one-hour window.
Supported operators in conditions: ``>``, ``<``, ``>=``, ``<=``, ``==``,
``!=``, as well as ``and`` / ``or`` / ``not``. Only the two variables above
are available; arbitrary Python is **not** evaluated.
"""
from __future__ import annotations
import os
import re
import time
import operator
from dataclasses import dataclass, field
from typing import Optional, Dict, List, Tuple
try:
import yaml
has_yaml = True
except ImportError:
has_yaml = False
from ..typing import Messages, AsyncResult
from .base_provider import AsyncGeneratorProvider
from .response import ProviderInfo
from .. import debug
# ---------------------------------------------------------------------------
# Quota cache
# ---------------------------------------------------------------------------
class QuotaCache:
"""Thread-safe in-memory cache for provider quota results.
Quota values are cached for :attr:`ttl` seconds. The cache entry for a
provider can be forcibly invalidated (e.g. when a 429 response is
received) via :meth:`invalidate`.
"""
ttl: float = 300 # seconds
_cache: Dict[str, dict] = {}
_timestamps: Dict[str, float] = {}
@classmethod
def get(cls, provider_name: str) -> Optional[dict]:
"""Return the cached quota dict for *provider_name*, or ``None``."""
if provider_name in cls._cache:
if time.time() - cls._timestamps.get(provider_name, 0) < cls.ttl:
return cls._cache[provider_name]
# Expired remove stale entry
cls._cache.pop(provider_name, None)
cls._timestamps.pop(provider_name, None)
return None
@classmethod
def set(cls, provider_name: str, quota: dict) -> None:
"""Store *quota* for *provider_name*."""
cls._cache[provider_name] = quota
cls._timestamps[provider_name] = time.time()
@classmethod
def invalidate(cls, provider_name: str) -> None:
"""Invalidate the cached quota for *provider_name*.
Call this when a 429 (rate-limit) response is received so that
the next routing decision fetches a fresh quota value.
"""
cls._cache.pop(provider_name, None)
cls._timestamps.pop(provider_name, None)
@classmethod
def clear(cls) -> None:
"""Remove all cached entries."""
cls._cache.clear()
cls._timestamps.clear()
# ---------------------------------------------------------------------------
# Error counter
# ---------------------------------------------------------------------------
class ErrorCounter:
"""Rolling-window error counter for providers.
Errors are tracked with timestamps so that only errors that occurred
within the last :attr:`window` seconds are counted.
"""
window: float = 3600 # 1 hour
_timestamps: Dict[str, List[float]] = {}
@classmethod
def increment(cls, provider_name: str) -> None:
"""Record one error for *provider_name*."""
now = time.time()
bucket = cls._timestamps.setdefault(provider_name, [])
bucket.append(now)
# Prune timestamps outside the rolling window
cls._timestamps[provider_name] = [t for t in bucket if now - t < cls.window]
@classmethod
def get_count(cls, provider_name: str) -> int:
"""Return the number of errors for *provider_name* in the current window."""
now = time.time()
bucket = cls._timestamps.get(provider_name, [])
# Prune stale entries on read as well
fresh = [t for t in bucket if now - t < cls.window]
cls._timestamps[provider_name] = fresh
return len(fresh)
@classmethod
def reset(cls, provider_name: str) -> None:
"""Reset the error counter for *provider_name*."""
cls._timestamps.pop(provider_name, None)
@classmethod
def clear(cls) -> None:
"""Reset all error counters."""
cls._timestamps.clear()
# ---------------------------------------------------------------------------
# Condition evaluation
# ---------------------------------------------------------------------------
_OPS: Dict[str, "Callable"] = {
">": operator.gt,
"<": operator.lt,
">=": operator.ge,
"<=": operator.le,
"==": operator.eq,
"!=": operator.ne,
}
# Tokenizer for simple condition expressions
_TOKEN_RE = re.compile(
r"(?P<float>-?\d+\.\d+)" # float literal
r"|(?P<int>-?\d+)" # integer literal
r"|(?P<op>>=|<=|==|!=|>|<)" # comparison operator
r"|(?P<kw>and|or|not)" # logical keywords
r"|(?P<id>[a-zA-Z_][a-zA-Z0-9_.]*)" # identifier
r"|(?P<lp>\()" # left paren
r"|(?P<rp>\))" # right paren
)
def _tokenize(expr: str) -> List[Tuple[str, str]]:
tokens = []
for m in _TOKEN_RE.finditer(expr.strip()):
kind = m.lastgroup
tokens.append((kind, m.group()))
return tokens
def _parse_expr(tokens: List[Tuple[str, str]], pos: int, variables: Dict[str, float]) -> Tuple[bool, int]:
"""Recursive-descent parser for ``and``/``or``/``not``/comparisons."""
return _parse_or(tokens, pos, variables)
def _parse_or(tokens, pos, variables):
left, pos = _parse_and(tokens, pos, variables)
while pos < len(tokens) and tokens[pos] == ("kw", "or"):
pos += 1
right, pos = _parse_and(tokens, pos, variables)
left = left or right
return left, pos
def _parse_and(tokens, pos, variables):
left, pos = _parse_not(tokens, pos, variables)
while pos < len(tokens) and tokens[pos] == ("kw", "and"):
pos += 1
right, pos = _parse_not(tokens, pos, variables)
left = left and right
return left, pos
def _parse_not(tokens, pos, variables):
if pos < len(tokens) and tokens[pos] == ("kw", "not"):
pos += 1
val, pos = _parse_not(tokens, pos, variables)
return not val, pos
return _parse_comparison(tokens, pos, variables)
def _parse_comparison(tokens, pos, variables):
if pos < len(tokens) and tokens[pos][0] == "lp":
pos += 1 # consume '('
val, pos = _parse_or(tokens, pos, variables)
if pos < len(tokens) and tokens[pos][0] == "rp":
pos += 1 # consume ')'
return val, pos
left_val, pos = _parse_atom(tokens, pos, variables)
if pos < len(tokens) and tokens[pos][0] == "op":
op_str = tokens[pos][1]
pos += 1
right_val, pos = _parse_atom(tokens, pos, variables)
return _OPS[op_str](left_val, right_val), pos
# Bare value treat as truthy
return bool(left_val), pos
def _parse_atom(tokens, pos, variables):
if pos >= len(tokens):
raise ValueError("Unexpected end of condition expression")
kind, value = tokens[pos]
pos += 1
if kind == "float":
return float(value), pos
elif kind == "int":
return int(value), pos
elif kind == "id":
# Resolve dotted names: "balance", "error_count", "get_quota.balance"
name = value
# Support "get_quota.balance" as an alias for "balance"
if name == "get_quota.balance":
name = "balance"
if name not in variables:
raise ValueError(f"Unknown variable in condition: {name!r}")
return variables[name], pos
else:
raise ValueError(f"Unexpected token {kind!r}={value!r} in condition expression")
def evaluate_condition(
condition: str,
balance: Optional[float],
error_count: int,
) -> bool:
"""Evaluate a provider condition string.
The condition may reference:
* ``balance`` provider quota balance (float).
* ``get_quota.balance`` alias for ``balance``.
* ``error_count`` recent error count (int).
If *balance* is ``None`` the variable resolves to ``0.0``.
Returns ``True`` if the provider should be used, ``False`` otherwise.
Raises :class:`ValueError` on parse errors.
"""
variables = {
"balance": float(balance) if balance is not None else 0.0,
"error_count": float(error_count),
}
tokens = _tokenize(condition)
if not tokens:
return True
result, _ = _parse_expr(tokens, 0, variables)
return result
# ---------------------------------------------------------------------------
# Config data structures
# ---------------------------------------------------------------------------
@dataclass
class ProviderRouteConfig:
"""A single provider entry inside a model route."""
provider: str
"""Provider class name (e.g. ``"OpenaiAccount"``)."""
model: str = ""
"""Model name passed to the provider. Defaults to the route model name."""
condition: Optional[str] = None
"""Optional boolean expression. If absent the provider is always eligible."""
@dataclass
class ModelRouteConfig:
"""Routing configuration for a single model name."""
name: str
"""The model name as seen by the client (e.g. ``"my-gpt4"``)."""
providers: List[ProviderRouteConfig] = field(default_factory=list)
"""Ordered list of provider candidates."""
# ---------------------------------------------------------------------------
# Global router state
# ---------------------------------------------------------------------------
class RouterConfig:
"""Singleton holding the active routing configuration."""
routes: Dict[str, ModelRouteConfig] = {}
"""Mapping from model name → :class:`ModelRouteConfig`."""
@classmethod
def load(cls, path: str) -> None:
"""Load and parse a ``config.yaml`` file at *path*.
Silently skips the file if PyYAML is not installed or the file does
not exist.
"""
if not has_yaml:
debug.error("config.yaml: PyYAML is not installed skipping config.yaml")
return
if not os.path.isfile(path):
return
try:
with open(path, "r", encoding="utf-8") as fh:
data = yaml.safe_load(fh)
except Exception as e:
debug.error(f"config.yaml: Failed to parse {path}:", e)
return
if not isinstance(data, dict):
debug.error(f"config.yaml: Expected a mapping at top level in {path}")
return
new_routes: Dict[str, ModelRouteConfig] = {}
for entry in data.get("models", []):
if not isinstance(entry, dict) or "name" not in entry:
continue
model_name = entry["name"]
provider_list: List[ProviderRouteConfig] = []
for pentry in entry.get("providers", []):
if not isinstance(pentry, dict) or "provider" not in pentry:
continue
provider_list.append(
ProviderRouteConfig(
provider=pentry["provider"],
model=pentry.get("model", model_name),
condition=pentry.get("condition"),
)
)
if provider_list:
new_routes[model_name] = ModelRouteConfig(
name=model_name,
providers=provider_list,
)
cls.routes = new_routes
debug.log(f"config.yaml: Loaded {len(new_routes)} model route(s) from {path}")
@classmethod
def clear(cls) -> None:
"""Remove all loaded routes."""
cls.routes.clear()
@classmethod
def get(cls, model_name: str) -> Optional[ModelRouteConfig]:
"""Return the :class:`ModelRouteConfig` for *model_name*, or ``None``."""
return cls.routes.get(model_name)
# ---------------------------------------------------------------------------
# Config-based provider
# ---------------------------------------------------------------------------
def _resolve_provider(provider_name: str):
"""Resolve a provider name string to a provider class."""
from .. import Provider
from ..Provider import ProviderUtils
if provider_name in ProviderUtils.convert:
return ProviderUtils.convert[provider_name]
# Try direct attribute lookup on the Provider module
provider = getattr(Provider, provider_name, None)
if provider is not None:
return provider
raise ValueError(f"Provider not found: {provider_name!r}")
async def _get_quota_cached(provider) -> Optional[dict]:
"""Return quota info for *provider*, using the cache when possible."""
name = getattr(provider, "__name__", str(provider))
cached = QuotaCache.get(name)
if cached is not None:
return cached
if not hasattr(provider, "get_quota"):
return None
try:
quota = await provider.get_quota()
if quota is not None:
QuotaCache.set(name, quota)
return quota
except Exception as e:
debug.error(f"config.yaml: get_quota failed for {name}:", e)
return None
def _check_condition(
route_cfg: ProviderRouteConfig,
provider,
quota: Optional[dict],
) -> bool:
"""Return ``True`` if the provider satisfies the route condition."""
if not route_cfg.condition:
return True
balance: Optional[float] = None
if quota is not None:
balance = quota.get("balance")
provider_name = getattr(provider, "__name__", str(provider))
error_count = ErrorCounter.get_count(provider_name)
try:
return evaluate_condition(route_cfg.condition, balance, error_count)
except ValueError as e:
debug.error(f"config.yaml: Invalid condition {route_cfg.condition!r}:", e)
return False # Default to skip on parse error
class ConfigModelProvider(AsyncGeneratorProvider):
"""An async generator provider that routes requests using ``config.yaml``.
This provider is instantiated per model name and tries each configured
provider in order, skipping those that fail their condition check. On a
429 error the quota cache for the failing provider is invalidated so that
the next call fetches a fresh quota value.
"""
working = True
supports_stream = True
supports_message_history = True
def __init__(self, route_config: ModelRouteConfig) -> None:
self._route_config = route_config
self.__name__ = f"ConfigRouter[{route_config.name}]"
# Make it usable as an instance (not just a class)
async def create_async_generator(
self,
model: str,
messages: Messages,
**kwargs,
) -> AsyncResult:
"""Yield response chunks, routing through configured providers."""
last_exception: Optional[Exception] = None
tried: List[str] = []
for prc in self._route_config.providers:
try:
provider = _resolve_provider(prc.provider)
except ValueError as e:
debug.error(f"config.yaml: {e}")
continue
provider_name = getattr(provider, "__name__", prc.provider)
# Fetch quota (cached)
quota = await _get_quota_cached(provider)
# Evaluate condition
if not _check_condition(prc, provider, quota):
debug.log(
f"config.yaml: Skipping {provider_name} "
f"(condition not met: {prc.condition!r})"
)
continue
target_model = prc.model or model
tried.append(provider_name)
yield ProviderInfo(
name=provider_name,
url=getattr(provider, "url", ""),
label=getattr(provider, "label", None),
model=target_model,
)
try:
if hasattr(provider, "create_async_generator"):
async for chunk in provider.create_async_generator(
target_model, messages, **kwargs
):
yield chunk
elif hasattr(provider, "create_completion"):
for chunk in provider.create_completion(
target_model, messages, stream=True, **kwargs
):
yield chunk
else:
raise NotImplementedError(
f"{provider_name} has no supported create method"
)
debug.log(f"config.yaml: {provider_name} succeeded for model {model!r}")
return # Success
except Exception as e:
# On rate-limit errors invalidate the quota cache
from ..errors import RateLimitError
if isinstance(e, RateLimitError) or "429" in str(e):
debug.log(
f"config.yaml: Rate-limited by {provider_name}, "
"invalidating quota cache"
)
QuotaCache.invalidate(provider_name)
ErrorCounter.increment(provider_name)
last_exception = e
debug.error(f"config.yaml: {provider_name} failed:", e)
if last_exception is not None:
raise last_exception
raise RuntimeError(
f"config.yaml: No provider succeeded for model {model!r}. "
f"Tried: {tried}"
)