Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc284cc1f0 | ||
|
|
9c0a49b97a | ||
|
|
61c47a3b08 | ||
|
|
c5aa59e297 | ||
|
|
211a862d54 | ||
|
|
c4c89e8e1b | ||
|
|
72983ac734 | ||
|
|
4d43dce64b | ||
|
|
0fa94d3c94 | ||
|
|
002dba5a75 | ||
|
|
fb24d024a7 | ||
|
|
eeb867da10 | ||
|
|
47b72b850f | ||
|
|
f44fbe3fe7 | ||
|
|
1c8922153d | ||
|
|
f3c07e1451 | ||
|
|
40ceb29e54 | ||
|
|
0699ecd0af | ||
|
|
ee9e746520 | ||
|
|
a763681c2e | ||
|
|
be613883a1 |
@@ -6,23 +6,61 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
var HttpClient *http.Client
|
var clientPool = &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return &http.Client{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func GetHttpClient(proxyAddr string) *http.Client {
|
||||||
if RelayTimeout == 0 {
|
client := clientPool.Get().(*http.Client)
|
||||||
HttpClient = &http.Client{}
|
|
||||||
} else {
|
if RelayTimeout > 0 {
|
||||||
HttpClient = &http.Client{
|
client.Timeout = time.Duration(RelayTimeout) * time.Second
|
||||||
Timeout: time.Duration(RelayTimeout) * time.Second,
|
}
|
||||||
|
|
||||||
|
if proxyAddr != "" {
|
||||||
|
proxyURL, err := url.Parse(proxyAddr)
|
||||||
|
if err != nil {
|
||||||
|
SysError("Error parsing proxy address: " + err.Error())
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
switch proxyURL.Scheme {
|
||||||
|
case "http", "https":
|
||||||
|
client.Transport = &http.Transport{
|
||||||
|
Proxy: http.ProxyURL(proxyURL),
|
||||||
|
}
|
||||||
|
case "socks5":
|
||||||
|
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
|
||||||
|
if err != nil {
|
||||||
|
SysError("Error creating SOCKS5 dialer: " + err.Error())
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
client.Transport = &http.Transport{
|
||||||
|
Dial: dialer.Dial,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
SysError("Unsupported proxy scheme: " + proxyURL.Scheme)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func PutHttpClient(c *http.Client) {
|
||||||
|
clientPool.Put(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@@ -92,12 +130,14 @@ func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
func SendRequest(req *http.Request, response any, outputResp bool, proxyAddr string) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := HttpClient.Do(req)
|
client := GetHttpClient(proxyAddr)
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
PutHttpClient(client)
|
||||||
|
|
||||||
if !outputResp {
|
if !outputResp {
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@@ -210,8 +250,10 @@ func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.Open
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) SendRequestRaw(req *http.Request) (body io.ReadCloser, err error) {
|
func (c *Client) SendRequestRaw(req *http.Request, proxyAddr string) (body io.ReadCloser, err error) {
|
||||||
resp, err := HttpClient.Do(req)
|
client := GetHttpClient(proxyAddr)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
PutHttpClient(client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -189,6 +189,7 @@ const (
|
|||||||
ChannelTypeTencent = 23
|
ChannelTypeTencent = 23
|
||||||
ChannelTypeAzureSpeech = 24
|
ChannelTypeAzureSpeech = 24
|
||||||
ChannelTypeGemini = 25
|
ChannelTypeGemini = 25
|
||||||
|
ChannelTypeBaichuan = 26
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
@@ -218,6 +219,7 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://hunyuan.cloud.tencent.com", //23
|
"https://hunyuan.cloud.tencent.com", //23
|
||||||
"", //24
|
"", //24
|
||||||
"", //25
|
"", //25
|
||||||
|
"https://api.baichuan-ai.com", //26
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package image
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"image"
|
"image"
|
||||||
_ "image/gif"
|
_ "image/gif"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
@@ -15,7 +16,22 @@ import (
|
|||||||
_ "golang.org/x/image/webp"
|
_ "golang.org/x/image/webp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func IsImageUrl(url string) (bool, error) {
|
||||||
|
resp, err := http.Head(url)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
func GetImageSizeFromUrl(url string) (width int, height int, err error) {
|
func GetImageSizeFromUrl(url string) (width int, height int, err error) {
|
||||||
|
isImage, err := IsImageUrl(url)
|
||||||
|
if !isImage {
|
||||||
|
return
|
||||||
|
}
|
||||||
resp, err := http.Get(url)
|
resp, err := http.Get(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@@ -28,6 +44,44 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
|
|||||||
return img.Width, img.Height, nil
|
return img.Width, img.Height, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||||
|
|
||||||
|
if strings.HasPrefix(url, "data:image/") {
|
||||||
|
dataURLPattern := regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
|
||||||
|
|
||||||
|
matches := dataURLPattern.FindStringSubmatch(url)
|
||||||
|
if len(matches) == 3 && matches[2] != "" {
|
||||||
|
mimeType = "image/" + matches[1]
|
||||||
|
data = matches[2]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = errors.New("image base64 decode failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
isImage, err := IsImageUrl(url)
|
||||||
|
if !isImage {
|
||||||
|
if err == nil {
|
||||||
|
err = errors.New("invalid image link")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
buffer := bytes.NewBuffer(nil)
|
||||||
|
_, err = buffer.ReadFrom(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mimeType = resp.Header.Get("Content-Type")
|
||||||
|
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
|
reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -152,3 +152,51 @@ func TestGetImageSize(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetImageSizeFromBase64(t *testing.T) {
|
||||||
|
for i, c := range cases {
|
||||||
|
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
|
||||||
|
resp, err := http.Get(c.url)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
encoded := base64.StdEncoding.EncodeToString(data)
|
||||||
|
width, height, err := img.GetImageSizeFromBase64(encoded)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, c.width, width)
|
||||||
|
assert.Equal(t, c.height, height)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetImageFromUrl(t *testing.T) {
|
||||||
|
for i, c := range cases {
|
||||||
|
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
|
||||||
|
resp, err := http.Get(c.url)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
encoded := base64.StdEncoding.EncodeToString(data)
|
||||||
|
|
||||||
|
mimeType, base64Data, err := img.GetImageFromUrl(c.url)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, encoded, base64Data)
|
||||||
|
assert.Equal(t, "image/"+c.format, mimeType)
|
||||||
|
|
||||||
|
encodedBase64 := "data:image/" + c.format + ";base64," + encoded
|
||||||
|
mimeType, base64Data, err = img.GetImageFromUrl(encodedBase64)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, encoded, base64Data)
|
||||||
|
assert.Equal(t, "image/"+c.format, mimeType)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
url := "https://raw.githubusercontent.com/songquanpeng/one-api/main/README.md"
|
||||||
|
_, _, err := img.GetImageFromUrl(url)
|
||||||
|
assert.Error(t, err)
|
||||||
|
encodedBase64 := "data:image/text;base64,"
|
||||||
|
_, _, err = img.GetImageFromUrl(encodedBase64)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ var ModelRatio = map[string]float64{
|
|||||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||||
"PaLM-2": 1,
|
"PaLM-2": 1,
|
||||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
|
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||||
@@ -92,6 +93,7 @@ var ModelRatio = map[string]float64{
|
|||||||
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
|
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
|
||||||
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
|
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
|
||||||
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
|
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
|
||||||
|
"qwen-vl-plus": 0.5715, // ¥0.008 / 1k tokens
|
||||||
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
||||||
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
||||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||||
@@ -99,6 +101,10 @@ var ModelRatio = map[string]float64{
|
|||||||
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||||
|
"Baichuan2-Turbo": 0.5715, // ¥0.008 / 1k tokens
|
||||||
|
"Baichuan2-Turbo-192k": 1.143, // ¥0.016 / 1k tokens
|
||||||
|
"Baichuan2-53B": 1.4286, // ¥0.02 / 1k tokens
|
||||||
|
"Baichuan-Text-Embedding": 0.0357, // ¥0.0005 / 1k tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelRatio2JSONString() string {
|
func ModelRatio2JSONString() string {
|
||||||
@@ -115,6 +121,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetModelRatio(name string) float64 {
|
func GetModelRatio(name string) float64 {
|
||||||
|
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
|
||||||
|
name = strings.TrimSuffix(name, "-internet")
|
||||||
|
}
|
||||||
ratio, ok := ModelRatio[name]
|
ratio, ok := ModelRatio[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
SysError("model ratio not found: " + name)
|
SysError("model ratio not found: " + name)
|
||||||
|
|||||||
@@ -190,13 +190,13 @@ func countImageTokens(url string, detail string) (_ int, err error) {
|
|||||||
func CountTokenInput(input any, model string) int {
|
func CountTokenInput(input any, model string) int {
|
||||||
switch v := input.(type) {
|
switch v := input.(type) {
|
||||||
case string:
|
case string:
|
||||||
return CountTokenInput(v, model)
|
return CountTokenText(v, model)
|
||||||
case []string:
|
case []string:
|
||||||
text := ""
|
text := ""
|
||||||
for _, s := range v {
|
for _, s := range v {
|
||||||
text += s
|
text += s
|
||||||
}
|
}
|
||||||
return CountTokenInput(text, model)
|
return CountTokenText(text, model)
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,10 +55,9 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
|||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Request = req
|
c.Request = req
|
||||||
|
|
||||||
setChannelToContext(c, channel)
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
provider := providers.GetProvider(channel.Type, c)
|
provider := providers.GetProvider(channel, c)
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
return 0, errors.New("provider not found")
|
return 0, errors.New("provider not found")
|
||||||
}
|
}
|
||||||
@@ -102,7 +101,6 @@ func UpdateChannelBalance(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"balance": balance,
|
"balance": balance,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateAllChannelsBalance() error {
|
func updateAllChannelsBalance() error {
|
||||||
@@ -146,7 +144,6 @@ func UpdateAllChannelsBalance(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func AutomaticallyUpdateChannels(frequency int) {
|
func AutomaticallyUpdateChannels(frequency int) {
|
||||||
|
|||||||
@@ -18,6 +18,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (err error, openaiErr *types.OpenAIError) {
|
func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (err error, openaiErr *types.OpenAIError) {
|
||||||
|
if channel.TestModel == "" {
|
||||||
|
return errors.New("请填写测速模型后再试"), nil
|
||||||
|
}
|
||||||
|
|
||||||
// 创建一个 http.Request
|
// 创建一个 http.Request
|
||||||
req, err := http.NewRequest("POST", "/v1/chat/completions", nil)
|
req, err := http.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -28,29 +32,9 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Request = req
|
c.Request = req
|
||||||
|
request.Model = channel.TestModel
|
||||||
|
|
||||||
setChannelToContext(c, channel)
|
provider := providers.GetProvider(channel, c)
|
||||||
// 创建映射
|
|
||||||
channelTypeToModel := map[int]string{
|
|
||||||
common.ChannelTypePaLM: "PaLM-2",
|
|
||||||
common.ChannelTypeAnthropic: "claude-2",
|
|
||||||
common.ChannelTypeBaidu: "ERNIE-Bot",
|
|
||||||
common.ChannelTypeZhipu: "chatglm_lite",
|
|
||||||
common.ChannelTypeAli: "qwen-turbo",
|
|
||||||
common.ChannelType360: "360GPT_S2_V9",
|
|
||||||
common.ChannelTypeXunfei: "SparkDesk",
|
|
||||||
common.ChannelTypeTencent: "hunyuan",
|
|
||||||
common.ChannelTypeAzure: "gpt-3.5-turbo",
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从映射中获取模型名称
|
|
||||||
model, ok := channelTypeToModel[channel.Type]
|
|
||||||
if !ok {
|
|
||||||
model = "gpt-3.5-turbo" // 默认值
|
|
||||||
}
|
|
||||||
request.Model = model
|
|
||||||
|
|
||||||
provider := providers.GetProvider(channel.Type, c)
|
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
return errors.New("channel not implemented"), nil
|
return errors.New("channel not implemented"), nil
|
||||||
}
|
}
|
||||||
@@ -70,13 +54,15 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
|||||||
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
|
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
|
||||||
Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
|
Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if openAIErrorWithStatusCode != nil {
|
||||||
return nil, &openAIErrorWithStatusCode.OpenAIError
|
return errors.New(openAIErrorWithStatusCode.Message), &openAIErrorWithStatusCode.OpenAIError
|
||||||
}
|
}
|
||||||
|
|
||||||
if Usage.CompletionTokens == 0 {
|
if Usage.CompletionTokens == 0 {
|
||||||
return errors.New(fmt.Sprintf("channel %s, message 补全 tokens 非预期返回 0", channel.Name)), nil
|
return fmt.Errorf("channel %s, message 补全 tokens 非预期返回 0", channel.Name), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common.SysLog(fmt.Sprintf("测试模型 %s 返回内容为:%s", channel.Name, w.Body.String()))
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,7 +118,6 @@ func TestChannel(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"time": consumedTime,
|
"time": consumedTime,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var testAllChannelsLock sync.Mutex
|
var testAllChannelsLock sync.Mutex
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAllChannels(c *gin.Context) {
|
func GetAllChannels(c *gin.Context) {
|
||||||
@@ -27,7 +28,6 @@ func GetAllChannels(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": channels,
|
"data": channels,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchChannels(c *gin.Context) {
|
func SearchChannels(c *gin.Context) {
|
||||||
@@ -45,7 +45,6 @@ func SearchChannels(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": channels,
|
"data": channels,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetChannel(c *gin.Context) {
|
func GetChannel(c *gin.Context) {
|
||||||
@@ -70,7 +69,6 @@ func GetChannel(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": channel,
|
"data": channel,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddChannel(c *gin.Context) {
|
func AddChannel(c *gin.Context) {
|
||||||
@@ -106,7 +104,6 @@ func AddChannel(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteChannel(c *gin.Context) {
|
func DeleteChannel(c *gin.Context) {
|
||||||
@@ -124,7 +121,6 @@ func DeleteChannel(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteDisabledChannel(c *gin.Context) {
|
func DeleteDisabledChannel(c *gin.Context) {
|
||||||
@@ -141,7 +137,6 @@ func DeleteDisabledChannel(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": rows,
|
"data": rows,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateChannel(c *gin.Context) {
|
func UpdateChannel(c *gin.Context) {
|
||||||
@@ -167,5 +162,4 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": channel,
|
"data": channel,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,13 +5,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-contrib/sessions"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GitHubOAuthResponse struct {
|
type GitHubOAuthResponse struct {
|
||||||
@@ -211,7 +212,6 @@ func GitHubBind(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "bind",
|
"message": "bind",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateOAuthCode(c *gin.Context) {
|
func GenerateOAuthCode(c *gin.Context) {
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
func GetGroups(c *gin.Context) {
|
func GetGroups(c *gin.Context) {
|
||||||
groupNames := make([]string, 0)
|
groupNames := make([]string, 0)
|
||||||
for groupName, _ := range common.GroupRatio {
|
for groupName := range common.GroupRatio {
|
||||||
groupNames = append(groupNames, groupName)
|
groupNames = append(groupNames, groupName)
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAllLogs(c *gin.Context) {
|
func GetAllLogs(c *gin.Context) {
|
||||||
@@ -33,7 +34,6 @@ func GetAllLogs(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserLogs(c *gin.Context) {
|
func GetUserLogs(c *gin.Context) {
|
||||||
@@ -60,7 +60,6 @@ func GetUserLogs(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchAllLogs(c *gin.Context) {
|
func SearchAllLogs(c *gin.Context) {
|
||||||
@@ -78,7 +77,6 @@ func SearchAllLogs(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchUserLogs(c *gin.Context) {
|
func SearchUserLogs(c *gin.Context) {
|
||||||
@@ -97,7 +95,6 @@ func SearchUserLogs(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetLogsStat(c *gin.Context) {
|
func GetLogsStat(c *gin.Context) {
|
||||||
@@ -118,7 +115,6 @@ func GetLogsStat(c *gin.Context) {
|
|||||||
//"token": tokenNum,
|
//"token": tokenNum,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetLogsSelfStat(c *gin.Context) {
|
func GetLogsSelfStat(c *gin.Context) {
|
||||||
@@ -139,7 +135,6 @@ func GetLogsSelfStat(c *gin.Context) {
|
|||||||
//"token": tokenNum,
|
//"token": tokenNum,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteHistoryLogs(c *gin.Context) {
|
func DeleteHistoryLogs(c *gin.Context) {
|
||||||
@@ -164,5 +159,4 @@ func DeleteHistoryLogs(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": count,
|
"data": count,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ func GetStatus(c *gin.Context) {
|
|||||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetNotice(c *gin.Context) {
|
func GetNotice(c *gin.Context) {
|
||||||
@@ -46,7 +45,6 @@ func GetNotice(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": common.OptionMap["Notice"],
|
"data": common.OptionMap["Notice"],
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAbout(c *gin.Context) {
|
func GetAbout(c *gin.Context) {
|
||||||
@@ -57,7 +55,6 @@ func GetAbout(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": common.OptionMap["About"],
|
"data": common.OptionMap["About"],
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetHomePageContent(c *gin.Context) {
|
func GetHomePageContent(c *gin.Context) {
|
||||||
@@ -68,7 +65,6 @@ func GetHomePageContent(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": common.OptionMap["HomePageContent"],
|
"data": common.OptionMap["HomePageContent"],
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendEmailVerification(c *gin.Context) {
|
func SendEmailVerification(c *gin.Context) {
|
||||||
@@ -121,7 +117,6 @@ func SendEmailVerification(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendPasswordResetEmail(c *gin.Context) {
|
func SendPasswordResetEmail(c *gin.Context) {
|
||||||
@@ -160,7 +155,6 @@ func SendPasswordResetEmail(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PasswordResetRequest struct {
|
type PasswordResetRequest struct {
|
||||||
@@ -200,5 +194,4 @@ func ResetPassword(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": password,
|
"data": password,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,11 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
|
"sort"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -25,550 +29,38 @@ type OpenAIModelPermission struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIModels struct {
|
type OpenAIModels struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Created int `json:"created"`
|
Created int `json:"created"`
|
||||||
OwnedBy string `json:"owned_by"`
|
OwnedBy *string `json:"owned_by"`
|
||||||
Permission []OpenAIModelPermission `json:"permission"`
|
Permission *[]OpenAIModelPermission `json:"permission"`
|
||||||
Root string `json:"root"`
|
Root *string `json:"root"`
|
||||||
Parent *string `json:"parent"`
|
Parent *string `json:"parent"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var openAIModels []OpenAIModels
|
var openAIModels []OpenAIModels
|
||||||
var openAIModelsMap map[string]OpenAIModels
|
var openAIModelsMap map[string]OpenAIModels
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
var permission []OpenAIModelPermission
|
|
||||||
permission = append(permission, OpenAIModelPermission{
|
|
||||||
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
|
|
||||||
Object: "model_permission",
|
|
||||||
Created: 1626777600,
|
|
||||||
AllowCreateEngine: true,
|
|
||||||
AllowSampling: true,
|
|
||||||
AllowLogprobs: true,
|
|
||||||
AllowSearchIndices: false,
|
|
||||||
AllowView: true,
|
|
||||||
AllowFineTuning: false,
|
|
||||||
Organization: "*",
|
|
||||||
Group: nil,
|
|
||||||
IsBlocking: false,
|
|
||||||
})
|
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
openAIModels = []OpenAIModels{
|
keys := make([]string, 0, len(common.ModelRatio))
|
||||||
{
|
for k := range common.ModelRatio {
|
||||||
Id: "dall-e-2",
|
keys = append(keys, k)
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "dall-e-2",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "dall-e-3",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "dall-e-3",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "whisper-1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "whisper-1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "tts-1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "tts-1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "tts-1-1106",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "tts-1-1106",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "tts-1-hd",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "tts-1-hd",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "tts-1-hd-1106",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "tts-1-hd-1106",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo-0301",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo-0301",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo-0613",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo-0613",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo-16k",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo-16k",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo-16k-0613",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo-16k-0613",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo-1106",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1699593571,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo-1106",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-3.5-turbo-instruct",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-3.5-turbo-instruct",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-0314",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-0314",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-0613",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-0613",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-32k",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-32k",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-32k-0314",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-32k-0314",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-32k-0613",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-32k-0613",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-1106-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1699593571,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-1106-preview",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gpt-4-vision-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1699593571,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gpt-4-vision-preview",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-embedding-ada-002",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-embedding-ada-002",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-davinci-003",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-davinci-003",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-davinci-002",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-davinci-002",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-curie-001",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-curie-001",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-babbage-001",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-babbage-001",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-ada-001",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-ada-001",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-moderation-latest",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-moderation-latest",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-moderation-stable",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-moderation-stable",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-davinci-edit-001",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-davinci-edit-001",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "code-davinci-edit-001",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "code-davinci-edit-001",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "claude-instant-1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "anthropic",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "claude-instant-1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "claude-2",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "anthropic",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "claude-2",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "claude-2.1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "anthropic",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "claude-2.1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "claude-2.0",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "anthropic",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "claude-2.0",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "ERNIE-Bot",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "baidu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "ERNIE-Bot",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "ERNIE-Bot-turbo",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "baidu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "ERNIE-Bot-turbo",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "ERNIE-Bot-4",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "baidu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "ERNIE-Bot-4",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "Embedding-V1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "baidu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "Embedding-V1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "PaLM-2",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "PaLM-2",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "gemini-pro",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "gemini-pro",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "chatglm_turbo",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "zhipu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "chatglm_turbo",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "chatglm_pro",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "zhipu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "chatglm_pro",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "chatglm_std",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "zhipu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "chatglm_std",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "chatglm_lite",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "zhipu",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "chatglm_lite",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "qwen-turbo",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "ali",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "qwen-turbo",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "qwen-plus",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "ali",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "qwen-plus",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "qwen-max",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "ali",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "qwen-max",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "qwen-max-longcontext",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "ali",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "qwen-max-longcontext",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "text-embedding-v1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "ali",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "text-embedding-v1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "SparkDesk",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "xunfei",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "SparkDesk",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "360GPT_S2_V9",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "360",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "360GPT_S2_V9",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "embedding-bert-512-v1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "360",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "embedding-bert-512-v1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "embedding_s1_v1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "360",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "embedding_s1_v1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "semantic_similarity_s1_v1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "360",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "semantic_similarity_s1_v1",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "hunyuan",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "tencent",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "hunyuan",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
for _, modelId := range keys {
|
||||||
|
openAIModels = append(openAIModels, OpenAIModels{
|
||||||
|
Id: modelId,
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: nil,
|
||||||
|
Permission: nil,
|
||||||
|
Root: nil,
|
||||||
|
Parent: nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
openAIModelsMap = make(map[string]OpenAIModels)
|
openAIModelsMap = make(map[string]OpenAIModels)
|
||||||
for _, model := range openAIModels {
|
for _, model := range openAIModels {
|
||||||
openAIModelsMap[model.Id] = model
|
openAIModelsMap[model.Id] = model
|
||||||
@@ -576,6 +68,35 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ListModels(c *gin.Context) {
|
func ListModels(c *gin.Context) {
|
||||||
|
groupName := c.GetString("group")
|
||||||
|
|
||||||
|
models, err := model.CacheGetGroupModels(groupName)
|
||||||
|
if err != nil {
|
||||||
|
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sort.Strings(models)
|
||||||
|
|
||||||
|
groupOpenAIModels := make([]OpenAIModels, 0, len(models))
|
||||||
|
for _, modelId := range models {
|
||||||
|
groupOpenAIModels = append(groupOpenAIModels, OpenAIModels{
|
||||||
|
Id: modelId,
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: nil,
|
||||||
|
Permission: nil,
|
||||||
|
Root: nil,
|
||||||
|
Parent: nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"object": "list",
|
||||||
|
"data": groupOpenAIModels,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListModelsForAdmin(c *gin.Context) {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": openAIModels,
|
"data": openAIModels,
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func RelayChat(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeChatCompletions)
|
provider, pass := getProvider(c, channel, common.RelayModeChatCompletions)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func RelayCompletions(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeCompletions)
|
provider, pass := getProvider(c, channel, common.RelayModeCompletions)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func RelayEmbeddings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeEmbeddings)
|
provider, pass := getProvider(c, channel, common.RelayModeEmbeddings)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func RelayImageEdits(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeImagesEdits)
|
provider, pass := getProvider(c, channel, common.RelayModeImagesEdits)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ func RelayImageGenerations(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeImagesGenerations)
|
provider, pass := getProvider(c, channel, common.RelayModeImagesGenerations)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func RelayImageVariations(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeImagesVariations)
|
provider, pass := getProvider(c, channel, common.RelayModeImagesVariations)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func RelayModerations(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeModerations)
|
provider, pass := getProvider(c, channel, common.RelayModeModerations)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func RelaySpeech(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeAudioSpeech)
|
provider, pass := getProvider(c, channel, common.RelayModeAudioSpeech)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func RelayTranscriptions(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeAudioTranscription)
|
provider, pass := getProvider(c, channel, common.RelayModeAudioTranscription)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func RelayTranslations(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeAudioTranslation)
|
provider, pass := getProvider(c, channel, common.RelayModeAudioTranslation)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, pas
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
setChannelToContext(c, channel)
|
c.Set("channel_id", channel.Id)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,8 +85,8 @@ func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool
|
|||||||
return channel, false
|
return channel, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func getProvider(c *gin.Context, channelType int, relayMode int) (providersBase.ProviderInterface, bool) {
|
func getProvider(c *gin.Context, channel *model.Channel, relayMode int) (providersBase.ProviderInterface, bool) {
|
||||||
provider := providers.GetProvider(channelType, c)
|
provider := providers.GetProvider(channel, c)
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
|
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
|
||||||
return nil, true
|
return nil, true
|
||||||
@@ -99,27 +100,6 @@ func getProvider(c *gin.Context, channelType int, relayMode int) (providersBase.
|
|||||||
return provider, false
|
return provider, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func setChannelToContext(c *gin.Context, channel *model.Channel) {
|
|
||||||
// c.Set("channel", channel.Type)
|
|
||||||
c.Set("channel_id", channel.Id)
|
|
||||||
c.Set("channel_name", channel.Name)
|
|
||||||
c.Set("api_key", channel.Key)
|
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
|
||||||
switch channel.Type {
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeGemini:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeAIProxyLibrary:
|
|
||||||
c.Set("library_id", channel.Other)
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
c.Set("plugin", channel.Other)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
|
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
|
||||||
if !common.AutomaticDisableChannelEnabled {
|
if !common.AutomaticDisableChannelEnabled {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAllTokens(c *gin.Context) {
|
func GetAllTokens(c *gin.Context) {
|
||||||
@@ -27,7 +28,6 @@ func GetAllTokens(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": tokens,
|
"data": tokens,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchTokens(c *gin.Context) {
|
func SearchTokens(c *gin.Context) {
|
||||||
@@ -46,7 +46,6 @@ func SearchTokens(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": tokens,
|
"data": tokens,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetToken(c *gin.Context) {
|
func GetToken(c *gin.Context) {
|
||||||
@@ -72,7 +71,6 @@ func GetToken(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": token,
|
"data": token,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTokenStatus(c *gin.Context) {
|
func GetTokenStatus(c *gin.Context) {
|
||||||
@@ -138,7 +136,6 @@ func AddToken(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteToken(c *gin.Context) {
|
func DeleteToken(c *gin.Context) {
|
||||||
@@ -156,7 +153,6 @@ func DeleteToken(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateToken(c *gin.Context) {
|
func UpdateToken(c *gin.Context) {
|
||||||
@@ -224,5 +220,4 @@ func UpdateToken(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": cleanToken,
|
"data": cleanToken,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -174,7 +174,6 @@ func Register(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllUsers(c *gin.Context) {
|
func GetAllUsers(c *gin.Context) {
|
||||||
@@ -195,7 +194,6 @@ func GetAllUsers(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": users,
|
"data": users,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchUsers(c *gin.Context) {
|
func SearchUsers(c *gin.Context) {
|
||||||
@@ -213,7 +211,6 @@ func SearchUsers(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": users,
|
"data": users,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(c *gin.Context) {
|
func GetUser(c *gin.Context) {
|
||||||
@@ -246,7 +243,6 @@ func GetUser(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": user,
|
"data": user,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserDashboard(c *gin.Context) {
|
func GetUserDashboard(c *gin.Context) {
|
||||||
@@ -306,7 +302,6 @@ func GenerateAccessToken(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": user.AccessToken,
|
"data": user.AccessToken,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAffCode(c *gin.Context) {
|
func GetAffCode(c *gin.Context) {
|
||||||
@@ -334,7 +329,6 @@ func GetAffCode(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": user.AffCode,
|
"data": user.AffCode,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSelf(c *gin.Context) {
|
func GetSelf(c *gin.Context) {
|
||||||
@@ -352,7 +346,6 @@ func GetSelf(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": user,
|
"data": user,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUser(c *gin.Context) {
|
func UpdateUser(c *gin.Context) {
|
||||||
@@ -416,7 +409,6 @@ func UpdateUser(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateSelf(c *gin.Context) {
|
func UpdateSelf(c *gin.Context) {
|
||||||
@@ -463,7 +455,6 @@ func UpdateSelf(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteUser(c *gin.Context) {
|
func DeleteUser(c *gin.Context) {
|
||||||
@@ -525,7 +516,6 @@ func DeleteSelf(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateUser(c *gin.Context) {
|
func CreateUser(c *gin.Context) {
|
||||||
@@ -574,7 +564,6 @@ func CreateUser(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ManageRequest struct {
|
type ManageRequest struct {
|
||||||
@@ -691,7 +680,6 @@ func ManageUser(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": clearUser,
|
"data": clearUser,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func EmailBind(c *gin.Context) {
|
func EmailBind(c *gin.Context) {
|
||||||
@@ -733,7 +721,6 @@ func EmailBind(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type topUpRequest struct {
|
type topUpRequest struct {
|
||||||
@@ -764,5 +751,4 @@ func TopUp(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": quota,
|
"data": quota,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,13 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type wechatLoginResponse struct {
|
type wechatLoginResponse struct {
|
||||||
@@ -160,5 +161,4 @@ func WeChatBind(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|||||||
6
go.mod
6
go.mod
@@ -16,7 +16,7 @@ require (
|
|||||||
github.com/gorilla/websocket v1.5.0
|
github.com/gorilla/websocket v1.5.0
|
||||||
github.com/pkoukk/tiktoken-go v0.1.5
|
github.com/pkoukk/tiktoken-go v0.1.5
|
||||||
github.com/stretchr/testify v1.8.3
|
github.com/stretchr/testify v1.8.3
|
||||||
golang.org/x/crypto v0.14.0
|
golang.org/x/crypto v0.17.0
|
||||||
golang.org/x/image v0.14.0
|
golang.org/x/image v0.14.0
|
||||||
gorm.io/driver/mysql v1.4.3
|
gorm.io/driver/mysql v1.4.3
|
||||||
gorm.io/driver/postgres v1.5.2
|
gorm.io/driver/postgres v1.5.2
|
||||||
@@ -58,8 +58,8 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||||
golang.org/x/arch v0.3.0 // indirect
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
golang.org/x/net v0.17.0 // indirect
|
golang.org/x/net v0.19.0 // indirect
|
||||||
golang.org/x/sys v0.13.0 // indirect
|
golang.org/x/sys v0.15.0 // indirect
|
||||||
golang.org/x/text v0.14.0 // indirect
|
golang.org/x/text v0.14.0 // indirect
|
||||||
google.golang.org/protobuf v1.30.0 // indirect
|
google.golang.org/protobuf v1.30.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
|||||||
10
go.sum
10
go.sum
@@ -152,13 +152,15 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
|
|||||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
||||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||||
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
|
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
|
||||||
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||||
|
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
|
||||||
|
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
|
||||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
@@ -166,8 +168,8 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
||||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"runtime/debug"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RelayPanicRecover() gin.HandlerFunc {
|
func RelayPanicRecover() gin.HandlerFunc {
|
||||||
@@ -12,6 +13,7 @@ func RelayPanicRecover() gin.HandlerFunc {
|
|||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
common.SysError(fmt.Sprintf("panic detected: %v", err))
|
common.SysError(fmt.Sprintf("panic detected: %v", err))
|
||||||
|
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
|
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
|
||||||
|
|||||||
@@ -39,6 +39,22 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
|||||||
return &channel, err
|
return &channel, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetGroupModels(group string) ([]string, error) {
|
||||||
|
var models []string
|
||||||
|
groupCol := "`group`"
|
||||||
|
trueVal := "1"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
groupCol = `"group"`
|
||||||
|
trueVal = "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
err := DB.Model(&Ability{}).Where(groupCol+" = ? and enabled = ? ", group, trueVal).Distinct("model").Pluck("model", &models).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) AddAbilities() error {
|
func (channel *Channel) AddAbilities() error {
|
||||||
models_ := strings.Split(channel.Models, ",")
|
models_ := strings.Split(channel.Models, ",")
|
||||||
groups_ := strings.Split(channel.Group, ",")
|
groups_ := strings.Split(channel.Group, ",")
|
||||||
|
|||||||
@@ -213,3 +213,22 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
|||||||
idx := rand.Intn(endIdx)
|
idx := rand.Intn(endIdx)
|
||||||
return channels[idx], nil
|
return channels[idx], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CacheGetGroupModels(group string) ([]string, error) {
|
||||||
|
if !common.MemoryCacheEnabled {
|
||||||
|
return GetGroupModels(group)
|
||||||
|
}
|
||||||
|
channelSyncLock.RLock()
|
||||||
|
defer channelSyncLock.RUnlock()
|
||||||
|
|
||||||
|
groupModels := group2model2channels[group]
|
||||||
|
if groupModels == nil {
|
||||||
|
return nil, errors.New("group not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
models := make([]string, 0)
|
||||||
|
for model := range groupModels {
|
||||||
|
models = append(models, model)
|
||||||
|
}
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ type Channel struct {
|
|||||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||||
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
|
Proxy string `json:"proxy" gorm:"type:varchar(255);default:''"`
|
||||||
|
TestModel string `json:"test_model" gorm:"type:varchar(50);default:''"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||||
|
|||||||
@@ -200,6 +200,10 @@ func SearchLogsByDayAndModel(user_id, start, end int) (LogStatistics []*LogStati
|
|||||||
groupSelect = "TO_CHAR(date_trunc('day', to_timestamp(created_at)), 'YYYY-MM-DD') as day"
|
groupSelect = "TO_CHAR(date_trunc('day', to_timestamp(created_at)), 'YYYY-MM-DD') as day"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if common.UsingSQLite {
|
||||||
|
groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
|
||||||
|
}
|
||||||
|
|
||||||
err = DB.Raw(`
|
err = DB.Raw(`
|
||||||
SELECT `+groupSelect+`,
|
SELECT `+groupSelect+`,
|
||||||
model_name, count(1) as request_count,
|
model_name, count(1) as request_count,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) {
|
|||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
var response base.BalanceResponse
|
var response base.BalanceResponse
|
||||||
_, errWithCode := common.SendRequest(req, &response, false)
|
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) {
|
|||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
var response AIProxyUserOverviewResponse
|
var response AIProxyUserOverviewResponse
|
||||||
_, errWithCode := common.SendRequest(req, &response, false)
|
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package ali
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"one-api/providers/base"
|
"one-api/providers/base"
|
||||||
|
|
||||||
@@ -28,13 +29,23 @@ type AliProvider struct {
|
|||||||
base.BaseProvider
|
base.BaseProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *AliProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
|
||||||
|
if modelName == "qwen-vl-plus" {
|
||||||
|
requestURL = "/api/v1/services/aigc/multimodal-generation/generation"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
}
|
||||||
|
|
||||||
// 获取请求头
|
// 获取请求头
|
||||||
func (p *AliProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *AliProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
p.CommonRequestHeaders(headers)
|
p.CommonRequestHeaders(headers)
|
||||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key)
|
||||||
if p.Context.GetString("plugin") != "" {
|
if p.Channel.Other != "" {
|
||||||
headers["X-DashScope-Plugin"] = p.Context.GetString("plugin")
|
headers["X-DashScope-Plugin"] = p.Channel.Other
|
||||||
}
|
}
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
|||||||
@@ -26,20 +26,12 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
choice := types.ChatCompletionChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: types.ChatCompletionMessage{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: aliResponse.Output.Text,
|
|
||||||
},
|
|
||||||
FinishReason: aliResponse.Output.FinishReason,
|
|
||||||
}
|
|
||||||
|
|
||||||
OpenAIResponse = types.ChatCompletionResponse{
|
OpenAIResponse = types.ChatCompletionResponse{
|
||||||
ID: aliResponse.RequestId,
|
ID: aliResponse.RequestId,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Choices: []types.ChatCompletionChoice{choice},
|
Model: aliResponse.Model,
|
||||||
|
Choices: aliResponse.Output.ToChatCompletionChoices(),
|
||||||
Usage: &types.Usage{
|
Usage: &types.Usage{
|
||||||
PromptTokens: aliResponse.Usage.InputTokens,
|
PromptTokens: aliResponse.Usage.InputTokens,
|
||||||
CompletionTokens: aliResponse.Usage.OutputTokens,
|
CompletionTokens: aliResponse.Usage.OutputTokens,
|
||||||
@@ -50,21 +42,57 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const AliEnableSearchModelSuffix = "-internet"
|
||||||
|
|
||||||
// 获取聊天请求体
|
// 获取聊天请求体
|
||||||
func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
|
func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
|
||||||
messages := make([]AliMessage, 0, len(request.Messages))
|
messages := make([]AliMessage, 0, len(request.Messages))
|
||||||
for i := 0; i < len(request.Messages); i++ {
|
for i := 0; i < len(request.Messages); i++ {
|
||||||
message := request.Messages[i]
|
message := request.Messages[i]
|
||||||
messages = append(messages, AliMessage{
|
if request.Model != "qwen-vl-plus" {
|
||||||
Content: message.StringContent(),
|
messages = append(messages, AliMessage{
|
||||||
Role: strings.ToLower(message.Role),
|
Content: message.StringContent(),
|
||||||
})
|
Role: strings.ToLower(message.Role),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
openaiContent := message.ParseContent()
|
||||||
|
var parts []AliMessagePart
|
||||||
|
for _, part := range openaiContent {
|
||||||
|
if part.Type == types.ContentTypeText {
|
||||||
|
parts = append(parts, AliMessagePart{
|
||||||
|
Text: part.Text,
|
||||||
|
})
|
||||||
|
} else if part.Type == types.ContentTypeImageURL {
|
||||||
|
parts = append(parts, AliMessagePart{
|
||||||
|
Image: part.ImageURL.URL,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages = append(messages, AliMessage{
|
||||||
|
Content: parts,
|
||||||
|
Role: strings.ToLower(message.Role),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enableSearch := false
|
||||||
|
aliModel := request.Model
|
||||||
|
if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) {
|
||||||
|
enableSearch = true
|
||||||
|
aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
return &AliChatRequest{
|
return &AliChatRequest{
|
||||||
Model: request.Model,
|
Model: aliModel,
|
||||||
Input: AliInput{
|
Input: AliInput{
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
},
|
},
|
||||||
|
Parameters: AliParameters{
|
||||||
|
ResultFormat: "message",
|
||||||
|
EnableSearch: enableSearch,
|
||||||
|
IncrementalOutput: request.Stream,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,6 +100,7 @@ func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *
|
|||||||
func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
requestBody := p.getChatRequestBody(request)
|
requestBody := p.getChatRequestBody(request)
|
||||||
|
|
||||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
headers := p.GetRequestHeaders()
|
headers := p.GetRequestHeaders()
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
@@ -86,7 +115,7 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa
|
|||||||
}
|
}
|
||||||
|
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
usage, errWithCode = p.sendStreamRequest(req)
|
usage, errWithCode = p.sendStreamRequest(req, request.Model)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -100,7 +129,9 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa
|
|||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
aliResponse := &AliChatResponse{}
|
aliResponse := &AliChatResponse{
|
||||||
|
Model: request.Model,
|
||||||
|
}
|
||||||
errWithCode = p.SendRequest(req, aliResponse, false)
|
errWithCode = p.SendRequest(req, aliResponse, false)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
@@ -117,10 +148,15 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa
|
|||||||
|
|
||||||
// 阿里云响应转OpenAI响应
|
// 阿里云响应转OpenAI响应
|
||||||
func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
|
func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
|
||||||
|
// chatChoice := aliResponse.Output.ToChatCompletionChoices()
|
||||||
|
// jsonBody, _ := json.MarshalIndent(chatChoice, "", " ")
|
||||||
|
// fmt.Println("requestBody:", string(jsonBody))
|
||||||
var choice types.ChatCompletionStreamChoice
|
var choice types.ChatCompletionStreamChoice
|
||||||
choice.Delta.Content = aliResponse.Output.Text
|
choice.Index = aliResponse.Output.Choices[0].Index
|
||||||
if aliResponse.Output.FinishReason != "null" {
|
choice.Delta.Content = aliResponse.Output.Choices[0].Message.StringContent()
|
||||||
finishReason := aliResponse.Output.FinishReason
|
// fmt.Println("choice.Delta.Content:", chatChoice[0].Message)
|
||||||
|
if aliResponse.Output.Choices[0].FinishReason != "null" {
|
||||||
|
finishReason := aliResponse.Output.Choices[0].FinishReason
|
||||||
choice.FinishReason = &finishReason
|
choice.FinishReason = &finishReason
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,22 +164,24 @@ func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ty
|
|||||||
ID: aliResponse.RequestId,
|
ID: aliResponse.RequestId,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: "ernie-bot",
|
Model: aliResponse.Model,
|
||||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
}
|
}
|
||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发送流请求
|
// 发送流请求
|
||||||
func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
defer req.Body.Close()
|
defer req.Body.Close()
|
||||||
|
|
||||||
usage = &types.Usage{}
|
usage = &types.Usage{}
|
||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := common.HttpClient.Do(req)
|
client := common.GetHttpClient(p.Channel.Proxy)
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
common.PutHttpClient(client)
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return nil, common.HandleErrorResp(resp)
|
return nil, common.HandleErrorResp(resp)
|
||||||
@@ -182,6 +220,7 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage,
|
|||||||
}()
|
}()
|
||||||
common.SetEventStreamHeaders(p.Context)
|
common.SetEventStreamHeaders(p.Context)
|
||||||
lastResponseText := ""
|
lastResponseText := ""
|
||||||
|
index := 0
|
||||||
p.Context.Stream(func(w io.Writer) bool {
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
@@ -196,9 +235,12 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage,
|
|||||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||||
}
|
}
|
||||||
|
aliResponse.Model = model
|
||||||
|
aliResponse.Output.Choices[0].Index = index
|
||||||
|
index++
|
||||||
response := p.streamResponseAli2OpenAI(&aliResponse)
|
response := p.streamResponseAli2OpenAI(&aliResponse)
|
||||||
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||||
lastResponseText = aliResponse.Output.Text
|
lastResponseText = aliResponse.Output.Choices[0].Message.StringContent()
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
package ali
|
package ali
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
type AliError struct {
|
type AliError struct {
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
@@ -13,20 +17,27 @@ type AliUsage struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AliMessage struct {
|
type AliMessage struct {
|
||||||
Content string `json:"content"`
|
Content any `json:"content"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AliMessagePart struct {
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Image string `json:"image,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type AliInput struct {
|
type AliInput struct {
|
||||||
// Prompt string `json:"prompt"`
|
// Prompt string `json:"prompt"`
|
||||||
Messages []AliMessage `json:"messages"`
|
Messages []AliMessage `json:"messages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliParameters struct {
|
type AliParameters struct {
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
Seed uint64 `json:"seed,omitempty"`
|
Seed uint64 `json:"seed,omitempty"`
|
||||||
EnableSearch bool `json:"enable_search,omitempty"`
|
EnableSearch bool `json:"enable_search,omitempty"`
|
||||||
|
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||||
|
ResultFormat string `json:"result_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliChatRequest struct {
|
type AliChatRequest struct {
|
||||||
@@ -35,14 +46,31 @@ type AliChatRequest struct {
|
|||||||
Parameters AliParameters `json:"parameters,omitempty"`
|
Parameters AliParameters `json:"parameters,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AliChoice struct {
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
Message types.ChatCompletionMessage `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
type AliOutput struct {
|
type AliOutput struct {
|
||||||
Text string `json:"text"`
|
Choices []types.ChatCompletionChoice `json:"choices"`
|
||||||
FinishReason string `json:"finish_reason"`
|
}
|
||||||
|
|
||||||
|
func (o *AliOutput) ToChatCompletionChoices() []types.ChatCompletionChoice {
|
||||||
|
for i := range o.Choices {
|
||||||
|
_, ok := o.Choices[i].Message.Content.(string)
|
||||||
|
if ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
o.Choices[i].Message.Content = o.Choices[i].Message.ParseContent()
|
||||||
|
}
|
||||||
|
return o.Choices
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliChatResponse struct {
|
type AliChatResponse struct {
|
||||||
Output AliOutput `json:"output"`
|
Output AliOutput `json:"output"`
|
||||||
Usage AliUsage `json:"usage"`
|
Usage AliUsage `json:"usage"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
AliError
|
AliError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
|
|||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
var response base.BalanceResponse
|
var response base.BalanceResponse
|
||||||
_, errWithCode := common.SendRequest(req, &response, false)
|
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) {
|
|||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
var response base.BalanceResponse
|
var response base.BalanceResponse
|
||||||
_, errWithCode := common.SendRequest(req, &response, false)
|
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIRespons
|
|||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
// 休眠 2 秒
|
// 休眠 2 秒
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
_, errWithCode = common.SendRequest(req, &getImageAzureResponse, false)
|
_, errWithCode = common.SendRequest(req, &getImageAzureResponse, false, c.Proxy)
|
||||||
fmt.Println("getImageAzureResponse", getImageAzureResponse)
|
fmt.Println("getImageAzureResponse", getImageAzureResponse)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
@@ -81,6 +81,7 @@ func (p *AzureProvider) ImageGenerationsAction(request *types.ImageRequest, isMo
|
|||||||
if request.Model == "dall-e-2" {
|
if request.Model == "dall-e-2" {
|
||||||
imageAzureResponse := &ImageAzureResponse{
|
imageAzureResponse := &ImageAzureResponse{
|
||||||
Header: headers,
|
Header: headers,
|
||||||
|
Proxy: p.Channel.Proxy,
|
||||||
}
|
}
|
||||||
errWithCode = p.SendRequest(req, imageAzureResponse, false)
|
errWithCode = p.SendRequest(req, imageAzureResponse, false)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ type ImageAzureResponse struct {
|
|||||||
Status string `json:"status,omitempty"`
|
Status string `json:"status,omitempty"`
|
||||||
Error ImageAzureError `json:"error,omitempty"`
|
Error ImageAzureError `json:"error,omitempty"`
|
||||||
Header map[string]string `json:"header,omitempty"`
|
Header map[string]string `json:"header,omitempty"`
|
||||||
|
Proxy string `json:"proxy,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ImageAzureError struct {
|
type ImageAzureError struct {
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ type AzureSpeechProvider struct {
|
|||||||
// 获取请求头
|
// 获取请求头
|
||||||
func (p *AzureSpeechProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *AzureSpeechProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
headers["Ocp-Apim-Subscription-Key"] = p.Context.GetString("api_key")
|
headers["Ocp-Apim-Subscription-Key"] = p.Channel.Key
|
||||||
headers["Content-Type"] = "application/ssml+xml"
|
headers["Content-Type"] = "application/ssml+xml"
|
||||||
headers["User-Agent"] = "OneAPI"
|
headers["User-Agent"] = "OneAPI"
|
||||||
// headers["X-Microsoft-OutputFormat"] = "audio-16khz-128kbitrate-mono-mp3"
|
// headers["X-Microsoft-OutputFormat"] = "audio-16khz-128kbitrate-mono-mp3"
|
||||||
|
|||||||
30
providers/baichuan/base.go
Normal file
30
providers/baichuan/base.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package baichuan
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/providers/base"
|
||||||
|
"one-api/providers/openai"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 定义供应商工厂
|
||||||
|
type BaichuanProviderFactory struct{}
|
||||||
|
|
||||||
|
// 创建 BaichuanProvider
|
||||||
|
// https://platform.baichuan-ai.com/docs/api
|
||||||
|
func (f BaichuanProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||||
|
return &BaichuanProvider{
|
||||||
|
OpenAIProvider: openai.OpenAIProvider{
|
||||||
|
BaseProvider: base.BaseProvider{
|
||||||
|
BaseURL: "https://api.baichuan-ai.com",
|
||||||
|
ChatCompletions: "/v1/chat/completions",
|
||||||
|
Embeddings: "/v1/embeddings",
|
||||||
|
Context: c,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaichuanProvider struct {
|
||||||
|
openai.OpenAIProvider
|
||||||
|
}
|
||||||
100
providers/baichuan/chat.go
Normal file
100
providers/baichuan/chat.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package baichuan
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/providers/openai"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (baichuanResponse *BaichuanChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
if baichuanResponse.Error.Message != "" {
|
||||||
|
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: baichuanResponse.Error,
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
OpenAIResponse = types.ChatCompletionResponse{
|
||||||
|
ID: baichuanResponse.ID,
|
||||||
|
Object: baichuanResponse.Object,
|
||||||
|
Created: baichuanResponse.Created,
|
||||||
|
Model: baichuanResponse.Model,
|
||||||
|
Choices: baichuanResponse.Choices,
|
||||||
|
Usage: baichuanResponse.Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取聊天请求体
|
||||||
|
func (p *BaichuanProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaichuanChatRequest {
|
||||||
|
messages := make([]BaichuanMessage, 0, len(request.Messages))
|
||||||
|
for i := 0; i < len(request.Messages); i++ {
|
||||||
|
message := request.Messages[i]
|
||||||
|
if message.Role == "system" || message.Role == "assistant" {
|
||||||
|
message.Role = "assistant"
|
||||||
|
} else {
|
||||||
|
message.Role = "user"
|
||||||
|
}
|
||||||
|
messages = append(messages, BaichuanMessage{
|
||||||
|
Content: message.StringContent(),
|
||||||
|
Role: strings.ToLower(message.Role),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &BaichuanChatRequest{
|
||||||
|
Model: request.Model,
|
||||||
|
Messages: messages,
|
||||||
|
Stream: request.Stream,
|
||||||
|
Temperature: request.Temperature,
|
||||||
|
TopP: request.TopP,
|
||||||
|
TopK: request.N,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 聊天
|
||||||
|
func (p *BaichuanProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
|
requestBody := p.getChatRequestBody(request)
|
||||||
|
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
if request.Stream {
|
||||||
|
headers["Accept"] = "text/event-stream"
|
||||||
|
}
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Stream {
|
||||||
|
openAIProviderChatStreamResponse := &openai.OpenAIProviderChatStreamResponse{}
|
||||||
|
var textResponse string
|
||||||
|
errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderChatStreamResponse)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = &types.Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: common.CountTokenText(textResponse, request.Model),
|
||||||
|
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
baichuanResponse := &BaichuanChatResponse{}
|
||||||
|
errWithCode = p.SendRequest(req, baichuanResponse, false)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = baichuanResponse.Usage
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
36
providers/baichuan/type.go
Normal file
36
providers/baichuan/type.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package baichuan
|
||||||
|
|
||||||
|
import "one-api/providers/openai"
|
||||||
|
|
||||||
|
type BaichuanMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaichuanKnowledgeBase struct {
|
||||||
|
Ids []string `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaichuanChatRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []BaichuanMessage `json:"messages"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
WithSearchEnhance bool `json:"with_search_enhance,omitempty"`
|
||||||
|
KnowledgeBase BaichuanKnowledgeBase `json:"knowledge_base,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaichuanKnowledgeBaseResponse struct {
|
||||||
|
Cites []struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
FileId string `json:"file_id"`
|
||||||
|
} `json:"cites"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaichuanChatResponse struct {
|
||||||
|
openai.OpenAIProviderChatResponse
|
||||||
|
KnowledgeBase BaichuanKnowledgeBaseResponse `json:"knowledge_base,omitempty"`
|
||||||
|
}
|
||||||
@@ -63,7 +63,7 @@ func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *BaiduProvider) getBaiduAccessToken() (string, error) {
|
func (p *BaiduProvider) getBaiduAccessToken() (string, error) {
|
||||||
apiKey := p.Context.GetString("api_key")
|
apiKey := p.Channel.Key
|
||||||
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
||||||
var accessToken BaiduAccessToken
|
var accessToken BaiduAccessToken
|
||||||
if accessToken, ok = val.(BaiduAccessToken); ok {
|
if accessToken, ok = val.(BaiduAccessToken); ok {
|
||||||
@@ -105,10 +105,12 @@ func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessTo
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := common.HttpClient.Do(req)
|
httpClient := common.GetHttpClient(p.Channel.Proxy)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
common.PutHttpClient(httpClient)
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
|||||||
@@ -88,13 +88,15 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
|
|||||||
}
|
}
|
||||||
|
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
usage, errWithCode = p.sendStreamRequest(req)
|
usage, errWithCode = p.sendStreamRequest(req, request.Model)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
baiduChatRequest := &BaiduChatResponse{}
|
baiduChatRequest := &BaiduChatResponse{
|
||||||
|
Model: request.Model,
|
||||||
|
}
|
||||||
errWithCode = p.SendRequest(req, baiduChatRequest, false)
|
errWithCode = p.SendRequest(req, baiduChatRequest, false)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
@@ -117,21 +119,23 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
|
|||||||
ID: baiduResponse.Id,
|
ID: baiduResponse.Id,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: baiduResponse.Created,
|
Created: baiduResponse.Created,
|
||||||
Model: "ernie-bot",
|
Model: baiduResponse.Model,
|
||||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
}
|
}
|
||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
defer req.Body.Close()
|
defer req.Body.Close()
|
||||||
|
|
||||||
usage = &types.Usage{}
|
usage = &types.Usage{}
|
||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := common.HttpClient.Do(req)
|
client := common.GetHttpClient(p.Channel.Proxy)
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
common.PutHttpClient(client)
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return nil, common.HandleErrorResp(resp)
|
return nil, common.HandleErrorResp(resp)
|
||||||
@@ -180,6 +184,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage
|
|||||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||||
}
|
}
|
||||||
|
baiduResponse.Model = model
|
||||||
response := p.streamResponseBaidu2OpenAI(&baiduResponse)
|
response := p.streamResponseBaidu2OpenAI(&baiduResponse)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ type BaiduChatResponse struct {
|
|||||||
IsTruncated bool `json:"is_truncated"`
|
IsTruncated bool `json:"is_truncated"`
|
||||||
NeedClearHistory bool `json:"need_clear_history"`
|
NeedClearHistory bool `json:"need_clear_history"`
|
||||||
Usage *types.Usage `json:"usage"`
|
Usage *types.Usage `json:"usage"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
BaiduError
|
BaiduError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -28,17 +29,22 @@ type BaseProvider struct {
|
|||||||
ImagesVariations string
|
ImagesVariations string
|
||||||
Proxy string
|
Proxy string
|
||||||
Context *gin.Context
|
Context *gin.Context
|
||||||
|
Channel *model.Channel
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取基础URL
|
// 获取基础URL
|
||||||
func (p *BaseProvider) GetBaseURL() string {
|
func (p *BaseProvider) GetBaseURL() string {
|
||||||
if p.Context.GetString("base_url") != "" {
|
if p.Channel.GetBaseURL() != "" {
|
||||||
return p.Context.GetString("base_url")
|
return p.Channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.BaseURL
|
return p.BaseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *BaseProvider) SetChannel(channel *model.Channel) {
|
||||||
|
p.Channel = channel
|
||||||
|
}
|
||||||
|
|
||||||
// 获取完整请求URL
|
// 获取完整请求URL
|
||||||
func (p *BaseProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
func (p *BaseProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
@@ -59,7 +65,7 @@ func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) {
|
|||||||
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
defer req.Body.Close()
|
defer req.Body.Close()
|
||||||
|
|
||||||
resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true)
|
resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true, p.Channel.Proxy)
|
||||||
if openAIErrorWithStatusCode != nil {
|
if openAIErrorWithStatusCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -102,10 +108,12 @@ func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusC
|
|||||||
defer req.Body.Close()
|
defer req.Body.Close()
|
||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := common.HttpClient.Do(req)
|
client := common.GetHttpClient(p.Channel.Proxy)
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
common.PutHttpClient(client)
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ type ProviderInterface interface {
|
|||||||
GetFullRequestURL(requestURL string, modelName string) string
|
GetFullRequestURL(requestURL string, modelName string) string
|
||||||
GetRequestHeaders() (headers map[string]string)
|
GetRequestHeaders() (headers map[string]string)
|
||||||
SupportAPI(relayMode int) bool
|
SupportAPI(relayMode int) bool
|
||||||
|
SetChannel(channel *model.Channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 完成接口
|
// 完成接口
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
|
|||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
p.CommonRequestHeaders(headers)
|
p.CommonRequestHeaders(headers)
|
||||||
|
|
||||||
headers["x-api-key"] = p.Context.GetString("api_key")
|
headers["x-api-key"] = p.Channel.Key
|
||||||
anthropicVersion := p.Context.Request.Header.Get("anthropic-version")
|
anthropicVersion := p.Context.Request.Header.Get("anthropic-version")
|
||||||
if anthropicVersion == "" {
|
if anthropicVersion == "" {
|
||||||
anthropicVersion = "2023-06-01"
|
anthropicVersion = "2023-06-01"
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ func (claudeResponse *ClaudeResponse) ResponseHandler(resp *http.Response) (Open
|
|||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Choices: []types.ChatCompletionChoice{choice},
|
Choices: []types.ChatCompletionChoice{choice},
|
||||||
|
Model: claudeResponse.Model,
|
||||||
}
|
}
|
||||||
|
|
||||||
completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model)
|
completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model)
|
||||||
@@ -141,10 +142,12 @@ func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
|
|||||||
defer req.Body.Close()
|
defer req.Body.Close()
|
||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := common.HttpClient.Do(req)
|
client := common.GetHttpClient(p.Channel.Proxy)
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||||
}
|
}
|
||||||
|
common.PutHttpClient(client)
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return common.HandleErrorResp(resp), ""
|
return common.HandleErrorResp(resp), ""
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error)
|
|||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
var response OpenAICreditGrants
|
var response OpenAICreditGrants
|
||||||
_, errWithCode := common.SendRequest(req, &response, false)
|
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,11 +28,11 @@ type GeminiProvider struct {
|
|||||||
func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
version := "v1"
|
version := "v1"
|
||||||
if p.Context.GetString("api_version") != "" {
|
if p.Channel.Other != "" {
|
||||||
version = p.Context.GetString("api_version")
|
version = p.Channel.Other
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s/%s/models/%s:%s?key=%s", baseURL, version, modelName, requestURL, p.Context.GetString("api_key"))
|
return fmt.Sprintf("%s/%s/models/%s:%s", baseURL, version, modelName, requestURL)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,6 +40,7 @@ func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string)
|
|||||||
func (p *GeminiProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *GeminiProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
p.CommonRequestHeaders(headers)
|
p.CommonRequestHeaders(headers)
|
||||||
|
headers["x-goog-api-key"] = p.Channel.Key
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,11 +7,16 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/image"
|
||||||
"one-api/providers/base"
|
"one-api/providers/base"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
GeminiVisionMaxImageNum = 16
|
||||||
|
)
|
||||||
|
|
||||||
func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
if len(response.Candidates) == 0 {
|
if len(response.Candidates) == 0 {
|
||||||
return nil, &types.OpenAIErrorWithStatusCode{
|
return nil, &types.OpenAIErrorWithStatusCode{
|
||||||
@@ -29,6 +34,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI
|
|||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
|
Model: response.Model,
|
||||||
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
|
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
|
||||||
}
|
}
|
||||||
for i, candidate := range response.Candidates {
|
for i, candidate := range response.Candidates {
|
||||||
@@ -46,7 +52,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI
|
|||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
}
|
}
|
||||||
|
|
||||||
completionTokens := common.CountTokenText(response.GetResponseText(), "gemini-pro")
|
completionTokens := common.CountTokenText(response.GetResponseText(), response.Model)
|
||||||
response.Usage.CompletionTokens = completionTokens
|
response.Usage.CompletionTokens = completionTokens
|
||||||
response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens
|
response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens
|
||||||
|
|
||||||
@@ -54,27 +60,27 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||||
func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest) {
|
func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
geminiRequest := GeminiChatRequest{
|
geminiRequest := GeminiChatRequest{
|
||||||
Contents: make([]GeminiChatContent, 0, len(request.Messages)),
|
Contents: make([]GeminiChatContent, 0, len(request.Messages)),
|
||||||
//SafetySettings: []GeminiChatSafetySettings{
|
SafetySettings: []GeminiChatSafetySettings{
|
||||||
// {
|
{
|
||||||
// Category: "HARM_CATEGORY_HARASSMENT",
|
Category: "HARM_CATEGORY_HARASSMENT",
|
||||||
// Threshold: "BLOCK_ONLY_HIGH",
|
Threshold: "BLOCK_NONE",
|
||||||
// },
|
},
|
||||||
// {
|
{
|
||||||
// Category: "HARM_CATEGORY_HATE_SPEECH",
|
Category: "HARM_CATEGORY_HATE_SPEECH",
|
||||||
// Threshold: "BLOCK_ONLY_HIGH",
|
Threshold: "BLOCK_NONE",
|
||||||
// },
|
},
|
||||||
// {
|
{
|
||||||
// Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
// Threshold: "BLOCK_ONLY_HIGH",
|
Threshold: "BLOCK_NONE",
|
||||||
// },
|
},
|
||||||
// {
|
{
|
||||||
// Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
// Threshold: "BLOCK_ONLY_HIGH",
|
Threshold: "BLOCK_NONE",
|
||||||
// },
|
},
|
||||||
//},
|
},
|
||||||
GenerationConfig: GeminiChatGenerationConfig{
|
GenerationConfig: GeminiChatGenerationConfig{
|
||||||
Temperature: request.Temperature,
|
Temperature: request.Temperature,
|
||||||
TopP: request.TopP,
|
TopP: request.TopP,
|
||||||
@@ -98,6 +104,34 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
openaiContent := message.ParseContent()
|
||||||
|
var parts []GeminiPart
|
||||||
|
imageNum := 0
|
||||||
|
for _, part := range openaiContent {
|
||||||
|
if part.Type == types.ContentTypeText {
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
Text: part.Text,
|
||||||
|
})
|
||||||
|
} else if part.Type == types.ContentTypeImageURL {
|
||||||
|
imageNum += 1
|
||||||
|
if imageNum > GeminiVisionMaxImageNum {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.URL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
InlineData: &GeminiInlineData{
|
||||||
|
MimeType: mimeType,
|
||||||
|
Data: data,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
content.Parts = parts
|
||||||
|
|
||||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||||
if content.Role == "assistant" {
|
if content.Role == "assistant" {
|
||||||
content.Role = "model"
|
content.Role = "model"
|
||||||
@@ -123,11 +157,14 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &geminiRequest
|
return &geminiRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
requestBody := p.getChatRequestBody(request)
|
requestBody, errWithCode := p.getChatRequestBody(request)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
fullRequestURL := p.GetFullRequestURL("generateContent", request.Model)
|
fullRequestURL := p.GetFullRequestURL("generateContent", request.Model)
|
||||||
headers := p.GetRequestHeaders()
|
headers := p.GetRequestHeaders()
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
@@ -142,7 +179,7 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode
|
|||||||
|
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
var responseText string
|
var responseText string
|
||||||
errWithCode, responseText = p.sendStreamRequest(req)
|
errWithCode, responseText = p.sendStreamRequest(req, request.Model)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -155,6 +192,7 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
var geminiResponse = &GeminiChatResponse{
|
var geminiResponse = &GeminiChatResponse{
|
||||||
|
Model: request.Model,
|
||||||
Usage: &types.Usage{
|
Usage: &types.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
},
|
},
|
||||||
@@ -170,25 +208,27 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse {
|
// func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse {
|
||||||
var choice types.ChatCompletionStreamChoice
|
// var choice types.ChatCompletionStreamChoice
|
||||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
// choice.Delta.Content = geminiResponse.GetResponseText()
|
||||||
choice.FinishReason = &base.StopFinishReason
|
// choice.FinishReason = &base.StopFinishReason
|
||||||
var response types.ChatCompletionStreamResponse
|
// var response types.ChatCompletionStreamResponse
|
||||||
response.Object = "chat.completion.chunk"
|
// response.Object = "chat.completion.chunk"
|
||||||
response.Model = "gemini"
|
// response.Model = "gemini"
|
||||||
response.Choices = []types.ChatCompletionStreamChoice{choice}
|
// response.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||||
return &response
|
// return &response
|
||||||
}
|
// }
|
||||||
|
|
||||||
func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
|
func (p *GeminiProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) {
|
||||||
defer req.Body.Close()
|
defer req.Body.Close()
|
||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := common.HttpClient.Do(req)
|
client := common.GetHttpClient(p.Channel.Proxy)
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||||
}
|
}
|
||||||
|
common.PutHttpClient(client)
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return common.HandleErrorResp(resp), ""
|
return common.HandleErrorResp(resp), ""
|
||||||
@@ -235,7 +275,7 @@ func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
|
|||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
var dummy dummyStruct
|
var dummy dummyStruct
|
||||||
err := json.Unmarshal([]byte(data), &dummy)
|
json.Unmarshal([]byte(data), &dummy)
|
||||||
responseText += dummy.Content
|
responseText += dummy.Content
|
||||||
var choice types.ChatCompletionStreamChoice
|
var choice types.ChatCompletionStreamChoice
|
||||||
choice.Delta.Content = dummy.Content
|
choice.Delta.Content = dummy.Content
|
||||||
@@ -243,7 +283,7 @@ func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
|
|||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: "gemini-pro",
|
Model: model,
|
||||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
}
|
}
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ type GeminiChatResponse struct {
|
|||||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||||
Usage *types.Usage `json:"usage,omitempty"`
|
Usage *types.Usage `json:"usage,omitempty"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiChatCandidate struct {
|
type GeminiChatCandidate struct {
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
|
|||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
var subscription OpenAISubscriptionResponse
|
var subscription OpenAISubscriptionResponse
|
||||||
_, errWithCode := common.SendRequest(req, &subscription, false)
|
_, errWithCode := common.SendRequest(req, &subscription, false, p.Channel.Proxy)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||||
}
|
}
|
||||||
@@ -38,7 +38,7 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
usage := OpenAIUsageResponse{}
|
usage := OpenAIUsageResponse{}
|
||||||
_, errWithCode = common.SendRequest(req, &usage, false)
|
_, errWithCode = common.SendRequest(req, &usage, false, p.Channel.Proxy)
|
||||||
|
|
||||||
balance := subscription.HardLimitUSD - usage.TotalUsage/100
|
balance := subscription.HardLimitUSD - usage.TotalUsage/100
|
||||||
channel.UpdateBalance(balance)
|
channel.UpdateBalance(balance)
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)
|
|||||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
|
||||||
if p.IsAzure {
|
if p.IsAzure {
|
||||||
apiVersion := p.Context.GetString("api_version")
|
apiVersion := p.Channel.Other
|
||||||
if modelName == "dall-e-2" {
|
if modelName == "dall-e-2" {
|
||||||
// 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本
|
// 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本
|
||||||
// 已经没有dall-e-2了,所以暂时写死
|
// 已经没有dall-e-2了,所以暂时写死
|
||||||
@@ -85,9 +85,9 @@ func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
|
|||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
p.CommonRequestHeaders(headers)
|
p.CommonRequestHeaders(headers)
|
||||||
if p.IsAzure {
|
if p.IsAzure {
|
||||||
headers["api-key"] = p.Context.GetString("api_key")
|
headers["api-key"] = p.Channel.Key
|
||||||
} else {
|
} else {
|
||||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
@@ -108,13 +108,15 @@ func (p *OpenAIProvider) GetRequestBody(request any, isModelMapped bool) (reques
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 发送流式请求
|
// 发送流式请求
|
||||||
func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
|
func (p *OpenAIProvider) SendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
|
||||||
defer req.Body.Close()
|
defer req.Body.Close()
|
||||||
|
|
||||||
resp, err := common.HttpClient.Do(req)
|
client := common.GetHttpClient(p.Channel.Proxy)
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||||
}
|
}
|
||||||
|
common.PutHttpClient(client)
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return common.HandleErrorResp(resp), ""
|
return common.HandleErrorResp(resp), ""
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isMode
|
|||||||
if request.Stream {
|
if request.Stream {
|
||||||
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
|
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
|
||||||
var textResponse string
|
var textResponse string
|
||||||
errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse)
|
errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderChatStreamResponse)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isMode
|
|||||||
if request.Stream {
|
if request.Stream {
|
||||||
// TODO
|
// TODO
|
||||||
var textResponse string
|
var textResponse string
|
||||||
errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse)
|
errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderCompletionResponse)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
|
|||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
var response OpenAISBUsageResponse
|
var response OpenAISBUsageResponse
|
||||||
_, errWithCode := common.SendRequest(req, &response, false)
|
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ type PalmProvider struct {
|
|||||||
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
p.CommonRequestHeaders(headers)
|
p.CommonRequestHeaders(headers)
|
||||||
|
headers["x-goog-api-key"] = p.Channel.Key
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
}
|
}
|
||||||
@@ -37,5 +38,5 @@ func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
|||||||
func (p *PalmProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
func (p *PalmProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
|
||||||
return fmt.Sprintf("%s%s?key=%s", baseURL, requestURL, p.Context.GetString("api_key"))
|
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (Open
|
|||||||
palmResponse.Usage.TotalTokens = palmResponse.Usage.PromptTokens + completionTokens
|
palmResponse.Usage.TotalTokens = palmResponse.Usage.PromptTokens + completionTokens
|
||||||
|
|
||||||
fullTextResponse.Usage = palmResponse.Usage
|
fullTextResponse.Usage = palmResponse.Usage
|
||||||
|
fullTextResponse.Model = palmResponse.Model
|
||||||
|
|
||||||
return fullTextResponse, nil
|
return fullTextResponse, nil
|
||||||
}
|
}
|
||||||
@@ -133,10 +134,12 @@ func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorW
|
|||||||
defer req.Body.Close()
|
defer req.Body.Close()
|
||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := common.HttpClient.Do(req)
|
client := common.GetHttpClient(p.Channel.Proxy)
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||||
}
|
}
|
||||||
|
common.PutHttpClient(client)
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return common.HandleErrorResp(resp), ""
|
return common.HandleErrorResp(resp), ""
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
"one-api/providers/aigc2d"
|
"one-api/providers/aigc2d"
|
||||||
"one-api/providers/aiproxy"
|
"one-api/providers/aiproxy"
|
||||||
"one-api/providers/ali"
|
"one-api/providers/ali"
|
||||||
@@ -9,6 +10,7 @@ import (
|
|||||||
"one-api/providers/api2gpt"
|
"one-api/providers/api2gpt"
|
||||||
"one-api/providers/azure"
|
"one-api/providers/azure"
|
||||||
azurespeech "one-api/providers/azureSpeech"
|
azurespeech "one-api/providers/azureSpeech"
|
||||||
|
"one-api/providers/baichuan"
|
||||||
"one-api/providers/baidu"
|
"one-api/providers/baidu"
|
||||||
"one-api/providers/base"
|
"one-api/providers/base"
|
||||||
"one-api/providers/claude"
|
"one-api/providers/claude"
|
||||||
@@ -51,23 +53,28 @@ func init() {
|
|||||||
providerFactories[common.ChannelTypeAPI2GPT] = api2gpt.Api2gptProviderFactory{}
|
providerFactories[common.ChannelTypeAPI2GPT] = api2gpt.Api2gptProviderFactory{}
|
||||||
providerFactories[common.ChannelTypeAzureSpeech] = azurespeech.AzureSpeechProviderFactory{}
|
providerFactories[common.ChannelTypeAzureSpeech] = azurespeech.AzureSpeechProviderFactory{}
|
||||||
providerFactories[common.ChannelTypeGemini] = gemini.GeminiProviderFactory{}
|
providerFactories[common.ChannelTypeGemini] = gemini.GeminiProviderFactory{}
|
||||||
|
providerFactories[common.ChannelTypeBaichuan] = baichuan.BaichuanProviderFactory{}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
func GetProvider(channelType int, c *gin.Context) base.ProviderInterface {
|
func GetProvider(channel *model.Channel, c *gin.Context) base.ProviderInterface {
|
||||||
factory, ok := providerFactories[channelType]
|
factory, ok := providerFactories[channel.Type]
|
||||||
|
var provider base.ProviderInterface
|
||||||
if !ok {
|
if !ok {
|
||||||
// 处理未找到的供应商工厂
|
// 处理未找到的供应商工厂
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||||
if c.GetString("base_url") != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = c.GetString("base_url")
|
baseURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
if baseURL != "" {
|
if baseURL == "" {
|
||||||
return openai.CreateOpenAIProvider(c, baseURL)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
provider = openai.CreateOpenAIProvider(c, baseURL)
|
||||||
}
|
}
|
||||||
return factory.Create(c)
|
provider = factory.Create(c)
|
||||||
|
provider.SetChannel(channel)
|
||||||
|
|
||||||
|
return provider
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func (p *TencentProvider) parseTencentConfig(config string) (appId int64, secret
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *TencentProvider) getTencentSign(req TencentChatRequest) string {
|
func (p *TencentProvider) getTencentSign(req TencentChatRequest) string {
|
||||||
apiKey := p.Context.GetString("api_key")
|
apiKey := p.Channel.Key
|
||||||
appId, secretId, secretKey, err := p.parseTencentConfig(apiKey)
|
appId, secretId, secretKey, err := p.parseTencentConfig(apiKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ func (TencentResponse *TencentChatResponse) ResponseHandler(resp *http.Response)
|
|||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Usage: TencentResponse.Usage,
|
Usage: TencentResponse.Usage,
|
||||||
|
Model: TencentResponse.Model,
|
||||||
}
|
}
|
||||||
if len(TencentResponse.Choices) > 0 {
|
if len(TencentResponse.Choices) > 0 {
|
||||||
choice := types.ChatCompletionChoice{
|
choice := types.ChatCompletionChoice{
|
||||||
@@ -100,7 +101,7 @@ func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isMod
|
|||||||
|
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
var responseText string
|
var responseText string
|
||||||
errWithCode, responseText = p.sendStreamRequest(req)
|
errWithCode, responseText = p.sendStreamRequest(req, request.Model)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -112,7 +113,9 @@ func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isMod
|
|||||||
usage.TotalTokens = promptTokens + usage.CompletionTokens
|
usage.TotalTokens = promptTokens + usage.CompletionTokens
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
tencentResponse := &TencentChatResponse{}
|
tencentResponse := &TencentChatResponse{
|
||||||
|
Model: request.Model,
|
||||||
|
}
|
||||||
errWithCode = p.SendRequest(req, tencentResponse, false)
|
errWithCode = p.SendRequest(req, tencentResponse, false)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
@@ -128,7 +131,7 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC
|
|||||||
response := types.ChatCompletionStreamResponse{
|
response := types.ChatCompletionStreamResponse{
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: "tencent-hunyuan",
|
Model: TencentResponse.Model,
|
||||||
}
|
}
|
||||||
if len(TencentResponse.Choices) > 0 {
|
if len(TencentResponse.Choices) > 0 {
|
||||||
var choice types.ChatCompletionStreamChoice
|
var choice types.ChatCompletionStreamChoice
|
||||||
@@ -141,13 +144,15 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
|
func (p *TencentProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) {
|
||||||
defer req.Body.Close()
|
defer req.Body.Close()
|
||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := common.HttpClient.Do(req)
|
client := common.GetHttpClient(p.Channel.Proxy)
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||||
}
|
}
|
||||||
|
common.PutHttpClient(client)
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return common.HandleErrorResp(resp), ""
|
return common.HandleErrorResp(resp), ""
|
||||||
@@ -195,6 +200,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr
|
|||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
TencentResponse.Model = model
|
||||||
response := p.streamResponseTencent2OpenAI(&TencentResponse)
|
response := p.streamResponseTencent2OpenAI(&TencentResponse)
|
||||||
if len(response.Choices) != 0 {
|
if len(response.Choices) != 0 {
|
||||||
responseText += response.Choices[0].Delta.Content
|
responseText += response.Choices[0].Delta.Content
|
||||||
|
|||||||
@@ -58,4 +58,5 @@ type TencentChatResponse struct {
|
|||||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||||
Note string `json:"note,omitempty"` // 注释
|
Note string `json:"note,omitempty"` // 注释
|
||||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||||
|
Model string `json:"model,omitempty"` // 模型名称
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func (p *XunfeiProvider) GetRequestHeaders() (headers map[string]string) {
|
|||||||
|
|
||||||
// 获取完整请求 URL
|
// 获取完整请求 URL
|
||||||
func (p *XunfeiProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
func (p *XunfeiProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
splits := strings.Split(p.Context.GetString("api_key"), "|")
|
splits := strings.Split(p.Channel.Key, "|")
|
||||||
if len(splits) != 3 {
|
if len(splits) != 3 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -58,7 +58,7 @@ func (p *XunfeiProvider) getXunfeiAuthUrl(apiKey string, apiSecret string) (stri
|
|||||||
query := p.Context.Request.URL.Query()
|
query := p.Context.Request.URL.Query()
|
||||||
apiVersion := query.Get("api-version")
|
apiVersion := query.Get("api-version")
|
||||||
if apiVersion == "" {
|
if apiVersion == "" {
|
||||||
apiVersion = p.Context.GetString("api_version")
|
apiVersion = p.Channel.Other
|
||||||
}
|
}
|
||||||
if apiVersion == "" {
|
if apiVersion == "" {
|
||||||
apiVersion = "v1.1"
|
apiVersion = "v1.1"
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func (p *ZhipuProvider) GetFullRequestURL(requestURL string, modelName string) s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *ZhipuProvider) getZhipuToken() string {
|
func (p *ZhipuProvider) getZhipuToken() string {
|
||||||
apikey := p.Context.GetString("api_key")
|
apikey := p.Channel.Key
|
||||||
data, ok := zhipuTokens.Load(apikey)
|
data, ok := zhipuTokens.Load(apikey)
|
||||||
if ok {
|
if ok {
|
||||||
tokenData := data.(zhipuTokenData)
|
tokenData := data.(zhipuTokenData)
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAI
|
|||||||
ID: zhipuResponse.Data.TaskId,
|
ID: zhipuResponse.Data.TaskId,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
|
Model: zhipuResponse.Model,
|
||||||
Choices: make([]types.ChatCompletionChoice, 0, len(zhipuResponse.Data.Choices)),
|
Choices: make([]types.ChatCompletionChoice, 0, len(zhipuResponse.Data.Choices)),
|
||||||
Usage: &zhipuResponse.Data.Usage,
|
Usage: &zhipuResponse.Data.Usage,
|
||||||
}
|
}
|
||||||
@@ -94,13 +95,15 @@ func (p *ZhipuProvider) ChatAction(request *types.ChatCompletionRequest, isModel
|
|||||||
}
|
}
|
||||||
|
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
errWithCode, usage = p.sendStreamRequest(req)
|
errWithCode, usage = p.sendStreamRequest(req, request.Model)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
zhipuResponse := &ZhipuResponse{}
|
zhipuResponse := &ZhipuResponse{
|
||||||
|
Model: request.Model,
|
||||||
|
}
|
||||||
errWithCode = p.SendRequest(req, zhipuResponse, false)
|
errWithCode = p.SendRequest(req, zhipuResponse, false)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
@@ -132,20 +135,22 @@ func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStrea
|
|||||||
ID: zhipuResponse.RequestId,
|
ID: zhipuResponse.RequestId,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: "chatglm",
|
Model: zhipuResponse.Model,
|
||||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
}
|
}
|
||||||
return &response, &zhipuResponse.Usage
|
return &response, &zhipuResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, *types.Usage) {
|
func (p *ZhipuProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, *types.Usage) {
|
||||||
defer req.Body.Close()
|
defer req.Body.Close()
|
||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := common.HttpClient.Do(req)
|
client := common.GetHttpClient(p.Channel.Proxy)
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
|
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
common.PutHttpClient(client)
|
||||||
|
|
||||||
if common.IsFailureStatusCode(resp) {
|
if common.IsFailureStatusCode(resp) {
|
||||||
return common.HandleErrorResp(resp), nil
|
return common.HandleErrorResp(resp), nil
|
||||||
@@ -159,7 +164,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
|
|||||||
if atEOF && len(data) == 0 {
|
if atEOF && len(data) == 0 {
|
||||||
return 0, nil, nil
|
return 0, nil, nil
|
||||||
}
|
}
|
||||||
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
|
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Contains(string(data), ":") {
|
||||||
return i + 2, data[0:i], nil
|
return i + 2, data[0:i], nil
|
||||||
}
|
}
|
||||||
if atEOF {
|
if atEOF {
|
||||||
@@ -195,6 +200,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
|
|||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
response := p.streamResponseZhipu2OpenAI(data)
|
response := p.streamResponseZhipu2OpenAI(data)
|
||||||
|
response.Model = model
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
@@ -209,6 +215,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
|
|||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
zhipuResponse.Model = model
|
||||||
response, zhipuUsage := p.streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
response, zhipuUsage := p.streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ type ZhipuResponse struct {
|
|||||||
Msg string `json:"msg"`
|
Msg string `json:"msg"`
|
||||||
Success bool `json:"success"`
|
Success bool `json:"success"`
|
||||||
Data ZhipuResponseData `json:"data"`
|
Data ZhipuResponseData `json:"data"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ZhipuStreamMetaResponse struct {
|
type ZhipuStreamMetaResponse struct {
|
||||||
@@ -38,6 +39,7 @@ type ZhipuStreamMetaResponse struct {
|
|||||||
TaskId string `json:"task_id"`
|
TaskId string `json:"task_id"`
|
||||||
TaskStatus string `json:"task_status"`
|
TaskStatus string `json:"task_status"`
|
||||||
types.Usage `json:"usage"`
|
types.Usage `json:"usage"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type zhipuTokenData struct {
|
type zhipuTokenData struct {
|
||||||
|
|||||||
@@ -1,3 +1,9 @@
|
|||||||
|
[//]: # (请按照以下格式关联 issue)
|
||||||
|
[//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR,谢谢)
|
||||||
|
[//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解)
|
||||||
|
[//]: # (开发者交流群:910657413)
|
||||||
|
[//]: # (请在提交 PR 之前删除上面的注释)
|
||||||
|
|
||||||
close #issue_number
|
close #issue_number
|
||||||
|
|
||||||
我已确认该 PR 已自测通过,相关截图如下:
|
我已确认该 PR 已自测通过,相关截图如下:
|
||||||
@@ -67,7 +67,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
{
|
{
|
||||||
channelRoute.GET("/", controller.GetAllChannels)
|
channelRoute.GET("/", controller.GetAllChannels)
|
||||||
channelRoute.GET("/search", controller.SearchChannels)
|
channelRoute.GET("/search", controller.SearchChannels)
|
||||||
channelRoute.GET("/models", controller.ListModels)
|
channelRoute.GET("/models", controller.ListModelsForAdmin)
|
||||||
channelRoute.GET("/:id", controller.GetChannel)
|
channelRoute.GET("/:id", controller.GetChannel)
|
||||||
channelRoute.GET("/test", controller.TestAllChannels)
|
channelRoute.GET("/test", controller.TestAllChannels)
|
||||||
channelRoute.GET("/test/:id", controller.TestChannel)
|
channelRoute.GET("/test/:id", controller.TestChannel)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
router.Use(middleware.CORS())
|
router.Use(middleware.CORS())
|
||||||
// https://platform.openai.com/docs/api-reference/introduction
|
// https://platform.openai.com/docs/api-reference/introduction
|
||||||
modelsRouter := router.Group("/v1/models")
|
modelsRouter := router.Group("/v1/models")
|
||||||
modelsRouter.Use(middleware.TokenAuth())
|
modelsRouter.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||||
{
|
{
|
||||||
modelsRouter.GET("", controller.ListModels)
|
modelsRouter.GET("", controller.ListModels)
|
||||||
modelsRouter.GET("/:model", controller.RetrieveModel)
|
modelsRouter.GET("/:model", controller.RetrieveModel)
|
||||||
|
|||||||
@@ -1,5 +1,10 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
|
const (
|
||||||
|
ContentTypeText = "text"
|
||||||
|
ContentTypeImageURL = "image_url"
|
||||||
|
)
|
||||||
|
|
||||||
type ChatCompletionMessage struct {
|
type ChatCompletionMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content any `json:"content"`
|
Content any `json:"content"`
|
||||||
@@ -22,17 +27,61 @@ func (m ChatCompletionMessage) StringContent() string {
|
|||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if contentMap["type"] == "text" {
|
|
||||||
if subStr, ok := contentMap["text"].(string); ok {
|
if subStr, ok := contentMap["text"].(string); ok && subStr != "" {
|
||||||
contentStr += subStr
|
contentStr += subStr
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return contentStr
|
return contentStr
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m ChatCompletionMessage) ParseContent() []ChatMessagePart {
|
||||||
|
var contentList []ChatMessagePart
|
||||||
|
content, ok := m.Content.(string)
|
||||||
|
if ok {
|
||||||
|
contentList = append(contentList, ChatMessagePart{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
Text: content,
|
||||||
|
})
|
||||||
|
return contentList
|
||||||
|
}
|
||||||
|
anyList, ok := m.Content.([]any)
|
||||||
|
if ok {
|
||||||
|
for _, contentItem := range anyList {
|
||||||
|
contentMap, ok := contentItem.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if subStr, ok := contentMap["text"].(string); ok && subStr != "" {
|
||||||
|
contentList = append(contentList, ChatMessagePart{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
Text: subStr,
|
||||||
|
})
|
||||||
|
} else if subObj, ok := contentMap["image_url"].(map[string]any); ok {
|
||||||
|
contentList = append(contentList, ChatMessagePart{
|
||||||
|
Type: ContentTypeImageURL,
|
||||||
|
ImageURL: &ChatMessageImageURL{
|
||||||
|
URL: subObj["url"].(string),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else if subObj, ok := contentMap["image"].(string); ok {
|
||||||
|
contentList = append(contentList, ChatMessagePart{
|
||||||
|
Type: ContentTypeImageURL,
|
||||||
|
ImageURL: &ChatMessageImageURL{
|
||||||
|
URL: subObj,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return contentList
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type ChatMessageImageURL struct {
|
type ChatMessageImageURL struct {
|
||||||
URL string `json:"url,omitempty"`
|
URL string `json:"url,omitempty"`
|
||||||
Detail string `json:"detail,omitempty"`
|
Detail string `json:"detail,omitempty"`
|
||||||
|
|||||||
@@ -65,6 +65,12 @@ export const CHANNEL_OPTIONS = {
|
|||||||
value: 23,
|
value: 23,
|
||||||
color: 'default'
|
color: 'default'
|
||||||
},
|
},
|
||||||
|
26: {
|
||||||
|
key: 26,
|
||||||
|
text: '百川',
|
||||||
|
value: 26,
|
||||||
|
color: 'orange'
|
||||||
|
},
|
||||||
24: {
|
24: {
|
||||||
key: 24,
|
key: 24,
|
||||||
text: 'Azure Speech',
|
text: 'Azure Speech',
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import { isAdmin } from 'utils/common';
|
import { isAdmin } from 'utils/common';
|
||||||
import { useNavigate } from 'react-router-dom';
|
import { useNavigate } from 'react-router-dom';
|
||||||
const navigate = useNavigate();
|
|
||||||
|
|
||||||
const useAuth = () => {
|
const useAuth = () => {
|
||||||
const userIsAdmin = isAdmin();
|
const userIsAdmin = isAdmin();
|
||||||
|
const navigate = useNavigate();
|
||||||
|
|
||||||
if (!userIsAdmin) {
|
if (!userIsAdmin) {
|
||||||
navigate('/panel/404');
|
navigate('/panel/404');
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ const validationSchema = Yup.object().shape({
|
|||||||
type: Yup.number().required('渠道 不能为空'),
|
type: Yup.number().required('渠道 不能为空'),
|
||||||
key: Yup.string().when('is_edit', { is: false, then: Yup.string().required('密钥 不能为空') }),
|
key: Yup.string().when('is_edit', { is: false, then: Yup.string().required('密钥 不能为空') }),
|
||||||
other: Yup.string(),
|
other: Yup.string(),
|
||||||
|
proxy: Yup.string(),
|
||||||
|
test_model: Yup.string(),
|
||||||
models: Yup.array().min(1, '模型 不能为空'),
|
models: Yup.array().min(1, '模型 不能为空'),
|
||||||
groups: Yup.array().min(1, '用户组 不能为空'),
|
groups: Yup.array().min(1, '用户组 不能为空'),
|
||||||
base_url: Yup.string().when('type', {
|
base_url: Yup.string().when('type', {
|
||||||
@@ -89,7 +91,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
|
|||||||
if (newInput) {
|
if (newInput) {
|
||||||
Object.keys(newInput).forEach((key) => {
|
Object.keys(newInput).forEach((key) => {
|
||||||
if (
|
if (
|
||||||
(!Array.isArray(values[key]) && values[key] !== null && values[key] !== undefined) ||
|
(!Array.isArray(values[key]) && values[key] !== null && values[key] !== undefined && values[key] !== '') ||
|
||||||
(Array.isArray(values[key]) && values[key].length > 0)
|
(Array.isArray(values[key]) && values[key].length > 0)
|
||||||
) {
|
) {
|
||||||
return;
|
return;
|
||||||
@@ -442,6 +444,50 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
|
|||||||
<FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText>
|
<FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText>
|
||||||
)}
|
)}
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
<FormControl fullWidth error={Boolean(touched.proxy && errors.proxy)} sx={{ ...theme.typography.otherInput }}>
|
||||||
|
<InputLabel htmlFor="channel-proxy-label">{inputLabel.proxy}</InputLabel>
|
||||||
|
<OutlinedInput
|
||||||
|
id="channel-proxy-label"
|
||||||
|
label={inputLabel.proxy}
|
||||||
|
type="text"
|
||||||
|
value={values.proxy}
|
||||||
|
name="proxy"
|
||||||
|
onBlur={handleBlur}
|
||||||
|
onChange={handleChange}
|
||||||
|
inputProps={{}}
|
||||||
|
aria-describedby="helper-text-channel-proxy-label"
|
||||||
|
/>
|
||||||
|
{touched.proxy && errors.proxy ? (
|
||||||
|
<FormHelperText error id="helper-tex-channel-proxy-label">
|
||||||
|
{errors.proxy}
|
||||||
|
</FormHelperText>
|
||||||
|
) : (
|
||||||
|
<FormHelperText id="helper-tex-channel-proxy-label"> {inputPrompt.proxy} </FormHelperText>
|
||||||
|
)}
|
||||||
|
</FormControl>
|
||||||
|
{inputPrompt.test_model && (
|
||||||
|
<FormControl fullWidth error={Boolean(touched.test_model && errors.test_model)} sx={{ ...theme.typography.otherInput }}>
|
||||||
|
<InputLabel htmlFor="channel-test_model-label">{inputLabel.test_model}</InputLabel>
|
||||||
|
<OutlinedInput
|
||||||
|
id="channel-test_model-label"
|
||||||
|
label={inputLabel.test_model}
|
||||||
|
type="text"
|
||||||
|
value={values.test_model}
|
||||||
|
name="test_model"
|
||||||
|
onBlur={handleBlur}
|
||||||
|
onChange={handleChange}
|
||||||
|
inputProps={{}}
|
||||||
|
aria-describedby="helper-text-channel-test_model-label"
|
||||||
|
/>
|
||||||
|
{touched.test_model && errors.test_model ? (
|
||||||
|
<FormHelperText error id="helper-tex-channel-test_model-label">
|
||||||
|
{errors.test_model}
|
||||||
|
</FormHelperText>
|
||||||
|
) : (
|
||||||
|
<FormHelperText id="helper-tex-channel-test_model-label"> {inputPrompt.test_model} </FormHelperText>
|
||||||
|
)}
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
<DialogActions>
|
<DialogActions>
|
||||||
<Button onClick={onCancel}>取消</Button>
|
<Button onClick={onCancel}>取消</Button>
|
||||||
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
|
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
|
||||||
|
|||||||
53
web/src/views/Channel/component/NameLabel.js
Normal file
53
web/src/views/Channel/component/NameLabel.js
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import PropTypes from 'prop-types';
|
||||||
|
import { Tooltip, Stack, Container } from '@mui/material';
|
||||||
|
import Label from 'ui-component/Label';
|
||||||
|
import { styled } from '@mui/material/styles';
|
||||||
|
import { showSuccess } from 'utils/common';
|
||||||
|
|
||||||
|
const TooltipContainer = styled(Container)({
|
||||||
|
maxHeight: '250px',
|
||||||
|
overflow: 'auto',
|
||||||
|
'&::-webkit-scrollbar': {
|
||||||
|
width: '0px' // Set the width to 0 to hide the scrollbar
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const NameLabel = ({ name, models }) => {
|
||||||
|
let modelMap = [];
|
||||||
|
modelMap = models.split(',');
|
||||||
|
modelMap.sort();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Tooltip
|
||||||
|
title={
|
||||||
|
<TooltipContainer>
|
||||||
|
<Stack spacing={1}>
|
||||||
|
{modelMap.map((item, index) => {
|
||||||
|
return (
|
||||||
|
<Label
|
||||||
|
variant="ghost"
|
||||||
|
key={index}
|
||||||
|
onClick={() => {
|
||||||
|
navigator.clipboard.writeText(item);
|
||||||
|
showSuccess('复制模型名称成功!');
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{item}
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</Stack>
|
||||||
|
</TooltipContainer>
|
||||||
|
}
|
||||||
|
placement="top"
|
||||||
|
>
|
||||||
|
{name}
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
NameLabel.propTypes = {
|
||||||
|
group: PropTypes.string
|
||||||
|
};
|
||||||
|
|
||||||
|
export default NameLabel;
|
||||||
@@ -29,6 +29,7 @@ import TableSwitch from 'ui-component/Switch';
|
|||||||
|
|
||||||
import ResponseTimeLabel from './ResponseTimeLabel';
|
import ResponseTimeLabel from './ResponseTimeLabel';
|
||||||
import GroupLabel from './GroupLabel';
|
import GroupLabel from './GroupLabel';
|
||||||
|
import NameLabel from './NameLabel';
|
||||||
|
|
||||||
import { IconDotsVertical, IconEdit, IconTrash, IconPencil } from '@tabler/icons-react';
|
import { IconDotsVertical, IconEdit, IconTrash, IconPencil } from '@tabler/icons-react';
|
||||||
|
|
||||||
@@ -102,7 +103,9 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal,
|
|||||||
<TableRow tabIndex={item.id}>
|
<TableRow tabIndex={item.id}>
|
||||||
<TableCell>{item.id}</TableCell>
|
<TableCell>{item.id}</TableCell>
|
||||||
|
|
||||||
<TableCell>{item.name}</TableCell>
|
<TableCell>
|
||||||
|
<NameLabel name={item.name} models={item.models} />
|
||||||
|
</TableCell>
|
||||||
|
|
||||||
<TableCell>
|
<TableCell>
|
||||||
<GroupLabel group={item.group} />
|
<GroupLabel group={item.group} />
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ export default function ChannelPage() {
|
|||||||
if (success) {
|
if (success) {
|
||||||
showSuccess('操作成功完成!');
|
showSuccess('操作成功完成!');
|
||||||
if (action === 'delete') {
|
if (action === 'delete') {
|
||||||
await loadChannels(0);
|
await handleRefresh();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
@@ -127,9 +127,7 @@ export default function ChannelPage() {
|
|||||||
|
|
||||||
// 处理刷新
|
// 处理刷新
|
||||||
const handleRefresh = async () => {
|
const handleRefresh = async () => {
|
||||||
await loadChannels(0);
|
await loadChannels(activePage);
|
||||||
setActivePage(0);
|
|
||||||
setSearchKeyword('');
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// 处理测试所有启用渠道
|
// 处理测试所有启用渠道
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ const defaultConfig = {
|
|||||||
key: '',
|
key: '',
|
||||||
base_url: '',
|
base_url: '',
|
||||||
other: '',
|
other: '',
|
||||||
|
proxy: '',
|
||||||
|
test_model: '',
|
||||||
model_mapping: '',
|
model_mapping: '',
|
||||||
models: [],
|
models: [],
|
||||||
groups: ['default']
|
groups: ['default']
|
||||||
@@ -15,6 +17,8 @@ const defaultConfig = {
|
|||||||
base_url: '渠道API地址',
|
base_url: '渠道API地址',
|
||||||
key: '密钥',
|
key: '密钥',
|
||||||
other: '其他参数',
|
other: '其他参数',
|
||||||
|
proxy: '代理地址',
|
||||||
|
test_model: '测速模型',
|
||||||
models: '模型',
|
models: '模型',
|
||||||
model_mapping: '模型映射关系',
|
model_mapping: '模型映射关系',
|
||||||
groups: '用户组'
|
groups: '用户组'
|
||||||
@@ -25,6 +29,8 @@ const defaultConfig = {
|
|||||||
base_url: '可空,请输入中转API地址,例如通过cloudflare中转',
|
base_url: '可空,请输入中转API地址,例如通过cloudflare中转',
|
||||||
key: '请输入渠道对应的鉴权密钥',
|
key: '请输入渠道对应的鉴权密钥',
|
||||||
other: '',
|
other: '',
|
||||||
|
proxy: '单独设置代理地址,支持http和socks5,例如:http://127.0.0.1:1080',
|
||||||
|
test_model: '用于测试使用的模型,为空时无法测速,如:gpt-3.5-turbo',
|
||||||
models: '请选择该渠道所支持的模型',
|
models: '请选择该渠道所支持的模型',
|
||||||
model_mapping:
|
model_mapping:
|
||||||
'请输入要修改的模型映射关系,格式为:api请求模型ID:实际转发给渠道的模型ID,使用JSON数组表示,例如:{"gpt-3.5": "gpt-35"}',
|
'请输入要修改的模型映射关系,格式为:api请求模型ID:实际转发给渠道的模型ID,使用JSON数组表示,例如:{"gpt-3.5": "gpt-35"}',
|
||||||
@@ -45,17 +51,20 @@ const typeConfig = {
|
|||||||
},
|
},
|
||||||
11: {
|
11: {
|
||||||
input: {
|
input: {
|
||||||
models: ['PaLM-2']
|
models: ['PaLM-2'],
|
||||||
|
test_model: 'PaLM-2'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
14: {
|
14: {
|
||||||
input: {
|
input: {
|
||||||
models: ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1']
|
models: ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1'],
|
||||||
|
test_model: 'claude-2'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
15: {
|
15: {
|
||||||
input: {
|
input: {
|
||||||
models: ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']
|
models: ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'],
|
||||||
|
test_model: 'ERNIE-Bot'
|
||||||
},
|
},
|
||||||
prompt: {
|
prompt: {
|
||||||
key: '按照如下格式输入:APIKey|SecretKey'
|
key: '按照如下格式输入:APIKey|SecretKey'
|
||||||
@@ -63,7 +72,8 @@ const typeConfig = {
|
|||||||
},
|
},
|
||||||
16: {
|
16: {
|
||||||
input: {
|
input: {
|
||||||
models: ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']
|
models: ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite'],
|
||||||
|
test_model: 'chatglm_lite'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
17: {
|
17: {
|
||||||
@@ -71,7 +81,18 @@ const typeConfig = {
|
|||||||
other: '插件参数'
|
other: '插件参数'
|
||||||
},
|
},
|
||||||
input: {
|
input: {
|
||||||
models: ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1']
|
models: [
|
||||||
|
'qwen-turbo',
|
||||||
|
'qwen-plus',
|
||||||
|
'qwen-max',
|
||||||
|
'qwen-max-longcontext',
|
||||||
|
'text-embedding-v1',
|
||||||
|
'qwen-turbo-internet',
|
||||||
|
'qwen-plus-internet',
|
||||||
|
'qwen-max-internet',
|
||||||
|
'qwen-max-longcontext-internet'
|
||||||
|
],
|
||||||
|
test_model: 'qwen-turbo'
|
||||||
},
|
},
|
||||||
prompt: {
|
prompt: {
|
||||||
other: '请输入插件参数,即 X-DashScope-Plugin 请求头的取值'
|
other: '请输入插件参数,即 X-DashScope-Plugin 请求头的取值'
|
||||||
@@ -91,7 +112,8 @@ const typeConfig = {
|
|||||||
},
|
},
|
||||||
19: {
|
19: {
|
||||||
input: {
|
input: {
|
||||||
models: ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']
|
models: ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'],
|
||||||
|
test_model: '360GPT_S2_V9'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
22: {
|
22: {
|
||||||
@@ -101,7 +123,8 @@ const typeConfig = {
|
|||||||
},
|
},
|
||||||
23: {
|
23: {
|
||||||
input: {
|
input: {
|
||||||
models: ['hunyuan']
|
models: ['hunyuan'],
|
||||||
|
test_model: 'hunyuan'
|
||||||
},
|
},
|
||||||
prompt: {
|
prompt: {
|
||||||
key: '按照如下格式输入:AppId|SecretId|SecretKey'
|
key: '按照如下格式输入:AppId|SecretId|SecretKey'
|
||||||
@@ -112,11 +135,26 @@ const typeConfig = {
|
|||||||
other: '版本号'
|
other: '版本号'
|
||||||
},
|
},
|
||||||
input: {
|
input: {
|
||||||
models: ['gemini-pro']
|
models: ['gemini-pro', 'gemini-pro-vision'],
|
||||||
|
test_model: 'gemini-pro'
|
||||||
},
|
},
|
||||||
prompt: {
|
prompt: {
|
||||||
other: '请输入版本号,例如:v1'
|
other: '请输入版本号,例如:v1'
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
26: {
|
||||||
|
input: {
|
||||||
|
models: ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan2-53B', 'Baichuan-Text-Embedding'],
|
||||||
|
test_model: 'Baichuan2-Turbo'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
24: {
|
||||||
|
input: {
|
||||||
|
models: ['tts-1', 'tts-1-hd']
|
||||||
|
},
|
||||||
|
prompt: {
|
||||||
|
test_model: ''
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ export default function Token() {
|
|||||||
if (success) {
|
if (success) {
|
||||||
showSuccess('操作成功完成!');
|
showSuccess('操作成功完成!');
|
||||||
if (action === 'delete') {
|
if (action === 'delete') {
|
||||||
await loadTokens(0);
|
await handleRefresh();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
@@ -119,9 +119,7 @@ export default function Token() {
|
|||||||
|
|
||||||
// 处理刷新
|
// 处理刷新
|
||||||
const handleRefresh = async () => {
|
const handleRefresh = async () => {
|
||||||
await loadTokens(0);
|
await loadTokens(activePage);
|
||||||
setActivePage(0);
|
|
||||||
setSearchKeyword('');
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleOpenModal = (tokenId) => {
|
const handleOpenModal = (tokenId) => {
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ const TopupCard = () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
let status = localStorage.getItem('status');
|
let status = localStorage.getItem('siteInfo');
|
||||||
if (status) {
|
if (status) {
|
||||||
status = JSON.parse(status);
|
status = JSON.parse(status);
|
||||||
if (status.top_up_link) {
|
if (status.top_up_link) {
|
||||||
|
|||||||
@@ -109,9 +109,7 @@ export default function Users() {
|
|||||||
|
|
||||||
// 处理刷新
|
// 处理刷新
|
||||||
const handleRefresh = async () => {
|
const handleRefresh = async () => {
|
||||||
await loadUsers(0);
|
await loadUsers(activePage);
|
||||||
setActivePage(0);
|
|
||||||
setSearchKeyword('');
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleOpenModal = (userId) => {
|
const handleOpenModal = (userId) => {
|
||||||
|
|||||||
Reference in New Issue
Block a user