Compare commits
38 Commits
v0.6.10-al
...
v0.6.11-al
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fa2a772731 | ||
|
|
4f68f3e1b3 | ||
|
|
0bab887b2d | ||
|
|
0230d36643 | ||
|
|
bad57d049a | ||
|
|
dc470ce82e | ||
|
|
ea0721d525 | ||
|
|
d0402f9086 | ||
|
|
1fead8e7f7 | ||
|
|
09911a301d | ||
|
|
f95e6b78b8 | ||
|
|
605bb06667 | ||
|
|
d88e07fd9a | ||
|
|
3915ce9814 | ||
|
|
999defc88b | ||
|
|
b51c47bc77 | ||
|
|
4f25cde132 | ||
|
|
d89e9d7e44 | ||
|
|
a858292b54 | ||
|
|
ff589b5e4a | ||
|
|
95e8c16338 | ||
|
|
381172cb36 | ||
|
|
59eae186a3 | ||
|
|
ce52f355bb | ||
|
|
cb9d0a74c9 | ||
|
|
49ffb1c60d | ||
|
|
2f16649896 | ||
|
|
af3aa57bd6 | ||
|
|
e9f117ff72 | ||
|
|
6bb5247bd6 | ||
|
|
305ce14fe3 | ||
|
|
36c8f4f15c | ||
|
|
45b51ea0ee | ||
|
|
7c8628bd95 | ||
|
|
6ab87f8a08 | ||
|
|
833fa7ad6f | ||
|
|
6eb0770a89 | ||
|
|
92cd46d64f |
10
.github/workflows/ci.yml
vendored
10
.github/workflows/ci.yml
vendored
@@ -1,19 +1,17 @@
|
|||||||
name: CI
|
name: CI
|
||||||
|
|
||||||
# This setup assumes that you run the unit tests with code coverage in the same
|
# This setup assumes that you run the unit tests with code coverage in the same
|
||||||
# workflow that will also print the coverage report as comment to the pull request.
|
# workflow that will also print the coverage report as comment to the pull request.
|
||||||
# Therefore, you need to trigger this workflow when a pull request is (re)opened or
|
# Therefore, you need to trigger this workflow when a pull request is (re)opened or
|
||||||
# when new code is pushed to the branch of the pull request. In addition, you also
|
# when new code is pushed to the branch of the pull request. In addition, you also
|
||||||
# need to trigger this workflow when new code is pushed to the main branch because
|
# need to trigger this workflow when new code is pushed to the main branch because
|
||||||
# we need to upload the code coverage results as artifact for the main branch as
|
# we need to upload the code coverage results as artifact for the main branch as
|
||||||
# well since it will be the baseline code coverage.
|
# well since it will be the baseline code coverage.
|
||||||
#
|
#
|
||||||
# We do not want to trigger the workflow for pushes to *any* branch because this
|
# We do not want to trigger the workflow for pushes to *any* branch because this
|
||||||
# would trigger our jobs twice on pull requests (once from "push" event and once
|
# would trigger our jobs twice on pull requests (once from "push" event and once
|
||||||
# from "pull_request->synchronize")
|
# from "pull_request->synchronize")
|
||||||
on:
|
on:
|
||||||
pull_request:
|
|
||||||
types: [opened, reopened, synchronize]
|
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
@@ -31,7 +29,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
go-version: ^1.22
|
go-version: ^1.22
|
||||||
|
|
||||||
# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
|
# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
|
||||||
# coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
|
# coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
|
||||||
# in the next step as well as the next job.
|
# in the next step as well as the next job.
|
||||||
- name: Test
|
- name: Test
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -9,4 +9,5 @@ logs
|
|||||||
data
|
data
|
||||||
/web/node_modules
|
/web/node_modules
|
||||||
cmd.md
|
cmd.md
|
||||||
.env
|
.env
|
||||||
|
/one-api
|
||||||
|
|||||||
17
README.md
17
README.md
@@ -115,8 +115,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
|||||||
21. 支持 Cloudflare Turnstile 用户校验。
|
21. 支持 Cloudflare Turnstile 用户校验。
|
||||||
22. 支持用户管理,支持**多种用户登录注册方式**:
|
22. 支持用户管理,支持**多种用户登录注册方式**:
|
||||||
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
|
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
|
||||||
+ 支持使用飞书进行授权登录。
|
+ 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)([这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login))。
|
||||||
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
+ 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。
|
||||||
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
||||||
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
|
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
|
||||||
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
|
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
|
||||||
@@ -175,6 +175,10 @@ sudo service nginx restart
|
|||||||
|
|
||||||
初始账号用户名为 `root`,密码为 `123456`。
|
初始账号用户名为 `root`,密码为 `123456`。
|
||||||
|
|
||||||
|
### 通过宝塔面板进行一键部署
|
||||||
|
1. 安装宝塔面板9.2.0及以上版本,前往 [宝塔面板](https://www.bt.cn/new/download.html?r=dk_oneapi) 官网,选择正式版的脚本下载安装;
|
||||||
|
2. 安装后登录宝塔面板,在左侧菜单栏中点击 `Docker`,首次进入会提示安装 `Docker` 服务,点击立即安装,按提示完成安装;
|
||||||
|
3. 安装完成后在应用商店中搜索 `One-API`,点击安装,配置域名等基本信息即可完成安装;
|
||||||
|
|
||||||
### 基于 Docker Compose 进行部署
|
### 基于 Docker Compose 进行部署
|
||||||
|
|
||||||
@@ -218,7 +222,7 @@ docker-compose ps
|
|||||||
3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。
|
3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。
|
||||||
4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。
|
4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。
|
||||||
5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。
|
5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。
|
||||||
6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟。
|
6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟(Redis 集群或者哨兵模式的支持请参考环境变量说明)。
|
||||||
7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。
|
7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。
|
||||||
|
|
||||||
环境变量的具体使用方法详见[此处](#环境变量)。
|
环境变量的具体使用方法详见[此处](#环境变量)。
|
||||||
@@ -347,6 +351,11 @@ graph LR
|
|||||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
||||||
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
||||||
+ 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。
|
+ 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。
|
||||||
|
+ 如果需要使用哨兵或者集群模式:
|
||||||
|
+ 则需要把该环境变量设置为节点列表,例如:`localhost:49153,localhost:49154,localhost:49155`。
|
||||||
|
+ 除此之外还需要设置以下环境变量:
|
||||||
|
+ `REDIS_PASSWORD`:Redis 集群或者哨兵模式下的密码设置。
|
||||||
|
+ `REDIS_MASTER_NAME`:Redis 哨兵模式下主节点的名称。
|
||||||
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
|
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
|
||||||
+ 例子:`SESSION_SECRET=random_string`
|
+ 例子:`SESSION_SECRET=random_string`
|
||||||
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。
|
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。
|
||||||
@@ -400,6 +409,8 @@ graph LR
|
|||||||
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
|
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
|
||||||
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
|
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
|
||||||
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
|
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
|
||||||
|
29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage,默认不开启,可选值为 `true` 和 `false`。
|
||||||
|
30. `TEST_PROMPT`:测试模型时的用户 prompt,默认为 `Print your model name exactly and do not output without any other text.`。
|
||||||
|
|
||||||
### 命令行参数
|
### 命令行参数
|
||||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/songquanpeng/one-api/common/env"
|
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/env"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -160,3 +161,6 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
|
|||||||
var RelayProxy = env.String("RELAY_PROXY", "")
|
var RelayProxy = env.String("RELAY_PROXY", "")
|
||||||
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
|
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
|
||||||
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
|
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
|
||||||
|
|
||||||
|
var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)
|
||||||
|
var TestPrompt = env.String("TEST_PROMPT", "Print your model name exactly and do not output without any other text.")
|
||||||
|
|||||||
@@ -20,4 +20,5 @@ const (
|
|||||||
BaseURL = "base_url"
|
BaseURL = "base_url"
|
||||||
AvailableModels = "available_models"
|
AvailableModels = "available_models"
|
||||||
KeyRequestBody = "key_request_body"
|
KeyRequestBody = "key_request_body"
|
||||||
|
SystemPrompt = "system_prompt"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
package helper
|
package helper
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/songquanpeng/one-api/common/random"
|
|
||||||
"html/template"
|
"html/template"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
@@ -11,6 +10,10 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
)
|
)
|
||||||
|
|
||||||
func OpenBrowser(url string) {
|
func OpenBrowser(url string) {
|
||||||
@@ -106,6 +109,18 @@ func GenRequestID() string {
|
|||||||
return GetTimeString() + random.GetRandomNumberString(8)
|
return GetTimeString() + random.GetRandomNumberString(8)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetRequestID(ctx context.Context, id string) context.Context {
|
||||||
|
return context.WithValue(ctx, RequestIdKey, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetRequestID(ctx context.Context) string {
|
||||||
|
rawRequestId := ctx.Value(RequestIdKey)
|
||||||
|
if rawRequestId == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return rawRequestId.(string)
|
||||||
|
}
|
||||||
|
|
||||||
func GetResponseID(c *gin.Context) string {
|
func GetResponseID(c *gin.Context) string {
|
||||||
logID := c.GetString(RequestIdKey)
|
logID := c.GetString(RequestIdKey)
|
||||||
return fmt.Sprintf("chatcmpl-%s", logID)
|
return fmt.Sprintf("chatcmpl-%s", logID)
|
||||||
|
|||||||
@@ -13,3 +13,8 @@ func GetTimeString() string {
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
|
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CalcElapsedTime return the elapsed time in milliseconds (ms)
|
||||||
|
func CalcElapsedTime(start time.Time) int64 {
|
||||||
|
return time.Now().Sub(start).Milliseconds()
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,19 +7,25 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type loggerLevel string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
loggerDEBUG = "DEBUG"
|
loggerDEBUG loggerLevel = "DEBUG"
|
||||||
loggerINFO = "INFO"
|
loggerINFO loggerLevel = "INFO"
|
||||||
loggerWarn = "WARN"
|
loggerWarn loggerLevel = "WARN"
|
||||||
loggerError = "ERR"
|
loggerError loggerLevel = "ERROR"
|
||||||
|
loggerFatal loggerLevel = "FATAL"
|
||||||
)
|
)
|
||||||
|
|
||||||
var setupLogOnce sync.Once
|
var setupLogOnce sync.Once
|
||||||
@@ -44,27 +50,26 @@ func SetupLogger() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SysLog(s string) {
|
func SysLog(s string) {
|
||||||
t := time.Now()
|
logHelper(nil, loggerINFO, s)
|
||||||
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SysLogf(format string, a ...any) {
|
func SysLogf(format string, a ...any) {
|
||||||
SysLog(fmt.Sprintf(format, a...))
|
logHelper(nil, loggerINFO, fmt.Sprintf(format, a...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func SysError(s string) {
|
func SysError(s string) {
|
||||||
t := time.Now()
|
logHelper(nil, loggerError, s)
|
||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SysErrorf(format string, a ...any) {
|
func SysErrorf(format string, a ...any) {
|
||||||
SysError(fmt.Sprintf(format, a...))
|
logHelper(nil, loggerError, fmt.Sprintf(format, a...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Debug(ctx context.Context, msg string) {
|
func Debug(ctx context.Context, msg string) {
|
||||||
if config.DebugEnabled {
|
if !config.DebugEnabled {
|
||||||
logHelper(ctx, loggerDEBUG, msg)
|
return
|
||||||
}
|
}
|
||||||
|
logHelper(ctx, loggerDEBUG, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Info(ctx context.Context, msg string) {
|
func Info(ctx context.Context, msg string) {
|
||||||
@@ -80,37 +85,65 @@ func Error(ctx context.Context, msg string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Debugf(ctx context.Context, format string, a ...any) {
|
func Debugf(ctx context.Context, format string, a ...any) {
|
||||||
Debug(ctx, fmt.Sprintf(format, a...))
|
logHelper(ctx, loggerDEBUG, fmt.Sprintf(format, a...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Infof(ctx context.Context, format string, a ...any) {
|
func Infof(ctx context.Context, format string, a ...any) {
|
||||||
Info(ctx, fmt.Sprintf(format, a...))
|
logHelper(ctx, loggerINFO, fmt.Sprintf(format, a...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Warnf(ctx context.Context, format string, a ...any) {
|
func Warnf(ctx context.Context, format string, a ...any) {
|
||||||
Warn(ctx, fmt.Sprintf(format, a...))
|
logHelper(ctx, loggerWarn, fmt.Sprintf(format, a...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Errorf(ctx context.Context, format string, a ...any) {
|
func Errorf(ctx context.Context, format string, a ...any) {
|
||||||
Error(ctx, fmt.Sprintf(format, a...))
|
logHelper(ctx, loggerError, fmt.Sprintf(format, a...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func logHelper(ctx context.Context, level string, msg string) {
|
func FatalLog(s string) {
|
||||||
|
logHelper(nil, loggerFatal, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FatalLogf(format string, a ...any) {
|
||||||
|
logHelper(nil, loggerFatal, fmt.Sprintf(format, a...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func logHelper(ctx context.Context, level loggerLevel, msg string) {
|
||||||
writer := gin.DefaultErrorWriter
|
writer := gin.DefaultErrorWriter
|
||||||
if level == loggerINFO {
|
if level == loggerINFO {
|
||||||
writer = gin.DefaultWriter
|
writer = gin.DefaultWriter
|
||||||
}
|
}
|
||||||
id := ctx.Value(helper.RequestIdKey)
|
var requestId string
|
||||||
if id == nil {
|
if ctx != nil {
|
||||||
id = helper.GenRequestID()
|
rawRequestId := helper.GetRequestID(ctx)
|
||||||
|
if rawRequestId != "" {
|
||||||
|
requestId = fmt.Sprintf(" | %s", rawRequestId)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
lineInfo, funcName := getLineInfo()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
_, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), requestId, lineInfo, funcName, msg)
|
||||||
SetupLogger()
|
SetupLogger()
|
||||||
|
if level == loggerFatal {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FatalLog(v ...any) {
|
func getLineInfo() (string, string) {
|
||||||
t := time.Now()
|
funcName := "[unknown] "
|
||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
pc, file, line, ok := runtime.Caller(3)
|
||||||
os.Exit(1)
|
if ok {
|
||||||
|
if fn := runtime.FuncForPC(pc); fn != nil {
|
||||||
|
parts := strings.Split(fn.Name(), ".")
|
||||||
|
funcName = "[" + parts[len(parts)-1] + "] "
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
file = "unknown"
|
||||||
|
line = 0
|
||||||
|
}
|
||||||
|
parts := strings.Split(file, "one-api/")
|
||||||
|
if len(parts) > 1 {
|
||||||
|
file = parts[1]
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(" | %s:%d", file, line), funcName
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,13 +2,15 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var RDB *redis.Client
|
var RDB redis.Cmdable
|
||||||
var RedisEnabled = true
|
var RedisEnabled = true
|
||||||
|
|
||||||
// InitRedisClient This function is called after init()
|
// InitRedisClient This function is called after init()
|
||||||
@@ -23,13 +25,23 @@ func InitRedisClient() (err error) {
|
|||||||
logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
|
logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
logger.SysLog("Redis is enabled")
|
redisConnString := os.Getenv("REDIS_CONN_STRING")
|
||||||
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
|
if os.Getenv("REDIS_MASTER_NAME") == "" {
|
||||||
if err != nil {
|
logger.SysLog("Redis is enabled")
|
||||||
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
|
opt, err := redis.ParseURL(redisConnString)
|
||||||
|
if err != nil {
|
||||||
|
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
|
||||||
|
}
|
||||||
|
RDB = redis.NewClient(opt)
|
||||||
|
} else {
|
||||||
|
// cluster mode
|
||||||
|
logger.SysLog("Redis cluster mode enabled")
|
||||||
|
RDB = redis.NewUniversalClient(&redis.UniversalOptions{
|
||||||
|
Addrs: strings.Split(redisConnString, ","),
|
||||||
|
Password: os.Getenv("REDIS_PASSWORD"),
|
||||||
|
MasterName: os.Getenv("REDIS_MASTER_NAME"),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
RDB = redis.NewClient(opt)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
|||||||
@@ -3,9 +3,10 @@ package render
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func StringData(c *gin.Context, str string) {
|
func StringData(c *gin.Context, str string) {
|
||||||
|
|||||||
@@ -5,16 +5,18 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/common/random"
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
"github.com/songquanpeng/one-api/controller"
|
"github.com/songquanpeng/one-api/controller"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type GitHubOAuthResponse struct {
|
type GitHubOAuthResponse struct {
|
||||||
@@ -81,6 +83,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GitHubOAuth(c *gin.Context) {
|
func GitHubOAuth(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||||
@@ -136,7 +139,7 @@ func GitHubOAuth(c *gin.Context) {
|
|||||||
user.Role = model.RoleCommonUser
|
user.Role = model.RoleCommonUser
|
||||||
user.Status = model.UserStatusEnabled
|
user.Status = model.UserStatusEnabled
|
||||||
|
|
||||||
if err := user.Insert(0); err != nil {
|
if err := user.Insert(ctx, 0); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
|
|||||||
@@ -5,15 +5,17 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/controller"
|
"github.com/songquanpeng/one-api/controller"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LarkOAuthResponse struct {
|
type LarkOAuthResponse struct {
|
||||||
@@ -40,7 +42,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
|
req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -79,6 +81,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func LarkOAuth(c *gin.Context) {
|
func LarkOAuth(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||||
@@ -125,7 +128,7 @@ func LarkOAuth(c *gin.Context) {
|
|||||||
user.Role = model.RoleCommonUser
|
user.Role = model.RoleCommonUser
|
||||||
user.Status = model.UserStatusEnabled
|
user.Status = model.UserStatusEnabled
|
||||||
|
|
||||||
if err := user.Insert(0); err != nil {
|
if err := user.Insert(ctx, 0); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
|
|||||||
@@ -5,15 +5,17 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/controller"
|
"github.com/songquanpeng/one-api/controller"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type OidcResponse struct {
|
type OidcResponse struct {
|
||||||
@@ -87,6 +89,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OidcAuth(c *gin.Context) {
|
func OidcAuth(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||||
@@ -142,7 +145,7 @@ func OidcAuth(c *gin.Context) {
|
|||||||
} else {
|
} else {
|
||||||
user.DisplayName = "OIDC User"
|
user.DisplayName = "OIDC User"
|
||||||
}
|
}
|
||||||
err := user.Insert(0)
|
err := user.Insert(ctx, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
|
|||||||
@@ -4,14 +4,16 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"github.com/songquanpeng/one-api/controller"
|
"github.com/songquanpeng/one-api/controller"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type wechatLoginResponse struct {
|
type wechatLoginResponse struct {
|
||||||
@@ -52,6 +54,7 @@ func getWeChatIdByCode(code string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WeChatAuth(c *gin.Context) {
|
func WeChatAuth(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
if !config.WeChatAuthEnabled {
|
if !config.WeChatAuthEnabled {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "管理员未开启通过微信登录以及注册",
|
"message": "管理员未开启通过微信登录以及注册",
|
||||||
@@ -87,7 +90,7 @@ func WeChatAuth(c *gin.Context) {
|
|||||||
user.Role = model.RoleCommonUser
|
user.Role = model.RoleCommonUser
|
||||||
user.Status = model.UserStatusEnabled
|
user.Status = model.UserStatusEnabled
|
||||||
|
|
||||||
if err := user.Insert(0); err != nil {
|
if err := user.Insert(ctx, 0); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
|
|||||||
@@ -4,16 +4,17 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/client"
|
"github.com/songquanpeng/one-api/common/client"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"github.com/songquanpeng/one-api/monitor"
|
"github.com/songquanpeng/one-api/monitor"
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -101,6 +102,16 @@ type SiliconFlowUsageResponse struct {
|
|||||||
} `json:"data"`
|
} `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DeepSeekUsageResponse struct {
|
||||||
|
IsAvailable bool `json:"is_available"`
|
||||||
|
BalanceInfos []struct {
|
||||||
|
Currency string `json:"currency"`
|
||||||
|
TotalBalance string `json:"total_balance"`
|
||||||
|
GrantedBalance string `json:"granted_balance"`
|
||||||
|
ToppedUpBalance string `json:"topped_up_balance"`
|
||||||
|
} `json:"balance_infos"`
|
||||||
|
}
|
||||||
|
|
||||||
// GetAuthHeader get auth header
|
// GetAuthHeader get auth header
|
||||||
func GetAuthHeader(token string) http.Header {
|
func GetAuthHeader(token string) http.Header {
|
||||||
h := http.Header{}
|
h := http.Header{}
|
||||||
@@ -237,7 +248,36 @@ func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
|
|||||||
if response.Code != 20000 {
|
if response.Code != 20000 {
|
||||||
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
|
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
|
||||||
}
|
}
|
||||||
balance, err := strconv.ParseFloat(response.Data.Balance, 64)
|
balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
channel.UpdateBalance(balance)
|
||||||
|
return balance, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) {
|
||||||
|
url := "https://api.deepseek.com/user/balance"
|
||||||
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
response := DeepSeekUsageResponse{}
|
||||||
|
err = json.Unmarshal(body, &response)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
index := -1
|
||||||
|
for i, balanceInfo := range response.BalanceInfos {
|
||||||
|
if balanceInfo.Currency == "CNY" {
|
||||||
|
index = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if index == -1 {
|
||||||
|
return 0, errors.New("currency CNY not found")
|
||||||
|
}
|
||||||
|
balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -271,6 +311,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
|||||||
return updateChannelAIGC2DBalance(channel)
|
return updateChannelAIGC2DBalance(channel)
|
||||||
case channeltype.SiliconFlow:
|
case channeltype.SiliconFlow:
|
||||||
return updateChannelSiliconFlowBalance(channel)
|
return updateChannelSiliconFlowBalance(channel)
|
||||||
|
case channeltype.DeepSeek:
|
||||||
|
return updateChannelDeepSeekBalance(channel)
|
||||||
default:
|
default:
|
||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -15,14 +16,17 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/common/message"
|
"github.com/songquanpeng/one-api/common/message"
|
||||||
"github.com/songquanpeng/one-api/middleware"
|
"github.com/songquanpeng/one-api/middleware"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"github.com/songquanpeng/one-api/monitor"
|
"github.com/songquanpeng/one-api/monitor"
|
||||||
relay "github.com/songquanpeng/one-api/relay"
|
"github.com/songquanpeng/one-api/relay"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
"github.com/songquanpeng/one-api/relay/controller"
|
"github.com/songquanpeng/one-api/relay/controller"
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
@@ -35,18 +39,34 @@ func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
|
|||||||
model = "gpt-3.5-turbo"
|
model = "gpt-3.5-turbo"
|
||||||
}
|
}
|
||||||
testRequest := &relaymodel.GeneralOpenAIRequest{
|
testRequest := &relaymodel.GeneralOpenAIRequest{
|
||||||
MaxTokens: 2,
|
Model: model,
|
||||||
Model: model,
|
|
||||||
}
|
}
|
||||||
testMessage := relaymodel.Message{
|
testMessage := relaymodel.Message{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: "hi",
|
Content: config.TestPrompt,
|
||||||
}
|
}
|
||||||
testRequest.Messages = append(testRequest.Messages, testMessage)
|
testRequest.Messages = append(testRequest.Messages, testMessage)
|
||||||
return testRequest
|
return testRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) {
|
func parseTestResponse(resp string) (*openai.TextResponse, string, error) {
|
||||||
|
var response openai.TextResponse
|
||||||
|
err := json.Unmarshal([]byte(resp), &response)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
if len(response.Choices) == 0 {
|
||||||
|
return nil, "", errors.New("response has no choices")
|
||||||
|
}
|
||||||
|
stringContent, ok := response.Choices[0].Content.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, "", errors.New("response content is not string")
|
||||||
|
}
|
||||||
|
return &response, stringContent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func testChannel(ctx context.Context, channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (responseMessage string, err error, openaiErr *relaymodel.Error) {
|
||||||
|
startTime := time.Now()
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Request = &http.Request{
|
c.Request = &http.Request{
|
||||||
@@ -66,7 +86,7 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
|
|||||||
apiType := channeltype.ToAPIType(channel.Type)
|
apiType := channeltype.ToAPIType(channel.Type)
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
return "", fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||||
}
|
}
|
||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
modelName := request.Model
|
modelName := request.Model
|
||||||
@@ -84,41 +104,69 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
|
|||||||
request.Model = modelName
|
request.Model = modelName
|
||||||
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
|
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return "", err, nil
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return "", err, nil
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
logContent := fmt.Sprintf("渠道 %s 测试成功,响应:%s", channel.Name, responseMessage)
|
||||||
|
if err != nil || openaiErr != nil {
|
||||||
|
errorMessage := ""
|
||||||
|
if err != nil {
|
||||||
|
errorMessage = err.Error()
|
||||||
|
} else {
|
||||||
|
errorMessage = openaiErr.Message
|
||||||
|
}
|
||||||
|
logContent = fmt.Sprintf("渠道 %s 测试失败,错误:%s", channel.Name, errorMessage)
|
||||||
|
}
|
||||||
|
go model.RecordTestLog(ctx, &model.Log{
|
||||||
|
ChannelId: channel.Id,
|
||||||
|
ModelName: modelName,
|
||||||
|
Content: logContent,
|
||||||
|
ElapsedTime: helper.CalcElapsedTime(startTime),
|
||||||
|
})
|
||||||
|
}()
|
||||||
logger.SysLog(string(jsonData))
|
logger.SysLog(string(jsonData))
|
||||||
requestBody := bytes.NewBuffer(jsonData)
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
c.Request.Body = io.NopCloser(requestBody)
|
c.Request.Body = io.NopCloser(requestBody)
|
||||||
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return "", err, nil
|
||||||
}
|
}
|
||||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||||
err := controller.RelayErrorHandler(resp)
|
err := controller.RelayErrorHandler(resp)
|
||||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
errorMessage := err.Error.Message
|
||||||
|
if errorMessage != "" {
|
||||||
|
errorMessage = ", error message: " + errorMessage
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("http status code: %d%s", resp.StatusCode, errorMessage), &err.Error
|
||||||
}
|
}
|
||||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
|
return "", fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
|
||||||
}
|
}
|
||||||
if usage == nil {
|
if usage == nil {
|
||||||
return errors.New("usage is nil"), nil
|
return "", errors.New("usage is nil"), nil
|
||||||
|
}
|
||||||
|
rawResponse := w.Body.String()
|
||||||
|
_, responseMessage, err = parseTestResponse(rawResponse)
|
||||||
|
if err != nil {
|
||||||
|
return "", err, nil
|
||||||
}
|
}
|
||||||
result := w.Result()
|
result := w.Result()
|
||||||
// print result.Body
|
// print result.Body
|
||||||
respBody, err := io.ReadAll(result.Body)
|
respBody, err := io.ReadAll(result.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return "", err, nil
|
||||||
}
|
}
|
||||||
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||||
return nil, nil
|
return responseMessage, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChannel(c *gin.Context) {
|
func TestChannel(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -135,10 +183,10 @@ func TestChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
model := c.Query("model")
|
modelName := c.Query("model")
|
||||||
testRequest := buildTestRequest(model)
|
testRequest := buildTestRequest(modelName)
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, _ = testChannel(channel, testRequest)
|
responseMessage, err, _ := testChannel(ctx, channel, testRequest)
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -148,18 +196,18 @@ func TestChannel(c *gin.Context) {
|
|||||||
consumedTime := float64(milliseconds) / 1000.0
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
"time": consumedTime,
|
"time": consumedTime,
|
||||||
"model": model,
|
"modelName": modelName,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": responseMessage,
|
||||||
"time": consumedTime,
|
"time": consumedTime,
|
||||||
"model": model,
|
"modelName": modelName,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -167,7 +215,7 @@ func TestChannel(c *gin.Context) {
|
|||||||
var testAllChannelsLock sync.Mutex
|
var testAllChannelsLock sync.Mutex
|
||||||
var testAllChannelsRunning bool = false
|
var testAllChannelsRunning bool = false
|
||||||
|
|
||||||
func testChannels(notify bool, scope string) error {
|
func testChannels(ctx context.Context, notify bool, scope string) error {
|
||||||
if config.RootUserEmail == "" {
|
if config.RootUserEmail == "" {
|
||||||
config.RootUserEmail = model.GetRootUserEmail()
|
config.RootUserEmail = model.GetRootUserEmail()
|
||||||
}
|
}
|
||||||
@@ -191,7 +239,7 @@ func testChannels(notify bool, scope string) error {
|
|||||||
isChannelEnabled := channel.Status == model.ChannelStatusEnabled
|
isChannelEnabled := channel.Status == model.ChannelStatusEnabled
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
testRequest := buildTestRequest("")
|
testRequest := buildTestRequest("")
|
||||||
err, openaiErr := testChannel(channel, testRequest)
|
_, err, openaiErr := testChannel(ctx, channel, testRequest)
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
if isChannelEnabled && milliseconds > disableThreshold {
|
if isChannelEnabled && milliseconds > disableThreshold {
|
||||||
@@ -225,11 +273,12 @@ func testChannels(notify bool, scope string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestChannels(c *gin.Context) {
|
func TestChannels(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
scope := c.Query("scope")
|
scope := c.Query("scope")
|
||||||
if scope == "" {
|
if scope == "" {
|
||||||
scope = "all"
|
scope = "all"
|
||||||
}
|
}
|
||||||
err := testChannels(true, scope)
|
err := testChannels(ctx, true, scope)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -245,10 +294,11 @@ func TestChannels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AutomaticallyTestChannels(frequency int) {
|
func AutomaticallyTestChannels(frequency int) {
|
||||||
|
ctx := context.Background()
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||||
logger.SysLog("testing all channels")
|
logger.SysLog("testing all channels")
|
||||||
_ = testChannels(false, "all")
|
_ = testChannels(ctx, false, "all")
|
||||||
logger.SysLog("channel test finished")
|
logger.SysLog("channel test finished")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ func Relay(c *gin.Context) {
|
|||||||
channelName := c.GetString(ctxkey.ChannelName)
|
channelName := c.GetString(ctxkey.ChannelName)
|
||||||
group := c.GetString(ctxkey.Group)
|
group := c.GetString(ctxkey.Group)
|
||||||
originalModel := c.GetString(ctxkey.OriginalModel)
|
originalModel := c.GetString(ctxkey.OriginalModel)
|
||||||
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
|
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
|
||||||
requestId := c.GetString(helper.RequestIdKey)
|
requestId := c.GetString(helper.RequestIdKey)
|
||||||
retryTimes := config.RetryTimes
|
retryTimes := config.RetryTimes
|
||||||
if !shouldRetry(c, bizErr.StatusCode) {
|
if !shouldRetry(c, bizErr.StatusCode) {
|
||||||
@@ -87,8 +87,7 @@ func Relay(c *gin.Context) {
|
|||||||
channelId := c.GetInt(ctxkey.ChannelId)
|
channelId := c.GetInt(ctxkey.ChannelId)
|
||||||
lastFailedChannelId = channelId
|
lastFailedChannelId = channelId
|
||||||
channelName := c.GetString(ctxkey.ChannelName)
|
channelName := c.GetString(ctxkey.ChannelName)
|
||||||
// BUG: bizErr is in race condition
|
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
|
||||||
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
|
|
||||||
}
|
}
|
||||||
if bizErr != nil {
|
if bizErr != nil {
|
||||||
if bizErr.StatusCode == http.StatusTooManyRequests {
|
if bizErr.StatusCode == http.StatusTooManyRequests {
|
||||||
@@ -122,7 +121,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
|
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) {
|
||||||
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
|
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
|
||||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||||
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ func Logout(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Register(c *gin.Context) {
|
func Register(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
if !config.RegisterEnabled {
|
if !config.RegisterEnabled {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "管理员关闭了新用户注册",
|
"message": "管理员关闭了新用户注册",
|
||||||
@@ -166,7 +167,7 @@ func Register(c *gin.Context) {
|
|||||||
if config.EmailVerificationEnabled {
|
if config.EmailVerificationEnabled {
|
||||||
cleanUser.Email = user.Email
|
cleanUser.Email = user.Email
|
||||||
}
|
}
|
||||||
if err := cleanUser.Insert(inviterId); err != nil {
|
if err := cleanUser.Insert(ctx, inviterId); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
@@ -362,6 +363,7 @@ func GetSelf(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUser(c *gin.Context) {
|
func UpdateUser(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
var updatedUser model.User
|
var updatedUser model.User
|
||||||
err := json.NewDecoder(c.Request.Body).Decode(&updatedUser)
|
err := json.NewDecoder(c.Request.Body).Decode(&updatedUser)
|
||||||
if err != nil || updatedUser.Id == 0 {
|
if err != nil || updatedUser.Id == 0 {
|
||||||
@@ -416,7 +418,7 @@ func UpdateUser(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if originUser.Quota != updatedUser.Quota {
|
if originUser.Quota != updatedUser.Quota {
|
||||||
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
|
model.RecordLog(ctx, originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
@@ -535,6 +537,7 @@ func DeleteSelf(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateUser(c *gin.Context) {
|
func CreateUser(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
var user model.User
|
var user model.User
|
||||||
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
||||||
if err != nil || user.Username == "" || user.Password == "" {
|
if err != nil || user.Username == "" || user.Password == "" {
|
||||||
@@ -568,7 +571,7 @@ func CreateUser(c *gin.Context) {
|
|||||||
Password: user.Password,
|
Password: user.Password,
|
||||||
DisplayName: user.DisplayName,
|
DisplayName: user.DisplayName,
|
||||||
}
|
}
|
||||||
if err := cleanUser.Insert(0); err != nil {
|
if err := cleanUser.Insert(ctx, 0); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
@@ -747,6 +750,7 @@ type topUpRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TopUp(c *gin.Context) {
|
func TopUp(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
req := topUpRequest{}
|
req := topUpRequest{}
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -757,7 +761,7 @@ func TopUp(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
quota, err := model.Redeem(req.Key, id)
|
quota, err := model.Redeem(ctx, req.Key, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -780,6 +784,7 @@ type adminTopUpRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AdminTopUp(c *gin.Context) {
|
func AdminTopUp(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
req := adminTopUpRequest{}
|
req := adminTopUpRequest{}
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -800,7 +805,7 @@ func AdminTopUp(c *gin.Context) {
|
|||||||
if req.Remark == "" {
|
if req.Remark == "" {
|
||||||
req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
|
req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
|
||||||
}
|
}
|
||||||
model.RecordTopupLog(req.UserId, req.Remark, req.Quota)
|
model.RecordTopupLog(ctx, req.UserId, req.Remark, req.Quota)
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
|
|||||||
8
go.mod
8
go.mod
@@ -25,7 +25,7 @@ require (
|
|||||||
github.com/pkoukk/tiktoken-go v0.1.7
|
github.com/pkoukk/tiktoken-go v0.1.7
|
||||||
github.com/smartystreets/goconvey v1.8.1
|
github.com/smartystreets/goconvey v1.8.1
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
golang.org/x/crypto v0.24.0
|
golang.org/x/crypto v0.31.0
|
||||||
golang.org/x/image v0.18.0
|
golang.org/x/image v0.18.0
|
||||||
google.golang.org/api v0.187.0
|
google.golang.org/api v0.187.0
|
||||||
gorm.io/driver/mysql v1.5.6
|
gorm.io/driver/mysql v1.5.6
|
||||||
@@ -99,9 +99,9 @@ require (
|
|||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/net v0.26.0 // indirect
|
golang.org/x/net v0.26.0 // indirect
|
||||||
golang.org/x/oauth2 v0.21.0 // indirect
|
golang.org/x/oauth2 v0.21.0 // indirect
|
||||||
golang.org/x/sync v0.7.0 // indirect
|
golang.org/x/sync v0.10.0 // indirect
|
||||||
golang.org/x/sys v0.21.0 // indirect
|
golang.org/x/sys v0.28.0 // indirect
|
||||||
golang.org/x/text v0.16.0 // indirect
|
golang.org/x/text v0.21.0 // indirect
|
||||||
golang.org/x/time v0.5.0 // indirect
|
golang.org/x/time v0.5.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
|
||||||
|
|||||||
16
go.sum
16
go.sum
@@ -222,8 +222,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
|||||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
|
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
|
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
|
||||||
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
|
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
|
||||||
@@ -244,20 +244,20 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht
|
|||||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
|||||||
@@ -2,13 +2,15 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ModelRequest struct {
|
type ModelRequest struct {
|
||||||
@@ -17,6 +19,7 @@ type ModelRequest struct {
|
|||||||
|
|
||||||
func Distribute() func(c *gin.Context) {
|
func Distribute() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
userId := c.GetInt(ctxkey.Id)
|
userId := c.GetInt(ctxkey.Id)
|
||||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||||
c.Set(ctxkey.Group, userGroup)
|
c.Set(ctxkey.Group, userGroup)
|
||||||
@@ -52,6 +55,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
logger.Debugf(ctx, "user id %d, user group: %s, request model: %s, using channel #%d", userId, userGroup, requestModel, channel.Id)
|
||||||
SetupContextForSelectedChannel(c, channel, requestModel)
|
SetupContextForSelectedChannel(c, channel, requestModel)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
@@ -61,6 +65,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
c.Set(ctxkey.Channel, channel.Type)
|
c.Set(ctxkey.Channel, channel.Type)
|
||||||
c.Set(ctxkey.ChannelId, channel.Id)
|
c.Set(ctxkey.ChannelId, channel.Id)
|
||||||
c.Set(ctxkey.ChannelName, channel.Name)
|
c.Set(ctxkey.ChannelName, channel.Name)
|
||||||
|
if channel.SystemPrompt != nil && *channel.SystemPrompt != "" {
|
||||||
|
c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt)
|
||||||
|
}
|
||||||
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
|
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
|
||||||
c.Set(ctxkey.OriginalModel, modelName) // for retry
|
c.Set(ctxkey.OriginalModel, modelName) // for retry
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
|
|||||||
27
middleware/gzip.go
Normal file
27
middleware/gzip.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"compress/gzip"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GzipDecodeMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if c.GetHeader("Content-Encoding") == "gzip" {
|
||||||
|
gzipReader, err := gzip.NewReader(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatus(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer gzipReader.Close()
|
||||||
|
|
||||||
|
// Replace the request body with the decompressed data
|
||||||
|
c.Request.Body = io.NopCloser(gzipReader)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Continue processing the request
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -10,7 +10,7 @@ func RequestId() func(c *gin.Context) {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
id := helper.GenRequestID()
|
id := helper.GenRequestID()
|
||||||
c.Set(helper.RequestIdKey, id)
|
c.Set(helper.RequestIdKey, id)
|
||||||
ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id)
|
ctx := helper.SetRequestID(c.Request.Context(), id)
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
c.Header(helper.RequestIdKey, id)
|
c.Header(helper.RequestIdKey, id)
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ type Channel struct {
|
|||||||
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
Config string `json:"config"`
|
Config string `json:"config"`
|
||||||
|
SystemPrompt *string `json:"system_prompt" gorm:"type:text"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChannelConfig struct {
|
type ChannelConfig struct {
|
||||||
|
|||||||
87
model/log.go
87
model/log.go
@@ -4,26 +4,31 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Log struct {
|
type Log struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
UserId int `json:"user_id" gorm:"index"`
|
UserId int `json:"user_id" gorm:"index"`
|
||||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"`
|
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"`
|
||||||
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
|
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
|
||||||
TokenName string `json:"token_name" gorm:"index;default:''"`
|
TokenName string `json:"token_name" gorm:"index;default:''"`
|
||||||
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
||||||
Quota int `json:"quota" gorm:"default:0"`
|
Quota int `json:"quota" gorm:"default:0"`
|
||||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||||
ChannelId int `json:"channel" gorm:"index"`
|
ChannelId int `json:"channel" gorm:"index"`
|
||||||
|
RequestId string `json:"request_id" gorm:"default:''"`
|
||||||
|
ElapsedTime int64 `json:"elapsed_time" gorm:"default:0"` // unit is ms
|
||||||
|
IsStream bool `json:"is_stream" gorm:"default:false"`
|
||||||
|
SystemPromptReset bool `json:"system_prompt_reset" gorm:"default:false"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -32,9 +37,21 @@ const (
|
|||||||
LogTypeConsume
|
LogTypeConsume
|
||||||
LogTypeManage
|
LogTypeManage
|
||||||
LogTypeSystem
|
LogTypeSystem
|
||||||
|
LogTypeTest
|
||||||
)
|
)
|
||||||
|
|
||||||
func RecordLog(userId int, logType int, content string) {
|
func recordLogHelper(ctx context.Context, log *Log) {
|
||||||
|
requestId := helper.GetRequestID(ctx)
|
||||||
|
log.RequestId = requestId
|
||||||
|
err := LOG_DB.Create(log).Error
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(ctx, "failed to record log: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Infof(ctx, "record log: %+v", log)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RecordLog(ctx context.Context, userId int, logType int, content string) {
|
||||||
if logType == LogTypeConsume && !config.LogConsumeEnabled {
|
if logType == LogTypeConsume && !config.LogConsumeEnabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -45,13 +62,10 @@ func RecordLog(userId int, logType int, content string) {
|
|||||||
Type: logType,
|
Type: logType,
|
||||||
Content: content,
|
Content: content,
|
||||||
}
|
}
|
||||||
err := LOG_DB.Create(log).Error
|
recordLogHelper(ctx, log)
|
||||||
if err != nil {
|
|
||||||
logger.SysError("failed to record log: " + err.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RecordTopupLog(userId int, content string, quota int) {
|
func RecordTopupLog(ctx context.Context, userId int, content string, quota int) {
|
||||||
log := &Log{
|
log := &Log{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Username: GetUsernameById(userId),
|
Username: GetUsernameById(userId),
|
||||||
@@ -60,34 +74,23 @@ func RecordTopupLog(userId int, content string, quota int) {
|
|||||||
Content: content,
|
Content: content,
|
||||||
Quota: quota,
|
Quota: quota,
|
||||||
}
|
}
|
||||||
err := LOG_DB.Create(log).Error
|
recordLogHelper(ctx, log)
|
||||||
if err != nil {
|
|
||||||
logger.SysError("failed to record log: " + err.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
|
func RecordConsumeLog(ctx context.Context, log *Log) {
|
||||||
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
|
||||||
if !config.LogConsumeEnabled {
|
if !config.LogConsumeEnabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log := &Log{
|
log.Username = GetUsernameById(log.UserId)
|
||||||
UserId: userId,
|
log.CreatedAt = helper.GetTimestamp()
|
||||||
Username: GetUsernameById(userId),
|
log.Type = LogTypeConsume
|
||||||
CreatedAt: helper.GetTimestamp(),
|
recordLogHelper(ctx, log)
|
||||||
Type: LogTypeConsume,
|
}
|
||||||
Content: content,
|
|
||||||
PromptTokens: promptTokens,
|
func RecordTestLog(ctx context.Context, log *Log) {
|
||||||
CompletionTokens: completionTokens,
|
log.CreatedAt = helper.GetTimestamp()
|
||||||
TokenName: tokenName,
|
log.Type = LogTypeTest
|
||||||
ModelName: modelName,
|
recordLogHelper(ctx, log)
|
||||||
Quota: int(quota),
|
|
||||||
ChannelId: channelId,
|
|
||||||
}
|
|
||||||
err := LOG_DB.Create(log).Error
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(ctx, "failed to record log: "+err.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
|
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -48,7 +51,7 @@ func GetRedemptionById(id int) (*Redemption, error) {
|
|||||||
return &redemption, err
|
return &redemption, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Redeem(key string, userId int) (quota int64, err error) {
|
func Redeem(ctx context.Context, key string, userId int) (quota int64, err error) {
|
||||||
if key == "" {
|
if key == "" {
|
||||||
return 0, errors.New("未提供兑换码")
|
return 0, errors.New("未提供兑换码")
|
||||||
}
|
}
|
||||||
@@ -82,7 +85,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.New("兑换失败," + err.Error())
|
return 0, errors.New("兑换失败," + err.Error())
|
||||||
}
|
}
|
||||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
|
RecordLog(ctx, userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
|
||||||
return redemption.Quota, nil
|
return redemption.Quota, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,19 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/blacklist"
|
"github.com/songquanpeng/one-api/common/blacklist"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/common/random"
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
"gorm.io/gorm"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -114,7 +117,7 @@ func DeleteUserById(id int) (err error) {
|
|||||||
return user.Delete()
|
return user.Delete()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) Insert(inviterId int) error {
|
func (user *User) Insert(ctx context.Context, inviterId int) error {
|
||||||
var err error
|
var err error
|
||||||
if user.Password != "" {
|
if user.Password != "" {
|
||||||
user.Password, err = common.Password2Hash(user.Password)
|
user.Password, err = common.Password2Hash(user.Password)
|
||||||
@@ -130,16 +133,16 @@ func (user *User) Insert(inviterId int) error {
|
|||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
if config.QuotaForNewUser > 0 {
|
if config.QuotaForNewUser > 0 {
|
||||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
|
RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
|
||||||
}
|
}
|
||||||
if inviterId != 0 {
|
if inviterId != 0 {
|
||||||
if config.QuotaForInvitee > 0 {
|
if config.QuotaForInvitee > 0 {
|
||||||
_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
|
_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
|
||||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
|
RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
|
||||||
}
|
}
|
||||||
if config.QuotaForInviter > 0 {
|
if config.QuotaForInviter > 0 {
|
||||||
_ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
|
_ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
|
||||||
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
|
RecordLog(ctx, inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// create default token
|
// create default token
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
|
|||||||
strings.Contains(lowerMessage, "credit") ||
|
strings.Contains(lowerMessage, "credit") ||
|
||||||
strings.Contains(lowerMessage, "balance") ||
|
strings.Contains(lowerMessage, "balance") ||
|
||||||
strings.Contains(lowerMessage, "permission denied") ||
|
strings.Contains(lowerMessage, "permission denied") ||
|
||||||
strings.Contains(lowerMessage, "organization has been restricted") || // groq
|
strings.Contains(lowerMessage, "organization has been restricted") || // groq
|
||||||
strings.Contains(lowerMessage, "已欠费") {
|
strings.Contains(lowerMessage, "已欠费") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
|
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/replicate"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
|
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
|
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
|
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
|
||||||
@@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
|
|||||||
return &vertexai.Adaptor{}
|
return &vertexai.Adaptor{}
|
||||||
case apitype.Proxy:
|
case apitype.Proxy:
|
||||||
return &proxy.Adaptor{}
|
return &proxy.Adaptor{}
|
||||||
|
case apitype.Replicate:
|
||||||
|
return &replicate.Adaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,23 @@
|
|||||||
package ali
|
package ali
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
|
"qwen-turbo", "qwen-turbo-latest",
|
||||||
"text-embedding-v1",
|
"qwen-plus", "qwen-plus-latest",
|
||||||
|
"qwen-max", "qwen-max-latest",
|
||||||
|
"qwen-max-longcontext",
|
||||||
|
"qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest",
|
||||||
|
"qwen-vl-ocr", "qwen-vl-ocr-latest",
|
||||||
|
"qwen-audio-turbo",
|
||||||
|
"qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest",
|
||||||
|
"qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest",
|
||||||
|
"qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct",
|
||||||
|
"qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct",
|
||||||
|
"qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat",
|
||||||
|
"qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat",
|
||||||
|
"qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1",
|
||||||
|
"qwen2-audio-instruct", "qwen-audio-chat",
|
||||||
|
"qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct",
|
||||||
|
"qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct",
|
||||||
|
"text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1",
|
||||||
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
|
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,5 +9,4 @@ var ModelList = []string{
|
|||||||
"claude-3-5-sonnet-20240620",
|
"claude-3-5-sonnet-20240620",
|
||||||
"claude-3-5-sonnet-20241022",
|
"claude-3-5-sonnet-20241022",
|
||||||
"claude-3-5-sonnet-latest",
|
"claude-3-5-sonnet-latest",
|
||||||
"claude-3-5-haiku-20241022",
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
|
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
@@ -24,7 +23,15 @@ func (a *Adaptor) Init(meta *meta.Meta) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||||
version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
|
var defaultVersion string
|
||||||
|
switch meta.ActualModelName {
|
||||||
|
case "gemini-2.0-flash-exp",
|
||||||
|
"gemini-2.0-flash-thinking-exp",
|
||||||
|
"gemini-2.0-flash-thinking-exp-01-21":
|
||||||
|
defaultVersion = "v1beta"
|
||||||
|
}
|
||||||
|
|
||||||
|
version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion)
|
||||||
action := ""
|
action := ""
|
||||||
switch meta.Mode {
|
switch meta.Mode {
|
||||||
case relaymode.Embeddings:
|
case relaymode.Embeddings:
|
||||||
@@ -36,6 +43,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
|||||||
if meta.IsStream {
|
if meta.IsStream {
|
||||||
action = "streamGenerateContent?alt=sse"
|
action = "streamGenerateContent?alt=sse"
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
|
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,5 +3,9 @@ package gemini
|
|||||||
// https://ai.google.dev/models/gemini
|
// https://ai.google.dev/models/gemini
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"gemini-pro", "gemini-1.0-pro", "gemini-1.5-flash", "gemini-1.5-pro", "text-embedding-004", "aqa",
|
"gemini-pro", "gemini-1.0-pro",
|
||||||
|
"gemini-1.5-flash", "gemini-1.5-pro",
|
||||||
|
"text-embedding-004", "aqa",
|
||||||
|
"gemini-2.0-flash-exp",
|
||||||
|
"gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,6 +55,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
Threshold: config.GeminiSafetySetting,
|
Threshold: config.GeminiSafetySetting,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||||
|
Threshold: config.GeminiSafetySetting,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
GenerationConfig: ChatGenerationConfig{
|
GenerationConfig: ChatGenerationConfig{
|
||||||
Temperature: textRequest.Temperature,
|
Temperature: textRequest.Temperature,
|
||||||
@@ -247,7 +251,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
|||||||
if candidate.Content.Parts[0].FunctionCall != nil {
|
if candidate.Content.Parts[0].FunctionCall != nil {
|
||||||
choice.Message.ToolCalls = getToolCalls(&candidate)
|
choice.Message.ToolCalls = getToolCalls(&candidate)
|
||||||
} else {
|
} else {
|
||||||
choice.Message.Content = candidate.Content.Parts[0].Text
|
var builder strings.Builder
|
||||||
|
for _, part := range candidate.Content.Parts {
|
||||||
|
if i > 0 {
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
builder.WriteString(part.Text)
|
||||||
|
}
|
||||||
|
choice.Message.Content = builder.String()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
choice.Message.Content = ""
|
choice.Message.Content = ""
|
||||||
|
|||||||
@@ -31,8 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
TopP: request.TopP,
|
TopP: request.TopP,
|
||||||
FrequencyPenalty: request.FrequencyPenalty,
|
FrequencyPenalty: request.FrequencyPenalty,
|
||||||
PresencePenalty: request.PresencePenalty,
|
PresencePenalty: request.PresencePenalty,
|
||||||
NumPredict: request.MaxTokens,
|
NumPredict: request.MaxTokens,
|
||||||
NumCtx: request.NumCtx,
|
NumCtx: request.NumCtx,
|
||||||
},
|
},
|
||||||
Stream: request.Stream,
|
Stream: request.Stream,
|
||||||
}
|
}
|
||||||
@@ -122,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
if strings.HasPrefix(data, "}") {
|
if strings.HasPrefix(data, "}") {
|
||||||
data = strings.TrimPrefix(data, "}") + "}"
|
data = strings.TrimPrefix(data, "}") + "}"
|
||||||
}
|
}
|
||||||
|
|
||||||
var ollamaResponse ChatResponse
|
var ollamaResponse ChatResponse
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ var ModelList = []string{
|
|||||||
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||||
"gpt-4o", "gpt-4o-2024-05-13",
|
"gpt-4o", "gpt-4o-2024-05-13",
|
||||||
"gpt-4o-2024-08-06",
|
"gpt-4o-2024-08-06",
|
||||||
|
"gpt-4o-2024-11-20",
|
||||||
"chatgpt-4o-latest",
|
"chatgpt-4o-latest",
|
||||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||||
"gpt-4-vision-preview",
|
"gpt-4-vision-preview",
|
||||||
@@ -20,4 +21,7 @@ var ModelList = []string{
|
|||||||
"dall-e-2", "dall-e-3",
|
"dall-e-2", "dall-e-3",
|
||||||
"whisper-1",
|
"whisper-1",
|
||||||
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
||||||
|
"o1", "o1-2024-12-17",
|
||||||
|
"o1-preview", "o1-preview-2024-09-12",
|
||||||
|
"o1-mini", "o1-mini-2024-09-12",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,15 +2,16 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
|
func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage {
|
||||||
usage := &model.Usage{}
|
usage := &model.Usage{}
|
||||||
usage.PromptTokens = promptTokens
|
usage.PromptTokens = promptTokens
|
||||||
usage.CompletionTokens = CountTokenText(responseText, modeName)
|
usage.CompletionTokens = CountTokenText(responseText, modelName)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return usage
|
return usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,16 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import "github.com/songquanpeng/one-api/relay/model"
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
||||||
|
logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
|
||||||
|
|
||||||
Error := model.Error{
|
Error := model.Error{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
Type: "one_api_error",
|
Type: "one_api_error",
|
||||||
|
|||||||
136
relay/adaptor/replicate/adaptor.go
Normal file
136
relay/adaptor/replicate/adaptor.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
meta *meta.Meta
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
|
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||||
|
return DrawImageRequest{
|
||||||
|
Input: ImageInput{
|
||||||
|
Steps: 25,
|
||||||
|
Prompt: request.Prompt,
|
||||||
|
Guidance: 3,
|
||||||
|
Seed: int(time.Now().UnixNano()),
|
||||||
|
SafetyTolerance: 5,
|
||||||
|
NImages: 1, // replicate will always return 1 image
|
||||||
|
Width: 1440,
|
||||||
|
Height: 1440,
|
||||||
|
AspectRatio: "1:1",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if !request.Stream {
|
||||||
|
// TODO: support non-stream mode
|
||||||
|
return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the prompt from OpenAI messages
|
||||||
|
var promptBuilder strings.Builder
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
switch msgCnt := message.Content.(type) {
|
||||||
|
case string:
|
||||||
|
promptBuilder.WriteString(message.Role)
|
||||||
|
promptBuilder.WriteString(": ")
|
||||||
|
promptBuilder.WriteString(msgCnt)
|
||||||
|
promptBuilder.WriteString("\n")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
replicateRequest := ReplicateChatRequest{
|
||||||
|
Input: ChatInput{
|
||||||
|
Prompt: promptBuilder.String(),
|
||||||
|
MaxTokens: request.MaxTokens,
|
||||||
|
Temperature: 1.0,
|
||||||
|
TopP: 1.0,
|
||||||
|
PresencePenalty: 0.0,
|
||||||
|
FrequencyPenalty: 0.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map optional fields
|
||||||
|
if request.Temperature != nil {
|
||||||
|
replicateRequest.Input.Temperature = *request.Temperature
|
||||||
|
}
|
||||||
|
if request.TopP != nil {
|
||||||
|
replicateRequest.Input.TopP = *request.TopP
|
||||||
|
}
|
||||||
|
if request.PresencePenalty != nil {
|
||||||
|
replicateRequest.Input.PresencePenalty = *request.PresencePenalty
|
||||||
|
}
|
||||||
|
if request.FrequencyPenalty != nil {
|
||||||
|
replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
|
||||||
|
}
|
||||||
|
if request.MaxTokens > 0 {
|
||||||
|
replicateRequest.Input.MaxTokens = request.MaxTokens
|
||||||
|
} else if request.MaxTokens == 0 {
|
||||||
|
replicateRequest.Input.MaxTokens = 500
|
||||||
|
}
|
||||||
|
|
||||||
|
return replicateRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||||
|
a.meta = meta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||||
|
if !slices.Contains(ModelList, meta.OriginModelName) {
|
||||||
|
return "", errors.Errorf("model %s not supported", meta.OriginModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||||
|
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
logger.Info(c, "send request to replicate")
|
||||||
|
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||||
|
switch meta.Mode {
|
||||||
|
case relaymode.ImagesGenerations:
|
||||||
|
err, usage = ImageHandler(c, resp)
|
||||||
|
case relaymode.ChatCompletions:
|
||||||
|
err, usage = ChatHandler(c, resp)
|
||||||
|
default:
|
||||||
|
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return "replicate"
|
||||||
|
}
|
||||||
191
relay/adaptor/replicate/chat.go
Normal file
191
relay/adaptor/replicate/chat.go
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/render"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ChatHandler(c *gin.Context, resp *http.Response) (
|
||||||
|
srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
payload, _ := io.ReadAll(resp.Body)
|
||||||
|
return openai.ErrorWrapper(
|
||||||
|
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||||
|
"bad_status_code", http.StatusInternalServerError),
|
||||||
|
nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respData := new(ChatResponse)
|
||||||
|
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
err = func() error {
|
||||||
|
// get task
|
||||||
|
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||||
|
http.MethodGet, respData.URLs.Get, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "new request")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||||
|
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "get task")
|
||||||
|
}
|
||||||
|
defer taskResp.Body.Close()
|
||||||
|
|
||||||
|
if taskResp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(taskResp.Body)
|
||||||
|
return errors.Errorf("bad status code [%d]%s",
|
||||||
|
taskResp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
taskBody, err := io.ReadAll(taskResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskData := new(ChatResponse)
|
||||||
|
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||||
|
return errors.Wrap(err, "decode task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch taskData.Status {
|
||||||
|
case "succeeded":
|
||||||
|
case "failed", "canceled":
|
||||||
|
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
|
||||||
|
default:
|
||||||
|
time.Sleep(time.Second * 3)
|
||||||
|
return errNextLoop
|
||||||
|
}
|
||||||
|
|
||||||
|
if taskData.URLs.Stream == "" {
|
||||||
|
return errors.New("stream url is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// request stream url
|
||||||
|
responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "chat stream handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxMeta := meta.GetByContext(c)
|
||||||
|
usage = openai.ResponseText2Usage(responseText,
|
||||||
|
ctxMeta.ActualModelName, ctxMeta.PromptTokens)
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errNextLoop) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
eventPrefix = "event: "
|
||||||
|
dataPrefix = "data: "
|
||||||
|
done = "[DONE]"
|
||||||
|
)
|
||||||
|
|
||||||
|
func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
|
||||||
|
// request stream endpoint
|
||||||
|
streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "new request to stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||||
|
streamReq.Header.Set("Accept", "text/event-stream")
|
||||||
|
streamReq.Header.Set("Cache-Control", "no-store")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(streamReq)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "do request to stream")
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(resp.Body)
|
||||||
|
return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
common.SetEventStreamHeaders(c)
|
||||||
|
doneRendered := false
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle comments starting with ':'
|
||||||
|
if strings.HasPrefix(line, ":") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse SSE fields
|
||||||
|
if strings.HasPrefix(line, eventPrefix) {
|
||||||
|
event := strings.TrimSpace(line[len(eventPrefix):])
|
||||||
|
var data string
|
||||||
|
// Read the following lines to get data and id
|
||||||
|
for scanner.Scan() {
|
||||||
|
nextLine := scanner.Text()
|
||||||
|
if nextLine == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(nextLine, dataPrefix) {
|
||||||
|
data = nextLine[len(dataPrefix):]
|
||||||
|
} else if strings.HasPrefix(nextLine, "id:") {
|
||||||
|
// id = strings.TrimSpace(nextLine[len("id:"):])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if event == "output" {
|
||||||
|
render.StringData(c, data)
|
||||||
|
responseText += data
|
||||||
|
} else if event == "done" {
|
||||||
|
render.Done(c)
|
||||||
|
doneRendered = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return "", errors.Wrap(err, "scan stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !doneRendered {
|
||||||
|
render.Done(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return responseText, nil
|
||||||
|
}
|
||||||
58
relay/adaptor/replicate/constant.go
Normal file
58
relay/adaptor/replicate/constant.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
// ModelList is a list of models that can be used with Replicate.
|
||||||
|
//
|
||||||
|
// https://replicate.com/pricing
|
||||||
|
var ModelList = []string{
|
||||||
|
// -------------------------------------
|
||||||
|
// image model
|
||||||
|
// -------------------------------------
|
||||||
|
"black-forest-labs/flux-1.1-pro",
|
||||||
|
"black-forest-labs/flux-1.1-pro-ultra",
|
||||||
|
"black-forest-labs/flux-canny-dev",
|
||||||
|
"black-forest-labs/flux-canny-pro",
|
||||||
|
"black-forest-labs/flux-depth-dev",
|
||||||
|
"black-forest-labs/flux-depth-pro",
|
||||||
|
"black-forest-labs/flux-dev",
|
||||||
|
"black-forest-labs/flux-dev-lora",
|
||||||
|
"black-forest-labs/flux-fill-dev",
|
||||||
|
"black-forest-labs/flux-fill-pro",
|
||||||
|
"black-forest-labs/flux-pro",
|
||||||
|
"black-forest-labs/flux-redux-dev",
|
||||||
|
"black-forest-labs/flux-redux-schnell",
|
||||||
|
"black-forest-labs/flux-schnell",
|
||||||
|
"black-forest-labs/flux-schnell-lora",
|
||||||
|
"ideogram-ai/ideogram-v2",
|
||||||
|
"ideogram-ai/ideogram-v2-turbo",
|
||||||
|
"recraft-ai/recraft-v3",
|
||||||
|
"recraft-ai/recraft-v3-svg",
|
||||||
|
"stability-ai/stable-diffusion-3",
|
||||||
|
"stability-ai/stable-diffusion-3.5-large",
|
||||||
|
"stability-ai/stable-diffusion-3.5-large-turbo",
|
||||||
|
"stability-ai/stable-diffusion-3.5-medium",
|
||||||
|
// -------------------------------------
|
||||||
|
// language model
|
||||||
|
// -------------------------------------
|
||||||
|
"ibm-granite/granite-20b-code-instruct-8k",
|
||||||
|
"ibm-granite/granite-3.0-2b-instruct",
|
||||||
|
"ibm-granite/granite-3.0-8b-instruct",
|
||||||
|
"ibm-granite/granite-8b-code-instruct-128k",
|
||||||
|
"meta/llama-2-13b",
|
||||||
|
"meta/llama-2-13b-chat",
|
||||||
|
"meta/llama-2-70b",
|
||||||
|
"meta/llama-2-70b-chat",
|
||||||
|
"meta/llama-2-7b",
|
||||||
|
"meta/llama-2-7b-chat",
|
||||||
|
"meta/meta-llama-3.1-405b-instruct",
|
||||||
|
"meta/meta-llama-3-70b",
|
||||||
|
"meta/meta-llama-3-70b-instruct",
|
||||||
|
"meta/meta-llama-3-8b",
|
||||||
|
"meta/meta-llama-3-8b-instruct",
|
||||||
|
"mistralai/mistral-7b-instruct-v0.2",
|
||||||
|
"mistralai/mistral-7b-v0.1",
|
||||||
|
"mistralai/mixtral-8x7b-instruct-v0.1",
|
||||||
|
// -------------------------------------
|
||||||
|
// video model
|
||||||
|
// -------------------------------------
|
||||||
|
// "minimax/video-01", // TODO: implement the adaptor
|
||||||
|
}
|
||||||
222
relay/adaptor/replicate/image.go
Normal file
222
relay/adaptor/replicate/image.go
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/png"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"golang.org/x/image/webp"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ImagesEditsHandler just copy response body to client
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-fill-pro
|
||||||
|
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
// c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
// for k, v := range resp.Header {
|
||||||
|
// c.Writer.Header().Set(k, v[0])
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||||
|
// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
// }
|
||||||
|
// defer resp.Body.Close()
|
||||||
|
|
||||||
|
// return nil, nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
var errNextLoop = errors.New("next_loop")
|
||||||
|
|
||||||
|
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
payload, _ := io.ReadAll(resp.Body)
|
||||||
|
return openai.ErrorWrapper(
|
||||||
|
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||||
|
"bad_status_code", http.StatusInternalServerError),
|
||||||
|
nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respData := new(ImageResponse)
|
||||||
|
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
err = func() error {
|
||||||
|
// get task
|
||||||
|
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||||
|
http.MethodGet, respData.URLs.Get, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "new request")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||||
|
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "get task")
|
||||||
|
}
|
||||||
|
defer taskResp.Body.Close()
|
||||||
|
|
||||||
|
if taskResp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(taskResp.Body)
|
||||||
|
return errors.Errorf("bad status code [%d]%s",
|
||||||
|
taskResp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
taskBody, err := io.ReadAll(taskResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskData := new(ImageResponse)
|
||||||
|
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||||
|
return errors.Wrap(err, "decode task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch taskData.Status {
|
||||||
|
case "succeeded":
|
||||||
|
case "failed", "canceled":
|
||||||
|
return errors.Errorf("task failed: %s", taskData.Status)
|
||||||
|
default:
|
||||||
|
time.Sleep(time.Second * 3)
|
||||||
|
return errNextLoop
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := taskData.GetOutput()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "get output")
|
||||||
|
}
|
||||||
|
if len(output) == 0 {
|
||||||
|
return errors.New("response output is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var pool errgroup.Group
|
||||||
|
respBody := &openai.ImageResponse{
|
||||||
|
Created: taskData.CompletedAt.Unix(),
|
||||||
|
Data: []openai.ImageData{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, imgOut := range output {
|
||||||
|
imgOut := imgOut
|
||||||
|
pool.Go(func() error {
|
||||||
|
// download image
|
||||||
|
downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||||
|
http.MethodGet, imgOut, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "new request")
|
||||||
|
}
|
||||||
|
|
||||||
|
imgResp, err := http.DefaultClient.Do(downloadReq)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "download image")
|
||||||
|
}
|
||||||
|
defer imgResp.Body.Close()
|
||||||
|
|
||||||
|
if imgResp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(imgResp.Body)
|
||||||
|
return errors.Errorf("bad status code [%d]%s",
|
||||||
|
imgResp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
imgData, err := io.ReadAll(imgResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read image")
|
||||||
|
}
|
||||||
|
|
||||||
|
imgData, err = ConvertImageToPNG(imgData)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "convert image")
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
respBody.Data = append(respBody.Data, openai.ImageData{
|
||||||
|
B64Json: fmt.Sprintf("data:image/png;base64,%s",
|
||||||
|
base64.StdEncoding.EncodeToString(imgData)),
|
||||||
|
})
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pool.Wait(); err != nil {
|
||||||
|
if len(respBody.Data) == 0 {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, respBody)
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errNextLoop) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertImageToPNG converts a WebP image to PNG format
|
||||||
|
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
|
||||||
|
// bypass if it's already a PNG image
|
||||||
|
if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
|
||||||
|
return webpData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if is jpeg, convert to png
|
||||||
|
if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
|
||||||
|
img, _, err := image.Decode(bytes.NewReader(webpData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "decode jpeg")
|
||||||
|
}
|
||||||
|
|
||||||
|
var pngBuffer bytes.Buffer
|
||||||
|
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "encode png")
|
||||||
|
}
|
||||||
|
|
||||||
|
return pngBuffer.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the WebP image
|
||||||
|
img, err := webp.Decode(bytes.NewReader(webpData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "decode webp")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode the image as PNG
|
||||||
|
var pngBuffer bytes.Buffer
|
||||||
|
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "encode png")
|
||||||
|
}
|
||||||
|
|
||||||
|
return pngBuffer.Bytes(), nil
|
||||||
|
}
|
||||||
159
relay/adaptor/replicate/model.go
Normal file
159
relay/adaptor/replicate/model.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DrawImageRequest draw image by fluxpro
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||||
|
type DrawImageRequest struct {
|
||||||
|
Input ImageInput `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageInput is input of DrawImageByFluxProRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
|
||||||
|
type ImageInput struct {
|
||||||
|
Steps int `json:"steps" binding:"required,min=1"`
|
||||||
|
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||||
|
ImagePrompt string `json:"image_prompt"`
|
||||||
|
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||||
|
Interval int `json:"interval" binding:"required,min=1,max=4"`
|
||||||
|
AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
|
||||||
|
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
NImages int `json:"n_images" binding:"required,min=1,max=8"`
|
||||||
|
Width int `json:"width" binding:"required,min=256,max=1440"`
|
||||||
|
Height int `json:"height" binding:"required,min=256,max=1440"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||||
|
type InpaintingImageByFlusReplicateRequest struct {
|
||||||
|
Input FluxInpaintingInput `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FluxInpaintingInput is input of DrawImageByFluxProRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||||
|
type FluxInpaintingInput struct {
|
||||||
|
Mask string `json:"mask" binding:"required"`
|
||||||
|
Image string `json:"image" binding:"required"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
Steps int `json:"steps" binding:"required,min=1"`
|
||||||
|
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||||
|
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||||
|
OutputFormat string `json:"output_format"`
|
||||||
|
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||||
|
PromptUnsampling bool `json:"prompt_unsampling"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageResponse is response of DrawImageByFluxProRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||||
|
type ImageResponse struct {
|
||||||
|
CompletedAt time.Time `json:"completed_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
DataRemoved bool `json:"data_removed"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Input DrawImageRequest `json:"input"`
|
||||||
|
Logs string `json:"logs"`
|
||||||
|
Metrics FluxMetrics `json:"metrics"`
|
||||||
|
// Output could be `string` or `[]string`
|
||||||
|
Output any `json:"output"`
|
||||||
|
StartedAt time.Time `json:"started_at"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
URLs FluxURLs `json:"urls"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ImageResponse) GetOutput() ([]string, error) {
|
||||||
|
switch v := r.Output.(type) {
|
||||||
|
case string:
|
||||||
|
return []string{v}, nil
|
||||||
|
case []string:
|
||||||
|
return v, nil
|
||||||
|
case nil:
|
||||||
|
return nil, nil
|
||||||
|
case []interface{}:
|
||||||
|
// convert []interface{} to []string
|
||||||
|
ret := make([]string, len(v))
|
||||||
|
for idx, vv := range v {
|
||||||
|
if vvv, ok := vv.(string); ok {
|
||||||
|
ret[idx] = vvv
|
||||||
|
} else {
|
||||||
|
return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
default:
|
||||||
|
return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FluxMetrics is metrics of ImageResponse
|
||||||
|
type FluxMetrics struct {
|
||||||
|
ImageCount int `json:"image_count"`
|
||||||
|
PredictTime float64 `json:"predict_time"`
|
||||||
|
TotalTime float64 `json:"total_time"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FluxURLs is urls of ImageResponse
|
||||||
|
type FluxURLs struct {
|
||||||
|
Get string `json:"get"`
|
||||||
|
Cancel string `json:"cancel"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReplicateChatRequest struct {
|
||||||
|
Input ChatInput `json:"input" form:"input" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatInput is input of ChatByReplicateRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema
|
||||||
|
type ChatInput struct {
|
||||||
|
TopK int `json:"top_k"`
|
||||||
|
TopP float64 `json:"top_p"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
MinTokens int `json:"min_tokens"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
SystemPrompt string `json:"system_prompt"`
|
||||||
|
StopSequences string `json:"stop_sequences"`
|
||||||
|
PromptTemplate string `json:"prompt_template"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatResponse is response of ChatByReplicateRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json
|
||||||
|
type ChatResponse struct {
|
||||||
|
CompletedAt time.Time `json:"completed_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
DataRemoved bool `json:"data_removed"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Input ChatInput `json:"input"`
|
||||||
|
Logs string `json:"logs"`
|
||||||
|
Metrics FluxMetrics `json:"metrics"`
|
||||||
|
// Output could be `string` or `[]string`
|
||||||
|
Output []string `json:"output"`
|
||||||
|
StartedAt time.Time `json:"started_at"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
URLs ChatResponseUrl `json:"urls"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatResponseUrl is task urls of ChatResponse
|
||||||
|
type ChatResponseUrl struct {
|
||||||
|
Stream string `json:"stream"`
|
||||||
|
Get string `json:"get"`
|
||||||
|
Cancel string `json:"cancel"`
|
||||||
|
}
|
||||||
@@ -2,16 +2,19 @@ package tencent
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
"io"
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://cloud.tencent.com/document/api/1729/101837
|
// https://cloud.tencent.com/document/api/1729/101837
|
||||||
@@ -52,10 +55,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tencentRequest := ConvertRequest(*request)
|
var convertedRequest any
|
||||||
|
switch relayMode {
|
||||||
|
case relaymode.Embeddings:
|
||||||
|
a.Action = "GetEmbedding"
|
||||||
|
convertedRequest = ConvertEmbeddingRequest(*request)
|
||||||
|
default:
|
||||||
|
a.Action = "ChatCompletions"
|
||||||
|
convertedRequest = ConvertRequest(*request)
|
||||||
|
}
|
||||||
// we have to calculate the sign here
|
// we have to calculate the sign here
|
||||||
a.Sign = GetSign(*tencentRequest, a, secretId, secretKey)
|
a.Sign = GetSign(convertedRequest, a, secretId, secretKey)
|
||||||
return tencentRequest, nil
|
return convertedRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||||
@@ -75,7 +86,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
|
|||||||
err, responseText = StreamHandler(c, resp)
|
err, responseText = StreamHandler(c, resp)
|
||||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = Handler(c, resp)
|
switch meta.Mode {
|
||||||
|
case relaymode.Embeddings:
|
||||||
|
err, usage = EmbeddingHandler(c, resp)
|
||||||
|
default:
|
||||||
|
err, usage = Handler(c, resp)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,4 +6,5 @@ var ModelList = []string{
|
|||||||
"hunyuan-standard-256K",
|
"hunyuan-standard-256K",
|
||||||
"hunyuan-pro",
|
"hunyuan-pro",
|
||||||
"hunyuan-vision",
|
"hunyuan-vision",
|
||||||
|
"hunyuan-embedding",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/songquanpeng/one-api/common/render"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -16,11 +15,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/conv"
|
"github.com/songquanpeng/one-api/common/conv"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/common/random"
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
|
"github.com/songquanpeng/one-api/common/render"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/constant"
|
"github.com/songquanpeng/one-api/relay/constant"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
@@ -44,8 +46,68 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||||
|
return &EmbeddingRequest{
|
||||||
|
InputList: request.ParseInput(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
var tencentResponseP EmbeddingResponseP
|
||||||
|
err := json.NewDecoder(resp.Body).Decode(&tencentResponseP)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tencentResponse := tencentResponseP.Response
|
||||||
|
if tencentResponse.Error.Code != "" {
|
||||||
|
return &model.ErrorWithStatusCode{
|
||||||
|
Error: model.Error{
|
||||||
|
Message: tencentResponse.Error.Message,
|
||||||
|
Code: tencentResponse.Error.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
requestModel := c.GetString(ctxkey.RequestModel)
|
||||||
|
fullTextResponse := embeddingResponseTencent2OpenAI(&tencentResponse)
|
||||||
|
fullTextResponse.Model = requestModel
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func embeddingResponseTencent2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||||
|
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||||
|
Object: "list",
|
||||||
|
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
|
||||||
|
Model: "hunyuan-embedding",
|
||||||
|
Usage: model.Usage{TotalTokens: response.EmbeddingUsage.TotalTokens},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range response.Data {
|
||||||
|
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||||
|
Object: item.Object,
|
||||||
|
Index: item.Index,
|
||||||
|
Embedding: item.Embedding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &openAIEmbeddingResponse
|
||||||
|
}
|
||||||
|
|
||||||
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
|
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||||
fullTextResponse := openai.TextResponse{
|
fullTextResponse := openai.TextResponse{
|
||||||
|
Id: response.ReqID,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: helper.GetTimestamp(),
|
Created: helper.GetTimestamp(),
|
||||||
Usage: model.Usage{
|
Usage: model.Usage{
|
||||||
@@ -148,7 +210,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
|
|||||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
TencentResponse = responseP.Response
|
TencentResponse = responseP.Response
|
||||||
if TencentResponse.Error.Code != 0 {
|
if TencentResponse.Error.Code != "" {
|
||||||
return &model.ErrorWithStatusCode{
|
return &model.ErrorWithStatusCode{
|
||||||
Error: model.Error{
|
Error: model.Error{
|
||||||
Message: TencentResponse.Error.Message,
|
Message: TencentResponse.Error.Message,
|
||||||
@@ -195,7 +257,7 @@ func hmacSha256(s, key string) string {
|
|||||||
return string(hashed.Sum(nil))
|
return string(hashed.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string {
|
func GetSign(req any, adaptor *Adaptor, secId, secKey string) string {
|
||||||
// build canonical request string
|
// build canonical request string
|
||||||
host := "hunyuan.tencentcloudapi.com"
|
host := "hunyuan.tencentcloudapi.com"
|
||||||
httpRequestMethod := "POST"
|
httpRequestMethod := "POST"
|
||||||
|
|||||||
@@ -35,16 +35,16 @@ type ChatRequest struct {
|
|||||||
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
|
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
|
||||||
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
|
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
|
||||||
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||||
TopP *float64 `json:"TopP"`
|
TopP *float64 `json:"TopP,omitempty"`
|
||||||
// 说明:
|
// 说明:
|
||||||
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
|
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
|
||||||
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
|
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
|
||||||
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||||
Temperature *float64 `json:"Temperature"`
|
Temperature *float64 `json:"Temperature,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
Code int `json:"Code"`
|
Code string `json:"Code"`
|
||||||
Message string `json:"Message"`
|
Message string `json:"Message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,15 +61,41 @@ type ResponseChoices struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
Choices []ResponseChoices `json:"Choices,omitempty"` // 结果
|
Choices []ResponseChoices `json:"Choices,omitempty"` // 结果
|
||||||
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
|
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
|
||||||
Id string `json:"Id,omitempty"` // 会话 id
|
Id string `json:"Id,omitempty"` // 会话 id
|
||||||
Usage Usage `json:"Usage,omitempty"` // token 数量
|
Usage Usage `json:"Usage,omitempty"` // token 数量
|
||||||
Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||||
Note string `json:"Note,omitempty"` // 注释
|
Note string `json:"Note,omitempty"` // 注释
|
||||||
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
ReqID string `json:"RequestId,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatResponseP struct {
|
type ChatResponseP struct {
|
||||||
Response ChatResponse `json:"Response,omitempty"`
|
Response ChatResponse `json:"Response,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbeddingRequest struct {
|
||||||
|
InputList []string `json:"InputList"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingData struct {
|
||||||
|
Embedding []float64 `json:"Embedding"`
|
||||||
|
Index int `json:"Index"`
|
||||||
|
Object string `json:"Object"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingUsage struct {
|
||||||
|
PromptTokens int `json:"PromptTokens"`
|
||||||
|
TotalTokens int `json:"TotalTokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingResponse struct {
|
||||||
|
Data []EmbeddingData `json:"Data"`
|
||||||
|
EmbeddingUsage EmbeddingUsage `json:"Usage,omitempty"`
|
||||||
|
RequestId string `json:"RequestId,omitempty"`
|
||||||
|
Error Error `json:"Error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingResponseP struct {
|
||||||
|
Response EmbeddingResponse `json:"Response,omitempty"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,7 +15,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002",
|
"gemini-pro", "gemini-pro-vision",
|
||||||
|
"gemini-1.5-pro-001", "gemini-1.5-flash-001",
|
||||||
|
"gemini-1.5-pro-002", "gemini-1.5-flash-002",
|
||||||
|
"gemini-2.0-flash-exp",
|
||||||
|
"gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21",
|
||||||
}
|
}
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ const (
|
|||||||
DeepL
|
DeepL
|
||||||
VertexAI
|
VertexAI
|
||||||
Proxy
|
Proxy
|
||||||
|
Replicate
|
||||||
|
|
||||||
Dummy // this one is only for count, do not add any channel after this
|
Dummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package billing
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
)
|
)
|
||||||
@@ -31,8 +32,17 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQ
|
|||||||
}
|
}
|
||||||
// totalQuota is total quota consumed
|
// totalQuota is total quota consumed
|
||||||
if totalQuota != 0 {
|
if totalQuota != 0 {
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("倍率:%.2f × %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
|
model.RecordConsumeLog(ctx, &model.Log{
|
||||||
|
UserId: userId,
|
||||||
|
ChannelId: channelId,
|
||||||
|
PromptTokens: int(totalQuota),
|
||||||
|
CompletionTokens: 0,
|
||||||
|
ModelName: modelName,
|
||||||
|
TokenName: tokenName,
|
||||||
|
Quota: int(totalQuota),
|
||||||
|
Content: logContent,
|
||||||
|
})
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
||||||
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,9 +9,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
USD2RMB = 7
|
USD2RMB = 7
|
||||||
USD = 500 // $0.002 = 1 -> $1 = 500
|
USD = 500 // $0.002 = 1 -> $1 = 500
|
||||||
RMB = USD / USD2RMB
|
MILLI_USD = 1.0 / 1000 * USD
|
||||||
|
RMB = USD / USD2RMB
|
||||||
)
|
)
|
||||||
|
|
||||||
// ModelRatio
|
// ModelRatio
|
||||||
@@ -37,6 +38,7 @@ var ModelRatio = map[string]float64{
|
|||||||
"chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens
|
"chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens
|
||||||
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
|
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
|
||||||
"gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens
|
"gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens
|
||||||
|
"gpt-4o-2024-11-20": 1.25, // $0.0025 / 1K tokens
|
||||||
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
|
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
|
||||||
"gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens
|
"gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens
|
||||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||||
@@ -48,8 +50,14 @@ var ModelRatio = map[string]float64{
|
|||||||
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
||||||
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
||||||
"gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
|
"gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
|
||||||
"davinci-002": 1, // $0.002 / 1K tokens
|
"o1": 7.5, // $15.00 / 1M input tokens
|
||||||
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
"o1-2024-12-17": 7.5,
|
||||||
|
"o1-preview": 7.5, // $15.00 / 1M input tokens
|
||||||
|
"o1-preview-2024-09-12": 7.5,
|
||||||
|
"o1-mini": 1.5, // $3.00 / 1M input tokens
|
||||||
|
"o1-mini-2024-09-12": 1.5,
|
||||||
|
"davinci-002": 1, // $0.002 / 1K tokens
|
||||||
|
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
||||||
"text-ada-001": 0.2,
|
"text-ada-001": 0.2,
|
||||||
"text-babbage-001": 0.25,
|
"text-babbage-001": 0.25,
|
||||||
"text-curie-001": 1,
|
"text-curie-001": 1,
|
||||||
@@ -102,11 +110,16 @@ var ModelRatio = map[string]float64{
|
|||||||
"bge-large-en": 0.002 * RMB,
|
"bge-large-en": 0.002 * RMB,
|
||||||
"tao-8k": 0.002 * RMB,
|
"tao-8k": 0.002 * RMB,
|
||||||
// https://ai.google.dev/pricing
|
// https://ai.google.dev/pricing
|
||||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
"gemini-1.0-pro": 1,
|
"gemini-1.0-pro": 1,
|
||||||
"gemini-1.5-flash": 1,
|
"gemini-1.5-pro": 1,
|
||||||
"gemini-1.5-pro": 1,
|
"gemini-1.5-pro-001": 1,
|
||||||
"aqa": 1,
|
"gemini-1.5-flash": 1,
|
||||||
|
"gemini-1.5-flash-001": 1,
|
||||||
|
"gemini-2.0-flash-exp": 1,
|
||||||
|
"gemini-2.0-flash-thinking-exp": 1,
|
||||||
|
"gemini-2.0-flash-thinking-exp-01-21": 1,
|
||||||
|
"aqa": 1,
|
||||||
// https://open.bigmodel.cn/pricing
|
// https://open.bigmodel.cn/pricing
|
||||||
"glm-4": 0.1 * RMB,
|
"glm-4": 0.1 * RMB,
|
||||||
"glm-4v": 0.1 * RMB,
|
"glm-4v": 0.1 * RMB,
|
||||||
@@ -118,29 +131,94 @@ var ModelRatio = map[string]float64{
|
|||||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||||
"cogview-3": 0.25 * RMB,
|
"cogview-3": 0.25 * RMB,
|
||||||
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
|
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
|
||||||
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens
|
"qwen-turbo": 1.4286, // ¥0.02 / 1k tokens
|
||||||
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
|
"qwen-turbo-latest": 1.4286,
|
||||||
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
|
"qwen-plus": 1.4286,
|
||||||
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
|
"qwen-plus-latest": 1.4286,
|
||||||
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
"qwen-max": 1.4286,
|
||||||
"ali-stable-diffusion-xl": 8,
|
"qwen-max-latest": 1.4286,
|
||||||
"ali-stable-diffusion-v1.5": 8,
|
"qwen-max-longcontext": 1.4286,
|
||||||
"wanx-v1": 8,
|
"qwen-vl-max": 1.4286,
|
||||||
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-vl-max-latest": 1.4286,
|
||||||
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-vl-plus": 1.4286,
|
||||||
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-vl-plus-latest": 1.4286,
|
||||||
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-vl-ocr": 1.4286,
|
||||||
"SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-vl-ocr-latest": 1.4286,
|
||||||
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-audio-turbo": 1.4286,
|
||||||
"SparkDesk-v3.5-32K": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-math-plus": 1.4286,
|
||||||
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
|
"qwen-math-plus-latest": 1.4286,
|
||||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
"qwen-math-turbo": 1.4286,
|
||||||
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
"qwen-math-turbo-latest": 1.4286,
|
||||||
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
"qwen-coder-plus": 1.4286,
|
||||||
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
"qwen-coder-plus-latest": 1.4286,
|
||||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
"qwen-coder-turbo": 1.4286,
|
||||||
"ChatStd": 0.01 * RMB,
|
"qwen-coder-turbo-latest": 1.4286,
|
||||||
"ChatPro": 0.1 * RMB,
|
"qwq-32b-preview": 1.4286,
|
||||||
|
"qwen2.5-72b-instruct": 1.4286,
|
||||||
|
"qwen2.5-32b-instruct": 1.4286,
|
||||||
|
"qwen2.5-14b-instruct": 1.4286,
|
||||||
|
"qwen2.5-7b-instruct": 1.4286,
|
||||||
|
"qwen2.5-3b-instruct": 1.4286,
|
||||||
|
"qwen2.5-1.5b-instruct": 1.4286,
|
||||||
|
"qwen2.5-0.5b-instruct": 1.4286,
|
||||||
|
"qwen2-72b-instruct": 1.4286,
|
||||||
|
"qwen2-57b-a14b-instruct": 1.4286,
|
||||||
|
"qwen2-7b-instruct": 1.4286,
|
||||||
|
"qwen2-1.5b-instruct": 1.4286,
|
||||||
|
"qwen2-0.5b-instruct": 1.4286,
|
||||||
|
"qwen1.5-110b-chat": 1.4286,
|
||||||
|
"qwen1.5-72b-chat": 1.4286,
|
||||||
|
"qwen1.5-32b-chat": 1.4286,
|
||||||
|
"qwen1.5-14b-chat": 1.4286,
|
||||||
|
"qwen1.5-7b-chat": 1.4286,
|
||||||
|
"qwen1.5-1.8b-chat": 1.4286,
|
||||||
|
"qwen1.5-0.5b-chat": 1.4286,
|
||||||
|
"qwen-72b-chat": 1.4286,
|
||||||
|
"qwen-14b-chat": 1.4286,
|
||||||
|
"qwen-7b-chat": 1.4286,
|
||||||
|
"qwen-1.8b-chat": 1.4286,
|
||||||
|
"qwen-1.8b-longcontext-chat": 1.4286,
|
||||||
|
"qwen2-vl-7b-instruct": 1.4286,
|
||||||
|
"qwen2-vl-2b-instruct": 1.4286,
|
||||||
|
"qwen-vl-v1": 1.4286,
|
||||||
|
"qwen-vl-chat-v1": 1.4286,
|
||||||
|
"qwen2-audio-instruct": 1.4286,
|
||||||
|
"qwen-audio-chat": 1.4286,
|
||||||
|
"qwen2.5-math-72b-instruct": 1.4286,
|
||||||
|
"qwen2.5-math-7b-instruct": 1.4286,
|
||||||
|
"qwen2.5-math-1.5b-instruct": 1.4286,
|
||||||
|
"qwen2-math-72b-instruct": 1.4286,
|
||||||
|
"qwen2-math-7b-instruct": 1.4286,
|
||||||
|
"qwen2-math-1.5b-instruct": 1.4286,
|
||||||
|
"qwen2.5-coder-32b-instruct": 1.4286,
|
||||||
|
"qwen2.5-coder-14b-instruct": 1.4286,
|
||||||
|
"qwen2.5-coder-7b-instruct": 1.4286,
|
||||||
|
"qwen2.5-coder-3b-instruct": 1.4286,
|
||||||
|
"qwen2.5-coder-1.5b-instruct": 1.4286,
|
||||||
|
"qwen2.5-coder-0.5b-instruct": 1.4286,
|
||||||
|
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
||||||
|
"text-embedding-v3": 0.05,
|
||||||
|
"text-embedding-v2": 0.05,
|
||||||
|
"text-embedding-async-v2": 0.05,
|
||||||
|
"text-embedding-async-v1": 0.05,
|
||||||
|
"ali-stable-diffusion-xl": 8.00,
|
||||||
|
"ali-stable-diffusion-v1.5": 8.00,
|
||||||
|
"wanx-v1": 8.00,
|
||||||
|
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v3.5-32K": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
|
||||||
|
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||||
|
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
|
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
|
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
|
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||||
|
"ChatStd": 0.01 * RMB,
|
||||||
|
"ChatPro": 0.1 * RMB,
|
||||||
// https://platform.moonshot.cn/pricing
|
// https://platform.moonshot.cn/pricing
|
||||||
"moonshot-v1-8k": 0.012 * RMB,
|
"moonshot-v1-8k": 0.012 * RMB,
|
||||||
"moonshot-v1-32k": 0.024 * RMB,
|
"moonshot-v1-32k": 0.024 * RMB,
|
||||||
@@ -203,20 +281,69 @@ var ModelRatio = map[string]float64{
|
|||||||
"command-r": 0.5 / 1000 * USD,
|
"command-r": 0.5 / 1000 * USD,
|
||||||
"command-r-plus": 3.0 / 1000 * USD,
|
"command-r-plus": 3.0 / 1000 * USD,
|
||||||
// https://platform.deepseek.com/api-docs/pricing/
|
// https://platform.deepseek.com/api-docs/pricing/
|
||||||
"deepseek-chat": 1.0 / 1000 * RMB,
|
"deepseek-chat": 0.14 * MILLI_USD,
|
||||||
"deepseek-coder": 1.0 / 1000 * RMB,
|
"deepseek-reasoner": 0.55 * MILLI_USD,
|
||||||
// https://www.deepl.com/pro?cta=header-prices
|
// https://www.deepl.com/pro?cta=header-prices
|
||||||
"deepl-zh": 25.0 / 1000 * USD,
|
"deepl-zh": 25.0 / 1000 * USD,
|
||||||
"deepl-en": 25.0 / 1000 * USD,
|
"deepl-en": 25.0 / 1000 * USD,
|
||||||
"deepl-ja": 25.0 / 1000 * USD,
|
"deepl-ja": 25.0 / 1000 * USD,
|
||||||
// https://console.x.ai/
|
// https://console.x.ai/
|
||||||
"grok-beta": 5.0 / 1000 * USD,
|
"grok-beta": 5.0 / 1000 * USD,
|
||||||
|
// replicate charges based on the number of generated images
|
||||||
|
// https://replicate.com/pricing
|
||||||
|
"black-forest-labs/flux-1.1-pro": 0.04 * USD,
|
||||||
|
"black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD,
|
||||||
|
"black-forest-labs/flux-canny-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-canny-pro": 0.05 * USD,
|
||||||
|
"black-forest-labs/flux-depth-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-depth-pro": 0.05 * USD,
|
||||||
|
"black-forest-labs/flux-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-dev-lora": 0.032 * USD,
|
||||||
|
"black-forest-labs/flux-fill-dev": 0.04 * USD,
|
||||||
|
"black-forest-labs/flux-fill-pro": 0.05 * USD,
|
||||||
|
"black-forest-labs/flux-pro": 0.055 * USD,
|
||||||
|
"black-forest-labs/flux-redux-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-redux-schnell": 0.003 * USD,
|
||||||
|
"black-forest-labs/flux-schnell": 0.003 * USD,
|
||||||
|
"black-forest-labs/flux-schnell-lora": 0.02 * USD,
|
||||||
|
"ideogram-ai/ideogram-v2": 0.08 * USD,
|
||||||
|
"ideogram-ai/ideogram-v2-turbo": 0.05 * USD,
|
||||||
|
"recraft-ai/recraft-v3": 0.04 * USD,
|
||||||
|
"recraft-ai/recraft-v3-svg": 0.08 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3": 0.035 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3.5-large": 0.065 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3.5-medium": 0.035 * USD,
|
||||||
|
// replicate chat models
|
||||||
|
"ibm-granite/granite-20b-code-instruct-8k": 0.100 * USD,
|
||||||
|
"ibm-granite/granite-3.0-2b-instruct": 0.030 * USD,
|
||||||
|
"ibm-granite/granite-3.0-8b-instruct": 0.050 * USD,
|
||||||
|
"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD,
|
||||||
|
"meta/llama-2-13b": 0.100 * USD,
|
||||||
|
"meta/llama-2-13b-chat": 0.100 * USD,
|
||||||
|
"meta/llama-2-70b": 0.650 * USD,
|
||||||
|
"meta/llama-2-70b-chat": 0.650 * USD,
|
||||||
|
"meta/llama-2-7b": 0.050 * USD,
|
||||||
|
"meta/llama-2-7b-chat": 0.050 * USD,
|
||||||
|
"meta/meta-llama-3.1-405b-instruct": 9.500 * USD,
|
||||||
|
"meta/meta-llama-3-70b": 0.650 * USD,
|
||||||
|
"meta/meta-llama-3-70b-instruct": 0.650 * USD,
|
||||||
|
"meta/meta-llama-3-8b": 0.050 * USD,
|
||||||
|
"meta/meta-llama-3-8b-instruct": 0.050 * USD,
|
||||||
|
"mistralai/mistral-7b-instruct-v0.2": 0.050 * USD,
|
||||||
|
"mistralai/mistral-7b-v0.1": 0.050 * USD,
|
||||||
|
"mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD,
|
||||||
}
|
}
|
||||||
|
|
||||||
var CompletionRatio = map[string]float64{
|
var CompletionRatio = map[string]float64{
|
||||||
// aws llama3
|
// aws llama3
|
||||||
"llama3-8b-8192(33)": 0.0006 / 0.0003,
|
"llama3-8b-8192(33)": 0.0006 / 0.0003,
|
||||||
"llama3-70b-8192(33)": 0.0035 / 0.00265,
|
"llama3-70b-8192(33)": 0.0035 / 0.00265,
|
||||||
|
// whisper
|
||||||
|
"whisper-1": 0, // only count input tokens
|
||||||
|
// deepseek
|
||||||
|
"deepseek-chat": 0.28 / 0.14,
|
||||||
|
"deepseek-reasoner": 2.19 / 0.55,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -334,16 +461,22 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
|||||||
return 4.0 / 3.0
|
return 4.0 / 3.0
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(name, "gpt-4") {
|
if strings.HasPrefix(name, "gpt-4") {
|
||||||
if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" {
|
if strings.HasPrefix(name, "gpt-4o") {
|
||||||
|
if name == "gpt-4o-2024-05-13" {
|
||||||
|
return 3
|
||||||
|
}
|
||||||
return 4
|
return 4
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(name, "gpt-4-turbo") ||
|
if strings.HasPrefix(name, "gpt-4-turbo") ||
|
||||||
strings.HasPrefix(name, "gpt-4o") ||
|
|
||||||
strings.HasSuffix(name, "preview") {
|
strings.HasSuffix(name, "preview") {
|
||||||
return 3
|
return 3
|
||||||
}
|
}
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
|
// including o1, o1-preview, o1-mini
|
||||||
|
if strings.HasPrefix(name, "o1") {
|
||||||
|
return 4
|
||||||
|
}
|
||||||
if name == "chatgpt-4o-latest" {
|
if name == "chatgpt-4o-latest" {
|
||||||
return 3
|
return 3
|
||||||
}
|
}
|
||||||
@@ -362,6 +495,7 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
|||||||
if strings.HasPrefix(name, "deepseek-") {
|
if strings.HasPrefix(name, "deepseek-") {
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
switch name {
|
switch name {
|
||||||
case "llama2-70b-4096":
|
case "llama2-70b-4096":
|
||||||
return 0.8 / 0.64
|
return 0.8 / 0.64
|
||||||
@@ -377,6 +511,35 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
|||||||
return 5
|
return 5
|
||||||
case "grok-beta":
|
case "grok-beta":
|
||||||
return 3
|
return 3
|
||||||
|
// Replicate Models
|
||||||
|
// https://replicate.com/pricing
|
||||||
|
case "ibm-granite/granite-20b-code-instruct-8k":
|
||||||
|
return 5
|
||||||
|
case "ibm-granite/granite-3.0-2b-instruct":
|
||||||
|
return 8.333333333333334
|
||||||
|
case "ibm-granite/granite-3.0-8b-instruct",
|
||||||
|
"ibm-granite/granite-8b-code-instruct-128k":
|
||||||
|
return 5
|
||||||
|
case "meta/llama-2-13b",
|
||||||
|
"meta/llama-2-13b-chat",
|
||||||
|
"meta/llama-2-7b",
|
||||||
|
"meta/llama-2-7b-chat",
|
||||||
|
"meta/meta-llama-3-8b",
|
||||||
|
"meta/meta-llama-3-8b-instruct":
|
||||||
|
return 5
|
||||||
|
case "meta/llama-2-70b",
|
||||||
|
"meta/llama-2-70b-chat",
|
||||||
|
"meta/meta-llama-3-70b",
|
||||||
|
"meta/meta-llama-3-70b-instruct":
|
||||||
|
return 2.750 / 0.650 // ≈4.230769
|
||||||
|
case "meta/meta-llama-3.1-405b-instruct":
|
||||||
|
return 1
|
||||||
|
case "mistralai/mistral-7b-instruct-v0.2",
|
||||||
|
"mistralai/mistral-7b-v0.1":
|
||||||
|
return 5
|
||||||
|
case "mistralai/mixtral-8x7b-instruct-v0.1":
|
||||||
|
return 1.000 / 0.300 // ≈3.333333
|
||||||
}
|
}
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,5 +47,6 @@ const (
|
|||||||
Proxy
|
Proxy
|
||||||
SiliconFlow
|
SiliconFlow
|
||||||
XAI
|
XAI
|
||||||
|
Replicate
|
||||||
Dummy
|
Dummy
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ func ToAPIType(channelType int) int {
|
|||||||
apiType = apitype.DeepL
|
apiType = apitype.DeepL
|
||||||
case VertextAI:
|
case VertextAI:
|
||||||
apiType = apitype.VertexAI
|
apiType = apitype.VertexAI
|
||||||
|
case Replicate:
|
||||||
|
apiType = apitype.Replicate
|
||||||
case Proxy:
|
case Proxy:
|
||||||
apiType = apitype.Proxy
|
apiType = apitype.Proxy
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ var ChannelBaseURLs = []string{
|
|||||||
"", // 43
|
"", // 43
|
||||||
"https://api.siliconflow.cn", // 44
|
"https://api.siliconflow.cn", // 44
|
||||||
"https://api.x.ai", // 45
|
"https://api.x.ai", // 45
|
||||||
|
"https://api.replicate.com/v1/models/", // 46
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package role
|
package role
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
System = "system"
|
||||||
Assistant = "assistant"
|
Assistant = "assistant"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -110,16 +110,9 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// map model name
|
// map model name
|
||||||
modelMapping := c.GetString(ctxkey.ModelMapping)
|
modelMapping := c.GetStringMapString(ctxkey.ModelMapping)
|
||||||
if modelMapping != "" {
|
if modelMapping != nil && modelMapping[audioModel] != "" {
|
||||||
modelMap := make(map[string]string)
|
audioModel = modelMapping[audioModel]
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
||||||
if err != nil {
|
|
||||||
return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap[audioModel] != "" {
|
|
||||||
audioModel = modelMap[audioModel]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := channeltype.ChannelBaseURLs[channelType]
|
baseURL := channeltype.ChannelBaseURLs[channelType]
|
||||||
|
|||||||
@@ -8,7 +8,11 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/relay/constant/role"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
@@ -90,7 +94,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
|
|||||||
return preConsumedQuota, nil
|
return preConsumedQuota, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
|
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) {
|
||||||
if usage == nil {
|
if usage == nil {
|
||||||
logger.Error(ctx, "usage is nil, which is unexpected")
|
logger.Error(ctx, "usage is nil, which is unexpected")
|
||||||
return
|
return
|
||||||
@@ -118,8 +122,20 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
||||||
}
|
}
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
|
logContent := fmt.Sprintf("倍率:%.2f × %.2f × %.2f", modelRatio, groupRatio, completionRatio)
|
||||||
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
|
model.RecordConsumeLog(ctx, &model.Log{
|
||||||
|
UserId: meta.UserId,
|
||||||
|
ChannelId: meta.ChannelId,
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: completionTokens,
|
||||||
|
ModelName: textRequest.Model,
|
||||||
|
TokenName: meta.TokenName,
|
||||||
|
Quota: int(quota),
|
||||||
|
Content: logContent,
|
||||||
|
IsStream: meta.IsStream,
|
||||||
|
ElapsedTime: helper.CalcElapsedTime(meta.StartTime),
|
||||||
|
SystemPromptReset: systemPromptReset,
|
||||||
|
})
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
||||||
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
||||||
}
|
}
|
||||||
@@ -142,15 +158,41 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK &&
|
||||||
|
// replicate return 201 to create a task
|
||||||
|
resp.StatusCode != http.StatusCreated {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if meta.ChannelType == channeltype.DeepL {
|
if meta.ChannelType == channeltype.DeepL {
|
||||||
// skip stream check for deepl
|
// skip stream check for deepl
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
|
|
||||||
|
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") &&
|
||||||
|
// Even if stream mode is enabled, replicate will first return a task info in JSON format,
|
||||||
|
// requiring the client to request the stream endpoint in the task info
|
||||||
|
meta.ChannelType != channeltype.Replicate {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) {
|
||||||
|
if prompt == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(request.Messages) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if request.Messages[0].Role == role.System {
|
||||||
|
request.Messages[0].Content = prompt
|
||||||
|
logger.Infof(ctx, "rewrite system prompt")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
request.Messages = append([]relaymodel.Message{{
|
||||||
|
Role: role.System,
|
||||||
|
Content: prompt,
|
||||||
|
}}, request.Messages...)
|
||||||
|
logger.Infof(ctx, "add system prompt")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
@@ -22,7 +23,7 @@ import (
|
|||||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
|
func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) {
|
||||||
imageRequest := &relaymodel.ImageRequest{}
|
imageRequest := &relaymodel.ImageRequest{}
|
||||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -65,7 +66,7 @@ func getImageSizeRatio(model string, size string) float64 {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
||||||
// check prompt length
|
// check prompt length
|
||||||
if imageRequest.Prompt == "" {
|
if imageRequest.Prompt == "" {
|
||||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||||
@@ -150,12 +151,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
}
|
}
|
||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
|
|
||||||
|
// these adaptors need to convert the request
|
||||||
switch meta.ChannelType {
|
switch meta.ChannelType {
|
||||||
case channeltype.Ali:
|
case channeltype.Zhipu,
|
||||||
fallthrough
|
channeltype.Ali,
|
||||||
case channeltype.Baidu:
|
channeltype.Replicate,
|
||||||
fallthrough
|
channeltype.Baidu:
|
||||||
case channeltype.Zhipu:
|
|
||||||
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||||
@@ -172,7 +173,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||||
|
|
||||||
quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
var quota int64
|
||||||
|
switch meta.ChannelType {
|
||||||
|
case channeltype.Replicate:
|
||||||
|
// replicate always return 1 image
|
||||||
|
quota = int64(ratio * imageCostRatio * 1000)
|
||||||
|
default:
|
||||||
|
quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
||||||
|
}
|
||||||
|
|
||||||
if userQuota-quota < 0 {
|
if userQuota-quota < 0 {
|
||||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
@@ -186,7 +194,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
if resp != nil &&
|
||||||
|
resp.StatusCode != http.StatusCreated && // replicate returns 201
|
||||||
|
resp.StatusCode != http.StatusOK {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -200,8 +210,17 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
}
|
}
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
tokenName := c.GetString(ctxkey.TokenName)
|
tokenName := c.GetString(ctxkey.TokenName)
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("倍率:%.2f × %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent)
|
model.RecordConsumeLog(ctx, &model.Log{
|
||||||
|
UserId: meta.UserId,
|
||||||
|
ChannelId: meta.ChannelId,
|
||||||
|
PromptTokens: 0,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
ModelName: imageRequest.Model,
|
||||||
|
TokenName: tokenName,
|
||||||
|
Quota: int(quota),
|
||||||
|
Content: logContent,
|
||||||
|
})
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
||||||
channelId := c.GetInt(ctxkey.ChannelId)
|
channelId := c.GetInt(ctxkey.ChannelId)
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/relay"
|
"github.com/songquanpeng/one-api/relay"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
@@ -35,6 +37,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
|||||||
meta.OriginModelName = textRequest.Model
|
meta.OriginModelName = textRequest.Model
|
||||||
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
|
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||||
meta.ActualModelName = textRequest.Model
|
meta.ActualModelName = textRequest.Model
|
||||||
|
// set system prompt if not empty
|
||||||
|
systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt)
|
||||||
// get model ratio & group ratio
|
// get model ratio & group ratio
|
||||||
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
|
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
|
||||||
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||||
@@ -79,12 +83,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
|||||||
return respErr
|
return respErr
|
||||||
}
|
}
|
||||||
// post-consume quota
|
// post-consume quota
|
||||||
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
|
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
|
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
|
||||||
if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
|
if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
|
||||||
// no need to convert request for openai
|
// no need to convert request for openai
|
||||||
return c.Request.Body, nil
|
return c.Request.Body, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
package meta
|
package meta
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Meta struct {
|
type Meta struct {
|
||||||
@@ -30,6 +33,8 @@ type Meta struct {
|
|||||||
ActualModelName string
|
ActualModelName string
|
||||||
RequestURLPath string
|
RequestURLPath string
|
||||||
PromptTokens int // only for DoResponse
|
PromptTokens int // only for DoResponse
|
||||||
|
SystemPrompt string
|
||||||
|
StartTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetByContext(c *gin.Context) *Meta {
|
func GetByContext(c *gin.Context) *Meta {
|
||||||
@@ -46,6 +51,8 @@ func GetByContext(c *gin.Context) *Meta {
|
|||||||
BaseURL: c.GetString(ctxkey.BaseURL),
|
BaseURL: c.GetString(ctxkey.BaseURL),
|
||||||
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||||
RequestURLPath: c.Request.URL.String(),
|
RequestURLPath: c.Request.URL.String(),
|
||||||
|
SystemPrompt: c.GetString(ctxkey.SystemPrompt),
|
||||||
|
StartTime: time.Now(),
|
||||||
}
|
}
|
||||||
cfg, ok := c.Get(ctxkey.Config)
|
cfg, ok := c.Get(ctxkey.Config)
|
||||||
if ok {
|
if ok {
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
func SetRelayRouter(router *gin.Engine) {
|
func SetRelayRouter(router *gin.Engine) {
|
||||||
router.Use(middleware.CORS())
|
router.Use(middleware.CORS())
|
||||||
|
router.Use(middleware.GzipDecodeMiddleware())
|
||||||
// https://platform.openai.com/docs/api-reference/introduction
|
// https://platform.openai.com/docs/api-reference/introduction
|
||||||
modelsRouter := router.Group("/v1/models")
|
modelsRouter := router.Group("/v1/models")
|
||||||
modelsRouter.Use(middleware.TokenAuth())
|
modelsRouter.Use(middleware.TokenAuth())
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
|
|||||||
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
||||||
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
||||||
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
||||||
|
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ const EditChannel = (props) => {
|
|||||||
base_url: '',
|
base_url: '',
|
||||||
other: '',
|
other: '',
|
||||||
model_mapping: '',
|
model_mapping: '',
|
||||||
|
system_prompt: '',
|
||||||
models: [],
|
models: [],
|
||||||
auto_ban: 1,
|
auto_ban: 1,
|
||||||
groups: ['default']
|
groups: ['default']
|
||||||
@@ -304,163 +305,163 @@ const EditChannel = (props) => {
|
|||||||
width={isMobile() ? '100%' : 600}
|
width={isMobile() ? '100%' : 600}
|
||||||
>
|
>
|
||||||
<Spin spinning={loading}>
|
<Spin spinning={loading}>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>类型:</Typography.Text>
|
<Typography.Text strong>类型:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Select
|
<Select
|
||||||
name='type'
|
name='type'
|
||||||
required
|
required
|
||||||
optionList={CHANNEL_OPTIONS}
|
optionList={CHANNEL_OPTIONS}
|
||||||
value={inputs.type}
|
value={inputs.type}
|
||||||
onChange={value => handleInputChange('type', value)}
|
onChange={value => handleInputChange('type', value)}
|
||||||
style={{width: '50%'}}
|
style={{ width: '50%' }}
|
||||||
/>
|
/>
|
||||||
{
|
{
|
||||||
inputs.type === 3 && (
|
inputs.type === 3 && (
|
||||||
<>
|
<>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Banner type={"warning"} description={
|
<Banner type={"warning"} description={
|
||||||
<>
|
<>
|
||||||
注意,<strong>模型部署名称必须和模型名称保持一致</strong>,因为 One API 会把请求体中的
|
注意,<strong>模型部署名称必须和模型名称保持一致</strong>,因为 One API 会把请求体中的
|
||||||
model
|
model
|
||||||
参数替换为你的部署名称(模型名称中的点会被剔除),<a target='_blank'
|
参数替换为你的部署名称(模型名称中的点会被剔除),<a target='_blank'
|
||||||
href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>。
|
href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>。
|
||||||
</>
|
</>
|
||||||
}>
|
}>
|
||||||
</Banner>
|
</Banner>
|
||||||
</div>
|
</div>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>AZURE_OPENAI_ENDPOINT:</Typography.Text>
|
<Typography.Text strong>AZURE_OPENAI_ENDPOINT:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
label='AZURE_OPENAI_ENDPOINT'
|
label='AZURE_OPENAI_ENDPOINT'
|
||||||
name='azure_base_url'
|
name='azure_base_url'
|
||||||
placeholder={'请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com'}
|
placeholder={'请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('base_url', value)
|
handleInputChange('base_url', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.base_url}
|
value={inputs.base_url}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>默认 API 版本:</Typography.Text>
|
<Typography.Text strong>默认 API 版本:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
label='默认 API 版本'
|
label='默认 API 版本'
|
||||||
name='azure_other'
|
name='azure_other'
|
||||||
placeholder={'请输入默认 API 版本,例如:2024-03-01-preview,该配置可以被实际的请求查询参数所覆盖'}
|
placeholder={'请输入默认 API 版本,例如:2024-03-01-preview,该配置可以被实际的请求查询参数所覆盖'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('other', value)
|
handleInputChange('other', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.other}
|
value={inputs.other}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
inputs.type === 8 && (
|
inputs.type === 8 && (
|
||||||
<>
|
<>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>Base URL:</Typography.Text>
|
<Typography.Text strong>Base URL:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
name='base_url'
|
name='base_url'
|
||||||
placeholder={'请输入自定义渠道的 Base URL'}
|
placeholder={'请输入自定义渠道的 Base URL'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('base_url', value)
|
handleInputChange('base_url', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.base_url}
|
value={inputs.base_url}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>名称:</Typography.Text>
|
<Typography.Text strong>名称:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
required
|
required
|
||||||
name='name'
|
name='name'
|
||||||
placeholder={'请为渠道命名'}
|
placeholder={'请为渠道命名'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('name', value)
|
handleInputChange('name', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.name}
|
value={inputs.name}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>分组:</Typography.Text>
|
<Typography.Text strong>分组:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Select
|
<Select
|
||||||
placeholder={'请选择可以使用该渠道的分组'}
|
placeholder={'请选择可以使用该渠道的分组'}
|
||||||
name='groups'
|
name='groups'
|
||||||
required
|
required
|
||||||
multiple
|
multiple
|
||||||
selection
|
selection
|
||||||
allowAdditions
|
allowAdditions
|
||||||
additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
|
additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('groups', value)
|
handleInputChange('groups', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.groups}
|
value={inputs.groups}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
optionList={groupOptions}
|
optionList={groupOptions}
|
||||||
/>
|
/>
|
||||||
{
|
{
|
||||||
inputs.type === 18 && (
|
inputs.type === 18 && (
|
||||||
<>
|
<>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>模型版本:</Typography.Text>
|
<Typography.Text strong>模型版本:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
name='other'
|
name='other'
|
||||||
placeholder={'请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1'}
|
placeholder={'请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('other', value)
|
handleInputChange('other', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.other}
|
value={inputs.other}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
inputs.type === 21 && (
|
inputs.type === 21 && (
|
||||||
<>
|
<>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>知识库 ID:</Typography.Text>
|
<Typography.Text strong>知识库 ID:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
label='知识库 ID'
|
label='知识库 ID'
|
||||||
name='other'
|
name='other'
|
||||||
placeholder={'请输入知识库 ID,例如:123456'}
|
placeholder={'请输入知识库 ID,例如:123456'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('other', value)
|
handleInputChange('other', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.other}
|
value={inputs.other}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>模型:</Typography.Text>
|
<Typography.Text strong>模型:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Select
|
<Select
|
||||||
placeholder={'请选择该渠道所支持的模型'}
|
placeholder={'请选择该渠道所支持的模型'}
|
||||||
name='models'
|
name='models'
|
||||||
required
|
required
|
||||||
multiple
|
multiple
|
||||||
selection
|
selection
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('models', value)
|
handleInputChange('models', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.models}
|
value={inputs.models}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
optionList={modelOptions}
|
optionList={modelOptions}
|
||||||
/>
|
/>
|
||||||
<div style={{lineHeight: '40px', marginBottom: '12px'}}>
|
<div style={{ lineHeight: '40px', marginBottom: '12px' }}>
|
||||||
<Space>
|
<Space>
|
||||||
<Button type='primary' onClick={() => {
|
<Button type='primary' onClick={() => {
|
||||||
handleInputChange('models', basicModels);
|
handleInputChange('models', basicModels);
|
||||||
@@ -473,28 +474,41 @@ const EditChannel = (props) => {
|
|||||||
}}>清除所有模型</Button>
|
}}>清除所有模型</Button>
|
||||||
</Space>
|
</Space>
|
||||||
<Input
|
<Input
|
||||||
addonAfter={
|
addonAfter={
|
||||||
<Button type='primary' onClick={addCustomModel}>填入</Button>
|
<Button type='primary' onClick={addCustomModel}>填入</Button>
|
||||||
}
|
}
|
||||||
placeholder='输入自定义模型名称'
|
placeholder='输入自定义模型名称'
|
||||||
value={customModel}
|
value={customModel}
|
||||||
onChange={(value) => {
|
onChange={(value) => {
|
||||||
setCustomModel(value.trim());
|
setCustomModel(value.trim());
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>模型重定向:</Typography.Text>
|
<Typography.Text strong>模型重定向:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<TextArea
|
<TextArea
|
||||||
placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
|
placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
|
||||||
name='model_mapping'
|
name='model_mapping'
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('model_mapping', value)
|
handleInputChange('model_mapping', value)
|
||||||
}}
|
}}
|
||||||
autosize
|
autosize
|
||||||
value={inputs.model_mapping}
|
value={inputs.model_mapping}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
<div style={{ marginTop: 10 }}>
|
||||||
|
<Typography.Text strong>系统提示词:</Typography.Text>
|
||||||
|
</div>
|
||||||
|
<TextArea
|
||||||
|
placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`}
|
||||||
|
name='system_prompt'
|
||||||
|
onChange={value => {
|
||||||
|
handleInputChange('system_prompt', value)
|
||||||
|
}}
|
||||||
|
autosize
|
||||||
|
value={inputs.system_prompt}
|
||||||
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
<Typography.Text style={{
|
<Typography.Text style={{
|
||||||
color: 'rgba(var(--semi-blue-5), 1)',
|
color: 'rgba(var(--semi-blue-5), 1)',
|
||||||
@@ -507,116 +521,116 @@ const EditChannel = (props) => {
|
|||||||
}>
|
}>
|
||||||
填入模板
|
填入模板
|
||||||
</Typography.Text>
|
</Typography.Text>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>密钥:</Typography.Text>
|
<Typography.Text strong>密钥:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
{
|
{
|
||||||
batch ?
|
batch ?
|
||||||
<TextArea
|
<TextArea
|
||||||
label='密钥'
|
label='密钥'
|
||||||
name='key'
|
name='key'
|
||||||
required
|
required
|
||||||
placeholder={'请输入密钥,一行一个'}
|
placeholder={'请输入密钥,一行一个'}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('key', value)
|
handleInputChange('key', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.key}
|
value={inputs.key}
|
||||||
style={{minHeight: 150, fontFamily: 'JetBrains Mono, Consolas'}}
|
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
:
|
:
|
||||||
<Input
|
<Input
|
||||||
label='密钥'
|
label='密钥'
|
||||||
name='key'
|
name='key'
|
||||||
required
|
required
|
||||||
placeholder={type2secretPrompt(inputs.type)}
|
placeholder={type2secretPrompt(inputs.type)}
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('key', value)
|
handleInputChange('key', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.key}
|
value={inputs.key}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
}
|
}
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>组织:</Typography.Text>
|
<Typography.Text strong>组织:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
label='组织,可选,不填则为默认组织'
|
label='组织,可选,不填则为默认组织'
|
||||||
name='openai_organization'
|
name='openai_organization'
|
||||||
placeholder='请输入组织org-xxx'
|
placeholder='请输入组织org-xxx'
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
handleInputChange('openai_organization', value)
|
handleInputChange('openai_organization', value)
|
||||||
}}
|
}}
|
||||||
value={inputs.openai_organization}
|
value={inputs.openai_organization}
|
||||||
/>
|
/>
|
||||||
<div style={{marginTop: 10, display: 'flex'}}>
|
<div style={{ marginTop: 10, display: 'flex' }}>
|
||||||
<Space>
|
<Space>
|
||||||
<Checkbox
|
<Checkbox
|
||||||
name='auto_ban'
|
name='auto_ban'
|
||||||
checked={autoBan}
|
checked={autoBan}
|
||||||
onChange={
|
onChange={
|
||||||
() => {
|
() => {
|
||||||
setAutoBan(!autoBan);
|
setAutoBan(!autoBan);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// onChange={handleInputChange}
|
// onChange={handleInputChange}
|
||||||
/>
|
/>
|
||||||
<Typography.Text
|
<Typography.Text
|
||||||
strong>是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道:</Typography.Text>
|
strong>是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道:</Typography.Text>
|
||||||
</Space>
|
</Space>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{
|
{
|
||||||
!isEdit && (
|
!isEdit && (
|
||||||
<div style={{marginTop: 10, display: 'flex'}}>
|
<div style={{ marginTop: 10, display: 'flex' }}>
|
||||||
<Space>
|
<Space>
|
||||||
<Checkbox
|
<Checkbox
|
||||||
checked={batch}
|
checked={batch}
|
||||||
label='批量创建'
|
label='批量创建'
|
||||||
name='batch'
|
name='batch'
|
||||||
onChange={() => setBatch(!batch)}
|
onChange={() => setBatch(!batch)}
|
||||||
/>
|
/>
|
||||||
<Typography.Text strong>批量创建</Typography.Text>
|
<Typography.Text strong>批量创建</Typography.Text>
|
||||||
</Space>
|
</Space>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
|
||||||
|
<>
|
||||||
|
<div style={{ marginTop: 10 }}>
|
||||||
|
<Typography.Text strong>代理:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
)
|
<Input
|
||||||
|
label='代理'
|
||||||
|
name='base_url'
|
||||||
|
placeholder={'此项可选,用于通过代理站来进行 API 调用'}
|
||||||
|
onChange={value => {
|
||||||
|
handleInputChange('base_url', value)
|
||||||
|
}}
|
||||||
|
value={inputs.base_url}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
|
inputs.type === 22 && (
|
||||||
<>
|
<>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>代理:</Typography.Text>
|
<Typography.Text strong>私有部署地址:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
<Input
|
<Input
|
||||||
label='代理'
|
name='base_url'
|
||||||
name='base_url'
|
placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'}
|
||||||
placeholder={'此项可选,用于通过代理站来进行 API 调用'}
|
onChange={value => {
|
||||||
onChange={value => {
|
handleInputChange('base_url', value)
|
||||||
handleInputChange('base_url', value)
|
}}
|
||||||
}}
|
value={inputs.base_url}
|
||||||
value={inputs.base_url}
|
autoComplete='new-password'
|
||||||
autoComplete='new-password'
|
/>
|
||||||
/>
|
</>
|
||||||
</>
|
)
|
||||||
)
|
|
||||||
}
|
|
||||||
{
|
|
||||||
inputs.type === 22 && (
|
|
||||||
<>
|
|
||||||
<div style={{marginTop: 10}}>
|
|
||||||
<Typography.Text strong>私有部署地址:</Typography.Text>
|
|
||||||
</div>
|
|
||||||
<Input
|
|
||||||
name='base_url'
|
|
||||||
placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'}
|
|
||||||
onChange={value => {
|
|
||||||
handleInputChange('base_url', value)
|
|
||||||
}}
|
|
||||||
value={inputs.base_url}
|
|
||||||
autoComplete='new-password'
|
|
||||||
/>
|
|
||||||
</>
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
</Spin>
|
</Spin>
|
||||||
|
|||||||
@@ -185,6 +185,12 @@ export const CHANNEL_OPTIONS = {
|
|||||||
value: 45,
|
value: 45,
|
||||||
color: 'primary'
|
color: 'primary'
|
||||||
},
|
},
|
||||||
|
45: {
|
||||||
|
key: 46,
|
||||||
|
text: 'Replicate',
|
||||||
|
value: 46,
|
||||||
|
color: 'primary'
|
||||||
|
},
|
||||||
41: {
|
41: {
|
||||||
key: 41,
|
key: 41,
|
||||||
text: 'Novita',
|
text: 'Novita',
|
||||||
|
|||||||
@@ -1,247 +1,260 @@
|
|||||||
import { enqueueSnackbar } from 'notistack';
|
import {enqueueSnackbar} from 'notistack';
|
||||||
import { snackbarConstants } from 'constants/SnackbarConstants';
|
import {snackbarConstants} from 'constants/SnackbarConstants';
|
||||||
import { API } from './api';
|
import {API} from './api';
|
||||||
|
|
||||||
export function getSystemName() {
|
export function getSystemName() {
|
||||||
let system_name = localStorage.getItem('system_name');
|
let system_name = localStorage.getItem('system_name');
|
||||||
if (!system_name) return 'One API';
|
if (!system_name) return 'One API';
|
||||||
return system_name;
|
return system_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isMobile() {
|
export function isMobile() {
|
||||||
return window.innerWidth <= 600;
|
return window.innerWidth <= 600;
|
||||||
}
|
}
|
||||||
|
|
||||||
// eslint-disable-next-line
|
// eslint-disable-next-line
|
||||||
export function SnackbarHTMLContent({ htmlContent }) {
|
export function SnackbarHTMLContent({htmlContent}) {
|
||||||
return <div dangerouslySetInnerHTML={{ __html: htmlContent }} />;
|
return <div dangerouslySetInnerHTML={{__html: htmlContent}}/>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getSnackbarOptions(variant) {
|
export function getSnackbarOptions(variant) {
|
||||||
let options = snackbarConstants.Common[variant];
|
let options = snackbarConstants.Common[variant];
|
||||||
if (isMobile()) {
|
if (isMobile()) {
|
||||||
// 合并 options 和 snackbarConstants.Mobile
|
// 合并 options 和 snackbarConstants.Mobile
|
||||||
options = { ...options, ...snackbarConstants.Mobile };
|
options = {...options, ...snackbarConstants.Mobile};
|
||||||
}
|
}
|
||||||
return options;
|
return options;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function showError(error) {
|
export function showError(error) {
|
||||||
if (error.message) {
|
if (error.message) {
|
||||||
if (error.name === 'AxiosError') {
|
if (error.name === 'AxiosError') {
|
||||||
switch (error.response.status) {
|
switch (error.response.status) {
|
||||||
case 429:
|
case 429:
|
||||||
enqueueSnackbar('错误:请求次数过多,请稍后再试!', getSnackbarOptions('ERROR'));
|
enqueueSnackbar('错误:请求次数过多,请稍后再试!', getSnackbarOptions('ERROR'));
|
||||||
break;
|
break;
|
||||||
case 500:
|
case 500:
|
||||||
enqueueSnackbar('错误:服务器内部错误,请联系管理员!', getSnackbarOptions('ERROR'));
|
enqueueSnackbar('错误:服务器内部错误,请联系管理员!', getSnackbarOptions('ERROR'));
|
||||||
break;
|
break;
|
||||||
case 405:
|
case 405:
|
||||||
enqueueSnackbar('本站仅作演示之用,无服务端!', getSnackbarOptions('INFO'));
|
enqueueSnackbar('本站仅作演示之用,无服务端!', getSnackbarOptions('INFO'));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
enqueueSnackbar('错误:' + error.message, getSnackbarOptions('ERROR'));
|
enqueueSnackbar('错误:' + error.message, getSnackbarOptions('ERROR'));
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
enqueueSnackbar('错误:' + error, getSnackbarOptions('ERROR'));
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
enqueueSnackbar('错误:' + error, getSnackbarOptions('ERROR'));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function showNotice(message, isHTML = false) {
|
export function showNotice(message, isHTML = false) {
|
||||||
if (isHTML) {
|
if (isHTML) {
|
||||||
enqueueSnackbar(<SnackbarHTMLContent htmlContent={message} />, getSnackbarOptions('NOTICE'));
|
enqueueSnackbar(<SnackbarHTMLContent htmlContent={message}/>, getSnackbarOptions('NOTICE'));
|
||||||
} else {
|
} else {
|
||||||
enqueueSnackbar(message, getSnackbarOptions('NOTICE'));
|
enqueueSnackbar(message, getSnackbarOptions('NOTICE'));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function showWarning(message) {
|
export function showWarning(message) {
|
||||||
enqueueSnackbar(message, getSnackbarOptions('WARNING'));
|
enqueueSnackbar(message, getSnackbarOptions('WARNING'));
|
||||||
}
|
}
|
||||||
|
|
||||||
export function showSuccess(message) {
|
export function showSuccess(message) {
|
||||||
enqueueSnackbar(message, getSnackbarOptions('SUCCESS'));
|
enqueueSnackbar(message, getSnackbarOptions('SUCCESS'));
|
||||||
}
|
}
|
||||||
|
|
||||||
export function showInfo(message) {
|
export function showInfo(message) {
|
||||||
enqueueSnackbar(message, getSnackbarOptions('INFO'));
|
enqueueSnackbar(message, getSnackbarOptions('INFO'));
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function getOAuthState() {
|
export async function getOAuthState() {
|
||||||
const res = await API.get('/api/oauth/state');
|
const res = await API.get('/api/oauth/state');
|
||||||
const { success, message, data } = res.data;
|
const {success, message, data} = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
return data;
|
return data;
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function onGitHubOAuthClicked(github_client_id, openInNewTab = false) {
|
export async function onGitHubOAuthClicked(github_client_id, openInNewTab = false) {
|
||||||
const state = await getOAuthState();
|
const state = await getOAuthState();
|
||||||
if (!state) return;
|
if (!state) return;
|
||||||
let url = `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`;
|
let url = `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`;
|
||||||
if (openInNewTab) {
|
if (openInNewTab) {
|
||||||
window.open(url);
|
window.open(url);
|
||||||
} else {
|
} else {
|
||||||
window.location.href = url;
|
window.location.href = url;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function onLarkOAuthClicked(lark_client_id) {
|
export async function onLarkOAuthClicked(lark_client_id) {
|
||||||
const state = await getOAuthState();
|
const state = await getOAuthState();
|
||||||
if (!state) return;
|
if (!state) return;
|
||||||
let redirect_uri = `${window.location.origin}/oauth/lark`;
|
let redirect_uri = `${window.location.origin}/oauth/lark`;
|
||||||
window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`);
|
window.open(`https://accounts.feishu.cn/open-apis/authen/v1/authorize?redirect_uri=${redirect_uri}&client_id=${lark_client_id}&state=${state}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function onOidcClicked(auth_url, client_id, openInNewTab = false) {
|
export async function onOidcClicked(auth_url, client_id, openInNewTab = false) {
|
||||||
const state = await getOAuthState();
|
const state = await getOAuthState();
|
||||||
if (!state) return;
|
if (!state) return;
|
||||||
const redirect_uri = `${window.location.origin}/oauth/oidc`;
|
const redirect_uri = `${window.location.origin}/oauth/oidc`;
|
||||||
const response_type = "code";
|
const response_type = "code";
|
||||||
const scope = "openid profile email";
|
const scope = "openid profile email";
|
||||||
const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`;
|
const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`;
|
||||||
if (openInNewTab) {
|
if (openInNewTab) {
|
||||||
window.open(url);
|
window.open(url);
|
||||||
} else
|
} else {
|
||||||
{
|
window.location.href = url;
|
||||||
window.location.href = url;
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isAdmin() {
|
export function isAdmin() {
|
||||||
let user = localStorage.getItem('user');
|
let user = localStorage.getItem('user');
|
||||||
if (!user) return false;
|
if (!user) return false;
|
||||||
user = JSON.parse(user);
|
user = JSON.parse(user);
|
||||||
return user.role >= 10;
|
return user.role >= 10;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function timestamp2string(timestamp) {
|
export function timestamp2string(timestamp) {
|
||||||
let date = new Date(timestamp * 1000);
|
let date = new Date(timestamp * 1000);
|
||||||
let year = date.getFullYear().toString();
|
let year = date.getFullYear().toString();
|
||||||
let month = (date.getMonth() + 1).toString();
|
let month = (date.getMonth() + 1).toString();
|
||||||
let day = date.getDate().toString();
|
let day = date.getDate().toString();
|
||||||
let hour = date.getHours().toString();
|
let hour = date.getHours().toString();
|
||||||
let minute = date.getMinutes().toString();
|
let minute = date.getMinutes().toString();
|
||||||
let second = date.getSeconds().toString();
|
let second = date.getSeconds().toString();
|
||||||
if (month.length === 1) {
|
if (month.length === 1) {
|
||||||
month = '0' + month;
|
month = '0' + month;
|
||||||
}
|
}
|
||||||
if (day.length === 1) {
|
if (day.length === 1) {
|
||||||
day = '0' + day;
|
day = '0' + day;
|
||||||
}
|
}
|
||||||
if (hour.length === 1) {
|
if (hour.length === 1) {
|
||||||
hour = '0' + hour;
|
hour = '0' + hour;
|
||||||
}
|
}
|
||||||
if (minute.length === 1) {
|
if (minute.length === 1) {
|
||||||
minute = '0' + minute;
|
minute = '0' + minute;
|
||||||
}
|
}
|
||||||
if (second.length === 1) {
|
if (second.length === 1) {
|
||||||
second = '0' + second;
|
second = '0' + second;
|
||||||
}
|
}
|
||||||
return year + '-' + month + '-' + day + ' ' + hour + ':' + minute + ':' + second;
|
return year + '-' + month + '-' + day + ' ' + hour + ':' + minute + ':' + second;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function calculateQuota(quota, digits = 2) {
|
export function calculateQuota(quota, digits = 2) {
|
||||||
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
||||||
quotaPerUnit = parseFloat(quotaPerUnit);
|
quotaPerUnit = parseFloat(quotaPerUnit);
|
||||||
|
|
||||||
return (quota / quotaPerUnit).toFixed(digits);
|
return (quota / quotaPerUnit).toFixed(digits);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function renderQuota(quota, digits = 2) {
|
export function renderQuota(quota, digits = 2) {
|
||||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||||
displayInCurrency = displayInCurrency === 'true';
|
displayInCurrency = displayInCurrency === 'true';
|
||||||
if (displayInCurrency) {
|
if (displayInCurrency) {
|
||||||
return '$' + calculateQuota(quota, digits);
|
return '$' + calculateQuota(quota, digits);
|
||||||
}
|
}
|
||||||
return renderNumber(quota);
|
return renderNumber(quota);
|
||||||
}
|
}
|
||||||
|
|
||||||
export const verifyJSON = (str) => {
|
export const verifyJSON = (str) => {
|
||||||
try {
|
try {
|
||||||
JSON.parse(str);
|
JSON.parse(str);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
export function renderNumber(num) {
|
export function renderNumber(num) {
|
||||||
if (num >= 1000000000) {
|
if (num >= 1000000000) {
|
||||||
return (num / 1000000000).toFixed(1) + 'B';
|
return (num / 1000000000).toFixed(1) + 'B';
|
||||||
} else if (num >= 1000000) {
|
} else if (num >= 1000000) {
|
||||||
return (num / 1000000).toFixed(1) + 'M';
|
return (num / 1000000).toFixed(1) + 'M';
|
||||||
} else if (num >= 10000) {
|
} else if (num >= 10000) {
|
||||||
return (num / 1000).toFixed(1) + 'k';
|
return (num / 1000).toFixed(1) + 'k';
|
||||||
} else {
|
} else {
|
||||||
return num;
|
return num;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function renderQuotaWithPrompt(quota, digits) {
|
export function renderQuotaWithPrompt(quota, digits) {
|
||||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||||
displayInCurrency = displayInCurrency === 'true';
|
displayInCurrency = displayInCurrency === 'true';
|
||||||
if (displayInCurrency) {
|
if (displayInCurrency) {
|
||||||
return `(等价金额:${renderQuota(quota, digits)})`;
|
return `(等价金额:${renderQuota(quota, digits)})`;
|
||||||
}
|
}
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
|
|
||||||
export function downloadTextAsFile(text, filename) {
|
export function downloadTextAsFile(text, filename) {
|
||||||
let blob = new Blob([text], { type: 'text/plain;charset=utf-8' });
|
let blob = new Blob([text], {type: 'text/plain;charset=utf-8'});
|
||||||
let url = URL.createObjectURL(blob);
|
let url = URL.createObjectURL(blob);
|
||||||
let a = document.createElement('a');
|
let a = document.createElement('a');
|
||||||
a.href = url;
|
a.href = url;
|
||||||
a.download = filename;
|
a.download = filename;
|
||||||
a.click();
|
a.click();
|
||||||
}
|
}
|
||||||
|
|
||||||
export function removeTrailingSlash(url) {
|
export function removeTrailingSlash(url) {
|
||||||
if (url.endsWith('/')) {
|
if (url.endsWith('/')) {
|
||||||
return url.slice(0, -1);
|
return url.slice(0, -1);
|
||||||
} else {
|
} else {
|
||||||
return url;
|
return url;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let channelModels = undefined;
|
let channelModels = undefined;
|
||||||
|
|
||||||
export async function loadChannelModels() {
|
export async function loadChannelModels() {
|
||||||
const res = await API.get('/api/models');
|
const res = await API.get('/api/models');
|
||||||
const { success, data } = res.data;
|
const {success, data} = res.data;
|
||||||
if (!success) {
|
if (!success) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
channelModels = data;
|
channelModels = data;
|
||||||
localStorage.setItem('channel_models', JSON.stringify(data));
|
localStorage.setItem('channel_models', JSON.stringify(data));
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getChannelModels(type) {
|
export function getChannelModels(type) {
|
||||||
if (channelModels !== undefined && type in channelModels) {
|
if (channelModels !== undefined && type in channelModels) {
|
||||||
return channelModels[type];
|
return channelModels[type];
|
||||||
}
|
}
|
||||||
let models = localStorage.getItem('channel_models');
|
let models = localStorage.getItem('channel_models');
|
||||||
if (!models) {
|
if (!models) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
channelModels = JSON.parse(models);
|
||||||
|
if (type in channelModels) {
|
||||||
|
return channelModels[type];
|
||||||
|
}
|
||||||
return [];
|
return [];
|
||||||
}
|
|
||||||
channelModels = JSON.parse(models);
|
|
||||||
if (type in channelModels) {
|
|
||||||
return channelModels[type];
|
|
||||||
}
|
|
||||||
return [];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function copy(text, name = '') {
|
export function copy(text, name = '') {
|
||||||
try {
|
if (navigator.clipboard && navigator.clipboard.writeText) {
|
||||||
navigator.clipboard.writeText(text);
|
navigator.clipboard.writeText(text).then(() => {
|
||||||
} catch (error) {
|
showNotice(`复制${name}成功!`, true);
|
||||||
text = `复制${name}失败,请手动复制:<br /><br />${text}`;
|
}, () => {
|
||||||
enqueueSnackbar(<SnackbarHTMLContent htmlContent={text} />, getSnackbarOptions('COPY'));
|
text = `复制${name}失败,请手动复制:<br /><br />${text}`;
|
||||||
return;
|
enqueueSnackbar(<SnackbarHTMLContent htmlContent={text}/>, getSnackbarOptions('COPY'));
|
||||||
}
|
});
|
||||||
showSuccess(`复制${name}成功!`);
|
} else {
|
||||||
|
const textArea = document.createElement("textarea");
|
||||||
|
textArea.value = text;
|
||||||
|
document.body.appendChild(textArea);
|
||||||
|
textArea.select();
|
||||||
|
try {
|
||||||
|
document.execCommand('copy');
|
||||||
|
showNotice(`复制${name}成功!`, true);
|
||||||
|
} catch (err) {
|
||||||
|
text = `复制${name}失败,请手动复制:<br /><br />${text}`;
|
||||||
|
enqueueSnackbar(<SnackbarHTMLContent htmlContent={text}/>, getSnackbarOptions('COPY'));
|
||||||
|
}
|
||||||
|
document.body.removeChild(textArea);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -595,6 +595,28 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
|
|||||||
<FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText>
|
<FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText>
|
||||||
)}
|
)}
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
<FormControl fullWidth error={Boolean(touched.system_prompt && errors.system_prompt)} sx={{ ...theme.typography.otherInput }}>
|
||||||
|
{/* <InputLabel htmlFor="channel-model_mapping-label">{inputLabel.model_mapping}</InputLabel> */}
|
||||||
|
<TextField
|
||||||
|
multiline
|
||||||
|
id="channel-system_prompt-label"
|
||||||
|
label={inputLabel.system_prompt}
|
||||||
|
value={values.system_prompt}
|
||||||
|
name="system_prompt"
|
||||||
|
onBlur={handleBlur}
|
||||||
|
onChange={handleChange}
|
||||||
|
aria-describedby="helper-text-channel-system_prompt-label"
|
||||||
|
minRows={5}
|
||||||
|
placeholder={inputPrompt.system_prompt}
|
||||||
|
/>
|
||||||
|
{touched.system_prompt && errors.system_prompt ? (
|
||||||
|
<FormHelperText error id="helper-tex-channel-system_prompt-label">
|
||||||
|
{errors.system_prompt}
|
||||||
|
</FormHelperText>
|
||||||
|
) : (
|
||||||
|
<FormHelperText id="helper-tex-channel-system_prompt-label"> {inputPrompt.system_prompt} </FormHelperText>
|
||||||
|
)}
|
||||||
|
</FormControl>
|
||||||
<DialogActions>
|
<DialogActions>
|
||||||
<Button onClick={onCancel}>取消</Button>
|
<Button onClick={onCancel}>取消</Button>
|
||||||
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
|
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
|
||||||
|
|||||||
@@ -268,6 +268,8 @@ function renderBalance(type, balance) {
|
|||||||
return <span>¥{balance.toFixed(2)}</span>;
|
return <span>¥{balance.toFixed(2)}</span>;
|
||||||
case 13: // AIGC2D
|
case 13: // AIGC2D
|
||||||
return <span>{renderNumber(balance)}</span>;
|
return <span>{renderNumber(balance)}</span>;
|
||||||
|
case 36: // DeepSeek
|
||||||
|
return <span>¥{balance.toFixed(2)}</span>;
|
||||||
case 44: // SiliconFlow
|
case 44: // SiliconFlow
|
||||||
return <span>¥{balance.toFixed(2)}</span>;
|
return <span>¥{balance.toFixed(2)}</span>;
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ const defaultConfig = {
|
|||||||
other: '其他参数',
|
other: '其他参数',
|
||||||
models: '模型',
|
models: '模型',
|
||||||
model_mapping: '模型映射关系',
|
model_mapping: '模型映射关系',
|
||||||
|
system_prompt: '系统提示词',
|
||||||
groups: '用户组',
|
groups: '用户组',
|
||||||
config: null
|
config: null
|
||||||
},
|
},
|
||||||
@@ -30,6 +31,7 @@ const defaultConfig = {
|
|||||||
models: '请选择该渠道所支持的模型',
|
models: '请选择该渠道所支持的模型',
|
||||||
model_mapping:
|
model_mapping:
|
||||||
'请输入要修改的模型映射关系,格式为:api请求模型ID:实际转发给渠道的模型ID,使用JSON数组表示,例如:{"gpt-3.5": "gpt-35"}',
|
'请输入要修改的模型映射关系,格式为:api请求模型ID:实际转发给渠道的模型ID,使用JSON数组表示,例如:{"gpt-3.5": "gpt-35"}',
|
||||||
|
system_prompt:"此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型",
|
||||||
groups: '请选择该渠道所支持的用户组',
|
groups: '请选择该渠道所支持的用户组',
|
||||||
config: null
|
config: null
|
||||||
},
|
},
|
||||||
|
|||||||
0
web/build.sh
Normal file → Executable file
0
web/build.sh
Normal file → Executable file
@@ -1,5 +1,15 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import { Button, Dropdown, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
|
import {
|
||||||
|
Button,
|
||||||
|
Dropdown,
|
||||||
|
Form,
|
||||||
|
Input,
|
||||||
|
Label,
|
||||||
|
Message,
|
||||||
|
Pagination,
|
||||||
|
Popup,
|
||||||
|
Table,
|
||||||
|
} from 'semantic-ui-react';
|
||||||
import { Link } from 'react-router-dom';
|
import { Link } from 'react-router-dom';
|
||||||
import {
|
import {
|
||||||
API,
|
API,
|
||||||
@@ -9,31 +19,31 @@ import {
|
|||||||
showError,
|
showError,
|
||||||
showInfo,
|
showInfo,
|
||||||
showSuccess,
|
showSuccess,
|
||||||
timestamp2string
|
timestamp2string,
|
||||||
} from '../helpers';
|
} from '../helpers';
|
||||||
|
|
||||||
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
|
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
|
||||||
import { renderGroup, renderNumber } from '../helpers/render';
|
import { renderGroup, renderNumber } from '../helpers/render';
|
||||||
|
|
||||||
function renderTimestamp(timestamp) {
|
function renderTimestamp(timestamp) {
|
||||||
return (
|
return <>{timestamp2string(timestamp)}</>;
|
||||||
<>
|
|
||||||
{timestamp2string(timestamp)}
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let type2label = undefined;
|
let type2label = undefined;
|
||||||
|
|
||||||
function renderType(type) {
|
function renderType(type) {
|
||||||
if (!type2label) {
|
if (!type2label) {
|
||||||
type2label = new Map;
|
type2label = new Map();
|
||||||
for (let i = 0; i < CHANNEL_OPTIONS.length; i++) {
|
for (let i = 0; i < CHANNEL_OPTIONS.length; i++) {
|
||||||
type2label[CHANNEL_OPTIONS[i].value] = CHANNEL_OPTIONS[i];
|
type2label[CHANNEL_OPTIONS[i].value] = CHANNEL_OPTIONS[i];
|
||||||
}
|
}
|
||||||
type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
|
type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
|
||||||
}
|
}
|
||||||
return <Label basic color={type2label[type]?.color}>{type2label[type] ? type2label[type].text : type}</Label>;
|
return (
|
||||||
|
<Label basic color={type2label[type]?.color}>
|
||||||
|
{type2label[type] ? type2label[type].text : type}
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderBalance(type, balance) {
|
function renderBalance(type, balance) {
|
||||||
@@ -52,6 +62,8 @@ function renderBalance(type, balance) {
|
|||||||
return <span>¥{balance.toFixed(2)}</span>;
|
return <span>¥{balance.toFixed(2)}</span>;
|
||||||
case 13: // AIGC2D
|
case 13: // AIGC2D
|
||||||
return <span>{renderNumber(balance)}</span>;
|
return <span>{renderNumber(balance)}</span>;
|
||||||
|
case 36: // DeepSeek
|
||||||
|
return <span>¥{balance.toFixed(2)}</span>;
|
||||||
case 44: // SiliconFlow
|
case 44: // SiliconFlow
|
||||||
return <span>¥{balance.toFixed(2)}</span>;
|
return <span>¥{balance.toFixed(2)}</span>;
|
||||||
default:
|
default:
|
||||||
@@ -60,10 +72,10 @@ function renderBalance(type, balance) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function isShowDetail() {
|
function isShowDetail() {
|
||||||
return localStorage.getItem("show_detail") === "true";
|
return localStorage.getItem('show_detail') === 'true';
|
||||||
}
|
}
|
||||||
|
|
||||||
const promptID = "detail"
|
const promptID = 'detail';
|
||||||
|
|
||||||
const ChannelsTable = () => {
|
const ChannelsTable = () => {
|
||||||
const [channels, setChannels] = useState([]);
|
const [channels, setChannels] = useState([]);
|
||||||
@@ -79,33 +91,37 @@ const ChannelsTable = () => {
|
|||||||
const res = await API.get(`/api/channel/?p=${startIdx}`);
|
const res = await API.get(`/api/channel/?p=${startIdx}`);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
let localChannels = data.map((channel) => {
|
let localChannels = data.map((channel) => {
|
||||||
if (channel.models === '') {
|
if (channel.models === '') {
|
||||||
channel.models = [];
|
channel.models = [];
|
||||||
channel.test_model = "";
|
channel.test_model = '';
|
||||||
} else {
|
|
||||||
channel.models = channel.models.split(',');
|
|
||||||
if (channel.models.length > 0) {
|
|
||||||
channel.test_model = channel.models[0];
|
|
||||||
}
|
|
||||||
channel.model_options = channel.models.map((model) => {
|
|
||||||
return {
|
|
||||||
key: model,
|
|
||||||
text: model,
|
|
||||||
value: model,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
console.log('channel', channel)
|
|
||||||
}
|
|
||||||
return channel;
|
|
||||||
});
|
|
||||||
if (startIdx === 0) {
|
|
||||||
setChannels(localChannels);
|
|
||||||
} else {
|
} else {
|
||||||
let newChannels = [...channels];
|
channel.models = channel.models.split(',');
|
||||||
newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...localChannels);
|
if (channel.models.length > 0) {
|
||||||
setChannels(newChannels);
|
channel.test_model = channel.models[0];
|
||||||
|
}
|
||||||
|
channel.model_options = channel.models.map((model) => {
|
||||||
|
return {
|
||||||
|
key: model,
|
||||||
|
text: model,
|
||||||
|
value: model,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
console.log('channel', channel);
|
||||||
}
|
}
|
||||||
|
return channel;
|
||||||
|
});
|
||||||
|
if (startIdx === 0) {
|
||||||
|
setChannels(localChannels);
|
||||||
|
} else {
|
||||||
|
let newChannels = [...channels];
|
||||||
|
newChannels.splice(
|
||||||
|
startIdx * ITEMS_PER_PAGE,
|
||||||
|
data.length,
|
||||||
|
...localChannels
|
||||||
|
);
|
||||||
|
setChannels(newChannels);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
}
|
}
|
||||||
@@ -129,8 +145,8 @@ const ChannelsTable = () => {
|
|||||||
|
|
||||||
const toggleShowDetail = () => {
|
const toggleShowDetail = () => {
|
||||||
setShowDetail(!showDetail);
|
setShowDetail(!showDetail);
|
||||||
localStorage.setItem("show_detail", (!showDetail).toString());
|
localStorage.setItem('show_detail', (!showDetail).toString());
|
||||||
}
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
loadChannels(0)
|
loadChannels(0)
|
||||||
@@ -194,13 +210,19 @@ const ChannelsTable = () => {
|
|||||||
const renderStatus = (status) => {
|
const renderStatus = (status) => {
|
||||||
switch (status) {
|
switch (status) {
|
||||||
case 1:
|
case 1:
|
||||||
return <Label basic color='green'>已启用</Label>;
|
return (
|
||||||
|
<Label basic color='green'>
|
||||||
|
已启用
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
case 2:
|
case 2:
|
||||||
return (
|
return (
|
||||||
<Popup
|
<Popup
|
||||||
trigger={<Label basic color='red'>
|
trigger={
|
||||||
已禁用
|
<Label basic color='red'>
|
||||||
</Label>}
|
已禁用
|
||||||
|
</Label>
|
||||||
|
}
|
||||||
content='本渠道被手动禁用'
|
content='本渠道被手动禁用'
|
||||||
basic
|
basic
|
||||||
/>
|
/>
|
||||||
@@ -208,9 +230,11 @@ const ChannelsTable = () => {
|
|||||||
case 3:
|
case 3:
|
||||||
return (
|
return (
|
||||||
<Popup
|
<Popup
|
||||||
trigger={<Label basic color='yellow'>
|
trigger={
|
||||||
已禁用
|
<Label basic color='yellow'>
|
||||||
</Label>}
|
已禁用
|
||||||
|
</Label>
|
||||||
|
}
|
||||||
content='本渠道被程序自动禁用'
|
content='本渠道被程序自动禁用'
|
||||||
basic
|
basic
|
||||||
/>
|
/>
|
||||||
@@ -228,15 +252,35 @@ const ChannelsTable = () => {
|
|||||||
let time = responseTime / 1000;
|
let time = responseTime / 1000;
|
||||||
time = time.toFixed(2) + ' 秒';
|
time = time.toFixed(2) + ' 秒';
|
||||||
if (responseTime === 0) {
|
if (responseTime === 0) {
|
||||||
return <Label basic color='grey'>未测试</Label>;
|
return (
|
||||||
|
<Label basic color='grey'>
|
||||||
|
未测试
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
} else if (responseTime <= 1000) {
|
} else if (responseTime <= 1000) {
|
||||||
return <Label basic color='green'>{time}</Label>;
|
return (
|
||||||
|
<Label basic color='green'>
|
||||||
|
{time}
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
} else if (responseTime <= 3000) {
|
} else if (responseTime <= 3000) {
|
||||||
return <Label basic color='olive'>{time}</Label>;
|
return (
|
||||||
|
<Label basic color='olive'>
|
||||||
|
{time}
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
} else if (responseTime <= 5000) {
|
} else if (responseTime <= 5000) {
|
||||||
return <Label basic color='yellow'>{time}</Label>;
|
return (
|
||||||
|
<Label basic color='yellow'>
|
||||||
|
{time}
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
return <Label basic color='red'>{time}</Label>;
|
return (
|
||||||
|
<Label basic color='red'>
|
||||||
|
{time}
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -275,7 +319,11 @@ const ChannelsTable = () => {
|
|||||||
newChannels[realIdx].response_time = time * 1000;
|
newChannels[realIdx].response_time = time * 1000;
|
||||||
newChannels[realIdx].test_time = Date.now() / 1000;
|
newChannels[realIdx].test_time = Date.now() / 1000;
|
||||||
setChannels(newChannels);
|
setChannels(newChannels);
|
||||||
showInfo(`渠道 ${name} 测试成功,模型 ${model},耗时 ${time.toFixed(2)} 秒。`);
|
showInfo(
|
||||||
|
`渠道 ${name} 测试成功,模型 ${model},耗时 ${time.toFixed(
|
||||||
|
2
|
||||||
|
)} 秒,模型输出:${message}`
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
}
|
}
|
||||||
@@ -358,7 +406,6 @@ const ChannelsTable = () => {
|
|||||||
setLoading(false);
|
setLoading(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Form onSubmit={searchChannels}>
|
<Form onSubmit={searchChannels}>
|
||||||
@@ -372,20 +419,22 @@ const ChannelsTable = () => {
|
|||||||
onChange={handleKeywordChange}
|
onChange={handleKeywordChange}
|
||||||
/>
|
/>
|
||||||
</Form>
|
</Form>
|
||||||
{
|
{showPrompt && (
|
||||||
showPrompt && (
|
<Message
|
||||||
<Message onDismiss={() => {
|
onDismiss={() => {
|
||||||
setShowPrompt(false);
|
setShowPrompt(false);
|
||||||
setPromptShown(promptID);
|
setPromptShown(promptID);
|
||||||
}}>
|
}}
|
||||||
OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
|
>
|
||||||
<br/>
|
OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为
|
||||||
渠道测试仅支持 chat 模型,优先使用 gpt-3.5-turbo,如果该模型不可用则使用你所配置的模型列表中的第一个模型。
|
0。对于支持的渠道类型,请点击余额进行刷新。
|
||||||
<br/>
|
<br />
|
||||||
点击下方详情按钮可以显示余额以及设置额外的测试模型。
|
渠道测试仅支持 chat 模型,优先使用
|
||||||
</Message>
|
gpt-3.5-turbo,如果该模型不可用则使用你所配置的模型列表中的第一个模型。
|
||||||
)
|
<br />
|
||||||
}
|
点击下方详情按钮可以显示余额以及设置额外的测试模型。
|
||||||
|
</Message>
|
||||||
|
)}
|
||||||
<Table basic compact size='small'>
|
<Table basic compact size='small'>
|
||||||
<Table.Header>
|
<Table.Header>
|
||||||
<Table.Row>
|
<Table.Row>
|
||||||
@@ -476,7 +525,11 @@ const ChannelsTable = () => {
|
|||||||
<Table.Cell>{renderStatus(channel.status)}</Table.Cell>
|
<Table.Cell>{renderStatus(channel.status)}</Table.Cell>
|
||||||
<Table.Cell>
|
<Table.Cell>
|
||||||
<Popup
|
<Popup
|
||||||
content={channel.test_time ? renderTimestamp(channel.test_time) : '未测试'}
|
content={
|
||||||
|
channel.test_time
|
||||||
|
? renderTimestamp(channel.test_time)
|
||||||
|
: '未测试'
|
||||||
|
}
|
||||||
key={channel.id}
|
key={channel.id}
|
||||||
trigger={renderResponseTime(channel.response_time)}
|
trigger={renderResponseTime(channel.response_time)}
|
||||||
basic
|
basic
|
||||||
@@ -484,27 +537,38 @@ const ChannelsTable = () => {
|
|||||||
</Table.Cell>
|
</Table.Cell>
|
||||||
<Table.Cell hidden={!showDetail}>
|
<Table.Cell hidden={!showDetail}>
|
||||||
<Popup
|
<Popup
|
||||||
trigger={<span onClick={() => {
|
trigger={
|
||||||
updateChannelBalance(channel.id, channel.name, idx);
|
<span
|
||||||
}} style={{ cursor: 'pointer' }}>
|
onClick={() => {
|
||||||
{renderBalance(channel.type, channel.balance)}
|
updateChannelBalance(channel.id, channel.name, idx);
|
||||||
</span>}
|
}}
|
||||||
|
style={{ cursor: 'pointer' }}
|
||||||
|
>
|
||||||
|
{renderBalance(channel.type, channel.balance)}
|
||||||
|
</span>
|
||||||
|
}
|
||||||
content='点击更新'
|
content='点击更新'
|
||||||
basic
|
basic
|
||||||
/>
|
/>
|
||||||
</Table.Cell>
|
</Table.Cell>
|
||||||
<Table.Cell>
|
<Table.Cell>
|
||||||
<Popup
|
<Popup
|
||||||
trigger={<Input type='number' defaultValue={channel.priority} onBlur={(event) => {
|
trigger={
|
||||||
manageChannel(
|
<Input
|
||||||
channel.id,
|
type='number'
|
||||||
'priority',
|
defaultValue={channel.priority}
|
||||||
idx,
|
onBlur={(event) => {
|
||||||
event.target.value
|
manageChannel(
|
||||||
);
|
channel.id,
|
||||||
}}>
|
'priority',
|
||||||
<input style={{ maxWidth: '60px' }} />
|
idx,
|
||||||
</Input>}
|
event.target.value
|
||||||
|
);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<input style={{ maxWidth: '60px' }} />
|
||||||
|
</Input>
|
||||||
|
}
|
||||||
content='渠道选择优先级,越高越优先'
|
content='渠道选择优先级,越高越优先'
|
||||||
basic
|
basic
|
||||||
/>
|
/>
|
||||||
@@ -526,7 +590,12 @@ const ChannelsTable = () => {
|
|||||||
size={'small'}
|
size={'small'}
|
||||||
positive
|
positive
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
testChannel(channel.id, channel.name, idx, channel.test_model);
|
testChannel(
|
||||||
|
channel.id,
|
||||||
|
channel.name,
|
||||||
|
idx,
|
||||||
|
channel.test_model
|
||||||
|
);
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
测试
|
测试
|
||||||
@@ -588,14 +657,31 @@ const ChannelsTable = () => {
|
|||||||
|
|
||||||
<Table.Footer>
|
<Table.Footer>
|
||||||
<Table.Row>
|
<Table.Row>
|
||||||
<Table.HeaderCell colSpan={showDetail ? "10" : "8"}>
|
<Table.HeaderCell colSpan={showDetail ? '10' : '8'}>
|
||||||
<Button size='small' as={Link} to='/channel/add' loading={loading}>
|
<Button
|
||||||
|
size='small'
|
||||||
|
as={Link}
|
||||||
|
to='/channel/add'
|
||||||
|
loading={loading}
|
||||||
|
>
|
||||||
添加新的渠道
|
添加新的渠道
|
||||||
</Button>
|
</Button>
|
||||||
<Button size='small' loading={loading} onClick={()=>{testChannels("all")}}>
|
<Button
|
||||||
|
size='small'
|
||||||
|
loading={loading}
|
||||||
|
onClick={() => {
|
||||||
|
testChannels('all');
|
||||||
|
}}
|
||||||
|
>
|
||||||
测试所有渠道
|
测试所有渠道
|
||||||
</Button>
|
</Button>
|
||||||
<Button size='small' loading={loading} onClick={()=>{testChannels("disabled")}}>
|
<Button
|
||||||
|
size='small'
|
||||||
|
loading={loading}
|
||||||
|
onClick={() => {
|
||||||
|
testChannels('disabled');
|
||||||
|
}}
|
||||||
|
>
|
||||||
测试禁用渠道
|
测试禁用渠道
|
||||||
</Button>
|
</Button>
|
||||||
{/*<Button size='small' onClick={updateAllChannelsBalance}*/}
|
{/*<Button size='small' onClick={updateAllChannelsBalance}*/}
|
||||||
@@ -610,7 +696,12 @@ const ChannelsTable = () => {
|
|||||||
flowing
|
flowing
|
||||||
hoverable
|
hoverable
|
||||||
>
|
>
|
||||||
<Button size='small' loading={loading} negative onClick={deleteAllDisabledChannels}>
|
<Button
|
||||||
|
size='small'
|
||||||
|
loading={loading}
|
||||||
|
negative
|
||||||
|
onClick={deleteAllDisabledChannels}
|
||||||
|
>
|
||||||
确认删除
|
确认删除
|
||||||
</Button>
|
</Button>
|
||||||
</Popup>
|
</Popup>
|
||||||
@@ -625,8 +716,12 @@ const ChannelsTable = () => {
|
|||||||
(channels.length % ITEMS_PER_PAGE === 0 ? 1 : 0)
|
(channels.length % ITEMS_PER_PAGE === 0 ? 1 : 0)
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
<Button size='small' onClick={refresh} loading={loading}>刷新</Button>
|
<Button size='small' onClick={refresh} loading={loading}>
|
||||||
<Button size='small' onClick={toggleShowDetail}>{showDetail ? "隐藏详情" : "详情"}</Button>
|
刷新
|
||||||
|
</Button>
|
||||||
|
<Button size='small' onClick={toggleShowDetail}>
|
||||||
|
{showDetail ? '隐藏详情' : '详情'}
|
||||||
|
</Button>
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
</Table.Row>
|
</Table.Row>
|
||||||
</Table.Footer>
|
</Table.Footer>
|
||||||
|
|||||||
@@ -1,21 +1,48 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import { Button, Form, Header, Label, Pagination, Segment, Select, Table } from 'semantic-ui-react';
|
import {
|
||||||
import { API, isAdmin, showError, timestamp2string } from '../helpers';
|
Button,
|
||||||
|
Form,
|
||||||
|
Header,
|
||||||
|
Label,
|
||||||
|
Pagination,
|
||||||
|
Segment,
|
||||||
|
Select,
|
||||||
|
Table,
|
||||||
|
} from 'semantic-ui-react';
|
||||||
|
import {
|
||||||
|
API,
|
||||||
|
copy,
|
||||||
|
isAdmin,
|
||||||
|
showError,
|
||||||
|
showSuccess,
|
||||||
|
showWarning,
|
||||||
|
timestamp2string,
|
||||||
|
} from '../helpers';
|
||||||
|
|
||||||
import { ITEMS_PER_PAGE } from '../constants';
|
import { ITEMS_PER_PAGE } from '../constants';
|
||||||
import { renderQuota } from '../helpers/render';
|
import { renderColorLabel, renderQuota } from '../helpers/render';
|
||||||
|
import { Link } from 'react-router-dom';
|
||||||
|
|
||||||
function renderTimestamp(timestamp) {
|
function renderTimestamp(timestamp, request_id) {
|
||||||
return (
|
return (
|
||||||
<>
|
<code
|
||||||
|
onClick={async () => {
|
||||||
|
if (await copy(request_id)) {
|
||||||
|
showSuccess(`已复制请求 ID:${request_id}`);
|
||||||
|
} else {
|
||||||
|
showWarning(`请求 ID 复制失败:${request_id}`);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
style={{ cursor: 'pointer' }}
|
||||||
|
>
|
||||||
{timestamp2string(timestamp)}
|
{timestamp2string(timestamp)}
|
||||||
</>
|
</code>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const MODE_OPTIONS = [
|
const MODE_OPTIONS = [
|
||||||
{ key: 'all', text: '全部用户', value: 'all' },
|
{ key: 'all', text: '全部用户', value: 'all' },
|
||||||
{ key: 'self', text: '当前用户', value: 'self' }
|
{ key: 'self', text: '当前用户', value: 'self' },
|
||||||
];
|
];
|
||||||
|
|
||||||
const LOG_OPTIONS = [
|
const LOG_OPTIONS = [
|
||||||
@@ -23,24 +50,92 @@ const LOG_OPTIONS = [
|
|||||||
{ key: '1', text: '充值', value: 1 },
|
{ key: '1', text: '充值', value: 1 },
|
||||||
{ key: '2', text: '消费', value: 2 },
|
{ key: '2', text: '消费', value: 2 },
|
||||||
{ key: '3', text: '管理', value: 3 },
|
{ key: '3', text: '管理', value: 3 },
|
||||||
{ key: '4', text: '系统', value: 4 }
|
{ key: '4', text: '系统', value: 4 },
|
||||||
|
{ key: '5', text: '测试', value: 5 },
|
||||||
];
|
];
|
||||||
|
|
||||||
function renderType(type) {
|
function renderType(type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case 1:
|
case 1:
|
||||||
return <Label basic color='green'> 充值 </Label>;
|
return (
|
||||||
|
<Label basic color='green'>
|
||||||
|
充值
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
case 2:
|
case 2:
|
||||||
return <Label basic color='olive'> 消费 </Label>;
|
return (
|
||||||
|
<Label basic color='olive'>
|
||||||
|
消费
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
case 3:
|
case 3:
|
||||||
return <Label basic color='orange'> 管理 </Label>;
|
return (
|
||||||
|
<Label basic color='orange'>
|
||||||
|
管理
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
case 4:
|
case 4:
|
||||||
return <Label basic color='purple'> 系统 </Label>;
|
return (
|
||||||
|
<Label basic color='purple'>
|
||||||
|
系统
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
|
case 5:
|
||||||
|
return (
|
||||||
|
<Label basic color='violet'>
|
||||||
|
测试
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
default:
|
default:
|
||||||
return <Label basic color='black'> 未知 </Label>;
|
return (
|
||||||
|
<Label basic color='black'>
|
||||||
|
未知
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getColorByElapsedTime(elapsedTime) {
|
||||||
|
if (elapsedTime === undefined || 0) return 'black';
|
||||||
|
if (elapsedTime < 1000) return 'green';
|
||||||
|
if (elapsedTime < 3000) return 'olive';
|
||||||
|
if (elapsedTime < 5000) return 'yellow';
|
||||||
|
if (elapsedTime < 10000) return 'orange';
|
||||||
|
return 'red';
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderDetail(log) {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{log.content}
|
||||||
|
<br />
|
||||||
|
{log.elapsed_time && (
|
||||||
|
<Label
|
||||||
|
basic
|
||||||
|
size={'mini'}
|
||||||
|
color={getColorByElapsedTime(log.elapsed_time)}
|
||||||
|
>
|
||||||
|
{log.elapsed_time} ms
|
||||||
|
</Label>
|
||||||
|
)}
|
||||||
|
{log.is_stream && (
|
||||||
|
<>
|
||||||
|
<Label size={'mini'} color='pink'>
|
||||||
|
Stream
|
||||||
|
</Label>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{log.system_prompt_reset && (
|
||||||
|
<>
|
||||||
|
<Label basic size={'mini'} color='red'>
|
||||||
|
System Prompt Reset
|
||||||
|
</Label>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
const LogsTable = () => {
|
const LogsTable = () => {
|
||||||
const [logs, setLogs] = useState([]);
|
const [logs, setLogs] = useState([]);
|
||||||
const [showStat, setShowStat] = useState(false);
|
const [showStat, setShowStat] = useState(false);
|
||||||
@@ -57,13 +152,20 @@ const LogsTable = () => {
|
|||||||
model_name: '',
|
model_name: '',
|
||||||
start_timestamp: timestamp2string(0),
|
start_timestamp: timestamp2string(0),
|
||||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||||
channel: ''
|
channel: '',
|
||||||
});
|
});
|
||||||
const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
|
const {
|
||||||
|
username,
|
||||||
|
token_name,
|
||||||
|
model_name,
|
||||||
|
start_timestamp,
|
||||||
|
end_timestamp,
|
||||||
|
channel,
|
||||||
|
} = inputs;
|
||||||
|
|
||||||
const [stat, setStat] = useState({
|
const [stat, setStat] = useState({
|
||||||
quota: 0,
|
quota: 0,
|
||||||
token: 0
|
token: 0,
|
||||||
});
|
});
|
||||||
|
|
||||||
const handleInputChange = (e, { name, value }) => {
|
const handleInputChange = (e, { name, value }) => {
|
||||||
@@ -73,7 +175,9 @@ const LogsTable = () => {
|
|||||||
const getLogSelfStat = async () => {
|
const getLogSelfStat = async () => {
|
||||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||||
let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
|
let res = await API.get(
|
||||||
|
`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`
|
||||||
|
);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
setStat(data);
|
setStat(data);
|
||||||
@@ -85,7 +189,9 @@ const LogsTable = () => {
|
|||||||
const getLogStat = async () => {
|
const getLogStat = async () => {
|
||||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||||
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
|
let res = await API.get(
|
||||||
|
`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`
|
||||||
|
);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
setStat(data);
|
setStat(data);
|
||||||
@@ -105,6 +211,10 @@ const LogsTable = () => {
|
|||||||
setShowStat(!showStat);
|
setShowStat(!showStat);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const showUserTokenQuota = () => {
|
||||||
|
return logType !== 5;
|
||||||
|
};
|
||||||
|
|
||||||
const loadLogs = async (startIdx) => {
|
const loadLogs = async (startIdx) => {
|
||||||
let url = '';
|
let url = '';
|
||||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||||
@@ -201,37 +311,82 @@ const LogsTable = () => {
|
|||||||
<Header as='h3'>
|
<Header as='h3'>
|
||||||
使用明细(总消耗额度:
|
使用明细(总消耗额度:
|
||||||
{showStat && renderQuota(stat.quota)}
|
{showStat && renderQuota(stat.quota)}
|
||||||
{!showStat && <span onClick={handleEyeClick} style={{ cursor: 'pointer', color: 'gray' }}>点击查看</span>}
|
{!showStat && (
|
||||||
|
<span
|
||||||
|
onClick={handleEyeClick}
|
||||||
|
style={{ cursor: 'pointer', color: 'gray' }}
|
||||||
|
>
|
||||||
|
点击查看
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
)
|
)
|
||||||
</Header>
|
</Header>
|
||||||
<Form>
|
<Form>
|
||||||
<Form.Group>
|
<Form.Group>
|
||||||
<Form.Input fluid label={'令牌名称'} width={3} value={token_name}
|
<Form.Input
|
||||||
placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
|
fluid
|
||||||
<Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值'
|
label={'令牌名称'}
|
||||||
name='model_name'
|
width={3}
|
||||||
onChange={handleInputChange} />
|
value={token_name}
|
||||||
<Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
|
placeholder={'可选值'}
|
||||||
name='start_timestamp'
|
name='token_name'
|
||||||
onChange={handleInputChange} />
|
onChange={handleInputChange}
|
||||||
<Form.Input fluid label='结束时间' width={4} value={end_timestamp} type='datetime-local'
|
/>
|
||||||
name='end_timestamp'
|
<Form.Input
|
||||||
onChange={handleInputChange} />
|
fluid
|
||||||
<Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
|
label='模型名称'
|
||||||
|
width={3}
|
||||||
|
value={model_name}
|
||||||
|
placeholder='可选值'
|
||||||
|
name='model_name'
|
||||||
|
onChange={handleInputChange}
|
||||||
|
/>
|
||||||
|
<Form.Input
|
||||||
|
fluid
|
||||||
|
label='起始时间'
|
||||||
|
width={4}
|
||||||
|
value={start_timestamp}
|
||||||
|
type='datetime-local'
|
||||||
|
name='start_timestamp'
|
||||||
|
onChange={handleInputChange}
|
||||||
|
/>
|
||||||
|
<Form.Input
|
||||||
|
fluid
|
||||||
|
label='结束时间'
|
||||||
|
width={4}
|
||||||
|
value={end_timestamp}
|
||||||
|
type='datetime-local'
|
||||||
|
name='end_timestamp'
|
||||||
|
onChange={handleInputChange}
|
||||||
|
/>
|
||||||
|
<Form.Button fluid label='操作' width={2} onClick={refresh}>
|
||||||
|
查询
|
||||||
|
</Form.Button>
|
||||||
</Form.Group>
|
</Form.Group>
|
||||||
{
|
{isAdminUser && (
|
||||||
isAdminUser && <>
|
<>
|
||||||
<Form.Group>
|
<Form.Group>
|
||||||
<Form.Input fluid label={'渠道 ID'} width={3} value={channel}
|
<Form.Input
|
||||||
placeholder='可选值' name='channel'
|
fluid
|
||||||
onChange={handleInputChange} />
|
label={'渠道 ID'}
|
||||||
<Form.Input fluid label={'用户名称'} width={3} value={username}
|
width={3}
|
||||||
placeholder={'可选值'} name='username'
|
value={channel}
|
||||||
onChange={handleInputChange} />
|
placeholder='可选值'
|
||||||
|
name='channel'
|
||||||
|
onChange={handleInputChange}
|
||||||
|
/>
|
||||||
|
<Form.Input
|
||||||
|
fluid
|
||||||
|
label={'用户名称'}
|
||||||
|
width={3}
|
||||||
|
value={username}
|
||||||
|
placeholder={'可选值'}
|
||||||
|
name='username'
|
||||||
|
onChange={handleInputChange}
|
||||||
|
/>
|
||||||
</Form.Group>
|
</Form.Group>
|
||||||
</>
|
</>
|
||||||
}
|
)}
|
||||||
</Form>
|
</Form>
|
||||||
<Table basic compact size='small'>
|
<Table basic compact size='small'>
|
||||||
<Table.Header>
|
<Table.Header>
|
||||||
@@ -245,8 +400,8 @@ const LogsTable = () => {
|
|||||||
>
|
>
|
||||||
时间
|
时间
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
{
|
{isAdminUser && (
|
||||||
isAdminUser && <Table.HeaderCell
|
<Table.HeaderCell
|
||||||
style={{ cursor: 'pointer' }}
|
style={{ cursor: 'pointer' }}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
sortLog('channel');
|
sortLog('channel');
|
||||||
@@ -255,27 +410,7 @@ const LogsTable = () => {
|
|||||||
>
|
>
|
||||||
渠道
|
渠道
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
}
|
)}
|
||||||
{
|
|
||||||
isAdminUser && <Table.HeaderCell
|
|
||||||
style={{ cursor: 'pointer' }}
|
|
||||||
onClick={() => {
|
|
||||||
sortLog('username');
|
|
||||||
}}
|
|
||||||
width={1}
|
|
||||||
>
|
|
||||||
用户
|
|
||||||
</Table.HeaderCell>
|
|
||||||
}
|
|
||||||
<Table.HeaderCell
|
|
||||||
style={{ cursor: 'pointer' }}
|
|
||||||
onClick={() => {
|
|
||||||
sortLog('token_name');
|
|
||||||
}}
|
|
||||||
width={1}
|
|
||||||
>
|
|
||||||
令牌
|
|
||||||
</Table.HeaderCell>
|
|
||||||
<Table.HeaderCell
|
<Table.HeaderCell
|
||||||
style={{ cursor: 'pointer' }}
|
style={{ cursor: 'pointer' }}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
@@ -294,33 +429,57 @@ const LogsTable = () => {
|
|||||||
>
|
>
|
||||||
模型
|
模型
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
<Table.HeaderCell
|
{showUserTokenQuota() && (
|
||||||
style={{ cursor: 'pointer' }}
|
<>
|
||||||
onClick={() => {
|
{isAdminUser && (
|
||||||
sortLog('prompt_tokens');
|
<Table.HeaderCell
|
||||||
}}
|
style={{ cursor: 'pointer' }}
|
||||||
width={1}
|
onClick={() => {
|
||||||
>
|
sortLog('username');
|
||||||
提示
|
}}
|
||||||
</Table.HeaderCell>
|
width={1}
|
||||||
<Table.HeaderCell
|
>
|
||||||
style={{ cursor: 'pointer' }}
|
用户
|
||||||
onClick={() => {
|
</Table.HeaderCell>
|
||||||
sortLog('completion_tokens');
|
)}
|
||||||
}}
|
<Table.HeaderCell
|
||||||
width={1}
|
style={{ cursor: 'pointer' }}
|
||||||
>
|
onClick={() => {
|
||||||
补全
|
sortLog('token_name');
|
||||||
</Table.HeaderCell>
|
}}
|
||||||
<Table.HeaderCell
|
width={1}
|
||||||
style={{ cursor: 'pointer' }}
|
>
|
||||||
onClick={() => {
|
令牌
|
||||||
sortLog('quota');
|
</Table.HeaderCell>
|
||||||
}}
|
<Table.HeaderCell
|
||||||
width={1}
|
style={{ cursor: 'pointer' }}
|
||||||
>
|
onClick={() => {
|
||||||
额度
|
sortLog('prompt_tokens');
|
||||||
</Table.HeaderCell>
|
}}
|
||||||
|
width={1}
|
||||||
|
>
|
||||||
|
提示
|
||||||
|
</Table.HeaderCell>
|
||||||
|
<Table.HeaderCell
|
||||||
|
style={{ cursor: 'pointer' }}
|
||||||
|
onClick={() => {
|
||||||
|
sortLog('completion_tokens');
|
||||||
|
}}
|
||||||
|
width={1}
|
||||||
|
>
|
||||||
|
补全
|
||||||
|
</Table.HeaderCell>
|
||||||
|
<Table.HeaderCell
|
||||||
|
style={{ cursor: 'pointer' }}
|
||||||
|
onClick={() => {
|
||||||
|
sortLog('quota');
|
||||||
|
}}
|
||||||
|
width={1}
|
||||||
|
>
|
||||||
|
额度
|
||||||
|
</Table.HeaderCell>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
<Table.HeaderCell
|
<Table.HeaderCell
|
||||||
style={{ cursor: 'pointer' }}
|
style={{ cursor: 'pointer' }}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
@@ -343,24 +502,64 @@ const LogsTable = () => {
|
|||||||
if (log.deleted) return <></>;
|
if (log.deleted) return <></>;
|
||||||
return (
|
return (
|
||||||
<Table.Row key={log.id}>
|
<Table.Row key={log.id}>
|
||||||
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
|
<Table.Cell>
|
||||||
{
|
{renderTimestamp(log.created_at, log.request_id)}
|
||||||
isAdminUser && (
|
</Table.Cell>
|
||||||
<Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell>
|
{isAdminUser && (
|
||||||
)
|
<Table.Cell>
|
||||||
}
|
{log.channel ? (
|
||||||
{
|
<Label
|
||||||
isAdminUser && (
|
basic
|
||||||
<Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
|
as={Link}
|
||||||
)
|
to={`/channel/edit/${log.channel}`}
|
||||||
}
|
>
|
||||||
<Table.Cell>{log.token_name ? <Label basic>{log.token_name}</Label> : ''}</Table.Cell>
|
{log.channel}
|
||||||
|
</Label>
|
||||||
|
) : (
|
||||||
|
''
|
||||||
|
)}
|
||||||
|
</Table.Cell>
|
||||||
|
)}
|
||||||
<Table.Cell>{renderType(log.type)}</Table.Cell>
|
<Table.Cell>{renderType(log.type)}</Table.Cell>
|
||||||
<Table.Cell>{log.model_name ? <Label basic>{log.model_name}</Label> : ''}</Table.Cell>
|
<Table.Cell>
|
||||||
<Table.Cell>{log.prompt_tokens ? log.prompt_tokens : ''}</Table.Cell>
|
{log.model_name ? renderColorLabel(log.model_name) : ''}
|
||||||
<Table.Cell>{log.completion_tokens ? log.completion_tokens : ''}</Table.Cell>
|
</Table.Cell>
|
||||||
<Table.Cell>{log.quota ? renderQuota(log.quota, 6) : ''}</Table.Cell>
|
{showUserTokenQuota() && (
|
||||||
<Table.Cell>{log.content}</Table.Cell>
|
<>
|
||||||
|
{isAdminUser && (
|
||||||
|
<Table.Cell>
|
||||||
|
{log.username ? (
|
||||||
|
<Label
|
||||||
|
basic
|
||||||
|
as={Link}
|
||||||
|
to={`/user/edit/${log.user_id}`}
|
||||||
|
>
|
||||||
|
{log.username}
|
||||||
|
</Label>
|
||||||
|
) : (
|
||||||
|
''
|
||||||
|
)}
|
||||||
|
</Table.Cell>
|
||||||
|
)}
|
||||||
|
<Table.Cell>
|
||||||
|
{log.token_name
|
||||||
|
? renderColorLabel(log.token_name)
|
||||||
|
: ''}
|
||||||
|
</Table.Cell>
|
||||||
|
|
||||||
|
<Table.Cell>
|
||||||
|
{log.prompt_tokens ? log.prompt_tokens : ''}
|
||||||
|
</Table.Cell>
|
||||||
|
<Table.Cell>
|
||||||
|
{log.completion_tokens ? log.completion_tokens : ''}
|
||||||
|
</Table.Cell>
|
||||||
|
<Table.Cell>
|
||||||
|
{log.quota ? renderQuota(log.quota, 6) : ''}
|
||||||
|
</Table.Cell>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<Table.Cell>{renderDetail(log)}</Table.Cell>
|
||||||
</Table.Row>
|
</Table.Row>
|
||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
@@ -379,7 +578,9 @@ const LogsTable = () => {
|
|||||||
setLogType(value);
|
setLogType(value);
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
<Button size='small' onClick={refresh} loading={loading}>刷新</Button>
|
<Button size='small' onClick={refresh} loading={loading}>
|
||||||
|
刷新
|
||||||
|
</Button>
|
||||||
<Pagination
|
<Pagination
|
||||||
floated='right'
|
floated='right'
|
||||||
activePage={activePage}
|
activePage={activePage}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
|
|||||||
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
||||||
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
||||||
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
||||||
|
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||||
|
|||||||
@@ -13,16 +13,18 @@ export function renderGroup(group) {
|
|||||||
}
|
}
|
||||||
let groups = group.split(',');
|
let groups = group.split(',');
|
||||||
groups.sort();
|
groups.sort();
|
||||||
return <>
|
return (
|
||||||
{groups.map((group) => {
|
<>
|
||||||
if (group === 'vip' || group === 'pro') {
|
{groups.map((group) => {
|
||||||
return <Label color='yellow'>{group}</Label>;
|
if (group === 'vip' || group === 'pro') {
|
||||||
} else if (group === 'svip' || group === 'premium') {
|
return <Label color='yellow'>{group}</Label>;
|
||||||
return <Label color='red'>{group}</Label>;
|
} else if (group === 'svip' || group === 'premium') {
|
||||||
}
|
return <Label color='red'>{group}</Label>;
|
||||||
return <Label>{group}</Label>;
|
}
|
||||||
})}
|
return <Label>{group}</Label>;
|
||||||
</>;
|
})}
|
||||||
|
</>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function renderNumber(num) {
|
export function renderNumber(num) {
|
||||||
@@ -55,4 +57,33 @@ export function renderQuotaWithPrompt(quota, digits) {
|
|||||||
return `(等价金额:${renderQuota(quota, digits)})`;
|
return `(等价金额:${renderQuota(quota, digits)})`;
|
||||||
}
|
}
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const colors = [
|
||||||
|
'red',
|
||||||
|
'orange',
|
||||||
|
'yellow',
|
||||||
|
'olive',
|
||||||
|
'green',
|
||||||
|
'teal',
|
||||||
|
'blue',
|
||||||
|
'violet',
|
||||||
|
'purple',
|
||||||
|
'pink',
|
||||||
|
'brown',
|
||||||
|
'grey',
|
||||||
|
'black',
|
||||||
|
];
|
||||||
|
|
||||||
|
export function renderColorLabel(text) {
|
||||||
|
let hash = 0;
|
||||||
|
for (let i = 0; i < text.length; i++) {
|
||||||
|
hash = text.charCodeAt(i) + ((hash << 5) - hash);
|
||||||
|
}
|
||||||
|
let index = Math.abs(hash % colors.length);
|
||||||
|
return (
|
||||||
|
<Label basic color={colors[index]}>
|
||||||
|
{text}
|
||||||
|
</Label>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ const EditChannel = () => {
|
|||||||
base_url: '',
|
base_url: '',
|
||||||
other: '',
|
other: '',
|
||||||
model_mapping: '',
|
model_mapping: '',
|
||||||
|
system_prompt: '',
|
||||||
models: [],
|
models: [],
|
||||||
groups: ['default']
|
groups: ['default']
|
||||||
};
|
};
|
||||||
@@ -425,7 +426,7 @@ const EditChannel = () => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
inputs.type !== 43 && (
|
inputs.type !== 43 && (<>
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.TextArea
|
<Form.TextArea
|
||||||
label='模型重定向'
|
label='模型重定向'
|
||||||
@@ -437,6 +438,18 @@ const EditChannel = () => {
|
|||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
|
<Form.Field>
|
||||||
|
<Form.TextArea
|
||||||
|
label='系统提示词'
|
||||||
|
placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`}
|
||||||
|
name='system_prompt'
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.system_prompt}
|
||||||
|
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user