up oauth re
This commit is contained in:
@@ -101,7 +101,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DatabaseAutoUpdate() {
|
func DatabaseAutoUpdate() {
|
||||||
version := 244
|
version := 245
|
||||||
|
|
||||||
db := global.DB
|
db := global.DB
|
||||||
|
|
||||||
@@ -146,6 +146,21 @@ func DatabaseAutoUpdate() {
|
|||||||
if v.Version < uint(version) {
|
if v.Version < uint(version) {
|
||||||
Migrate(uint(version))
|
Migrate(uint(version))
|
||||||
}
|
}
|
||||||
|
// 245迁移
|
||||||
|
if v.Version < 245 {
|
||||||
|
//oauths 表的 oauth_type 字段设置为 op同样的值
|
||||||
|
db.Exec("update oauths set oauth_type = op")
|
||||||
|
db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google' and issuer = ''")
|
||||||
|
db.Exec("update user_thirds set oauth_type = third_type, op = third_type")
|
||||||
|
//通过email迁移旧的google授权
|
||||||
|
uts := make([]model.UserThird, 0)
|
||||||
|
db.Where("oauth_type = ?", "google").Find(&uts)
|
||||||
|
for _, ut := range uts {
|
||||||
|
if ut.UserId > 0 {
|
||||||
|
db.Model(&model.User{}).Where("id = ?", ut.UserId).Update("email", ut.OpenId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -180,15 +180,18 @@ func (o *Oauth) Create(c *gin.Context) {
|
|||||||
response.Fail(c, 101, errList[0])
|
response.Fail(c, 101, errList[0])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
u := f.ToOauth()
|
||||||
ex := service.AllService.OauthService.InfoByOp(f.Op)
|
err := u.FormatOauthInfo()
|
||||||
|
if err != nil {
|
||||||
|
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ex := service.AllService.OauthService.InfoByOp(u.Op)
|
||||||
if ex.Id > 0 {
|
if ex.Id > 0 {
|
||||||
response.Fail(c, 101, response.TranslateMsg(c, "ItemExists"))
|
response.Fail(c, 101, response.TranslateMsg(c, "ItemExists"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
err = service.AllService.OauthService.Create(u)
|
||||||
u := f.ToOauth()
|
|
||||||
err := service.AllService.OauthService.Create(u)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
|
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -217,7 +217,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
|
|||||||
oauthCache.UserId = user.Id
|
oauthCache.UserId = user.Id
|
||||||
oauthService.SetOauthCache(cacheKey, oauthCache, 0)
|
oauthService.SetOauthCache(cacheKey, oauthCache, 0)
|
||||||
// 如果是webadmin,登录成功后跳转到webadmin
|
// 如果是webadmin,登录成功后跳转到webadmin
|
||||||
if oauthCache.DeviceType == "webadmin" {
|
if oauthCache.DeviceType == model.LoginLogClientWebAdmin {
|
||||||
/*service.AllService.UserService.Login(u, &model.LoginLog{
|
/*service.AllService.UserService.Login(u, &model.LoginLog{
|
||||||
UserId: u.Id,
|
UserId: u.Id,
|
||||||
Client: "webadmin",
|
Client: "webadmin",
|
||||||
|
|||||||
@@ -5,15 +5,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type UserForm struct {
|
type UserForm struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Username string `json:"username" validate:"required,gte=4,lte=10"`
|
Username string `json:"username" validate:"required,gte=4,lte=10"`
|
||||||
Email string `json:"email" validate:"required,email"`
|
Email string `json:"email"` //validate:"required,email" email不强制
|
||||||
//Password string `json:"password" validate:"required,gte=4,lte=20"`
|
//Password string `json:"password" validate:"required,gte=4,lte=20"`
|
||||||
Nickname string `json:"nickname"`
|
Nickname string `json:"nickname"`
|
||||||
Avatar string `json:"avatar"`
|
Avatar string `json:"avatar"`
|
||||||
GroupId uint `json:"group_id" validate:"required"`
|
GroupId uint `json:"group_id" validate:"required"`
|
||||||
IsAdmin *bool `json:"is_admin" `
|
IsAdmin *bool `json:"is_admin" `
|
||||||
Status model.StatusCode `json:"status" validate:"required,gte=0"`
|
Status model.StatusCode `json:"status" validate:"required,gte=0"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uf *UserForm) FromUser(user *model.User) *UserForm {
|
func (uf *UserForm) FromUser(user *model.User) *UserForm {
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"errors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const OIDC_DEFAULT_SCOPES = "openid,profile,email"
|
const OIDC_DEFAULT_SCOPES = "openid,profile,email"
|
||||||
@@ -27,32 +27,23 @@ func ValidateOauthType(oauthType string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
OauthNameGithub string = "GitHub"
|
UserEndpointGithub string = "https://api.github.com/user"
|
||||||
OauthNameGoogle string = "Google"
|
IssuerGoogle string = "https://accounts.google.com"
|
||||||
OauthNameOidc string = "OIDC"
|
|
||||||
OauthNameWebauth string = "WebAuth"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
UserEndpointGithub string = "https://api.github.com/user"
|
|
||||||
IssuerGoogle string = "https://accounts.google.com"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Oauth struct {
|
type Oauth struct {
|
||||||
IdModel
|
IdModel
|
||||||
Op string `json:"op"`
|
Op string `json:"op"`
|
||||||
OauthType string `json:"oauth_type"`
|
OauthType string `json:"oauth_type"`
|
||||||
ClientId string `json:"client_id"`
|
ClientId string `json:"client_id"`
|
||||||
ClientSecret string `json:"client_secret"`
|
ClientSecret string `json:"client_secret"`
|
||||||
RedirectUrl string `json:"redirect_url"`
|
RedirectUrl string `json:"redirect_url"`
|
||||||
AutoRegister *bool `json:"auto_register"`
|
AutoRegister *bool `json:"auto_register"`
|
||||||
Scopes string `json:"scopes"`
|
Scopes string `json:"scopes"`
|
||||||
Issuer string `json:"issuer"`
|
Issuer string `json:"issuer"`
|
||||||
TimeModel
|
TimeModel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// Helper function to format oauth info, it's used in the update and create method
|
// Helper function to format oauth info, it's used in the update and create method
|
||||||
func (oa *Oauth) FormatOauthInfo() error {
|
func (oa *Oauth) FormatOauthInfo() error {
|
||||||
oauthType := strings.TrimSpace(oa.OauthType)
|
oauthType := strings.TrimSpace(oa.OauthType)
|
||||||
@@ -60,25 +51,20 @@ func (oa *Oauth) FormatOauthInfo() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
switch oauthType {
|
||||||
|
case OauthTypeGithub:
|
||||||
|
oa.Op = OauthTypeGithub
|
||||||
|
case OauthTypeGoogle:
|
||||||
|
oa.Op = OauthTypeGoogle
|
||||||
|
}
|
||||||
// check if the op is empty, set the default value
|
// check if the op is empty, set the default value
|
||||||
op := strings.TrimSpace(oa.Op)
|
op := strings.TrimSpace(oa.Op)
|
||||||
if op == "" {
|
if op == "" && oauthType == OauthTypeOidc {
|
||||||
switch oauthType {
|
oa.Op = OauthTypeOidc
|
||||||
case OauthTypeGithub:
|
|
||||||
oa.Op = OauthNameGithub
|
|
||||||
case OauthTypeGoogle:
|
|
||||||
oa.Op = OauthNameGoogle
|
|
||||||
case OauthTypeOidc:
|
|
||||||
oa.Op = OauthNameOidc
|
|
||||||
case OauthTypeWebauth:
|
|
||||||
oa.Op = OauthNameWebauth
|
|
||||||
default:
|
|
||||||
oa.Op = oauthType
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value
|
// check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value
|
||||||
issuer := strings.TrimSpace(oa.Issuer)
|
issuer := strings.TrimSpace(oa.Issuer)
|
||||||
// If the oauth type is google and the issuer is empty, set the issuer to the default value
|
// If the oauth type is google and the issuer is empty, set the issuer to the default value
|
||||||
if oauthType == OauthTypeGoogle && issuer == "" {
|
if oauthType == OauthTypeGoogle && issuer == "" {
|
||||||
oa.Issuer = IssuerGoogle
|
oa.Issuer = IssuerGoogle
|
||||||
}
|
}
|
||||||
@@ -86,12 +72,12 @@ func (oa *Oauth) FormatOauthInfo() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OauthUser struct {
|
type OauthUser struct {
|
||||||
OpenId string `json:"open_id" gorm:"not null;index"`
|
OpenId string `json:"open_id" gorm:"not null;index"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
VerifiedEmail bool `json:"verified_email,omitempty"`
|
VerifiedEmail bool `json:"verified_email,omitempty"`
|
||||||
Picture string `json:"picture,omitempty"`
|
Picture string `json:"picture,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ou *OauthUser) ToUser(user *User, overideUsername bool) {
|
func (ou *OauthUser) ToUser(user *User, overideUsername bool) {
|
||||||
@@ -122,7 +108,7 @@ func (ou *OidcUser) ToOauthUser() *OauthUser {
|
|||||||
if ou.PreferredUsername != "" {
|
if ou.PreferredUsername != "" {
|
||||||
username = ou.PreferredUsername
|
username = ou.PreferredUsername
|
||||||
} else {
|
} else {
|
||||||
username = strings.ToLower(strings.Split(ou.Email, "@")[0])
|
username = strings.ToLower(ou.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &OauthUser{
|
return &OauthUser{
|
||||||
@@ -135,29 +121,26 @@ func (ou *OidcUser) ToOauthUser() *OauthUser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
type GithubUser struct {
|
type GithubUser struct {
|
||||||
OauthUserBase
|
OauthUserBase
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Login string `json:"login"`
|
Login string `json:"login"`
|
||||||
AvatarUrl string `json:"avatar_url"`
|
AvatarUrl string `json:"avatar_url"`
|
||||||
VerifiedEmail bool `json:"verified_email"`
|
VerifiedEmail bool `json:"verified_email"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gu *GithubUser) ToOauthUser() *OauthUser {
|
func (gu *GithubUser) ToOauthUser() *OauthUser {
|
||||||
username := strings.ToLower(gu.Login)
|
username := strings.ToLower(gu.Login)
|
||||||
return &OauthUser{
|
return &OauthUser{
|
||||||
OpenId: strconv.Itoa(gu.Id),
|
OpenId: strconv.Itoa(gu.Id),
|
||||||
Name: gu.Name,
|
Name: gu.Name,
|
||||||
Username: username,
|
Username: username,
|
||||||
Email: gu.Email,
|
Email: gu.Email,
|
||||||
VerifiedEmail: gu.VerifiedEmail,
|
VerifiedEmail: gu.VerifiedEmail,
|
||||||
Picture: gu.AvatarUrl,
|
Picture: gu.AvatarUrl,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
type OauthList struct {
|
type OauthList struct {
|
||||||
Oauths []*Oauth `json:"list"`
|
Oauths []*Oauth `json:"list"`
|
||||||
Pagination
|
Pagination
|
||||||
|
|||||||
@@ -1,14 +1,9 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
IdModel
|
IdModel
|
||||||
Username string `json:"username" gorm:"default:'';not null;uniqueIndex"`
|
Username string `json:"username" gorm:"default:'';not null;uniqueIndex"`
|
||||||
Email string `json:"email" gorm:"default:'';not null;uniqueIndex"`
|
Email string `json:"email" gorm:"default:'';not null;index"`
|
||||||
// Email string `json:"email" `
|
// Email string `json:"email" `
|
||||||
Password string `json:"-" gorm:"default:'';not null;"`
|
Password string `json:"-" gorm:"default:'';not null;"`
|
||||||
Nickname string `json:"nickname" gorm:"default:'';not null;"`
|
Nickname string `json:"nickname" gorm:"default:'';not null;"`
|
||||||
@@ -20,13 +15,13 @@ type User struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// BeforeSave 钩子用于确保 email 字段有合理的默认值
|
// BeforeSave 钩子用于确保 email 字段有合理的默认值
|
||||||
func (u *User) BeforeSave(tx *gorm.DB) (err error) {
|
//func (u *User) BeforeSave(tx *gorm.DB) (err error) {
|
||||||
// 如果 email 为空,设置为默认值
|
// // 如果 email 为空,设置为默认值
|
||||||
if u.Email == "" {
|
// if u.Email == "" {
|
||||||
u.Email = fmt.Sprintf("%s@example.com", u.Username)
|
// u.Email = fmt.Sprintf("%s@example.com", u.Username)
|
||||||
}
|
// }
|
||||||
return nil
|
// return nil
|
||||||
}
|
//}
|
||||||
|
|
||||||
type UserList struct {
|
type UserList struct {
|
||||||
Users []*User `json:"list,omitempty"`
|
Users []*User `json:"list,omitempty"`
|
||||||
|
|||||||
@@ -6,20 +6,21 @@ import (
|
|||||||
|
|
||||||
type UserThird struct {
|
type UserThird struct {
|
||||||
IdModel
|
IdModel
|
||||||
UserId uint ` json:"user_id" gorm:"not null;index"`
|
UserId uint `json:"user_id" gorm:"not null;index"`
|
||||||
OauthUser
|
OauthUser
|
||||||
// UnionId string `json:"union_id" gorm:"not null;"`
|
UnionId string `json:"union_id" gorm:"default:'';not null;"`
|
||||||
// OauthType string `json:"oauth_type" gorm:"not null;"`
|
// OauthType string `json:"oauth_type" gorm:"not null;"`
|
||||||
OauthType string `json:"oauth_type"`
|
ThirdType string `json:"third_type" gorm:"default:'';not null;"` //deprecated
|
||||||
Op string `json:"op" gorm:"not null;"`
|
OauthType string `json:"oauth_type" gorm:"default:'';not null;"`
|
||||||
|
Op string `json:"op" gorm:"default:'';not null;"`
|
||||||
TimeModel
|
TimeModel
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType string, op string) {
|
func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType string, op string) {
|
||||||
u.UserId = userId
|
u.UserId = userId
|
||||||
u.OauthUser = *oauthUser
|
u.OauthUser = *oauthUser
|
||||||
u.OauthType = oauthType
|
u.OauthType = oauthType
|
||||||
u.Op = op
|
u.Op = op
|
||||||
// make sure email is lower case
|
// make sure email is lower case
|
||||||
u.Email = strings.ToLower(u.Email)
|
u.Email = strings.ToLower(u.Email)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,16 +12,15 @@ import (
|
|||||||
// "golang.org/x/oauth2/google"
|
// "golang.org/x/oauth2/google"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
// "io"
|
// "io"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"fmt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
type OauthService struct {
|
type OauthService struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,26 +33,26 @@ type OidcEndpoint struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OauthCacheItem struct {
|
type OauthCacheItem struct {
|
||||||
UserId uint `json:"user_id"`
|
UserId uint `json:"user_id"`
|
||||||
Id string `json:"id"` //rustdesk的设备ID
|
Id string `json:"id"` //rustdesk的设备ID
|
||||||
Op string `json:"op"`
|
Op string `json:"op"`
|
||||||
Action string `json:"action"`
|
Action string `json:"action"`
|
||||||
Uuid string `json:"uuid"`
|
Uuid string `json:"uuid"`
|
||||||
DeviceName string `json:"device_name"`
|
DeviceName string `json:"device_name"`
|
||||||
DeviceOs string `json:"device_os"`
|
DeviceOs string `json:"device_os"`
|
||||||
DeviceType string `json:"device_type"`
|
DeviceType string `json:"device_type"`
|
||||||
OpenId string `json:"open_id"`
|
OpenId string `json:"open_id"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
|
func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
|
||||||
return &model.OauthUser{
|
return &model.OauthUser{
|
||||||
OpenId: oci.OpenId,
|
OpenId: oci.OpenId,
|
||||||
Username: oci.Username,
|
Username: oci.Username,
|
||||||
Name: oci.Name,
|
Name: oci.Name,
|
||||||
Email: oci.Email,
|
Email: oci.Email,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,14 +63,13 @@ const (
|
|||||||
OauthActionTypeBind = "bind"
|
OauthActionTypeBind = "bind"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
|
func (oci *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
|
||||||
oa.OpenId = oauthUser.OpenId
|
oci.OpenId = oauthUser.OpenId
|
||||||
oa.Username = oauthUser.Username
|
oci.Username = oauthUser.Username
|
||||||
oa.Name = oauthUser.Name
|
oci.Name = oauthUser.Name
|
||||||
oa.Email = oauthUser.Email
|
oci.Email = oauthUser.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
|
func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
|
||||||
v, ok := OauthCache.Load(key)
|
v, ok := OauthCache.Load(key)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -164,7 +162,7 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil, nil
|
return err, nil, nil
|
||||||
}
|
}
|
||||||
oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL,TokenURL: endpoint.TokenURL,}
|
oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL, TokenURL: endpoint.TokenURL}
|
||||||
oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
|
oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported OAuth type"), nil, nil
|
return errors.New("unsupported OAuth type"), nil, nil
|
||||||
@@ -259,9 +257,8 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string)
|
|||||||
return nil, user.ToOauthUser()
|
return nil, user.ToOauthUser()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// oidcCallback oidc回调, 通过code获取用户信息
|
// oidcCallback oidc回调, 通过code获取用户信息
|
||||||
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) {
|
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) {
|
||||||
var user = &model.OidcUser{}
|
var user = &model.OidcUser{}
|
||||||
if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
|
if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
@@ -280,21 +277,20 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
|
|||||||
}
|
}
|
||||||
oauthType := oauthInfo.OauthType
|
oauthType := oauthInfo.OauthType
|
||||||
switch oauthType {
|
switch oauthType {
|
||||||
case model.OauthTypeGithub:
|
case model.OauthTypeGithub:
|
||||||
err, oauthUser = os.githubCallback(oauthConfig, code)
|
err, oauthUser = os.githubCallback(oauthConfig, code)
|
||||||
case model.OauthTypeOidc, model.OauthTypeGoogle:
|
case model.OauthTypeOidc, model.OauthTypeGoogle:
|
||||||
err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
|
err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
|
err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported OAuth type"), nil
|
return errors.New("unsupported OAuth type"), nil
|
||||||
}
|
}
|
||||||
return err, oauthUser
|
return err, oauthUser
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird {
|
func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird {
|
||||||
ut := &model.UserThird{}
|
ut := &model.UserThird{}
|
||||||
global.DB.Where("open_id = ? and op = ?", openId, op).First(ut)
|
global.DB.Where("open_id = ? and op = ?", openId, op).First(ut)
|
||||||
@@ -343,17 +339,17 @@ func (os *OauthService) InfoByOp(op string) *model.Oauth {
|
|||||||
|
|
||||||
// Helper function to get scopes by operation
|
// Helper function to get scopes by operation
|
||||||
func (os *OauthService) getScopesByOp(op string) []string {
|
func (os *OauthService) getScopesByOp(op string) []string {
|
||||||
scopes := os.InfoByOp(op).Scopes
|
scopes := os.InfoByOp(op).Scopes
|
||||||
return os.constructScopes(scopes)
|
return os.constructScopes(scopes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to construct scopes
|
// Helper function to construct scopes
|
||||||
func (os *OauthService) constructScopes(scopes string) []string {
|
func (os *OauthService) constructScopes(scopes string) []string {
|
||||||
scopes = strings.TrimSpace(scopes)
|
scopes = strings.TrimSpace(scopes)
|
||||||
if scopes == "" {
|
if scopes == "" {
|
||||||
scopes = model.OIDC_DEFAULT_SCOPES
|
scopes = model.OIDC_DEFAULT_SCOPES
|
||||||
}
|
}
|
||||||
return strings.Split(scopes, ",")
|
return strings.Split(scopes, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) {
|
func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) {
|
||||||
@@ -461,4 +457,4 @@ func (os *OauthService) getGithubPrimaryEmail(client *http.Client, githubUser *m
|
|||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("no primary verified email found")
|
return fmt.Errorf("no primary verified email found")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ import (
|
|||||||
adResp "Gwen/http/response/admin"
|
adResp "Gwen/http/response/admin"
|
||||||
"Gwen/model"
|
"Gwen/model"
|
||||||
"Gwen/utils"
|
"Gwen/utils"
|
||||||
|
"errors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
|
||||||
"strings"
|
"strings"
|
||||||
"errors"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserService struct {
|
type UserService struct {
|
||||||
@@ -23,6 +23,7 @@ func (us *UserService) InfoById(id uint) *model.User {
|
|||||||
global.DB.Where("id = ?", id).First(u)
|
global.DB.Where("id = ?", id).First(u)
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
// InfoByUsername 根据用户名取用户信息
|
// InfoByUsername 根据用户名取用户信息
|
||||||
func (us *UserService) InfoByUsername(un string) *model.User {
|
func (us *UserService) InfoByUsername(un string) *model.User {
|
||||||
u := &model.User{}
|
u := &model.User{}
|
||||||
@@ -75,11 +76,11 @@ func (us *UserService) GenerateToken(u *model.User) string {
|
|||||||
func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserToken {
|
func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserToken {
|
||||||
token := us.GenerateToken(u)
|
token := us.GenerateToken(u)
|
||||||
ut := &model.UserToken{
|
ut := &model.UserToken{
|
||||||
UserId: u.Id,
|
UserId: u.Id,
|
||||||
Token: token,
|
Token: token,
|
||||||
DeviceUuid: llog.Uuid,
|
DeviceUuid: llog.Uuid,
|
||||||
DeviceId: llog.DeviceId,
|
DeviceId: llog.DeviceId,
|
||||||
ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(),
|
ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(),
|
||||||
}
|
}
|
||||||
global.DB.Create(ut)
|
global.DB.Create(ut)
|
||||||
llog.UserTokenId = ut.UserId
|
llog.UserTokenId = ut.UserId
|
||||||
@@ -162,7 +163,7 @@ func (us *UserService) Create(u *model.User) error {
|
|||||||
// GetUuidByToken 根据token和user取uuid
|
// GetUuidByToken 根据token和user取uuid
|
||||||
func (us *UserService) GetUuidByToken(u *model.User, token string) string {
|
func (us *UserService) GetUuidByToken(u *model.User, token string) string {
|
||||||
ut := &model.UserToken{}
|
ut := &model.UserToken{}
|
||||||
err :=global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error
|
err := global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -214,12 +215,12 @@ func (us *UserService) Delete(u *model.User) error {
|
|||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
tx.Commit()
|
|
||||||
// 删除关联的peer
|
// 删除关联的peer
|
||||||
if err := AllService.PeerService.EraseUserId(u.Id); err != nil {
|
if err := AllService.PeerService.EraseUserId(u.Id); err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
tx.Commit()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,7 +231,7 @@ func (us *UserService) Update(u *model.User) error {
|
|||||||
if us.IsAdmin(currentUser) {
|
if us.IsAdmin(currentUser) {
|
||||||
adminCount := us.getAdminUserCount()
|
adminCount := us.getAdminUserCount()
|
||||||
// 如果这是唯一的管理员,确保不能禁用或取消管理员权限
|
// 如果这是唯一的管理员,确保不能禁用或取消管理员权限
|
||||||
if adminCount <= 1 && ( !us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) {
|
if adminCount <= 1 && (!us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) {
|
||||||
return errors.New("The last admin user cannot be disabled or demoted")
|
return errors.New("The last admin user cannot be disabled or demoted")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -290,48 +291,49 @@ func (us *UserService) InfoByOauthId(op string, openId string) *model.User {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RegisterByOauth 注册
|
// RegisterByOauth 注册
|
||||||
func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) (error, *model.User) {
|
func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser, op string) (error, *model.User) {
|
||||||
global.Lock.Lock("registerByOauth")
|
global.Lock.Lock("registerByOauth")
|
||||||
defer global.Lock.UnLock("registerByOauth")
|
defer global.Lock.UnLock("registerByOauth")
|
||||||
ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId)
|
ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId)
|
||||||
if ut.Id != 0 {
|
if ut.Id != 0 {
|
||||||
return nil, us.InfoById(ut.UserId)
|
return nil, us.InfoById(ut.UserId)
|
||||||
}
|
}
|
||||||
//check if this email has been registered
|
|
||||||
email := oauthUser.Email
|
|
||||||
err, oauthType := AllService.OauthService.GetTypeByOp(op)
|
err, oauthType := AllService.OauthService.GetTypeByOp(op)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
// if email is empty, use username and op as email
|
//check if this email has been registered
|
||||||
if email == "" {
|
email := oauthUser.Email
|
||||||
email = oauthUser.Username + "@" + op
|
// only email is not empty
|
||||||
}
|
if email != "" {
|
||||||
email = strings.ToLower(email)
|
email = strings.ToLower(email)
|
||||||
// update email to oauthUser, in case it contain upper case
|
// update email to oauthUser, in case it contain upper case
|
||||||
oauthUser.Email = email
|
oauthUser.Email = email
|
||||||
user := us.InfoByEmail(email)
|
user := us.InfoByEmail(email)
|
||||||
tx := global.DB.Begin()
|
if user.Id != 0 {
|
||||||
if user.Id != 0 {
|
ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
|
||||||
ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
|
global.DB.Create(ut)
|
||||||
} else {
|
return nil, user
|
||||||
ut = &model.UserThird{}
|
|
||||||
ut.FromOauthUser(0, oauthUser, oauthType, op)
|
|
||||||
// The initial username should be formatted
|
|
||||||
username := us.formatUsername(oauthUser.Username)
|
|
||||||
usernameUnique := us.GenerateUsernameByOauth(username)
|
|
||||||
user = &model.User{
|
|
||||||
Username: usernameUnique,
|
|
||||||
GroupId: 1,
|
|
||||||
}
|
}
|
||||||
oauthUser.ToUser(user, false)
|
|
||||||
tx.Create(user)
|
|
||||||
if user.Id == 0 {
|
|
||||||
tx.Rollback()
|
|
||||||
return errors.New("OauthRegisterFailed"), user
|
|
||||||
}
|
|
||||||
ut.UserId = user.Id
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tx := global.DB.Begin()
|
||||||
|
ut = &model.UserThird{}
|
||||||
|
ut.FromOauthUser(0, oauthUser, oauthType, op)
|
||||||
|
// The initial username should be formatted
|
||||||
|
username := us.formatUsername(oauthUser.Username)
|
||||||
|
usernameUnique := us.GenerateUsernameByOauth(username)
|
||||||
|
user := &model.User{
|
||||||
|
Username: usernameUnique,
|
||||||
|
GroupId: 1,
|
||||||
|
}
|
||||||
|
oauthUser.ToUser(user, false)
|
||||||
|
tx.Create(user)
|
||||||
|
if user.Id == 0 {
|
||||||
|
tx.Rollback()
|
||||||
|
return errors.New("OauthRegisterFailed"), user
|
||||||
|
}
|
||||||
|
ut.UserId = user.Id
|
||||||
tx.Create(ut)
|
tx.Create(ut)
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
return nil, user
|
return nil, user
|
||||||
@@ -433,7 +435,7 @@ func (us *UserService) formatUsername(username string) string {
|
|||||||
return username
|
return username
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper functions, getUserCount
|
// Helper functions, getUserCount
|
||||||
func (us *UserService) getUserCount() int64 {
|
func (us *UserService) getUserCount() int64 {
|
||||||
var count int64
|
var count int64
|
||||||
global.DB.Model(&model.User{}).Count(&count)
|
global.DB.Model(&model.User{}).Count(&count)
|
||||||
@@ -445,4 +447,4 @@ func (us *UserService) getAdminUserCount() int64 {
|
|||||||
var count int64
|
var count int64
|
||||||
global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count)
|
global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count)
|
||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user