up oauth re
This commit is contained in:
@@ -101,7 +101,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DatabaseAutoUpdate() {
|
func DatabaseAutoUpdate() {
|
||||||
version := 244
|
version := 245
|
||||||
|
|
||||||
db := global.DB
|
db := global.DB
|
||||||
|
|
||||||
@@ -146,6 +146,21 @@ func DatabaseAutoUpdate() {
|
|||||||
if v.Version < uint(version) {
|
if v.Version < uint(version) {
|
||||||
Migrate(uint(version))
|
Migrate(uint(version))
|
||||||
}
|
}
|
||||||
|
// 245迁移
|
||||||
|
if v.Version < 245 {
|
||||||
|
//oauths 表的 oauth_type 字段设置为 op同样的值
|
||||||
|
db.Exec("update oauths set oauth_type = op")
|
||||||
|
db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google' and issuer = ''")
|
||||||
|
db.Exec("update user_thirds set oauth_type = third_type, op = third_type")
|
||||||
|
//通过email迁移旧的google授权
|
||||||
|
uts := make([]model.UserThird, 0)
|
||||||
|
db.Where("oauth_type = ?", "google").Find(&uts)
|
||||||
|
for _, ut := range uts {
|
||||||
|
if ut.UserId > 0 {
|
||||||
|
db.Model(&model.User{}).Where("id = ?", ut.UserId).Update("email", ut.OpenId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -180,15 +180,18 @@ func (o *Oauth) Create(c *gin.Context) {
|
|||||||
response.Fail(c, 101, errList[0])
|
response.Fail(c, 101, errList[0])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
u := f.ToOauth()
|
||||||
ex := service.AllService.OauthService.InfoByOp(f.Op)
|
err := u.FormatOauthInfo()
|
||||||
|
if err != nil {
|
||||||
|
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ex := service.AllService.OauthService.InfoByOp(u.Op)
|
||||||
if ex.Id > 0 {
|
if ex.Id > 0 {
|
||||||
response.Fail(c, 101, response.TranslateMsg(c, "ItemExists"))
|
response.Fail(c, 101, response.TranslateMsg(c, "ItemExists"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
err = service.AllService.OauthService.Create(u)
|
||||||
u := f.ToOauth()
|
|
||||||
err := service.AllService.OauthService.Create(u)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
|
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -217,7 +217,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
|
|||||||
oauthCache.UserId = user.Id
|
oauthCache.UserId = user.Id
|
||||||
oauthService.SetOauthCache(cacheKey, oauthCache, 0)
|
oauthService.SetOauthCache(cacheKey, oauthCache, 0)
|
||||||
// 如果是webadmin,登录成功后跳转到webadmin
|
// 如果是webadmin,登录成功后跳转到webadmin
|
||||||
if oauthCache.DeviceType == "webadmin" {
|
if oauthCache.DeviceType == model.LoginLogClientWebAdmin {
|
||||||
/*service.AllService.UserService.Login(u, &model.LoginLog{
|
/*service.AllService.UserService.Login(u, &model.LoginLog{
|
||||||
UserId: u.Id,
|
UserId: u.Id,
|
||||||
Client: "webadmin",
|
Client: "webadmin",
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
type UserForm struct {
|
type UserForm struct {
|
||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Username string `json:"username" validate:"required,gte=4,lte=10"`
|
Username string `json:"username" validate:"required,gte=4,lte=10"`
|
||||||
Email string `json:"email" validate:"required,email"`
|
Email string `json:"email"` //validate:"required,email" email不强制
|
||||||
//Password string `json:"password" validate:"required,gte=4,lte=20"`
|
//Password string `json:"password" validate:"required,gte=4,lte=20"`
|
||||||
Nickname string `json:"nickname"`
|
Nickname string `json:"nickname"`
|
||||||
Avatar string `json:"avatar"`
|
Avatar string `json:"avatar"`
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"errors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const OIDC_DEFAULT_SCOPES = "openid,profile,email"
|
const OIDC_DEFAULT_SCOPES = "openid,profile,email"
|
||||||
@@ -26,13 +26,6 @@ func ValidateOauthType(oauthType string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
OauthNameGithub string = "GitHub"
|
|
||||||
OauthNameGoogle string = "Google"
|
|
||||||
OauthNameOidc string = "OIDC"
|
|
||||||
OauthNameWebauth string = "WebAuth"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
UserEndpointGithub string = "https://api.github.com/user"
|
UserEndpointGithub string = "https://api.github.com/user"
|
||||||
IssuerGoogle string = "https://accounts.google.com"
|
IssuerGoogle string = "https://accounts.google.com"
|
||||||
@@ -51,8 +44,6 @@ type Oauth struct {
|
|||||||
TimeModel
|
TimeModel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// Helper function to format oauth info, it's used in the update and create method
|
// Helper function to format oauth info, it's used in the update and create method
|
||||||
func (oa *Oauth) FormatOauthInfo() error {
|
func (oa *Oauth) FormatOauthInfo() error {
|
||||||
oauthType := strings.TrimSpace(oa.OauthType)
|
oauthType := strings.TrimSpace(oa.OauthType)
|
||||||
@@ -60,21 +51,16 @@ func (oa *Oauth) FormatOauthInfo() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// check if the op is empty, set the default value
|
|
||||||
op := strings.TrimSpace(oa.Op)
|
|
||||||
if op == "" {
|
|
||||||
switch oauthType {
|
switch oauthType {
|
||||||
case OauthTypeGithub:
|
case OauthTypeGithub:
|
||||||
oa.Op = OauthNameGithub
|
oa.Op = OauthTypeGithub
|
||||||
case OauthTypeGoogle:
|
case OauthTypeGoogle:
|
||||||
oa.Op = OauthNameGoogle
|
oa.Op = OauthTypeGoogle
|
||||||
case OauthTypeOidc:
|
|
||||||
oa.Op = OauthNameOidc
|
|
||||||
case OauthTypeWebauth:
|
|
||||||
oa.Op = OauthNameWebauth
|
|
||||||
default:
|
|
||||||
oa.Op = oauthType
|
|
||||||
}
|
}
|
||||||
|
// check if the op is empty, set the default value
|
||||||
|
op := strings.TrimSpace(oa.Op)
|
||||||
|
if op == "" && oauthType == OauthTypeOidc {
|
||||||
|
oa.Op = OauthTypeOidc
|
||||||
}
|
}
|
||||||
// check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value
|
// check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value
|
||||||
issuer := strings.TrimSpace(oa.Issuer)
|
issuer := strings.TrimSpace(oa.Issuer)
|
||||||
@@ -122,7 +108,7 @@ func (ou *OidcUser) ToOauthUser() *OauthUser {
|
|||||||
if ou.PreferredUsername != "" {
|
if ou.PreferredUsername != "" {
|
||||||
username = ou.PreferredUsername
|
username = ou.PreferredUsername
|
||||||
} else {
|
} else {
|
||||||
username = strings.ToLower(strings.Split(ou.Email, "@")[0])
|
username = strings.ToLower(ou.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &OauthUser{
|
return &OauthUser{
|
||||||
@@ -135,7 +121,6 @@ func (ou *OidcUser) ToOauthUser() *OauthUser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
type GithubUser struct {
|
type GithubUser struct {
|
||||||
OauthUserBase
|
OauthUserBase
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
@@ -156,8 +141,6 @@ func (gu *GithubUser) ToOauthUser() *OauthUser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
type OauthList struct {
|
type OauthList struct {
|
||||||
Oauths []*Oauth `json:"list"`
|
Oauths []*Oauth `json:"list"`
|
||||||
Pagination
|
Pagination
|
||||||
|
|||||||
@@ -1,14 +1,9 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
IdModel
|
IdModel
|
||||||
Username string `json:"username" gorm:"default:'';not null;uniqueIndex"`
|
Username string `json:"username" gorm:"default:'';not null;uniqueIndex"`
|
||||||
Email string `json:"email" gorm:"default:'';not null;uniqueIndex"`
|
Email string `json:"email" gorm:"default:'';not null;index"`
|
||||||
// Email string `json:"email" `
|
// Email string `json:"email" `
|
||||||
Password string `json:"-" gorm:"default:'';not null;"`
|
Password string `json:"-" gorm:"default:'';not null;"`
|
||||||
Nickname string `json:"nickname" gorm:"default:'';not null;"`
|
Nickname string `json:"nickname" gorm:"default:'';not null;"`
|
||||||
@@ -20,13 +15,13 @@ type User struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// BeforeSave 钩子用于确保 email 字段有合理的默认值
|
// BeforeSave 钩子用于确保 email 字段有合理的默认值
|
||||||
func (u *User) BeforeSave(tx *gorm.DB) (err error) {
|
//func (u *User) BeforeSave(tx *gorm.DB) (err error) {
|
||||||
// 如果 email 为空,设置为默认值
|
// // 如果 email 为空,设置为默认值
|
||||||
if u.Email == "" {
|
// if u.Email == "" {
|
||||||
u.Email = fmt.Sprintf("%s@example.com", u.Username)
|
// u.Email = fmt.Sprintf("%s@example.com", u.Username)
|
||||||
}
|
// }
|
||||||
return nil
|
// return nil
|
||||||
}
|
//}
|
||||||
|
|
||||||
type UserList struct {
|
type UserList struct {
|
||||||
Users []*User `json:"list,omitempty"`
|
Users []*User `json:"list,omitempty"`
|
||||||
|
|||||||
@@ -8,10 +8,11 @@ type UserThird struct {
|
|||||||
IdModel
|
IdModel
|
||||||
UserId uint `json:"user_id" gorm:"not null;index"`
|
UserId uint `json:"user_id" gorm:"not null;index"`
|
||||||
OauthUser
|
OauthUser
|
||||||
// UnionId string `json:"union_id" gorm:"not null;"`
|
UnionId string `json:"union_id" gorm:"default:'';not null;"`
|
||||||
// OauthType string `json:"oauth_type" gorm:"not null;"`
|
// OauthType string `json:"oauth_type" gorm:"not null;"`
|
||||||
OauthType string `json:"oauth_type"`
|
ThirdType string `json:"third_type" gorm:"default:'';not null;"` //deprecated
|
||||||
Op string `json:"op" gorm:"not null;"`
|
OauthType string `json:"oauth_type" gorm:"default:'';not null;"`
|
||||||
|
Op string `json:"op" gorm:"default:'';not null;"`
|
||||||
TimeModel
|
TimeModel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,16 +12,15 @@ import (
|
|||||||
// "golang.org/x/oauth2/google"
|
// "golang.org/x/oauth2/google"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
// "io"
|
// "io"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"fmt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
type OauthService struct {
|
type OauthService struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,14 +63,13 @@ const (
|
|||||||
OauthActionTypeBind = "bind"
|
OauthActionTypeBind = "bind"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
|
func (oci *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
|
||||||
oa.OpenId = oauthUser.OpenId
|
oci.OpenId = oauthUser.OpenId
|
||||||
oa.Username = oauthUser.Username
|
oci.Username = oauthUser.Username
|
||||||
oa.Name = oauthUser.Name
|
oci.Name = oauthUser.Name
|
||||||
oa.Email = oauthUser.Email
|
oci.Email = oauthUser.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
|
func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
|
||||||
v, ok := OauthCache.Load(key)
|
v, ok := OauthCache.Load(key)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -164,7 +162,7 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil, nil
|
return err, nil, nil
|
||||||
}
|
}
|
||||||
oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL,TokenURL: endpoint.TokenURL,}
|
oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL, TokenURL: endpoint.TokenURL}
|
||||||
oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
|
oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported OAuth type"), nil, nil
|
return errors.New("unsupported OAuth type"), nil, nil
|
||||||
@@ -259,9 +257,8 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string)
|
|||||||
return nil, user.ToOauthUser()
|
return nil, user.ToOauthUser()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// oidcCallback oidc回调, 通过code获取用户信息
|
// oidcCallback oidc回调, 通过code获取用户信息
|
||||||
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) {
|
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) {
|
||||||
var user = &model.OidcUser{}
|
var user = &model.OidcUser{}
|
||||||
if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
|
if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
@@ -294,7 +291,6 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
|
|||||||
return err, oauthUser
|
return err, oauthUser
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird {
|
func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird {
|
||||||
ut := &model.UserThird{}
|
ut := &model.UserThird{}
|
||||||
global.DB.Where("open_id = ? and op = ?", openId, op).First(ut)
|
global.DB.Where("open_id = ? and op = ?", openId, op).First(ut)
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ import (
|
|||||||
adResp "Gwen/http/response/admin"
|
adResp "Gwen/http/response/admin"
|
||||||
"Gwen/model"
|
"Gwen/model"
|
||||||
"Gwen/utils"
|
"Gwen/utils"
|
||||||
|
"errors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
|
||||||
"strings"
|
"strings"
|
||||||
"errors"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserService struct {
|
type UserService struct {
|
||||||
@@ -23,6 +23,7 @@ func (us *UserService) InfoById(id uint) *model.User {
|
|||||||
global.DB.Where("id = ?", id).First(u)
|
global.DB.Where("id = ?", id).First(u)
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
// InfoByUsername 根据用户名取用户信息
|
// InfoByUsername 根据用户名取用户信息
|
||||||
func (us *UserService) InfoByUsername(un string) *model.User {
|
func (us *UserService) InfoByUsername(un string) *model.User {
|
||||||
u := &model.User{}
|
u := &model.User{}
|
||||||
@@ -214,12 +215,12 @@ func (us *UserService) Delete(u *model.User) error {
|
|||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
tx.Commit()
|
|
||||||
// 删除关联的peer
|
// 删除关联的peer
|
||||||
if err := AllService.PeerService.EraseUserId(u.Id); err != nil {
|
if err := AllService.PeerService.EraseUserId(u.Id); err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
tx.Commit()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -297,30 +298,32 @@ func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) (
|
|||||||
if ut.Id != 0 {
|
if ut.Id != 0 {
|
||||||
return nil, us.InfoById(ut.UserId)
|
return nil, us.InfoById(ut.UserId)
|
||||||
}
|
}
|
||||||
//check if this email has been registered
|
|
||||||
email := oauthUser.Email
|
|
||||||
err, oauthType := AllService.OauthService.GetTypeByOp(op)
|
err, oauthType := AllService.OauthService.GetTypeByOp(op)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
// if email is empty, use username and op as email
|
//check if this email has been registered
|
||||||
if email == "" {
|
email := oauthUser.Email
|
||||||
email = oauthUser.Username + "@" + op
|
// only email is not empty
|
||||||
}
|
if email != "" {
|
||||||
email = strings.ToLower(email)
|
email = strings.ToLower(email)
|
||||||
// update email to oauthUser, in case it contain upper case
|
// update email to oauthUser, in case it contain upper case
|
||||||
oauthUser.Email = email
|
oauthUser.Email = email
|
||||||
user := us.InfoByEmail(email)
|
user := us.InfoByEmail(email)
|
||||||
tx := global.DB.Begin()
|
|
||||||
if user.Id != 0 {
|
if user.Id != 0 {
|
||||||
ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
|
ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
|
||||||
} else {
|
global.DB.Create(ut)
|
||||||
|
return nil, user
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := global.DB.Begin()
|
||||||
ut = &model.UserThird{}
|
ut = &model.UserThird{}
|
||||||
ut.FromOauthUser(0, oauthUser, oauthType, op)
|
ut.FromOauthUser(0, oauthUser, oauthType, op)
|
||||||
// The initial username should be formatted
|
// The initial username should be formatted
|
||||||
username := us.formatUsername(oauthUser.Username)
|
username := us.formatUsername(oauthUser.Username)
|
||||||
usernameUnique := us.GenerateUsernameByOauth(username)
|
usernameUnique := us.GenerateUsernameByOauth(username)
|
||||||
user = &model.User{
|
user := &model.User{
|
||||||
Username: usernameUnique,
|
Username: usernameUnique,
|
||||||
GroupId: 1,
|
GroupId: 1,
|
||||||
}
|
}
|
||||||
@@ -331,7 +334,6 @@ func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) (
|
|||||||
return errors.New("OauthRegisterFailed"), user
|
return errors.New("OauthRegisterFailed"), user
|
||||||
}
|
}
|
||||||
ut.UserId = user.Id
|
ut.UserId = user.Id
|
||||||
}
|
|
||||||
tx.Create(ut)
|
tx.Create(ut)
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
return nil, user
|
return nil, user
|
||||||
|
|||||||
Reference in New Issue
Block a user