Compare commits

...

4 Commits

Author SHA1 Message Date
JustSong
fa71daa8a7 fix: fix wrong implementation for /v1/models (close #128) 2023-05-31 14:43:29 +08:00
JustSong
54215dc303 chore: make channel test related code separated 2023-05-23 10:01:09 +08:00
JustSong
f9f42997b2 chore: only check OpenAI channel & custom channel 2023-05-23 10:00:36 +08:00
JustSong
25eab0b224 style: fix UI related problems 2023-05-22 22:41:39 +08:00
6 changed files with 235 additions and 227 deletions

View File

@@ -144,6 +144,10 @@ func updateAllChannelsBalance() error {
if channel.Status != common.ChannelStatusEnabled {
continue
}
// TODO: support Azure
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
continue
}
balance, err := updateChannelBalance(channel)
if err != nil {
continue

199
controller/channel-test.go Normal file
View File

@@ -0,0 +1,199 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"sync"
"time"
)
func testChannel(channel *model.Channel, request *ChatRequest) error {
if request.Model == "" {
request.Model = "gpt-3.5-turbo"
if channel.Type == common.ChannelTypeAzure {
request.Model = "gpt-35-turbo"
}
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
} else {
if channel.Type == common.ChannelTypeCustom {
requestURL = channel.BaseURL
}
requestURL += "/v1/chat/completions"
}
jsonData, err := json.Marshal(request)
if err != nil {
return err
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return err
}
if response.Error.Message != "" || response.Error.Code != "" {
return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
}
return nil
}
func buildTestRequest(c *gin.Context) *ChatRequest {
model_ := c.Query("model")
testRequest := &ChatRequest{
Model: model_,
MaxTokens: 1,
}
testMessage := Message{
Role: "user",
Content: "hi",
}
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
testRequest := buildTestRequest(c)
tik := time.Now()
err = testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
"time": consumedTime,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"time": consumedTime,
})
return
}
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
}
func testAllChannels(c *gin.Context) error {
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
return errors.New("测试已在运行中")
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return err
}
testRequest := buildTestRequest(c)
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
go func() {
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
tik := time.Now()
err := testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if err != nil || milliseconds > disableThreshold {
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
}
disableChannel(channel.Id, channel.Name, err.Error())
}
channel.UpdateResponseTime(milliseconds)
}
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
}()
return nil
}
func TestAllChannels(c *gin.Context) {
err := testAllChannels(c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

View File

@@ -1,18 +1,12 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"sync"
"time"
)
func GetAllChannels(c *gin.Context) {
@@ -158,187 +152,3 @@ func UpdateChannel(c *gin.Context) {
})
return
}
func testChannel(channel *model.Channel, request *ChatRequest) error {
if request.Model == "" {
request.Model = "gpt-3.5-turbo"
if channel.Type == common.ChannelTypeAzure {
request.Model = "gpt-35-turbo"
}
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
} else {
if channel.Type == common.ChannelTypeCustom {
requestURL = channel.BaseURL
}
requestURL += "/v1/chat/completions"
}
jsonData, err := json.Marshal(request)
if err != nil {
return err
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return err
}
if response.Error.Message != "" || response.Error.Code != "" {
return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
}
return nil
}
func buildTestRequest(c *gin.Context) *ChatRequest {
model_ := c.Query("model")
testRequest := &ChatRequest{
Model: model_,
MaxTokens: 1,
}
testMessage := Message{
Role: "user",
Content: "hi",
}
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
testRequest := buildTestRequest(c)
tik := time.Now()
err = testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
"time": consumedTime,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"time": consumedTime,
})
return
}
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
}
func testAllChannels(c *gin.Context) error {
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
return errors.New("测试已在运行中")
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return err
}
testRequest := buildTestRequest(c)
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
go func() {
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
tik := time.Now()
err := testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if err != nil || milliseconds > disableThreshold {
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
}
disableChannel(channel.Id, channel.Name, err.Error())
}
channel.UpdateResponseTime(milliseconds)
}
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
}()
return nil
}
func TestAllChannels(c *gin.Context) {
err := testAllChannels(c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

View File

@@ -23,20 +23,21 @@ type OpenAIModelPermission struct {
}
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
}
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
func init() {
permission := OpenAIModelPermission{
var permission []OpenAIModelPermission
permission = append(permission, OpenAIModelPermission{
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
Object: "model_permission",
Created: 1626777600,
@@ -49,7 +50,7 @@ func init() {
Organization: "*",
Group: nil,
IsBlocking: false,
}
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{
{
@@ -106,15 +107,6 @@ func init() {
Root: "gpt-4-32k-0314",
Parent: nil,
},
{
Id: "gpt-3.5-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
@@ -132,7 +124,10 @@ func init() {
}
func ListModels(c *gin.Context) {
c.JSON(200, openAIModels)
c.JSON(200, gin.H{
"object": "list",
"data": openAIModels,
})
}
func RetrieveModel(c *gin.Context) {

View File

@@ -405,7 +405,7 @@ const ChannelsTable = () => {
<Button size='small' loading={loading} onClick={testAllChannels}>
测试所有已启用通道
</Button>
<Button size='small' onClick={updateAllChannelsBalance} loading={updatingBalance}>更新所有已启用通道余额</Button>
<Button size='small' onClick={updateAllChannelsBalance} loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
<Pagination
floated='right'
activePage={activePage}

View File

@@ -106,22 +106,6 @@ const EditToken = () => {
required={!isEdit}
/>
</Form.Field>
<Message>注意令牌的额度仅用于限制令牌本身的最大额度使用量实际的使用受到账户的剩余额度限制</Message>
<Form.Field>
<Form.Input
label='额度'
name='remain_quota'
placeholder={'请输入额度'}
onChange={handleInputChange}
value={remain_quota}
autoComplete='new-password'
type='number'
disabled={unlimited_quota}
/>
</Form.Field>
<Button type={'button'} style={{ marginBottom: '14px' }} onClick={() => {
setUnlimitedQuota();
}}>{unlimited_quota ? '取消无限额度' : '设置为无限额度'}</Button>
<Form.Field>
<Form.Input
label='过期时间'
@@ -150,7 +134,23 @@ const EditToken = () => {
setExpiredTime(0, 0, 0, 1);
}}>一分钟后过期</Button>
</div>
<Button positive onClick={submit} style={{marginTop: '12px'}}>提交</Button>
<Message>注意令牌的额度仅用于限制令牌本身的最大额度使用量实际的使用受到账户的剩余额度限制</Message>
<Form.Field>
<Form.Input
label='额度'
name='remain_quota'
placeholder={'请输入额度'}
onChange={handleInputChange}
value={remain_quota}
autoComplete='new-password'
type='number'
disabled={unlimited_quota}
/>
</Form.Field>
<Button type={'button'} onClick={() => {
setUnlimitedQuota();
}}>{unlimited_quota ? '取消无限额度' : '设置为无限额度'}</Button>
<Button positive onClick={submit}>提交</Button>
</Form>
</Segment>
</>