style(oidc): Oidc style
This commit is contained in:
2
go.mod
2
go.mod
@@ -36,9 +36,11 @@ require (
|
|||||||
github.com/bytedance/sonic v1.8.0 // indirect
|
github.com/bytedance/sonic v1.8.0 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||||
|
github.com/coreos/go-oidc/v3 v3.12.0 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect
|
github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect
|
||||||
|
github.com/go-jose/go-jose/v4 v4.0.2 // indirect
|
||||||
github.com/go-ldap/ldap/v3 v3.4.10 // indirect
|
github.com/go-ldap/ldap/v3 v3.4.10 // indirect
|
||||||
github.com/go-openapi/jsonpointer v0.19.5 // indirect
|
github.com/go-openapi/jsonpointer v0.19.5 // indirect
|
||||||
github.com/go-openapi/jsonreference v0.19.6 // indirect
|
github.com/go-openapi/jsonreference v0.19.6 // indirect
|
||||||
|
|||||||
177
service/oauth.go
177
service/oauth.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/lejianwen/rustdesk-api/v2/global"
|
"github.com/lejianwen/rustdesk-api/v2/global"
|
||||||
"github.com/lejianwen/rustdesk-api/v2/model"
|
"github.com/lejianwen/rustdesk-api/v2/model"
|
||||||
"github.com/lejianwen/rustdesk-api/v2/utils"
|
"github.com/lejianwen/rustdesk-api/v2/utils"
|
||||||
@@ -82,10 +83,9 @@ func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
|
|||||||
func (os *OauthService) SetOauthCache(key string, item *OauthCacheItem, expire uint) {
|
func (os *OauthService) SetOauthCache(key string, item *OauthCacheItem, expire uint) {
|
||||||
OauthCache.Store(key, item)
|
OauthCache.Store(key, item)
|
||||||
if expire > 0 {
|
if expire > 0 {
|
||||||
go func() {
|
time.AfterFunc(time.Duration(expire)*time.Second, func() {
|
||||||
time.Sleep(time.Duration(expire) * time.Second)
|
|
||||||
os.DeleteOauthCache(key)
|
os.DeleteOauthCache(key)
|
||||||
}()
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,12 +96,12 @@ func (os *OauthService) DeleteOauthCache(key string) {
|
|||||||
func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) {
|
func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) {
|
||||||
state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
|
state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
|
||||||
verifier = ""
|
verifier = ""
|
||||||
if op == string(model.OauthTypeWebauth) {
|
if op == model.OauthTypeWebauth {
|
||||||
url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state
|
url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state
|
||||||
//url = "http://localhost:8888/_admin/#/oauth/" + code
|
//url = "http://localhost:8888/_admin/#/oauth/" + code
|
||||||
return nil, state, verifier, url
|
return nil, state, verifier, url
|
||||||
}
|
}
|
||||||
err, oauthInfo, oauthConfig := os.GetOauthConfig(op)
|
err, oauthInfo, oauthConfig, _ := os.GetOauthConfig(op)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
extras := make([]oauth2.AuthCodeOption, 0, 3)
|
extras := make([]oauth2.AuthCodeOption, 0, 3)
|
||||||
if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable {
|
if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable {
|
||||||
@@ -121,88 +121,80 @@ func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url
|
|||||||
return err, state, verifier, ""
|
return err, state, verifier, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Method to fetch OIDC configuration dynamically
|
func (os *OauthService) FetchOidcProvider(issuer string) (error, *oidc.Provider) {
|
||||||
func (os *OauthService) FetchOidcEndpoint(issuer string) (error, OidcEndpoint) {
|
|
||||||
configURL := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
|
||||||
|
|
||||||
// Get the HTTP client (with or without proxy based on configuration)
|
// Get the HTTP client (with or without proxy based on configuration)
|
||||||
client := getHTTPClientWithProxy()
|
client := getHTTPClientWithProxy()
|
||||||
|
|
||||||
resp, err := client.Get(configURL)
|
ctx := oidc.ClientContext(context.Background(), client)
|
||||||
|
|
||||||
|
provider, err := oidc.NewProvider(ctx, issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("failed to fetch OIDC configuration"), OidcEndpoint{}
|
return err, nil
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return errors.New("OIDC configuration not found, status code: %d"), OidcEndpoint{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var endpoint OidcEndpoint
|
return nil, provider
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&endpoint); err != nil {
|
|
||||||
return errors.New("failed to parse OIDC configuration"), OidcEndpoint{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, endpoint
|
func (os *OauthService) GithubProvider() *oidc.Provider {
|
||||||
}
|
return (&oidc.ProviderConfig{
|
||||||
|
IssuerURL: "",
|
||||||
func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) {
|
AuthURL: github.Endpoint.AuthURL,
|
||||||
oauthInfo := os.InfoByOp(op)
|
TokenURL: github.Endpoint.TokenURL,
|
||||||
if oauthInfo.Issuer == "" {
|
DeviceAuthURL: github.Endpoint.DeviceAuthURL,
|
||||||
return errors.New("issuer is empty"), OidcEndpoint{}
|
UserInfoURL: model.UserEndpointGithub,
|
||||||
}
|
JWKSURL: "",
|
||||||
return os.FetchOidcEndpoint(oauthInfo.Issuer)
|
Algorithms: nil,
|
||||||
|
}).NewProvider(context.Background())
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
|
// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
|
||||||
func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
|
func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config, provider *oidc.Provider) {
|
||||||
err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
|
//err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
|
||||||
if err != nil {
|
|
||||||
return err, nil, nil
|
|
||||||
}
|
|
||||||
// Maybe should validate the oauthConfig here
|
|
||||||
oauthType := oauthInfo.OauthType
|
|
||||||
err = model.ValidateOauthType(oauthType)
|
|
||||||
if err != nil {
|
|
||||||
return err, nil, nil
|
|
||||||
}
|
|
||||||
switch oauthType {
|
|
||||||
case model.OauthTypeGithub:
|
|
||||||
oauthConfig.Endpoint = github.Endpoint
|
|
||||||
oauthConfig.Scopes = []string{"read:user", "user:email"}
|
|
||||||
case model.OauthTypeOidc, model.OauthTypeGoogle:
|
|
||||||
var endpoint OidcEndpoint
|
|
||||||
err, endpoint = os.FetchOidcEndpoint(oauthInfo.Issuer)
|
|
||||||
if err != nil {
|
|
||||||
return err, nil, nil
|
|
||||||
}
|
|
||||||
oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL, TokenURL: endpoint.TokenURL}
|
|
||||||
oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
|
|
||||||
default:
|
|
||||||
return errors.New("unsupported OAuth type"), nil, nil
|
|
||||||
}
|
|
||||||
return nil, oauthInfo, oauthConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
|
|
||||||
func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
|
|
||||||
oauthInfo = os.InfoByOp(op)
|
oauthInfo = os.InfoByOp(op)
|
||||||
if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" {
|
if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" {
|
||||||
return errors.New("ConfigNotFound"), nil, nil
|
return errors.New("ConfigNotFound"), nil, nil, nil
|
||||||
}
|
}
|
||||||
// If the redirect URL is empty, use the default redirect URL
|
// If the redirect URL is empty, use the default redirect URL
|
||||||
if oauthInfo.RedirectUrl == "" {
|
if oauthInfo.RedirectUrl == "" {
|
||||||
oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback"
|
oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback"
|
||||||
}
|
}
|
||||||
return nil, oauthInfo, &oauth2.Config{
|
oauthConfig = &oauth2.Config{
|
||||||
ClientID: oauthInfo.ClientId,
|
ClientID: oauthInfo.ClientId,
|
||||||
ClientSecret: oauthInfo.ClientSecret,
|
ClientSecret: oauthInfo.ClientSecret,
|
||||||
RedirectURL: oauthInfo.RedirectUrl,
|
RedirectURL: oauthInfo.RedirectUrl,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Maybe should validate the oauthConfig here
|
||||||
|
oauthType := oauthInfo.OauthType
|
||||||
|
err = model.ValidateOauthType(oauthType)
|
||||||
|
if err != nil {
|
||||||
|
return err, nil, nil, nil
|
||||||
|
}
|
||||||
|
switch oauthType {
|
||||||
|
case model.OauthTypeGithub:
|
||||||
|
oauthConfig.Endpoint = github.Endpoint
|
||||||
|
oauthConfig.Scopes = []string{"read:user", "user:email"}
|
||||||
|
provider = os.GithubProvider()
|
||||||
|
//case model.OauthTypeGoogle: //google单独出来,可以少一次FetchOidcEndpoint请求
|
||||||
|
// oauthConfig.Endpoint = google.Endpoint
|
||||||
|
// oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
|
||||||
|
case model.OauthTypeOidc, model.OauthTypeGoogle:
|
||||||
|
err, provider = os.FetchOidcProvider(oauthInfo.Issuer)
|
||||||
|
if err != nil {
|
||||||
|
return err, nil, nil, nil
|
||||||
|
}
|
||||||
|
oauthConfig.Endpoint = provider.Endpoint()
|
||||||
|
oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
|
||||||
|
default:
|
||||||
|
return errors.New("unsupported OAuth type"), nil, nil, nil
|
||||||
|
}
|
||||||
|
return nil, oauthInfo, oauthConfig, provider
|
||||||
}
|
}
|
||||||
|
|
||||||
func getHTTPClientWithProxy() *http.Client {
|
func getHTTPClientWithProxy() *http.Client {
|
||||||
//todo add timeout
|
//add timeout 30s
|
||||||
|
timeout := time.Duration(60) * time.Second
|
||||||
if global.Config.Proxy.Enable {
|
if global.Config.Proxy.Enable {
|
||||||
if global.Config.Proxy.Host == "" {
|
if global.Config.Proxy.Host == "" {
|
||||||
global.Logger.Warn("Proxy is enabled but proxy host is empty.")
|
global.Logger.Warn("Proxy is enabled but proxy host is empty.")
|
||||||
@@ -216,33 +208,58 @@ func getHTTPClientWithProxy() *http.Client {
|
|||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
Proxy: http.ProxyURL(proxyURL),
|
Proxy: http.ProxyURL(proxyURL),
|
||||||
}
|
}
|
||||||
return &http.Client{Transport: transport}
|
return &http.Client{Transport: transport, Timeout: timeout}
|
||||||
}
|
}
|
||||||
return http.DefaultClient
|
return http.DefaultClient
|
||||||
}
|
}
|
||||||
|
func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string, nonce string, userData interface{}) (err error, client *http.Client) {
|
||||||
func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, verifier string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
|
|
||||||
|
|
||||||
// 设置代理客户端
|
// 设置代理客户端
|
||||||
httpClient := getHTTPClientWithProxy()
|
httpClient := getHTTPClientWithProxy()
|
||||||
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
||||||
|
|
||||||
var exchangeOpts []oauth2.AuthCodeOption
|
exchangeOpts := make([]oauth2.AuthCodeOption, 0, 1)
|
||||||
if verifier != "" {
|
if verifier != "" {
|
||||||
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)}
|
exchangeOpts = append(exchangeOpts, oauth2.VerifierOption(verifier))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 code 换取 token
|
token, err := oauthConfig.Exchange(ctx, code, exchangeOpts...)
|
||||||
var token *oauth2.Token
|
|
||||||
token, err = oauthConfig.Exchange(ctx, code, exchangeOpts...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
|
global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
|
||||||
return errors.New("GetOauthTokenError"), nil
|
return errors.New("GetOauthTokenError"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 获取 ID Token, github没有id_token
|
||||||
|
rawIDToken, ok := token.Extra("id_token").(string)
|
||||||
|
if ok && rawIDToken != "" {
|
||||||
|
// 验证 ID Token
|
||||||
|
v := provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID})
|
||||||
|
idToken, err2 := v.Verify(ctx, rawIDToken)
|
||||||
|
if err2 != nil {
|
||||||
|
global.Logger.Warn("IdTokenVerifyError: ", err2)
|
||||||
|
return errors.New("IdTokenVerifyError"), nil
|
||||||
|
}
|
||||||
|
if nonce != "" {
|
||||||
|
// 验证 nonce
|
||||||
|
var claims struct {
|
||||||
|
Nonce string `json:"nonce"`
|
||||||
|
}
|
||||||
|
if err2 = idToken.Claims(&claims); err2 != nil {
|
||||||
|
global.Logger.Warn("Failed to parse ID Token claims: ", err)
|
||||||
|
return errors.New("IDTokenClaimsError"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Nonce != nonce {
|
||||||
|
global.Logger.Warn("Nonce does not match")
|
||||||
|
return errors.New("NonceDoesNotMatch"), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 获取用户信息
|
// 获取用户信息
|
||||||
client = oauthConfig.Client(ctx, token)
|
client = oauthConfig.Client(ctx, token)
|
||||||
resp, err := client.Get(userEndpoint)
|
resp, err := client.Get(provider.UserInfoEndpoint())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
global.Logger.Warn("failed getting user info: ", err)
|
global.Logger.Warn("failed getting user info: ", err)
|
||||||
return errors.New("GetOauthUserInfoError"), nil
|
return errors.New("GetOauthUserInfoError"), nil
|
||||||
@@ -263,9 +280,9 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, ve
|
|||||||
}
|
}
|
||||||
|
|
||||||
// githubCallback github回调
|
// githubCallback github回调
|
||||||
func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string, verifier string) (error, *model.OauthUser) {
|
func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string) (error, *model.OauthUser) {
|
||||||
var user = &model.GithubUser{}
|
var user = &model.GithubUser{}
|
||||||
err, client := os.callbackBase(oauthConfig, code, verifier, model.UserEndpointGithub, user)
|
err, client := os.callbackBase(oauthConfig, provider, code, verifier, "", user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
@@ -277,9 +294,9 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// oidcCallback oidc回调, 通过code获取用户信息
|
// oidcCallback oidc回调, 通过code获取用户信息
|
||||||
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, verifier string, userInfoEndpoint string) (error, *model.OauthUser) {
|
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string) (error, *model.OauthUser) {
|
||||||
var user = &model.OidcUser{}
|
var user = &model.OidcUser{}
|
||||||
if err, _ := os.callbackBase(oauthConfig, code, verifier, userInfoEndpoint, user); err != nil {
|
if err, _ := os.callbackBase(oauthConfig, provider, code, verifier, "", user); err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
return nil, user.ToOauthUser()
|
return nil, user.ToOauthUser()
|
||||||
@@ -287,9 +304,7 @@ func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, ve
|
|||||||
|
|
||||||
// Callback: Get user information by code and op(Oauth provider)
|
// Callback: Get user information by code and op(Oauth provider)
|
||||||
func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) {
|
func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) {
|
||||||
var oauthInfo *model.Oauth
|
err, oauthInfo, oauthConfig, provider := os.GetOauthConfig(op)
|
||||||
var oauthConfig *oauth2.Config
|
|
||||||
err, oauthInfo, oauthConfig = os.GetOauthConfig(op)
|
|
||||||
// oauthType is already validated in GetOauthConfig
|
// oauthType is already validated in GetOauthConfig
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
@@ -297,13 +312,9 @@ func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUse
|
|||||||
oauthType := oauthInfo.OauthType
|
oauthType := oauthInfo.OauthType
|
||||||
switch oauthType {
|
switch oauthType {
|
||||||
case model.OauthTypeGithub:
|
case model.OauthTypeGithub:
|
||||||
err, oauthUser = os.githubCallback(oauthConfig, code, verifier)
|
err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier)
|
||||||
case model.OauthTypeOidc, model.OauthTypeGoogle:
|
case model.OauthTypeOidc, model.OauthTypeGoogle:
|
||||||
err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
|
err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier)
|
||||||
if err != nil {
|
|
||||||
return err, nil
|
|
||||||
}
|
|
||||||
err, oauthUser = os.oidcCallback(oauthConfig, code, verifier, endpoint.UserInfo)
|
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported OAuth type"), nil
|
return errors.New("unsupported OAuth type"), nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user