Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97f98cd6ce | ||
|
|
51f2920661 | ||
|
|
7a5d141ce8 | ||
|
|
3cef02a0bb | ||
|
|
46a7ecc1ba | ||
|
|
4d2b037f5e | ||
|
|
323364b24e | ||
|
|
f19109cdf8 | ||
|
|
527260d60a |
4
.github/workflows/build.yml
vendored
4
.github/workflows/build.yml
vendored
@@ -115,12 +115,12 @@ jobs:
|
||||
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
|
||||
else
|
||||
if [ "${{ matrix.job.platform }}" = "arm64" ]; then
|
||||
wget https://musl.cc/aarch64-linux-musl-cross.tgz
|
||||
wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
|
||||
tar -xf aarch64-linux-musl-cross.tgz
|
||||
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
|
||||
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
||||
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
|
||||
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz
|
||||
wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
|
||||
tar -xf armv7l-linux-musleabihf-cross.tgz
|
||||
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
|
||||
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
||||
|
||||
4
.github/workflows/build_test.yml
vendored
4
.github/workflows/build_test.yml
vendored
@@ -101,12 +101,12 @@ jobs:
|
||||
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
|
||||
else
|
||||
if [ "${{ matrix.job.platform }}" = "arm64" ]; then
|
||||
wget https://musl.cc/aarch64-linux-musl-cross.tgz
|
||||
wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
|
||||
tar -xf aarch64-linux-musl-cross.tgz
|
||||
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
|
||||
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
||||
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
|
||||
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz
|
||||
wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
|
||||
tar -xf armv7l-linux-musleabihf-cross.tgz
|
||||
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
|
||||
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -5,4 +5,4 @@ runtime/*
|
||||
go.sum
|
||||
resources/admin
|
||||
release
|
||||
data
|
||||
data/rustdeskapi.db
|
||||
@@ -163,6 +163,9 @@
|
||||
| RUSTDESK_API_APP_SHOW_SWAGGER | 是否可见swagger文档;`1`显示,`0`不显示,默认`0`不显示 | `1` |
|
||||
| RUSTDESK_API_APP_TOKEN_EXPIRE | token有效时长 | `168h` |
|
||||
| RUSTDESK_API_APP_DISABLE_PWD_LOGIN | 是否禁用密码登录; `true`, `false` 默认`false` | `false` |
|
||||
| RUSTDESK_API_APP_REGISTER_STATUS | 注册用户默认状态; 1 启用,2 禁用, 默认 1 | `1` |
|
||||
| RUSTDESK_API_APP_CAPTCHA_THRESHOLD | 验证码触发次数; -1 不启用, 0 一直启用, >0 登录错误次数后启用 ;默认 `3` | `3` |
|
||||
| RUSTDESK_API_APP_BAN_THRESHOLD | 封禁IP触发次数; 0 不启用, >0 登录错误次数后封禁IP; 默认 `0` | `0` |
|
||||
| -----ADMIN配置----- | ---------- | ---------- |
|
||||
| RUSTDESK_API_ADMIN_TITLE | 后台标题 | `RustDesk Api Admin` |
|
||||
| RUSTDESK_API_ADMIN_HELLO | 后台欢迎语,可以使用`html` | |
|
||||
|
||||
@@ -162,6 +162,9 @@ The table below does not list all configurations. Please refer to the configurat
|
||||
| RUSTDESK_API_APP_SHOW_SWAGGER | swagger visible; 1: yes, 0: no; default: 0 | `0` |
|
||||
| RUSTDESK_API_APP_TOKEN_EXPIRE | token expire duration | `168h` |
|
||||
| RUSTDESK_API_APP_DISABLE_PWD_LOGIN | disable password login | `false` |
|
||||
| RUSTDESK_API_APP_REGISTER_STATUS | register user default status ; 1 enabled , 2 disabled ; default 1 | `1` |
|
||||
| RUSTDESK_API_APP_CAPTCHA_THRESHOLD | captcha threshold; -1 disabled, 0 always enable, >0 threshold ;default `3` | `3` |
|
||||
| RUSTDESK_API_APP_BAN_THRESHOLD | ban ip threshold; 0 disabled, >0 threshold ; default `0` | `0` |
|
||||
| ----- ADMIN Configuration----- | ---------- | ---------- |
|
||||
| RUSTDESK_API_ADMIN_TITLE | Admin Title | `RustDesk Api Admin` |
|
||||
| RUSTDESK_API_ADMIN_HELLO | Admin welcome message, you can use `html` | |
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// @title 管理系统API
|
||||
@@ -175,8 +176,16 @@ func InitGlobal() {
|
||||
//service
|
||||
service.New(&global.Config, global.DB, global.Logger, global.Jwt, global.Lock)
|
||||
|
||||
global.LoginLimiter = utils.NewLoginLimiter(utils.SecurityPolicy{
|
||||
CaptchaThreshold: global.Config.App.CaptchaThreshold,
|
||||
BanThreshold: global.Config.App.BanThreshold,
|
||||
AttemptsWindow: 10 * time.Minute,
|
||||
BanDuration: 30 * time.Minute,
|
||||
})
|
||||
global.LoginLimiter.RegisterProvider(utils.B64StringCaptchaProvider{})
|
||||
DatabaseAutoUpdate()
|
||||
}
|
||||
|
||||
func DatabaseAutoUpdate() {
|
||||
version := 262
|
||||
|
||||
|
||||
@@ -2,14 +2,21 @@ lang: "zh-CN"
|
||||
app:
|
||||
web-client: 1 # 1:启用 0:禁用
|
||||
register: false #是否开启注册
|
||||
register-status: 1 # 注册用户默认状态 1:启用 2:禁用
|
||||
captcha-threshold: 3 # <0:disabled, 0 always, >0:enabled
|
||||
ban-threshold: 0 # 0:disabled, >0:enabled
|
||||
show-swagger: 0 # 1:启用 0:禁用
|
||||
token-expire: 168h
|
||||
web-sso: true #web auth sso
|
||||
disable-pwd-login: false #禁用密码登录
|
||||
|
||||
admin:
|
||||
title: "RustDesk Api Admin"
|
||||
hello-file: "./conf/admin/hello.html" #优先使用file
|
||||
hello: ""
|
||||
# ID Server and Relay Server ports https://github.com/lejianwen/rustdesk-api/issues/257
|
||||
id-server-port: 21116 # ID Server port (for server cmd)
|
||||
relay-server-port: 21117 # ID Server port (for server cmd)
|
||||
gin:
|
||||
api-addr: "0.0.0.0:21114"
|
||||
mode: "release" #release,debug,test
|
||||
|
||||
@@ -14,17 +14,22 @@ const (
|
||||
)
|
||||
|
||||
type App struct {
|
||||
WebClient int `mapstructure:"web-client"`
|
||||
Register bool `mapstructure:"register"`
|
||||
ShowSwagger int `mapstructure:"show-swagger"`
|
||||
TokenExpire time.Duration `mapstructure:"token-expire"`
|
||||
WebSso bool `mapstructure:"web-sso"`
|
||||
DisablePwdLogin bool `mapstructure:"disable-pwd-login"`
|
||||
WebClient int `mapstructure:"web-client"`
|
||||
Register bool `mapstructure:"register"`
|
||||
RegisterStatus int `mapstructure:"register-status"`
|
||||
ShowSwagger int `mapstructure:"show-swagger"`
|
||||
TokenExpire time.Duration `mapstructure:"token-expire"`
|
||||
WebSso bool `mapstructure:"web-sso"`
|
||||
DisablePwdLogin bool `mapstructure:"disable-pwd-login"`
|
||||
CaptchaThreshold int `mapstructure:"captcha-threshold"`
|
||||
BanThreshold int `mapstructure:"ban-threshold"`
|
||||
}
|
||||
type Admin struct {
|
||||
Title string `mapstructure:"title"`
|
||||
Hello string `mapstructure:"hello"`
|
||||
HelloFile string `mapstructure:"hello-file"`
|
||||
Title string `mapstructure:"title"`
|
||||
Hello string `mapstructure:"hello"`
|
||||
HelloFile string `mapstructure:"hello-file"`
|
||||
IdServerPort int `mapstructure:"id-server-port"`
|
||||
RelayServerPort int `mapstructure:"relay-server-port"`
|
||||
}
|
||||
type Config struct {
|
||||
Lang string `mapstructure:"lang"`
|
||||
@@ -43,6 +48,15 @@ type Config struct {
|
||||
Ldap Ldap
|
||||
}
|
||||
|
||||
func (a *Admin) Init() {
|
||||
if a.IdServerPort == 0 {
|
||||
a.IdServerPort = DefaultIdServerPort
|
||||
}
|
||||
if a.RelayServerPort == 0 {
|
||||
a.RelayServerPort = DefaultRelayServerPort
|
||||
}
|
||||
}
|
||||
|
||||
// Init 初始化配置
|
||||
func Init(rowVal *Config, path string) *viper.Viper {
|
||||
if path == "" {
|
||||
@@ -77,7 +91,7 @@ func Init(rowVal *Config, path string) *viper.Viper {
|
||||
panic(fmt.Errorf("Fatal error config: %s \n", err))
|
||||
}
|
||||
rowVal.Rustdesk.LoadKeyFile()
|
||||
rowVal.Rustdesk.ParsePort()
|
||||
rowVal.Admin.Init()
|
||||
return v
|
||||
}
|
||||
|
||||
|
||||
@@ -2,8 +2,6 @@ package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -40,19 +38,3 @@ func (rd *Rustdesk) LoadKeyFile() {
|
||||
return
|
||||
}
|
||||
}
|
||||
func (rd *Rustdesk) ParsePort() {
|
||||
// Parse port
|
||||
idres := strings.Split(rd.IdServer, ":")
|
||||
if len(idres) == 1 {
|
||||
rd.IdServerPort = DefaultIdServerPort
|
||||
} else if len(idres) == 2 {
|
||||
rd.IdServerPort, _ = strconv.Atoi(idres[1])
|
||||
}
|
||||
|
||||
relayres := strings.Split(rd.RelayServer, ":")
|
||||
if len(relayres) == 1 {
|
||||
rd.RelayServerPort = DefaultRelayServerPort
|
||||
} else if len(relayres) == 2 {
|
||||
rd.RelayServerPort, _ = strconv.Atoi(relayres[1])
|
||||
}
|
||||
}
|
||||
|
||||
0
data/.gitkeep
Normal file
0
data/.gitkeep
Normal file
@@ -1,4 +1,4 @@
|
||||
// Package admin Code generated by swaggo/swag. DO NOT EDIT
|
||||
// Package admin Content generated by swaggo/swag. DO NOT EDIT
|
||||
package admin
|
||||
|
||||
import "github.com/swaggo/swag"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package api Code generated by swaggo/swag. DO NOT EDIT
|
||||
// Package api Content generated by swaggo/swag. DO NOT EDIT
|
||||
package api
|
||||
|
||||
import "github.com/swaggo/swag"
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/lejianwen/rustdesk-api/v2/lib/jwt"
|
||||
"github.com/lejianwen/rustdesk-api/v2/lib/lock"
|
||||
"github.com/lejianwen/rustdesk-api/v2/lib/upload"
|
||||
"github.com/lejianwen/rustdesk-api/v2/utils"
|
||||
"github.com/nicksnyder/go-i18n/v2/i18n"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/viper"
|
||||
@@ -31,8 +32,9 @@ var (
|
||||
ValidStruct func(*gin.Context, interface{}) []string
|
||||
ValidVar func(ctx *gin.Context, field interface{}, tag string) []string
|
||||
}
|
||||
Oss *upload.Oss
|
||||
Jwt *jwt.Jwt
|
||||
Lock lock.Locker
|
||||
Localizer func(lang string) *i18n.Localizer
|
||||
Oss *upload.Oss
|
||||
Jwt *jwt.Jwt
|
||||
Lock lock.Locker
|
||||
Localizer func(lang string) *i18n.Localizer
|
||||
LoginLimiter *utils.LoginLimiter
|
||||
)
|
||||
|
||||
@@ -11,135 +11,11 @@ import (
|
||||
adResp "github.com/lejianwen/rustdesk-api/v2/http/response/admin"
|
||||
"github.com/lejianwen/rustdesk-api/v2/model"
|
||||
"github.com/lejianwen/rustdesk-api/v2/service"
|
||||
"github.com/mojocn/base64Captcha"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Login struct {
|
||||
}
|
||||
|
||||
// Captcha 验证码结构
|
||||
type Captcha struct {
|
||||
Id string `json:"id"` // 验证码 ID
|
||||
B64 string `json:"b64"` // base64 验证码
|
||||
Code string `json:"-"` // 验证码内容
|
||||
ExpiresAt time.Time `json:"-"` // 过期时间
|
||||
}
|
||||
type LoginLimiter struct {
|
||||
mu sync.RWMutex
|
||||
failCount map[string]int // 记录每个 IP 的失败次数
|
||||
timestamp map[string]time.Time // 记录每个 IP 的最后失败时间
|
||||
captchas map[string]Captcha // 每个 IP 的验证码
|
||||
threshold int // 失败阈值
|
||||
expiry time.Duration // 失败记录过期时间
|
||||
}
|
||||
|
||||
func NewLoginLimiter(threshold int, expiry time.Duration) *LoginLimiter {
|
||||
return &LoginLimiter{
|
||||
failCount: make(map[string]int),
|
||||
timestamp: make(map[string]time.Time),
|
||||
captchas: make(map[string]Captcha),
|
||||
threshold: threshold,
|
||||
expiry: expiry,
|
||||
}
|
||||
}
|
||||
|
||||
// RecordFailure 记录登录失败
|
||||
func (l *LoginLimiter) RecordFailure(ip string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
// 如果该 IP 的记录已经过期,重置计数
|
||||
if lastTime, exists := l.timestamp[ip]; exists && time.Since(lastTime) > l.expiry {
|
||||
l.failCount[ip] = 0
|
||||
}
|
||||
|
||||
// 更新失败次数和时间戳
|
||||
l.failCount[ip]++
|
||||
l.timestamp[ip] = time.Now()
|
||||
}
|
||||
|
||||
// NeedsCaptcha 检查是否需要验证码
|
||||
func (l *LoginLimiter) NeedsCaptcha(ip string) bool {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
|
||||
// 检查记录是否存在且未过期
|
||||
if lastTime, exists := l.timestamp[ip]; exists && time.Since(lastTime) <= l.expiry {
|
||||
return l.failCount[ip] >= l.threshold
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GenerateCaptcha 为指定 IP 生成验证码
|
||||
func (l *LoginLimiter) GenerateCaptcha(ip string) Captcha {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
capd := base64Captcha.NewDriverString(50, 150, 5, 10, 4, "1234567890abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
|
||||
b64cap := base64Captcha.NewCaptcha(capd, base64Captcha.DefaultMemStore)
|
||||
id, b64s, answer, err := b64cap.Generate()
|
||||
if err != nil {
|
||||
global.Logger.Error("Generate captcha failed: " + err.Error())
|
||||
return Captcha{}
|
||||
}
|
||||
// 保存验证码到对应 IP
|
||||
l.captchas[ip] = Captcha{
|
||||
Id: id,
|
||||
B64: b64s,
|
||||
Code: answer,
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
}
|
||||
return l.captchas[ip]
|
||||
}
|
||||
|
||||
// VerifyCaptcha 验证指定 IP 的验证码
|
||||
func (l *LoginLimiter) VerifyCaptcha(ip, code string) bool {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
|
||||
// 检查验证码是否存在且未过期
|
||||
if captcha, exists := l.captchas[ip]; exists && time.Now().Before(captcha.ExpiresAt) {
|
||||
return captcha.Code == code
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RemoveCaptcha 移除指定 IP 的验证码
|
||||
func (l *LoginLimiter) RemoveCaptcha(ip string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
delete(l.captchas, ip)
|
||||
}
|
||||
|
||||
// CleanupExpired 清理过期的记录
|
||||
func (l *LoginLimiter) CleanupExpired() {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for ip, lastTime := range l.timestamp {
|
||||
if now.Sub(lastTime) > l.expiry {
|
||||
delete(l.failCount, ip)
|
||||
delete(l.timestamp, ip)
|
||||
delete(l.captchas, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LoginLimiter) RemoveRecord(ip string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
delete(l.failCount, ip)
|
||||
delete(l.timestamp, ip)
|
||||
delete(l.captchas, ip)
|
||||
}
|
||||
|
||||
var loginLimiter = NewLoginLimiter(3, 5*time.Minute)
|
||||
|
||||
// Login 登录
|
||||
// @Tags 登录
|
||||
// @Summary 登录
|
||||
@@ -156,10 +32,16 @@ func (ct *Login) Login(c *gin.Context) {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "PwdLoginDisabled"))
|
||||
return
|
||||
}
|
||||
|
||||
// 检查登录限制
|
||||
loginLimiter := global.LoginLimiter
|
||||
clientIp := c.ClientIP()
|
||||
_, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp)
|
||||
|
||||
f := &admin.Login{}
|
||||
err := c.ShouldBindJSON(f)
|
||||
clientIp := c.ClientIP()
|
||||
if err != nil {
|
||||
loginLimiter.RecordFailedAttempt(clientIp)
|
||||
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp))
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
|
||||
return
|
||||
@@ -167,14 +49,15 @@ func (ct *Login) Login(c *gin.Context) {
|
||||
|
||||
errList := global.Validator.ValidStruct(c, f)
|
||||
if len(errList) > 0 {
|
||||
loginLimiter.RecordFailedAttempt(clientIp)
|
||||
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp))
|
||||
response.Fail(c, 101, errList[0])
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否需要验证码
|
||||
if loginLimiter.NeedsCaptcha(clientIp) {
|
||||
if f.Captcha == "" || !loginLimiter.VerifyCaptcha(clientIp, f.Captcha) {
|
||||
if needCaptcha {
|
||||
if f.CaptchaId == "" || f.Captcha == "" || !loginLimiter.VerifyCaptcha(f.CaptchaId, f.Captcha) {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError"))
|
||||
return
|
||||
}
|
||||
@@ -184,17 +67,19 @@ func (ct *Login) Login(c *gin.Context) {
|
||||
|
||||
if u.Id == 0 {
|
||||
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp))
|
||||
loginLimiter.RecordFailure(clientIp)
|
||||
if loginLimiter.NeedsCaptcha(clientIp) {
|
||||
loginLimiter.RemoveCaptcha(clientIp)
|
||||
loginLimiter.RecordFailedAttempt(clientIp)
|
||||
if _, needCaptcha = loginLimiter.CheckSecurityStatus(clientIp); needCaptcha {
|
||||
response.Fail(c, 110, response.TranslateMsg(c, "UsernameOrPasswordError"))
|
||||
} else {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "UsernameOrPasswordError"))
|
||||
}
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "UsernameOrPasswordError"))
|
||||
return
|
||||
}
|
||||
|
||||
if !service.AllService.UserService.CheckUserEnable(u) {
|
||||
if loginLimiter.NeedsCaptcha(clientIp) {
|
||||
loginLimiter.RemoveCaptcha(clientIp)
|
||||
if needCaptcha {
|
||||
response.Fail(c, 110, response.TranslateMsg(c, "UserDisabled"))
|
||||
return
|
||||
}
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "UserDisabled"))
|
||||
return
|
||||
@@ -209,23 +94,37 @@ func (ct *Login) Login(c *gin.Context) {
|
||||
Platform: f.Platform,
|
||||
})
|
||||
|
||||
// 成功后清除记录
|
||||
loginLimiter.RemoveRecord(clientIp)
|
||||
|
||||
// 清理过期记录
|
||||
go loginLimiter.CleanupExpired()
|
||||
|
||||
// 登录成功,清除登录限制
|
||||
loginLimiter.RemoveAttempts(clientIp)
|
||||
responseLoginSuccess(c, u, ut.Token)
|
||||
}
|
||||
func (ct *Login) Captcha(c *gin.Context) {
|
||||
loginLimiter := global.LoginLimiter
|
||||
clientIp := c.ClientIP()
|
||||
if !loginLimiter.NeedsCaptcha(clientIp) {
|
||||
banned, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp)
|
||||
if banned {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "LoginBanned"))
|
||||
return
|
||||
}
|
||||
if !needCaptcha {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired"))
|
||||
return
|
||||
}
|
||||
captcha := loginLimiter.GenerateCaptcha(clientIp)
|
||||
err, captcha := loginLimiter.RequireCaptcha()
|
||||
if err != nil {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error())
|
||||
return
|
||||
}
|
||||
err, b64 := loginLimiter.DrawCaptcha(captcha.Content)
|
||||
if err != nil {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"captcha": captcha,
|
||||
"captcha": gin.H{
|
||||
"id": captcha.Id,
|
||||
"b64": b64,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -257,12 +156,18 @@ func (ct *Login) Logout(c *gin.Context) {
|
||||
// @Failure 500 {object} response.ErrorResponse
|
||||
// @Router /admin/login-options [post]
|
||||
func (ct *Login) LoginOptions(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
loginLimiter := global.LoginLimiter
|
||||
clientIp := c.ClientIP()
|
||||
banned, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp)
|
||||
if banned {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "LoginBanned"))
|
||||
return
|
||||
}
|
||||
ops := service.AllService.OauthService.GetOauthProviders()
|
||||
response.Success(c, gin.H{
|
||||
"ops": ops,
|
||||
"register": global.Config.App.Register,
|
||||
"need_captcha": loginLimiter.NeedsCaptcha(ip),
|
||||
"need_captcha": needCaptcha,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -119,7 +119,16 @@ func (r *Rustdesk) SendCmd(c *gin.Context) {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
|
||||
return
|
||||
}
|
||||
res, err := service.AllService.ServerCmdService.SendCmd(rc.Target, rc.Cmd, rc.Option)
|
||||
|
||||
port := 0
|
||||
switch rc.Target {
|
||||
case model.ServerCmdTargetIdServer:
|
||||
port = global.Config.Admin.IdServerPort - 1
|
||||
case model.ServerCmdTargetRelayServer:
|
||||
port = global.Config.Admin.RelayServerPort
|
||||
}
|
||||
|
||||
res, err := service.AllService.ServerCmdService.SendCmd(port, rc.Cmd, rc.Option)
|
||||
if err != nil {
|
||||
response.Fail(c, 101, err.Error())
|
||||
return
|
||||
|
||||
@@ -320,11 +320,22 @@ func (ct *User) Register(c *gin.Context) {
|
||||
response.Fail(c, 101, errList[0])
|
||||
return
|
||||
}
|
||||
u := service.AllService.UserService.Register(f.Username, f.Email, f.Password)
|
||||
regStatus := model.StatusCode(global.Config.App.RegisterStatus)
|
||||
// 注册状态可能未配置,默认启用
|
||||
if regStatus != model.COMMON_STATUS_DISABLED && regStatus != model.COMMON_STATUS_ENABLE {
|
||||
regStatus = model.COMMON_STATUS_ENABLE
|
||||
}
|
||||
|
||||
u := service.AllService.UserService.Register(f.Username, f.Email, f.Password, regStatus)
|
||||
if u == nil || u.Id == 0 {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed"))
|
||||
return
|
||||
}
|
||||
if regStatus == model.COMMON_STATUS_DISABLED {
|
||||
// 需要管理员审核
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "RegisterSuccessWaitAdminConfirm"))
|
||||
return
|
||||
}
|
||||
// 注册成功后自动登录
|
||||
ut := service.AllService.UserService.Login(u, &model.LoginLog{
|
||||
UserId: u.Id,
|
||||
|
||||
@@ -31,10 +31,16 @@ func (l *Login) Login(c *gin.Context) {
|
||||
response.Error(c, response.TranslateMsg(c, "PwdLoginDisabled"))
|
||||
return
|
||||
}
|
||||
|
||||
// 检查登录限制
|
||||
loginLimiter := global.LoginLimiter
|
||||
clientIp := c.ClientIP()
|
||||
|
||||
f := &api.LoginForm{}
|
||||
err := c.ShouldBindJSON(f)
|
||||
//fmt.Println(f)
|
||||
if err != nil {
|
||||
loginLimiter.RecordFailedAttempt(clientIp)
|
||||
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), c.ClientIP()))
|
||||
response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error())
|
||||
return
|
||||
@@ -42,6 +48,7 @@ func (l *Login) Login(c *gin.Context) {
|
||||
|
||||
errList := global.Validator.ValidStruct(c, f)
|
||||
if len(errList) > 0 {
|
||||
loginLimiter.RecordFailedAttempt(clientIp)
|
||||
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), c.ClientIP()))
|
||||
response.Error(c, errList[0])
|
||||
return
|
||||
@@ -50,6 +57,7 @@ func (l *Login) Login(c *gin.Context) {
|
||||
u := service.AllService.UserService.InfoByUsernamePassword(f.Username, f.Password)
|
||||
|
||||
if u.Id == 0 {
|
||||
loginLimiter.RecordFailedAttempt(clientIp)
|
||||
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), c.ClientIP()))
|
||||
response.Error(c, response.TranslateMsg(c, "UsernameOrPasswordError"))
|
||||
return
|
||||
|
||||
@@ -33,7 +33,7 @@ func ApiInit() {
|
||||
g.NoRoute(func(c *gin.Context) {
|
||||
c.String(http.StatusNotFound, "404 not found")
|
||||
})
|
||||
g.Use(middleware.Logger(), gin.Recovery())
|
||||
g.Use(middleware.Logger(), middleware.Limiter(), gin.Recovery())
|
||||
router.WebInit(g)
|
||||
router.Init(g)
|
||||
router.ApiInit(g)
|
||||
|
||||
22
http/middleware/limiter.go
Normal file
22
http/middleware/limiter.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lejianwen/rustdesk-api/v2/global"
|
||||
"github.com/lejianwen/rustdesk-api/v2/http/response"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func Limiter() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
loginLimiter := global.LoginLimiter
|
||||
clientIp := c.ClientIP()
|
||||
banned, _ := loginLimiter.CheckSecurityStatus(clientIp)
|
||||
if banned {
|
||||
response.Fail(c, http.StatusLocked, response.TranslateMsg(c, "Banned"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
package admin
|
||||
|
||||
type Login struct {
|
||||
Username string `json:"username" validate:"required" label:"用户名"`
|
||||
Password string `json:"password,omitempty" validate:"required" label:"密码"`
|
||||
Platform string `json:"platform" label:"平台"`
|
||||
Captcha string `json:"captcha,omitempty" label:"验证码"`
|
||||
Username string `json:"username" validate:"required" label:"用户名"`
|
||||
Password string `json:"password,omitempty" validate:"required" label:"密码"`
|
||||
Platform string `json:"platform" label:"平台"`
|
||||
Captcha string `json:"captcha,omitempty" label:"验证码"`
|
||||
CaptchaId string `json:"captcha_id,omitempty"`
|
||||
}
|
||||
|
||||
type LoginLogQuery struct {
|
||||
|
||||
@@ -142,4 +142,14 @@ other = "Password login disabled."
|
||||
[CannotShareToSelf]
|
||||
description = "Cannot share to self."
|
||||
one = "Cannot share to self."
|
||||
other = "Cannot share to self."
|
||||
other = "Cannot share to self."
|
||||
|
||||
[Banned]
|
||||
description = "Banned."
|
||||
one = "Banned."
|
||||
other = "Banned."
|
||||
|
||||
[RegisterSuccessWaitAdminConfirm]
|
||||
description = "Register success, wait admin confirm."
|
||||
one = "Register success, wait admin confirm."
|
||||
other = "Register success, wait admin confirm."
|
||||
@@ -151,4 +151,14 @@ other = "Inicio de sesión con contraseña deshabilitado."
|
||||
[CannotShareToSelf]
|
||||
description = "Cannot share to self."
|
||||
one = "No se puede compartir con uno mismo."
|
||||
other = "No se puede compartir con uno mismo."
|
||||
other = "No se puede compartir con uno mismo."
|
||||
|
||||
[Banned]
|
||||
description = "Banned."
|
||||
one = "Prohibido."
|
||||
other = "Prohibido."
|
||||
|
||||
[RegisterSuccessWaitAdminConfirm]
|
||||
description = "Register success, wait admin confirm."
|
||||
one = "Registro exitoso, espere la confirmación del administrador."
|
||||
other = "Registro exitoso, espere la confirmación del administrador."
|
||||
@@ -151,4 +151,14 @@ other = "Connexion par mot de passe désactivée."
|
||||
[CannotShareToSelf]
|
||||
description = "Cannot share to self."
|
||||
one = "Impossible de partager avec soi-même."
|
||||
other = "Impossible de partager avec soi-même."
|
||||
other = "Impossible de partager avec soi-même."
|
||||
|
||||
[Banned]
|
||||
description = "Banned."
|
||||
one = "Banni."
|
||||
other = "Banni."
|
||||
|
||||
[RegisterSuccessWaitAdminConfirm]
|
||||
description = "Register success wait admin confirm."
|
||||
one = "Inscription réussie, veuillez attendre la confirmation de l'administrateur."
|
||||
other = "Inscription réussie, veuillez attendre la confirmation de l'administrateur."
|
||||
@@ -145,4 +145,14 @@ other = "비밀번호 로그인이 비활성화되었습니다."
|
||||
[CannotShareToSelf]
|
||||
description = "Cannot share to self."
|
||||
one = "자기 자신에게 공유할 수 없습니다."
|
||||
other = "자기 자신에게 공유할 수 없습니다."
|
||||
other = "자기 자신에게 공유할 수 없습니다."
|
||||
|
||||
[Banned]
|
||||
description = "Banned."
|
||||
one = "금지됨."
|
||||
other = "금지됨."
|
||||
|
||||
[RegisterSuccessWaitAdminConfirm]
|
||||
description = "Register success wait admin confirm."
|
||||
one = "가입 성공, 관리자 확인 대기 중."
|
||||
other = "가입 성공, 관리자 확인 대기 중."
|
||||
@@ -151,4 +151,14 @@ other = "Вход по паролю отключен."
|
||||
[CannotShareToSelf]
|
||||
description = "Cannot share to self."
|
||||
one = "Нельзя поделиться с собой."
|
||||
other = "Нельзя поделиться с собой."
|
||||
other = "Нельзя поделиться с собой."
|
||||
|
||||
[Banned]
|
||||
description = "Banned."
|
||||
one = "Заблокировано."
|
||||
other = "Заблокировано."
|
||||
|
||||
[RegisterSuccessWaitAdminConfirm]
|
||||
description = "Register success wait admin confirm."
|
||||
one = "Регистрация прошла успешно, ожидайте подтверждения администратора."
|
||||
other = "Регистрация прошла успешно, ожидайте подтверждения администратора."
|
||||
@@ -144,4 +144,14 @@ other = "密码登录已禁用。"
|
||||
[CannotShareToSelf]
|
||||
description = "Cannot share to self."
|
||||
one = "不能共享给自己。"
|
||||
other = "不能共享给自己。"
|
||||
other = "不能共享给自己。"
|
||||
|
||||
[Banned]
|
||||
description = "Banned."
|
||||
one = "已被封禁。"
|
||||
other = "已被封禁。"
|
||||
|
||||
[RegisterSuccessWaitAdminConfirm]
|
||||
description = "Register success, wait for admin confirm."
|
||||
one = "注册成功,请等待管理员审核。"
|
||||
other = "注册成功,请等待管理员审核。"
|
||||
@@ -144,4 +144,14 @@ other = "密碼登錄已禁用。"
|
||||
[CannotShareToSelf]
|
||||
description = "Cannot share to self."
|
||||
one = "無法共享給自己。"
|
||||
other = "無法共享給自己。"
|
||||
other = "無法共享給自己。"
|
||||
|
||||
[Banned]
|
||||
description = "Banned."
|
||||
one = "禁止使用。"
|
||||
other = "禁止使用。"
|
||||
|
||||
[RegisterSuccessWaitAdminConfirm]
|
||||
description = "Register success wait admin confirm."
|
||||
one = "註冊成功,請等待管理員確認。"
|
||||
other = "註冊成功,請等待管理員確認。"
|
||||
2
resources/web2/js/dist/index.js
vendored
2
resources/web2/js/dist/index.js
vendored
@@ -11550,7 +11550,7 @@ async function or(u) {
|
||||
let E = [], l = [];
|
||||
for (let d = 0; d < e.length; d++) {
|
||||
const c = 1 << 7 - d % 8;
|
||||
(s[d / 8] & c) === c ? E.push(e[d]) : l.push(e[d])
|
||||
(s[Math.floor(d / 8)] & c) === c ? E.push(e[d]) : l.push(e[d])
|
||||
}
|
||||
_t(E, l), n.close();
|
||||
return
|
||||
|
||||
@@ -411,7 +411,7 @@ func (ls *LdapService) isUserAdmin(cfg *config.Ldap, ldapUser *LdapUser) bool {
|
||||
// Check "memberOf" directly
|
||||
if len(ldapUser.MemberOf) > 0 {
|
||||
for _, group := range ldapUser.MemberOf {
|
||||
if group == adminGroup {
|
||||
if strings.EqualFold(group, adminGroup) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,14 +40,7 @@ func (is *ServerCmdService) Create(u *model.ServerCmd) error {
|
||||
}
|
||||
|
||||
// SendCmd 发送命令
|
||||
func (is *ServerCmdService) SendCmd(target string, cmd string, arg string) (string, error) {
|
||||
port := 0
|
||||
switch target {
|
||||
case model.ServerCmdTargetIdServer:
|
||||
port = Config.Rustdesk.IdServerPort - 1
|
||||
case model.ServerCmdTargetRelayServer:
|
||||
port = Config.Rustdesk.RelayServerPort
|
||||
}
|
||||
func (is *ServerCmdService) SendCmd(port int, cmd string, arg string) (string, error) {
|
||||
//组装命令
|
||||
cmd = cmd + " " + arg
|
||||
res, err := is.SendSocketCmd("v6", port, cmd)
|
||||
|
||||
@@ -412,12 +412,13 @@ func (us *UserService) IsPasswordEmptyByUser(u *model.User) bool {
|
||||
}
|
||||
|
||||
// Register 注册, 如果用户名已存在则返回nil
|
||||
func (us *UserService) Register(username string, email string, password string) *model.User {
|
||||
func (us *UserService) Register(username string, email string, password string, status model.StatusCode) *model.User {
|
||||
u := &model.User{
|
||||
Username: username,
|
||||
Email: email,
|
||||
Password: password,
|
||||
GroupId: 1,
|
||||
Status: status,
|
||||
}
|
||||
err := us.Create(u)
|
||||
if err != nil {
|
||||
|
||||
48
utils/captcha.go
Normal file
48
utils/captcha.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/mojocn/base64Captcha"
|
||||
"time"
|
||||
)
|
||||
|
||||
var capdString = base64Captcha.NewDriverString(50, 150, 0, 5, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
|
||||
|
||||
var capdMath = base64Captcha.NewDriverMath(50, 150, 3, 10, nil, nil, nil)
|
||||
|
||||
type B64StringCaptchaProvider struct{}
|
||||
|
||||
func (p B64StringCaptchaProvider) Generate() (string, string, string, error) {
|
||||
id, content, answer := capdString.GenerateIdQuestionAnswer()
|
||||
return id, content, answer, nil
|
||||
}
|
||||
|
||||
func (p B64StringCaptchaProvider) Expiration() time.Duration {
|
||||
return 5 * time.Minute
|
||||
}
|
||||
func (p B64StringCaptchaProvider) Draw(content string) (string, error) {
|
||||
item, err := capdString.DrawCaptcha(content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b64str := item.EncodeB64string()
|
||||
return b64str, nil
|
||||
}
|
||||
|
||||
type B64MathCaptchaProvider struct{}
|
||||
|
||||
func (p B64MathCaptchaProvider) Generate() (string, string, string, error) {
|
||||
id, content, answer := capdMath.GenerateIdQuestionAnswer()
|
||||
return id, content, answer, nil
|
||||
}
|
||||
|
||||
func (p B64MathCaptchaProvider) Expiration() time.Duration {
|
||||
return 5 * time.Minute
|
||||
}
|
||||
func (p B64MathCaptchaProvider) Draw(content string) (string, error) {
|
||||
item, err := capdMath.DrawCaptcha(content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b64str := item.EncodeB64string()
|
||||
return b64str, nil
|
||||
}
|
||||
296
utils/login_limiter.go
Normal file
296
utils/login_limiter.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 安全策略配置
|
||||
type SecurityPolicy struct {
|
||||
CaptchaThreshold int // 尝试失败次数达到验证码阈值,小于0表示不启用, 0表示强制启用
|
||||
BanThreshold int // 尝试失败次数达到封禁阈值,为0表示不启用
|
||||
AttemptsWindow time.Duration
|
||||
BanDuration time.Duration
|
||||
}
|
||||
|
||||
// 验证码提供者接口
|
||||
type CaptchaProvider interface {
|
||||
Generate() (id string, content string, answer string, err error)
|
||||
//Validate(ip, code string) bool
|
||||
Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow
|
||||
Draw(content string) (string, error) // 绘制验证码
|
||||
}
|
||||
|
||||
// 验证码元数据
|
||||
type CaptchaMeta struct {
|
||||
Id string
|
||||
Content string
|
||||
Answer string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// IP封禁记录
|
||||
type BanRecord struct {
|
||||
ExpiresAt time.Time
|
||||
Reason string
|
||||
}
|
||||
|
||||
// 登录限制器
|
||||
type LoginLimiter struct {
|
||||
mu sync.Mutex
|
||||
policy SecurityPolicy
|
||||
attempts map[string][]time.Time //
|
||||
captchas map[string]CaptchaMeta
|
||||
bannedIPs map[string]BanRecord
|
||||
provider CaptchaProvider
|
||||
cleanupStop chan struct{}
|
||||
}
|
||||
|
||||
var defaultSecurityPolicy = SecurityPolicy{
|
||||
CaptchaThreshold: 3,
|
||||
BanThreshold: 5,
|
||||
AttemptsWindow: 5 * time.Minute,
|
||||
BanDuration: 30 * time.Minute,
|
||||
}
|
||||
|
||||
func NewLoginLimiter(policy SecurityPolicy) *LoginLimiter {
|
||||
// 设置默认值
|
||||
if policy.AttemptsWindow == 0 {
|
||||
policy.AttemptsWindow = 5 * time.Minute
|
||||
}
|
||||
if policy.BanDuration == 0 {
|
||||
policy.BanDuration = 30 * time.Minute
|
||||
}
|
||||
|
||||
ll := &LoginLimiter{
|
||||
policy: policy,
|
||||
attempts: make(map[string][]time.Time),
|
||||
captchas: make(map[string]CaptchaMeta),
|
||||
bannedIPs: make(map[string]BanRecord),
|
||||
cleanupStop: make(chan struct{}),
|
||||
}
|
||||
go ll.cleanupRoutine()
|
||||
return ll
|
||||
}
|
||||
|
||||
// 注册验证码提供者
|
||||
func (ll *LoginLimiter) RegisterProvider(p CaptchaProvider) {
|
||||
ll.mu.Lock()
|
||||
defer ll.mu.Unlock()
|
||||
ll.provider = p
|
||||
}
|
||||
|
||||
// isDisabled 检查是否禁用登录限制
|
||||
func (ll *LoginLimiter) isDisabled() bool {
|
||||
return ll.policy.CaptchaThreshold < 0 && ll.policy.BanThreshold == 0
|
||||
}
|
||||
|
||||
// 记录登录失败尝试
|
||||
func (ll *LoginLimiter) RecordFailedAttempt(ip string) {
|
||||
if ll.isDisabled() {
|
||||
return
|
||||
}
|
||||
ll.mu.Lock()
|
||||
defer ll.mu.Unlock()
|
||||
|
||||
if banned, _ := ll.isBanned(ip); banned {
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-ll.policy.AttemptsWindow)
|
||||
|
||||
// 清理过期尝试
|
||||
validAttempts := ll.pruneAttempts(ip, windowStart)
|
||||
|
||||
// 记录新尝试
|
||||
validAttempts = append(validAttempts, now)
|
||||
ll.attempts[ip] = validAttempts
|
||||
|
||||
// 检查封禁条件
|
||||
if ll.policy.BanThreshold > 0 && len(validAttempts) >= ll.policy.BanThreshold {
|
||||
ll.banIP(ip, "excessive failed attempts")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
func (ll *LoginLimiter) RequireCaptcha() (error, CaptchaMeta) {
|
||||
ll.mu.Lock()
|
||||
defer ll.mu.Unlock()
|
||||
|
||||
if ll.provider == nil {
|
||||
return errors.New("no captcha provider available"), CaptchaMeta{}
|
||||
}
|
||||
|
||||
id, content, answer, err := ll.provider.Generate()
|
||||
if err != nil {
|
||||
return err, CaptchaMeta{}
|
||||
}
|
||||
|
||||
// 存储验证码
|
||||
ll.captchas[id] = CaptchaMeta{
|
||||
Id: id,
|
||||
Content: content,
|
||||
Answer: answer,
|
||||
ExpiresAt: time.Now().Add(ll.provider.Expiration()),
|
||||
}
|
||||
|
||||
return nil, ll.captchas[id]
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
func (ll *LoginLimiter) VerifyCaptcha(id, answer string) bool {
|
||||
ll.mu.Lock()
|
||||
defer ll.mu.Unlock()
|
||||
|
||||
// 查找匹配验证码
|
||||
if ll.provider == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取并验证验证码
|
||||
captcha, exists := ll.captchas[id]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// 清理过期验证码
|
||||
if time.Now().After(captcha.ExpiresAt) {
|
||||
delete(ll.captchas, id)
|
||||
return false
|
||||
}
|
||||
|
||||
// 验证并清理状态
|
||||
if answer == captcha.Answer {
|
||||
delete(ll.captchas, id)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (ll *LoginLimiter) DrawCaptcha(content string) (err error, str string) {
|
||||
str, err = ll.provider.Draw(content)
|
||||
return
|
||||
}
|
||||
|
||||
// 清除记录窗口
|
||||
func (ll *LoginLimiter) RemoveAttempts(ip string) {
|
||||
ll.mu.Lock()
|
||||
defer ll.mu.Unlock()
|
||||
|
||||
_, exists := ll.attempts[ip]
|
||||
if exists {
|
||||
delete(ll.attempts, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// CheckSecurityStatus 检查安全状态
|
||||
func (ll *LoginLimiter) CheckSecurityStatus(ip string) (banned bool, captchaRequired bool) {
|
||||
if ll.isDisabled() {
|
||||
return
|
||||
}
|
||||
ll.mu.Lock()
|
||||
defer ll.mu.Unlock()
|
||||
|
||||
// 检查封禁状态
|
||||
if banned, _ = ll.isBanned(ip); banned {
|
||||
return
|
||||
}
|
||||
|
||||
// 清理过期数据
|
||||
ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow))
|
||||
|
||||
// 检查验证码要求
|
||||
captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 后台清理任务
|
||||
func (ll *LoginLimiter) cleanupRoutine() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ll.cleanupExpired()
|
||||
case <-ll.cleanupStop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 内部工具方法
|
||||
func (ll *LoginLimiter) isBanned(ip string) (bool, BanRecord) {
|
||||
record, exists := ll.bannedIPs[ip]
|
||||
if !exists {
|
||||
return false, BanRecord{}
|
||||
}
|
||||
if time.Now().After(record.ExpiresAt) {
|
||||
delete(ll.bannedIPs, ip)
|
||||
return false, BanRecord{}
|
||||
}
|
||||
return true, record
|
||||
}
|
||||
|
||||
func (ll *LoginLimiter) banIP(ip, reason string) {
|
||||
ll.bannedIPs[ip] = BanRecord{
|
||||
ExpiresAt: time.Now().Add(ll.policy.BanDuration),
|
||||
Reason: reason,
|
||||
}
|
||||
delete(ll.attempts, ip)
|
||||
delete(ll.captchas, ip)
|
||||
}
|
||||
|
||||
func (ll *LoginLimiter) pruneAttempts(ip string, cutoff time.Time) []time.Time {
|
||||
var valid []time.Time
|
||||
for _, t := range ll.attempts[ip] {
|
||||
if t.After(cutoff) {
|
||||
valid = append(valid, t)
|
||||
}
|
||||
}
|
||||
if len(valid) == 0 {
|
||||
delete(ll.attempts, ip)
|
||||
} else {
|
||||
ll.attempts[ip] = valid
|
||||
}
|
||||
return valid
|
||||
}
|
||||
|
||||
func (ll *LoginLimiter) pruneCaptchas(id string) {
|
||||
if captcha, exists := ll.captchas[id]; exists {
|
||||
if time.Now().After(captcha.ExpiresAt) {
|
||||
delete(ll.captchas, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ll *LoginLimiter) cleanupExpired() {
|
||||
ll.mu.Lock()
|
||||
defer ll.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 清理封禁记录
|
||||
for ip, record := range ll.bannedIPs {
|
||||
if now.After(record.ExpiresAt) {
|
||||
delete(ll.bannedIPs, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// 清理尝试记录
|
||||
for ip := range ll.attempts {
|
||||
ll.pruneAttempts(ip, now.Add(-ll.policy.AttemptsWindow))
|
||||
}
|
||||
|
||||
// 清理验证码
|
||||
for id := range ll.captchas {
|
||||
ll.pruneCaptchas(id)
|
||||
}
|
||||
}
|
||||
290
utils/login_limiter_test.go
Normal file
290
utils/login_limiter_test.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MockCaptchaProvider struct{}
|
||||
|
||||
func (p *MockCaptchaProvider) Generate() (string, string, string, error) {
|
||||
id := uuid.New().String()
|
||||
content := uuid.New().String()
|
||||
answer := uuid.New().String()
|
||||
return id, content, answer, nil
|
||||
}
|
||||
|
||||
func (p *MockCaptchaProvider) Expiration() time.Duration {
|
||||
return 2 * time.Second
|
||||
}
|
||||
func (p *MockCaptchaProvider) Draw(content string) (string, error) {
|
||||
return "MOCK", nil
|
||||
}
|
||||
|
||||
func TestSecurityWorkflow(t *testing.T) {
|
||||
policy := SecurityPolicy{
|
||||
CaptchaThreshold: 3,
|
||||
BanThreshold: 5,
|
||||
AttemptsWindow: 5 * time.Minute,
|
||||
BanDuration: 5 * time.Minute,
|
||||
}
|
||||
limiter := NewLoginLimiter(policy)
|
||||
ip := "192.168.1.100"
|
||||
|
||||
// 测试正常失败记录
|
||||
for i := 0; i < 3; i++ {
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
}
|
||||
isBanned, capRequired := limiter.CheckSecurityStatus(ip)
|
||||
fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, isBanned, capRequired)
|
||||
if isBanned {
|
||||
t.Error("IP should not be banned yet")
|
||||
}
|
||||
if !capRequired {
|
||||
t.Error("Captcha should be required")
|
||||
}
|
||||
// 测试触发封禁
|
||||
for i := 0; i < 3; i++ {
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
isBanned, capRequired = limiter.CheckSecurityStatus(ip)
|
||||
fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, isBanned, capRequired)
|
||||
}
|
||||
|
||||
// 测试封禁状态
|
||||
if isBanned, _ = limiter.CheckSecurityStatus(ip); !isBanned {
|
||||
t.Error("IP should be banned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptchaFlow(t *testing.T) {
|
||||
policy := SecurityPolicy{CaptchaThreshold: 2}
|
||||
limiter := NewLoginLimiter(policy)
|
||||
limiter.RegisterProvider(&MockCaptchaProvider{})
|
||||
ip := "10.0.0.1"
|
||||
|
||||
// 触发验证码要求
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
|
||||
// 检查状态
|
||||
if _, need := limiter.CheckSecurityStatus(ip); !need {
|
||||
t.Error("应该需要验证码")
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
err, capc := limiter.RequireCaptcha()
|
||||
if err != nil {
|
||||
t.Fatalf("生成验证码失败: %v", err)
|
||||
}
|
||||
fmt.Printf("验证码内容: %#v\n", capc)
|
||||
|
||||
// 验证成功
|
||||
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||
t.Error("验证码应该验证成功")
|
||||
}
|
||||
|
||||
// 验证已删除
|
||||
if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||
t.Error("验证码应该已删除")
|
||||
}
|
||||
|
||||
limiter.RemoveAttempts(ip)
|
||||
// 验证后状态
|
||||
if banned, need := limiter.CheckSecurityStatus(ip); banned || need {
|
||||
t.Error("验证成功后应该重置状态")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptchaMustFlow(t *testing.T) {
|
||||
policy := SecurityPolicy{CaptchaThreshold: 0}
|
||||
limiter := NewLoginLimiter(policy)
|
||||
limiter.RegisterProvider(&MockCaptchaProvider{})
|
||||
ip := "10.0.0.1"
|
||||
|
||||
// 检查状态
|
||||
if _, need := limiter.CheckSecurityStatus(ip); !need {
|
||||
t.Error("应该需要验证码")
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
err, capc := limiter.RequireCaptcha()
|
||||
if err != nil {
|
||||
t.Fatalf("生成验证码失败: %v", err)
|
||||
}
|
||||
fmt.Printf("验证码内容: %#v\n", capc)
|
||||
|
||||
// 验证成功
|
||||
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||
t.Error("验证码应该验证成功")
|
||||
}
|
||||
|
||||
// 验证后状态
|
||||
if _, need := limiter.CheckSecurityStatus(ip); !need {
|
||||
t.Error("应该需要验证码")
|
||||
}
|
||||
}
|
||||
func TestAttemptTimeout(t *testing.T) {
|
||||
policy := SecurityPolicy{CaptchaThreshold: 2, AttemptsWindow: 1 * time.Second}
|
||||
limiter := NewLoginLimiter(policy)
|
||||
limiter.RegisterProvider(&MockCaptchaProvider{})
|
||||
ip := "10.0.0.1"
|
||||
|
||||
// 触发验证码要求
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
|
||||
// 检查状态
|
||||
if _, need := limiter.CheckSecurityStatus(ip); !need {
|
||||
t.Error("应该需要验证码")
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
err, _ := limiter.RequireCaptcha()
|
||||
if err != nil {
|
||||
t.Fatalf("生成验证码失败: %v", err)
|
||||
}
|
||||
// 等待超过 AttemptsWindow
|
||||
time.Sleep(2 * time.Second)
|
||||
// 触发验证码要求
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
|
||||
// 检查状态
|
||||
if _, need := limiter.CheckSecurityStatus(ip); need {
|
||||
t.Error("不应该需要验证码")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptchaTimeout(t *testing.T) {
|
||||
policy := SecurityPolicy{CaptchaThreshold: 2}
|
||||
limiter := NewLoginLimiter(policy)
|
||||
limiter.RegisterProvider(&MockCaptchaProvider{})
|
||||
ip := "10.0.0.1"
|
||||
|
||||
// 触发验证码要求
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
|
||||
// 检查状态
|
||||
if _, need := limiter.CheckSecurityStatus(ip); !need {
|
||||
t.Error("应该需要验证码")
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
err, capc := limiter.RequireCaptcha()
|
||||
if err != nil {
|
||||
t.Fatalf("生成验证码失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待超过 CaptchaValidPeriod
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// 验证成功
|
||||
if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||
t.Error("验证码应该已过期")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestBanFlow(t *testing.T) {
|
||||
policy := SecurityPolicy{BanThreshold: 5}
|
||||
limiter := NewLoginLimiter(policy)
|
||||
ip := "10.0.0.1"
|
||||
// 触发ban
|
||||
for i := 0; i < 5; i++ {
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
}
|
||||
|
||||
// 检查状态
|
||||
if banned, _ := limiter.CheckSecurityStatus(ip); !banned {
|
||||
t.Error("should be banned")
|
||||
}
|
||||
}
|
||||
func TestBanDisableFlow(t *testing.T) {
|
||||
policy := SecurityPolicy{BanThreshold: 0}
|
||||
limiter := NewLoginLimiter(policy)
|
||||
ip := "10.0.0.1"
|
||||
// 触发ban
|
||||
for i := 0; i < 5; i++ {
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
}
|
||||
|
||||
// 检查状态
|
||||
if banned, _ := limiter.CheckSecurityStatus(ip); banned {
|
||||
t.Error("should not be banned")
|
||||
}
|
||||
}
|
||||
func TestBanTimeout(t *testing.T) {
|
||||
policy := SecurityPolicy{BanThreshold: 5, BanDuration: 1 * time.Second}
|
||||
limiter := NewLoginLimiter(policy)
|
||||
ip := "10.0.0.1"
|
||||
// 触发ban
|
||||
// 触发ban
|
||||
for i := 0; i < 5; i++ {
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 检查状态
|
||||
if banned, _ := limiter.CheckSecurityStatus(ip); banned {
|
||||
t.Error("should not be banned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiterDisabled(t *testing.T) {
|
||||
policy := SecurityPolicy{BanThreshold: 0, CaptchaThreshold: -1}
|
||||
limiter := NewLoginLimiter(policy)
|
||||
ip := "10.0.0.1"
|
||||
// 触发ban
|
||||
for i := 0; i < 5; i++ {
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
}
|
||||
|
||||
// 检查状态
|
||||
if banned, capNeed := limiter.CheckSecurityStatus(ip); banned || capNeed {
|
||||
fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, banned, capNeed)
|
||||
t.Error("should not be banned or need captcha")
|
||||
}
|
||||
}
|
||||
|
||||
func TestB64CaptchaFlow(t *testing.T) {
|
||||
limiter := NewLoginLimiter(defaultSecurityPolicy)
|
||||
limiter.RegisterProvider(B64StringCaptchaProvider{})
|
||||
ip := "10.0.0.1"
|
||||
|
||||
// 触发验证码要求
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
limiter.RecordFailedAttempt(ip)
|
||||
|
||||
// 检查状态
|
||||
if _, need := limiter.CheckSecurityStatus(ip); !need {
|
||||
t.Error("应该需要验证码")
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
err, capc := limiter.RequireCaptcha()
|
||||
if err != nil {
|
||||
t.Fatalf("生成验证码失败: %v", err)
|
||||
}
|
||||
fmt.Printf("验证码内容: %#v\n", capc)
|
||||
|
||||
//draw
|
||||
err, b64 := limiter.DrawCaptcha(capc.Content)
|
||||
if err != nil {
|
||||
t.Fatalf("绘制验证码失败: %v", err)
|
||||
}
|
||||
fmt.Printf("验证码内容: %#v\n", b64)
|
||||
|
||||
// 验证成功
|
||||
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||
t.Error("验证码应该验证成功")
|
||||
}
|
||||
limiter.RemoveAttempts(ip)
|
||||
// 验证后状态
|
||||
if banned, need := limiter.CheckSecurityStatus(ip); banned || need {
|
||||
t.Error("验证成功后应该重置状态")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user