本帖最后由 CrLf 于 2023-7-29 00:17 编辑
命令行下调用OpenAI接口,从标准输入中读取用户输入并将其发送到GPT模型,再将响应写入标准输出。因原版默认是UTF8,所以我改成默认以GBK编码读取输入,并增加 --utf8 开关兼容utf8编码。
原版GitHub:https://github.com/pdfinn/sgpt
用法:
sgpt -k <API_KEY> -i <INSTRUCTION> [-t TEMPERATURE] [-m MODEL] [-s SEPARATOR] [-u] [-d]COPY
参数说明:短参数 | 长参数 | 环境变量 | 描述 | 默认值 |
-k | --api_key | SGPT_API_KEY | 配置OpenAI的API KEY | 无 |
-i | --instruction | SGPT_INSTRUCTION | 系统指令,用于补充一些背景信息或要求 | 无 |
-t | --temperature | SGPT_TEMPERATURE | 温度值,范围是0~1,数值越高,给出的答案越有想象力但也更倾向于编造 | 0.5 |
-m | --model | SGPT_MODEL | 所采用的模型 | gpt-3.5-turbo |
-s | --separator | SGPT_SEPARATOR | 不同内容的分隔符 | \n |
-u | --utf8 | SGPT_UTF8 | 以UTF8编码解读输入内容(该参数由CrLf添加,使默认编码是GBK) | false |
-d | --debug | SGPT_DEBUG | 启用调试模式,将输出很多调试信息 | false |
CrLf修改后的源码:
package main | |
import ( | |
"bufio" | |
"encoding/json" | |
"fmt" | |
"github.com/spf13/pflag" | |
"github.com/spf13/viper" | |
"io" | |
"io/ioutil" | |
"log" | |
"net/http" | |
"os" | |
"strconv" | |
"strings" | |
// mod by CrLf 添加必要的模块 | |
"bytes" | |
"golang.org/x/text/encoding/simplifiedchinese" | |
"golang.org/x/text/transform" | |
) | |
// mod by CrLf 用于将UTF8转码为GBK | |
// UTF-8 转 GBK | |
func Utf8ToGbk(s []byte) ([]byte, error) { | |
reader := transform.NewReader(bytes.NewReader(s), simplifiedchinese.GBK.NewEncoder()) | |
d, e := ioutil.ReadAll(reader) | |
if e != nil { | |
return nil, e | |
} | |
return d, nil | |
} | |
func GbkToUtf8(s []byte) ([]byte, error) { | |
reader := transform.NewReader(bytes.NewReader(s), simplifiedchinese.GBK.NewDecoder()) | |
d, e := ioutil.ReadAll(reader) | |
if e != nil { | |
return nil, e | |
} | |
return d, nil | |
} | |
type OpenAIResponse struct { | |
Choices []struct { | |
Text string `json:"text,omitempty"` | |
Message struct { | |
Role string `json:"role,omitempty"` | |
Content string `json:"content,omitempty"` | |
} `json:"message,omitempty"` | |
} `json:"choices"` | |
} | |
// mod by CrLf 声明utf8变量 | |
var utf8 *bool | |
var debug *bool | |
func init() { | |
// mod by CrLf 去除重复的提醒 | |
// envUTF8 := os.Getenv("SGPT_UTF8") | |
// envDebug := os.Getenv("SGPT_DEBUG") | |
// utf8 = pflag.Bool("u", parseBoolWithDefault(envUTF8, false), "Enable UTF8 input") | |
// debug = pflag.Bool("d", parseBoolWithDefault(envDebug, false), "Enable debug output") | |
} | |
func main() { | |
// Default values | |
defaultTemperature := 0.5 | |
defaultModel := "gpt-3.5-turbo" | |
// Check environment variables | |
envApiKey := os.Getenv("SGPT_API_KEY") | |
envInstruction := os.Getenv("SGPT_INSTRUCTION") | |
envTemperature, err := strconv.ParseFloat(os.Getenv("SGPT_TEMPERATURE"), 64) | |
if err != nil { | |
envTemperature = defaultTemperature | |
} | |
envModel := os.Getenv("SGPT_MODEL") | |
envSeparator := os.Getenv("SGPT_SEPARATOR") | |
// mod by CrLf 增加对环境变量 SGPT_UTF8 的支持 | |
envUTF8 := parseBoolWithDefault(os.Getenv("SGPT_UTF8"), false) | |
envDebug := parseBoolWithDefault(os.Getenv("SGPT_DEBUG"), false) | |
// Command line arguments | |
apiKey := pflag.StringP("api_key", "k", envApiKey, "OpenAI API key") | |
instruction := pflag.StringP("instruction", "i", envInstruction, "Instruction for the GPT model") | |
temperature := pflag.Float64P("temperature", "t", envTemperature, "Temperature for the GPT model") | |
model := pflag.StringP("model", "m", envModel, "GPT model to use") | |
defaulSeparator := "\n" | |
separator := pflag.StringP("separator", "s", envSeparator, "Separator character for input") | |
if *separator == "" { | |
*separator = defaulSeparator | |
} | |
// mod by CrLf 增加对参数 --utf8 或 -u 的支持 | |
utf8 = pflag.BoolP("utf8", "u", envUTF8, "Enable UTF8 input") | |
debug = pflag.BoolP("debug", "d", envDebug, "Enable debug output") | |
pflag.Parse() | |
// Read the configuration file | |
viper.SetConfigName("sgpt") | |
viper.AddConfigPath(".") | |
viper.AddConfigPath("$HOME/.sgpt") | |
viper.SetConfigType("yaml") | |
err = viper.ReadInConfig() | |
// mod by CrLf 默认屏蔽无用警告,仅在debug模式下展示 | |
if _, ok := err.(viper.ConfigFileNotFoundError); ok { | |
debugOutput(*debug, "Warning: Config file not found: %v", err) | |
} else if err != nil { | |
debugOutput(*debug, "Warning: Error reading config file: %v", err) | |
} | |
// Set default values and bind configuration values to flags | |
viper.SetDefault("model", defaultModel) | |
viper.SetDefault("temperature", defaultTemperature) | |
viper.BindPFlag("api_key", pflag.Lookup("k")) | |
viper.BindPFlag("instruction", pflag.Lookup("i")) | |
viper.BindPFlag("model", pflag.Lookup("m")) | |
viper.BindPFlag("temperature", pflag.Lookup("t")) | |
viper.BindPFlag("separator", pflag.Lookup("s")) | |
viper.BindPFlag("debug", pflag.Lookup("d")) | |
// Use default values if neither flags nor environment variables are set | |
if *model == "" { | |
*model = defaultModel | |
} | |
if *apiKey == "" { | |
log.Fatal("API key is required") | |
} | |
// Read input from stdin continuously | |
// mod by CrLf 根据utf8开关的启禁用状态判断以utf8还是gbk读取stdin | |
var reader io.RuneReader | |
if *utf8 { | |
reader = bufio.NewReader(os.Stdin) | |
} else { | |
byteInput, _ := io.ReadAll(os.Stdin) | |
gbkBytes, _ := GbkToUtf8(byteInput) | |
reader = bytes.NewReader(gbkBytes) | |
} | |
var inputBuffer strings.Builder | |
for { | |
inputChar, _, err := reader.ReadRune() | |
if err == io.EOF { | |
input := inputBuffer.String() | |
if input != "" { | |
response, err := callOpenAI(*apiKey, *instruction, input, *temperature, *model) | |
if err != nil { | |
log.Fatal(err) | |
} | |
fmt.Println(response) | |
} | |
break | |
} | |
if err != nil { | |
log.Fatal(err) | |
} | |
if string(inputChar) == *separator { | |
input := inputBuffer.String() | |
inputBuffer.Reset() | |
response, err := callOpenAI(*apiKey, *instruction, input, *temperature, *model) | |
if err != nil { | |
log.Fatal(err) | |
} | |
fmt.Println(response) | |
} else { | |
inputBuffer.WriteRune(inputChar) | |
} | |
} | |
} | |
func debugOutput(debug bool, format string, a ...interface{}) { | |
if debug { | |
log.Printf(format, a...) | |
} | |
} | |
func parseFloatWithDefault(value string, defaultValue float64) float64 { | |
if value == "" { | |
return defaultValue | |
} | |
parsedValue, err := strconv.ParseFloat(value, 64) | |
if err != nil { | |
log.Printf("Warning: Failed to parse float value: %v", err) | |
return defaultValue | |
} | |
return parsedValue | |
} | |
func parseBoolWithDefault(value string, defaultValue bool) bool { | |
if value == "" { | |
return defaultValue | |
} | |
parsedValue, err := strconv.ParseBool(value) | |
if err != nil { | |
log.Printf("Warning: Failed to parse bool value: %v", err) | |
return defaultValue | |
} | |
return parsedValue | |
} | |
func callOpenAI(apiKey, instruction, input string, temperature float64, model string) (string, error) { | |
var url string | |
var jsonData []byte | |
var err error | |
switch model { | |
case "gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo": | |
url = "https://api.openai.com/v1/chat/completions" | |
// Prepare JSON data for GPT-4 models | |
messages := []map[string]string{ | |
{"role": "system", "content": instruction}, | |
{"role": "user", "content": input}, | |
} | |
jsonData, err = json.Marshal(map[string]interface{}{ | |
"model": model, | |
"messages": messages, | |
"temperature": temperature, | |
"max_tokens": 100, | |
"stop": []string{"\n"}, | |
}) | |
case "text-davinci-003", "text-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001": | |
url = "https://api.openai.com/v1/completions" | |
// Prepare JSON data for GPT-3 models | |
prompt := instruction + " " + input | |
jsonData, err = json.Marshal(map[string]interface{}{ | |
"model": model, | |
"prompt": prompt, | |
"temperature": temperature, | |
"max_tokens": 100, | |
"stop": []string{"\n"}, | |
}) | |
case "whisper-1": | |
url = "https://api.openai.com/v1/audio/transcriptions" | |
default: | |
return "", fmt.Errorf("unsupported model: %s", model) | |
} | |
if err != nil { | |
return "", err | |
} | |
data := strings.NewReader(string(jsonData)) | |
req, err := http.NewRequest("POST", url, data) | |
if err != nil { | |
return "", err | |
} | |
req.Header.Set("Content-Type", "application/json") | |
req.Header.Set("Authorization", "Bearer "+apiKey) | |
client := &http.Client{} | |
resp, err := client.Do(req) | |
if err != nil { | |
return "", err | |
} | |
defer resp.Body.Close() | |
body, err := ioutil.ReadAll(resp.Body) | |
if err != nil { | |
return "", err | |
} | |
debugOutput(*debug, "API response: %s\n", string(body)) | |
var openAIResponse OpenAIResponse | |
err = json.Unmarshal(body, &openAIResponse) | |
if err != nil { | |
return "", err | |
} | |
if len(openAIResponse.Choices) == 0 { | |
debugOutput(*debug, "API response: %s\n", string(body)) | |
debugOutput(*debug, "HTTP status code: %s\n", strconv.Itoa(resp.StatusCode)) | |
return "", fmt.Errorf("no choices returned from the API") | |
} | |
assistantMessage := "" | |
for _, choice := range openAIResponse.Choices { | |
if choice.Message.Role == "assistant" { | |
assistantMessage = strings.TrimSpace(choice.Message.Content) | |
break | |
} | |
if choice.Text != "" { | |
assistantMessage = strings.TrimSpace(choice.Text) | |
break | |
} | |
} | |
if assistantMessage == "" { | |
return "", fmt.Errorf("no assistant message found in the API response") | |
} | |
return assistantMessage, nil | |
}COPY |
编译后的下载地址:http://bcn.bathome.net/s/tool/index.html?key=sgpt
---------------------------------------------------------------------------------------
本帖最后由 CrLf 于 2023-7-29 00:13 编辑
举个例子:
echo 柬埔寨在哪里|sgpt.exe --api_key "***这里是你的openai_api_key***" --instruction "请用中文回答:" --model "gpt-3.5-turbo" | |
:: 回答为:柬埔寨位于东南亚,东临越南,南接泰国,西邻泰国和洞朗,北界老挝。COPY |
如果要传入非GBK字符,请 chcp 65001 后使用 --utf8 开关
出处:http://bbs.bathome.net/thread-66919-1-2.html
标签:pflag,return,string,err,nil,修改版,命令行,model,ChatGPT From: https://www.cnblogs.com/mq0036/p/18290338