Files
AI2API/qwen2api.go
T
BlueSkyXN b63b8dae00 0.1.8
2025-03-25 10:26:22 +08:00

2144 lines
56 KiB
Go

package main
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"mime/multipart"
"net/http"
"os"
"os/signal"
"regexp"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
)
// 版本和API常量
const (
Version = "1.0.0"
TargetURL = "https://chat.qwen.ai/api/chat/completions"
ModelsURL = "https://chat.qwen.ai/api/models"
FilesURL = "https://chat.qwen.ai/api/v1/files/"
TasksURL = "https://chat.qwen.ai/api/v1/tasks/status/"
)
// 默认模型列表(当获取接口失败时使用)
var DefaultModels = []string{
"qwen-max-latest",
"qwen-plus-latest",
"qwen2.5-vl-72b-instruct",
"qwen2.5-14b-instruct-1m",
"qvq-72b-preview",
"qwq-32b-preview",
"qwen2.5-coder-32b-instruct",
"qwen-turbo-latest",
"qwen2.5-72b-instruct",
}
// 扩展模型变种后缀
var ModelSuffixes = []string{
"",
"-thinking",
"-search",
"-thinking-search",
"-draw",
}
// 日志级别常量
const (
LogLevelDebug = "debug"
LogLevelInfo = "info"
LogLevelWarn = "warn"
LogLevelError = "error"
)
// WorkerPool 工作池结构体,用于管理goroutine
type WorkerPool struct {
taskQueue chan *Task
workerCount int
shutdownChannel chan struct{}
wg sync.WaitGroup
}
// Task 任务结构体,包含请求处理所需数据
type Task struct {
r *http.Request
w http.ResponseWriter
done chan struct{}
reqID string
isStream bool
apiReq APIRequest
path string
}
// Semaphore 信号量实现,用于限制并发数量
type Semaphore struct {
sem chan struct{}
}
// 配置结构体
type Config struct {
Port string
Address string
LogLevel string
DevMode bool
MaxRetries int
Timeout int
VerifySSL bool
WorkerCount int
QueueSize int
MaxConcurrent int
APIPrefix string
}
// APIRequest OpenAI兼容的请求结构体
type APIRequest struct {
Model string `json:"model"`
Messages []APIMessage `json:"messages"`
Stream bool `json:"stream"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
}
// APIMessage 消息结构体
type APIMessage struct {
Role string `json:"role"`
Content interface{} `json:"content"`
FeatureConfig interface{} `json:"feature_config,omitempty"`
ChatType string `json:"chat_type,omitempty"`
Extra interface{} `json:"extra,omitempty"`
}
// 内容项目结构体(处理图像等内容)
type ContentItem struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
Image string `json:"image,omitempty"`
}
// ImageURL 图像URL结构体
type ImageURL struct {
URL string `json:"url"`
}
// QwenRequest 通义千问API请求结构体
type QwenRequest struct {
Model string `json:"model"`
Messages []APIMessage `json:"messages"`
Stream bool `json:"stream"`
ChatType string `json:"chat_type,omitempty"`
ID string `json:"id,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
Size string `json:"size,omitempty"`
}
// QwenResponse 通义千问API响应结构体
type QwenResponse struct {
Messages []struct {
Role string `json:"role"`
Content string `json:"content"`
Extra struct {
Wanx struct {
TaskID string `json:"task_id"`
} `json:"wanx"`
} `json:"extra"`
} `json:"messages"`
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// FileUploadResponse 文件上传响应
type FileUploadResponse struct {
ID string `json:"id"`
}
// TaskStatusResponse 任务状态响应
type TaskStatusResponse struct {
Content string `json:"content"`
}
// StreamChunk OpenAI兼容的流式响应块
type StreamChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
} `json:"choices"`
}
// CompletionResponse OpenAI兼容的完成响应
type CompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// ImagesResponse 图像生成响应
type ImagesResponse struct {
Created int64 `json:"created"`
Data []ImageURL `json:"data"`
}
// ImagesRequest 图像生成请求
type ImagesRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
}
// ModelData 模型数据
type ModelData struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
// ModelsResponse 模型列表响应
type ModelsResponse struct {
Object string `json:"object"`
Data []ModelData `json:"data"`
}
// 全局变量
var (
appConfig *Config
logger *log.Logger
logLevel string
logMutex sync.Mutex
workerPool *WorkerPool
requestSem *Semaphore
requestCount uint64 = 0
countMutex sync.Mutex
// 性能指标
requestCounter int64
successCounter int64
errorCounter int64
avgResponseTime int64
queuedRequests int64
rejectedRequests int64
)
// NewSemaphore 创建新的信号量
func NewSemaphore(size int) *Semaphore {
return &Semaphore{
sem: make(chan struct{}, size),
}
}
// Acquire 获取信号量(阻塞)
func (s *Semaphore) Acquire() {
s.sem <- struct{}{}
}
// Release 释放信号量
func (s *Semaphore) Release() {
<-s.sem
}
// TryAcquire 尝试获取信号量(非阻塞)
func (s *Semaphore) TryAcquire() bool {
select {
case s.sem <- struct{}{}:
return true
default:
return false
}
}
// NewWorkerPool 创建并启动一个新的工作池
func NewWorkerPool(workerCount int, queueSize int) *WorkerPool {
pool := &WorkerPool{
taskQueue: make(chan *Task, queueSize),
workerCount: workerCount,
shutdownChannel: make(chan struct{}),
}
pool.Start()
return pool
}
// Start 启动工作池中的worker goroutines
func (pool *WorkerPool) Start() {
// 启动工作goroutine
for i := 0; i < pool.workerCount; i++ {
pool.wg.Add(1)
go func(workerID int) {
defer pool.wg.Done()
logInfo("Worker %d 已启动", workerID)
for {
select {
case task, ok := <-pool.taskQueue:
if !ok {
// 队列已关闭,退出worker
logInfo("Worker %d 收到队列关闭信号,准备退出", workerID)
return
}
logDebug("Worker %d 处理任务 reqID:%s", workerID, task.reqID)
// 处理任务
switch task.path {
case "/v1/models":
handleModels(task.w, task.r)
case "/v1/chat/completions":
if task.isStream {
handleStreamingRequest(task.w, task.r, task.apiReq, task.reqID)
} else {
handleNonStreamingRequest(task.w, task.r, task.apiReq, task.reqID)
}
case "/v1/images/generations":
handleImageGenerations(task.w, task.r, task.apiReq, task.reqID)
}
// 通知任务完成
close(task.done)
case <-pool.shutdownChannel:
// 收到关闭信号,退出worker
logInfo("Worker %d 收到关闭信号,准备退出", workerID)
return
}
}
}(i)
}
}
// SubmitTask 提交任务到工作池,非阻塞
func (pool *WorkerPool) SubmitTask(task *Task) (bool, error) {
select {
case pool.taskQueue <- task:
// 任务成功添加到队列
return true, nil
default:
// 队列已满
return false, fmt.Errorf("任务队列已满")
}
}
// Shutdown 关闭工作池
func (pool *WorkerPool) Shutdown() {
logInfo("正在关闭工作池...")
// 发送关闭信号给所有worker
close(pool.shutdownChannel)
// 等待所有worker退出
pool.wg.Wait()
// 关闭任务队列
close(pool.taskQueue)
logInfo("工作池已关闭")
}
// 日志函数
func initLogger(level string) {
logger = log.New(os.Stdout, "[QwenAPI] ", log.LstdFlags)
logLevel = level
}
func logDebug(format string, v ...interface{}) {
if logLevel == LogLevelDebug {
logMutex.Lock()
logger.Printf("[DEBUG] "+format, v...)
logMutex.Unlock()
}
}
func logInfo(format string, v ...interface{}) {
if logLevel == LogLevelDebug || logLevel == LogLevelInfo {
logMutex.Lock()
logger.Printf("[INFO] "+format, v...)
logMutex.Unlock()
}
}
func logWarn(format string, v ...interface{}) {
if logLevel == LogLevelDebug || logLevel == LogLevelInfo || logLevel == LogLevelWarn {
logMutex.Lock()
logger.Printf("[WARN] "+format, v...)
logMutex.Unlock()
}
}
func logError(format string, v ...interface{}) {
logMutex.Lock()
logger.Printf("[ERROR] "+format, v...)
logMutex.Unlock()
// 错误计数
atomic.AddInt64(&errorCounter, 1)
}
// 解析命令行参数
func parseFlags() *Config {
cfg := &Config{}
flag.StringVar(&cfg.Port, "port", "8080", "Port to listen on")
flag.StringVar(&cfg.Address, "address", "localhost", "Address to listen on")
flag.StringVar(&cfg.LogLevel, "log-level", LogLevelInfo, "Log level (debug, info, warn, error)")
flag.BoolVar(&cfg.DevMode, "dev", false, "Enable development mode with enhanced logging")
flag.IntVar(&cfg.MaxRetries, "max-retries", 3, "Maximum number of retries for failed requests")
flag.IntVar(&cfg.Timeout, "timeout", 300, "Request timeout in seconds")
flag.BoolVar(&cfg.VerifySSL, "verify-ssl", true, "Verify SSL certificates")
flag.IntVar(&cfg.WorkerCount, "workers", 50, "Number of worker goroutines in the pool")
flag.IntVar(&cfg.QueueSize, "queue-size", 500, "Size of the task queue")
flag.IntVar(&cfg.MaxConcurrent, "max-concurrent", 100, "Maximum number of concurrent requests")
flag.StringVar(&cfg.APIPrefix, "api-prefix", "", "API prefix for all endpoints")
flag.Parse()
// 如果开发模式开启,自动设置日志级别为debug
if cfg.DevMode && cfg.LogLevel != LogLevelDebug {
cfg.LogLevel = LogLevelDebug
fmt.Println("开发模式已启用,日志级别设置为debug")
}
return cfg
}
// 从请求头中提取令牌
func extractToken(r *http.Request) (string, error) {
// 获取 Authorization 头部
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", fmt.Errorf("missing Authorization header")
}
// 验证格式并提取令牌
if !strings.HasPrefix(authHeader, "Bearer ") {
return "", fmt.Errorf("invalid Authorization header format, must start with 'Bearer '")
}
// 提取令牌值
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == "" {
return "", fmt.Errorf("empty token in Authorization header")
}
return token, nil
}
// 设置CORS头
func setCORSHeaders(w http.ResponseWriter) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
}
// 生成UUID
func generateUUID() string {
b := make([]byte, 16)
_, err := rand.Read(b)
if err != nil {
return fmt.Sprintf("%d", time.Now().UnixNano())
}
return fmt.Sprintf("%x-%x-%x-%x-%x",
b[0:4], b[4:6], b[6:8], b[8:10], b[10:])
}
// 安全的HTTP客户端,支持禁用SSL验证
func getHTTPClient() *http.Client {
tr := &http.Transport{
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSClientConfig: nil, // 默认配置
}
// 如果配置了禁用SSL验证
if !appConfig.VerifySSL {
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
return &http.Client{
Timeout: time.Duration(appConfig.Timeout) * time.Second,
Transport: tr,
}
}
// 主入口函数
func main() {
// 解析配置
appConfig = parseFlags()
// 初始化日志
initLogger(appConfig.LogLevel)
logInfo("启动服务: 地址=%s, 端口=%s, 版本=%s, 日志级别=%s",
appConfig.Address, appConfig.Port, Version, appConfig.LogLevel)
// 创建工作池和信号量
workerPool = NewWorkerPool(appConfig.WorkerCount, appConfig.QueueSize)
requestSem = NewSemaphore(appConfig.MaxConcurrent)
logInfo("工作池已创建: %d个worker, 队列大小为%d", appConfig.WorkerCount, appConfig.QueueSize)
// 配置更高的并发处理能力
http.DefaultTransport.(*http.Transport).MaxIdleConnsPerHost = 100
http.DefaultTransport.(*http.Transport).MaxIdleConns = 100
http.DefaultTransport.(*http.Transport).IdleConnTimeout = 90 * time.Second
// 创建自定义服务器,支持更高并发
server := &http.Server{
Addr: appConfig.Address + ":" + appConfig.Port,
ReadTimeout: time.Duration(appConfig.Timeout) * time.Second,
WriteTimeout: time.Duration(appConfig.Timeout) * time.Second,
IdleTimeout: 120 * time.Second,
Handler: nil, // 使用默认的ServeMux
}
// API路径前缀
apiPrefix := appConfig.APIPrefix
// 创建处理器
http.HandleFunc(apiPrefix+"/v1/models", func(w http.ResponseWriter, r *http.Request) {
setCORSHeaders(w)
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 计数
countMutex.Lock()
requestCount++
currentCount := requestCount
countMutex.Unlock()
reqID := generateRequestID()
logInfo("[reqID:%s] 收到模型列表请求 #%d", reqID, currentCount)
// 请求计数
atomic.AddInt64(&requestCounter, 1)
startTime := time.Now()
// 创建任务
task := &Task{
r: r,
w: w,
done: make(chan struct{}),
reqID: reqID,
path: "/v1/models",
}
// 尝试获取信号量
if !requestSem.TryAcquire() {
// 请求数量超过限制
atomic.AddInt64(&rejectedRequests, 1)
logWarn("[reqID:%s] 请求被拒绝: 当前并发请求数已达上限", reqID)
w.Header().Set("Retry-After", "30")
http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable)
return
}
// 释放信号量(在函数返回时)
defer requestSem.Release()
// 添加到任务队列
atomic.AddInt64(&queuedRequests, 1)
submitted, err := workerPool.SubmitTask(task)
if !submitted {
atomic.AddInt64(&queuedRequests, -1)
atomic.AddInt64(&rejectedRequests, 1)
logError("[reqID:%s] 提交任务失败: %v", reqID, err)
w.Header().Set("Retry-After", "60")
http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable)
return
}
logInfo("[reqID:%s] 任务已提交到队列", reqID)
// 等待任务完成或超时
select {
case <-task.done:
// 任务已完成
logInfo("[reqID:%s] 任务已完成", reqID)
case <-r.Context().Done():
// 请求被取消或超时
logWarn("[reqID:%s] 请求被取消或超时", reqID)
}
// 请求处理完成,更新指标
atomic.AddInt64(&queuedRequests, -1)
elapsed := time.Since(startTime).Milliseconds()
// 更新平均响应时间
atomic.AddInt64(&avgResponseTime, elapsed)
if r.Context().Err() == nil {
// 成功计数增加
atomic.AddInt64(&successCounter, 1)
logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed)
} else {
logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed)
}
})
http.HandleFunc(apiPrefix+"/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
setCORSHeaders(w)
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 计数器增加
countMutex.Lock()
requestCount++
currentCount := requestCount
countMutex.Unlock()
reqID := generateRequestID()
logInfo("[reqID:%s] 收到新请求 #%d", reqID, currentCount)
// 请求计数
atomic.AddInt64(&requestCounter, 1)
startTime := time.Now()
// 尝试获取信号量
if !requestSem.TryAcquire() {
// 请求数量超过限制
atomic.AddInt64(&rejectedRequests, 1)
logWarn("[reqID:%s] 请求 #%d 被拒绝: 当前并发请求数已达上限", reqID, currentCount)
w.Header().Set("Retry-After", "30")
http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable)
return
}
// 释放信号量(在函数返回时)
defer requestSem.Release()
// 解析请求体
var apiReq APIRequest
if err := json.NewDecoder(r.Body).Decode(&apiReq); err != nil {
logError("[reqID:%s] 解析请求失败: %v", reqID, err)
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// 创建任务
task := &Task{
r: r,
w: w,
done: make(chan struct{}),
reqID: reqID,
isStream: apiReq.Stream,
apiReq: apiReq,
path: "/v1/chat/completions",
}
// 添加到任务队列
atomic.AddInt64(&queuedRequests, 1)
submitted, err := workerPool.SubmitTask(task)
if !submitted {
atomic.AddInt64(&queuedRequests, -1)
atomic.AddInt64(&rejectedRequests, 1)
logError("[reqID:%s] 提交任务失败: %v", reqID, err)
w.Header().Set("Retry-After", "60")
http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable)
return
}
logInfo("[reqID:%s] 任务已提交到队列", reqID)
// 等待任务完成或超时
select {
case <-task.done:
// 任务已完成
logInfo("[reqID:%s] 任务已完成", reqID)
case <-r.Context().Done():
// 请求被取消或超时
logWarn("[reqID:%s] 请求被取消或超时", reqID)
}
// 请求处理完成,更新指标
atomic.AddInt64(&queuedRequests, -1)
elapsed := time.Since(startTime).Milliseconds()
// 更新平均响应时间
atomic.AddInt64(&avgResponseTime, elapsed)
if r.Context().Err() == nil {
// 成功计数增加
atomic.AddInt64(&successCounter, 1)
logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed)
} else {
logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed)
}
})
http.HandleFunc(apiPrefix+"/v1/images/generations", func(w http.ResponseWriter, r *http.Request) {
setCORSHeaders(w)
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 计数器增加
countMutex.Lock()
requestCount++
currentCount := requestCount
countMutex.Unlock()
reqID := generateRequestID()
logInfo("[reqID:%s] 收到图像生成请求 #%d", reqID, currentCount)
// 请求计数
atomic.AddInt64(&requestCounter, 1)
startTime := time.Now()
// 尝试获取信号量
if !requestSem.TryAcquire() {
// 请求数量超过限制
atomic.AddInt64(&rejectedRequests, 1)
logWarn("[reqID:%s] 请求 #%d 被拒绝: 当前并发请求数已达上限", reqID, currentCount)
w.Header().Set("Retry-After", "30")
http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable)
return
}
// 释放信号量(在函数返回时)
defer requestSem.Release()
// 解析请求体
var apiReq APIRequest
if err := json.NewDecoder(r.Body).Decode(&apiReq); err != nil {
logError("[reqID:%s] 解析请求失败: %v", reqID, err)
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// 创建任务
task := &Task{
r: r,
w: w,
done: make(chan struct{}),
reqID: reqID,
apiReq: apiReq,
path: "/v1/images/generations",
}
// 添加到任务队列
atomic.AddInt64(&queuedRequests, 1)
submitted, err := workerPool.SubmitTask(task)
if !submitted {
atomic.AddInt64(&queuedRequests, -1)
atomic.AddInt64(&rejectedRequests, 1)
logError("[reqID:%s] 提交任务失败: %v", reqID, err)
w.Header().Set("Retry-After", "60")
http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable)
return
}
logInfo("[reqID:%s] 任务已提交到队列", reqID)
// 等待任务完成或超时
select {
case <-task.done:
// 任务已完成
logInfo("[reqID:%s] 任务已完成", reqID)
case <-r.Context().Done():
// 请求被取消或超时
logWarn("[reqID:%s] 请求被取消或超时", reqID)
}
// 请求处理完成,更新指标
atomic.AddInt64(&queuedRequests, -1)
elapsed := time.Since(startTime).Milliseconds()
// 更新平均响应时间
atomic.AddInt64(&avgResponseTime, elapsed)
if r.Context().Err() == nil {
// 成功计数增加
atomic.AddInt64(&successCounter, 1)
logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed)
} else {
logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed)
}
})
// 添加健康检查端点
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
setCORSHeaders(w)
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 获取各种计数器的值
reqCount := atomic.LoadInt64(&requestCounter)
succCount := atomic.LoadInt64(&successCounter)
errCount := atomic.LoadInt64(&errorCounter)
queuedCount := atomic.LoadInt64(&queuedRequests)
rejectedCount := atomic.LoadInt64(&rejectedRequests)
// 计算平均响应时间
var avgTime int64 = 0
if reqCount > 0 {
avgTime = atomic.LoadInt64(&avgResponseTime) / reqCount
}
// 构建响应
stats := map[string]interface{}{
"status": "ok",
"version": Version,
"requests": reqCount,
"success": succCount,
"errors": errCount,
"queued": queuedCount,
"rejected": rejectedCount,
"avg_time_ms": avgTime,
"worker_count": workerPool.workerCount,
"queue_size": len(workerPool.taskQueue),
"queue_capacity": cap(workerPool.taskQueue),
"queue_percent": float64(len(workerPool.taskQueue)) / float64(cap(workerPool.taskQueue)) * 100,
"concurrent_limit": appConfig.MaxConcurrent,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(stats)
})
// 创建停止通道
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
// 在goroutine中启动服务器
go func() {
logInfo("Starting proxy server on %s", server.Addr)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logError("Failed to start server: %v", err)
os.Exit(1)
}
}()
// 等待停止信号
<-stop
// 创建上下文用于优雅关闭
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 优雅关闭服务器
logInfo("Server is shutting down...")
if err := server.Shutdown(ctx); err != nil {
logError("Server shutdown failed: %v", err)
}
// 关闭工作池
workerPool.Shutdown()
logInfo("Server gracefully stopped")
}
// 生成请求ID
func generateRequestID() string {
return fmt.Sprintf("%x", time.Now().UnixNano())
}
// 处理模型列表请求
func handleModels(w http.ResponseWriter, r *http.Request) {
logInfo("处理模型列表请求")
// 从请求中提取token
authToken, err := extractToken(r)
if err != nil {
logWarn("提取token失败: %v", err)
// 使用默认模型列表
returnDefaultModels(w)
return
}
// 请求通义千问API获取模型列表
client := getHTTPClient()
req, err := http.NewRequest("GET", ModelsURL, nil)
if err != nil {
logError("创建请求失败: %v", err)
returnDefaultModels(w)
return
}
// 设置请求头
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
resp, err := client.Do(req)
if err != nil {
logError("请求模型列表失败: %v", err)
returnDefaultModels(w)
return
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
logError("获取模型列表返回非200状态码: %d", resp.StatusCode)
returnDefaultModels(w)
return
}
// 解析响应
var qwenResp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&qwenResp); err != nil {
logError("解析模型列表响应失败: %v", err)
returnDefaultModels(w)
return
}
// 提取模型ID
models := make([]string, 0, len(qwenResp.Data))
for _, model := range qwenResp.Data {
models = append(models, model.ID)
}
// 如果没有获取到模型,使用默认列表
if len(models) == 0 {
logWarn("未获取到模型,使用默认列表")
returnDefaultModels(w)
return
}
// 扩展模型列表,增加变种后缀
expandedModels := make([]ModelData, 0, len(models)*len(ModelSuffixes))
for _, model := range models {
for _, suffix := range ModelSuffixes {
expandedModels = append(expandedModels, ModelData{
ID: model + suffix,
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "qwen",
})
}
}
// 构建响应
modelsResp := ModelsResponse{
Object: "list",
Data: expandedModels,
}
// 返回响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(modelsResp)
}
// 返回默认模型列表
func returnDefaultModels(w http.ResponseWriter) {
// 扩展默认模型列表,增加变种后缀
expandedModels := make([]ModelData, 0, len(DefaultModels)*len(ModelSuffixes))
for _, model := range DefaultModels {
for _, suffix := range ModelSuffixes {
expandedModels = append(expandedModels, ModelData{
ID: model + suffix,
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "qwen",
})
}
}
// 构建响应
modelsResp := ModelsResponse{
Object: "list",
Data: expandedModels,
}
// 返回响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(modelsResp)
}
// 处理聊天完成请求(流式)
func handleStreamingRequest(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string) {
logInfo("[reqID:%s] 处理流式请求", reqID)
// 从请求中提取token
authToken, err := extractToken(r)
if err != nil {
logError("[reqID:%s] 提取token失败: %v", reqID, err)
http.Error(w, "无效的认证信息", http.StatusUnauthorized)
return
}
// 检查消息
if len(apiReq.Messages) == 0 {
logError("[reqID:%s] 消息为空", reqID)
http.Error(w, "消息为空", http.StatusBadRequest)
return
}
// 准备模型名和聊天类型
modelName := "qwen-turbo-latest"
if apiReq.Model != "" {
modelName = apiReq.Model
}
chatType := "t2t"
// 处理特殊模型名后缀
if strings.Contains(modelName, "-draw") {
handleDrawRequest(w, r, apiReq, reqID, authToken)
return
}
// 处理思考模式
if strings.Contains(modelName, "-thinking") {
modelName = strings.Replace(modelName, "-thinking", "", 1)
lastMsgIdx := len(apiReq.Messages) - 1
if lastMsgIdx >= 0 {
apiReq.Messages[lastMsgIdx].FeatureConfig = map[string]interface{}{
"thinking_enabled": true,
}
}
}
// 处理搜索模式
if strings.Contains(modelName, "-search") {
modelName = strings.Replace(modelName, "-search", "", 1)
chatType = "search"
lastMsgIdx := len(apiReq.Messages) - 1
if lastMsgIdx >= 0 {
apiReq.Messages[lastMsgIdx].ChatType = "search"
}
}
// 处理图片消息
lastMsgIdx := len(apiReq.Messages) - 1
if lastMsgIdx >= 0 {
lastMsg := apiReq.Messages[lastMsgIdx]
// 检查内容是否为数组
contentArray, ok := lastMsg.Content.([]interface{})
if ok {
// 处理内容数组
for i, item := range contentArray {
itemMap, isMap := item.(map[string]interface{})
if !isMap {
continue
}
// 检查是否包含图像URL
if imageURL, hasImageURL := itemMap["image_url"]; hasImageURL {
imageURLMap, isMap := imageURL.(map[string]interface{})
if !isMap {
continue
}
// 获取URL
url, hasURL := imageURLMap["url"].(string)
if !hasURL {
continue
}
// 上传图像
imageID, uploadErr := uploadImage(url, authToken)
if uploadErr != nil {
logError("[reqID:%s] 上传图像失败: %v", reqID, uploadErr)
continue
}
// 替换内容
contentArrayCopy := make([]interface{}, len(contentArray))
copy(contentArrayCopy, contentArray)
contentArrayCopy[i] = map[string]interface{}{
"type": "image",
"image": imageID,
}
apiReq.Messages[lastMsgIdx].Content = contentArrayCopy
break
}
}
}
}
// 创建通义千问请求
qwenReq := QwenRequest{
Model: modelName,
Messages: apiReq.Messages,
Stream: true,
ChatType: chatType,
ID: generateUUID(),
}
// 序列化请求
reqData, err := json.Marshal(qwenReq)
if err != nil {
logError("[reqID:%s] 序列化请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 创建HTTP请求
req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData))
if err != nil {
logError("[reqID:%s] 创建请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
client := getHTTPClient()
resp, err := client.Do(req)
if err != nil {
logError("[reqID:%s] 发送请求失败: %v", reqID, err)
http.Error(w, "连接到API失败", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes))
http.Error(w, fmt.Sprintf("API错误,状态码: %d", resp.StatusCode), resp.StatusCode)
return
}
// 设置响应头
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
// 创建响应ID和时间戳
respID := fmt.Sprintf("chatcmpl-%s", generateUUID())
createdTime := time.Now().Unix()
// 创建读取器和Flusher
reader := bufio.NewReaderSize(resp.Body, 16384)
flusher, ok := w.(http.Flusher)
if !ok {
logError("[reqID:%s] 流式传输不支持", reqID)
http.Error(w, "流式传输不支持", http.StatusInternalServerError)
return
}
// 发送角色块
roleChunk := createRoleChunk(respID, createdTime, modelName)
w.Write([]byte("data: " + string(roleChunk) + "\n\n"))
flusher.Flush()
// 用于去重的前一个内容
previousContent := ""
// 创建正则表达式来查找 data: 行
dataRegex := regexp.MustCompile(`(?m)^data: (.+)$`)
// 持续读取响应
buffer := ""
pendingContent := "" // 用于累积内容,解决流处理断开问题
for {
// 添加超时检测
select {
case <-r.Context().Done():
logWarn("[reqID:%s] 请求超时或被客户端取消", reqID)
return
default:
// 继续处理
}
// 读取一块数据
chunk := make([]byte, 4096)
n, err := reader.Read(chunk)
if err != nil {
if err != io.EOF {
logError("[reqID:%s] 读取响应出错: %v", reqID, err)
return
}
break
}
// 添加到缓冲区
buffer += string(chunk[:n])
// 更稳健的处理方式:按行分割并只处理完整行
lines := strings.Split(buffer, "\n")
// 保留最后可能不完整的行
if len(lines) > 0 {
buffer = lines[len(lines)-1]
}
// 处理所有完整的行(除最后一行外)
for i := 0; i < len(lines)-1; i++ {
line := lines[i]
if !strings.HasPrefix(line, "data: ") {
continue
}
// 提取数据部分
dataStr := strings.TrimPrefix(line, "data: ")
// 处理[DONE]消息
if dataStr == "[DONE]" {
logDebug("[reqID:%s] 收到[DONE]消息", reqID)
w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
continue
}
// 解析JSON
var qwenResp QwenResponse
if err := json.Unmarshal([]byte(dataStr), &qwenResp); err != nil {
logWarn("[reqID:%s] 解析JSON失败: %v, data: %s", reqID, err, dataStr)
continue
}
// 处理块
for _, choice := range qwenResp.Choices {
content := choice.Delta.Content
// 改进去重逻辑 - 只处理重复前缀
if previousContent != "" && strings.HasPrefix(content, previousContent) {
// 计算新增内容
newContent := content[len(previousContent):]
if newContent != "" {
// 创建内容块 - 只发送新部分
contentChunk := createContentChunk(respID, createdTime, modelName, newContent)
w.Write([]byte("data: " + string(contentChunk) + "\n\n"))
flusher.Flush()
pendingContent += newContent // 累积内容
}
} else if content != "" {
// 直接发送完整内容
contentChunk := createContentChunk(respID, createdTime, modelName, content)
w.Write([]byte("data: " + string(contentChunk) + "\n\n"))
flusher.Flush()
pendingContent += content // 累积内容
}
// 更新前一个内容为完整内容
if content != "" {
previousContent = content
}
// 处理完成标志
if choice.FinishReason != "" {
finishReason := choice.FinishReason
doneChunk := createDoneChunk(respID, createdTime, modelName, finishReason)
w.Write([]byte("data: " + string(doneChunk) + "\n\n"))
flusher.Flush()
}
}
}
}
// 检查是否有累积的内容需要作为最终响应
if pendingContent != "" {
logInfo("[reqID:%s] 流处理完成,累积内容长度: %d", reqID, len(pendingContent))
}
// 发送结束信号(如果没有正常结束)
finishReason := "stop"
doneChunk := createDoneChunk(respID, createdTime, modelName, finishReason)
w.Write([]byte("data: " + string(doneChunk) + "\n\n"))
w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}
// 处理聊天完成请求(非流式)
func handleNonStreamingRequest(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string) {
logInfo("[reqID:%s] 处理非流式请求", reqID)
// 从请求中提取token
authToken, err := extractToken(r)
if err != nil {
logError("[reqID:%s] 提取token失败: %v", reqID, err)
http.Error(w, "无效的认证信息", http.StatusUnauthorized)
return
}
// 检查消息
if len(apiReq.Messages) == 0 {
logError("[reqID:%s] 消息为空", reqID)
http.Error(w, "消息为空", http.StatusBadRequest)
return
}
// 准备模型名和聊天类型
modelName := "qwen-turbo-latest"
if apiReq.Model != "" {
modelName = apiReq.Model
}
chatType := "t2t"
// 处理特殊模型名后缀
if strings.Contains(modelName, "-draw") {
handleDrawRequest(w, r, apiReq, reqID, authToken)
return
}
// 处理思考模式
if strings.Contains(modelName, "-thinking") {
modelName = strings.Replace(modelName, "-thinking", "", 1)
lastMsgIdx := len(apiReq.Messages) - 1
if lastMsgIdx >= 0 {
apiReq.Messages[lastMsgIdx].FeatureConfig = map[string]interface{}{
"thinking_enabled": true,
}
}
}
// 处理搜索模式
if strings.Contains(modelName, "-search") {
modelName = strings.Replace(modelName, "-search", "", 1)
chatType = "search"
lastMsgIdx := len(apiReq.Messages) - 1
if lastMsgIdx >= 0 {
apiReq.Messages[lastMsgIdx].ChatType = "search"
}
}
// 处理图片消息
lastMsgIdx := len(apiReq.Messages) - 1
if lastMsgIdx >= 0 {
lastMsg := apiReq.Messages[lastMsgIdx]
// 检查内容是否为数组
contentArray, ok := lastMsg.Content.([]interface{})
if ok {
// 处理内容数组
for i, item := range contentArray {
itemMap, isMap := item.(map[string]interface{})
if !isMap {
continue
}
// 检查是否包含图像URL
if imageURL, hasImageURL := itemMap["image_url"]; hasImageURL {
imageURLMap, isMap := imageURL.(map[string]interface{})
if !isMap {
continue
}
// 获取URL
url, hasURL := imageURLMap["url"].(string)
if !hasURL {
continue
}
// 上传图像
imageID, uploadErr := uploadImage(url, authToken)
if uploadErr != nil {
logError("[reqID:%s] 上传图像失败: %v", reqID, uploadErr)
continue
}
// 替换内容
contentArrayCopy := make([]interface{}, len(contentArray))
copy(contentArrayCopy, contentArray)
contentArrayCopy[i] = map[string]interface{}{
"type": "image",
"image": imageID,
}
apiReq.Messages[lastMsgIdx].Content = contentArrayCopy
break
}
}
}
}
// 创建通义千问请求 - 通过流式请求来获取非流式响应
qwenReq := QwenRequest{
Model: modelName,
Messages: apiReq.Messages,
Stream: true, // 使用流式API
ChatType: chatType,
ID: generateUUID(),
}
// 序列化请求
reqData, err := json.Marshal(qwenReq)
if err != nil {
logError("[reqID:%s] 序列化请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 创建HTTP请求
req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData))
if err != nil {
logError("[reqID:%s] 创建请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
client := getHTTPClient()
resp, err := client.Do(req)
if err != nil {
logError("[reqID:%s] 发送请求失败: %v", reqID, err)
http.Error(w, "连接到API失败", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes))
http.Error(w, fmt.Sprintf("API错误,状态码: %d", resp.StatusCode), resp.StatusCode)
return
}
// 从流式响应中提取完整内容
fullContent, err := extractFullContentFromStream(resp.Body, reqID)
if err != nil {
logError("[reqID:%s] 提取内容失败: %v", reqID, err)
http.Error(w, "解析响应失败", http.StatusInternalServerError)
return
}
// 创建非流式响应
completionResponse := CompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", generateUUID()),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: modelName,
Choices: []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
}{
{
Index: 0,
Message: struct {
Role string `json:"role"`
Content string `json:"content"`
}{
Role: "assistant",
Content: fullContent,
},
FinishReason: "stop",
},
},
Usage: struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}{
PromptTokens: estimateTokens(apiReq.Messages),
CompletionTokens: len(fullContent) / 4,
TotalTokens: estimateTokens(apiReq.Messages) + len(fullContent)/4,
},
}
// 返回响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(completionResponse)
}
// 处理图像生成请求
func handleImageGenerations(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string) {
logInfo("[reqID:%s] 处理图像生成请求", reqID)
// 从请求中提取token
authToken, err := extractToken(r)
if err != nil {
logError("[reqID:%s] 提取token失败: %v", reqID, err)
http.Error(w, "无效的认证信息", http.StatusUnauthorized)
return
}
// 解析图像生成请求
var imgReq ImagesRequest
if err := json.NewDecoder(r.Body).Decode(&imgReq); err != nil {
logError("[reqID:%s] 解析图像请求失败: %v", reqID, err)
http.Error(w, "无效的请求体", http.StatusBadRequest)
return
}
// 默认值设置
if imgReq.Model == "" {
imgReq.Model = "qwen-max-latest-draw"
}
if imgReq.Size == "" {
imgReq.Size = "1024*1024"
}
if imgReq.N <= 0 {
imgReq.N = 1
}
// 获取纯模型名(去除-draw后缀)
modelName := strings.Replace(imgReq.Model, "-draw", "", 1)
modelName = strings.Replace(modelName, "-thinking", "", 1)
modelName = strings.Replace(modelName, "-search", "", 1)
// 创建图像生成任务
qwenReq := QwenRequest{
Stream: false,
IncrementalOutput: true,
ChatType: "t2i",
Model: modelName,
Messages: []APIMessage{
{
Role: "user",
Content: imgReq.Prompt,
ChatType: "t2i",
Extra: map[string]interface{}{},
FeatureConfig: map[string]interface{}{
"thinking_enabled": false,
},
},
},
ID: generateUUID(),
Size: imgReq.Size,
}
// 序列化请求
reqData, err := json.Marshal(qwenReq)
if err != nil {
logError("[reqID:%s] 序列化请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 创建HTTP请求
req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData))
if err != nil {
logError("[reqID:%s] 创建请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
client := getHTTPClient()
resp, err := client.Do(req)
if err != nil {
logError("[reqID:%s] 发送请求失败: %v", reqID, err)
http.Error(w, "连接到API失败", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes))
http.Error(w, fmt.Sprintf("API错误,状态码: %d", resp.StatusCode), resp.StatusCode)
return
}
// 解析响应获取任务ID
var qwenResp QwenResponse
if err := json.NewDecoder(resp.Body).Decode(&qwenResp); err != nil {
logError("[reqID:%s] 解析响应失败: %v", reqID, err)
http.Error(w, "解析响应失败", http.StatusInternalServerError)
return
}
// 提取任务ID
taskID := ""
for _, msg := range qwenResp.Messages {
if msg.Role == "assistant" && msg.Extra.Wanx.TaskID != "" {
taskID = msg.Extra.Wanx.TaskID
break
}
}
if taskID == "" {
logError("[reqID:%s] 无法获取图像生成任务ID", reqID)
http.Error(w, "无法获取图像生成任务ID", http.StatusInternalServerError)
return
}
// 轮询等待图像生成完成
var imageURL string
for i := 0; i < 30; i++ {
select {
case <-r.Context().Done():
logWarn("[reqID:%s] 请求超时或被客户端取消", reqID)
http.Error(w, "请求超时", http.StatusGatewayTimeout)
return
default:
// 继续处理
}
// 检查任务状态
statusURL := TasksURL + taskID
statusReq, err := http.NewRequestWithContext(r.Context(), "GET", statusURL, nil)
if err != nil {
logError("[reqID:%s] 创建状态请求失败: %v", reqID, err)
time.Sleep(6 * time.Second)
continue
}
// 设置请求头
statusReq.Header.Set("Authorization", "Bearer "+authToken)
statusReq.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
statusResp, err := client.Do(statusReq)
if err != nil {
logError("[reqID:%s] 发送状态请求失败: %v", reqID, err)
time.Sleep(6 * time.Second)
continue
}
// 解析响应
var statusData TaskStatusResponse
if err := json.NewDecoder(statusResp.Body).Decode(&statusData); err != nil {
logError("[reqID:%s] 解析状态响应失败: %v", reqID, err)
statusResp.Body.Close()
time.Sleep(6 * time.Second)
continue
}
statusResp.Body.Close()
// 检查是否有内容
if statusData.Content != "" {
imageURL = statusData.Content
break
}
time.Sleep(6 * time.Second)
}
if imageURL == "" {
logError("[reqID:%s] 图像生成超时", reqID)
http.Error(w, "图像生成超时", http.StatusGatewayTimeout)
return
}
// 构造图像列表
images := make([]ImageURL, imgReq.N)
for i := 0; i < imgReq.N; i++ {
images[i] = ImageURL{URL: imageURL}
}
// 返回响应
imgResp := ImagesResponse{
Created: time.Now().Unix(),
Data: images,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(imgResp)
}
// 处理特殊的绘图请求
func handleDrawRequest(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string, authToken string) {
logInfo("[reqID:%s] 处理绘图请求", reqID)
// 获取绘图提示
var prompt string
if len(apiReq.Messages) > 0 {
lastMsg := apiReq.Messages[len(apiReq.Messages)-1]
prompt, _ = lastMsg.Content.(string)
}
if prompt == "" {
logError("[reqID:%s] 绘图提示为空", reqID)
http.Error(w, "绘图提示为空", http.StatusBadRequest)
return
}
// 准备绘图请求参数
size := "1024*1024"
modelName := strings.Replace(apiReq.Model, "-draw", "", 1)
modelName = strings.Replace(modelName, "-thinking", "", 1)
modelName = strings.Replace(modelName, "-search", "", 1)
// 创建绘图请求
qwenReq := QwenRequest{
Stream: false,
IncrementalOutput: true,
ChatType: "t2i",
Model: modelName,
Messages: []APIMessage{
{
Role: "user",
Content: prompt,
ChatType: "t2i",
Extra: map[string]interface{}{},
FeatureConfig: map[string]interface{}{
"thinking_enabled": false,
},
},
},
ID: generateUUID(),
Size: size,
}
// 序列化请求
reqData, err := json.Marshal(qwenReq)
if err != nil {
logError("[reqID:%s] 序列化请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 创建HTTP请求
req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData))
if err != nil {
logError("[reqID:%s] 创建请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
client := getHTTPClient()
resp, err := client.Do(req)
if err != nil {
logError("[reqID:%s] 发送请求失败: %v", reqID, err)
http.Error(w, "连接到API失败", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes))
http.Error(w, fmt.Sprintf("API错误,状态码: %d", resp.StatusCode), resp.StatusCode)
return
}
// 解析响应获取任务ID
var qwenResp QwenResponse
if err := json.NewDecoder(resp.Body).Decode(&qwenResp); err != nil {
logError("[reqID:%s] 解析响应失败: %v", reqID, err)
http.Error(w, "解析响应失败", http.StatusInternalServerError)
return
}
// 提取任务ID
taskID := ""
for _, msg := range qwenResp.Messages {
if msg.Role == "assistant" && msg.Extra.Wanx.TaskID != "" {
taskID = msg.Extra.Wanx.TaskID
break
}
}
if taskID == "" {
logError("[reqID:%s] 无法获取图像生成任务ID", reqID)
http.Error(w, "无法获取图像生成任务ID", http.StatusInternalServerError)
return
}
// 轮询等待图像生成完成
var imageURL string
for i := 0; i < 30; i++ {
select {
case <-r.Context().Done():
logWarn("[reqID:%s] 请求超时或被客户端取消", reqID)
http.Error(w, "请求超时", http.StatusGatewayTimeout)
return
default:
// 继续处理
}
// 检查任务状态
statusURL := TasksURL + taskID
statusReq, err := http.NewRequestWithContext(r.Context(), "GET", statusURL, nil)
if err != nil {
logError("[reqID:%s] 创建状态请求失败: %v", reqID, err)
time.Sleep(6 * time.Second)
continue
}
// 设置请求头
statusReq.Header.Set("Authorization", "Bearer "+authToken)
statusReq.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
statusResp, err := client.Do(statusReq)
if err != nil {
logError("[reqID:%s] 发送状态请求失败: %v", reqID, err)
time.Sleep(6 * time.Second)
continue
}
// 解析响应
var statusData TaskStatusResponse
if err := json.NewDecoder(statusResp.Body).Decode(&statusData); err != nil {
logError("[reqID:%s] 解析状态响应失败: %v", reqID, err)
statusResp.Body.Close()
time.Sleep(6 * time.Second)
continue
}
statusResp.Body.Close()
// 检查是否有内容
if statusData.Content != "" {
imageURL = statusData.Content
break
}
time.Sleep(6 * time.Second)
}
if imageURL == "" {
logError("[reqID:%s] 图像生成超时", reqID)
http.Error(w, "图像生成超时", http.StatusGatewayTimeout)
return
}
// 返回OpenAI标准格式响应(使用Markdown嵌入图片)
completionResponse := CompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", generateUUID()),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: apiReq.Model,
Choices: []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
}{
{
Index: 0,
Message: struct {
Role string `json:"role"`
Content string `json:"content"`
}{
Role: "assistant",
Content: fmt.Sprintf("![%s](%s)", imageURL, imageURL),
},
FinishReason: "stop",
},
},
Usage: struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}{
PromptTokens: 1024,
CompletionTokens: 1024,
TotalTokens: 2048,
},
}
// 返回响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(completionResponse)
}
// 从流式响应中提取完整内容
func extractFullContentFromStream(body io.ReadCloser, reqID string) (string, error) {
var contentBuilder strings.Builder
// 创建读取器
reader := bufio.NewReaderSize(body, 16384)
// 持续读取响应
buffer := ""
for {
// 读取一块数据
chunk := make([]byte, 4096)
n, err := reader.Read(chunk)
if err != nil {
if err != io.EOF {
return contentBuilder.String(), err
}
break
}
// 添加到缓冲区
buffer += string(chunk[:n])
// 更稳健的处理方式:按行分割并只处理完整行
lines := strings.Split(buffer, "\n")
// 保留最后可能不完整的行
if len(lines) > 0 {
buffer = lines[len(lines)-1]
}
// 处理所有完整的行(除最后一行外)
for i := 0; i < len(lines)-1; i++ {
line := lines[i]
if !strings.HasPrefix(line, "data: ") {
continue
}
// 提取数据部分
dataStr := strings.TrimPrefix(line, "data: ")
// 处理[DONE]消息
if dataStr == "[DONE]" {
logDebug("[reqID:%s] 非流式模式收到[DONE]消息", reqID)
continue
}
// 解析JSON
var qwenResp QwenResponse
if err := json.Unmarshal([]byte(dataStr), &qwenResp); err != nil {
logWarn("[reqID:%s] 解析JSON失败: %v, data: %s", reqID, err, dataStr)
continue
}
// 提取内容 - 累积所有delta内容片段
for _, choice := range qwenResp.Choices {
if choice.Delta.Content != "" {
contentBuilder.WriteString(choice.Delta.Content)
}
}
}
}
// 记录提取的内容长度
contentStr := contentBuilder.String()
logInfo("[reqID:%s] 非流式模式:成功提取完整内容,长度: %d", reqID, len(contentStr))
return contentStr, nil
}
// 上传图像到千问API
func uploadImage(base64Data string, authToken string) (string, error) {
// 从base64数据中提取图片数据
if !strings.HasPrefix(base64Data, "data:") {
return "", fmt.Errorf("invalid base64 data format")
}
parts := strings.SplitN(base64Data, ",", 2)
if len(parts) != 2 {
return "", fmt.Errorf("invalid base64 data format")
}
imageData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("failed to decode base64 data: %v", err)
}
// 创建multipart表单
body := bytes.Buffer{}
writer := multipart.NewWriter(&body)
// 添加文件
part, err := writer.CreateFormFile("file", fmt.Sprintf("image-%d.jpg", time.Now().UnixNano()))
if err != nil {
return "", fmt.Errorf("failed to create form file: %v", err)
}
if _, err := part.Write(imageData); err != nil {
return "", fmt.Errorf("failed to write image data: %v", err)
}
// 关闭writer
if err := writer.Close(); err != nil {
return "", fmt.Errorf("failed to close writer: %v", err)
}
// 创建HTTP请求
req, err := http.NewRequest("POST", FilesURL, &body)
if err != nil {
return "", fmt.Errorf("failed to create request: %v", err)
}
// 设置请求头
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
client := getHTTPClient()
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %v", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API returned non-200 status code: %d, response: %s", resp.StatusCode, string(bodyBytes))
}
// 解析响应
var uploadResp FileUploadResponse
if err := json.NewDecoder(resp.Body).Decode(&uploadResp); err != nil {
return "", fmt.Errorf("failed to parse response: %v", err)
}
return uploadResp.ID, nil
}
// 创建角色块
func createRoleChunk(id string, created int64, model string) []byte {
chunk := StreamChunk{
ID: id,
Object: "chat.completion.chunk",
Created: created,
Model: model,
Choices: []struct {
Index int `json:"index"`
Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}{
{
Index: 0,
Delta: struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
}{
Role: "assistant",
},
},
},
}
data, _ := json.Marshal(chunk)
return data
}
// 创建内容块
func createContentChunk(id string, created int64, model string, content string) []byte {
chunk := StreamChunk{
ID: id,
Object: "chat.completion.chunk",
Created: created,
Model: model,
Choices: []struct {
Index int `json:"index"`
Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}{
{
Index: 0,
Delta: struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
}{
Content: content,
},
},
},
}
data, _ := json.Marshal(chunk)
return data
}
// 创建完成块
func createDoneChunk(id string, created int64, model string, reason string) []byte {
finishReason := reason
chunk := StreamChunk{
ID: id,
Object: "chat.completion.chunk",
Created: created,
Model: model,
Choices: []struct {
Index int `json:"index"`
Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}{
{
Index: 0,
Delta: struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
}{},
FinishReason: &finishReason,
},
},
}
data, _ := json.Marshal(chunk)
return data
}
// 估算tokens(简单实现)
func estimateTokens(messages []APIMessage) int {
var total int
for _, msg := range messages {
switch content := msg.Content.(type) {
case string:
total += len(content) / 4
case []interface{}:
for _, item := range content {
if itemMap, ok := item.(map[string]interface{}); ok {
if text, ok := itemMap["text"].(string); ok {
total += len(text) / 4
}
}
}
}
}
return total
}