up oauth re
This commit is contained in:
@@ -101,7 +101,7 @@ func main() {
|
||||
}
|
||||
|
||||
func DatabaseAutoUpdate() {
|
||||
version := 244
|
||||
version := 245
|
||||
|
||||
db := global.DB
|
||||
|
||||
@@ -146,6 +146,21 @@ func DatabaseAutoUpdate() {
|
||||
if v.Version < uint(version) {
|
||||
Migrate(uint(version))
|
||||
}
|
||||
// 245迁移
|
||||
if v.Version < 245 {
|
||||
//oauths 表的 oauth_type 字段设置为 op同样的值
|
||||
db.Exec("update oauths set oauth_type = op")
|
||||
db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google' and issuer = ''")
|
||||
db.Exec("update user_thirds set oauth_type = third_type, op = third_type")
|
||||
//通过email迁移旧的google授权
|
||||
uts := make([]model.UserThird, 0)
|
||||
db.Where("oauth_type = ?", "google").Find(&uts)
|
||||
for _, ut := range uts {
|
||||
if ut.UserId > 0 {
|
||||
db.Model(&model.User{}).Where("id = ?", ut.UserId).Update("email", ut.OpenId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -180,15 +180,18 @@ func (o *Oauth) Create(c *gin.Context) {
|
||||
response.Fail(c, 101, errList[0])
|
||||
return
|
||||
}
|
||||
|
||||
ex := service.AllService.OauthService.InfoByOp(f.Op)
|
||||
u := f.ToOauth()
|
||||
err := u.FormatOauthInfo()
|
||||
if err != nil {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
|
||||
return
|
||||
}
|
||||
ex := service.AllService.OauthService.InfoByOp(u.Op)
|
||||
if ex.Id > 0 {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "ItemExists"))
|
||||
return
|
||||
}
|
||||
|
||||
u := f.ToOauth()
|
||||
err := service.AllService.OauthService.Create(u)
|
||||
err = service.AllService.OauthService.Create(u)
|
||||
if err != nil {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
|
||||
return
|
||||
|
||||
@@ -217,7 +217,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
|
||||
oauthCache.UserId = user.Id
|
||||
oauthService.SetOauthCache(cacheKey, oauthCache, 0)
|
||||
// 如果是webadmin,登录成功后跳转到webadmin
|
||||
if oauthCache.DeviceType == "webadmin" {
|
||||
if oauthCache.DeviceType == model.LoginLogClientWebAdmin {
|
||||
/*service.AllService.UserService.Login(u, &model.LoginLog{
|
||||
UserId: u.Id,
|
||||
Client: "webadmin",
|
||||
|
||||
@@ -5,15 +5,15 @@ import (
|
||||
)
|
||||
|
||||
type UserForm struct {
|
||||
Id uint `json:"id"`
|
||||
Username string `json:"username" validate:"required,gte=4,lte=10"`
|
||||
Email string `json:"email" validate:"required,email"`
|
||||
Id uint `json:"id"`
|
||||
Username string `json:"username" validate:"required,gte=4,lte=10"`
|
||||
Email string `json:"email"` //validate:"required,email" email不强制
|
||||
//Password string `json:"password" validate:"required,gte=4,lte=20"`
|
||||
Nickname string `json:"nickname"`
|
||||
Avatar string `json:"avatar"`
|
||||
GroupId uint `json:"group_id" validate:"required"`
|
||||
IsAdmin *bool `json:"is_admin" `
|
||||
Status model.StatusCode `json:"status" validate:"required,gte=0"`
|
||||
Nickname string `json:"nickname"`
|
||||
Avatar string `json:"avatar"`
|
||||
GroupId uint `json:"group_id" validate:"required"`
|
||||
IsAdmin *bool `json:"is_admin" `
|
||||
Status model.StatusCode `json:"status" validate:"required,gte=0"`
|
||||
}
|
||||
|
||||
func (uf *UserForm) FromUser(user *model.User) *UserForm {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const OIDC_DEFAULT_SCOPES = "openid,profile,email"
|
||||
@@ -27,32 +27,23 @@ func ValidateOauthType(oauthType string) error {
|
||||
}
|
||||
|
||||
const (
|
||||
OauthNameGithub string = "GitHub"
|
||||
OauthNameGoogle string = "Google"
|
||||
OauthNameOidc string = "OIDC"
|
||||
OauthNameWebauth string = "WebAuth"
|
||||
)
|
||||
|
||||
const (
|
||||
UserEndpointGithub string = "https://api.github.com/user"
|
||||
IssuerGoogle string = "https://accounts.google.com"
|
||||
UserEndpointGithub string = "https://api.github.com/user"
|
||||
IssuerGoogle string = "https://accounts.google.com"
|
||||
)
|
||||
|
||||
type Oauth struct {
|
||||
IdModel
|
||||
Op string `json:"op"`
|
||||
OauthType string `json:"oauth_type"`
|
||||
ClientId string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
RedirectUrl string `json:"redirect_url"`
|
||||
AutoRegister *bool `json:"auto_register"`
|
||||
Scopes string `json:"scopes"`
|
||||
Issuer string `json:"issuer"`
|
||||
Op string `json:"op"`
|
||||
OauthType string `json:"oauth_type"`
|
||||
ClientId string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
RedirectUrl string `json:"redirect_url"`
|
||||
AutoRegister *bool `json:"auto_register"`
|
||||
Scopes string `json:"scopes"`
|
||||
Issuer string `json:"issuer"`
|
||||
TimeModel
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Helper function to format oauth info, it's used in the update and create method
|
||||
func (oa *Oauth) FormatOauthInfo() error {
|
||||
oauthType := strings.TrimSpace(oa.OauthType)
|
||||
@@ -60,25 +51,20 @@ func (oa *Oauth) FormatOauthInfo() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch oauthType {
|
||||
case OauthTypeGithub:
|
||||
oa.Op = OauthTypeGithub
|
||||
case OauthTypeGoogle:
|
||||
oa.Op = OauthTypeGoogle
|
||||
}
|
||||
// check if the op is empty, set the default value
|
||||
op := strings.TrimSpace(oa.Op)
|
||||
if op == "" {
|
||||
switch oauthType {
|
||||
case OauthTypeGithub:
|
||||
oa.Op = OauthNameGithub
|
||||
case OauthTypeGoogle:
|
||||
oa.Op = OauthNameGoogle
|
||||
case OauthTypeOidc:
|
||||
oa.Op = OauthNameOidc
|
||||
case OauthTypeWebauth:
|
||||
oa.Op = OauthNameWebauth
|
||||
default:
|
||||
oa.Op = oauthType
|
||||
}
|
||||
if op == "" && oauthType == OauthTypeOidc {
|
||||
oa.Op = OauthTypeOidc
|
||||
}
|
||||
// check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value
|
||||
issuer := strings.TrimSpace(oa.Issuer)
|
||||
// If the oauth type is google and the issuer is empty, set the issuer to the default value
|
||||
// If the oauth type is google and the issuer is empty, set the issuer to the default value
|
||||
if oauthType == OauthTypeGoogle && issuer == "" {
|
||||
oa.Issuer = IssuerGoogle
|
||||
}
|
||||
@@ -86,12 +72,12 @@ func (oa *Oauth) FormatOauthInfo() error {
|
||||
}
|
||||
|
||||
type OauthUser struct {
|
||||
OpenId string `json:"open_id" gorm:"not null;index"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
VerifiedEmail bool `json:"verified_email,omitempty"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
OpenId string `json:"open_id" gorm:"not null;index"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
VerifiedEmail bool `json:"verified_email,omitempty"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
}
|
||||
|
||||
func (ou *OauthUser) ToUser(user *User, overideUsername bool) {
|
||||
@@ -122,7 +108,7 @@ func (ou *OidcUser) ToOauthUser() *OauthUser {
|
||||
if ou.PreferredUsername != "" {
|
||||
username = ou.PreferredUsername
|
||||
} else {
|
||||
username = strings.ToLower(strings.Split(ou.Email, "@")[0])
|
||||
username = strings.ToLower(ou.Email)
|
||||
}
|
||||
|
||||
return &OauthUser{
|
||||
@@ -135,29 +121,26 @@ func (ou *OidcUser) ToOauthUser() *OauthUser {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
type GithubUser struct {
|
||||
OauthUserBase
|
||||
Id int `json:"id"`
|
||||
Login string `json:"login"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
VerifiedEmail bool `json:"verified_email"`
|
||||
Id int `json:"id"`
|
||||
Login string `json:"login"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
VerifiedEmail bool `json:"verified_email"`
|
||||
}
|
||||
|
||||
func (gu *GithubUser) ToOauthUser() *OauthUser {
|
||||
username := strings.ToLower(gu.Login)
|
||||
return &OauthUser{
|
||||
OpenId: strconv.Itoa(gu.Id),
|
||||
Name: gu.Name,
|
||||
Username: username,
|
||||
Email: gu.Email,
|
||||
VerifiedEmail: gu.VerifiedEmail,
|
||||
Picture: gu.AvatarUrl,
|
||||
OpenId: strconv.Itoa(gu.Id),
|
||||
Name: gu.Name,
|
||||
Username: username,
|
||||
Email: gu.Email,
|
||||
VerifiedEmail: gu.VerifiedEmail,
|
||||
Picture: gu.AvatarUrl,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
type OauthList struct {
|
||||
Oauths []*Oauth `json:"list"`
|
||||
Pagination
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
IdModel
|
||||
Username string `json:"username" gorm:"default:'';not null;uniqueIndex"`
|
||||
Email string `json:"email" gorm:"default:'';not null;uniqueIndex"`
|
||||
Username string `json:"username" gorm:"default:'';not null;uniqueIndex"`
|
||||
Email string `json:"email" gorm:"default:'';not null;index"`
|
||||
// Email string `json:"email" `
|
||||
Password string `json:"-" gorm:"default:'';not null;"`
|
||||
Nickname string `json:"nickname" gorm:"default:'';not null;"`
|
||||
@@ -20,13 +15,13 @@ type User struct {
|
||||
}
|
||||
|
||||
// BeforeSave 钩子用于确保 email 字段有合理的默认值
|
||||
func (u *User) BeforeSave(tx *gorm.DB) (err error) {
|
||||
// 如果 email 为空,设置为默认值
|
||||
if u.Email == "" {
|
||||
u.Email = fmt.Sprintf("%s@example.com", u.Username)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
//func (u *User) BeforeSave(tx *gorm.DB) (err error) {
|
||||
// // 如果 email 为空,设置为默认值
|
||||
// if u.Email == "" {
|
||||
// u.Email = fmt.Sprintf("%s@example.com", u.Username)
|
||||
// }
|
||||
// return nil
|
||||
//}
|
||||
|
||||
type UserList struct {
|
||||
Users []*User `json:"list,omitempty"`
|
||||
|
||||
@@ -6,20 +6,21 @@ import (
|
||||
|
||||
type UserThird struct {
|
||||
IdModel
|
||||
UserId uint ` json:"user_id" gorm:"not null;index"`
|
||||
UserId uint `json:"user_id" gorm:"not null;index"`
|
||||
OauthUser
|
||||
// UnionId string `json:"union_id" gorm:"not null;"`
|
||||
UnionId string `json:"union_id" gorm:"default:'';not null;"`
|
||||
// OauthType string `json:"oauth_type" gorm:"not null;"`
|
||||
OauthType string `json:"oauth_type"`
|
||||
Op string `json:"op" gorm:"not null;"`
|
||||
ThirdType string `json:"third_type" gorm:"default:'';not null;"` //deprecated
|
||||
OauthType string `json:"oauth_type" gorm:"default:'';not null;"`
|
||||
Op string `json:"op" gorm:"default:'';not null;"`
|
||||
TimeModel
|
||||
}
|
||||
|
||||
func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType string, op string) {
|
||||
u.UserId = userId
|
||||
u.OauthUser = *oauthUser
|
||||
u.OauthType = oauthType
|
||||
u.Op = op
|
||||
u.UserId = userId
|
||||
u.OauthUser = *oauthUser
|
||||
u.OauthType = oauthType
|
||||
u.Op = op
|
||||
// make sure email is lower case
|
||||
u.Email = strings.ToLower(u.Email)
|
||||
}
|
||||
u.Email = strings.ToLower(u.Email)
|
||||
}
|
||||
|
||||
@@ -12,16 +12,15 @@ import (
|
||||
// "golang.org/x/oauth2/google"
|
||||
"gorm.io/gorm"
|
||||
// "io"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
|
||||
type OauthService struct {
|
||||
}
|
||||
|
||||
@@ -34,26 +33,26 @@ type OidcEndpoint struct {
|
||||
}
|
||||
|
||||
type OauthCacheItem struct {
|
||||
UserId uint `json:"user_id"`
|
||||
Id string `json:"id"` //rustdesk的设备ID
|
||||
Op string `json:"op"`
|
||||
Action string `json:"action"`
|
||||
Uuid string `json:"uuid"`
|
||||
DeviceName string `json:"device_name"`
|
||||
DeviceOs string `json:"device_os"`
|
||||
DeviceType string `json:"device_type"`
|
||||
OpenId string `json:"open_id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
UserId uint `json:"user_id"`
|
||||
Id string `json:"id"` //rustdesk的设备ID
|
||||
Op string `json:"op"`
|
||||
Action string `json:"action"`
|
||||
Uuid string `json:"uuid"`
|
||||
DeviceName string `json:"device_name"`
|
||||
DeviceOs string `json:"device_os"`
|
||||
DeviceType string `json:"device_type"`
|
||||
OpenId string `json:"open_id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
|
||||
return &model.OauthUser{
|
||||
OpenId: oci.OpenId,
|
||||
OpenId: oci.OpenId,
|
||||
Username: oci.Username,
|
||||
Name: oci.Name,
|
||||
Email: oci.Email,
|
||||
Name: oci.Name,
|
||||
Email: oci.Email,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,14 +63,13 @@ const (
|
||||
OauthActionTypeBind = "bind"
|
||||
)
|
||||
|
||||
func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
|
||||
oa.OpenId = oauthUser.OpenId
|
||||
oa.Username = oauthUser.Username
|
||||
oa.Name = oauthUser.Name
|
||||
oa.Email = oauthUser.Email
|
||||
func (oci *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
|
||||
oci.OpenId = oauthUser.OpenId
|
||||
oci.Username = oauthUser.Username
|
||||
oci.Name = oauthUser.Name
|
||||
oci.Email = oauthUser.Email
|
||||
}
|
||||
|
||||
|
||||
func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
|
||||
v, ok := OauthCache.Load(key)
|
||||
if !ok {
|
||||
@@ -164,7 +162,7 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O
|
||||
if err != nil {
|
||||
return err, nil, nil
|
||||
}
|
||||
oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL,TokenURL: endpoint.TokenURL,}
|
||||
oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL, TokenURL: endpoint.TokenURL}
|
||||
oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
|
||||
default:
|
||||
return errors.New("unsupported OAuth type"), nil, nil
|
||||
@@ -259,9 +257,8 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string)
|
||||
return nil, user.ToOauthUser()
|
||||
}
|
||||
|
||||
|
||||
// oidcCallback oidc回调, 通过code获取用户信息
|
||||
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) {
|
||||
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) {
|
||||
var user = &model.OidcUser{}
|
||||
if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
|
||||
return err, nil
|
||||
@@ -280,21 +277,20 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
|
||||
}
|
||||
oauthType := oauthInfo.OauthType
|
||||
switch oauthType {
|
||||
case model.OauthTypeGithub:
|
||||
err, oauthUser = os.githubCallback(oauthConfig, code)
|
||||
case model.OauthTypeOidc, model.OauthTypeGoogle:
|
||||
case model.OauthTypeGithub:
|
||||
err, oauthUser = os.githubCallback(oauthConfig, code)
|
||||
case model.OauthTypeOidc, model.OauthTypeGoogle:
|
||||
err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
|
||||
default:
|
||||
return errors.New("unsupported OAuth type"), nil
|
||||
}
|
||||
return err, oauthUser
|
||||
err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
|
||||
default:
|
||||
return errors.New("unsupported OAuth type"), nil
|
||||
}
|
||||
return err, oauthUser
|
||||
}
|
||||
|
||||
|
||||
func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird {
|
||||
ut := &model.UserThird{}
|
||||
global.DB.Where("open_id = ? and op = ?", openId, op).First(ut)
|
||||
@@ -343,17 +339,17 @@ func (os *OauthService) InfoByOp(op string) *model.Oauth {
|
||||
|
||||
// Helper function to get scopes by operation
|
||||
func (os *OauthService) getScopesByOp(op string) []string {
|
||||
scopes := os.InfoByOp(op).Scopes
|
||||
scopes := os.InfoByOp(op).Scopes
|
||||
return os.constructScopes(scopes)
|
||||
}
|
||||
|
||||
// Helper function to construct scopes
|
||||
func (os *OauthService) constructScopes(scopes string) []string {
|
||||
scopes = strings.TrimSpace(scopes)
|
||||
if scopes == "" {
|
||||
scopes = model.OIDC_DEFAULT_SCOPES
|
||||
}
|
||||
return strings.Split(scopes, ",")
|
||||
scopes = strings.TrimSpace(scopes)
|
||||
if scopes == "" {
|
||||
scopes = model.OIDC_DEFAULT_SCOPES
|
||||
}
|
||||
return strings.Split(scopes, ",")
|
||||
}
|
||||
|
||||
func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) {
|
||||
@@ -461,4 +457,4 @@ func (os *OauthService) getGithubPrimaryEmail(client *http.Client, githubUser *m
|
||||
}
|
||||
|
||||
return fmt.Errorf("no primary verified email found")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,13 +5,13 @@ import (
|
||||
adResp "Gwen/http/response/admin"
|
||||
"Gwen/model"
|
||||
"Gwen/utils"
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"time"
|
||||
"strings"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserService struct {
|
||||
@@ -23,6 +23,7 @@ func (us *UserService) InfoById(id uint) *model.User {
|
||||
global.DB.Where("id = ?", id).First(u)
|
||||
return u
|
||||
}
|
||||
|
||||
// InfoByUsername 根据用户名取用户信息
|
||||
func (us *UserService) InfoByUsername(un string) *model.User {
|
||||
u := &model.User{}
|
||||
@@ -75,11 +76,11 @@ func (us *UserService) GenerateToken(u *model.User) string {
|
||||
func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserToken {
|
||||
token := us.GenerateToken(u)
|
||||
ut := &model.UserToken{
|
||||
UserId: u.Id,
|
||||
Token: token,
|
||||
UserId: u.Id,
|
||||
Token: token,
|
||||
DeviceUuid: llog.Uuid,
|
||||
DeviceId: llog.DeviceId,
|
||||
ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(),
|
||||
ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(),
|
||||
}
|
||||
global.DB.Create(ut)
|
||||
llog.UserTokenId = ut.UserId
|
||||
@@ -162,7 +163,7 @@ func (us *UserService) Create(u *model.User) error {
|
||||
// GetUuidByToken 根据token和user取uuid
|
||||
func (us *UserService) GetUuidByToken(u *model.User, token string) string {
|
||||
ut := &model.UserToken{}
|
||||
err :=global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error
|
||||
err := global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
@@ -214,12 +215,12 @@ func (us *UserService) Delete(u *model.User) error {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
tx.Commit()
|
||||
// 删除关联的peer
|
||||
if err := AllService.PeerService.EraseUserId(u.Id); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -230,7 +231,7 @@ func (us *UserService) Update(u *model.User) error {
|
||||
if us.IsAdmin(currentUser) {
|
||||
adminCount := us.getAdminUserCount()
|
||||
// 如果这是唯一的管理员,确保不能禁用或取消管理员权限
|
||||
if adminCount <= 1 && ( !us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) {
|
||||
if adminCount <= 1 && (!us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) {
|
||||
return errors.New("The last admin user cannot be disabled or demoted")
|
||||
}
|
||||
}
|
||||
@@ -290,48 +291,49 @@ func (us *UserService) InfoByOauthId(op string, openId string) *model.User {
|
||||
}
|
||||
|
||||
// RegisterByOauth 注册
|
||||
func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) (error, *model.User) {
|
||||
func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser, op string) (error, *model.User) {
|
||||
global.Lock.Lock("registerByOauth")
|
||||
defer global.Lock.UnLock("registerByOauth")
|
||||
ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId)
|
||||
if ut.Id != 0 {
|
||||
return nil, us.InfoById(ut.UserId)
|
||||
}
|
||||
//check if this email has been registered
|
||||
email := oauthUser.Email
|
||||
err, oauthType := AllService.OauthService.GetTypeByOp(op)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
// if email is empty, use username and op as email
|
||||
if email == "" {
|
||||
email = oauthUser.Username + "@" + op
|
||||
}
|
||||
email = strings.ToLower(email)
|
||||
// update email to oauthUser, in case it contain upper case
|
||||
oauthUser.Email = email
|
||||
user := us.InfoByEmail(email)
|
||||
tx := global.DB.Begin()
|
||||
if user.Id != 0 {
|
||||
ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
|
||||
} else {
|
||||
ut = &model.UserThird{}
|
||||
ut.FromOauthUser(0, oauthUser, oauthType, op)
|
||||
// The initial username should be formatted
|
||||
username := us.formatUsername(oauthUser.Username)
|
||||
usernameUnique := us.GenerateUsernameByOauth(username)
|
||||
user = &model.User{
|
||||
Username: usernameUnique,
|
||||
GroupId: 1,
|
||||
//check if this email has been registered
|
||||
email := oauthUser.Email
|
||||
// only email is not empty
|
||||
if email != "" {
|
||||
email = strings.ToLower(email)
|
||||
// update email to oauthUser, in case it contain upper case
|
||||
oauthUser.Email = email
|
||||
user := us.InfoByEmail(email)
|
||||
if user.Id != 0 {
|
||||
ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
|
||||
global.DB.Create(ut)
|
||||
return nil, user
|
||||
}
|
||||
oauthUser.ToUser(user, false)
|
||||
tx.Create(user)
|
||||
if user.Id == 0 {
|
||||
tx.Rollback()
|
||||
return errors.New("OauthRegisterFailed"), user
|
||||
}
|
||||
ut.UserId = user.Id
|
||||
}
|
||||
|
||||
tx := global.DB.Begin()
|
||||
ut = &model.UserThird{}
|
||||
ut.FromOauthUser(0, oauthUser, oauthType, op)
|
||||
// The initial username should be formatted
|
||||
username := us.formatUsername(oauthUser.Username)
|
||||
usernameUnique := us.GenerateUsernameByOauth(username)
|
||||
user := &model.User{
|
||||
Username: usernameUnique,
|
||||
GroupId: 1,
|
||||
}
|
||||
oauthUser.ToUser(user, false)
|
||||
tx.Create(user)
|
||||
if user.Id == 0 {
|
||||
tx.Rollback()
|
||||
return errors.New("OauthRegisterFailed"), user
|
||||
}
|
||||
ut.UserId = user.Id
|
||||
tx.Create(ut)
|
||||
tx.Commit()
|
||||
return nil, user
|
||||
@@ -433,7 +435,7 @@ func (us *UserService) formatUsername(username string) string {
|
||||
return username
|
||||
}
|
||||
|
||||
// Helper functions, getUserCount
|
||||
// Helper functions, getUserCount
|
||||
func (us *UserService) getUserCount() int64 {
|
||||
var count int64
|
||||
global.DB.Model(&model.User{}).Count(&count)
|
||||
@@ -445,4 +447,4 @@ func (us *UserService) getAdminUserCount() int64 {
|
||||
var count int64
|
||||
global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count)
|
||||
return count
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user