mirror of
https://github.com/zeke-chin/cursor-api.git
synced 2026-04-22 15:07:16 +08:00
412 lines
13 KiB
Python
412 lines
13 KiB
Python
import json
|
|
|
|
from fastapi import FastAPI, Request, Response, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional, Dict, Any
|
|
import uuid
|
|
import httpx
|
|
import os
|
|
from dotenv import load_dotenv
|
|
import time
|
|
import re
|
|
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
app = FastAPI()
|
|
|
|
# 添加CORS中间件
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# 定义请求模型
|
|
class Message(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
model: str
|
|
messages: List[Message]
|
|
stream: bool = False
|
|
|
|
|
|
def string_to_hex(text: str, model_name: str) -> bytes:
|
|
"""将文本转换为特定格式的十六进制数据"""
|
|
# 将输入文本转换为UTF-8字节
|
|
text_bytes = text.encode('utf-8')
|
|
text_length = len(text_bytes)
|
|
|
|
# 固定常量
|
|
FIXED_HEADER = 2
|
|
SEPARATOR = 1
|
|
FIXED_SUFFIX_LENGTH = 0xA3 + len(model_name)
|
|
|
|
# 计算第一个长度字段
|
|
if text_length < 128:
|
|
text_length_field1 = format(text_length, '02x')
|
|
text_length_field_size1 = 1
|
|
else:
|
|
low_byte1 = (text_length & 0x7F) | 0x80
|
|
high_byte1 = (text_length >> 7) & 0xFF
|
|
text_length_field1 = format(low_byte1, '02x') + format(high_byte1, '02x')
|
|
text_length_field_size1 = 2
|
|
|
|
# 计算基础长度字段
|
|
base_length = text_length + 0x2A
|
|
if base_length < 128:
|
|
text_length_field = format(base_length, '02x')
|
|
text_length_field_size = 1
|
|
else:
|
|
low_byte = (base_length & 0x7F) | 0x80
|
|
high_byte = (base_length >> 7) & 0xFF
|
|
text_length_field = format(low_byte, '02x') + format(high_byte, '02x')
|
|
text_length_field_size = 2
|
|
|
|
# 计算总消息长度
|
|
message_total_length = (FIXED_HEADER + text_length_field_size + SEPARATOR +
|
|
text_length_field_size1 + text_length + FIXED_SUFFIX_LENGTH)
|
|
|
|
# 构造十六进制字符串
|
|
model_name_bytes = model_name.encode('utf-8')
|
|
model_name_length_hex = format(len(model_name_bytes), '02X')
|
|
model_name_hex = model_name_bytes.hex().upper()
|
|
|
|
hex_string = (
|
|
f"{message_total_length:010x}"
|
|
"12"
|
|
f"{text_length_field}"
|
|
"0A"
|
|
f"{text_length_field1}"
|
|
f"{text_bytes.hex()}"
|
|
"10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612"
|
|
"2002A132F643A2F6964656150726F2F656475626F73733A1E0A"
|
|
f"{model_name_length_hex}"
|
|
f"{model_name_hex}"
|
|
"22004A"
|
|
"2461383761396133342D323164642D343863372D623434662D616636633365636536663765"
|
|
"680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061"
|
|
"800101B00100C00100E00100E80100"
|
|
).upper()
|
|
|
|
return bytes.fromhex(hex_string)
|
|
|
|
|
|
def chunk_to_utf8_string(chunk: bytes) -> str:
|
|
"""将二进制chunk转换为UTF-8字符串"""
|
|
if not chunk or len(chunk) < 2:
|
|
return ''
|
|
|
|
if chunk[0] in [0x01, 0x02] or (chunk[0] == 0x60 and chunk[1] == 0x0C):
|
|
return ''
|
|
|
|
# 记录原始chunk的十六进制(调试用)
|
|
print(f"chunk length: {len(chunk)}")
|
|
# print(f"chunk hex: {chunk.hex()}")
|
|
|
|
try:
|
|
# 去掉0x0A之前的所有字节
|
|
try:
|
|
chunk = chunk[chunk.index(0x0A) + 1:]
|
|
except ValueError:
|
|
pass
|
|
|
|
filtered_chunk = bytearray()
|
|
i = 0
|
|
while i < len(chunk):
|
|
# 检查是否有连续的0x00
|
|
if i + 4 <= len(chunk) and all(chunk[j] == 0x00 for j in range(i, i + 4)):
|
|
i += 4
|
|
while i < len(chunk) and chunk[i] <= 0x0F:
|
|
i += 1
|
|
continue
|
|
|
|
if chunk[i] == 0x0C:
|
|
i += 1
|
|
while i < len(chunk) and chunk[i] == 0x0A:
|
|
i += 1
|
|
else:
|
|
filtered_chunk.append(chunk[i])
|
|
i += 1
|
|
|
|
# 过滤掉特定字节
|
|
filtered_chunk = bytes(b for b in filtered_chunk
|
|
if b != 0x00 and b != 0x0C)
|
|
|
|
if not filtered_chunk:
|
|
return ''
|
|
|
|
result = filtered_chunk.decode('utf-8', errors='ignore').strip()
|
|
# print(f"decoded result: {result}") # 调试输出
|
|
return result
|
|
|
|
except Exception as e:
|
|
print(f"Error in chunk_to_utf8_string: {str(e)}")
|
|
return ''
|
|
|
|
|
|
async def process_stream(chunks, ):
|
|
"""处理流式响应"""
|
|
response_id = f"chatcmpl-{str(uuid.uuid4())}"
|
|
|
|
# 先将所有chunks读取到列表中
|
|
# chunks = []
|
|
# async for chunk in response.aiter_raw():
|
|
# chunks.append(chunk)
|
|
|
|
# 然后处理保存的chunks
|
|
for chunk in chunks:
|
|
text = chunk_to_utf8_string(chunk)
|
|
if text:
|
|
# 清理文本
|
|
text = text.strip()
|
|
if "<|END_USER|>" in text:
|
|
text = text.split("<|END_USER|>")[-1].strip()
|
|
if text and text[0].isalpha():
|
|
text = text[1:].strip()
|
|
text = re.sub(r"[\x00-\x1F\x7F]", "", text)
|
|
|
|
if text: # 确保清理后的文本不为空
|
|
data_body = {
|
|
"id": response_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {
|
|
"content": text
|
|
}
|
|
}]
|
|
}
|
|
yield f"data: {json.dumps(data_body, ensure_ascii=False)}\n\n"
|
|
# yield "data: {\n"
|
|
# yield f' "id": "{response_id}",\n'
|
|
# yield ' "object": "chat.completion.chunk",\n'
|
|
# yield f' "created": {int(time.time())},\n'
|
|
# yield ' "choices": [{\n'
|
|
# yield ' "index": 0,\n'
|
|
# yield ' "delta": {\n'
|
|
# yield f' "content": "{text}"\n'
|
|
# yield " }\n"
|
|
# yield " }]\n"
|
|
# yield "}\n\n"
|
|
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def chat_completions(request: Request, chat_request: ChatRequest):
|
|
# 验证o1模型不支持流式输出
|
|
if chat_request.model.startswith('o1-') and chat_request.stream:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Model not supported stream"
|
|
)
|
|
|
|
# 获取并处理认证令牌
|
|
auth_header = request.headers.get('authorization', '')
|
|
if not auth_header.startswith('Bearer '):
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Invalid authorization header"
|
|
)
|
|
|
|
auth_token = auth_header.replace('Bearer ', '')
|
|
if not auth_token:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Missing authorization token"
|
|
)
|
|
|
|
# 处理多个密钥
|
|
keys = [key.strip() for key in auth_token.split(',')]
|
|
if keys:
|
|
auth_token = keys[0] # 使用第一个密钥
|
|
|
|
if '%3A%3A' in auth_token:
|
|
auth_token = auth_token.split('%3A%3A')[1]
|
|
|
|
# 格式化消息
|
|
formatted_messages = "\n".join(
|
|
f"{msg.role}:{msg.content}" for msg in chat_request.messages
|
|
)
|
|
|
|
# 生成请求数据
|
|
hex_data = string_to_hex(formatted_messages, chat_request.model)
|
|
|
|
# 准备请求头
|
|
headers = {
|
|
'Content-Type': 'application/connect+proto',
|
|
'Authorization': f'Bearer {auth_token}',
|
|
'Connect-Accept-Encoding': 'gzip,br',
|
|
'Connect-Protocol-Version': '1',
|
|
'User-Agent': 'connect-es/1.4.0',
|
|
'X-Amzn-Trace-Id': f'Root={str(uuid.uuid4())}',
|
|
'X-Cursor-Checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
|
|
'X-Cursor-Client-Version': '0.42.3',
|
|
'X-Cursor-Timezone': 'Asia/Shanghai',
|
|
'X-Ghost-Mode': 'false',
|
|
'X-Request-Id': str(uuid.uuid4()),
|
|
'Host': 'api2.cursor.sh'
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
|
try:
|
|
# 使用 stream=True 参数
|
|
# 打印 headers 和 二进制 data
|
|
print(f"headers: {headers}")
|
|
print(hex_data)
|
|
async with client.stream(
|
|
'POST',
|
|
'https://api2.cursor.sh/aiserver.v1.AiService/StreamChat',
|
|
headers=headers,
|
|
content=hex_data,
|
|
timeout=None
|
|
) as response:
|
|
if chat_request.stream:
|
|
chunks = []
|
|
async for chunk in response.aiter_raw():
|
|
chunks.append(chunk)
|
|
return StreamingResponse(
|
|
process_stream(chunks),
|
|
media_type="text/event-stream"
|
|
)
|
|
else:
|
|
# 非流式响应处理
|
|
text = ''
|
|
async for chunk in response.aiter_raw():
|
|
# print('chunk:', chunk.hex())
|
|
print('chunk length:', len(chunk))
|
|
|
|
res = chunk_to_utf8_string(chunk)
|
|
# print('res:', res)
|
|
if res:
|
|
text += res
|
|
|
|
# 清理响应文本
|
|
import re
|
|
text = re.sub(r'^.*<\|END_USER\|>', '', text, flags=re.DOTALL)
|
|
text = re.sub(r'^\n[a-zA-Z]?', '', text).strip()
|
|
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
|
|
|
|
return {
|
|
"id": f"chatcmpl-{str(uuid.uuid4())}",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": chat_request.model,
|
|
"choices": [{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": text
|
|
},
|
|
"finish_reason": "stop"
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0,
|
|
"total_tokens": 0
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"Error: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Internal server error"
|
|
)
|
|
|
|
|
|
@app.post("/models")
|
|
async def models():
|
|
return {
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": "claude-3-5-sonnet-20241022",
|
|
"object": "model",
|
|
"created": 1713744000,
|
|
"owned_by": "anthropic"
|
|
},
|
|
{
|
|
"id": "claude-3-opus",
|
|
"object": "model",
|
|
"created": 1709251200,
|
|
"owned_by": "anthropic"
|
|
},
|
|
{
|
|
"id": "claude-3.5-haiku",
|
|
"object": "model",
|
|
"created": 1711929600,
|
|
"owned_by": "anthropic"
|
|
},
|
|
{
|
|
"id": "claude-3.5-sonnet",
|
|
"object": "model",
|
|
"created": 1711929600,
|
|
"owned_by": "anthropic"
|
|
},
|
|
{
|
|
"id": "cursor-small",
|
|
"object": "model",
|
|
"created": 1712534400,
|
|
"owned_by": "cursor"
|
|
},
|
|
{
|
|
"id": "gpt-3.5-turbo",
|
|
"object": "model",
|
|
"created": 1677649200,
|
|
"owned_by": "openai"
|
|
},
|
|
{
|
|
"id": "gpt-4",
|
|
"object": "model",
|
|
"created": 1687392000,
|
|
"owned_by": "openai"
|
|
},
|
|
{
|
|
"id": "gpt-4-turbo-2024-04-09",
|
|
"object": "model",
|
|
"created": 1712620800,
|
|
"owned_by": "openai"
|
|
},
|
|
{
|
|
"id": "gpt-4o",
|
|
"object": "model",
|
|
"created": 1712620800,
|
|
"owned_by": "openai"
|
|
},
|
|
{
|
|
"id": "gpt-4o-mini",
|
|
"object": "model",
|
|
"created": 1712620800,
|
|
"owned_by": "openai"
|
|
},
|
|
{
|
|
"id": "o1-mini",
|
|
"object": "model",
|
|
"created": 1712620800,
|
|
"owned_by": "openai"
|
|
},
|
|
{
|
|
"id": "o1-preview",
|
|
"object": "model",
|
|
"created": 1712620800,
|
|
"owned_by": "openai"
|
|
}
|
|
]
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
port = int(os.getenv("PORT", "3001"))
|
|
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True, timeout_keep_alive=30)
|