mirror of
https://github.com/xtekky/gpt4free.git
synced 2026-04-22 15:47:11 +08:00
Add PaProviderRegistry, /pa/* API routes (providers list, chat/completions, backend-api/v2/conversation)
Agent-Logs-Url: https://github.com/xtekky/gpt4free/sessions/e0daf662-ee35-43ac-bdef-27dd570bc00d Co-authored-by: hlohaus <983577+hlohaus@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
d2f2af886c
commit
dd9230ef4e
@@ -560,3 +560,123 @@ class TestSecurityHardening(unittest.IsolatedAsyncioTestCase):
|
||||
"max_depth": MAX_RECURSION_DEPTH * 100,
|
||||
})
|
||||
self.assertTrue(result.get("success"))
|
||||
|
||||
class TestPaProviderRegistry(unittest.TestCase):
|
||||
"""Tests for PaProviderRegistry — stable IDs without exposing filenames."""
|
||||
|
||||
def setUp(self):
|
||||
"""Create a temporary .pa.py file in the workspace for testing."""
|
||||
from g4f.mcp.pa_provider import get_workspace_dir, get_pa_registry, _pa_registry
|
||||
self.workspace = get_workspace_dir()
|
||||
# Force a fresh registry for each test
|
||||
import g4f.mcp.pa_provider as _mod
|
||||
_mod._pa_registry = None
|
||||
|
||||
self.pa_file = self.workspace / "registry_test.pa.py"
|
||||
self.pa_file.write_text("""
|
||||
class Provider:
|
||||
label = "RegistryTestProvider"
|
||||
working = True
|
||||
models = ["rt-model-1", "rt-model-2"]
|
||||
url = "https://test.example.com"
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(cls, model, messages, **kwargs):
|
||||
yield "hello from registry test"
|
||||
""")
|
||||
|
||||
def tearDown(self):
|
||||
if self.pa_file.exists():
|
||||
self.pa_file.unlink()
|
||||
import g4f.mcp.pa_provider as _mod
|
||||
_mod._pa_registry = None
|
||||
|
||||
def test_list_providers_returns_list(self):
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
reg = get_pa_registry()
|
||||
reg.refresh()
|
||||
result = reg.list_providers()
|
||||
self.assertIsInstance(result, list)
|
||||
self.assertGreaterEqual(len(result), 1)
|
||||
|
||||
def test_provider_has_required_fields(self):
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
reg = get_pa_registry()
|
||||
reg.refresh()
|
||||
providers = reg.list_providers()
|
||||
p = next((x for x in providers if x.get("label") == "RegistryTestProvider"), None)
|
||||
self.assertIsNotNone(p, "Test provider not found in registry")
|
||||
self.assertIn("id", p)
|
||||
self.assertIn("label", p)
|
||||
self.assertIn("models", p)
|
||||
self.assertIn("working", p)
|
||||
self.assertIn("url", p)
|
||||
self.assertEqual(p["label"], "RegistryTestProvider")
|
||||
self.assertIn("rt-model-1", p["models"])
|
||||
self.assertTrue(p["working"])
|
||||
|
||||
def test_filename_not_exposed(self):
|
||||
"""Provider IDs and info must NOT contain the filename or path."""
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
import json
|
||||
reg = get_pa_registry()
|
||||
reg.refresh()
|
||||
providers = reg.list_providers()
|
||||
for p in providers:
|
||||
serialized = json.dumps(p)
|
||||
self.assertNotIn("registry_test", serialized, "Filename leaked in provider info")
|
||||
self.assertNotIn(".pa.py", serialized, "Extension leaked in provider info")
|
||||
self.assertNotIn(str(self.workspace), serialized, "Workspace path leaked")
|
||||
|
||||
def test_stable_id(self):
|
||||
"""The same file gets the same ID across refreshes."""
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
reg = get_pa_registry()
|
||||
reg.refresh()
|
||||
p1 = next(x for x in reg.list_providers() if x["label"] == "RegistryTestProvider")
|
||||
reg.refresh()
|
||||
p2 = next(x for x in reg.list_providers() if x["label"] == "RegistryTestProvider")
|
||||
self.assertEqual(p1["id"], p2["id"])
|
||||
|
||||
def test_get_provider_class_returns_class(self):
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
reg = get_pa_registry()
|
||||
reg.refresh()
|
||||
p = next(x for x in reg.list_providers() if x["label"] == "RegistryTestProvider")
|
||||
cls = reg.get_provider_class(p["id"])
|
||||
self.assertIsNotNone(cls)
|
||||
self.assertTrue(hasattr(cls, "create_async_generator"))
|
||||
|
||||
def test_get_provider_class_missing_returns_none(self):
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
reg = get_pa_registry()
|
||||
self.assertIsNone(reg.get_provider_class("nonexistent00"))
|
||||
|
||||
def test_get_provider_info_returns_dict(self):
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
reg = get_pa_registry()
|
||||
reg.refresh()
|
||||
p = next(x for x in reg.list_providers() if x["label"] == "RegistryTestProvider")
|
||||
info = reg.get_provider_info(p["id"])
|
||||
self.assertIsNotNone(info)
|
||||
self.assertEqual(info["id"], p["id"])
|
||||
self.assertEqual(info["label"], "RegistryTestProvider")
|
||||
|
||||
def test_get_provider_info_missing_returns_none(self):
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
reg = get_pa_registry()
|
||||
self.assertIsNone(reg.get_provider_info("nonexistent00"))
|
||||
|
||||
def test_id_length(self):
|
||||
"""IDs should be 8 hex characters."""
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
reg = get_pa_registry()
|
||||
reg.refresh()
|
||||
for p in reg.list_providers():
|
||||
self.assertRegex(p["id"], r'^[0-9a-f]{8}$')
|
||||
|
||||
def test_registry_singleton(self):
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
r1 = get_pa_registry()
|
||||
r2 = get_pa_registry()
|
||||
self.assertIs(r1, r2)
|
||||
|
||||
+209
-2
@@ -267,7 +267,7 @@ class Api:
|
||||
else:
|
||||
user = "admin"
|
||||
path = request.url.path
|
||||
if path.startswith("/v1") or path.startswith("/api/") or (AppConfig.demo and path == '/backend-api/v2/upload_cookies'):
|
||||
if path.startswith("/v1") or path.startswith("/api/") or path.startswith("/pa/") or (AppConfig.demo and path == '/backend-api/v2/upload_cookies'):
|
||||
if request.method != "OPTIONS" and not path.endswith("/models"):
|
||||
if not user_g4f_api_key:
|
||||
return ErrorResponse.from_message("G4F API key required", HTTP_401_UNAUTHORIZED)
|
||||
@@ -636,7 +636,214 @@ class Api:
|
||||
'params': [*provider.get_parameters()] if hasattr(provider, "get_parameters") else []
|
||||
}
|
||||
|
||||
responses = {
|
||||
# ------------------------------------------------------------------ #
|
||||
# PA Provider routes #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@self.app.get("/pa/providers", responses={
|
||||
HTTP_200_OK: {},
|
||||
})
|
||||
async def pa_providers_list():
|
||||
"""List all PA providers loaded from the workspace.
|
||||
|
||||
Filenames are never exposed; each provider is identified by a
|
||||
stable opaque ID (SHA-256 of the path, first 8 hex chars).
|
||||
"""
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
return get_pa_registry().list_providers()
|
||||
|
||||
@self.app.get("/pa/providers/{provider_id}", responses={
|
||||
HTTP_200_OK: {},
|
||||
HTTP_404_NOT_FOUND: {"model": ErrorResponseModel},
|
||||
})
|
||||
async def pa_providers_detail(provider_id: str):
|
||||
"""Get details for a single PA provider by its opaque ID."""
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
info = get_pa_registry().get_provider_info(provider_id)
|
||||
if info is None:
|
||||
return ErrorResponse.from_message(
|
||||
f"PA provider '{provider_id}' not found", HTTP_404_NOT_FOUND
|
||||
)
|
||||
return info
|
||||
|
||||
responses_pa = {
|
||||
HTTP_200_OK: {"model": ChatCompletion},
|
||||
HTTP_401_UNAUTHORIZED: {"model": ErrorResponseModel},
|
||||
HTTP_404_NOT_FOUND: {"model": ErrorResponseModel},
|
||||
HTTP_422_UNPROCESSABLE_ENTITY: {"model": ErrorResponseModel},
|
||||
HTTP_500_INTERNAL_SERVER_ERROR: {"model": ErrorResponseModel},
|
||||
}
|
||||
|
||||
@self.app.post("/pa/chat/completions", responses=responses_pa)
|
||||
@self.app.post("/pa/{provider_id}/chat/completions", responses=responses_pa)
|
||||
async def pa_chat_completions(
|
||||
config: ChatCompletionsConfig,
|
||||
credentials: Annotated[HTTPAuthorizationCredentials, Depends(Api.security)] = None,
|
||||
provider_id: str = None,
|
||||
):
|
||||
"""OpenAI-compatible chat completions endpoint backed by PA providers.
|
||||
|
||||
The PA provider is identified by its opaque ID either from the URL
|
||||
path (``/pa/{provider_id}/chat/completions``) or from the ``provider``
|
||||
field in the JSON body. When both are absent the first available PA
|
||||
provider is used.
|
||||
"""
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
|
||||
registry = get_pa_registry()
|
||||
pid = provider_id or config.provider
|
||||
if pid is None:
|
||||
listing = registry.list_providers()
|
||||
if not listing:
|
||||
return ErrorResponse.from_message(
|
||||
"No PA providers found in workspace", HTTP_404_NOT_FOUND
|
||||
)
|
||||
pid = listing[0]["id"]
|
||||
|
||||
provider_cls = registry.get_provider_class(pid)
|
||||
if provider_cls is None:
|
||||
return ErrorResponse.from_message(
|
||||
f"PA provider '{pid}' not found", HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
try:
|
||||
config.provider = None # pass the class directly below
|
||||
if credentials is not None and credentials.credentials != "secret":
|
||||
config.api_key = credentials.credentials
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
**filter_none(
|
||||
**(
|
||||
config.model_dump(exclude_none=True)
|
||||
if hasattr(config, "model_dump")
|
||||
else config.dict(exclude_none=True)
|
||||
),
|
||||
**{
|
||||
"conversation_id": None,
|
||||
"provider": provider_cls,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
if not config.stream:
|
||||
return await response
|
||||
|
||||
async def streaming():
|
||||
try:
|
||||
async for chunk in response:
|
||||
if not isinstance(chunk, BaseConversation):
|
||||
yield (
|
||||
f"data: "
|
||||
f"{chunk.model_dump_json() if hasattr(chunk, 'model_dump_json') else chunk.json()}"
|
||||
f"\n\n"
|
||||
)
|
||||
except GeneratorExit:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
yield f"data: {format_exception(e, config)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(streaming(), media_type="text/event-stream")
|
||||
|
||||
except (ModelNotFoundError, ProviderNotFoundError) as e:
|
||||
logger.exception(e)
|
||||
return ErrorResponse.from_exception(e, config, HTTP_404_NOT_FOUND)
|
||||
except (MissingAuthError, NoValidHarFileError) as e:
|
||||
logger.exception(e)
|
||||
return ErrorResponse.from_exception(e, config, HTTP_401_UNAUTHORIZED)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return ErrorResponse.from_exception(e, config, HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
@self.app.post("/pa/backend-api/v2/conversation", responses={
|
||||
HTTP_200_OK: {},
|
||||
HTTP_404_NOT_FOUND: {"model": ErrorResponseModel},
|
||||
HTTP_422_UNPROCESSABLE_ENTITY: {"model": ErrorResponseModel},
|
||||
HTTP_500_INTERNAL_SERVER_ERROR: {"model": ErrorResponseModel},
|
||||
})
|
||||
async def pa_backend_conversation(request: Request):
|
||||
"""GUI-compatible streaming conversation endpoint for PA providers.
|
||||
|
||||
Accepts the same JSON body as ``/backend-api/v2/conversation`` and
|
||||
streams Server-Sent Events in the same format used by the gpt4free
|
||||
web interface (``{"type": "content", "content": "..."}`` etc.).
|
||||
|
||||
The ``provider`` field should contain the opaque PA provider ID
|
||||
returned by ``GET /pa/providers``. When omitted the first available
|
||||
PA provider is used.
|
||||
"""
|
||||
from g4f.mcp.pa_provider import get_pa_registry
|
||||
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return ErrorResponse.from_message(
|
||||
"Invalid JSON body", HTTP_422_UNPROCESSABLE_ENTITY
|
||||
)
|
||||
|
||||
registry = get_pa_registry()
|
||||
pid = body.get("provider")
|
||||
if pid:
|
||||
provider_cls = registry.get_provider_class(pid)
|
||||
if provider_cls is None:
|
||||
return ErrorResponse.from_message(
|
||||
f"PA provider '{pid}' not found", HTTP_404_NOT_FOUND
|
||||
)
|
||||
else:
|
||||
listing = registry.list_providers()
|
||||
if not listing:
|
||||
return ErrorResponse.from_message(
|
||||
"No PA providers found in workspace", HTTP_404_NOT_FOUND
|
||||
)
|
||||
provider_cls = registry.get_provider_class(listing[0]["id"])
|
||||
|
||||
provider_label = getattr(provider_cls, "label", provider_cls.__name__)
|
||||
messages = body.get("messages") or []
|
||||
model = body.get("model") or getattr(provider_cls, "default_model", "") or ""
|
||||
|
||||
async def gen_backend_stream():
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps({"type": "provider", "provider": provider_label, "model": model})
|
||||
+ "\n\n"
|
||||
)
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model,
|
||||
provider=provider_cls,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response:
|
||||
if isinstance(chunk, BaseConversation):
|
||||
continue
|
||||
text = ""
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
delta = chunk.choices[0].delta
|
||||
text = getattr(delta, "content", "") or ""
|
||||
if text:
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps({"type": "content", "content": text})
|
||||
+ "\n\n"
|
||||
)
|
||||
except GeneratorExit:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps({"type": "error", "error": f"{type(e).__name__}: {e}"})
|
||||
+ "\n\n"
|
||||
)
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps({"type": "finish", "finish": "stop"})
|
||||
+ "\n\n"
|
||||
)
|
||||
|
||||
return StreamingResponse(gen_backend_stream(), media_type="text/event-stream")
|
||||
HTTP_200_OK: {"model": TranscriptionResponseModel},
|
||||
HTTP_401_UNAUTHORIZED: {"model": ErrorResponseModel},
|
||||
HTTP_404_NOT_FOUND: {"model": ErrorResponseModel},
|
||||
|
||||
@@ -28,8 +28,10 @@ from .pa_provider import (
|
||||
load_pa_provider,
|
||||
list_pa_providers,
|
||||
get_workspace_dir,
|
||||
get_pa_registry,
|
||||
SAFE_MODULES,
|
||||
SafeExecutionResult,
|
||||
PaProviderRegistry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -51,6 +53,8 @@ __all__ = [
|
||||
'load_pa_provider',
|
||||
'list_pa_providers',
|
||||
'get_workspace_dir',
|
||||
'get_pa_registry',
|
||||
'SAFE_MODULES',
|
||||
'SafeExecutionResult',
|
||||
'PaProviderRegistry',
|
||||
]
|
||||
|
||||
@@ -60,7 +60,9 @@ import io
|
||||
import ast
|
||||
import sys
|
||||
import json
|
||||
import hashlib
|
||||
import threading
|
||||
import time as _time_module
|
||||
import traceback
|
||||
import builtins as _builtins
|
||||
from pathlib import Path
|
||||
@@ -480,3 +482,127 @@ def list_pa_providers(directory: "Optional[str | Path]" = None) -> List[Path]:
|
||||
if not directory.exists():
|
||||
return []
|
||||
return sorted(directory.rglob("*.pa.py"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PA Provider Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PaProviderRegistry:
|
||||
"""Singleton registry for PA providers loaded from the workspace.
|
||||
|
||||
Each provider is assigned a **stable opaque ID** derived from the SHA-256
|
||||
hash of its canonical file path (truncated to 8 hex chars). The filename
|
||||
is never exposed in any public-facing method.
|
||||
|
||||
The registry is automatically refreshed when the cache is older than
|
||||
:attr:`TTL` seconds so hot-reloaded PA files are picked up without a
|
||||
restart.
|
||||
"""
|
||||
|
||||
#: How long (in seconds) the cached entries remain valid.
|
||||
TTL: float = 5.0
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Each entry: (id, label, models, working, url, cls)
|
||||
self._entries: List[tuple] = []
|
||||
# Force a refresh on the first access.
|
||||
self._loaded_at: float = -self.TTL
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _make_id(path: Path) -> str:
|
||||
"""Return a stable 8-char hex ID for *path* (no path info exposed)."""
|
||||
return hashlib.sha256(str(path.resolve()).encode("utf-8")).hexdigest()[:8]
|
||||
|
||||
def _ensure_fresh(self) -> None:
|
||||
if _time_module.monotonic() - self._loaded_at >= self.TTL:
|
||||
self.refresh()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def refresh(self) -> None:
|
||||
"""Re-scan the workspace and reload all ``.pa.py`` providers."""
|
||||
entries: List[tuple] = []
|
||||
for pa_path in list_pa_providers():
|
||||
try:
|
||||
cls = load_pa_provider(pa_path)
|
||||
if cls is None:
|
||||
continue
|
||||
provider_id = self._make_id(pa_path)
|
||||
models_list: List[str] = []
|
||||
try:
|
||||
if hasattr(cls, "get_models"):
|
||||
raw = cls.get_models()
|
||||
models_list = list(raw) if raw else []
|
||||
elif hasattr(cls, "models"):
|
||||
models_list = list(getattr(cls, "models") or [])
|
||||
except Exception:
|
||||
pass
|
||||
entries.append((
|
||||
provider_id,
|
||||
getattr(cls, "label", cls.__name__),
|
||||
models_list,
|
||||
bool(getattr(cls, "working", True)),
|
||||
getattr(cls, "url", None),
|
||||
cls,
|
||||
))
|
||||
except Exception:
|
||||
pass
|
||||
self._entries = entries
|
||||
self._loaded_at = _time_module.monotonic()
|
||||
|
||||
def list_providers(self) -> List[Dict[str, Any]]:
|
||||
"""Return a list of provider info dicts (no filesystem paths)."""
|
||||
self._ensure_fresh()
|
||||
return [
|
||||
{
|
||||
"id": e[0],
|
||||
"object": "pa_provider",
|
||||
"label": e[1],
|
||||
"models": e[2],
|
||||
"working": e[3],
|
||||
"url": e[4],
|
||||
}
|
||||
for e in self._entries
|
||||
]
|
||||
|
||||
def get_provider_class(self, provider_id: str) -> Optional[Type]:
|
||||
"""Return the provider class for *provider_id*, or ``None``."""
|
||||
self._ensure_fresh()
|
||||
for e in self._entries:
|
||||
if e[0] == provider_id:
|
||||
return e[5]
|
||||
return None
|
||||
|
||||
def get_provider_info(self, provider_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return the info dict for *provider_id*, or ``None``."""
|
||||
self._ensure_fresh()
|
||||
for e in self._entries:
|
||||
if e[0] == provider_id:
|
||||
return {
|
||||
"id": e[0],
|
||||
"object": "pa_provider",
|
||||
"label": e[1],
|
||||
"models": e[2],
|
||||
"working": e[3],
|
||||
"url": e[4],
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
#: Module-level singleton.
|
||||
_pa_registry: Optional[PaProviderRegistry] = None
|
||||
|
||||
|
||||
def get_pa_registry() -> PaProviderRegistry:
|
||||
"""Return the singleton :class:`PaProviderRegistry`, creating it if needed."""
|
||||
global _pa_registry
|
||||
if _pa_registry is None:
|
||||
_pa_registry = PaProviderRegistry()
|
||||
return _pa_registry
|
||||
|
||||
Reference in New Issue
Block a user