🎨 结构中继控制器

This commit is contained in:
MartialBE
2023-12-02 03:28:18 +08:00
parent 2114bc1982
commit be364ae09b
45 changed files with 1267 additions and 204 deletions

View File

@@ -9,9 +9,104 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"one-api/providers"
providersBase "one-api/providers/base"
"one-api/types"
"reflect"
"strconv"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
)
func GetValidFieldName(err error, obj interface{}) string {
getObj := reflect.TypeOf(obj)
if errs, ok := err.(validator.ValidationErrors); ok {
for _, e := range errs {
if f, exist := getObj.Elem().FieldByName(e.Field()); exist {
return f.Name
}
}
}
return err.Error()
}
func fetchChannel(c *gin.Context, modelName string) (*model.Channel, bool) {
channelId, ok := c.Get("channelId")
if ok {
return fetchChannelById(c, channelId.(int))
}
return fetchChannelByModel(c, modelName)
}
func fetchChannelById(c *gin.Context, channelId any) (*model.Channel, bool) {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return nil, true
}
channel, err := model.GetChannelById(id, true)
if err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return nil, true
}
if channel.Status != common.ChannelStatusEnabled {
common.AbortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
return nil, true
}
return channel, false
}
func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool) {
group := c.GetString("group")
channel, err := model.CacheGetRandomSatisfiedChannel(group, modelName)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName)
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
common.AbortWithMessage(c, http.StatusServiceUnavailable, message)
return nil, true
}
return channel, false
}
func getProvider(c *gin.Context, channelType int, relayMode int) (providersBase.ProviderInterface, bool) {
provider := providers.GetProvider(channelType, c)
if provider == nil {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
return nil, true
}
if !provider.SupportAPI(relayMode) {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel does not support this API")
return nil, true
}
return provider, false
}
func setChannelToContext(c *gin.Context, channel *model.Channel) {
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.GetModelMapping())
c.Set("api_key", channel.Key)
c.Set("base_url", channel.GetBaseURL())
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
}
}
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
if !common.AutomaticDisableChannelEnabled {
return false
@@ -68,6 +163,26 @@ type QuotaInfo struct {
userId int
channelId int
tokenId int
HandelStatus bool
}
func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) {
quotaInfo := &QuotaInfo{
modelName: modelName,
promptTokens: promptTokens,
userId: c.GetInt("id"),
channelId: c.GetInt("channel_id"),
tokenId: c.GetInt("token_id"),
HandelStatus: false,
}
quotaInfo.initQuotaInfo(c.GetString("group"))
errWithCode := quotaInfo.preQuotaConsumption()
if errWithCode != nil {
return nil, errWithCode
}
return quotaInfo, nil
}
func (q *QuotaInfo) initQuotaInfo(groupName string) {
@@ -89,16 +204,16 @@ func (q *QuotaInfo) initQuotaInfo(groupName string) {
func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
userQuota, err := model.CacheGetUserQuota(q.userId)
if err != nil {
return types.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota < q.preConsumedQuota {
return types.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota)
if err != nil {
return types.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*q.preConsumedQuota {
@@ -111,8 +226,9 @@ func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
if q.preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota)
if err != nil {
return types.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
q.HandelStatus = true
}
return nil