diff --git a/config.dev.json b/config.dev.json index 5aad201..5a126cc 100644 --- a/config.dev.json +++ b/config.dev.json @@ -1,7 +1,7 @@ { "api_key": "xxxxxxxxx", - "session_timeout": 60, - "max_tokens": 1024, + "session_timeout": 180, + "max_tokens": 2000, "model": "text-davinci-003", "temperature": 0.9, "session_clear_token": "清空会话" diff --git a/config/config.go b/config/config.go index e7961db..5af2db2 100644 --- a/config/config.go +++ b/config/config.go @@ -34,13 +34,7 @@ var once sync.Once func LoadConfig() *Configuration { once.Do(func() { // 从文件中读取 - config = &Configuration{ - SessionTimeout: 60, - MaxTokens: 512, - Model: "text-davinci-003", - Temperature: 0.9, - SessionClearToken: "下一个问题", - } + config = &Configuration{} f, err := os.Open("config.json") if err != nil { logger.Danger("open config err: %v", err) @@ -55,7 +49,6 @@ func LoadConfig() *Configuration { } // 如果环境变量有配置,读取环境变量 - // 有环境变量使用环境变量 ApiKey := os.Getenv("APIKEY") SessionTimeout := os.Getenv("SESSION_TIMEOUT") Model := os.Getenv("MODEL") @@ -66,12 +59,14 @@ func LoadConfig() *Configuration { config.ApiKey = ApiKey } if SessionTimeout != "" { - duration, err := time.ParseDuration(SessionTimeout) + duration, err := strconv.ParseInt(SessionTimeout, 10, 64) if err != nil { logger.Danger(fmt.Sprintf("config session timeout err: %v ,get is %v", err, SessionTimeout)) return } - config.SessionTimeout = duration + config.SessionTimeout = time.Duration(duration) * time.Second + } else { + config.SessionTimeout = time.Duration(config.SessionTimeout) * time.Second } if Model != "" { config.Model = Model diff --git a/gtp/gtp.go b/gpt/gpt.go similarity index 70% rename from gtp/gtp.go rename to gpt/gpt.go index e390fd8..5debe6b 100644 --- a/gtp/gtp.go +++ b/gpt/gpt.go @@ -1,13 +1,11 @@ -package gtp +package gpt import ( "bytes" "encoding/json" - "errors" "fmt" "io/ioutil" "net/http" - "time" "github.com/eryajf/chatgpt-dingtalk/config" "github.com/eryajf/chatgpt-dingtalk/public/logger" @@ -34,13 +32,10 @@ type ChoiceItem struct { // ChatGPTRequestBody 响应体 type ChatGPTRequestBody struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - MaxTokens uint `json:"max_tokens"` - Temperature float64 `json:"temperature"` - TopP int `json:"top_p"` - FrequencyPenalty int `json:"frequency_penalty"` - PresencePenalty int `json:"presence_penalty"` + Model string `json:"model"` + Prompt string `json:"prompt"` + MaxTokens uint `json:"max_tokens"` + Temperature float64 `json:"temperature"` } // Completions gtp文本模型回复 @@ -51,13 +46,10 @@ type ChatGPTRequestBody struct { func Completions(msg string) (string, error) { cfg := config.LoadConfig() requestBody := ChatGPTRequestBody{ - Model: cfg.Model, - Prompt: msg, - MaxTokens: cfg.MaxTokens, - Temperature: cfg.Temperature, - TopP: 1, - FrequencyPenalty: 0, - PresencePenalty: 0, + Model: cfg.Model, + Prompt: msg, + MaxTokens: cfg.MaxTokens, + Temperature: cfg.Temperature, } requestData, err := json.Marshal(requestBody) if err != nil { @@ -69,23 +61,23 @@ func Completions(msg string) (string, error) { return "", err } - apiKey := config.LoadConfig().ApiKey req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - client := &http.Client{Timeout: 30 * time.Second} + req.Header.Set("Authorization", "Bearer "+cfg.ApiKey) + client := &http.Client{Timeout: cfg.SessionTimeout} response, err := client.Do(req) if err != nil { return "", err } defer response.Body.Close() - if response.StatusCode != 200 { - body, _ := ioutil.ReadAll(response.Body) - return "", errors.New(fmt.Sprintf("请求GTP出错了,gtp api status code not equals 200,code is %d ,details: %v ", response.StatusCode, string(body))) - } + body, err := ioutil.ReadAll(response.Body) if err != nil { return "", err } + + if response.StatusCode != 200 { + return "", fmt.Errorf("请求GTP出错了,gtp api status code not equals 200,code is %d ,details: %v ", response.StatusCode, string(body)) + } logger.Info(fmt.Sprintf("response gtp json string : %v", string(body))) gptResponseBody := &ChatGPTResponseBody{} diff --git a/main.go b/main.go index f1e7a60..042ca19 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,7 @@ import ( "net/http" "strings" - "github.com/eryajf/chatgpt-dingtalk/gtp" + "github.com/eryajf/chatgpt-dingtalk/gpt" "github.com/eryajf/chatgpt-dingtalk/public" "github.com/eryajf/chatgpt-dingtalk/public/logger" "github.com/eryajf/chatgpt-dingtalk/service" @@ -33,7 +33,6 @@ func Start() { return } // TODO: 校验请求 - // fmt.Println(r.Header) if len(data) == 0 { logger.Warning("回调参数为空,以至于无法正常解析,请检查原因") return @@ -76,9 +75,9 @@ func ProcessRequest(rmsg public.ReceiveMsg) error { } else { requestText := getRequestText(rmsg) // 获取问题的答案 - reply, err := gtp.Completions(requestText) + reply, err := gpt.Completions(requestText) if err != nil { - logger.Info("gtp request error: %v \n", err) + logger.Info("gpt request error: %v \n", err) _, err = rmsg.ReplyText("机器人太累了,让她休息会儿,过一会儿再来请求。") if err != nil { logger.Warning("send message error: %v \n", err)