up oauth re

This commit is contained in:
ljw
2024-11-05 09:48:02 +08:00
parent 8af01c859c
commit d4d39eecaa
9 changed files with 170 additions and 175 deletions

View File

@@ -101,7 +101,7 @@ func main() {
} }
func DatabaseAutoUpdate() { func DatabaseAutoUpdate() {
version := 244 version := 245
db := global.DB db := global.DB
@@ -146,6 +146,21 @@ func DatabaseAutoUpdate() {
if v.Version < uint(version) { if v.Version < uint(version) {
Migrate(uint(version)) Migrate(uint(version))
} }
// 245迁移
if v.Version < 245 {
//oauths 表的 oauth_type 字段设置为 op同样的值
db.Exec("update oauths set oauth_type = op")
db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google' and issuer = ''")
db.Exec("update user_thirds set oauth_type = third_type, op = third_type")
//通过email迁移旧的google授权
uts := make([]model.UserThird, 0)
db.Where("oauth_type = ?", "google").Find(&uts)
for _, ut := range uts {
if ut.UserId > 0 {
db.Model(&model.User{}).Where("id = ?", ut.UserId).Update("email", ut.OpenId)
}
}
}
} }
} }

View File

@@ -180,15 +180,18 @@ func (o *Oauth) Create(c *gin.Context) {
response.Fail(c, 101, errList[0]) response.Fail(c, 101, errList[0])
return return
} }
u := f.ToOauth()
ex := service.AllService.OauthService.InfoByOp(f.Op) err := u.FormatOauthInfo()
if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
return
}
ex := service.AllService.OauthService.InfoByOp(u.Op)
if ex.Id > 0 { if ex.Id > 0 {
response.Fail(c, 101, response.TranslateMsg(c, "ItemExists")) response.Fail(c, 101, response.TranslateMsg(c, "ItemExists"))
return return
} }
err = service.AllService.OauthService.Create(u)
u := f.ToOauth()
err := service.AllService.OauthService.Create(u)
if err != nil { if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error()) response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
return return

View File

@@ -217,7 +217,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
oauthCache.UserId = user.Id oauthCache.UserId = user.Id
oauthService.SetOauthCache(cacheKey, oauthCache, 0) oauthService.SetOauthCache(cacheKey, oauthCache, 0)
// 如果是webadmin登录成功后跳转到webadmin // 如果是webadmin登录成功后跳转到webadmin
if oauthCache.DeviceType == "webadmin" { if oauthCache.DeviceType == model.LoginLogClientWebAdmin {
/*service.AllService.UserService.Login(u, &model.LoginLog{ /*service.AllService.UserService.Login(u, &model.LoginLog{
UserId: u.Id, UserId: u.Id,
Client: "webadmin", Client: "webadmin",

View File

@@ -5,15 +5,15 @@ import (
) )
type UserForm struct { type UserForm struct {
Id uint `json:"id"` Id uint `json:"id"`
Username string `json:"username" validate:"required,gte=4,lte=10"` Username string `json:"username" validate:"required,gte=4,lte=10"`
Email string `json:"email" validate:"required,email"` Email string `json:"email"` //validate:"required,email" email不强制
//Password string `json:"password" validate:"required,gte=4,lte=20"` //Password string `json:"password" validate:"required,gte=4,lte=20"`
Nickname string `json:"nickname"` Nickname string `json:"nickname"`
Avatar string `json:"avatar"` Avatar string `json:"avatar"`
GroupId uint `json:"group_id" validate:"required"` GroupId uint `json:"group_id" validate:"required"`
IsAdmin *bool `json:"is_admin" ` IsAdmin *bool `json:"is_admin" `
Status model.StatusCode `json:"status" validate:"required,gte=0"` Status model.StatusCode `json:"status" validate:"required,gte=0"`
} }
func (uf *UserForm) FromUser(user *model.User) *UserForm { func (uf *UserForm) FromUser(user *model.User) *UserForm {

View File

@@ -1,9 +1,9 @@
package model package model
import ( import (
"errors"
"strconv" "strconv"
"strings" "strings"
"errors"
) )
const OIDC_DEFAULT_SCOPES = "openid,profile,email" const OIDC_DEFAULT_SCOPES = "openid,profile,email"
@@ -27,32 +27,23 @@ func ValidateOauthType(oauthType string) error {
} }
const ( const (
OauthNameGithub string = "GitHub" UserEndpointGithub string = "https://api.github.com/user"
OauthNameGoogle string = "Google" IssuerGoogle string = "https://accounts.google.com"
OauthNameOidc string = "OIDC"
OauthNameWebauth string = "WebAuth"
)
const (
UserEndpointGithub string = "https://api.github.com/user"
IssuerGoogle string = "https://accounts.google.com"
) )
type Oauth struct { type Oauth struct {
IdModel IdModel
Op string `json:"op"` Op string `json:"op"`
OauthType string `json:"oauth_type"` OauthType string `json:"oauth_type"`
ClientId string `json:"client_id"` ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"` ClientSecret string `json:"client_secret"`
RedirectUrl string `json:"redirect_url"` RedirectUrl string `json:"redirect_url"`
AutoRegister *bool `json:"auto_register"` AutoRegister *bool `json:"auto_register"`
Scopes string `json:"scopes"` Scopes string `json:"scopes"`
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
TimeModel TimeModel
} }
// Helper function to format oauth info, it's used in the update and create method // Helper function to format oauth info, it's used in the update and create method
func (oa *Oauth) FormatOauthInfo() error { func (oa *Oauth) FormatOauthInfo() error {
oauthType := strings.TrimSpace(oa.OauthType) oauthType := strings.TrimSpace(oa.OauthType)
@@ -60,25 +51,20 @@ func (oa *Oauth) FormatOauthInfo() error {
if err != nil { if err != nil {
return err return err
} }
switch oauthType {
case OauthTypeGithub:
oa.Op = OauthTypeGithub
case OauthTypeGoogle:
oa.Op = OauthTypeGoogle
}
// check if the op is empty, set the default value // check if the op is empty, set the default value
op := strings.TrimSpace(oa.Op) op := strings.TrimSpace(oa.Op)
if op == "" { if op == "" && oauthType == OauthTypeOidc {
switch oauthType { oa.Op = OauthTypeOidc
case OauthTypeGithub:
oa.Op = OauthNameGithub
case OauthTypeGoogle:
oa.Op = OauthNameGoogle
case OauthTypeOidc:
oa.Op = OauthNameOidc
case OauthTypeWebauth:
oa.Op = OauthNameWebauth
default:
oa.Op = oauthType
}
} }
// check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value // check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value
issuer := strings.TrimSpace(oa.Issuer) issuer := strings.TrimSpace(oa.Issuer)
// If the oauth type is google and the issuer is empty, set the issuer to the default value // If the oauth type is google and the issuer is empty, set the issuer to the default value
if oauthType == OauthTypeGoogle && issuer == "" { if oauthType == OauthTypeGoogle && issuer == "" {
oa.Issuer = IssuerGoogle oa.Issuer = IssuerGoogle
} }
@@ -86,12 +72,12 @@ func (oa *Oauth) FormatOauthInfo() error {
} }
type OauthUser struct { type OauthUser struct {
OpenId string `json:"open_id" gorm:"not null;index"` OpenId string `json:"open_id" gorm:"not null;index"`
Name string `json:"name"` Name string `json:"name"`
Username string `json:"username"` Username string `json:"username"`
Email string `json:"email"` Email string `json:"email"`
VerifiedEmail bool `json:"verified_email,omitempty"` VerifiedEmail bool `json:"verified_email,omitempty"`
Picture string `json:"picture,omitempty"` Picture string `json:"picture,omitempty"`
} }
func (ou *OauthUser) ToUser(user *User, overideUsername bool) { func (ou *OauthUser) ToUser(user *User, overideUsername bool) {
@@ -122,7 +108,7 @@ func (ou *OidcUser) ToOauthUser() *OauthUser {
if ou.PreferredUsername != "" { if ou.PreferredUsername != "" {
username = ou.PreferredUsername username = ou.PreferredUsername
} else { } else {
username = strings.ToLower(strings.Split(ou.Email, "@")[0]) username = strings.ToLower(ou.Email)
} }
return &OauthUser{ return &OauthUser{
@@ -135,29 +121,26 @@ func (ou *OidcUser) ToOauthUser() *OauthUser {
} }
} }
type GithubUser struct { type GithubUser struct {
OauthUserBase OauthUserBase
Id int `json:"id"` Id int `json:"id"`
Login string `json:"login"` Login string `json:"login"`
AvatarUrl string `json:"avatar_url"` AvatarUrl string `json:"avatar_url"`
VerifiedEmail bool `json:"verified_email"` VerifiedEmail bool `json:"verified_email"`
} }
func (gu *GithubUser) ToOauthUser() *OauthUser { func (gu *GithubUser) ToOauthUser() *OauthUser {
username := strings.ToLower(gu.Login) username := strings.ToLower(gu.Login)
return &OauthUser{ return &OauthUser{
OpenId: strconv.Itoa(gu.Id), OpenId: strconv.Itoa(gu.Id),
Name: gu.Name, Name: gu.Name,
Username: username, Username: username,
Email: gu.Email, Email: gu.Email,
VerifiedEmail: gu.VerifiedEmail, VerifiedEmail: gu.VerifiedEmail,
Picture: gu.AvatarUrl, Picture: gu.AvatarUrl,
} }
} }
type OauthList struct { type OauthList struct {
Oauths []*Oauth `json:"list"` Oauths []*Oauth `json:"list"`
Pagination Pagination

View File

@@ -1,14 +1,9 @@
package model package model
import (
"fmt"
"gorm.io/gorm"
)
type User struct { type User struct {
IdModel IdModel
Username string `json:"username" gorm:"default:'';not null;uniqueIndex"` Username string `json:"username" gorm:"default:'';not null;uniqueIndex"`
Email string `json:"email" gorm:"default:'';not null;uniqueIndex"` Email string `json:"email" gorm:"default:'';not null;index"`
// Email string `json:"email" ` // Email string `json:"email" `
Password string `json:"-" gorm:"default:'';not null;"` Password string `json:"-" gorm:"default:'';not null;"`
Nickname string `json:"nickname" gorm:"default:'';not null;"` Nickname string `json:"nickname" gorm:"default:'';not null;"`
@@ -20,13 +15,13 @@ type User struct {
} }
// BeforeSave 钩子用于确保 email 字段有合理的默认值 // BeforeSave 钩子用于确保 email 字段有合理的默认值
func (u *User) BeforeSave(tx *gorm.DB) (err error) { //func (u *User) BeforeSave(tx *gorm.DB) (err error) {
// 如果 email 为空,设置为默认值 // // 如果 email 为空,设置为默认值
if u.Email == "" { // if u.Email == "" {
u.Email = fmt.Sprintf("%s@example.com", u.Username) // u.Email = fmt.Sprintf("%s@example.com", u.Username)
} // }
return nil // return nil
} //}
type UserList struct { type UserList struct {
Users []*User `json:"list,omitempty"` Users []*User `json:"list,omitempty"`

View File

@@ -6,20 +6,21 @@ import (
type UserThird struct { type UserThird struct {
IdModel IdModel
UserId uint ` json:"user_id" gorm:"not null;index"` UserId uint `json:"user_id" gorm:"not null;index"`
OauthUser OauthUser
// UnionId string `json:"union_id" gorm:"not null;"` UnionId string `json:"union_id" gorm:"default:'';not null;"`
// OauthType string `json:"oauth_type" gorm:"not null;"` // OauthType string `json:"oauth_type" gorm:"not null;"`
OauthType string `json:"oauth_type"` ThirdType string `json:"third_type" gorm:"default:'';not null;"` //deprecated
Op string `json:"op" gorm:"not null;"` OauthType string `json:"oauth_type" gorm:"default:'';not null;"`
Op string `json:"op" gorm:"default:'';not null;"`
TimeModel TimeModel
} }
func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType string, op string) { func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType string, op string) {
u.UserId = userId u.UserId = userId
u.OauthUser = *oauthUser u.OauthUser = *oauthUser
u.OauthType = oauthType u.OauthType = oauthType
u.Op = op u.Op = op
// make sure email is lower case // make sure email is lower case
u.Email = strings.ToLower(u.Email) u.Email = strings.ToLower(u.Email)
} }

View File

@@ -12,16 +12,15 @@ import (
// "golang.org/x/oauth2/google" // "golang.org/x/oauth2/google"
"gorm.io/gorm" "gorm.io/gorm"
// "io" // "io"
"fmt"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"fmt"
) )
type OauthService struct { type OauthService struct {
} }
@@ -34,26 +33,26 @@ type OidcEndpoint struct {
} }
type OauthCacheItem struct { type OauthCacheItem struct {
UserId uint `json:"user_id"` UserId uint `json:"user_id"`
Id string `json:"id"` //rustdesk的设备ID Id string `json:"id"` //rustdesk的设备ID
Op string `json:"op"` Op string `json:"op"`
Action string `json:"action"` Action string `json:"action"`
Uuid string `json:"uuid"` Uuid string `json:"uuid"`
DeviceName string `json:"device_name"` DeviceName string `json:"device_name"`
DeviceOs string `json:"device_os"` DeviceOs string `json:"device_os"`
DeviceType string `json:"device_type"` DeviceType string `json:"device_type"`
OpenId string `json:"open_id"` OpenId string `json:"open_id"`
Username string `json:"username"` Username string `json:"username"`
Name string `json:"name"` Name string `json:"name"`
Email string `json:"email"` Email string `json:"email"`
} }
func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser { func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
return &model.OauthUser{ return &model.OauthUser{
OpenId: oci.OpenId, OpenId: oci.OpenId,
Username: oci.Username, Username: oci.Username,
Name: oci.Name, Name: oci.Name,
Email: oci.Email, Email: oci.Email,
} }
} }
@@ -64,14 +63,13 @@ const (
OauthActionTypeBind = "bind" OauthActionTypeBind = "bind"
) )
func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) { func (oci *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
oa.OpenId = oauthUser.OpenId oci.OpenId = oauthUser.OpenId
oa.Username = oauthUser.Username oci.Username = oauthUser.Username
oa.Name = oauthUser.Name oci.Name = oauthUser.Name
oa.Email = oauthUser.Email oci.Email = oauthUser.Email
} }
func (os *OauthService) GetOauthCache(key string) *OauthCacheItem { func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
v, ok := OauthCache.Load(key) v, ok := OauthCache.Load(key)
if !ok { if !ok {
@@ -164,7 +162,7 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O
if err != nil { if err != nil {
return err, nil, nil return err, nil, nil
} }
oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL,TokenURL: endpoint.TokenURL,} oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL, TokenURL: endpoint.TokenURL}
oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes) oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
default: default:
return errors.New("unsupported OAuth type"), nil, nil return errors.New("unsupported OAuth type"), nil, nil
@@ -259,9 +257,8 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string)
return nil, user.ToOauthUser() return nil, user.ToOauthUser()
} }
// oidcCallback oidc回调, 通过code获取用户信息 // oidcCallback oidc回调, 通过code获取用户信息
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) { func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) {
var user = &model.OidcUser{} var user = &model.OidcUser{}
if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil { if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
return err, nil return err, nil
@@ -280,21 +277,20 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
} }
oauthType := oauthInfo.OauthType oauthType := oauthInfo.OauthType
switch oauthType { switch oauthType {
case model.OauthTypeGithub: case model.OauthTypeGithub:
err, oauthUser = os.githubCallback(oauthConfig, code) err, oauthUser = os.githubCallback(oauthConfig, code)
case model.OauthTypeOidc, model.OauthTypeGoogle: case model.OauthTypeOidc, model.OauthTypeGoogle:
err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer) err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
if err != nil { if err != nil {
return err, nil return err, nil
} }
err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo) err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
default: default:
return errors.New("unsupported OAuth type"), nil return errors.New("unsupported OAuth type"), nil
} }
return err, oauthUser return err, oauthUser
} }
func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird { func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird {
ut := &model.UserThird{} ut := &model.UserThird{}
global.DB.Where("open_id = ? and op = ?", openId, op).First(ut) global.DB.Where("open_id = ? and op = ?", openId, op).First(ut)
@@ -343,17 +339,17 @@ func (os *OauthService) InfoByOp(op string) *model.Oauth {
// Helper function to get scopes by operation // Helper function to get scopes by operation
func (os *OauthService) getScopesByOp(op string) []string { func (os *OauthService) getScopesByOp(op string) []string {
scopes := os.InfoByOp(op).Scopes scopes := os.InfoByOp(op).Scopes
return os.constructScopes(scopes) return os.constructScopes(scopes)
} }
// Helper function to construct scopes // Helper function to construct scopes
func (os *OauthService) constructScopes(scopes string) []string { func (os *OauthService) constructScopes(scopes string) []string {
scopes = strings.TrimSpace(scopes) scopes = strings.TrimSpace(scopes)
if scopes == "" { if scopes == "" {
scopes = model.OIDC_DEFAULT_SCOPES scopes = model.OIDC_DEFAULT_SCOPES
} }
return strings.Split(scopes, ",") return strings.Split(scopes, ",")
} }
func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) { func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) {
@@ -461,4 +457,4 @@ func (os *OauthService) getGithubPrimaryEmail(client *http.Client, githubUser *m
} }
return fmt.Errorf("no primary verified email found") return fmt.Errorf("no primary verified email found")
} }

View File

@@ -5,13 +5,13 @@ import (
adResp "Gwen/http/response/admin" adResp "Gwen/http/response/admin"
"Gwen/model" "Gwen/model"
"Gwen/utils" "Gwen/utils"
"errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
"math/rand" "math/rand"
"strconv" "strconv"
"time"
"strings" "strings"
"errors" "time"
) )
type UserService struct { type UserService struct {
@@ -23,6 +23,7 @@ func (us *UserService) InfoById(id uint) *model.User {
global.DB.Where("id = ?", id).First(u) global.DB.Where("id = ?", id).First(u)
return u return u
} }
// InfoByUsername 根据用户名取用户信息 // InfoByUsername 根据用户名取用户信息
func (us *UserService) InfoByUsername(un string) *model.User { func (us *UserService) InfoByUsername(un string) *model.User {
u := &model.User{} u := &model.User{}
@@ -75,11 +76,11 @@ func (us *UserService) GenerateToken(u *model.User) string {
func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserToken { func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserToken {
token := us.GenerateToken(u) token := us.GenerateToken(u)
ut := &model.UserToken{ ut := &model.UserToken{
UserId: u.Id, UserId: u.Id,
Token: token, Token: token,
DeviceUuid: llog.Uuid, DeviceUuid: llog.Uuid,
DeviceId: llog.DeviceId, DeviceId: llog.DeviceId,
ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(), ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(),
} }
global.DB.Create(ut) global.DB.Create(ut)
llog.UserTokenId = ut.UserId llog.UserTokenId = ut.UserId
@@ -162,7 +163,7 @@ func (us *UserService) Create(u *model.User) error {
// GetUuidByToken 根据token和user取uuid // GetUuidByToken 根据token和user取uuid
func (us *UserService) GetUuidByToken(u *model.User, token string) string { func (us *UserService) GetUuidByToken(u *model.User, token string) string {
ut := &model.UserToken{} ut := &model.UserToken{}
err :=global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error err := global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error
if err != nil { if err != nil {
return "" return ""
} }
@@ -214,12 +215,12 @@ func (us *UserService) Delete(u *model.User) error {
tx.Rollback() tx.Rollback()
return err return err
} }
tx.Commit()
// 删除关联的peer // 删除关联的peer
if err := AllService.PeerService.EraseUserId(u.Id); err != nil { if err := AllService.PeerService.EraseUserId(u.Id); err != nil {
tx.Rollback() tx.Rollback()
return err return err
} }
tx.Commit()
return nil return nil
} }
@@ -230,7 +231,7 @@ func (us *UserService) Update(u *model.User) error {
if us.IsAdmin(currentUser) { if us.IsAdmin(currentUser) {
adminCount := us.getAdminUserCount() adminCount := us.getAdminUserCount()
// 如果这是唯一的管理员,确保不能禁用或取消管理员权限 // 如果这是唯一的管理员,确保不能禁用或取消管理员权限
if adminCount <= 1 && ( !us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) { if adminCount <= 1 && (!us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) {
return errors.New("The last admin user cannot be disabled or demoted") return errors.New("The last admin user cannot be disabled or demoted")
} }
} }
@@ -290,48 +291,49 @@ func (us *UserService) InfoByOauthId(op string, openId string) *model.User {
} }
// RegisterByOauth 注册 // RegisterByOauth 注册
func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) (error, *model.User) { func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser, op string) (error, *model.User) {
global.Lock.Lock("registerByOauth") global.Lock.Lock("registerByOauth")
defer global.Lock.UnLock("registerByOauth") defer global.Lock.UnLock("registerByOauth")
ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId) ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId)
if ut.Id != 0 { if ut.Id != 0 {
return nil, us.InfoById(ut.UserId) return nil, us.InfoById(ut.UserId)
} }
//check if this email has been registered
email := oauthUser.Email
err, oauthType := AllService.OauthService.GetTypeByOp(op) err, oauthType := AllService.OauthService.GetTypeByOp(op)
if err != nil { if err != nil {
return err, nil return err, nil
} }
// if email is empty, use username and op as email //check if this email has been registered
if email == "" { email := oauthUser.Email
email = oauthUser.Username + "@" + op // only email is not empty
} if email != "" {
email = strings.ToLower(email) email = strings.ToLower(email)
// update email to oauthUser, in case it contain upper case // update email to oauthUser, in case it contain upper case
oauthUser.Email = email oauthUser.Email = email
user := us.InfoByEmail(email) user := us.InfoByEmail(email)
tx := global.DB.Begin() if user.Id != 0 {
if user.Id != 0 { ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
ut.FromOauthUser(user.Id, oauthUser, oauthType, op) global.DB.Create(ut)
} else { return nil, user
ut = &model.UserThird{}
ut.FromOauthUser(0, oauthUser, oauthType, op)
// The initial username should be formatted
username := us.formatUsername(oauthUser.Username)
usernameUnique := us.GenerateUsernameByOauth(username)
user = &model.User{
Username: usernameUnique,
GroupId: 1,
} }
oauthUser.ToUser(user, false)
tx.Create(user)
if user.Id == 0 {
tx.Rollback()
return errors.New("OauthRegisterFailed"), user
}
ut.UserId = user.Id
} }
tx := global.DB.Begin()
ut = &model.UserThird{}
ut.FromOauthUser(0, oauthUser, oauthType, op)
// The initial username should be formatted
username := us.formatUsername(oauthUser.Username)
usernameUnique := us.GenerateUsernameByOauth(username)
user := &model.User{
Username: usernameUnique,
GroupId: 1,
}
oauthUser.ToUser(user, false)
tx.Create(user)
if user.Id == 0 {
tx.Rollback()
return errors.New("OauthRegisterFailed"), user
}
ut.UserId = user.Id
tx.Create(ut) tx.Create(ut)
tx.Commit() tx.Commit()
return nil, user return nil, user
@@ -433,7 +435,7 @@ func (us *UserService) formatUsername(username string) string {
return username return username
} }
// Helper functions, getUserCount // Helper functions, getUserCount
func (us *UserService) getUserCount() int64 { func (us *UserService) getUserCount() int64 {
var count int64 var count int64
global.DB.Model(&model.User{}).Count(&count) global.DB.Model(&model.User{}).Count(&count)
@@ -445,4 +447,4 @@ func (us *UserService) getAdminUserCount() int64 {
var count int64 var count int64
global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count) global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count)
return count return count
} }