feat!: Add JWT

- `RUSTDESK_API_JWT_KEY`如果设置,将会启用JWT,token自动续期功能将失效
- 此功能是为了server端校验token的合法性
This commit is contained in:
lejianwen
2025-01-15 19:25:28 +08:00
parent 3c608463e6
commit f41b9d5887
7 changed files with 39 additions and 22 deletions

View File

@@ -5,6 +5,7 @@ import (
"Gwen/global" "Gwen/global"
"Gwen/http" "Gwen/http"
"Gwen/lib/cache" "Gwen/lib/cache"
"Gwen/lib/jwt"
"Gwen/lib/lock" "Gwen/lib/lock"
"Gwen/lib/logger" "Gwen/lib/logger"
"Gwen/lib/orm" "Gwen/lib/orm"
@@ -17,6 +18,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"os" "os"
"strconv" "strconv"
"time"
) )
// @title 管理系统API // @title 管理系统API
@@ -163,7 +165,7 @@ func InitGlobal() {
//jwt //jwt
//fmt.Println(global.Config.Jwt.PrivateKey) //fmt.Println(global.Config.Jwt.PrivateKey)
//global.Jwt = jwt.NewJwt(global.Config.Jwt.PrivateKey, global.Config.Jwt.ExpireDuration*time.Second) global.Jwt = jwt.NewJwt(global.Config.Jwt.Key, global.Config.Jwt.ExpireDuration*time.Second)
//locker //locker
global.Lock = lock.NewLocal() global.Lock = lock.NewLocal()

View File

@@ -36,6 +36,9 @@ logger:
proxy: proxy:
enable: false enable: false
host: "http://127.0.0.1:1080" host: "http://127.0.0.1:1080"
jwt:
key: ""
expire-duration: 360000
redis: redis:
addr: "127.0.0.1:6379" addr: "127.0.0.1:6379"
password: "" password: ""
@@ -53,6 +56,4 @@ oss:
callback-url: "" callback-url: ""
expire-time: 30 expire-time: 30
max-byte: 10240 max-byte: 10240
jwt:
private-key: "./conf/jwt_pri.pem"
expire-duration: 360000

View File

View File

@@ -3,6 +3,6 @@ package config
import "time" import "time"
type Jwt struct { type Jwt struct {
PrivateKey string `mapstructure:"private-key"` Key string `mapstructure:"key"`
ExpireDuration time.Duration `mapstructure:"expire-duration"` ExpireDuration time.Duration `mapstructure:"expire-duration"`
} }

View File

@@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"Gwen/global"
"Gwen/service" "Gwen/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -27,7 +28,21 @@ func RustAuth() gin.HandlerFunc {
//提取token格式是Bearer {token} //提取token格式是Bearer {token}
//这里只是简单的提取 //这里只是简单的提取
token = token[7:] token = token[7:]
//验证token //验证token
//检查是否设置了jwt key
if global.Config.Jwt.Key != "" {
uid, _ := service.AllService.UserService.VerifyJWT(token)
if uid == 0 {
c.JSON(401, gin.H{
"error": "Unauthorized",
})
c.Abort()
return
}
}
user, ut := service.AllService.UserService.InfoByAccessToken(token) user, ut := service.AllService.UserService.InfoByAccessToken(token)
if user.Id == 0 { if user.Id == 0 {
c.JSON(401, gin.H{ c.JSON(401, gin.H{
@@ -38,7 +53,7 @@ func RustAuth() gin.HandlerFunc {
} }
if !service.AllService.UserService.CheckUserEnable(user) { if !service.AllService.UserService.CheckUserEnable(user) {
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "账号已被禁用", "error": "Unauthorized",
}) })
c.Abort() c.Abort()
return return

View File

@@ -1,14 +1,13 @@
package jwt package jwt
import ( import (
"crypto/rsa" "fmt"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"os"
"time" "time"
) )
type Jwt struct { type Jwt struct {
privateKey *rsa.PrivateKey Key []byte
TokenExpireDuration time.Duration TokenExpireDuration time.Duration
} }
@@ -17,31 +16,24 @@ type UserClaims struct {
jwt.RegisteredClaims jwt.RegisteredClaims
} }
func NewJwt(privateKeyFile string, tokenExpireDuration time.Duration) *Jwt { func NewJwt(key string, tokenExpireDuration time.Duration) *Jwt {
privateKeyContent, err := os.ReadFile(privateKeyFile)
if err != nil {
panic(err)
}
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(privateKeyContent)
if err != nil {
panic(err)
}
return &Jwt{ return &Jwt{
privateKey: privateKey, Key: []byte(key),
TokenExpireDuration: tokenExpireDuration, TokenExpireDuration: tokenExpireDuration,
} }
} }
func (s *Jwt) GenerateToken(userId uint) string { func (s *Jwt) GenerateToken(userId uint) string {
t := jwt.NewWithClaims(jwt.SigningMethodRS256, t := jwt.NewWithClaims(jwt.SigningMethodHS256,
UserClaims{ UserClaims{
UserId: userId, UserId: userId,
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(s.TokenExpireDuration)), ExpiresAt: jwt.NewNumericDate(time.Now().Add(s.TokenExpireDuration)),
}, },
}) })
token, err := t.SignedString(s.privateKey) token, err := t.SignedString(s.Key)
if err != nil { if err != nil {
fmt.Println(err)
return "" return ""
} }
return token return token
@@ -49,7 +41,7 @@ func (s *Jwt) GenerateToken(userId uint) string {
func (s *Jwt) ParseToken(tokenString string) (uint, error) { func (s *Jwt) ParseToken(tokenString string) (uint, error) {
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.privateKey.Public(), nil return s.Key, nil
}) })
if err != nil { if err != nil {
return 0, err return 0, err

View File

@@ -68,6 +68,9 @@ func (us *UserService) InfoByAccessToken(token string) (*model.User, *model.User
// GenerateToken 生成token // GenerateToken 生成token
func (us *UserService) GenerateToken(u *model.User) string { func (us *UserService) GenerateToken(u *model.User) string {
if global.Config.Jwt.Key != "" {
return global.Jwt.GenerateToken(u.Id)
}
return utils.Md5(u.Username + time.Now().String()) return utils.Md5(u.Username + time.Now().String())
} }
@@ -461,3 +464,7 @@ func (us *UserService) AutoRefreshAccessToken(ut *model.UserToken) {
func (us *UserService) BatchDeleteUserToken(ids []uint) error { func (us *UserService) BatchDeleteUserToken(ids []uint) error {
return global.DB.Where("id in ?", ids).Delete(&model.UserToken{}).Error return global.DB.Where("id in ?", ids).Delete(&model.UserToken{}).Error
} }
func (us *UserService) VerifyJWT(token string) (uint, error) {
return global.Jwt.ParseToken(token)
}