Compare commits

..

38 Commits

Author SHA1 Message Date
JustSong
fa2a772731 feat: able to query test log 2025-01-31 21:23:12 +08:00
JustSong
4f68f3e1b3 chore: update log content 2025-01-31 20:16:56 +08:00
JustSong
0bab887b2d chore: update log content 2025-01-31 20:15:04 +08:00
JustSong
0230d36643 feat: update log table style 2025-01-31 20:06:43 +08:00
JustSong
bad57d049a feat: update log table style 2025-01-31 20:02:51 +08:00
JustSong
dc470ce82e feat: show stream & elapsed time in log detail 2025-01-31 19:34:22 +08:00
JustSong
ea0721d525 feat: update log content format 2025-01-31 18:15:43 +08:00
JustSong
d0402f9086 feat: record request_id 2025-01-31 17:54:04 +08:00
JustSong
1fead8e7f7 chore: add debug log for distributor 2025-01-31 17:26:33 +08:00
Fennng
09911a301d feat: support hunyuan-embedding (#2035)
* feat: support hunyuan-embedding

* chore: improve implementation

---------

Co-authored-by: LUO Feng <luofeng@flowpp.com>
Co-authored-by: JustSong <quanpengsong@gmail.com>
2025-01-31 16:48:02 +08:00
chenzikun
f95e6b78b8 fix: fix berry copy token (#2041)
* [bugfix]修复copy问题

* [update]两阶段编译代码

---------

Co-authored-by: zicorn <a24395@autel.com>
2025-01-31 16:12:59 +08:00
JustSong
605bb06667 feat: update logger 2025-01-31 16:00:53 +08:00
Laisky.Cai
d88e07fd9a feat: add deepseek-reasoner & gemini-2.0-flash-thinking-exp-01-21 (#2045)
* feat: add MILLI_USD constant and update pricing for deepseek services

* feat: add support for new Gemini model version 'gemini-2.0-flash-thinking-exp-01-21'
2025-01-31 15:15:59 +08:00
JustSong
3915ce9814 chore: update ci yaml 2024-12-27 22:01:37 +08:00
JustSong
999defc88b chore: update readme 2024-12-27 21:59:38 +08:00
JustSong
b51c47bc77 docs: update README.md 2024-12-27 21:45:51 +08:00
JustSong
4f25cde132 fix: add branch check 2024-12-27 20:41:20 +08:00
JustSong
d89e9d7e44 fix: add branch limitation and drop pull_request trigger for ci.yml 2024-12-27 20:34:04 +08:00
Qiying Wang
a858292b54 feat: support gpt-4o-2024-11-20 (#1941) 2024-12-22 19:49:50 +08:00
Yuwei Ba
ff589b5e4a chore: update model mapping implementation for audio (#1932)
* fixed model mapping

* chore: update implementation

---------

Co-authored-by: JustSong <quanpengsong@gmail.com>
2024-12-22 19:33:11 +08:00
Ke Wang
95e8c16338 feat: add balance query support for DeepSeek (#1946)
* Support Balance Query for DeepSeek

* Fix
2024-12-22 19:26:33 +08:00
lihangfu
381172cb36 feat: support Redis Sentinel and Redis Cluster (#1952)
* feature: support Redis Sentinel and Redis Cluster

* chore: update implementation

---------

Co-authored-by: JustSong <quanpengsong@gmail.com>
2024-12-22 19:21:24 +08:00
ZhangTianrong
59eae186a3 fix: remove the duplicate claude-3-5-haiku-20241022 in Anthropic's base model list (#1957)
* Update constants.go

Remove the duplicate `claude-3-5-haiku-20241022` causing issue 1928

* fix: fix syntax error

---------

Co-authored-by: JustSong <quanpengsong@gmail.com>
2024-12-22 18:58:29 +08:00
bestlaw66
ce52f355bb docs: add tutorial section for BT Panel installation (#1985)
* Update README.md

在国内有大部分用户都在使用宝塔面板管理服务器,因此增加使用宝塔面板部署的教程,可视化的部署方式可以帮助用户更加便捷的部署one-api

* docs: update readme

---------

Co-authored-by: JustSong <quanpengsong@gmail.com>
2024-12-22 18:55:04 +08:00
Ke Wang
cb9d0a74c9 fix: fix balance query for siliconflow (#1960) 2024-12-22 18:48:47 +08:00
Wei Tingjiang
49ffb1c60d feat: enhance response handling to support gemini-2.0-thinking (#1995) 2024-12-22 18:25:44 +08:00
longkeyy
2f16649896 feat: update qwen model and price (#1966) 2024-12-22 18:22:57 +08:00
dependabot[bot]
af3aa57bd6 chore(deps): bump golang.org/x/crypto from 0.24.0 to 0.31.0 (#1976)
Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.24.0 to 0.31.0.
- [Commits](https://github.com/golang/crypto/compare/v0.24.0...v0.31.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-22 18:21:00 +08:00
Laisky.Cai
e9f117ff72 feat: add gemini-2.0-flash-exp and fix race condition in processChannelRelayError (#1983)
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2024-12-21 20:32:30 +08:00
Laisky.Cai
6bb5247bd6 feat: add support for new OpenAI models and update billing ratios (#1990) 2024-12-21 20:28:51 +08:00
Laisky.Cai
305ce14fe3 feat: support replicate chat models (#1989)
* feat: add Replicate adaptor and integrate into channel and API types

* feat: support llm chat on replicate
2024-12-21 14:41:19 +08:00
JustSong
36c8f4f15c docs: update readme 2024-12-21 00:33:20 +08:00
JustSong
45b51ea0ee feat: update feishu oauth login 2024-12-20 23:27:00 +08:00
Calcium-Ion
7c8628bd95 feat: support gzip decode (#1962) 2024-12-04 23:34:24 +08:00
JustSong
6ab87f8a08 feat: add warning in log when system prompt is reset 2024-11-10 17:18:46 +08:00
JustSong
833fa7ad6f feat: support set system_prompt for theme air & berry 2024-11-10 15:09:02 +08:00
JustSong
6eb0770a89 feat: support set system prompt for channel (close #1920) 2024-11-10 14:53:34 +08:00
JustSong
92cd46d64f feat: able to use ENFORCE_INCLUDE_USAGE to enforce include usage in response 2024-11-10 00:36:08 +08:00
75 changed files with 2705 additions and 897 deletions

View File

@@ -1,19 +1,17 @@
name: CI name: CI
# This setup assumes that you run the unit tests with code coverage in the same # This setup assumes that you run the unit tests with code coverage in the same
# workflow that will also print the coverage report as comment to the pull request. # workflow that will also print the coverage report as comment to the pull request.
# Therefore, you need to trigger this workflow when a pull request is (re)opened or # Therefore, you need to trigger this workflow when a pull request is (re)opened or
# when new code is pushed to the branch of the pull request. In addition, you also # when new code is pushed to the branch of the pull request. In addition, you also
# need to trigger this workflow when new code is pushed to the main branch because # need to trigger this workflow when new code is pushed to the main branch because
# we need to upload the code coverage results as artifact for the main branch as # we need to upload the code coverage results as artifact for the main branch as
# well since it will be the baseline code coverage. # well since it will be the baseline code coverage.
# #
# We do not want to trigger the workflow for pushes to *any* branch because this # We do not want to trigger the workflow for pushes to *any* branch because this
# would trigger our jobs twice on pull requests (once from "push" event and once # would trigger our jobs twice on pull requests (once from "push" event and once
# from "pull_request->synchronize") # from "pull_request->synchronize")
on: on:
pull_request:
types: [opened, reopened, synchronize]
push: push:
branches: branches:
- 'main' - 'main'
@@ -31,7 +29,7 @@ jobs:
with: with:
go-version: ^1.22 go-version: ^1.22
# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
# coverage profile to a file. You will need the name of the file (e.g. "coverage.txt") # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
# in the next step as well as the next job. # in the next step as well as the next job.
- name: Test - name: Test

3
.gitignore vendored
View File

@@ -9,4 +9,5 @@ logs
data data
/web/node_modules /web/node_modules
cmd.md cmd.md
.env .env
/one-api

View File

@@ -115,8 +115,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
21. 支持 Cloudflare Turnstile 用户校验。 21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式** 22. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ 支持使用飞书进行授权登录 + 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)[这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login)
+ [GitHub 开放授权](https://github.com/settings/applications/new)。 + 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
@@ -175,6 +175,10 @@ sudo service nginx restart
初始账号用户名为 `root`,密码为 `123456` 初始账号用户名为 `root`,密码为 `123456`
### 通过宝塔面板进行一键部署
1. 安装宝塔面板9.2.0及以上版本,前往 [宝塔面板](https://www.bt.cn/new/download.html?r=dk_oneapi) 官网,选择正式版的脚本下载安装;
2. 安装后登录宝塔面板,在左侧菜单栏中点击 `Docker`,首次进入会提示安装 `Docker` 服务,点击立即安装,按提示完成安装;
3. 安装完成后在应用商店中搜索 `One-API`,点击安装,配置域名等基本信息即可完成安装;
### 基于 Docker Compose 进行部署 ### 基于 Docker Compose 进行部署
@@ -218,7 +222,7 @@ docker-compose ps
3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。
4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis无论主从。 4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis无论主从。
5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。 5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。
6. 从服务器上**分别**装好 Redis设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟。 6. 从服务器上**分别**装好 Redis设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟Redis 集群或者哨兵模式的支持请参考环境变量说明)
7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。 7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。
环境变量的具体使用方法详见[此处](#环境变量)。 环境变量的具体使用方法详见[此处](#环境变量)。
@@ -347,6 +351,11 @@ graph LR
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
+ 如果数据库访问延迟很低,没有必要启用 Redis启用后反而会出现数据滞后的问题。 + 如果数据库访问延迟很低,没有必要启用 Redis启用后反而会出现数据滞后的问题。
+ 如果需要使用哨兵或者集群模式:
+ 则需要把该环境变量设置为节点列表,例如:`localhost:49153,localhost:49154,localhost:49155`。
+ 除此之外还需要设置以下环境变量:
+ `REDIS_PASSWORD`Redis 集群或者哨兵模式下的密码设置。
+ `REDIS_MASTER_NAME`Redis 哨兵模式下主节点的名称。
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
+ 例子:`SESSION_SECRET=random_string` + 例子:`SESSION_SECRET=random_string`
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite请使用 MySQL 或 PostgreSQL。 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite请使用 MySQL 或 PostgreSQL。
@@ -400,6 +409,8 @@ graph LR
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage默认不开启可选值为 `true` 和 `false`。
30. `TEST_PROMPT`:测试模型时的用户 prompt默认为 `Print your model name exactly and do not output without any other text.`。
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

View File

@@ -1,13 +1,14 @@
package config package config
import ( import (
"github.com/songquanpeng/one-api/common/env"
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/songquanpeng/one-api/common/env"
"github.com/google/uuid" "github.com/google/uuid"
) )
@@ -160,3 +161,6 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
var RelayProxy = env.String("RELAY_PROXY", "") var RelayProxy = env.String("RELAY_PROXY", "")
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)
var TestPrompt = env.String("TEST_PROMPT", "Print your model name exactly and do not output without any other text.")

View File

@@ -20,4 +20,5 @@ const (
BaseURL = "base_url" BaseURL = "base_url"
AvailableModels = "available_models" AvailableModels = "available_models"
KeyRequestBody = "key_request_body" KeyRequestBody = "key_request_body"
SystemPrompt = "system_prompt"
) )

View File

@@ -1,9 +1,8 @@
package helper package helper
import ( import (
"context"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/random"
"html/template" "html/template"
"log" "log"
"net" "net"
@@ -11,6 +10,10 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/random"
) )
func OpenBrowser(url string) { func OpenBrowser(url string) {
@@ -106,6 +109,18 @@ func GenRequestID() string {
return GetTimeString() + random.GetRandomNumberString(8) return GetTimeString() + random.GetRandomNumberString(8)
} }
func SetRequestID(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, RequestIdKey, id)
}
func GetRequestID(ctx context.Context) string {
rawRequestId := ctx.Value(RequestIdKey)
if rawRequestId == nil {
return ""
}
return rawRequestId.(string)
}
func GetResponseID(c *gin.Context) string { func GetResponseID(c *gin.Context) string {
logID := c.GetString(RequestIdKey) logID := c.GetString(RequestIdKey)
return fmt.Sprintf("chatcmpl-%s", logID) return fmt.Sprintf("chatcmpl-%s", logID)

View File

@@ -13,3 +13,8 @@ func GetTimeString() string {
now := time.Now() now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
} }
// CalcElapsedTime return the elapsed time in milliseconds (ms)
func CalcElapsedTime(start time.Time) int64 {
return time.Now().Sub(start).Milliseconds()
}

View File

@@ -7,19 +7,25 @@ import (
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
) )
type loggerLevel string
const ( const (
loggerDEBUG = "DEBUG" loggerDEBUG loggerLevel = "DEBUG"
loggerINFO = "INFO" loggerINFO loggerLevel = "INFO"
loggerWarn = "WARN" loggerWarn loggerLevel = "WARN"
loggerError = "ERR" loggerError loggerLevel = "ERROR"
loggerFatal loggerLevel = "FATAL"
) )
var setupLogOnce sync.Once var setupLogOnce sync.Once
@@ -44,27 +50,26 @@ func SetupLogger() {
} }
func SysLog(s string) { func SysLog(s string) {
t := time.Now() logHelper(nil, loggerINFO, s)
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
} }
func SysLogf(format string, a ...any) { func SysLogf(format string, a ...any) {
SysLog(fmt.Sprintf(format, a...)) logHelper(nil, loggerINFO, fmt.Sprintf(format, a...))
} }
func SysError(s string) { func SysError(s string) {
t := time.Now() logHelper(nil, loggerError, s)
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
} }
func SysErrorf(format string, a ...any) { func SysErrorf(format string, a ...any) {
SysError(fmt.Sprintf(format, a...)) logHelper(nil, loggerError, fmt.Sprintf(format, a...))
} }
func Debug(ctx context.Context, msg string) { func Debug(ctx context.Context, msg string) {
if config.DebugEnabled { if !config.DebugEnabled {
logHelper(ctx, loggerDEBUG, msg) return
} }
logHelper(ctx, loggerDEBUG, msg)
} }
func Info(ctx context.Context, msg string) { func Info(ctx context.Context, msg string) {
@@ -80,37 +85,65 @@ func Error(ctx context.Context, msg string) {
} }
func Debugf(ctx context.Context, format string, a ...any) { func Debugf(ctx context.Context, format string, a ...any) {
Debug(ctx, fmt.Sprintf(format, a...)) logHelper(ctx, loggerDEBUG, fmt.Sprintf(format, a...))
} }
func Infof(ctx context.Context, format string, a ...any) { func Infof(ctx context.Context, format string, a ...any) {
Info(ctx, fmt.Sprintf(format, a...)) logHelper(ctx, loggerINFO, fmt.Sprintf(format, a...))
} }
func Warnf(ctx context.Context, format string, a ...any) { func Warnf(ctx context.Context, format string, a ...any) {
Warn(ctx, fmt.Sprintf(format, a...)) logHelper(ctx, loggerWarn, fmt.Sprintf(format, a...))
} }
func Errorf(ctx context.Context, format string, a ...any) { func Errorf(ctx context.Context, format string, a ...any) {
Error(ctx, fmt.Sprintf(format, a...)) logHelper(ctx, loggerError, fmt.Sprintf(format, a...))
} }
func logHelper(ctx context.Context, level string, msg string) { func FatalLog(s string) {
logHelper(nil, loggerFatal, s)
}
func FatalLogf(format string, a ...any) {
logHelper(nil, loggerFatal, fmt.Sprintf(format, a...))
}
func logHelper(ctx context.Context, level loggerLevel, msg string) {
writer := gin.DefaultErrorWriter writer := gin.DefaultErrorWriter
if level == loggerINFO { if level == loggerINFO {
writer = gin.DefaultWriter writer = gin.DefaultWriter
} }
id := ctx.Value(helper.RequestIdKey) var requestId string
if id == nil { if ctx != nil {
id = helper.GenRequestID() rawRequestId := helper.GetRequestID(ctx)
if rawRequestId != "" {
requestId = fmt.Sprintf(" | %s", rawRequestId)
}
} }
lineInfo, funcName := getLineInfo()
now := time.Now() now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) _, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), requestId, lineInfo, funcName, msg)
SetupLogger() SetupLogger()
if level == loggerFatal {
os.Exit(1)
}
} }
func FatalLog(v ...any) { func getLineInfo() (string, string) {
t := time.Now() funcName := "[unknown] "
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) pc, file, line, ok := runtime.Caller(3)
os.Exit(1) if ok {
if fn := runtime.FuncForPC(pc); fn != nil {
parts := strings.Split(fn.Name(), ".")
funcName = "[" + parts[len(parts)-1] + "] "
}
} else {
file = "unknown"
line = 0
}
parts := strings.Split(file, "one-api/")
if len(parts) > 1 {
file = parts[1]
}
return fmt.Sprintf(" | %s:%d", file, line), funcName
} }

View File

@@ -2,13 +2,15 @@ package common
import ( import (
"context" "context"
"os"
"strings"
"time"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"os"
"time"
) )
var RDB *redis.Client var RDB redis.Cmdable
var RedisEnabled = true var RedisEnabled = true
// InitRedisClient This function is called after init() // InitRedisClient This function is called after init()
@@ -23,13 +25,23 @@ func InitRedisClient() (err error) {
logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil return nil
} }
logger.SysLog("Redis is enabled") redisConnString := os.Getenv("REDIS_CONN_STRING")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) if os.Getenv("REDIS_MASTER_NAME") == "" {
if err != nil { logger.SysLog("Redis is enabled")
logger.FatalLog("failed to parse Redis connection string: " + err.Error()) opt, err := redis.ParseURL(redisConnString)
if err != nil {
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
RDB = redis.NewClient(opt)
} else {
// cluster mode
logger.SysLog("Redis cluster mode enabled")
RDB = redis.NewUniversalClient(&redis.UniversalOptions{
Addrs: strings.Split(redisConnString, ","),
Password: os.Getenv("REDIS_PASSWORD"),
MasterName: os.Getenv("REDIS_MASTER_NAME"),
})
} }
RDB = redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()

View File

@@ -3,9 +3,10 @@ package render
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"strings"
) )
func StringData(c *gin.Context, str string) { func StringData(c *gin.Context, str string) {

View File

@@ -5,16 +5,18 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"strconv"
"time"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
) )
type GitHubOAuthResponse struct { type GitHubOAuthResponse struct {
@@ -81,6 +83,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
} }
func GitHubOAuth(c *gin.Context) { func GitHubOAuth(c *gin.Context) {
ctx := c.Request.Context()
session := sessions.Default(c) session := sessions.Default(c)
state := c.Query("state") state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
@@ -136,7 +139,7 @@ func GitHubOAuth(c *gin.Context) {
user.Role = model.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = model.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(ctx, 0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@@ -5,15 +5,17 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"strconv"
"time"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
) )
type LarkOAuthResponse struct { type LarkOAuthResponse struct {
@@ -40,7 +42,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -79,6 +81,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
} }
func LarkOAuth(c *gin.Context) { func LarkOAuth(c *gin.Context) {
ctx := c.Request.Context()
session := sessions.Default(c) session := sessions.Default(c)
state := c.Query("state") state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
@@ -125,7 +128,7 @@ func LarkOAuth(c *gin.Context) {
user.Role = model.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = model.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(ctx, 0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@@ -5,15 +5,17 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"strconv"
"time"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
) )
type OidcResponse struct { type OidcResponse struct {
@@ -87,6 +89,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
} }
func OidcAuth(c *gin.Context) { func OidcAuth(c *gin.Context) {
ctx := c.Request.Context()
session := sessions.Default(c) session := sessions.Default(c)
state := c.Query("state") state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
@@ -142,7 +145,7 @@ func OidcAuth(c *gin.Context) {
} else { } else {
user.DisplayName = "OIDC User" user.DisplayName = "OIDC User"
} }
err := user.Insert(0) err := user.Insert(ctx, 0)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@@ -4,14 +4,16 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
) )
type wechatLoginResponse struct { type wechatLoginResponse struct {
@@ -52,6 +54,7 @@ func getWeChatIdByCode(code string) (string, error) {
} }
func WeChatAuth(c *gin.Context) { func WeChatAuth(c *gin.Context) {
ctx := c.Request.Context()
if !config.WeChatAuthEnabled { if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册", "message": "管理员未开启通过微信登录以及注册",
@@ -87,7 +90,7 @@ func WeChatAuth(c *gin.Context) {
user.Role = model.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = model.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(ctx, 0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@@ -4,16 +4,17 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"strconv"
"time"
"github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"io"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -101,6 +102,16 @@ type SiliconFlowUsageResponse struct {
} `json:"data"` } `json:"data"`
} }
type DeepSeekUsageResponse struct {
IsAvailable bool `json:"is_available"`
BalanceInfos []struct {
Currency string `json:"currency"`
TotalBalance string `json:"total_balance"`
GrantedBalance string `json:"granted_balance"`
ToppedUpBalance string `json:"topped_up_balance"`
} `json:"balance_infos"`
}
// GetAuthHeader get auth header // GetAuthHeader get auth header
func GetAuthHeader(token string) http.Header { func GetAuthHeader(token string) http.Header {
h := http.Header{} h := http.Header{}
@@ -237,7 +248,36 @@ func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
if response.Code != 20000 { if response.Code != 20000 {
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message) return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
} }
balance, err := strconv.ParseFloat(response.Data.Balance, 64) balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) {
url := "https://api.deepseek.com/user/balance"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := DeepSeekUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
index := -1
for i, balanceInfo := range response.BalanceInfos {
if balanceInfo.Currency == "CNY" {
index = i
break
}
}
if index == -1 {
return 0, errors.New("currency CNY not found")
}
balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -271,6 +311,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return updateChannelAIGC2DBalance(channel) return updateChannelAIGC2DBalance(channel)
case channeltype.SiliconFlow: case channeltype.SiliconFlow:
return updateChannelSiliconFlowBalance(channel) return updateChannelSiliconFlowBalance(channel)
case channeltype.DeepSeek:
return updateChannelDeepSeekBalance(channel)
default: default:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
} }

View File

@@ -2,6 +2,7 @@ package controller
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -15,14 +16,17 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message" "github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
relay "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
@@ -35,18 +39,34 @@ func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
model = "gpt-3.5-turbo" model = "gpt-3.5-turbo"
} }
testRequest := &relaymodel.GeneralOpenAIRequest{ testRequest := &relaymodel.GeneralOpenAIRequest{
MaxTokens: 2, Model: model,
Model: model,
} }
testMessage := relaymodel.Message{ testMessage := relaymodel.Message{
Role: "user", Role: "user",
Content: "hi", Content: config.TestPrompt,
} }
testRequest.Messages = append(testRequest.Messages, testMessage) testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest return testRequest
} }
func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) { func parseTestResponse(resp string) (*openai.TextResponse, string, error) {
var response openai.TextResponse
err := json.Unmarshal([]byte(resp), &response)
if err != nil {
return nil, "", err
}
if len(response.Choices) == 0 {
return nil, "", errors.New("response has no choices")
}
stringContent, ok := response.Choices[0].Content.(string)
if !ok {
return nil, "", errors.New("response content is not string")
}
return &response, stringContent, nil
}
func testChannel(ctx context.Context, channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (responseMessage string, err error, openaiErr *relaymodel.Error) {
startTime := time.Now()
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{ c.Request = &http.Request{
@@ -66,7 +86,7 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
apiType := channeltype.ToAPIType(channel.Type) apiType := channeltype.ToAPIType(channel.Type)
adaptor := relay.GetAdaptor(apiType) adaptor := relay.GetAdaptor(apiType)
if adaptor == nil { if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return "", fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
adaptor.Init(meta) adaptor.Init(meta)
modelName := request.Model modelName := request.Model
@@ -84,41 +104,69 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
request.Model = modelName request.Model = modelName
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil { if err != nil {
return err, nil return "", err, nil
} }
jsonData, err := json.Marshal(convertedRequest) jsonData, err := json.Marshal(convertedRequest)
if err != nil { if err != nil {
return err, nil return "", err, nil
} }
defer func() {
logContent := fmt.Sprintf("渠道 %s 测试成功,响应:%s", channel.Name, responseMessage)
if err != nil || openaiErr != nil {
errorMessage := ""
if err != nil {
errorMessage = err.Error()
} else {
errorMessage = openaiErr.Message
}
logContent = fmt.Sprintf("渠道 %s 测试失败,错误:%s", channel.Name, errorMessage)
}
go model.RecordTestLog(ctx, &model.Log{
ChannelId: channel.Id,
ModelName: modelName,
Content: logContent,
ElapsedTime: helper.CalcElapsedTime(startTime),
})
}()
logger.SysLog(string(jsonData)) logger.SysLog(string(jsonData))
requestBody := bytes.NewBuffer(jsonData) requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody) c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, meta, requestBody) resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil { if err != nil {
return err, nil return "", err, nil
} }
if resp != nil && resp.StatusCode != http.StatusOK { if resp != nil && resp.StatusCode != http.StatusOK {
err := controller.RelayErrorHandler(resp) err := controller.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error errorMessage := err.Error.Message
if errorMessage != "" {
errorMessage = ", error message: " + errorMessage
}
return "", fmt.Errorf("http status code: %d%s", resp.StatusCode, errorMessage), &err.Error
} }
usage, respErr := adaptor.DoResponse(c, resp, meta) usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil { if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error return "", fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
} }
if usage == nil { if usage == nil {
return errors.New("usage is nil"), nil return "", errors.New("usage is nil"), nil
}
rawResponse := w.Body.String()
_, responseMessage, err = parseTestResponse(rawResponse)
if err != nil {
return "", err, nil
} }
result := w.Result() result := w.Result()
// print result.Body // print result.Body
respBody, err := io.ReadAll(result.Body) respBody, err := io.ReadAll(result.Body)
if err != nil { if err != nil {
return err, nil return "", err, nil
} }
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil return responseMessage, nil, nil
} }
func TestChannel(c *gin.Context) { func TestChannel(c *gin.Context) {
ctx := c.Request.Context()
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -135,10 +183,10 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
model := c.Query("model") modelName := c.Query("model")
testRequest := buildTestRequest(model) testRequest := buildTestRequest(modelName)
tik := time.Now() tik := time.Now()
err, _ = testChannel(channel, testRequest) responseMessage, err, _ := testChannel(ctx, channel, testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if err != nil { if err != nil {
@@ -148,18 +196,18 @@ func TestChannel(c *gin.Context) {
consumedTime := float64(milliseconds) / 1000.0 consumedTime := float64(milliseconds) / 1000.0
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
"time": consumedTime, "time": consumedTime,
"model": model, "modelName": modelName,
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": responseMessage,
"time": consumedTime, "time": consumedTime,
"model": model, "modelName": modelName,
}) })
return return
} }
@@ -167,7 +215,7 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false var testAllChannelsRunning bool = false
func testChannels(notify bool, scope string) error { func testChannels(ctx context.Context, notify bool, scope string) error {
if config.RootUserEmail == "" { if config.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail() config.RootUserEmail = model.GetRootUserEmail()
} }
@@ -191,7 +239,7 @@ func testChannels(notify bool, scope string) error {
isChannelEnabled := channel.Status == model.ChannelStatusEnabled isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
testRequest := buildTestRequest("") testRequest := buildTestRequest("")
err, openaiErr := testChannel(channel, testRequest) _, err, openaiErr := testChannel(ctx, channel, testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold { if isChannelEnabled && milliseconds > disableThreshold {
@@ -225,11 +273,12 @@ func testChannels(notify bool, scope string) error {
} }
func TestChannels(c *gin.Context) { func TestChannels(c *gin.Context) {
ctx := c.Request.Context()
scope := c.Query("scope") scope := c.Query("scope")
if scope == "" { if scope == "" {
scope = "all" scope = "all"
} }
err := testChannels(true, scope) err := testChannels(ctx, true, scope)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -245,10 +294,11 @@ func TestChannels(c *gin.Context) {
} }
func AutomaticallyTestChannels(frequency int) { func AutomaticallyTestChannels(frequency int) {
ctx := context.Background()
for { for {
time.Sleep(time.Duration(frequency) * time.Minute) time.Sleep(time.Duration(frequency) * time.Minute)
logger.SysLog("testing all channels") logger.SysLog("testing all channels")
_ = testChannels(false, "all") _ = testChannels(ctx, false, "all")
logger.SysLog("channel test finished") logger.SysLog("channel test finished")
} }
} }

View File

@@ -60,7 +60,7 @@ func Relay(c *gin.Context) {
channelName := c.GetString(ctxkey.ChannelName) channelName := c.GetString(ctxkey.ChannelName)
group := c.GetString(ctxkey.Group) group := c.GetString(ctxkey.Group)
originalModel := c.GetString(ctxkey.OriginalModel) originalModel := c.GetString(ctxkey.OriginalModel)
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
requestId := c.GetString(helper.RequestIdKey) requestId := c.GetString(helper.RequestIdKey)
retryTimes := config.RetryTimes retryTimes := config.RetryTimes
if !shouldRetry(c, bizErr.StatusCode) { if !shouldRetry(c, bizErr.StatusCode) {
@@ -87,8 +87,7 @@ func Relay(c *gin.Context) {
channelId := c.GetInt(ctxkey.ChannelId) channelId := c.GetInt(ctxkey.ChannelId)
lastFailedChannelId = channelId lastFailedChannelId = channelId
channelName := c.GetString(ctxkey.ChannelName) channelName := c.GetString(ctxkey.ChannelName)
// BUG: bizErr is in race condition go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
} }
if bizErr != nil { if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests { if bizErr.StatusCode == http.StatusTooManyRequests {
@@ -122,7 +121,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
return true return true
} }
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) {
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors // https://platform.openai.com/docs/guides/error-codes/api-errors
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {

View File

@@ -109,6 +109,7 @@ func Logout(c *gin.Context) {
} }
func Register(c *gin.Context) { func Register(c *gin.Context) {
ctx := c.Request.Context()
if !config.RegisterEnabled { if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册", "message": "管理员关闭了新用户注册",
@@ -166,7 +167,7 @@ func Register(c *gin.Context) {
if config.EmailVerificationEnabled { if config.EmailVerificationEnabled {
cleanUser.Email = user.Email cleanUser.Email = user.Email
} }
if err := cleanUser.Insert(inviterId); err != nil { if err := cleanUser.Insert(ctx, inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
@@ -362,6 +363,7 @@ func GetSelf(c *gin.Context) {
} }
func UpdateUser(c *gin.Context) { func UpdateUser(c *gin.Context) {
ctx := c.Request.Context()
var updatedUser model.User var updatedUser model.User
err := json.NewDecoder(c.Request.Body).Decode(&updatedUser) err := json.NewDecoder(c.Request.Body).Decode(&updatedUser)
if err != nil || updatedUser.Id == 0 { if err != nil || updatedUser.Id == 0 {
@@ -416,7 +418,7 @@ func UpdateUser(c *gin.Context) {
return return
} }
if originUser.Quota != updatedUser.Quota { if originUser.Quota != updatedUser.Quota {
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) model.RecordLog(ctx, originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
@@ -535,6 +537,7 @@ func DeleteSelf(c *gin.Context) {
} }
func CreateUser(c *gin.Context) { func CreateUser(c *gin.Context) {
ctx := c.Request.Context()
var user model.User var user model.User
err := json.NewDecoder(c.Request.Body).Decode(&user) err := json.NewDecoder(c.Request.Body).Decode(&user)
if err != nil || user.Username == "" || user.Password == "" { if err != nil || user.Username == "" || user.Password == "" {
@@ -568,7 +571,7 @@ func CreateUser(c *gin.Context) {
Password: user.Password, Password: user.Password,
DisplayName: user.DisplayName, DisplayName: user.DisplayName,
} }
if err := cleanUser.Insert(0); err != nil { if err := cleanUser.Insert(ctx, 0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
@@ -747,6 +750,7 @@ type topUpRequest struct {
} }
func TopUp(c *gin.Context) { func TopUp(c *gin.Context) {
ctx := c.Request.Context()
req := topUpRequest{} req := topUpRequest{}
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
@@ -757,7 +761,7 @@ func TopUp(c *gin.Context) {
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
quota, err := model.Redeem(req.Key, id) quota, err := model.Redeem(ctx, req.Key, id)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -780,6 +784,7 @@ type adminTopUpRequest struct {
} }
func AdminTopUp(c *gin.Context) { func AdminTopUp(c *gin.Context) {
ctx := c.Request.Context()
req := adminTopUpRequest{} req := adminTopUpRequest{}
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
@@ -800,7 +805,7 @@ func AdminTopUp(c *gin.Context) {
if req.Remark == "" { if req.Remark == "" {
req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
} }
model.RecordTopupLog(req.UserId, req.Remark, req.Quota) model.RecordTopupLog(ctx, req.UserId, req.Remark, req.Quota)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",

8
go.mod
View File

@@ -25,7 +25,7 @@ require (
github.com/pkoukk/tiktoken-go v0.1.7 github.com/pkoukk/tiktoken-go v0.1.7
github.com/smartystreets/goconvey v1.8.1 github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.24.0 golang.org/x/crypto v0.31.0
golang.org/x/image v0.18.0 golang.org/x/image v0.18.0
google.golang.org/api v0.187.0 google.golang.org/api v0.187.0
gorm.io/driver/mysql v1.5.6 gorm.io/driver/mysql v1.5.6
@@ -99,9 +99,9 @@ require (
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/net v0.26.0 // indirect golang.org/x/net v0.26.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect
golang.org/x/sync v0.7.0 // indirect golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.21.0 // indirect golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.16.0 // indirect golang.org/x/text v0.21.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect

16
go.sum
View File

@@ -222,8 +222,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
@@ -244,20 +244,20 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -2,13 +2,15 @@ package middleware
import ( import (
"fmt" "fmt"
"net/http"
"strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"net/http"
"strconv"
) )
type ModelRequest struct { type ModelRequest struct {
@@ -17,6 +19,7 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) { func Distribute() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
ctx := c.Request.Context()
userId := c.GetInt(ctxkey.Id) userId := c.GetInt(ctxkey.Id)
userGroup, _ := model.CacheGetUserGroup(userId) userGroup, _ := model.CacheGetUserGroup(userId)
c.Set(ctxkey.Group, userGroup) c.Set(ctxkey.Group, userGroup)
@@ -52,6 +55,7 @@ func Distribute() func(c *gin.Context) {
return return
} }
} }
logger.Debugf(ctx, "user id %d, user group: %s, request model: %s, using channel #%d", userId, userGroup, requestModel, channel.Id)
SetupContextForSelectedChannel(c, channel, requestModel) SetupContextForSelectedChannel(c, channel, requestModel)
c.Next() c.Next()
} }
@@ -61,6 +65,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set(ctxkey.Channel, channel.Type) c.Set(ctxkey.Channel, channel.Type)
c.Set(ctxkey.ChannelId, channel.Id) c.Set(ctxkey.ChannelId, channel.Id)
c.Set(ctxkey.ChannelName, channel.Name) c.Set(ctxkey.ChannelName, channel.Name)
if channel.SystemPrompt != nil && *channel.SystemPrompt != "" {
c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt)
}
c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
c.Set(ctxkey.OriginalModel, modelName) // for retry c.Set(ctxkey.OriginalModel, modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))

27
middleware/gzip.go Normal file
View File

@@ -0,0 +1,27 @@
package middleware
import (
"compress/gzip"
"github.com/gin-gonic/gin"
"io"
"net/http"
)
func GzipDecodeMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if c.GetHeader("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(c.Request.Body)
if err != nil {
c.AbortWithStatus(http.StatusBadRequest)
return
}
defer gzipReader.Close()
// Replace the request body with the decompressed data
c.Request.Body = io.NopCloser(gzipReader)
}
// Continue processing the request
c.Next()
}
}

View File

@@ -1,8 +1,8 @@
package middleware package middleware
import ( import (
"context"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
) )
@@ -10,7 +10,7 @@ func RequestId() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
id := helper.GenRequestID() id := helper.GenRequestID()
c.Set(helper.RequestIdKey, id) c.Set(helper.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id) ctx := helper.SetRequestID(c.Request.Context(), id)
c.Request = c.Request.WithContext(ctx) c.Request = c.Request.WithContext(ctx)
c.Header(helper.RequestIdKey, id) c.Header(helper.RequestIdKey, id)
c.Next() c.Next()

View File

@@ -37,6 +37,7 @@ type Channel struct {
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Config string `json:"config"` Config string `json:"config"`
SystemPrompt *string `json:"system_prompt" gorm:"type:text"`
} }
type ChannelConfig struct { type ChannelConfig struct {

View File

@@ -4,26 +4,31 @@ import (
"context" "context"
"fmt" "fmt"
"gorm.io/gorm"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
) )
type Log struct { type Log struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"` UserId int `json:"user_id" gorm:"index"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"` CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"` Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"` Content string `json:"content"`
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
TokenName string `json:"token_name" gorm:"index;default:''"` TokenName string `json:"token_name" gorm:"index;default:''"`
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
Quota int `json:"quota" gorm:"default:0"` Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"` PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"` CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
ChannelId int `json:"channel" gorm:"index"` ChannelId int `json:"channel" gorm:"index"`
RequestId string `json:"request_id" gorm:"default:''"`
ElapsedTime int64 `json:"elapsed_time" gorm:"default:0"` // unit is ms
IsStream bool `json:"is_stream" gorm:"default:false"`
SystemPromptReset bool `json:"system_prompt_reset" gorm:"default:false"`
} }
const ( const (
@@ -32,9 +37,21 @@ const (
LogTypeConsume LogTypeConsume
LogTypeManage LogTypeManage
LogTypeSystem LogTypeSystem
LogTypeTest
) )
func RecordLog(userId int, logType int, content string) { func recordLogHelper(ctx context.Context, log *Log) {
requestId := helper.GetRequestID(ctx)
log.RequestId = requestId
err := LOG_DB.Create(log).Error
if err != nil {
logger.Error(ctx, "failed to record log: "+err.Error())
return
}
logger.Infof(ctx, "record log: %+v", log)
}
func RecordLog(ctx context.Context, userId int, logType int, content string) {
if logType == LogTypeConsume && !config.LogConsumeEnabled { if logType == LogTypeConsume && !config.LogConsumeEnabled {
return return
} }
@@ -45,13 +62,10 @@ func RecordLog(userId int, logType int, content string) {
Type: logType, Type: logType,
Content: content, Content: content,
} }
err := LOG_DB.Create(log).Error recordLogHelper(ctx, log)
if err != nil {
logger.SysError("failed to record log: " + err.Error())
}
} }
func RecordTopupLog(userId int, content string, quota int) { func RecordTopupLog(ctx context.Context, userId int, content string, quota int) {
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
Username: GetUsernameById(userId), Username: GetUsernameById(userId),
@@ -60,34 +74,23 @@ func RecordTopupLog(userId int, content string, quota int) {
Content: content, Content: content,
Quota: quota, Quota: quota,
} }
err := LOG_DB.Create(log).Error recordLogHelper(ctx, log)
if err != nil {
logger.SysError("failed to record log: " + err.Error())
}
} }
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { func RecordConsumeLog(ctx context.Context, log *Log) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !config.LogConsumeEnabled { if !config.LogConsumeEnabled {
return return
} }
log := &Log{ log.Username = GetUsernameById(log.UserId)
UserId: userId, log.CreatedAt = helper.GetTimestamp()
Username: GetUsernameById(userId), log.Type = LogTypeConsume
CreatedAt: helper.GetTimestamp(), recordLogHelper(ctx, log)
Type: LogTypeConsume, }
Content: content,
PromptTokens: promptTokens, func RecordTestLog(ctx context.Context, log *Log) {
CompletionTokens: completionTokens, log.CreatedAt = helper.GetTimestamp()
TokenName: tokenName, log.Type = LogTypeTest
ModelName: modelName, recordLogHelper(ctx, log)
Quota: int(quota),
ChannelId: channelId,
}
err := LOG_DB.Create(log).Error
if err != nil {
logger.Error(ctx, "failed to record log: "+err.Error())
}
} }
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {

View File

@@ -1,11 +1,14 @@
package model package model
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"gorm.io/gorm"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"gorm.io/gorm"
) )
const ( const (
@@ -48,7 +51,7 @@ func GetRedemptionById(id int) (*Redemption, error) {
return &redemption, err return &redemption, err
} }
func Redeem(key string, userId int) (quota int64, err error) { func Redeem(ctx context.Context, key string, userId int) (quota int64, err error) {
if key == "" { if key == "" {
return 0, errors.New("未提供兑换码") return 0, errors.New("未提供兑换码")
} }
@@ -82,7 +85,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
if err != nil { if err != nil {
return 0, errors.New("兑换失败," + err.Error()) return 0, errors.New("兑换失败," + err.Error())
} }
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) RecordLog(ctx, userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
return redemption.Quota, nil return redemption.Quota, nil
} }

View File

@@ -1,16 +1,19 @@
package model package model
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strings"
"gorm.io/gorm"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/common/random"
"gorm.io/gorm"
"strings"
) )
const ( const (
@@ -114,7 +117,7 @@ func DeleteUserById(id int) (err error) {
return user.Delete() return user.Delete()
} }
func (user *User) Insert(inviterId int) error { func (user *User) Insert(ctx context.Context, inviterId int) error {
var err error var err error
if user.Password != "" { if user.Password != "" {
user.Password, err = common.Password2Hash(user.Password) user.Password, err = common.Password2Hash(user.Password)
@@ -130,16 +133,16 @@ func (user *User) Insert(inviterId int) error {
return result.Error return result.Error
} }
if config.QuotaForNewUser > 0 { if config.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
} }
if inviterId != 0 { if inviterId != 0 {
if config.QuotaForInvitee > 0 { if config.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
} }
if config.QuotaForInviter > 0 { if config.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, config.QuotaForInviter) _ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) RecordLog(ctx, inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
} }
} }
// create default token // create default token

View File

@@ -34,7 +34,7 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
strings.Contains(lowerMessage, "credit") || strings.Contains(lowerMessage, "credit") ||
strings.Contains(lowerMessage, "balance") || strings.Contains(lowerMessage, "balance") ||
strings.Contains(lowerMessage, "permission denied") || strings.Contains(lowerMessage, "permission denied") ||
strings.Contains(lowerMessage, "organization has been restricted") || // groq strings.Contains(lowerMessage, "organization has been restricted") || // groq
strings.Contains(lowerMessage, "已欠费") { strings.Contains(lowerMessage, "已欠费") {
return true return true
} }

BIN
one-api

Binary file not shown.

View File

@@ -16,6 +16,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/adaptor/palm" "github.com/songquanpeng/one-api/relay/adaptor/palm"
"github.com/songquanpeng/one-api/relay/adaptor/proxy" "github.com/songquanpeng/one-api/relay/adaptor/proxy"
"github.com/songquanpeng/one-api/relay/adaptor/replicate"
"github.com/songquanpeng/one-api/relay/adaptor/tencent" "github.com/songquanpeng/one-api/relay/adaptor/tencent"
"github.com/songquanpeng/one-api/relay/adaptor/vertexai" "github.com/songquanpeng/one-api/relay/adaptor/vertexai"
"github.com/songquanpeng/one-api/relay/adaptor/xunfei" "github.com/songquanpeng/one-api/relay/adaptor/xunfei"
@@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
return &vertexai.Adaptor{} return &vertexai.Adaptor{}
case apitype.Proxy: case apitype.Proxy:
return &proxy.Adaptor{} return &proxy.Adaptor{}
case apitype.Replicate:
return &replicate.Adaptor{}
} }
return nil return nil
} }

View File

@@ -1,7 +1,23 @@
package ali package ali
var ModelList = []string{ var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", "qwen-turbo", "qwen-turbo-latest",
"text-embedding-v1", "qwen-plus", "qwen-plus-latest",
"qwen-max", "qwen-max-latest",
"qwen-max-longcontext",
"qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest",
"qwen-vl-ocr", "qwen-vl-ocr-latest",
"qwen-audio-turbo",
"qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest",
"qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest",
"qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct",
"qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct",
"qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat",
"qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat",
"qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1",
"qwen2-audio-instruct", "qwen-audio-chat",
"qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct",
"qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct",
"text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1",
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", "ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
} }

View File

@@ -9,5 +9,4 @@ var ModelList = []string{
"claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest",
"claude-3-5-haiku-20241022",
} }

View File

@@ -7,7 +7,6 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
channelhelper "github.com/songquanpeng/one-api/relay/adaptor" channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
@@ -24,7 +23,15 @@ func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion) var defaultVersion string
switch meta.ActualModelName {
case "gemini-2.0-flash-exp",
"gemini-2.0-flash-thinking-exp",
"gemini-2.0-flash-thinking-exp-01-21":
defaultVersion = "v1beta"
}
version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion)
action := "" action := ""
switch meta.Mode { switch meta.Mode {
case relaymode.Embeddings: case relaymode.Embeddings:
@@ -36,6 +43,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
if meta.IsStream { if meta.IsStream {
action = "streamGenerateContent?alt=sse" action = "streamGenerateContent?alt=sse"
} }
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
} }

View File

@@ -3,5 +3,9 @@ package gemini
// https://ai.google.dev/models/gemini // https://ai.google.dev/models/gemini
var ModelList = []string{ var ModelList = []string{
"gemini-pro", "gemini-1.0-pro", "gemini-1.5-flash", "gemini-1.5-pro", "text-embedding-004", "aqa", "gemini-pro", "gemini-1.0-pro",
"gemini-1.5-flash", "gemini-1.5-pro",
"text-embedding-004", "aqa",
"gemini-2.0-flash-exp",
"gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21",
} }

View File

@@ -55,6 +55,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: config.GeminiSafetySetting, Threshold: config.GeminiSafetySetting,
}, },
{
Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
Threshold: config.GeminiSafetySetting,
},
}, },
GenerationConfig: ChatGenerationConfig{ GenerationConfig: ChatGenerationConfig{
Temperature: textRequest.Temperature, Temperature: textRequest.Temperature,
@@ -247,7 +251,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
if candidate.Content.Parts[0].FunctionCall != nil { if candidate.Content.Parts[0].FunctionCall != nil {
choice.Message.ToolCalls = getToolCalls(&candidate) choice.Message.ToolCalls = getToolCalls(&candidate)
} else { } else {
choice.Message.Content = candidate.Content.Parts[0].Text var builder strings.Builder
for _, part := range candidate.Content.Parts {
if i > 0 {
builder.WriteString("\n")
}
builder.WriteString(part.Text)
}
choice.Message.Content = builder.String()
} }
} else { } else {
choice.Message.Content = "" choice.Message.Content = ""

View File

@@ -31,8 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
TopP: request.TopP, TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty, FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty, PresencePenalty: request.PresencePenalty,
NumPredict: request.MaxTokens, NumPredict: request.MaxTokens,
NumCtx: request.NumCtx, NumCtx: request.NumCtx,
}, },
Stream: request.Stream, Stream: request.Stream,
} }
@@ -122,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if strings.HasPrefix(data, "}") { if strings.HasPrefix(data, "}") {
data = strings.TrimPrefix(data, "}") + "}" data = strings.TrimPrefix(data, "}") + "}"
} }
var ollamaResponse ChatResponse var ollamaResponse ChatResponse

View File

@@ -9,6 +9,7 @@ var ModelList = []string{
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o", "gpt-4o-2024-05-13",
"gpt-4o-2024-08-06", "gpt-4o-2024-08-06",
"gpt-4o-2024-11-20",
"chatgpt-4o-latest", "chatgpt-4o-latest",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18", "gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"gpt-4-vision-preview", "gpt-4-vision-preview",
@@ -20,4 +21,7 @@ var ModelList = []string{
"dall-e-2", "dall-e-3", "dall-e-2", "dall-e-3",
"whisper-1", "whisper-1",
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
"o1", "o1-2024-12-17",
"o1-preview", "o1-preview-2024-09-12",
"o1-mini", "o1-mini-2024-09-12",
} }

View File

@@ -2,15 +2,16 @@ package openai
import ( import (
"fmt" "fmt"
"strings"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"strings"
) )
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage {
usage := &model.Usage{} usage := &model.Usage{}
usage.PromptTokens = promptTokens usage.PromptTokens = promptTokens
usage.CompletionTokens = CountTokenText(responseText, modeName) usage.CompletionTokens = CountTokenText(responseText, modelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage return usage
} }

View File

@@ -1,8 +1,16 @@
package openai package openai
import "github.com/songquanpeng/one-api/relay/model" import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/model"
)
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode { func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
Error := model.Error{ Error := model.Error{
Message: err.Error(), Message: err.Error(),
Type: "one_api_error", Type: "one_api_error",

View File

@@ -0,0 +1,136 @@
package replicate
import (
"fmt"
"io"
"net/http"
"slices"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
type Adaptor struct {
meta *meta.Meta
}
// ConvertImageRequest implements adaptor.Adaptor.
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return DrawImageRequest{
Input: ImageInput{
Steps: 25,
Prompt: request.Prompt,
Guidance: 3,
Seed: int(time.Now().UnixNano()),
SafetyTolerance: 5,
NImages: 1, // replicate will always return 1 image
Width: 1440,
Height: 1440,
AspectRatio: "1:1",
},
}, nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if !request.Stream {
// TODO: support non-stream mode
return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
}
// Build the prompt from OpenAI messages
var promptBuilder strings.Builder
for _, message := range request.Messages {
switch msgCnt := message.Content.(type) {
case string:
promptBuilder.WriteString(message.Role)
promptBuilder.WriteString(": ")
promptBuilder.WriteString(msgCnt)
promptBuilder.WriteString("\n")
default:
}
}
replicateRequest := ReplicateChatRequest{
Input: ChatInput{
Prompt: promptBuilder.String(),
MaxTokens: request.MaxTokens,
Temperature: 1.0,
TopP: 1.0,
PresencePenalty: 0.0,
FrequencyPenalty: 0.0,
},
}
// Map optional fields
if request.Temperature != nil {
replicateRequest.Input.Temperature = *request.Temperature
}
if request.TopP != nil {
replicateRequest.Input.TopP = *request.TopP
}
if request.PresencePenalty != nil {
replicateRequest.Input.PresencePenalty = *request.PresencePenalty
}
if request.FrequencyPenalty != nil {
replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
}
if request.MaxTokens > 0 {
replicateRequest.Input.MaxTokens = request.MaxTokens
} else if request.MaxTokens == 0 {
replicateRequest.Input.MaxTokens = 500
}
return replicateRequest, nil
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
if !slices.Contains(ModelList, meta.OriginModelName) {
return "", errors.Errorf("model %s not supported", meta.OriginModelName)
}
return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
logger.Info(c, "send request to replicate")
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
switch meta.Mode {
case relaymode.ImagesGenerations:
err, usage = ImageHandler(c, resp)
case relaymode.ChatCompletions:
err, usage = ChatHandler(c, resp)
default:
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "replicate"
}

View File

@@ -0,0 +1,191 @@
package replicate
import (
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
func ChatHandler(c *gin.Context, resp *http.Response) (
srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
if resp.StatusCode != http.StatusCreated {
payload, _ := io.ReadAll(resp.Body)
return openai.ErrorWrapper(
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
"bad_status_code", http.StatusInternalServerError),
nil
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
respData := new(ChatResponse)
if err = json.Unmarshal(respBody, respData); err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
for {
err = func() error {
// get task
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
http.MethodGet, respData.URLs.Get, nil)
if err != nil {
return errors.Wrap(err, "new request")
}
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
taskResp, err := http.DefaultClient.Do(taskReq)
if err != nil {
return errors.Wrap(err, "get task")
}
defer taskResp.Body.Close()
if taskResp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(taskResp.Body)
return errors.Errorf("bad status code [%d]%s",
taskResp.StatusCode, string(payload))
}
taskBody, err := io.ReadAll(taskResp.Body)
if err != nil {
return errors.Wrap(err, "read task response")
}
taskData := new(ChatResponse)
if err = json.Unmarshal(taskBody, taskData); err != nil {
return errors.Wrap(err, "decode task response")
}
switch taskData.Status {
case "succeeded":
case "failed", "canceled":
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
default:
time.Sleep(time.Second * 3)
return errNextLoop
}
if taskData.URLs.Stream == "" {
return errors.New("stream url is empty")
}
// request stream url
responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
if err != nil {
return errors.Wrap(err, "chat stream handler")
}
ctxMeta := meta.GetByContext(c)
usage = openai.ResponseText2Usage(responseText,
ctxMeta.ActualModelName, ctxMeta.PromptTokens)
return nil
}()
if err != nil {
if errors.Is(err, errNextLoop) {
continue
}
return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil
}
break
}
return nil, usage
}
const (
eventPrefix = "event: "
dataPrefix = "data: "
done = "[DONE]"
)
func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
// request stream endpoint
streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
if err != nil {
return "", errors.Wrap(err, "new request to stream")
}
streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
streamReq.Header.Set("Accept", "text/event-stream")
streamReq.Header.Set("Cache-Control", "no-store")
resp, err := http.DefaultClient.Do(streamReq)
if err != nil {
return "", errors.Wrap(err, "do request to stream")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(resp.Body)
return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
common.SetEventStreamHeaders(c)
doneRendered := false
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
// Handle comments starting with ':'
if strings.HasPrefix(line, ":") {
continue
}
// Parse SSE fields
if strings.HasPrefix(line, eventPrefix) {
event := strings.TrimSpace(line[len(eventPrefix):])
var data string
// Read the following lines to get data and id
for scanner.Scan() {
nextLine := scanner.Text()
if nextLine == "" {
break
}
if strings.HasPrefix(nextLine, dataPrefix) {
data = nextLine[len(dataPrefix):]
} else if strings.HasPrefix(nextLine, "id:") {
// id = strings.TrimSpace(nextLine[len("id:"):])
}
}
if event == "output" {
render.StringData(c, data)
responseText += data
} else if event == "done" {
render.Done(c)
doneRendered = true
break
}
}
}
if err := scanner.Err(); err != nil {
return "", errors.Wrap(err, "scan stream")
}
if !doneRendered {
render.Done(c)
}
return responseText, nil
}

View File

@@ -0,0 +1,58 @@
package replicate
// ModelList is a list of models that can be used with Replicate.
//
// https://replicate.com/pricing
var ModelList = []string{
// -------------------------------------
// image model
// -------------------------------------
"black-forest-labs/flux-1.1-pro",
"black-forest-labs/flux-1.1-pro-ultra",
"black-forest-labs/flux-canny-dev",
"black-forest-labs/flux-canny-pro",
"black-forest-labs/flux-depth-dev",
"black-forest-labs/flux-depth-pro",
"black-forest-labs/flux-dev",
"black-forest-labs/flux-dev-lora",
"black-forest-labs/flux-fill-dev",
"black-forest-labs/flux-fill-pro",
"black-forest-labs/flux-pro",
"black-forest-labs/flux-redux-dev",
"black-forest-labs/flux-redux-schnell",
"black-forest-labs/flux-schnell",
"black-forest-labs/flux-schnell-lora",
"ideogram-ai/ideogram-v2",
"ideogram-ai/ideogram-v2-turbo",
"recraft-ai/recraft-v3",
"recraft-ai/recraft-v3-svg",
"stability-ai/stable-diffusion-3",
"stability-ai/stable-diffusion-3.5-large",
"stability-ai/stable-diffusion-3.5-large-turbo",
"stability-ai/stable-diffusion-3.5-medium",
// -------------------------------------
// language model
// -------------------------------------
"ibm-granite/granite-20b-code-instruct-8k",
"ibm-granite/granite-3.0-2b-instruct",
"ibm-granite/granite-3.0-8b-instruct",
"ibm-granite/granite-8b-code-instruct-128k",
"meta/llama-2-13b",
"meta/llama-2-13b-chat",
"meta/llama-2-70b",
"meta/llama-2-70b-chat",
"meta/llama-2-7b",
"meta/llama-2-7b-chat",
"meta/meta-llama-3.1-405b-instruct",
"meta/meta-llama-3-70b",
"meta/meta-llama-3-70b-instruct",
"meta/meta-llama-3-8b",
"meta/meta-llama-3-8b-instruct",
"mistralai/mistral-7b-instruct-v0.2",
"mistralai/mistral-7b-v0.1",
"mistralai/mixtral-8x7b-instruct-v0.1",
// -------------------------------------
// video model
// -------------------------------------
// "minimax/video-01", // TODO: implement the adaptor
}

View File

@@ -0,0 +1,222 @@
package replicate
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"image"
"image/png"
"io"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"golang.org/x/image/webp"
"golang.org/x/sync/errgroup"
)
// ImagesEditsHandler just copy response body to client
//
// https://replicate.com/black-forest-labs/flux-fill-pro
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
// c.Writer.WriteHeader(resp.StatusCode)
// for k, v := range resp.Header {
// c.Writer.Header().Set(k, v[0])
// }
// if _, err := io.Copy(c.Writer, resp.Body); err != nil {
// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
// }
// defer resp.Body.Close()
// return nil, nil
// }
var errNextLoop = errors.New("next_loop")
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
if resp.StatusCode != http.StatusCreated {
payload, _ := io.ReadAll(resp.Body)
return openai.ErrorWrapper(
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
"bad_status_code", http.StatusInternalServerError),
nil
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
respData := new(ImageResponse)
if err = json.Unmarshal(respBody, respData); err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
for {
err = func() error {
// get task
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
http.MethodGet, respData.URLs.Get, nil)
if err != nil {
return errors.Wrap(err, "new request")
}
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
taskResp, err := http.DefaultClient.Do(taskReq)
if err != nil {
return errors.Wrap(err, "get task")
}
defer taskResp.Body.Close()
if taskResp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(taskResp.Body)
return errors.Errorf("bad status code [%d]%s",
taskResp.StatusCode, string(payload))
}
taskBody, err := io.ReadAll(taskResp.Body)
if err != nil {
return errors.Wrap(err, "read task response")
}
taskData := new(ImageResponse)
if err = json.Unmarshal(taskBody, taskData); err != nil {
return errors.Wrap(err, "decode task response")
}
switch taskData.Status {
case "succeeded":
case "failed", "canceled":
return errors.Errorf("task failed: %s", taskData.Status)
default:
time.Sleep(time.Second * 3)
return errNextLoop
}
output, err := taskData.GetOutput()
if err != nil {
return errors.Wrap(err, "get output")
}
if len(output) == 0 {
return errors.New("response output is empty")
}
var mu sync.Mutex
var pool errgroup.Group
respBody := &openai.ImageResponse{
Created: taskData.CompletedAt.Unix(),
Data: []openai.ImageData{},
}
for _, imgOut := range output {
imgOut := imgOut
pool.Go(func() error {
// download image
downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
http.MethodGet, imgOut, nil)
if err != nil {
return errors.Wrap(err, "new request")
}
imgResp, err := http.DefaultClient.Do(downloadReq)
if err != nil {
return errors.Wrap(err, "download image")
}
defer imgResp.Body.Close()
if imgResp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(imgResp.Body)
return errors.Errorf("bad status code [%d]%s",
imgResp.StatusCode, string(payload))
}
imgData, err := io.ReadAll(imgResp.Body)
if err != nil {
return errors.Wrap(err, "read image")
}
imgData, err = ConvertImageToPNG(imgData)
if err != nil {
return errors.Wrap(err, "convert image")
}
mu.Lock()
respBody.Data = append(respBody.Data, openai.ImageData{
B64Json: fmt.Sprintf("data:image/png;base64,%s",
base64.StdEncoding.EncodeToString(imgData)),
})
mu.Unlock()
return nil
})
}
if err := pool.Wait(); err != nil {
if len(respBody.Data) == 0 {
return errors.WithStack(err)
}
logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
}
c.JSON(http.StatusOK, respBody)
return nil
}()
if err != nil {
if errors.Is(err, errNextLoop) {
continue
}
return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
}
break
}
return nil, nil
}
// ConvertImageToPNG converts a WebP image to PNG format
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
// bypass if it's already a PNG image
if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
return webpData, nil
}
// check if is jpeg, convert to png
if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
img, _, err := image.Decode(bytes.NewReader(webpData))
if err != nil {
return nil, errors.Wrap(err, "decode jpeg")
}
var pngBuffer bytes.Buffer
if err := png.Encode(&pngBuffer, img); err != nil {
return nil, errors.Wrap(err, "encode png")
}
return pngBuffer.Bytes(), nil
}
// Decode the WebP image
img, err := webp.Decode(bytes.NewReader(webpData))
if err != nil {
return nil, errors.Wrap(err, "decode webp")
}
// Encode the image as PNG
var pngBuffer bytes.Buffer
if err := png.Encode(&pngBuffer, img); err != nil {
return nil, errors.Wrap(err, "encode png")
}
return pngBuffer.Bytes(), nil
}

View File

@@ -0,0 +1,159 @@
package replicate
import (
"time"
"github.com/pkg/errors"
)
// DrawImageRequest draw image by fluxpro
//
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
type DrawImageRequest struct {
Input ImageInput `json:"input"`
}
// ImageInput is input of DrawImageByFluxProRequest
//
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
type ImageInput struct {
Steps int `json:"steps" binding:"required,min=1"`
Prompt string `json:"prompt" binding:"required,min=5"`
ImagePrompt string `json:"image_prompt"`
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
Interval int `json:"interval" binding:"required,min=1,max=4"`
AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
Seed int `json:"seed"`
NImages int `json:"n_images" binding:"required,min=1,max=8"`
Width int `json:"width" binding:"required,min=256,max=1440"`
Height int `json:"height" binding:"required,min=256,max=1440"`
}
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
//
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
type InpaintingImageByFlusReplicateRequest struct {
Input FluxInpaintingInput `json:"input"`
}
// FluxInpaintingInput is input of DrawImageByFluxProRequest
//
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
type FluxInpaintingInput struct {
Mask string `json:"mask" binding:"required"`
Image string `json:"image" binding:"required"`
Seed int `json:"seed"`
Steps int `json:"steps" binding:"required,min=1"`
Prompt string `json:"prompt" binding:"required,min=5"`
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
OutputFormat string `json:"output_format"`
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
PromptUnsampling bool `json:"prompt_unsampling"`
}
// ImageResponse is response of DrawImageByFluxProRequest
//
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
type ImageResponse struct {
CompletedAt time.Time `json:"completed_at"`
CreatedAt time.Time `json:"created_at"`
DataRemoved bool `json:"data_removed"`
Error string `json:"error"`
ID string `json:"id"`
Input DrawImageRequest `json:"input"`
Logs string `json:"logs"`
Metrics FluxMetrics `json:"metrics"`
// Output could be `string` or `[]string`
Output any `json:"output"`
StartedAt time.Time `json:"started_at"`
Status string `json:"status"`
URLs FluxURLs `json:"urls"`
Version string `json:"version"`
}
func (r *ImageResponse) GetOutput() ([]string, error) {
switch v := r.Output.(type) {
case string:
return []string{v}, nil
case []string:
return v, nil
case nil:
return nil, nil
case []interface{}:
// convert []interface{} to []string
ret := make([]string, len(v))
for idx, vv := range v {
if vvv, ok := vv.(string); ok {
ret[idx] = vvv
} else {
return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
}
}
return ret, nil
default:
return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
}
}
// FluxMetrics is metrics of ImageResponse
type FluxMetrics struct {
ImageCount int `json:"image_count"`
PredictTime float64 `json:"predict_time"`
TotalTime float64 `json:"total_time"`
}
// FluxURLs is urls of ImageResponse
type FluxURLs struct {
Get string `json:"get"`
Cancel string `json:"cancel"`
}
type ReplicateChatRequest struct {
Input ChatInput `json:"input" form:"input" binding:"required"`
}
// ChatInput is input of ChatByReplicateRequest
//
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema
type ChatInput struct {
TopK int `json:"top_k"`
TopP float64 `json:"top_p"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
MinTokens int `json:"min_tokens"`
Temperature float64 `json:"temperature"`
SystemPrompt string `json:"system_prompt"`
StopSequences string `json:"stop_sequences"`
PromptTemplate string `json:"prompt_template"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
}
// ChatResponse is response of ChatByReplicateRequest
//
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json
type ChatResponse struct {
CompletedAt time.Time `json:"completed_at"`
CreatedAt time.Time `json:"created_at"`
DataRemoved bool `json:"data_removed"`
Error string `json:"error"`
ID string `json:"id"`
Input ChatInput `json:"input"`
Logs string `json:"logs"`
Metrics FluxMetrics `json:"metrics"`
// Output could be `string` or `[]string`
Output []string `json:"output"`
StartedAt time.Time `json:"started_at"`
Status string `json:"status"`
URLs ChatResponseUrl `json:"urls"`
Version string `json:"version"`
}
// ChatResponseUrl is task urls of ChatResponse
type ChatResponseUrl struct {
Stream string `json:"stream"`
Get string `json:"get"`
Cancel string `json:"cancel"`
}

View File

@@ -2,16 +2,19 @@ package tencent
import ( import (
"errors" "errors"
"io"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "github.com/songquanpeng/one-api/relay/relaymode"
"net/http"
"strconv"
"strings"
) )
// https://cloud.tencent.com/document/api/1729/101837 // https://cloud.tencent.com/document/api/1729/101837
@@ -52,10 +55,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if err != nil { if err != nil {
return nil, err return nil, err
} }
tencentRequest := ConvertRequest(*request) var convertedRequest any
switch relayMode {
case relaymode.Embeddings:
a.Action = "GetEmbedding"
convertedRequest = ConvertEmbeddingRequest(*request)
default:
a.Action = "ChatCompletions"
convertedRequest = ConvertRequest(*request)
}
// we have to calculate the sign here // we have to calculate the sign here
a.Sign = GetSign(*tencentRequest, a, secretId, secretKey) a.Sign = GetSign(convertedRequest, a, secretId, secretKey)
return tencentRequest, nil return convertedRequest, nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
@@ -75,7 +86,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
err, responseText = StreamHandler(c, resp) err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else { } else {
err, usage = Handler(c, resp) switch meta.Mode {
case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
} }
return return
} }

View File

@@ -6,4 +6,5 @@ var ModelList = []string{
"hunyuan-standard-256K", "hunyuan-standard-256K",
"hunyuan-pro", "hunyuan-pro",
"hunyuan-vision", "hunyuan-vision",
"hunyuan-embedding",
} }

View File

@@ -8,7 +8,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
@@ -16,11 +15,14 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
@@ -44,8 +46,68 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
} }
} }
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
InputList: request.ParseInput(),
}
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var tencentResponseP EmbeddingResponseP
err := json.NewDecoder(resp.Body).Decode(&tencentResponseP)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
tencentResponse := tencentResponseP.Response
if tencentResponse.Error.Code != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: tencentResponse.Error.Message,
Code: tencentResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
requestModel := c.GetString(ctxkey.RequestModel)
fullTextResponse := embeddingResponseTencent2OpenAI(&tencentResponse)
fullTextResponse.Model = requestModel
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseTencent2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
Model: "hunyuan-embedding",
Usage: model.Usage{TotalTokens: response.EmbeddingUsage.TotalTokens},
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: response.ReqID,
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Usage: model.Usage{ Usage: model.Usage{
@@ -148,7 +210,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
TencentResponse = responseP.Response TencentResponse = responseP.Response
if TencentResponse.Error.Code != 0 { if TencentResponse.Error.Code != "" {
return &model.ErrorWithStatusCode{ return &model.ErrorWithStatusCode{
Error: model.Error{ Error: model.Error{
Message: TencentResponse.Error.Message, Message: TencentResponse.Error.Message,
@@ -195,7 +257,7 @@ func hmacSha256(s, key string) string {
return string(hashed.Sum(nil)) return string(hashed.Sum(nil))
} }
func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string { func GetSign(req any, adaptor *Adaptor, secId, secKey string) string {
// build canonical request string // build canonical request string
host := "hunyuan.tencentcloudapi.com" host := "hunyuan.tencentcloudapi.com"
httpRequestMethod := "POST" httpRequestMethod := "POST"

View File

@@ -35,16 +35,16 @@ type ChatRequest struct {
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。 // 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。 // 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。 // 3. 非必要不建议使用,不合理的取值会影响效果。
TopP *float64 `json:"TopP"` TopP *float64 `json:"TopP,omitempty"`
// 说明: // 说明:
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。 // 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。 // 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。 // 3. 非必要不建议使用,不合理的取值会影响效果。
Temperature *float64 `json:"Temperature"` Temperature *float64 `json:"Temperature,omitempty"`
} }
type Error struct { type Error struct {
Code int `json:"Code"` Code string `json:"Code"`
Message string `json:"Message"` Message string `json:"Message"`
} }
@@ -61,15 +61,41 @@ type ResponseChoices struct {
} }
type ChatResponse struct { type ChatResponse struct {
Choices []ResponseChoices `json:"Choices,omitempty"` // 结果 Choices []ResponseChoices `json:"Choices,omitempty"` // 结果
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串 Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
Id string `json:"Id,omitempty"` // 会话 id Id string `json:"Id,omitempty"` // 会话 id
Usage Usage `json:"Usage,omitempty"` // token 数量 Usage Usage `json:"Usage,omitempty"` // token 数量
Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值 Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"Note,omitempty"` // 注释 Note string `json:"Note,omitempty"` // 注释
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参 ReqID string `json:"RequestId,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
} }
type ChatResponseP struct { type ChatResponseP struct {
Response ChatResponse `json:"Response,omitempty"` Response ChatResponse `json:"Response,omitempty"`
} }
type EmbeddingRequest struct {
InputList []string `json:"InputList"`
}
type EmbeddingData struct {
Embedding []float64 `json:"Embedding"`
Index int `json:"Index"`
Object string `json:"Object"`
}
type EmbeddingUsage struct {
PromptTokens int `json:"PromptTokens"`
TotalTokens int `json:"TotalTokens"`
}
type EmbeddingResponse struct {
Data []EmbeddingData `json:"Data"`
EmbeddingUsage EmbeddingUsage `json:"Usage,omitempty"`
RequestId string `json:"RequestId,omitempty"`
Error Error `json:"Error,omitempty"`
}
type EmbeddingResponseP struct {
Response EmbeddingResponse `json:"Response,omitempty"`
}

View File

@@ -15,7 +15,11 @@ import (
) )
var ModelList = []string{ var ModelList = []string{
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002", "gemini-pro", "gemini-pro-vision",
"gemini-1.5-pro-001", "gemini-1.5-flash-001",
"gemini-1.5-pro-002", "gemini-1.5-flash-002",
"gemini-2.0-flash-exp",
"gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21",
} }
type Adaptor struct { type Adaptor struct {

View File

@@ -19,6 +19,7 @@ const (
DeepL DeepL
VertexAI VertexAI
Proxy Proxy
Replicate
Dummy // this one is only for count, do not add any channel after this Dummy // this one is only for count, do not add any channel after this
) )

View File

@@ -3,6 +3,7 @@ package billing
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
) )
@@ -31,8 +32,17 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQ
} }
// totalQuota is total quota consumed // totalQuota is total quota consumed
if totalQuota != 0 { if totalQuota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("倍率%.2f × %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) model.RecordConsumeLog(ctx, &model.Log{
UserId: userId,
ChannelId: channelId,
PromptTokens: int(totalQuota),
CompletionTokens: 0,
ModelName: modelName,
TokenName: tokenName,
Quota: int(totalQuota),
Content: logContent,
})
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota) model.UpdateChannelUsedQuota(channelId, totalQuota)
} }

View File

@@ -9,9 +9,10 @@ import (
) )
const ( const (
USD2RMB = 7 USD2RMB = 7
USD = 500 // $0.002 = 1 -> $1 = 500 USD = 500 // $0.002 = 1 -> $1 = 500
RMB = USD / USD2RMB MILLI_USD = 1.0 / 1000 * USD
RMB = USD / USD2RMB
) )
// ModelRatio // ModelRatio
@@ -37,6 +38,7 @@ var ModelRatio = map[string]float64{
"chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens "chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens "gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens
"gpt-4o-2024-11-20": 1.25, // $0.0025 / 1K tokens
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
"gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens "gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens "gpt-4-vision-preview": 5, // $0.01 / 1K tokens
@@ -48,8 +50,14 @@ var ModelRatio = map[string]float64{
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens "o1": 7.5, // $15.00 / 1M input tokens
"babbage-002": 0.2, // $0.0004 / 1K tokens "o1-2024-12-17": 7.5,
"o1-preview": 7.5, // $15.00 / 1M input tokens
"o1-preview-2024-09-12": 7.5,
"o1-mini": 1.5, // $3.00 / 1M input tokens
"o1-mini-2024-09-12": 1.5,
"davinci-002": 1, // $0.002 / 1K tokens
"babbage-002": 0.2, // $0.0004 / 1K tokens
"text-ada-001": 0.2, "text-ada-001": 0.2,
"text-babbage-001": 0.25, "text-babbage-001": 0.25,
"text-curie-001": 1, "text-curie-001": 1,
@@ -102,11 +110,16 @@ var ModelRatio = map[string]float64{
"bge-large-en": 0.002 * RMB, "bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB, "tao-8k": 0.002 * RMB,
// https://ai.google.dev/pricing // https://ai.google.dev/pricing
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro": 1, "gemini-1.0-pro": 1,
"gemini-1.5-flash": 1, "gemini-1.5-pro": 1,
"gemini-1.5-pro": 1, "gemini-1.5-pro-001": 1,
"aqa": 1, "gemini-1.5-flash": 1,
"gemini-1.5-flash-001": 1,
"gemini-2.0-flash-exp": 1,
"gemini-2.0-flash-thinking-exp": 1,
"gemini-2.0-flash-thinking-exp-01-21": 1,
"aqa": 1,
// https://open.bigmodel.cn/pricing // https://open.bigmodel.cn/pricing
"glm-4": 0.1 * RMB, "glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB, "glm-4v": 0.1 * RMB,
@@ -118,29 +131,94 @@ var ModelRatio = map[string]float64{
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"cogview-3": 0.25 * RMB, "cogview-3": 0.25 * RMB,
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens "qwen-turbo": 1.4286, // ¥0.02 / 1k tokens
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens "qwen-turbo-latest": 1.4286,
"qwen-max": 1.4286, // ¥0.02 / 1k tokens "qwen-plus": 1.4286,
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens "qwen-plus-latest": 1.4286,
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens "qwen-max": 1.4286,
"ali-stable-diffusion-xl": 8, "qwen-max-latest": 1.4286,
"ali-stable-diffusion-v1.5": 8, "qwen-max-longcontext": 1.4286,
"wanx-v1": 8, "qwen-vl-max": 1.4286,
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens "qwen-vl-max-latest": 1.4286,
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens "qwen-vl-plus": 1.4286,
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens "qwen-vl-plus-latest": 1.4286,
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens "qwen-vl-ocr": 1.4286,
"SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens "qwen-vl-ocr-latest": 1.4286,
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens "qwen-audio-turbo": 1.4286,
"SparkDesk-v3.5-32K": 1.2858, // ¥0.018 / 1k tokens "qwen-math-plus": 1.4286,
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens "qwen-math-plus-latest": 1.4286,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "qwen-math-turbo": 1.4286,
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "qwen-math-turbo-latest": 1.4286,
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens "qwen-coder-plus": 1.4286,
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens "qwen-coder-plus-latest": 1.4286,
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 "qwen-coder-turbo": 1.4286,
"ChatStd": 0.01 * RMB, "qwen-coder-turbo-latest": 1.4286,
"ChatPro": 0.1 * RMB, "qwq-32b-preview": 1.4286,
"qwen2.5-72b-instruct": 1.4286,
"qwen2.5-32b-instruct": 1.4286,
"qwen2.5-14b-instruct": 1.4286,
"qwen2.5-7b-instruct": 1.4286,
"qwen2.5-3b-instruct": 1.4286,
"qwen2.5-1.5b-instruct": 1.4286,
"qwen2.5-0.5b-instruct": 1.4286,
"qwen2-72b-instruct": 1.4286,
"qwen2-57b-a14b-instruct": 1.4286,
"qwen2-7b-instruct": 1.4286,
"qwen2-1.5b-instruct": 1.4286,
"qwen2-0.5b-instruct": 1.4286,
"qwen1.5-110b-chat": 1.4286,
"qwen1.5-72b-chat": 1.4286,
"qwen1.5-32b-chat": 1.4286,
"qwen1.5-14b-chat": 1.4286,
"qwen1.5-7b-chat": 1.4286,
"qwen1.5-1.8b-chat": 1.4286,
"qwen1.5-0.5b-chat": 1.4286,
"qwen-72b-chat": 1.4286,
"qwen-14b-chat": 1.4286,
"qwen-7b-chat": 1.4286,
"qwen-1.8b-chat": 1.4286,
"qwen-1.8b-longcontext-chat": 1.4286,
"qwen2-vl-7b-instruct": 1.4286,
"qwen2-vl-2b-instruct": 1.4286,
"qwen-vl-v1": 1.4286,
"qwen-vl-chat-v1": 1.4286,
"qwen2-audio-instruct": 1.4286,
"qwen-audio-chat": 1.4286,
"qwen2.5-math-72b-instruct": 1.4286,
"qwen2.5-math-7b-instruct": 1.4286,
"qwen2.5-math-1.5b-instruct": 1.4286,
"qwen2-math-72b-instruct": 1.4286,
"qwen2-math-7b-instruct": 1.4286,
"qwen2-math-1.5b-instruct": 1.4286,
"qwen2.5-coder-32b-instruct": 1.4286,
"qwen2.5-coder-14b-instruct": 1.4286,
"qwen2.5-coder-7b-instruct": 1.4286,
"qwen2.5-coder-3b-instruct": 1.4286,
"qwen2.5-coder-1.5b-instruct": 1.4286,
"qwen2.5-coder-0.5b-instruct": 1.4286,
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"text-embedding-v3": 0.05,
"text-embedding-v2": 0.05,
"text-embedding-async-v2": 0.05,
"text-embedding-async-v1": 0.05,
"ali-stable-diffusion-xl": 8.00,
"ali-stable-diffusion-v1.5": 8.00,
"wanx-v1": 8.00,
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5-32K": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"ChatStd": 0.01 * RMB,
"ChatPro": 0.1 * RMB,
// https://platform.moonshot.cn/pricing // https://platform.moonshot.cn/pricing
"moonshot-v1-8k": 0.012 * RMB, "moonshot-v1-8k": 0.012 * RMB,
"moonshot-v1-32k": 0.024 * RMB, "moonshot-v1-32k": 0.024 * RMB,
@@ -203,20 +281,69 @@ var ModelRatio = map[string]float64{
"command-r": 0.5 / 1000 * USD, "command-r": 0.5 / 1000 * USD,
"command-r-plus": 3.0 / 1000 * USD, "command-r-plus": 3.0 / 1000 * USD,
// https://platform.deepseek.com/api-docs/pricing/ // https://platform.deepseek.com/api-docs/pricing/
"deepseek-chat": 1.0 / 1000 * RMB, "deepseek-chat": 0.14 * MILLI_USD,
"deepseek-coder": 1.0 / 1000 * RMB, "deepseek-reasoner": 0.55 * MILLI_USD,
// https://www.deepl.com/pro?cta=header-prices // https://www.deepl.com/pro?cta=header-prices
"deepl-zh": 25.0 / 1000 * USD, "deepl-zh": 25.0 / 1000 * USD,
"deepl-en": 25.0 / 1000 * USD, "deepl-en": 25.0 / 1000 * USD,
"deepl-ja": 25.0 / 1000 * USD, "deepl-ja": 25.0 / 1000 * USD,
// https://console.x.ai/ // https://console.x.ai/
"grok-beta": 5.0 / 1000 * USD, "grok-beta": 5.0 / 1000 * USD,
// replicate charges based on the number of generated images
// https://replicate.com/pricing
"black-forest-labs/flux-1.1-pro": 0.04 * USD,
"black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD,
"black-forest-labs/flux-canny-dev": 0.025 * USD,
"black-forest-labs/flux-canny-pro": 0.05 * USD,
"black-forest-labs/flux-depth-dev": 0.025 * USD,
"black-forest-labs/flux-depth-pro": 0.05 * USD,
"black-forest-labs/flux-dev": 0.025 * USD,
"black-forest-labs/flux-dev-lora": 0.032 * USD,
"black-forest-labs/flux-fill-dev": 0.04 * USD,
"black-forest-labs/flux-fill-pro": 0.05 * USD,
"black-forest-labs/flux-pro": 0.055 * USD,
"black-forest-labs/flux-redux-dev": 0.025 * USD,
"black-forest-labs/flux-redux-schnell": 0.003 * USD,
"black-forest-labs/flux-schnell": 0.003 * USD,
"black-forest-labs/flux-schnell-lora": 0.02 * USD,
"ideogram-ai/ideogram-v2": 0.08 * USD,
"ideogram-ai/ideogram-v2-turbo": 0.05 * USD,
"recraft-ai/recraft-v3": 0.04 * USD,
"recraft-ai/recraft-v3-svg": 0.08 * USD,
"stability-ai/stable-diffusion-3": 0.035 * USD,
"stability-ai/stable-diffusion-3.5-large": 0.065 * USD,
"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
"stability-ai/stable-diffusion-3.5-medium": 0.035 * USD,
// replicate chat models
"ibm-granite/granite-20b-code-instruct-8k": 0.100 * USD,
"ibm-granite/granite-3.0-2b-instruct": 0.030 * USD,
"ibm-granite/granite-3.0-8b-instruct": 0.050 * USD,
"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD,
"meta/llama-2-13b": 0.100 * USD,
"meta/llama-2-13b-chat": 0.100 * USD,
"meta/llama-2-70b": 0.650 * USD,
"meta/llama-2-70b-chat": 0.650 * USD,
"meta/llama-2-7b": 0.050 * USD,
"meta/llama-2-7b-chat": 0.050 * USD,
"meta/meta-llama-3.1-405b-instruct": 9.500 * USD,
"meta/meta-llama-3-70b": 0.650 * USD,
"meta/meta-llama-3-70b-instruct": 0.650 * USD,
"meta/meta-llama-3-8b": 0.050 * USD,
"meta/meta-llama-3-8b-instruct": 0.050 * USD,
"mistralai/mistral-7b-instruct-v0.2": 0.050 * USD,
"mistralai/mistral-7b-v0.1": 0.050 * USD,
"mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD,
} }
var CompletionRatio = map[string]float64{ var CompletionRatio = map[string]float64{
// aws llama3 // aws llama3
"llama3-8b-8192(33)": 0.0006 / 0.0003, "llama3-8b-8192(33)": 0.0006 / 0.0003,
"llama3-70b-8192(33)": 0.0035 / 0.00265, "llama3-70b-8192(33)": 0.0035 / 0.00265,
// whisper
"whisper-1": 0, // only count input tokens
// deepseek
"deepseek-chat": 0.28 / 0.14,
"deepseek-reasoner": 2.19 / 0.55,
} }
var ( var (
@@ -334,16 +461,22 @@ func GetCompletionRatio(name string, channelType int) float64 {
return 4.0 / 3.0 return 4.0 / 3.0
} }
if strings.HasPrefix(name, "gpt-4") { if strings.HasPrefix(name, "gpt-4") {
if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" { if strings.HasPrefix(name, "gpt-4o") {
if name == "gpt-4o-2024-05-13" {
return 3
}
return 4 return 4
} }
if strings.HasPrefix(name, "gpt-4-turbo") || if strings.HasPrefix(name, "gpt-4-turbo") ||
strings.HasPrefix(name, "gpt-4o") ||
strings.HasSuffix(name, "preview") { strings.HasSuffix(name, "preview") {
return 3 return 3
} }
return 2 return 2
} }
// including o1, o1-preview, o1-mini
if strings.HasPrefix(name, "o1") {
return 4
}
if name == "chatgpt-4o-latest" { if name == "chatgpt-4o-latest" {
return 3 return 3
} }
@@ -362,6 +495,7 @@ func GetCompletionRatio(name string, channelType int) float64 {
if strings.HasPrefix(name, "deepseek-") { if strings.HasPrefix(name, "deepseek-") {
return 2 return 2
} }
switch name { switch name {
case "llama2-70b-4096": case "llama2-70b-4096":
return 0.8 / 0.64 return 0.8 / 0.64
@@ -377,6 +511,35 @@ func GetCompletionRatio(name string, channelType int) float64 {
return 5 return 5
case "grok-beta": case "grok-beta":
return 3 return 3
// Replicate Models
// https://replicate.com/pricing
case "ibm-granite/granite-20b-code-instruct-8k":
return 5
case "ibm-granite/granite-3.0-2b-instruct":
return 8.333333333333334
case "ibm-granite/granite-3.0-8b-instruct",
"ibm-granite/granite-8b-code-instruct-128k":
return 5
case "meta/llama-2-13b",
"meta/llama-2-13b-chat",
"meta/llama-2-7b",
"meta/llama-2-7b-chat",
"meta/meta-llama-3-8b",
"meta/meta-llama-3-8b-instruct":
return 5
case "meta/llama-2-70b",
"meta/llama-2-70b-chat",
"meta/meta-llama-3-70b",
"meta/meta-llama-3-70b-instruct":
return 2.750 / 0.650 // ≈4.230769
case "meta/meta-llama-3.1-405b-instruct":
return 1
case "mistralai/mistral-7b-instruct-v0.2",
"mistralai/mistral-7b-v0.1":
return 5
case "mistralai/mixtral-8x7b-instruct-v0.1":
return 1.000 / 0.300 // ≈3.333333
} }
return 1 return 1
} }

View File

@@ -47,5 +47,6 @@ const (
Proxy Proxy
SiliconFlow SiliconFlow
XAI XAI
Replicate
Dummy Dummy
) )

View File

@@ -37,6 +37,8 @@ func ToAPIType(channelType int) int {
apiType = apitype.DeepL apiType = apitype.DeepL
case VertextAI: case VertextAI:
apiType = apitype.VertexAI apiType = apitype.VertexAI
case Replicate:
apiType = apitype.Replicate
case Proxy: case Proxy:
apiType = apitype.Proxy apiType = apitype.Proxy
} }

View File

@@ -47,6 +47,7 @@ var ChannelBaseURLs = []string{
"", // 43 "", // 43
"https://api.siliconflow.cn", // 44 "https://api.siliconflow.cn", // 44
"https://api.x.ai", // 45 "https://api.x.ai", // 45
"https://api.replicate.com/v1/models/", // 46
} }
func init() { func init() {

View File

@@ -1,5 +1,6 @@
package role package role
const ( const (
System = "system"
Assistant = "assistant" Assistant = "assistant"
) )

View File

@@ -110,16 +110,9 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}() }()
// map model name // map model name
modelMapping := c.GetString(ctxkey.ModelMapping) modelMapping := c.GetStringMapString(ctxkey.ModelMapping)
if modelMapping != "" { if modelMapping != nil && modelMapping[audioModel] != "" {
modelMap := make(map[string]string) audioModel = modelMapping[audioModel]
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[audioModel] != "" {
audioModel = modelMap[audioModel]
}
} }
baseURL := channeltype.ChannelBaseURLs[channelType] baseURL := channeltype.ChannelBaseURLs[channelType]

View File

@@ -8,7 +8,11 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/relay/constant/role"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
@@ -90,7 +94,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
return preConsumedQuota, nil return preConsumedQuota, nil
} }
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) {
if usage == nil { if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected") logger.Error(ctx, "usage is nil, which is unexpected")
return return
@@ -118,8 +122,20 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
if err != nil { if err != nil {
logger.Error(ctx, "error update user quota cache: "+err.Error()) logger.Error(ctx, "error update user quota cache: "+err.Error())
} }
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) logContent := fmt.Sprintf("倍率%.2f × %.2f × %.2f", modelRatio, groupRatio, completionRatio)
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) model.RecordConsumeLog(ctx, &model.Log{
UserId: meta.UserId,
ChannelId: meta.ChannelId,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
ModelName: textRequest.Model,
TokenName: meta.TokenName,
Quota: int(quota),
Content: logContent,
IsStream: meta.IsStream,
ElapsedTime: helper.CalcElapsedTime(meta.StartTime),
SystemPromptReset: systemPromptReset,
})
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota)
} }
@@ -142,15 +158,41 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
} }
return true return true
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK &&
// replicate return 201 to create a task
resp.StatusCode != http.StatusCreated {
return true return true
} }
if meta.ChannelType == channeltype.DeepL { if meta.ChannelType == channeltype.DeepL {
// skip stream check for deepl // skip stream check for deepl
return false return false
} }
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") &&
// Even if stream mode is enabled, replicate will first return a task info in JSON format,
// requiring the client to request the stream endpoint in the task info
meta.ChannelType != channeltype.Replicate {
return true return true
} }
return false return false
} }
func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) {
if prompt == "" {
return false
}
if len(request.Messages) == 0 {
return false
}
if request.Messages[0].Role == role.System {
request.Messages[0].Content = prompt
logger.Infof(ctx, "rewrite system prompt")
return true
}
request.Messages = append([]relaymodel.Message{{
Role: role.System,
Content: prompt,
}}, request.Messages...)
logger.Infof(ctx, "add system prompt")
return true
}

View File

@@ -10,6 +10,7 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
@@ -22,7 +23,7 @@ import (
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
) )
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) {
imageRequest := &relaymodel.ImageRequest{} imageRequest := &relaymodel.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest) err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil { if err != nil {
@@ -65,7 +66,7 @@ func getImageSizeRatio(model string, size string) float64 {
return 1 return 1
} }
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
// check prompt length // check prompt length
if imageRequest.Prompt == "" { if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
@@ -150,12 +151,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
} }
adaptor.Init(meta) adaptor.Init(meta)
// these adaptors need to convert the request
switch meta.ChannelType { switch meta.ChannelType {
case channeltype.Ali: case channeltype.Zhipu,
fallthrough channeltype.Ali,
case channeltype.Baidu: channeltype.Replicate,
fallthrough channeltype.Baidu:
case channeltype.Zhipu:
finalRequest, err := adaptor.ConvertImageRequest(imageRequest) finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
@@ -172,7 +173,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) var quota int64
switch meta.ChannelType {
case channeltype.Replicate:
// replicate always return 1 image
quota = int64(ratio * imageCostRatio * 1000)
default:
quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
}
if userQuota-quota < 0 { if userQuota-quota < 0 {
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
@@ -186,7 +194,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
} }
defer func(ctx context.Context) { defer func(ctx context.Context) {
if resp != nil && resp.StatusCode != http.StatusOK { if resp != nil &&
resp.StatusCode != http.StatusCreated && // replicate returns 201
resp.StatusCode != http.StatusOK {
return return
} }
@@ -200,8 +210,17 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
} }
if quota != 0 { if quota != 0 {
tokenName := c.GetString(ctxkey.TokenName) tokenName := c.GetString(ctxkey.TokenName)
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("倍率%.2f × %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) model.RecordConsumeLog(ctx, &model.Log{
UserId: meta.UserId,
ChannelId: meta.ChannelId,
PromptTokens: 0,
CompletionTokens: 0,
ModelName: imageRequest.Model,
TokenName: tokenName,
Quota: int(quota),
Content: logContent,
})
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
channelId := c.GetInt(ctxkey.ChannelId) channelId := c.GetInt(ctxkey.ChannelId)
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)

View File

@@ -8,6 +8,8 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
@@ -35,6 +37,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
meta.OriginModelName = textRequest.Model meta.OriginModelName = textRequest.Model
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model meta.ActualModelName = textRequest.Model
// set system prompt if not empty
systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt)
// get model ratio & group ratio // get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group) groupRatio := billingratio.GetGroupRatio(meta.Group)
@@ -79,12 +83,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
return respErr return respErr
} }
// post-consume quota // post-consume quota
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset)
return nil return nil
} }
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
// no need to convert request for openai // no need to convert request for openai
return c.Request.Body, nil return c.Request.Body, nil
} }

View File

@@ -1,12 +1,15 @@
package meta package meta
import ( import (
"strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"strings"
) )
type Meta struct { type Meta struct {
@@ -30,6 +33,8 @@ type Meta struct {
ActualModelName string ActualModelName string
RequestURLPath string RequestURLPath string
PromptTokens int // only for DoResponse PromptTokens int // only for DoResponse
SystemPrompt string
StartTime time.Time
} }
func GetByContext(c *gin.Context) *Meta { func GetByContext(c *gin.Context) *Meta {
@@ -46,6 +51,8 @@ func GetByContext(c *gin.Context) *Meta {
BaseURL: c.GetString(ctxkey.BaseURL), BaseURL: c.GetString(ctxkey.BaseURL),
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
RequestURLPath: c.Request.URL.String(), RequestURLPath: c.Request.URL.String(),
SystemPrompt: c.GetString(ctxkey.SystemPrompt),
StartTime: time.Now(),
} }
cfg, ok := c.Get(ctxkey.Config) cfg, ok := c.Get(ctxkey.Config)
if ok { if ok {

View File

@@ -9,6 +9,7 @@ import (
func SetRelayRouter(router *gin.Engine) { func SetRelayRouter(router *gin.Engine) {
router.Use(middleware.CORS()) router.Use(middleware.CORS())
router.Use(middleware.GzipDecodeMiddleware())
// https://platform.openai.com/docs/api-reference/introduction // https://platform.openai.com/docs/api-reference/introduction
modelsRouter := router.Group("/v1/models") modelsRouter := router.Group("/v1/models")
modelsRouter.Use(middleware.TokenAuth()) modelsRouter.Use(middleware.TokenAuth())

View File

@@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
{ key: 43, text: 'Proxy', value: 43, color: 'blue' }, { key: 43, text: 'Proxy', value: 43, color: 'blue' },
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
{ key: 45, text: 'xAI', value: 45, color: 'blue' }, { key: 45, text: 'xAI', value: 45, color: 'blue' },
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' }, { key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' }, { key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },

View File

@@ -43,6 +43,7 @@ const EditChannel = (props) => {
base_url: '', base_url: '',
other: '', other: '',
model_mapping: '', model_mapping: '',
system_prompt: '',
models: [], models: [],
auto_ban: 1, auto_ban: 1,
groups: ['default'] groups: ['default']
@@ -304,163 +305,163 @@ const EditChannel = (props) => {
width={isMobile() ? '100%' : 600} width={isMobile() ? '100%' : 600}
> >
<Spin spinning={loading}> <Spin spinning={loading}>
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>类型</Typography.Text> <Typography.Text strong>类型</Typography.Text>
</div> </div>
<Select <Select
name='type' name='type'
required required
optionList={CHANNEL_OPTIONS} optionList={CHANNEL_OPTIONS}
value={inputs.type} value={inputs.type}
onChange={value => handleInputChange('type', value)} onChange={value => handleInputChange('type', value)}
style={{width: '50%'}} style={{ width: '50%' }}
/> />
{ {
inputs.type === 3 && ( inputs.type === 3 && (
<> <>
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Banner type={"warning"} description={ <Banner type={"warning"} description={
<> <>
注意<strong>模型部署名称必须和模型名称保持一致</strong> One API 注意<strong>模型部署名称必须和模型名称保持一致</strong> One API
model model
参数替换为你的部署名称模型名称中的点会被剔除<a target='_blank' 参数替换为你的部署名称模型名称中的点会被剔除<a target='_blank'
href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a> href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>
</> </>
}> }>
</Banner> </Banner>
</div> </div>
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>AZURE_OPENAI_ENDPOINT</Typography.Text> <Typography.Text strong>AZURE_OPENAI_ENDPOINT</Typography.Text>
</div> </div>
<Input <Input
label='AZURE_OPENAI_ENDPOINT' label='AZURE_OPENAI_ENDPOINT'
name='azure_base_url' name='azure_base_url'
placeholder={'请输入 AZURE_OPENAI_ENDPOINT例如https://docs-test-001.openai.azure.com'} placeholder={'请输入 AZURE_OPENAI_ENDPOINT例如https://docs-test-001.openai.azure.com'}
onChange={value => { onChange={value => {
handleInputChange('base_url', value) handleInputChange('base_url', value)
}} }}
value={inputs.base_url} value={inputs.base_url}
autoComplete='new-password' autoComplete='new-password'
/> />
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>默认 API 版本</Typography.Text> <Typography.Text strong>默认 API 版本</Typography.Text>
</div> </div>
<Input <Input
label='默认 API 版本' label='默认 API 版本'
name='azure_other' name='azure_other'
placeholder={'请输入默认 API 版本例如2024-03-01-preview该配置可以被实际的请求查询参数所覆盖'} placeholder={'请输入默认 API 版本例如2024-03-01-preview该配置可以被实际的请求查询参数所覆盖'}
onChange={value => { onChange={value => {
handleInputChange('other', value) handleInputChange('other', value)
}} }}
value={inputs.other} value={inputs.other}
autoComplete='new-password' autoComplete='new-password'
/> />
</> </>
) )
} }
{ {
inputs.type === 8 && ( inputs.type === 8 && (
<> <>
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>Base URL</Typography.Text> <Typography.Text strong>Base URL</Typography.Text>
</div> </div>
<Input <Input
name='base_url' name='base_url'
placeholder={'请输入自定义渠道的 Base URL'} placeholder={'请输入自定义渠道的 Base URL'}
onChange={value => { onChange={value => {
handleInputChange('base_url', value) handleInputChange('base_url', value)
}} }}
value={inputs.base_url} value={inputs.base_url}
autoComplete='new-password' autoComplete='new-password'
/> />
</> </>
) )
} }
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>名称</Typography.Text> <Typography.Text strong>名称</Typography.Text>
</div> </div>
<Input <Input
required required
name='name' name='name'
placeholder={'请为渠道命名'} placeholder={'请为渠道命名'}
onChange={value => { onChange={value => {
handleInputChange('name', value) handleInputChange('name', value)
}} }}
value={inputs.name} value={inputs.name}
autoComplete='new-password' autoComplete='new-password'
/> />
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>分组</Typography.Text> <Typography.Text strong>分组</Typography.Text>
</div> </div>
<Select <Select
placeholder={'请选择可以使用该渠道的分组'} placeholder={'请选择可以使用该渠道的分组'}
name='groups' name='groups'
required required
multiple multiple
selection selection
allowAdditions allowAdditions
additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'} additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
onChange={value => { onChange={value => {
handleInputChange('groups', value) handleInputChange('groups', value)
}} }}
value={inputs.groups} value={inputs.groups}
autoComplete='new-password' autoComplete='new-password'
optionList={groupOptions} optionList={groupOptions}
/> />
{ {
inputs.type === 18 && ( inputs.type === 18 && (
<> <>
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>模型版本</Typography.Text> <Typography.Text strong>模型版本</Typography.Text>
</div> </div>
<Input <Input
name='other' name='other'
placeholder={'请输入星火大模型版本注意是接口地址中的版本号例如v2.1'} placeholder={'请输入星火大模型版本注意是接口地址中的版本号例如v2.1'}
onChange={value => { onChange={value => {
handleInputChange('other', value) handleInputChange('other', value)
}} }}
value={inputs.other} value={inputs.other}
autoComplete='new-password' autoComplete='new-password'
/> />
</> </>
) )
} }
{ {
inputs.type === 21 && ( inputs.type === 21 && (
<> <>
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>知识库 ID</Typography.Text> <Typography.Text strong>知识库 ID</Typography.Text>
</div> </div>
<Input <Input
label='知识库 ID' label='知识库 ID'
name='other' name='other'
placeholder={'请输入知识库 ID例如123456'} placeholder={'请输入知识库 ID例如123456'}
onChange={value => { onChange={value => {
handleInputChange('other', value) handleInputChange('other', value)
}} }}
value={inputs.other} value={inputs.other}
autoComplete='new-password' autoComplete='new-password'
/> />
</> </>
) )
} }
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>模型</Typography.Text> <Typography.Text strong>模型</Typography.Text>
</div> </div>
<Select <Select
placeholder={'请选择该渠道所支持的模型'} placeholder={'请选择该渠道所支持的模型'}
name='models' name='models'
required required
multiple multiple
selection selection
onChange={value => { onChange={value => {
handleInputChange('models', value) handleInputChange('models', value)
}} }}
value={inputs.models} value={inputs.models}
autoComplete='new-password' autoComplete='new-password'
optionList={modelOptions} optionList={modelOptions}
/> />
<div style={{lineHeight: '40px', marginBottom: '12px'}}> <div style={{ lineHeight: '40px', marginBottom: '12px' }}>
<Space> <Space>
<Button type='primary' onClick={() => { <Button type='primary' onClick={() => {
handleInputChange('models', basicModels); handleInputChange('models', basicModels);
@@ -473,28 +474,41 @@ const EditChannel = (props) => {
}}>清除所有模型</Button> }}>清除所有模型</Button>
</Space> </Space>
<Input <Input
addonAfter={ addonAfter={
<Button type='primary' onClick={addCustomModel}>填入</Button> <Button type='primary' onClick={addCustomModel}>填入</Button>
} }
placeholder='输入自定义模型名称' placeholder='输入自定义模型名称'
value={customModel} value={customModel}
onChange={(value) => { onChange={(value) => {
setCustomModel(value.trim()); setCustomModel(value.trim());
}} }}
/> />
</div> </div>
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>模型重定向</Typography.Text> <Typography.Text strong>模型重定向</Typography.Text>
</div> </div>
<TextArea <TextArea
placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`} placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
name='model_mapping' name='model_mapping'
onChange={value => { onChange={value => {
handleInputChange('model_mapping', value) handleInputChange('model_mapping', value)
}} }}
autosize autosize
value={inputs.model_mapping} value={inputs.model_mapping}
autoComplete='new-password' autoComplete='new-password'
/>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>系统提示词</Typography.Text>
</div>
<TextArea
placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`}
name='system_prompt'
onChange={value => {
handleInputChange('system_prompt', value)
}}
autosize
value={inputs.system_prompt}
autoComplete='new-password'
/> />
<Typography.Text style={{ <Typography.Text style={{
color: 'rgba(var(--semi-blue-5), 1)', color: 'rgba(var(--semi-blue-5), 1)',
@@ -507,116 +521,116 @@ const EditChannel = (props) => {
}> }>
填入模板 填入模板
</Typography.Text> </Typography.Text>
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>密钥</Typography.Text> <Typography.Text strong>密钥</Typography.Text>
</div> </div>
{ {
batch ? batch ?
<TextArea <TextArea
label='密钥' label='密钥'
name='key' name='key'
required required
placeholder={'请输入密钥,一行一个'} placeholder={'请输入密钥,一行一个'}
onChange={value => { onChange={value => {
handleInputChange('key', value) handleInputChange('key', value)
}} }}
value={inputs.key} value={inputs.key}
style={{minHeight: 150, fontFamily: 'JetBrains Mono, Consolas'}} style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password' autoComplete='new-password'
/> />
: :
<Input <Input
label='密钥' label='密钥'
name='key' name='key'
required required
placeholder={type2secretPrompt(inputs.type)} placeholder={type2secretPrompt(inputs.type)}
onChange={value => { onChange={value => {
handleInputChange('key', value) handleInputChange('key', value)
}} }}
value={inputs.key} value={inputs.key}
autoComplete='new-password' autoComplete='new-password'
/> />
} }
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>组织</Typography.Text> <Typography.Text strong>组织</Typography.Text>
</div> </div>
<Input <Input
label='组织,可选,不填则为默认组织' label='组织,可选,不填则为默认组织'
name='openai_organization' name='openai_organization'
placeholder='请输入组织org-xxx' placeholder='请输入组织org-xxx'
onChange={value => { onChange={value => {
handleInputChange('openai_organization', value) handleInputChange('openai_organization', value)
}} }}
value={inputs.openai_organization} value={inputs.openai_organization}
/> />
<div style={{marginTop: 10, display: 'flex'}}> <div style={{ marginTop: 10, display: 'flex' }}>
<Space> <Space>
<Checkbox <Checkbox
name='auto_ban' name='auto_ban'
checked={autoBan} checked={autoBan}
onChange={ onChange={
() => { () => {
setAutoBan(!autoBan); setAutoBan(!autoBan);
} }
} }
// onChange={handleInputChange} // onChange={handleInputChange}
/> />
<Typography.Text <Typography.Text
strong>是否自动禁用仅当自动禁用开启时有效关闭后不会自动禁用该渠道</Typography.Text> strong>是否自动禁用仅当自动禁用开启时有效关闭后不会自动禁用该渠道</Typography.Text>
</Space> </Space>
</div> </div>
{ {
!isEdit && ( !isEdit && (
<div style={{marginTop: 10, display: 'flex'}}> <div style={{ marginTop: 10, display: 'flex' }}>
<Space> <Space>
<Checkbox <Checkbox
checked={batch} checked={batch}
label='批量创建' label='批量创建'
name='batch' name='batch'
onChange={() => setBatch(!batch)} onChange={() => setBatch(!batch)}
/> />
<Typography.Text strong>批量创建</Typography.Text> <Typography.Text strong>批量创建</Typography.Text>
</Space> </Space>
</div>
)
}
{
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>代理</Typography.Text>
</div> </div>
) <Input
label='代理'
name='base_url'
placeholder={'此项可选,用于通过代理站来进行 API 调用'}
onChange={value => {
handleInputChange('base_url', value)
}}
value={inputs.base_url}
autoComplete='new-password'
/>
</>
)
} }
{ {
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( inputs.type === 22 && (
<> <>
<div style={{marginTop: 10}}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>代理</Typography.Text> <Typography.Text strong>私有部署地址</Typography.Text>
</div> </div>
<Input <Input
label='代理' name='base_url'
name='base_url' placeholder={'请输入私有部署地址格式为https://fastgpt.run/api/openapi'}
placeholder={'此项可选,用于通过代理站来进行 API 调用'} onChange={value => {
onChange={value => { handleInputChange('base_url', value)
handleInputChange('base_url', value) }}
}} value={inputs.base_url}
value={inputs.base_url} autoComplete='new-password'
autoComplete='new-password' />
/> </>
</> )
)
}
{
inputs.type === 22 && (
<>
<div style={{marginTop: 10}}>
<Typography.Text strong>私有部署地址</Typography.Text>
</div>
<Input
name='base_url'
placeholder={'请输入私有部署地址格式为https://fastgpt.run/api/openapi'}
onChange={value => {
handleInputChange('base_url', value)
}}
value={inputs.base_url}
autoComplete='new-password'
/>
</>
)
} }
</Spin> </Spin>

View File

@@ -185,6 +185,12 @@ export const CHANNEL_OPTIONS = {
value: 45, value: 45,
color: 'primary' color: 'primary'
}, },
45: {
key: 46,
text: 'Replicate',
value: 46,
color: 'primary'
},
41: { 41: {
key: 41, key: 41,
text: 'Novita', text: 'Novita',

View File

@@ -1,247 +1,260 @@
import { enqueueSnackbar } from 'notistack'; import {enqueueSnackbar} from 'notistack';
import { snackbarConstants } from 'constants/SnackbarConstants'; import {snackbarConstants} from 'constants/SnackbarConstants';
import { API } from './api'; import {API} from './api';
export function getSystemName() { export function getSystemName() {
let system_name = localStorage.getItem('system_name'); let system_name = localStorage.getItem('system_name');
if (!system_name) return 'One API'; if (!system_name) return 'One API';
return system_name; return system_name;
} }
export function isMobile() { export function isMobile() {
return window.innerWidth <= 600; return window.innerWidth <= 600;
} }
// eslint-disable-next-line // eslint-disable-next-line
export function SnackbarHTMLContent({ htmlContent }) { export function SnackbarHTMLContent({htmlContent}) {
return <div dangerouslySetInnerHTML={{ __html: htmlContent }} />; return <div dangerouslySetInnerHTML={{__html: htmlContent}}/>;
} }
export function getSnackbarOptions(variant) { export function getSnackbarOptions(variant) {
let options = snackbarConstants.Common[variant]; let options = snackbarConstants.Common[variant];
if (isMobile()) { if (isMobile()) {
// 合并 options 和 snackbarConstants.Mobile // 合并 options 和 snackbarConstants.Mobile
options = { ...options, ...snackbarConstants.Mobile }; options = {...options, ...snackbarConstants.Mobile};
} }
return options; return options;
} }
export function showError(error) { export function showError(error) {
if (error.message) { if (error.message) {
if (error.name === 'AxiosError') { if (error.name === 'AxiosError') {
switch (error.response.status) { switch (error.response.status) {
case 429: case 429:
enqueueSnackbar('错误:请求次数过多,请稍后再试!', getSnackbarOptions('ERROR')); enqueueSnackbar('错误:请求次数过多,请稍后再试!', getSnackbarOptions('ERROR'));
break; break;
case 500: case 500:
enqueueSnackbar('错误:服务器内部错误,请联系管理员!', getSnackbarOptions('ERROR')); enqueueSnackbar('错误:服务器内部错误,请联系管理员!', getSnackbarOptions('ERROR'));
break; break;
case 405: case 405:
enqueueSnackbar('本站仅作演示之用,无服务端!', getSnackbarOptions('INFO')); enqueueSnackbar('本站仅作演示之用,无服务端!', getSnackbarOptions('INFO'));
break; break;
default: default:
enqueueSnackbar('错误:' + error.message, getSnackbarOptions('ERROR')); enqueueSnackbar('错误:' + error.message, getSnackbarOptions('ERROR'));
} }
return; return;
}
} else {
enqueueSnackbar('错误:' + error, getSnackbarOptions('ERROR'));
} }
} else {
enqueueSnackbar('错误:' + error, getSnackbarOptions('ERROR'));
}
} }
export function showNotice(message, isHTML = false) { export function showNotice(message, isHTML = false) {
if (isHTML) { if (isHTML) {
enqueueSnackbar(<SnackbarHTMLContent htmlContent={message} />, getSnackbarOptions('NOTICE')); enqueueSnackbar(<SnackbarHTMLContent htmlContent={message}/>, getSnackbarOptions('NOTICE'));
} else { } else {
enqueueSnackbar(message, getSnackbarOptions('NOTICE')); enqueueSnackbar(message, getSnackbarOptions('NOTICE'));
} }
} }
export function showWarning(message) { export function showWarning(message) {
enqueueSnackbar(message, getSnackbarOptions('WARNING')); enqueueSnackbar(message, getSnackbarOptions('WARNING'));
} }
export function showSuccess(message) { export function showSuccess(message) {
enqueueSnackbar(message, getSnackbarOptions('SUCCESS')); enqueueSnackbar(message, getSnackbarOptions('SUCCESS'));
} }
export function showInfo(message) { export function showInfo(message) {
enqueueSnackbar(message, getSnackbarOptions('INFO')); enqueueSnackbar(message, getSnackbarOptions('INFO'));
} }
export async function getOAuthState() { export async function getOAuthState() {
const res = await API.get('/api/oauth/state'); const res = await API.get('/api/oauth/state');
const { success, message, data } = res.data; const {success, message, data} = res.data;
if (success) { if (success) {
return data; return data;
} else { } else {
showError(message); showError(message);
return ''; return '';
} }
} }
export async function onGitHubOAuthClicked(github_client_id, openInNewTab = false) { export async function onGitHubOAuthClicked(github_client_id, openInNewTab = false) {
const state = await getOAuthState(); const state = await getOAuthState();
if (!state) return; if (!state) return;
let url = `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`; let url = `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`;
if (openInNewTab) { if (openInNewTab) {
window.open(url); window.open(url);
} else { } else {
window.location.href = url; window.location.href = url;
} }
} }
export async function onLarkOAuthClicked(lark_client_id) { export async function onLarkOAuthClicked(lark_client_id) {
const state = await getOAuthState(); const state = await getOAuthState();
if (!state) return; if (!state) return;
let redirect_uri = `${window.location.origin}/oauth/lark`; let redirect_uri = `${window.location.origin}/oauth/lark`;
window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`); window.open(`https://accounts.feishu.cn/open-apis/authen/v1/authorize?redirect_uri=${redirect_uri}&client_id=${lark_client_id}&state=${state}`);
} }
export async function onOidcClicked(auth_url, client_id, openInNewTab = false) { export async function onOidcClicked(auth_url, client_id, openInNewTab = false) {
const state = await getOAuthState(); const state = await getOAuthState();
if (!state) return; if (!state) return;
const redirect_uri = `${window.location.origin}/oauth/oidc`; const redirect_uri = `${window.location.origin}/oauth/oidc`;
const response_type = "code"; const response_type = "code";
const scope = "openid profile email"; const scope = "openid profile email";
const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`; const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`;
if (openInNewTab) { if (openInNewTab) {
window.open(url); window.open(url);
} else } else {
{ window.location.href = url;
window.location.href = url; }
}
} }
export function isAdmin() { export function isAdmin() {
let user = localStorage.getItem('user'); let user = localStorage.getItem('user');
if (!user) return false; if (!user) return false;
user = JSON.parse(user); user = JSON.parse(user);
return user.role >= 10; return user.role >= 10;
} }
export function timestamp2string(timestamp) { export function timestamp2string(timestamp) {
let date = new Date(timestamp * 1000); let date = new Date(timestamp * 1000);
let year = date.getFullYear().toString(); let year = date.getFullYear().toString();
let month = (date.getMonth() + 1).toString(); let month = (date.getMonth() + 1).toString();
let day = date.getDate().toString(); let day = date.getDate().toString();
let hour = date.getHours().toString(); let hour = date.getHours().toString();
let minute = date.getMinutes().toString(); let minute = date.getMinutes().toString();
let second = date.getSeconds().toString(); let second = date.getSeconds().toString();
if (month.length === 1) { if (month.length === 1) {
month = '0' + month; month = '0' + month;
} }
if (day.length === 1) { if (day.length === 1) {
day = '0' + day; day = '0' + day;
} }
if (hour.length === 1) { if (hour.length === 1) {
hour = '0' + hour; hour = '0' + hour;
} }
if (minute.length === 1) { if (minute.length === 1) {
minute = '0' + minute; minute = '0' + minute;
} }
if (second.length === 1) { if (second.length === 1) {
second = '0' + second; second = '0' + second;
} }
return year + '-' + month + '-' + day + ' ' + hour + ':' + minute + ':' + second; return year + '-' + month + '-' + day + ' ' + hour + ':' + minute + ':' + second;
} }
export function calculateQuota(quota, digits = 2) { export function calculateQuota(quota, digits = 2) {
let quotaPerUnit = localStorage.getItem('quota_per_unit'); let quotaPerUnit = localStorage.getItem('quota_per_unit');
quotaPerUnit = parseFloat(quotaPerUnit); quotaPerUnit = parseFloat(quotaPerUnit);
return (quota / quotaPerUnit).toFixed(digits); return (quota / quotaPerUnit).toFixed(digits);
} }
export function renderQuota(quota, digits = 2) { export function renderQuota(quota, digits = 2) {
let displayInCurrency = localStorage.getItem('display_in_currency'); let displayInCurrency = localStorage.getItem('display_in_currency');
displayInCurrency = displayInCurrency === 'true'; displayInCurrency = displayInCurrency === 'true';
if (displayInCurrency) { if (displayInCurrency) {
return '$' + calculateQuota(quota, digits); return '$' + calculateQuota(quota, digits);
} }
return renderNumber(quota); return renderNumber(quota);
} }
export const verifyJSON = (str) => { export const verifyJSON = (str) => {
try { try {
JSON.parse(str); JSON.parse(str);
} catch (e) { } catch (e) {
return false; return false;
} }
return true; return true;
}; };
export function renderNumber(num) { export function renderNumber(num) {
if (num >= 1000000000) { if (num >= 1000000000) {
return (num / 1000000000).toFixed(1) + 'B'; return (num / 1000000000).toFixed(1) + 'B';
} else if (num >= 1000000) { } else if (num >= 1000000) {
return (num / 1000000).toFixed(1) + 'M'; return (num / 1000000).toFixed(1) + 'M';
} else if (num >= 10000) { } else if (num >= 10000) {
return (num / 1000).toFixed(1) + 'k'; return (num / 1000).toFixed(1) + 'k';
} else { } else {
return num; return num;
} }
} }
export function renderQuotaWithPrompt(quota, digits) { export function renderQuotaWithPrompt(quota, digits) {
let displayInCurrency = localStorage.getItem('display_in_currency'); let displayInCurrency = localStorage.getItem('display_in_currency');
displayInCurrency = displayInCurrency === 'true'; displayInCurrency = displayInCurrency === 'true';
if (displayInCurrency) { if (displayInCurrency) {
return `(等价金额:${renderQuota(quota, digits)}`; return `(等价金额:${renderQuota(quota, digits)}`;
} }
return ''; return '';
} }
export function downloadTextAsFile(text, filename) { export function downloadTextAsFile(text, filename) {
let blob = new Blob([text], { type: 'text/plain;charset=utf-8' }); let blob = new Blob([text], {type: 'text/plain;charset=utf-8'});
let url = URL.createObjectURL(blob); let url = URL.createObjectURL(blob);
let a = document.createElement('a'); let a = document.createElement('a');
a.href = url; a.href = url;
a.download = filename; a.download = filename;
a.click(); a.click();
} }
export function removeTrailingSlash(url) { export function removeTrailingSlash(url) {
if (url.endsWith('/')) { if (url.endsWith('/')) {
return url.slice(0, -1); return url.slice(0, -1);
} else { } else {
return url; return url;
} }
} }
let channelModels = undefined; let channelModels = undefined;
export async function loadChannelModels() { export async function loadChannelModels() {
const res = await API.get('/api/models'); const res = await API.get('/api/models');
const { success, data } = res.data; const {success, data} = res.data;
if (!success) { if (!success) {
return; return;
} }
channelModels = data; channelModels = data;
localStorage.setItem('channel_models', JSON.stringify(data)); localStorage.setItem('channel_models', JSON.stringify(data));
} }
export function getChannelModels(type) { export function getChannelModels(type) {
if (channelModels !== undefined && type in channelModels) { if (channelModels !== undefined && type in channelModels) {
return channelModels[type]; return channelModels[type];
} }
let models = localStorage.getItem('channel_models'); let models = localStorage.getItem('channel_models');
if (!models) { if (!models) {
return [];
}
channelModels = JSON.parse(models);
if (type in channelModels) {
return channelModels[type];
}
return []; return [];
}
channelModels = JSON.parse(models);
if (type in channelModels) {
return channelModels[type];
}
return [];
} }
export function copy(text, name = '') { export function copy(text, name = '') {
try { if (navigator.clipboard && navigator.clipboard.writeText) {
navigator.clipboard.writeText(text); navigator.clipboard.writeText(text).then(() => {
} catch (error) { showNotice(`复制${name}成功!`, true);
text = `复制${name}失败,请手动复制:<br /><br />${text}`; }, () => {
enqueueSnackbar(<SnackbarHTMLContent htmlContent={text} />, getSnackbarOptions('COPY')); text = `复制${name}失败,请手动复制:<br /><br />${text}`;
return; enqueueSnackbar(<SnackbarHTMLContent htmlContent={text}/>, getSnackbarOptions('COPY'));
} });
showSuccess(`复制${name}成功!`); } else {
const textArea = document.createElement("textarea");
textArea.value = text;
document.body.appendChild(textArea);
textArea.select();
try {
document.execCommand('copy');
showNotice(`复制${name}成功!`, true);
} catch (err) {
text = `复制${name}失败,请手动复制:<br /><br />${text}`;
enqueueSnackbar(<SnackbarHTMLContent htmlContent={text}/>, getSnackbarOptions('COPY'));
}
document.body.removeChild(textArea);
}
} }

View File

@@ -595,6 +595,28 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
<FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText> <FormHelperText id="helper-tex-channel-model_mapping-label"> {inputPrompt.model_mapping} </FormHelperText>
)} )}
</FormControl> </FormControl>
<FormControl fullWidth error={Boolean(touched.system_prompt && errors.system_prompt)} sx={{ ...theme.typography.otherInput }}>
{/* <InputLabel htmlFor="channel-model_mapping-label">{inputLabel.model_mapping}</InputLabel> */}
<TextField
multiline
id="channel-system_prompt-label"
label={inputLabel.system_prompt}
value={values.system_prompt}
name="system_prompt"
onBlur={handleBlur}
onChange={handleChange}
aria-describedby="helper-text-channel-system_prompt-label"
minRows={5}
placeholder={inputPrompt.system_prompt}
/>
{touched.system_prompt && errors.system_prompt ? (
<FormHelperText error id="helper-tex-channel-system_prompt-label">
{errors.system_prompt}
</FormHelperText>
) : (
<FormHelperText id="helper-tex-channel-system_prompt-label"> {inputPrompt.system_prompt} </FormHelperText>
)}
</FormControl>
<DialogActions> <DialogActions>
<Button onClick={onCancel}>取消</Button> <Button onClick={onCancel}>取消</Button>
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary"> <Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">

View File

@@ -268,6 +268,8 @@ function renderBalance(type, balance) {
return <span>¥{balance.toFixed(2)}</span>; return <span>¥{balance.toFixed(2)}</span>;
case 13: // AIGC2D case 13: // AIGC2D
return <span>{renderNumber(balance)}</span>; return <span>{renderNumber(balance)}</span>;
case 36: // DeepSeek
return <span>¥{balance.toFixed(2)}</span>;
case 44: // SiliconFlow case 44: // SiliconFlow
return <span>¥{balance.toFixed(2)}</span>; return <span>¥{balance.toFixed(2)}</span>;
default: default:

View File

@@ -18,6 +18,7 @@ const defaultConfig = {
other: '其他参数', other: '其他参数',
models: '模型', models: '模型',
model_mapping: '模型映射关系', model_mapping: '模型映射关系',
system_prompt: '系统提示词',
groups: '用户组', groups: '用户组',
config: null config: null
}, },
@@ -30,6 +31,7 @@ const defaultConfig = {
models: '请选择该渠道所支持的模型', models: '请选择该渠道所支持的模型',
model_mapping: model_mapping:
'请输入要修改的模型映射关系格式为api请求模型ID:实际转发给渠道的模型ID使用JSON数组表示例如{"gpt-3.5": "gpt-35"}', '请输入要修改的模型映射关系格式为api请求模型ID:实际转发给渠道的模型ID使用JSON数组表示例如{"gpt-3.5": "gpt-35"}',
system_prompt:"此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型",
groups: '请选择该渠道所支持的用户组', groups: '请选择该渠道所支持的用户组',
config: null config: null
}, },

0
web/build.sh Normal file → Executable file
View File

View File

@@ -1,5 +1,15 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Dropdown, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; import {
Button,
Dropdown,
Form,
Input,
Label,
Message,
Pagination,
Popup,
Table,
} from 'semantic-ui-react';
import { Link } from 'react-router-dom'; import { Link } from 'react-router-dom';
import { import {
API, API,
@@ -9,31 +19,31 @@ import {
showError, showError,
showInfo, showInfo,
showSuccess, showSuccess,
timestamp2string timestamp2string,
} from '../helpers'; } from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
import { renderGroup, renderNumber } from '../helpers/render'; import { renderGroup, renderNumber } from '../helpers/render';
function renderTimestamp(timestamp) { function renderTimestamp(timestamp) {
return ( return <>{timestamp2string(timestamp)}</>;
<>
{timestamp2string(timestamp)}
</>
);
} }
let type2label = undefined; let type2label = undefined;
function renderType(type) { function renderType(type) {
if (!type2label) { if (!type2label) {
type2label = new Map; type2label = new Map();
for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { for (let i = 0; i < CHANNEL_OPTIONS.length; i++) {
type2label[CHANNEL_OPTIONS[i].value] = CHANNEL_OPTIONS[i]; type2label[CHANNEL_OPTIONS[i].value] = CHANNEL_OPTIONS[i];
} }
type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
} }
return <Label basic color={type2label[type]?.color}>{type2label[type] ? type2label[type].text : type}</Label>; return (
<Label basic color={type2label[type]?.color}>
{type2label[type] ? type2label[type].text : type}
</Label>
);
} }
function renderBalance(type, balance) { function renderBalance(type, balance) {
@@ -52,6 +62,8 @@ function renderBalance(type, balance) {
return <span>¥{balance.toFixed(2)}</span>; return <span>¥{balance.toFixed(2)}</span>;
case 13: // AIGC2D case 13: // AIGC2D
return <span>{renderNumber(balance)}</span>; return <span>{renderNumber(balance)}</span>;
case 36: // DeepSeek
return <span>¥{balance.toFixed(2)}</span>;
case 44: // SiliconFlow case 44: // SiliconFlow
return <span>¥{balance.toFixed(2)}</span>; return <span>¥{balance.toFixed(2)}</span>;
default: default:
@@ -60,10 +72,10 @@ function renderBalance(type, balance) {
} }
function isShowDetail() { function isShowDetail() {
return localStorage.getItem("show_detail") === "true"; return localStorage.getItem('show_detail') === 'true';
} }
const promptID = "detail" const promptID = 'detail';
const ChannelsTable = () => { const ChannelsTable = () => {
const [channels, setChannels] = useState([]); const [channels, setChannels] = useState([]);
@@ -79,33 +91,37 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/?p=${startIdx}`); const res = await API.get(`/api/channel/?p=${startIdx}`);
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
let localChannels = data.map((channel) => { let localChannels = data.map((channel) => {
if (channel.models === '') { if (channel.models === '') {
channel.models = []; channel.models = [];
channel.test_model = ""; channel.test_model = '';
} else {
channel.models = channel.models.split(',');
if (channel.models.length > 0) {
channel.test_model = channel.models[0];
}
channel.model_options = channel.models.map((model) => {
return {
key: model,
text: model,
value: model,
}
})
console.log('channel', channel)
}
return channel;
});
if (startIdx === 0) {
setChannels(localChannels);
} else { } else {
let newChannels = [...channels]; channel.models = channel.models.split(',');
newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...localChannels); if (channel.models.length > 0) {
setChannels(newChannels); channel.test_model = channel.models[0];
}
channel.model_options = channel.models.map((model) => {
return {
key: model,
text: model,
value: model,
};
});
console.log('channel', channel);
} }
return channel;
});
if (startIdx === 0) {
setChannels(localChannels);
} else {
let newChannels = [...channels];
newChannels.splice(
startIdx * ITEMS_PER_PAGE,
data.length,
...localChannels
);
setChannels(newChannels);
}
} else { } else {
showError(message); showError(message);
} }
@@ -129,8 +145,8 @@ const ChannelsTable = () => {
const toggleShowDetail = () => { const toggleShowDetail = () => {
setShowDetail(!showDetail); setShowDetail(!showDetail);
localStorage.setItem("show_detail", (!showDetail).toString()); localStorage.setItem('show_detail', (!showDetail).toString());
} };
useEffect(() => { useEffect(() => {
loadChannels(0) loadChannels(0)
@@ -194,13 +210,19 @@ const ChannelsTable = () => {
const renderStatus = (status) => { const renderStatus = (status) => {
switch (status) { switch (status) {
case 1: case 1:
return <Label basic color='green'>已启用</Label>; return (
<Label basic color='green'>
已启用
</Label>
);
case 2: case 2:
return ( return (
<Popup <Popup
trigger={<Label basic color='red'> trigger={
已禁用 <Label basic color='red'>
</Label>} 已禁用
</Label>
}
content='本渠道被手动禁用' content='本渠道被手动禁用'
basic basic
/> />
@@ -208,9 +230,11 @@ const ChannelsTable = () => {
case 3: case 3:
return ( return (
<Popup <Popup
trigger={<Label basic color='yellow'> trigger={
已禁用 <Label basic color='yellow'>
</Label>} 已禁用
</Label>
}
content='本渠道被程序自动禁用' content='本渠道被程序自动禁用'
basic basic
/> />
@@ -228,15 +252,35 @@ const ChannelsTable = () => {
let time = responseTime / 1000; let time = responseTime / 1000;
time = time.toFixed(2) + ' 秒'; time = time.toFixed(2) + ' 秒';
if (responseTime === 0) { if (responseTime === 0) {
return <Label basic color='grey'>未测试</Label>; return (
<Label basic color='grey'>
未测试
</Label>
);
} else if (responseTime <= 1000) { } else if (responseTime <= 1000) {
return <Label basic color='green'>{time}</Label>; return (
<Label basic color='green'>
{time}
</Label>
);
} else if (responseTime <= 3000) { } else if (responseTime <= 3000) {
return <Label basic color='olive'>{time}</Label>; return (
<Label basic color='olive'>
{time}
</Label>
);
} else if (responseTime <= 5000) { } else if (responseTime <= 5000) {
return <Label basic color='yellow'>{time}</Label>; return (
<Label basic color='yellow'>
{time}
</Label>
);
} else { } else {
return <Label basic color='red'>{time}</Label>; return (
<Label basic color='red'>
{time}
</Label>
);
} }
}; };
@@ -275,7 +319,11 @@ const ChannelsTable = () => {
newChannels[realIdx].response_time = time * 1000; newChannels[realIdx].response_time = time * 1000;
newChannels[realIdx].test_time = Date.now() / 1000; newChannels[realIdx].test_time = Date.now() / 1000;
setChannels(newChannels); setChannels(newChannels);
showInfo(`渠道 ${name} 测试成功,模型 ${model},耗时 ${time.toFixed(2)} 秒。`); showInfo(
`渠道 ${name} 测试成功,模型 ${model},耗时 ${time.toFixed(
2
)} 秒,模型输出:${message}`
);
} else { } else {
showError(message); showError(message);
} }
@@ -358,7 +406,6 @@ const ChannelsTable = () => {
setLoading(false); setLoading(false);
}; };
return ( return (
<> <>
<Form onSubmit={searchChannels}> <Form onSubmit={searchChannels}>
@@ -372,20 +419,22 @@ const ChannelsTable = () => {
onChange={handleKeywordChange} onChange={handleKeywordChange}
/> />
</Form> </Form>
{ {showPrompt && (
showPrompt && ( <Message
<Message onDismiss={() => { onDismiss={() => {
setShowPrompt(false); setShowPrompt(false);
setPromptShown(promptID); setPromptShown(promptID);
}}> }}
OpenAI 渠道已经不再支持通过 key 获取余额因此余额显示为 0对于支持的渠道类型请点击余额进行刷新 >
<br/> OpenAI 渠道已经不再支持通过 key 获取余额因此余额显示为
渠道测试仅支持 chat 模型优先使用 gpt-3.5-turbo如果该模型不可用则使用你所配置的模型列表中的第一个模型 0对于支持的渠道类型请点击余额进行刷新
<br/> <br />
点击下方详情按钮可以显示余额以及设置额外的测试模型 渠道测试仅支持 chat 模型优先使用
</Message> gpt-3.5-turbo如果该模型不可用则使用你所配置的模型列表中的第一个模型
) <br />
} 点击下方详情按钮可以显示余额以及设置额外的测试模型
</Message>
)}
<Table basic compact size='small'> <Table basic compact size='small'>
<Table.Header> <Table.Header>
<Table.Row> <Table.Row>
@@ -476,7 +525,11 @@ const ChannelsTable = () => {
<Table.Cell>{renderStatus(channel.status)}</Table.Cell> <Table.Cell>{renderStatus(channel.status)}</Table.Cell>
<Table.Cell> <Table.Cell>
<Popup <Popup
content={channel.test_time ? renderTimestamp(channel.test_time) : '未测试'} content={
channel.test_time
? renderTimestamp(channel.test_time)
: '未测试'
}
key={channel.id} key={channel.id}
trigger={renderResponseTime(channel.response_time)} trigger={renderResponseTime(channel.response_time)}
basic basic
@@ -484,27 +537,38 @@ const ChannelsTable = () => {
</Table.Cell> </Table.Cell>
<Table.Cell hidden={!showDetail}> <Table.Cell hidden={!showDetail}>
<Popup <Popup
trigger={<span onClick={() => { trigger={
updateChannelBalance(channel.id, channel.name, idx); <span
}} style={{ cursor: 'pointer' }}> onClick={() => {
{renderBalance(channel.type, channel.balance)} updateChannelBalance(channel.id, channel.name, idx);
</span>} }}
style={{ cursor: 'pointer' }}
>
{renderBalance(channel.type, channel.balance)}
</span>
}
content='点击更新' content='点击更新'
basic basic
/> />
</Table.Cell> </Table.Cell>
<Table.Cell> <Table.Cell>
<Popup <Popup
trigger={<Input type='number' defaultValue={channel.priority} onBlur={(event) => { trigger={
manageChannel( <Input
channel.id, type='number'
'priority', defaultValue={channel.priority}
idx, onBlur={(event) => {
event.target.value manageChannel(
); channel.id,
}}> 'priority',
<input style={{ maxWidth: '60px' }} /> idx,
</Input>} event.target.value
);
}}
>
<input style={{ maxWidth: '60px' }} />
</Input>
}
content='渠道选择优先级,越高越优先' content='渠道选择优先级,越高越优先'
basic basic
/> />
@@ -526,7 +590,12 @@ const ChannelsTable = () => {
size={'small'} size={'small'}
positive positive
onClick={() => { onClick={() => {
testChannel(channel.id, channel.name, idx, channel.test_model); testChannel(
channel.id,
channel.name,
idx,
channel.test_model
);
}} }}
> >
测试 测试
@@ -588,14 +657,31 @@ const ChannelsTable = () => {
<Table.Footer> <Table.Footer>
<Table.Row> <Table.Row>
<Table.HeaderCell colSpan={showDetail ? "10" : "8"}> <Table.HeaderCell colSpan={showDetail ? '10' : '8'}>
<Button size='small' as={Link} to='/channel/add' loading={loading}> <Button
size='small'
as={Link}
to='/channel/add'
loading={loading}
>
添加新的渠道 添加新的渠道
</Button> </Button>
<Button size='small' loading={loading} onClick={()=>{testChannels("all")}}> <Button
size='small'
loading={loading}
onClick={() => {
testChannels('all');
}}
>
测试所有渠道 测试所有渠道
</Button> </Button>
<Button size='small' loading={loading} onClick={()=>{testChannels("disabled")}}> <Button
size='small'
loading={loading}
onClick={() => {
testChannels('disabled');
}}
>
测试禁用渠道 测试禁用渠道
</Button> </Button>
{/*<Button size='small' onClick={updateAllChannelsBalance}*/} {/*<Button size='small' onClick={updateAllChannelsBalance}*/}
@@ -610,7 +696,12 @@ const ChannelsTable = () => {
flowing flowing
hoverable hoverable
> >
<Button size='small' loading={loading} negative onClick={deleteAllDisabledChannels}> <Button
size='small'
loading={loading}
negative
onClick={deleteAllDisabledChannels}
>
确认删除 确认删除
</Button> </Button>
</Popup> </Popup>
@@ -625,8 +716,12 @@ const ChannelsTable = () => {
(channels.length % ITEMS_PER_PAGE === 0 ? 1 : 0) (channels.length % ITEMS_PER_PAGE === 0 ? 1 : 0)
} }
/> />
<Button size='small' onClick={refresh} loading={loading}>刷新</Button> <Button size='small' onClick={refresh} loading={loading}>
<Button size='small' onClick={toggleShowDetail}>{showDetail ? "隐藏详情" : "详情"}</Button> 刷新
</Button>
<Button size='small' onClick={toggleShowDetail}>
{showDetail ? '隐藏详情' : '详情'}
</Button>
</Table.HeaderCell> </Table.HeaderCell>
</Table.Row> </Table.Row>
</Table.Footer> </Table.Footer>

View File

@@ -1,21 +1,48 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Label, Pagination, Segment, Select, Table } from 'semantic-ui-react'; import {
import { API, isAdmin, showError, timestamp2string } from '../helpers'; Button,
Form,
Header,
Label,
Pagination,
Segment,
Select,
Table,
} from 'semantic-ui-react';
import {
API,
copy,
isAdmin,
showError,
showSuccess,
showWarning,
timestamp2string,
} from '../helpers';
import { ITEMS_PER_PAGE } from '../constants'; import { ITEMS_PER_PAGE } from '../constants';
import { renderQuota } from '../helpers/render'; import { renderColorLabel, renderQuota } from '../helpers/render';
import { Link } from 'react-router-dom';
function renderTimestamp(timestamp) { function renderTimestamp(timestamp, request_id) {
return ( return (
<> <code
onClick={async () => {
if (await copy(request_id)) {
showSuccess(`已复制请求 ID${request_id}`);
} else {
showWarning(`请求 ID 复制失败:${request_id}`);
}
}}
style={{ cursor: 'pointer' }}
>
{timestamp2string(timestamp)} {timestamp2string(timestamp)}
</> </code>
); );
} }
const MODE_OPTIONS = [ const MODE_OPTIONS = [
{ key: 'all', text: '全部用户', value: 'all' }, { key: 'all', text: '全部用户', value: 'all' },
{ key: 'self', text: '当前用户', value: 'self' } { key: 'self', text: '当前用户', value: 'self' },
]; ];
const LOG_OPTIONS = [ const LOG_OPTIONS = [
@@ -23,24 +50,92 @@ const LOG_OPTIONS = [
{ key: '1', text: '充值', value: 1 }, { key: '1', text: '充值', value: 1 },
{ key: '2', text: '消费', value: 2 }, { key: '2', text: '消费', value: 2 },
{ key: '3', text: '管理', value: 3 }, { key: '3', text: '管理', value: 3 },
{ key: '4', text: '系统', value: 4 } { key: '4', text: '系统', value: 4 },
{ key: '5', text: '测试', value: 5 },
]; ];
function renderType(type) { function renderType(type) {
switch (type) { switch (type) {
case 1: case 1:
return <Label basic color='green'> 充值 </Label>; return (
<Label basic color='green'>
充值
</Label>
);
case 2: case 2:
return <Label basic color='olive'> 消费 </Label>; return (
<Label basic color='olive'>
消费
</Label>
);
case 3: case 3:
return <Label basic color='orange'> 管理 </Label>; return (
<Label basic color='orange'>
管理
</Label>
);
case 4: case 4:
return <Label basic color='purple'> 系统 </Label>; return (
<Label basic color='purple'>
系统
</Label>
);
case 5:
return (
<Label basic color='violet'>
测试
</Label>
);
default: default:
return <Label basic color='black'> 未知 </Label>; return (
<Label basic color='black'>
未知
</Label>
);
} }
} }
function getColorByElapsedTime(elapsedTime) {
if (elapsedTime === undefined || 0) return 'black';
if (elapsedTime < 1000) return 'green';
if (elapsedTime < 3000) return 'olive';
if (elapsedTime < 5000) return 'yellow';
if (elapsedTime < 10000) return 'orange';
return 'red';
}
function renderDetail(log) {
return (
<>
{log.content}
<br />
{log.elapsed_time && (
<Label
basic
size={'mini'}
color={getColorByElapsedTime(log.elapsed_time)}
>
{log.elapsed_time} ms
</Label>
)}
{log.is_stream && (
<>
<Label size={'mini'} color='pink'>
Stream
</Label>
</>
)}
{log.system_prompt_reset && (
<>
<Label basic size={'mini'} color='red'>
System Prompt Reset
</Label>
</>
)}
</>
);
}
const LogsTable = () => { const LogsTable = () => {
const [logs, setLogs] = useState([]); const [logs, setLogs] = useState([]);
const [showStat, setShowStat] = useState(false); const [showStat, setShowStat] = useState(false);
@@ -57,13 +152,20 @@ const LogsTable = () => {
model_name: '', model_name: '',
start_timestamp: timestamp2string(0), start_timestamp: timestamp2string(0),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600), end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: '' channel: '',
}); });
const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs; const {
username,
token_name,
model_name,
start_timestamp,
end_timestamp,
channel,
} = inputs;
const [stat, setStat] = useState({ const [stat, setStat] = useState({
quota: 0, quota: 0,
token: 0 token: 0,
}); });
const handleInputChange = (e, { name, value }) => { const handleInputChange = (e, { name, value }) => {
@@ -73,7 +175,9 @@ const LogsTable = () => {
const getLogSelfStat = async () => { const getLogSelfStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); let res = await API.get(
`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`
);
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
setStat(data); setStat(data);
@@ -85,7 +189,9 @@ const LogsTable = () => {
const getLogStat = async () => { const getLogStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`); let res = await API.get(
`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`
);
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
setStat(data); setStat(data);
@@ -105,6 +211,10 @@ const LogsTable = () => {
setShowStat(!showStat); setShowStat(!showStat);
}; };
const showUserTokenQuota = () => {
return logType !== 5;
};
const loadLogs = async (startIdx) => { const loadLogs = async (startIdx) => {
let url = ''; let url = '';
let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localStartTimestamp = Date.parse(start_timestamp) / 1000;
@@ -201,37 +311,82 @@ const LogsTable = () => {
<Header as='h3'> <Header as='h3'>
使用明细总消耗额度 使用明细总消耗额度
{showStat && renderQuota(stat.quota)} {showStat && renderQuota(stat.quota)}
{!showStat && <span onClick={handleEyeClick} style={{ cursor: 'pointer', color: 'gray' }}>点击查看</span>} {!showStat && (
<span
onClick={handleEyeClick}
style={{ cursor: 'pointer', color: 'gray' }}
>
点击查看
</span>
)}
</Header> </Header>
<Form> <Form>
<Form.Group> <Form.Group>
<Form.Input fluid label={'令牌名称'} width={3} value={token_name} <Form.Input
placeholder={'可选值'} name='token_name' onChange={handleInputChange} /> fluid
<Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值' label={'令牌名称'}
name='model_name' width={3}
onChange={handleInputChange} /> value={token_name}
<Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local' placeholder={'可选值'}
name='start_timestamp' name='token_name'
onChange={handleInputChange} /> onChange={handleInputChange}
<Form.Input fluid label='结束时间' width={4} value={end_timestamp} type='datetime-local' />
name='end_timestamp' <Form.Input
onChange={handleInputChange} /> fluid
<Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button> label='模型名称'
width={3}
value={model_name}
placeholder='可选值'
name='model_name'
onChange={handleInputChange}
/>
<Form.Input
fluid
label='起始时间'
width={4}
value={start_timestamp}
type='datetime-local'
name='start_timestamp'
onChange={handleInputChange}
/>
<Form.Input
fluid
label='结束时间'
width={4}
value={end_timestamp}
type='datetime-local'
name='end_timestamp'
onChange={handleInputChange}
/>
<Form.Button fluid label='操作' width={2} onClick={refresh}>
查询
</Form.Button>
</Form.Group> </Form.Group>
{ {isAdminUser && (
isAdminUser && <> <>
<Form.Group> <Form.Group>
<Form.Input fluid label={'渠道 ID'} width={3} value={channel} <Form.Input
placeholder='可选值' name='channel' fluid
onChange={handleInputChange} /> label={'渠道 ID'}
<Form.Input fluid label={'用户名称'} width={3} value={username} width={3}
placeholder={'可选值'} name='username' value={channel}
onChange={handleInputChange} /> placeholder='可选值'
name='channel'
onChange={handleInputChange}
/>
<Form.Input
fluid
label={'用户名称'}
width={3}
value={username}
placeholder={'可选值'}
name='username'
onChange={handleInputChange}
/>
</Form.Group> </Form.Group>
</> </>
} )}
</Form> </Form>
<Table basic compact size='small'> <Table basic compact size='small'>
<Table.Header> <Table.Header>
@@ -245,8 +400,8 @@ const LogsTable = () => {
> >
时间 时间
</Table.HeaderCell> </Table.HeaderCell>
{ {isAdminUser && (
isAdminUser && <Table.HeaderCell <Table.HeaderCell
style={{ cursor: 'pointer' }} style={{ cursor: 'pointer' }}
onClick={() => { onClick={() => {
sortLog('channel'); sortLog('channel');
@@ -255,27 +410,7 @@ const LogsTable = () => {
> >
渠道 渠道
</Table.HeaderCell> </Table.HeaderCell>
} )}
{
isAdminUser && <Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('username');
}}
width={1}
>
用户
</Table.HeaderCell>
}
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('token_name');
}}
width={1}
>
令牌
</Table.HeaderCell>
<Table.HeaderCell <Table.HeaderCell
style={{ cursor: 'pointer' }} style={{ cursor: 'pointer' }}
onClick={() => { onClick={() => {
@@ -294,33 +429,57 @@ const LogsTable = () => {
> >
模型 模型
</Table.HeaderCell> </Table.HeaderCell>
<Table.HeaderCell {showUserTokenQuota() && (
style={{ cursor: 'pointer' }} <>
onClick={() => { {isAdminUser && (
sortLog('prompt_tokens'); <Table.HeaderCell
}} style={{ cursor: 'pointer' }}
width={1} onClick={() => {
> sortLog('username');
提示 }}
</Table.HeaderCell> width={1}
<Table.HeaderCell >
style={{ cursor: 'pointer' }} 用户
onClick={() => { </Table.HeaderCell>
sortLog('completion_tokens'); )}
}} <Table.HeaderCell
width={1} style={{ cursor: 'pointer' }}
> onClick={() => {
补全 sortLog('token_name');
</Table.HeaderCell> }}
<Table.HeaderCell width={1}
style={{ cursor: 'pointer' }} >
onClick={() => { 令牌
sortLog('quota'); </Table.HeaderCell>
}} <Table.HeaderCell
width={1} style={{ cursor: 'pointer' }}
> onClick={() => {
额度 sortLog('prompt_tokens');
</Table.HeaderCell> }}
width={1}
>
提示
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('completion_tokens');
}}
width={1}
>
补全
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('quota');
}}
width={1}
>
额度
</Table.HeaderCell>
</>
)}
<Table.HeaderCell <Table.HeaderCell
style={{ cursor: 'pointer' }} style={{ cursor: 'pointer' }}
onClick={() => { onClick={() => {
@@ -343,24 +502,64 @@ const LogsTable = () => {
if (log.deleted) return <></>; if (log.deleted) return <></>;
return ( return (
<Table.Row key={log.id}> <Table.Row key={log.id}>
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell> <Table.Cell>
{ {renderTimestamp(log.created_at, log.request_id)}
isAdminUser && ( </Table.Cell>
<Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell> {isAdminUser && (
) <Table.Cell>
} {log.channel ? (
{ <Label
isAdminUser && ( basic
<Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell> as={Link}
) to={`/channel/edit/${log.channel}`}
} >
<Table.Cell>{log.token_name ? <Label basic>{log.token_name}</Label> : ''}</Table.Cell> {log.channel}
</Label>
) : (
''
)}
</Table.Cell>
)}
<Table.Cell>{renderType(log.type)}</Table.Cell> <Table.Cell>{renderType(log.type)}</Table.Cell>
<Table.Cell>{log.model_name ? <Label basic>{log.model_name}</Label> : ''}</Table.Cell> <Table.Cell>
<Table.Cell>{log.prompt_tokens ? log.prompt_tokens : ''}</Table.Cell> {log.model_name ? renderColorLabel(log.model_name) : ''}
<Table.Cell>{log.completion_tokens ? log.completion_tokens : ''}</Table.Cell> </Table.Cell>
<Table.Cell>{log.quota ? renderQuota(log.quota, 6) : ''}</Table.Cell> {showUserTokenQuota() && (
<Table.Cell>{log.content}</Table.Cell> <>
{isAdminUser && (
<Table.Cell>
{log.username ? (
<Label
basic
as={Link}
to={`/user/edit/${log.user_id}`}
>
{log.username}
</Label>
) : (
''
)}
</Table.Cell>
)}
<Table.Cell>
{log.token_name
? renderColorLabel(log.token_name)
: ''}
</Table.Cell>
<Table.Cell>
{log.prompt_tokens ? log.prompt_tokens : ''}
</Table.Cell>
<Table.Cell>
{log.completion_tokens ? log.completion_tokens : ''}
</Table.Cell>
<Table.Cell>
{log.quota ? renderQuota(log.quota, 6) : ''}
</Table.Cell>
</>
)}
<Table.Cell>{renderDetail(log)}</Table.Cell>
</Table.Row> </Table.Row>
); );
})} })}
@@ -379,7 +578,9 @@ const LogsTable = () => {
setLogType(value); setLogType(value);
}} }}
/> />
<Button size='small' onClick={refresh} loading={loading}>刷新</Button> <Button size='small' onClick={refresh} loading={loading}>
刷新
</Button>
<Pagination <Pagination
floated='right' floated='right'
activePage={activePage} activePage={activePage}

View File

@@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
{ key: 43, text: 'Proxy', value: 43, color: 'blue' }, { key: 43, text: 'Proxy', value: 43, color: 'blue' },
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
{ key: 45, text: 'xAI', value: 45, color: 'blue' }, { key: 45, text: 'xAI', value: 45, color: 'blue' },
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' }, { key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' }, { key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },

View File

@@ -13,16 +13,18 @@ export function renderGroup(group) {
} }
let groups = group.split(','); let groups = group.split(',');
groups.sort(); groups.sort();
return <> return (
{groups.map((group) => { <>
if (group === 'vip' || group === 'pro') { {groups.map((group) => {
return <Label color='yellow'>{group}</Label>; if (group === 'vip' || group === 'pro') {
} else if (group === 'svip' || group === 'premium') { return <Label color='yellow'>{group}</Label>;
return <Label color='red'>{group}</Label>; } else if (group === 'svip' || group === 'premium') {
} return <Label color='red'>{group}</Label>;
return <Label>{group}</Label>; }
})} return <Label>{group}</Label>;
</>; })}
</>
);
} }
export function renderNumber(num) { export function renderNumber(num) {
@@ -55,4 +57,33 @@ export function renderQuotaWithPrompt(quota, digits) {
return `(等价金额:${renderQuota(quota, digits)}`; return `(等价金额:${renderQuota(quota, digits)}`;
} }
return ''; return '';
} }
const colors = [
'red',
'orange',
'yellow',
'olive',
'green',
'teal',
'blue',
'violet',
'purple',
'pink',
'brown',
'grey',
'black',
];
export function renderColorLabel(text) {
let hash = 0;
for (let i = 0; i < text.length; i++) {
hash = text.charCodeAt(i) + ((hash << 5) - hash);
}
let index = Math.abs(hash % colors.length);
return (
<Label basic color={colors[index]}>
{text}
</Label>
);
}

View File

@@ -43,6 +43,7 @@ const EditChannel = () => {
base_url: '', base_url: '',
other: '', other: '',
model_mapping: '', model_mapping: '',
system_prompt: '',
models: [], models: [],
groups: ['default'] groups: ['default']
}; };
@@ -425,7 +426,7 @@ const EditChannel = () => {
) )
} }
{ {
inputs.type !== 43 && ( inputs.type !== 43 && (<>
<Form.Field> <Form.Field>
<Form.TextArea <Form.TextArea
label='模型重定向' label='模型重定向'
@@ -437,6 +438,18 @@ const EditChannel = () => {
autoComplete='new-password' autoComplete='new-password'
/> />
</Form.Field> </Form.Field>
<Form.Field>
<Form.TextArea
label='系统提示词'
placeholder={`此项可选,用于强制设置给定的系统提示词,请配合自定义模型 & 模型重定向使用,首先创建一个唯一的自定义模型名称并在上面填入,之后将该自定义模型重定向映射到该渠道一个原生支持的模型`}
name='system_prompt'
onChange={handleInputChange}
value={inputs.system_prompt}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
/>
</Form.Field>
</>
) )
} }
{ {