Compare commits

...

26 Commits

Author SHA1 Message Date
lejianwen
2948eaaa5c chore: Update Go version to 1.23 in build configurations 2025-06-16 15:41:16 +08:00
lejianwen
8641ba5c0c docs: Update swagger docs 2025-06-16 12:31:48 +08:00
lejianwen
60b7a18fe7 feat: Add PostgreSQL support and refactor MySQL DSN handling (#284) 2025-06-16 12:26:08 +08:00
lejianwen
ca068816ae feat: Add start time in /api/sysinfover 2025-06-16 12:23:48 +08:00
lejianwen
06648d9a6c fix(admin): Use admin-hello first
(#274) (#255)
2025-06-15 15:33:12 +08:00
puyujian
8a8abd5163 feat(oauth): 支持linux.do登录 (#280)
* 支持linux.do登录

* 修正
2025-06-15 15:32:20 +08:00
lejianwen
97f98cd6ce chore: update download links for musl cross-compilers 2025-06-05 12:14:17 +08:00
lejianwen
51f2920661 fix: Init sqlite fail(#266) 2025-06-04 09:31:43 +08:00
lejianwen
7a5d141ce8 fix(server): Port custom (#257) 2025-05-30 12:27:37 +08:00
lejianwen
3cef02a0bb fix(webclient): Peer online status 2025-05-29 18:51:37 +08:00
lejianwen
46a7ecc1ba fix: Captcha some problem when users login with same ip 2025-05-27 17:36:20 +08:00
lejianwen
4d2b037f5e docs: Readme 2025-05-25 17:44:29 +08:00
lejianwen
323364b24e feat(register): Register status can be set (#223) 2025-05-25 17:03:13 +08:00
lejianwen
f19109cdf8 feat(login): Captcha upgrade and add the function to ban IP addresses (#250) 2025-05-25 16:52:58 +08:00
Tao Chen
527260d60a fix: dn should be case-insensitive (#250) 2025-05-21 09:07:08 +08:00
lejianwen
46bb44f0ab fix(webclient): DefaultIdServerPort undefined (#238) 2025-05-16 20:14:36 +08:00
lejianwen
2f1380f24a fix(webclient): Remove license warning (#235) 2025-05-13 13:11:19 +08:00
lejianwen
ece3328e94 feat(webclient): Web client to 1.4.0 2025-05-12 20:16:08 +08:00
lejianwen
fdd26d87be fix: PageSize (#225) 2025-05-06 19:08:18 +08:00
lejianwen
2ade0dda42 chore: Noelware/docker-manifest-action 2025-04-25 16:20:36 +08:00
lejianwen
a87ae5cf65 chore: Noelware/docker-manifest-action 2025-04-25 14:34:45 +08:00
lejianwen
fe7b8b53a6 style: Oauth page languages 2025-04-24 21:52:43 +08:00
lejianwen
b929f3efdb style: Remove useless configurations 2025-04-15 10:52:46 +08:00
lejianwen
f847fc076f fix: Low case (#149) 2025-04-15 10:46:21 +08:00
lejianwen
60d0a701ce fix: Share pwd 2025-04-15 10:09:56 +08:00
lejianwen
0dedaf6824 feat: Peer share to group 2025-04-14 19:12:40 +08:00
68 changed files with 93285 additions and 86473 deletions

View File

@@ -66,7 +66,7 @@ jobs:
- name: Set up Go environment - name: Set up Go environment
uses: actions/setup-go@v4 uses: actions/setup-go@v4
with: with:
go-version: '1.22' # 选择 Go 版本 go-version: '1.23' # 选择 Go 版本
- name: Set up npm - name: Set up npm
uses: actions/setup-node@v2 uses: actions/setup-node@v2
@@ -115,12 +115,12 @@ jobs:
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
else else
if [ "${{ matrix.job.platform }}" = "arm64" ]; then if [ "${{ matrix.job.platform }}" = "arm64" ]; then
wget https://musl.cc/aarch64-linux-musl-cross.tgz wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
tar -xf aarch64-linux-musl-cross.tgz tar -xf aarch64-linux-musl-cross.tgz
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
tar -xf armv7l-linux-musleabihf-cross.tgz tar -xf armv7l-linux-musleabihf-cross.tgz
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
@@ -147,6 +147,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Generate Changelog - name: Generate Changelog
if: startsWith(github.ref, 'refs/tags/') && github.event_name == 'push'
run: npx changelogithub # or changelogithub@0.12 if ensure the stable result run: npx changelogithub # or changelogithub@0.12 if ensure the stable result
env: env:
GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
@@ -380,7 +381,7 @@ jobs:
- name: Create and push manifest Docker Hub (:version) - name: Create and push manifest Docker Hub (:version)
if: ${{ env.SKIP_DOCKER_HUB == 'false' }} if: ${{ env.SKIP_DOCKER_HUB == 'false' }}
uses: Noelware/docker-manifest-action@master uses: Noelware/docker-manifest-action@v0.2.3
with: with:
base-image: ${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }} base-image: ${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}
extra-images: ${{ env.DOCKERHUB_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}-amd64, extra-images: ${{ env.DOCKERHUB_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}-amd64,
@@ -390,7 +391,7 @@ jobs:
- name: Create and push manifest GHCR (:version) - name: Create and push manifest GHCR (:version)
if: ${{ env.SKIP_GHCR == 'false' }} if: ${{ env.SKIP_GHCR == 'false' }}
uses: Noelware/docker-manifest-action@master uses: Noelware/docker-manifest-action@v0.2.3
with: with:
base-image: ghcr.io/${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }} base-image: ghcr.io/${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}
extra-images: ghcr.io/${{ env.GHCR_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}-amd64, extra-images: ghcr.io/${{ env.GHCR_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}-amd64,
@@ -401,7 +402,7 @@ jobs:
- name: Create and push manifest Docker Hub (:latest) - name: Create and push manifest Docker Hub (:latest)
if: ${{ env.SKIP_DOCKER_HUB == 'false' }} if: ${{ env.SKIP_DOCKER_HUB == 'false' }}
uses: Noelware/docker-manifest-action@master uses: Noelware/docker-manifest-action@v0.2.3
with: with:
base-image: ${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:latest base-image: ${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:latest
extra-images: ${{ env.DOCKERHUB_IMAGE_NAMESPACE }}/rustdesk-api:latest-amd64, extra-images: ${{ env.DOCKERHUB_IMAGE_NAMESPACE }}/rustdesk-api:latest-amd64,
@@ -411,7 +412,7 @@ jobs:
- name: Create and push manifest GHCR (:latest) - name: Create and push manifest GHCR (:latest)
if: ${{ env.SKIP_GHCR == 'false' }} if: ${{ env.SKIP_GHCR == 'false' }}
uses: Noelware/docker-manifest-action@master uses: Noelware/docker-manifest-action@v0.2.3
with: with:
base-image: ghcr.io/${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:latest base-image: ghcr.io/${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:latest
extra-images: ghcr.io/${{ env.GHCR_IMAGE_NAMESPACE }}/rustdesk-api:latest-amd64, extra-images: ghcr.io/${{ env.GHCR_IMAGE_NAMESPACE }}/rustdesk-api:latest-amd64,
@@ -422,7 +423,7 @@ jobs:
- name: Create and push Full S6 manifest Docker Hub (:version) - name: Create and push Full S6 manifest Docker Hub (:version)
if: ${{ env.SKIP_DOCKER_HUB == 'false' }} if: ${{ env.SKIP_DOCKER_HUB == 'false' }}
uses: Noelware/docker-manifest-action@master uses: Noelware/docker-manifest-action@v0.2.3
with: with:
base-image: ${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:full-s6 base-image: ${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:full-s6
extra-images: ${{ env.DOCKERHUB_IMAGE_NAMESPACE }}/rustdesk-api:full-s6-amd64, extra-images: ${{ env.DOCKERHUB_IMAGE_NAMESPACE }}/rustdesk-api:full-s6-amd64,
@@ -433,7 +434,7 @@ jobs:
- name: Create and push Full S6 manifest GHCR (:latest) - name: Create and push Full S6 manifest GHCR (:latest)
if: ${{ env.SKIP_GHCR == 'false' }} if: ${{ env.SKIP_GHCR == 'false' }}
uses: Noelware/docker-manifest-action@master uses: Noelware/docker-manifest-action@v0.2.3
with: with:
base-image: ghcr.io/${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:full-s6 base-image: ghcr.io/${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:full-s6
extra-images: ghcr.io/${{ env.GHCR_IMAGE_NAMESPACE }}/rustdesk-api:full-s6-amd64, extra-images: ghcr.io/${{ env.GHCR_IMAGE_NAMESPACE }}/rustdesk-api:full-s6-amd64,

View File

@@ -61,7 +61,7 @@ jobs:
- name: Set up Go environment - name: Set up Go environment
uses: actions/setup-go@v4 uses: actions/setup-go@v4
with: with:
go-version: '1.22' # 选择 Go 版本 go-version: '1.23' # 选择 Go 版本
- name: Set up npm - name: Set up npm
uses: actions/setup-node@v2 uses: actions/setup-node@v2
@@ -101,12 +101,12 @@ jobs:
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
else else
if [ "${{ matrix.job.platform }}" = "arm64" ]; then if [ "${{ matrix.job.platform }}" = "arm64" ]; then
wget https://musl.cc/aarch64-linux-musl-cross.tgz wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
tar -xf aarch64-linux-musl-cross.tgz tar -xf aarch64-linux-musl-cross.tgz
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
tar -xf armv7l-linux-musleabihf-cross.tgz tar -xf armv7l-linux-musleabihf-cross.tgz
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
@@ -317,7 +317,7 @@ jobs:
- name: Create and push manifest Docker Hub (:version) - name: Create and push manifest Docker Hub (:version)
if: ${{ env.SKIP_DOCKER_HUB == 'false' }} if: ${{ env.SKIP_DOCKER_HUB == 'false' }}
uses: Noelware/docker-manifest-action@master uses: Noelware/docker-manifest-action@v0.2.3
with: with:
base-image: ${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }} base-image: ${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}
extra-images: ${{ env.DOCKERHUB_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}-amd64, extra-images: ${{ env.DOCKERHUB_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}-amd64,
@@ -327,7 +327,7 @@ jobs:
- name: Create and push manifest GHCR (:version) - name: Create and push manifest GHCR (:version)
if: ${{ env.SKIP_GHCR == 'false' }} if: ${{ env.SKIP_GHCR == 'false' }}
uses: Noelware/docker-manifest-action@master uses: Noelware/docker-manifest-action@v0.2.3
with: with:
base-image: ghcr.io/${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }} base-image: ghcr.io/${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}
extra-images: ghcr.io/${{ env.GHCR_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}-amd64, extra-images: ghcr.io/${{ env.GHCR_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}-amd64,

2
.gitignore vendored
View File

@@ -5,4 +5,4 @@ runtime/*
go.sum go.sum
resources/admin resources/admin
release release
data data/rustdeskapi.db

View File

@@ -163,6 +163,9 @@
| RUSTDESK_API_APP_SHOW_SWAGGER | 是否可见swagger文档;`1`显示,`0`不显示,默认`0`不显示 | `1` | | RUSTDESK_API_APP_SHOW_SWAGGER | 是否可见swagger文档;`1`显示,`0`不显示,默认`0`不显示 | `1` |
| RUSTDESK_API_APP_TOKEN_EXPIRE | token有效时长 | `168h` | | RUSTDESK_API_APP_TOKEN_EXPIRE | token有效时长 | `168h` |
| RUSTDESK_API_APP_DISABLE_PWD_LOGIN | 是否禁用密码登录; `true`, `false` 默认`false` | `false` | | RUSTDESK_API_APP_DISABLE_PWD_LOGIN | 是否禁用密码登录; `true`, `false` 默认`false` | `false` |
| RUSTDESK_API_APP_REGISTER_STATUS | 注册用户默认状态; 1 启用2 禁用, 默认 1 | `1` |
| RUSTDESK_API_APP_CAPTCHA_THRESHOLD | 验证码触发次数; -1 不启用, 0 一直启用, >0 登录错误次数后启用 ;默认 `3` | `3` |
| RUSTDESK_API_APP_BAN_THRESHOLD | 封禁IP触发次数; 0 不启用, >0 登录错误次数后封禁IP; 默认 `0` | `0` |
| -----ADMIN配置----- | ---------- | ---------- | | -----ADMIN配置----- | ---------- | ---------- |
| RUSTDESK_API_ADMIN_TITLE | 后台标题 | `RustDesk Api Admin` | | RUSTDESK_API_ADMIN_TITLE | 后台标题 | `RustDesk Api Admin` |
| RUSTDESK_API_ADMIN_HELLO | 后台欢迎语,可以使用`html` | | | RUSTDESK_API_ADMIN_HELLO | 后台欢迎语,可以使用`html` | |

View File

@@ -162,6 +162,9 @@ The table below does not list all configurations. Please refer to the configurat
| RUSTDESK_API_APP_SHOW_SWAGGER | swagger visible; 1: yes, 0: no; default: 0 | `0` | | RUSTDESK_API_APP_SHOW_SWAGGER | swagger visible; 1: yes, 0: no; default: 0 | `0` |
| RUSTDESK_API_APP_TOKEN_EXPIRE | token expire duration | `168h` | | RUSTDESK_API_APP_TOKEN_EXPIRE | token expire duration | `168h` |
| RUSTDESK_API_APP_DISABLE_PWD_LOGIN | disable password login | `false` | | RUSTDESK_API_APP_DISABLE_PWD_LOGIN | disable password login | `false` |
| RUSTDESK_API_APP_REGISTER_STATUS | register user default status ; 1 enabled , 2 disabled ; default 1 | `1` |
| RUSTDESK_API_APP_CAPTCHA_THRESHOLD | captcha threshold; -1 disabled, 0 always enable, >0 threshold ;default `3` | `3` |
| RUSTDESK_API_APP_BAN_THRESHOLD | ban ip threshold; 0 disabled, >0 threshold ; default `0` | `0` |
| ----- ADMIN Configuration----- | ---------- | ---------- | | ----- ADMIN Configuration----- | ---------- | ---------- |
| RUSTDESK_API_ADMIN_TITLE | Admin Title | `RustDesk Api Admin` | | RUSTDESK_API_ADMIN_TITLE | Admin Title | `RustDesk Api Admin` |
| RUSTDESK_API_ADMIN_HELLO | Admin welcome message, you can use `html` | | | RUSTDESK_API_ADMIN_HELLO | Admin welcome message, you can use `html` | |

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"fmt"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/lejianwen/rustdesk-api/v2/config" "github.com/lejianwen/rustdesk-api/v2/config"
"github.com/lejianwen/rustdesk-api/v2/global" "github.com/lejianwen/rustdesk-api/v2/global"
@@ -18,6 +19,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"os" "os"
"strconv" "strconv"
"time"
) )
// @title 管理系统API // @title 管理系统API
@@ -139,18 +141,40 @@ func InitGlobal() {
} }
//gorm //gorm
if global.Config.Gorm.Type == config.TypeMysql { if global.Config.Gorm.Type == config.TypeMysql {
dns := global.Config.Mysql.Username + ":" + global.Config.Mysql.Password + "@(" + global.Config.Mysql.Addr + ")/" + global.Config.Mysql.Dbname + "?charset=utf8mb4&parseTime=True&loc=Local"
dsn := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
global.Config.Mysql.Username,
global.Config.Mysql.Password,
global.Config.Mysql.Addr,
global.Config.Mysql.Dbname,
)
global.DB = orm.NewMysql(&orm.MysqlConfig{ global.DB = orm.NewMysql(&orm.MysqlConfig{
Dns: dns, Dsn: dsn,
MaxIdleConns: global.Config.Gorm.MaxIdleConns, MaxIdleConns: global.Config.Gorm.MaxIdleConns,
MaxOpenConns: global.Config.Gorm.MaxOpenConns, MaxOpenConns: global.Config.Gorm.MaxOpenConns,
}) }, global.Logger)
} else if global.Config.Gorm.Type == config.TypePostgresql {
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
global.Config.Postgresql.Host,
global.Config.Postgresql.Port,
global.Config.Postgresql.User,
global.Config.Postgresql.Password,
global.Config.Postgresql.Dbname,
global.Config.Postgresql.Sslmode,
global.Config.Postgresql.TimeZone,
)
global.DB = orm.NewPostgresql(&orm.PostgresqlConfig{
Dsn: dsn,
MaxIdleConns: global.Config.Gorm.MaxIdleConns,
MaxOpenConns: global.Config.Gorm.MaxOpenConns,
}, global.Logger)
} else { } else {
//sqlite //sqlite
global.DB = orm.NewSqlite(&orm.SqliteConfig{ global.DB = orm.NewSqlite(&orm.SqliteConfig{
MaxIdleConns: global.Config.Gorm.MaxIdleConns, MaxIdleConns: global.Config.Gorm.MaxIdleConns,
MaxOpenConns: global.Config.Gorm.MaxOpenConns, MaxOpenConns: global.Config.Gorm.MaxOpenConns,
}) }, global.Logger)
} }
//validator //validator
@@ -175,8 +199,16 @@ func InitGlobal() {
//service //service
service.New(&global.Config, global.DB, global.Logger, global.Jwt, global.Lock) service.New(&global.Config, global.DB, global.Logger, global.Jwt, global.Lock)
global.LoginLimiter = utils.NewLoginLimiter(utils.SecurityPolicy{
CaptchaThreshold: global.Config.App.CaptchaThreshold,
BanThreshold: global.Config.App.BanThreshold,
AttemptsWindow: 10 * time.Minute,
BanDuration: 30 * time.Minute,
})
global.LoginLimiter.RegisterProvider(utils.B64StringCaptchaProvider{})
DatabaseAutoUpdate() DatabaseAutoUpdate()
} }
func DatabaseAutoUpdate() { func DatabaseAutoUpdate() {
version := 262 version := 262
@@ -188,11 +220,17 @@ func DatabaseAutoUpdate() {
if dbName == "" { if dbName == "" {
dbName = global.Config.Mysql.Dbname dbName = global.Config.Mysql.Dbname
// 移除 DSN 中的数据库名称,以便初始连接时不指定数据库 // 移除 DSN 中的数据库名称,以便初始连接时不指定数据库
dsnWithoutDB := global.Config.Mysql.Username + ":" + global.Config.Mysql.Password + "@(" + global.Config.Mysql.Addr + ")/?charset=utf8mb4&parseTime=True&loc=Local" dsnWithoutDB := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
global.Config.Mysql.Username,
global.Config.Mysql.Password,
global.Config.Mysql.Addr,
"",
)
//新链接 //新链接
dbWithoutDB := orm.NewMysql(&orm.MysqlConfig{ dbWithoutDB := orm.NewMysql(&orm.MysqlConfig{
Dns: dsnWithoutDB, Dsn: dsnWithoutDB,
}) }, global.Logger)
// 获取底层的 *sql.DB 对象,并确保在程序退出时关闭连接 // 获取底层的 *sql.DB 对象,并确保在程序退出时关闭连接
sqlDBWithoutDB, err := dbWithoutDB.DB() sqlDBWithoutDB, err := dbWithoutDB.DB()
if err != nil { if err != nil {

View File

@@ -2,14 +2,21 @@ lang: "zh-CN"
app: app:
web-client: 1 # 1:启用 0:禁用 web-client: 1 # 1:启用 0:禁用
register: false #是否开启注册 register: false #是否开启注册
register-status: 1 # 注册用户默认状态 1:启用 2:禁用
captcha-threshold: 3 # <0:disabled, 0 always, >0:enabled
ban-threshold: 0 # 0:disabled, >0:enabled
show-swagger: 0 # 1:启用 0:禁用 show-swagger: 0 # 1:启用 0:禁用
token-expire: 168h token-expire: 168h
web-sso: true #web auth sso web-sso: true #web auth sso
disable-pwd-login: false #禁用密码登录 disable-pwd-login: false #禁用密码登录
admin: admin:
title: "RustDesk Api Admin" title: "RustDesk API Admin"
hello-file: "./conf/admin/hello.html" #优先使用file hello-file: "./conf/admin/hello.html" #优先使用file
hello: "" hello: ""
# ID Server and Relay Server ports https://github.com/lejianwen/rustdesk-api/issues/257
id-server-port: 21116 # ID Server port (for server cmd)
relay-server-port: 21117 # ID Server port (for server cmd)
gin: gin:
api-addr: "0.0.0.0:21114" api-addr: "0.0.0.0:21114"
mode: "release" #release,debug,test mode: "release" #release,debug,test
@@ -24,6 +31,16 @@ mysql:
password: "" password: ""
addr: "" addr: ""
dbname: "" dbname: ""
postgresql:
host: "127.0.0.1"
port: "5432"
user: ""
password: ""
dbname: "postgres"
sslmode: "disable" # disable, require, verify-ca, verify-full
time-zone: "Asia/Shanghai" # Time zone for PostgreSQL connection
rustdesk: rustdesk:
id-server: "192.168.1.66:21116" id-server: "192.168.1.66:21116"
relay-server: "192.168.1.66:21117" relay-server: "192.168.1.66:21117"
@@ -64,21 +81,3 @@ ldap:
sync: false # If true, the user will be synchronized to the database when the user logs in. If false, the user will be synchronized to the database when the user be created. sync: false # If true, the user will be synchronized to the database when the user logs in. If false, the user will be synchronized to the database when the user be created.
admin-group: "cn=admin,dc=example,dc=com" # The group name of the admin group, if the user is in this group, the user will be an admin. admin-group: "cn=admin,dc=example,dc=com" # The group name of the admin group, if the user is in this group, the user will be an admin.
redis:
addr: "127.0.0.1:6379"
password: ""
db: 0
cache:
type: "file"
file-dir: "./runtime/cache"
redis-addr: "127.0.0.1:6379"
redis-pwd: ""
redis-db: 0
oss:
access-key-id: ""
access-key-secret: ""
host: ""
callback-url: ""
expire-time: 30
max-byte: 10240

View File

@@ -16,15 +16,20 @@ const (
type App struct { type App struct {
WebClient int `mapstructure:"web-client"` WebClient int `mapstructure:"web-client"`
Register bool `mapstructure:"register"` Register bool `mapstructure:"register"`
RegisterStatus int `mapstructure:"register-status"`
ShowSwagger int `mapstructure:"show-swagger"` ShowSwagger int `mapstructure:"show-swagger"`
TokenExpire time.Duration `mapstructure:"token-expire"` TokenExpire time.Duration `mapstructure:"token-expire"`
WebSso bool `mapstructure:"web-sso"` WebSso bool `mapstructure:"web-sso"`
DisablePwdLogin bool `mapstructure:"disable-pwd-login"` DisablePwdLogin bool `mapstructure:"disable-pwd-login"`
CaptchaThreshold int `mapstructure:"captcha-threshold"`
BanThreshold int `mapstructure:"ban-threshold"`
} }
type Admin struct { type Admin struct {
Title string `mapstructure:"title"` Title string `mapstructure:"title"`
Hello string `mapstructure:"hello"` Hello string `mapstructure:"hello"`
HelloFile string `mapstructure:"hello-file"` HelloFile string `mapstructure:"hello-file"`
IdServerPort int `mapstructure:"id-server-port"`
RelayServerPort int `mapstructure:"relay-server-port"`
} }
type Config struct { type Config struct {
Lang string `mapstructure:"lang"` Lang string `mapstructure:"lang"`
@@ -32,6 +37,7 @@ type Config struct {
Admin Admin Admin Admin
Gorm Gorm Gorm Gorm
Mysql Mysql Mysql Mysql
Postgresql Postgresql
Gin Gin Gin Gin
Logger Logger Logger Logger
Redis Redis Redis Redis
@@ -43,6 +49,15 @@ type Config struct {
Ldap Ldap Ldap Ldap
} }
func (a *Admin) Init() {
if a.IdServerPort == 0 {
a.IdServerPort = DefaultIdServerPort
}
if a.RelayServerPort == 0 {
a.RelayServerPort = DefaultRelayServerPort
}
}
// Init 初始化配置 // Init 初始化配置
func Init(rowVal *Config, path string) *viper.Viper { func Init(rowVal *Config, path string) *viper.Viper {
if path == "" { if path == "" {
@@ -77,7 +92,7 @@ func Init(rowVal *Config, path string) *viper.Viper {
panic(fmt.Errorf("Fatal error config: %s \n", err)) panic(fmt.Errorf("Fatal error config: %s \n", err))
} }
rowVal.Rustdesk.LoadKeyFile() rowVal.Rustdesk.LoadKeyFile()
rowVal.Rustdesk.ParsePort() rowVal.Admin.Init()
return v return v
} }

View File

@@ -3,6 +3,7 @@ package config
const ( const (
TypeSqlite = "sqlite" TypeSqlite = "sqlite"
TypeMysql = "mysql" TypeMysql = "mysql"
TypePostgresql = "postgresql"
) )
type Gorm struct { type Gorm struct {
@@ -17,3 +18,13 @@ type Mysql struct {
Password string `mapstructure:"password"` Password string `mapstructure:"password"`
Dbname string `mapstructure:"dbname"` Dbname string `mapstructure:"dbname"`
} }
type Postgresql struct {
Host string `mapstructure:"host"`
Port string `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
Dbname string `mapstructure:"dbname"`
Sslmode string `mapstructure:"sslmode"` // "disable", "require", "verify-ca", "verify-full"
TimeZone string `mapstructure:"time-zone"` // e.g., "Asia/Shanghai"
}

View File

@@ -18,3 +18,9 @@ type OidcOauth struct {
ClientSecret string `mapstructure:"client-secret"` ClientSecret string `mapstructure:"client-secret"`
RedirectUrl string `mapstructure:"redirect-url"` RedirectUrl string `mapstructure:"redirect-url"`
} }
type LinuxdoOauth struct {
ClientId string `mapstructure:"client-id"`
ClientSecret string `mapstructure:"client-secret"`
RedirectUrl string `mapstructure:"redirect-url"`
}

View File

@@ -2,8 +2,6 @@ package config
import ( import (
"os" "os"
"strconv"
"strings"
) )
const ( const (
@@ -40,19 +38,3 @@ func (rd *Rustdesk) LoadKeyFile() {
return return
} }
} }
func (rd *Rustdesk) ParsePort() {
// Parse port
idres := strings.Split(rd.IdServer, ":")
if len(idres) == 1 {
rd.IdServerPort = DefaultIdServerPort
} else if len(idres) == 2 {
rd.IdServerPort, _ = strconv.Atoi(idres[1])
}
relayres := strings.Split(rd.RelayServer, ":")
if len(relayres) == 1 {
rd.RelayServerPort = DefaultRelayServerPort
} else if len(relayres) == 2 {
rd.RelayServerPort, _ = strconv.Atoi(relayres[1])
}
}

0
data/.gitkeep Normal file
View File

View File

@@ -5828,6 +5828,9 @@ const docTemplateadmin = `{
"captcha": { "captcha": {
"type": "string" "type": "string"
}, },
"captcha_id": {
"type": "string"
},
"password": { "password": {
"type": "string" "type": "string"
}, },

View File

@@ -5821,6 +5821,9 @@
"captcha": { "captcha": {
"type": "string" "type": "string"
}, },
"captcha_id": {
"type": "string"
},
"password": { "password": {
"type": "string" "type": "string"
}, },

View File

@@ -297,6 +297,8 @@ definitions:
properties: properties:
captcha: captcha:
type: string type: string
captcha_id:
type: string
password: password:
type: string type: string
platform: platform:

View File

@@ -1208,7 +1208,7 @@ const docTemplateapi = `{
"application/json" "application/json"
], ],
"tags": [ "tags": [
"地址" "System"
], ],
"summary": "提交系统信息", "summary": "提交系统信息",
"parameters": [ "parameters": [
@@ -1238,6 +1238,35 @@ const docTemplateapi = `{
} }
} }
}, },
"/sysinfo_ver": {
"post": {
"description": "获取系统版本信息",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"System"
],
"summary": "获取系统版本信息",
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "string"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.ErrorResponse"
}
}
}
}
},
"/users": { "/users": {
"get": { "get": {
"security": [ "security": [

View File

@@ -1201,7 +1201,7 @@
"application/json" "application/json"
], ],
"tags": [ "tags": [
"地址" "System"
], ],
"summary": "提交系统信息", "summary": "提交系统信息",
"parameters": [ "parameters": [
@@ -1231,6 +1231,35 @@
} }
} }
}, },
"/sysinfo_ver": {
"post": {
"description": "获取系统版本信息",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"System"
],
"summary": "获取系统版本信息",
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "string"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.ErrorResponse"
}
}
}
}
},
"/users": { "/users": {
"get": { "get": {
"security": [ "security": [

View File

@@ -973,7 +973,26 @@ paths:
$ref: '#/definitions/response.ErrorResponse' $ref: '#/definitions/response.ErrorResponse'
summary: 提交系统信息 summary: 提交系统信息
tags: tags:
- 地址 - System
/sysinfo_ver:
post:
consumes:
- application/json
description: 获取系统版本信息
produces:
- application/json
responses:
"200":
description: OK
schema:
type: string
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/response.ErrorResponse'
summary: 获取系统版本信息
tags:
- System
/users: /users:
get: get:
consumes: consumes:

View File

@@ -10,6 +10,7 @@ import (
"github.com/lejianwen/rustdesk-api/v2/lib/jwt" "github.com/lejianwen/rustdesk-api/v2/lib/jwt"
"github.com/lejianwen/rustdesk-api/v2/lib/lock" "github.com/lejianwen/rustdesk-api/v2/lib/lock"
"github.com/lejianwen/rustdesk-api/v2/lib/upload" "github.com/lejianwen/rustdesk-api/v2/lib/upload"
"github.com/lejianwen/rustdesk-api/v2/utils"
"github.com/nicksnyder/go-i18n/v2/i18n" "github.com/nicksnyder/go-i18n/v2/i18n"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/spf13/viper" "github.com/spf13/viper"
@@ -35,4 +36,5 @@ var (
Jwt *jwt.Jwt Jwt *jwt.Jwt
Lock lock.Locker Lock lock.Locker
Localizer func(lang string) *i18n.Localizer Localizer func(lang string) *i18n.Localizer
LoginLimiter *utils.LoginLimiter
) )

23
go.mod
View File

@@ -1,19 +1,23 @@
module github.com/lejianwen/rustdesk-api/v2 module github.com/lejianwen/rustdesk-api/v2
go 1.22 go 1.23
toolchain go1.23.10
require ( require (
github.com/BurntSushi/toml v1.3.2 github.com/BurntSushi/toml v1.3.2
github.com/antonfisher/nested-logrus-formatter v1.3.1 github.com/antonfisher/nested-logrus-formatter v1.3.1
github.com/fsnotify/fsnotify v1.5.1 github.com/coreos/go-oidc/v3 v3.12.0
github.com/fvbock/endless v0.0.0-20170109170031-447134032cb6 github.com/fvbock/endless v0.0.0-20170109170031-447134032cb6
github.com/gin-gonic/gin v1.9.0 github.com/gin-gonic/gin v1.9.0
github.com/go-ldap/ldap/v3 v3.4.10
github.com/go-playground/locales v0.14.1 github.com/go-playground/locales v0.14.1
github.com/go-playground/universal-translator v0.18.1 github.com/go-playground/universal-translator v0.18.1
github.com/go-playground/validator/v10 v10.26.0 github.com/go-playground/validator/v10 v10.26.0
github.com/go-redis/redis/v8 v8.11.4 github.com/go-redis/redis/v8 v8.11.4
github.com/golang-jwt/jwt/v5 v5.2.1 github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/mojocn/base64Captcha v1.3.6
github.com/nicksnyder/go-i18n/v2 v2.4.0 github.com/nicksnyder/go-i18n/v2 v2.4.0
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/spf13/cobra v1.8.1 github.com/spf13/cobra v1.8.1
@@ -24,8 +28,9 @@ require (
golang.org/x/oauth2 v0.23.0 golang.org/x/oauth2 v0.23.0
golang.org/x/text v0.22.0 golang.org/x/text v0.22.0
gorm.io/driver/mysql v1.5.7 gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.5.6 gorm.io/driver/sqlite v1.5.6
gorm.io/gorm v1.25.7 gorm.io/gorm v1.25.10
) )
require ( require (
@@ -36,13 +41,12 @@ require (
github.com/bytedance/sonic v1.8.0 // indirect github.com/bytedance/sonic v1.8.0 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/coreos/go-oidc/v3 v3.12.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/fsnotify/fsnotify v1.5.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect
github.com/go-jose/go-jose/v4 v4.0.2 // indirect github.com/go-jose/go-jose/v4 v4.0.2 // indirect
github.com/go-ldap/ldap/v3 v3.4.10 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect github.com/go-openapi/jsonpointer v0.19.5 // indirect
github.com/go-openapi/jsonreference v0.19.6 // indirect github.com/go-openapi/jsonreference v0.19.6 // indirect
github.com/go-openapi/spec v0.20.4 // indirect github.com/go-openapi/spec v0.20.4 // indirect
@@ -52,6 +56,10 @@ require (
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.6.0 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect github.com/josharian/intern v1.0.0 // indirect
@@ -65,9 +73,9 @@ require (
github.com/mitchellh/mapstructure v1.4.2 // indirect github.com/mitchellh/mapstructure v1.4.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/mojocn/base64Captcha v1.3.6 // indirect
github.com/pelletier/go-toml v1.9.4 // indirect github.com/pelletier/go-toml v1.9.4 // indirect
github.com/pelletier/go-toml/v2 v2.0.6 // indirect github.com/pelletier/go-toml/v2 v2.0.6 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/spf13/afero v1.6.0 // indirect github.com/spf13/afero v1.6.0 // indirect
github.com/spf13/cast v1.4.1 // indirect github.com/spf13/cast v1.4.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect
@@ -79,8 +87,9 @@ require (
golang.org/x/crypto v0.33.0 // indirect golang.org/x/crypto v0.33.0 // indirect
golang.org/x/image v0.13.0 // indirect golang.org/x/image v0.13.0 // indirect
golang.org/x/net v0.34.0 // indirect golang.org/x/net v0.34.0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect golang.org/x/sys v0.30.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.org/x/tools v0.26.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/ini.v1 v1.63.2 // indirect gopkg.in/ini.v1 v1.63.2 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect

View File

@@ -120,7 +120,7 @@ func (abcr *AddressBookCollectionRule) CheckForm(t *model.AddressBookCollectionR
//check to_id //check to_id
if t.Type == model.ShareAddressBookRuleTypePersonal { if t.Type == model.ShareAddressBookRuleTypePersonal {
if t.ToId == t.UserId { if t.ToId == t.UserId {
return "ParamsError", false return "CannotShareToSelf", false
} }
tou := service.AllService.UserService.InfoById(t.ToId) tou := service.AllService.UserService.InfoById(t.ToId)
if tou.Id == 0 { if tou.Id == 0 {
@@ -135,7 +135,7 @@ func (abcr *AddressBookCollectionRule) CheckForm(t *model.AddressBookCollectionR
return "ParamsError", false return "ParamsError", false
} }
// 重复检查 // 重复检查
ex := service.AllService.AddressBookService.RulePersonalInfoByToIdAndCid(t.ToId, t.CollectionId) ex := service.AllService.AddressBookService.RuleInfoByToIdAndCid(t.Type, t.ToId, t.CollectionId)
if t.Id == 0 && ex.Id > 0 { if t.Id == 0 && ex.Id > 0 {
return "ItemExists", false return "ItemExists", false
} }

View File

@@ -78,6 +78,7 @@ func (co *Config) AdminConfig(c *gin.Context) {
} }
hello := global.Config.Admin.Hello hello := global.Config.Admin.Hello
if hello == "" {
helloFile := global.Config.Admin.HelloFile helloFile := global.Config.Admin.HelloFile
if helloFile != "" { if helloFile != "" {
b, err := os.ReadFile(helloFile) b, err := os.ReadFile(helloFile)
@@ -85,6 +86,7 @@ func (co *Config) AdminConfig(c *gin.Context) {
hello = string(b) hello = string(b)
} }
} }
}
//replace {{username}} to username //replace {{username}} to username
hello = strings.Replace(hello, "{{username}}", u.Username, -1) hello = strings.Replace(hello, "{{username}}", u.Username, -1)

View File

@@ -11,135 +11,11 @@ import (
adResp "github.com/lejianwen/rustdesk-api/v2/http/response/admin" adResp "github.com/lejianwen/rustdesk-api/v2/http/response/admin"
"github.com/lejianwen/rustdesk-api/v2/model" "github.com/lejianwen/rustdesk-api/v2/model"
"github.com/lejianwen/rustdesk-api/v2/service" "github.com/lejianwen/rustdesk-api/v2/service"
"github.com/mojocn/base64Captcha"
"sync"
"time"
) )
type Login struct { type Login struct {
} }
// Captcha 验证码结构
type Captcha struct {
Id string `json:"id"` // 验证码 ID
B64 string `json:"b64"` // base64 验证码
Code string `json:"-"` // 验证码内容
ExpiresAt time.Time `json:"-"` // 过期时间
}
type LoginLimiter struct {
mu sync.RWMutex
failCount map[string]int // 记录每个 IP 的失败次数
timestamp map[string]time.Time // 记录每个 IP 的最后失败时间
captchas map[string]Captcha // 每个 IP 的验证码
threshold int // 失败阈值
expiry time.Duration // 失败记录过期时间
}
func NewLoginLimiter(threshold int, expiry time.Duration) *LoginLimiter {
return &LoginLimiter{
failCount: make(map[string]int),
timestamp: make(map[string]time.Time),
captchas: make(map[string]Captcha),
threshold: threshold,
expiry: expiry,
}
}
// RecordFailure 记录登录失败
func (l *LoginLimiter) RecordFailure(ip string) {
l.mu.Lock()
defer l.mu.Unlock()
// 如果该 IP 的记录已经过期,重置计数
if lastTime, exists := l.timestamp[ip]; exists && time.Since(lastTime) > l.expiry {
l.failCount[ip] = 0
}
// 更新失败次数和时间戳
l.failCount[ip]++
l.timestamp[ip] = time.Now()
}
// NeedsCaptcha 检查是否需要验证码
func (l *LoginLimiter) NeedsCaptcha(ip string) bool {
l.mu.RLock()
defer l.mu.RUnlock()
// 检查记录是否存在且未过期
if lastTime, exists := l.timestamp[ip]; exists && time.Since(lastTime) <= l.expiry {
return l.failCount[ip] >= l.threshold
}
return false
}
// GenerateCaptcha 为指定 IP 生成验证码
func (l *LoginLimiter) GenerateCaptcha(ip string) Captcha {
l.mu.Lock()
defer l.mu.Unlock()
capd := base64Captcha.NewDriverString(50, 150, 5, 10, 4, "1234567890abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
b64cap := base64Captcha.NewCaptcha(capd, base64Captcha.DefaultMemStore)
id, b64s, answer, err := b64cap.Generate()
if err != nil {
global.Logger.Error("Generate captcha failed: " + err.Error())
return Captcha{}
}
// 保存验证码到对应 IP
l.captchas[ip] = Captcha{
Id: id,
B64: b64s,
Code: answer,
ExpiresAt: time.Now().Add(5 * time.Minute),
}
return l.captchas[ip]
}
// VerifyCaptcha 验证指定 IP 的验证码
func (l *LoginLimiter) VerifyCaptcha(ip, code string) bool {
l.mu.RLock()
defer l.mu.RUnlock()
// 检查验证码是否存在且未过期
if captcha, exists := l.captchas[ip]; exists && time.Now().Before(captcha.ExpiresAt) {
return captcha.Code == code
}
return false
}
// RemoveCaptcha 移除指定 IP 的验证码
func (l *LoginLimiter) RemoveCaptcha(ip string) {
l.mu.Lock()
defer l.mu.Unlock()
delete(l.captchas, ip)
}
// CleanupExpired 清理过期的记录
func (l *LoginLimiter) CleanupExpired() {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
for ip, lastTime := range l.timestamp {
if now.Sub(lastTime) > l.expiry {
delete(l.failCount, ip)
delete(l.timestamp, ip)
delete(l.captchas, ip)
}
}
}
func (l *LoginLimiter) RemoveRecord(ip string) {
l.mu.Lock()
defer l.mu.Unlock()
delete(l.failCount, ip)
delete(l.timestamp, ip)
delete(l.captchas, ip)
}
var loginLimiter = NewLoginLimiter(3, 5*time.Minute)
// Login 登录 // Login 登录
// @Tags 登录 // @Tags 登录
// @Summary 登录 // @Summary 登录
@@ -156,10 +32,16 @@ func (ct *Login) Login(c *gin.Context) {
response.Fail(c, 101, response.TranslateMsg(c, "PwdLoginDisabled")) response.Fail(c, 101, response.TranslateMsg(c, "PwdLoginDisabled"))
return return
} }
// 检查登录限制
loginLimiter := global.LoginLimiter
clientIp := c.ClientIP()
_, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp)
f := &admin.Login{} f := &admin.Login{}
err := c.ShouldBindJSON(f) err := c.ShouldBindJSON(f)
clientIp := c.ClientIP()
if err != nil { if err != nil {
loginLimiter.RecordFailedAttempt(clientIp)
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp)) global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp))
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error()) response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
return return
@@ -167,14 +49,15 @@ func (ct *Login) Login(c *gin.Context) {
errList := global.Validator.ValidStruct(c, f) errList := global.Validator.ValidStruct(c, f)
if len(errList) > 0 { if len(errList) > 0 {
loginLimiter.RecordFailedAttempt(clientIp)
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp)) global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp))
response.Fail(c, 101, errList[0]) response.Fail(c, 101, errList[0])
return return
} }
// 检查是否需要验证码 // 检查是否需要验证码
if loginLimiter.NeedsCaptcha(clientIp) { if needCaptcha {
if f.Captcha == "" || !loginLimiter.VerifyCaptcha(clientIp, f.Captcha) { if f.CaptchaId == "" || f.Captcha == "" || !loginLimiter.VerifyCaptcha(f.CaptchaId, f.Captcha) {
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")) response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError"))
return return
} }
@@ -184,17 +67,19 @@ func (ct *Login) Login(c *gin.Context) {
if u.Id == 0 { if u.Id == 0 {
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp)) global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp))
loginLimiter.RecordFailure(clientIp) loginLimiter.RecordFailedAttempt(clientIp)
if loginLimiter.NeedsCaptcha(clientIp) { if _, needCaptcha = loginLimiter.CheckSecurityStatus(clientIp); needCaptcha {
loginLimiter.RemoveCaptcha(clientIp) response.Fail(c, 110, response.TranslateMsg(c, "UsernameOrPasswordError"))
} } else {
response.Fail(c, 101, response.TranslateMsg(c, "UsernameOrPasswordError")) response.Fail(c, 101, response.TranslateMsg(c, "UsernameOrPasswordError"))
}
return return
} }
if !service.AllService.UserService.CheckUserEnable(u) { if !service.AllService.UserService.CheckUserEnable(u) {
if loginLimiter.NeedsCaptcha(clientIp) { if needCaptcha {
loginLimiter.RemoveCaptcha(clientIp) response.Fail(c, 110, response.TranslateMsg(c, "UserDisabled"))
return
} }
response.Fail(c, 101, response.TranslateMsg(c, "UserDisabled")) response.Fail(c, 101, response.TranslateMsg(c, "UserDisabled"))
return return
@@ -209,23 +94,37 @@ func (ct *Login) Login(c *gin.Context) {
Platform: f.Platform, Platform: f.Platform,
}) })
// 成功清除记录 // 登录成功清除登录限制
loginLimiter.RemoveRecord(clientIp) loginLimiter.RemoveAttempts(clientIp)
// 清理过期记录
go loginLimiter.CleanupExpired()
responseLoginSuccess(c, u, ut.Token) responseLoginSuccess(c, u, ut.Token)
} }
func (ct *Login) Captcha(c *gin.Context) { func (ct *Login) Captcha(c *gin.Context) {
loginLimiter := global.LoginLimiter
clientIp := c.ClientIP() clientIp := c.ClientIP()
if !loginLimiter.NeedsCaptcha(clientIp) { banned, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp)
if banned {
response.Fail(c, 101, response.TranslateMsg(c, "LoginBanned"))
return
}
if !needCaptcha {
response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired")) response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired"))
return return
} }
captcha := loginLimiter.GenerateCaptcha(clientIp) err, captcha := loginLimiter.RequireCaptcha()
if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error())
return
}
err, b64 := loginLimiter.DrawCaptcha(captcha.Content)
if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error())
return
}
response.Success(c, gin.H{ response.Success(c, gin.H{
"captcha": captcha, "captcha": gin.H{
"id": captcha.Id,
"b64": b64,
},
}) })
} }
@@ -257,12 +156,18 @@ func (ct *Login) Logout(c *gin.Context) {
// @Failure 500 {object} response.ErrorResponse // @Failure 500 {object} response.ErrorResponse
// @Router /admin/login-options [post] // @Router /admin/login-options [post]
func (ct *Login) LoginOptions(c *gin.Context) { func (ct *Login) LoginOptions(c *gin.Context) {
ip := c.ClientIP() loginLimiter := global.LoginLimiter
clientIp := c.ClientIP()
banned, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp)
if banned {
response.Fail(c, 101, response.TranslateMsg(c, "LoginBanned"))
return
}
ops := service.AllService.OauthService.GetOauthProviders() ops := service.AllService.OauthService.GetOauthProviders()
response.Success(c, gin.H{ response.Success(c, gin.H{
"ops": ops, "ops": ops,
"register": global.Config.App.Register, "register": global.Config.App.Register,
"need_captcha": loginLimiter.NeedsCaptcha(ip), "need_captcha": needCaptcha,
}) })
} }

View File

@@ -100,21 +100,21 @@ func (abcr *AddressBookCollectionRule) CheckForm(u *model.User, t *model.Address
//check to_id //check to_id
if t.Type == model.ShareAddressBookRuleTypePersonal { if t.Type == model.ShareAddressBookRuleTypePersonal {
if t.ToId == t.UserId { if t.ToId == t.UserId {
return "ParamsError", false return "CannotShareToSelf", false
} }
tou := service.AllService.UserService.InfoById(t.ToId) tou := service.AllService.UserService.InfoById(t.ToId)
if tou.Id == 0 { if tou.Id == 0 {
return "ItemNotFound", false return "ItemNotFound", false
} }
//非管理员不能分享给非本组织用户 //非管理员不能分享给非本组织用户
if tou.GroupId != u.GroupId { //if tou.GroupId != u.GroupId {
return "NoAccess", false // return "NoAccess", false
} //}
} else if t.Type == model.ShareAddressBookRuleTypeGroup { } else if t.Type == model.ShareAddressBookRuleTypeGroup {
//非管理员不能分享给其他组 //非管理员不能分享给其他组
if t.ToId != u.GroupId { //if t.ToId != u.GroupId {
return "NoAccess", false // return "NoAccess", false
} //}
tog := service.AllService.GroupService.InfoById(t.ToId) tog := service.AllService.GroupService.InfoById(t.ToId)
if tog.Id == 0 { if tog.Id == 0 {
@@ -124,7 +124,7 @@ func (abcr *AddressBookCollectionRule) CheckForm(u *model.User, t *model.Address
return "ParamsError", false return "ParamsError", false
} }
// 重复检查 // 重复检查
ex := service.AllService.AddressBookService.RulePersonalInfoByToIdAndCid(t.ToId, t.CollectionId) ex := service.AllService.AddressBookService.RuleInfoByToIdAndCid(t.Type, t.ToId, t.CollectionId)
if t.Id == 0 && ex.Id > 0 { if t.Id == 0 && ex.Id > 0 {
return "ItemExists", false return "ItemExists", false
} }

View File

@@ -119,7 +119,16 @@ func (r *Rustdesk) SendCmd(c *gin.Context) {
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")) response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
return return
} }
res, err := service.AllService.ServerCmdService.SendCmd(rc.Target, rc.Cmd, rc.Option)
port := 0
switch rc.Target {
case model.ServerCmdTargetIdServer:
port = global.Config.Admin.IdServerPort - 1
case model.ServerCmdTargetRelayServer:
port = global.Config.Admin.RelayServerPort
}
res, err := service.AllService.ServerCmdService.SendCmd(port, rc.Cmd, rc.Option)
if err != nil { if err != nil {
response.Fail(c, 101, err.Error()) response.Fail(c, 101, err.Error())
return return

View File

@@ -296,32 +296,12 @@ func (ct *User) MyOauth(c *gin.Context) {
// groupUsers // groupUsers
func (ct *User) GroupUsers(c *gin.Context) { func (ct *User) GroupUsers(c *gin.Context) {
q := &admin.GroupUsersQuery{} aG := service.AllService.GroupService.List(1, 999, nil)
if err := c.ShouldBindJSON(q); err != nil { aU := service.AllService.UserService.List(1, 9999, nil)
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error()) response.Success(c, gin.H{
return "groups": aG.Groups,
} "users": aU.Users,
u := service.AllService.UserService.CurUser(c)
gid := u.GroupId
uid := u.Id
if service.AllService.UserService.IsAdmin(u) && q.UserId > 0 {
nu := service.AllService.UserService.InfoById(q.UserId)
gid = nu.GroupId
uid = q.UserId
}
res := service.AllService.UserService.List(1, 999, func(tx *gorm.DB) {
tx.Where("group_id = ?", gid)
}) })
var data []*adResp.GroupUsersPayload
for _, _u := range res.Users {
gup := &adResp.GroupUsersPayload{}
gup.FromUser(_u)
if _u.Id == uid {
gup.Status = 0
}
data = append(data, gup)
}
response.Success(c, data)
} }
// Register // Register
@@ -340,11 +320,22 @@ func (ct *User) Register(c *gin.Context) {
response.Fail(c, 101, errList[0]) response.Fail(c, 101, errList[0])
return return
} }
u := service.AllService.UserService.Register(f.Username, f.Email, f.Password) regStatus := model.StatusCode(global.Config.App.RegisterStatus)
// 注册状态可能未配置,默认启用
if regStatus != model.COMMON_STATUS_DISABLED && regStatus != model.COMMON_STATUS_ENABLE {
regStatus = model.COMMON_STATUS_ENABLE
}
u := service.AllService.UserService.Register(f.Username, f.Email, f.Password, regStatus)
if u == nil || u.Id == 0 { if u == nil || u.Id == 0 {
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")) response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed"))
return return
} }
if regStatus == model.COMMON_STATUS_DISABLED {
// 需要管理员审核
response.Fail(c, 101, response.TranslateMsg(c, "RegisterSuccessWaitAdminConfirm"))
return
}
// 注册成功后自动登录 // 注册成功后自动登录
ut := service.AllService.UserService.Login(u, &model.LoginLog{ ut := service.AllService.UserService.Login(u, &model.LoginLog{
UserId: u.Id, UserId: u.Id,

View File

@@ -31,10 +31,16 @@ func (l *Login) Login(c *gin.Context) {
response.Error(c, response.TranslateMsg(c, "PwdLoginDisabled")) response.Error(c, response.TranslateMsg(c, "PwdLoginDisabled"))
return return
} }
// 检查登录限制
loginLimiter := global.LoginLimiter
clientIp := c.ClientIP()
f := &api.LoginForm{} f := &api.LoginForm{}
err := c.ShouldBindJSON(f) err := c.ShouldBindJSON(f)
//fmt.Println(f) //fmt.Println(f)
if err != nil { if err != nil {
loginLimiter.RecordFailedAttempt(clientIp)
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), c.ClientIP())) global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), c.ClientIP()))
response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error()) response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error())
return return
@@ -42,6 +48,7 @@ func (l *Login) Login(c *gin.Context) {
errList := global.Validator.ValidStruct(c, f) errList := global.Validator.ValidStruct(c, f)
if len(errList) > 0 { if len(errList) > 0 {
loginLimiter.RecordFailedAttempt(clientIp)
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), c.ClientIP())) global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), c.ClientIP()))
response.Error(c, errList[0]) response.Error(c, errList[0])
return return
@@ -50,6 +57,7 @@ func (l *Login) Login(c *gin.Context) {
u := service.AllService.UserService.InfoByUsernamePassword(f.Username, f.Password) u := service.AllService.UserService.InfoByUsernamePassword(f.Username, f.Password)
if u.Id == 0 { if u.Id == 0 {
loginLimiter.RecordFailedAttempt(clientIp)
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), c.ClientIP())) global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), c.ClientIP()))
response.Error(c, response.TranslateMsg(c, "UsernameOrPasswordError")) response.Error(c, response.TranslateMsg(c, "UsernameOrPasswordError"))
return return

View File

@@ -8,6 +8,8 @@ import (
apiResp "github.com/lejianwen/rustdesk-api/v2/http/response/api" apiResp "github.com/lejianwen/rustdesk-api/v2/http/response/api"
"github.com/lejianwen/rustdesk-api/v2/model" "github.com/lejianwen/rustdesk-api/v2/model"
"github.com/lejianwen/rustdesk-api/v2/service" "github.com/lejianwen/rustdesk-api/v2/service"
"github.com/lejianwen/rustdesk-api/v2/utils"
"github.com/nicksnyder/go-i18n/v2/i18n"
"net/http" "net/http"
) )
@@ -145,7 +147,8 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
state := c.Query("state") state := c.Query("state")
if state == "" { if state == "" {
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{ c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": response.TranslateParamMsg(c, "ParamIsEmpty", "state"), "message": "ParamIsEmpty",
"sub_message": "state",
}) })
return return
} }
@@ -155,7 +158,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
oauthCache := oauthService.GetOauthCache(cacheKey) oauthCache := oauthService.GetOauthCache(cacheKey)
if oauthCache == nil { if oauthCache == nil {
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{ c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": response.TranslateMsg(c, "OauthExpired"), "message": "OauthExpired",
}) })
return return
} }
@@ -169,7 +172,8 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
err, oauthUser := oauthService.Callback(code, verifier, op, nonce) err, oauthUser := oauthService.Callback(code, verifier, op, nonce)
if err != nil { if err != nil {
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{ c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": response.TranslateMsg(c, "OauthFailed") + response.TranslateMsg(c, err.Error()), "message": "OauthFailed",
"sub_message": err.Error(),
}) })
return return
} }
@@ -182,7 +186,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
utr := oauthService.UserThirdInfo(op, openid) utr := oauthService.UserThirdInfo(op, openid)
if utr.UserId > 0 { if utr.UserId > 0 {
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{ c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": response.TranslateMsg(c, "OauthHasBindOtherUser"), "message": "OauthHasBindOtherUser",
}) })
return return
} }
@@ -190,7 +194,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
user = service.AllService.UserService.InfoById(userId) user = service.AllService.UserService.InfoById(userId)
if user == nil { if user == nil {
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{ c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": response.TranslateMsg(c, "ItemNotFound"), "message": "ItemNotFound",
}) })
return return
} }
@@ -198,12 +202,12 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
err := oauthService.BindOauthUser(userId, oauthUser, op) err := oauthService.BindOauthUser(userId, oauthUser, op)
if err != nil { if err != nil {
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{ c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": response.TranslateMsg(c, "BindFail"), "message": "BindFail",
}) })
return return
} }
c.HTML(http.StatusOK, "oauth_success.html", gin.H{ c.HTML(http.StatusOK, "oauth_success.html", gin.H{
"message": response.TranslateMsg(c, "BindSuccess"), "message": "BindSuccess",
}) })
return return
@@ -211,7 +215,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
//登录 //登录
if userId != 0 { if userId != 0 {
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{ c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": response.TranslateMsg(c, "OauthHasBeenSuccess"), "message": "OauthHasBeenSuccess",
}) })
return return
} }
@@ -230,7 +234,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
err, user = service.AllService.UserService.RegisterByOauth(oauthUser, op) err, user = service.AllService.UserService.RegisterByOauth(oauthUser, op)
if err != nil { if err != nil {
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{ c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": response.TranslateMsg(c, err.Error()), "message": err.Error(),
}) })
return return
} }
@@ -252,14 +256,50 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
return return
} }
c.HTML(http.StatusOK, "oauth_success.html", gin.H{ c.HTML(http.StatusOK, "oauth_success.html", gin.H{
"message": response.TranslateMsg(c, "OauthSuccess"), "message": "OauthSuccess",
}) })
return return
} else { } else {
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{ c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": response.TranslateMsg(c, "ParamsError"), "message": "ParamsError",
}) })
return return
} }
} }
type MessageParams struct {
Lang string `json:"lang" form:"lang"`
Title string `json:"title" form:"title"`
Msg string `json:"msg" form:"msg"`
}
func (o *Oauth) Message(c *gin.Context) {
mp := &MessageParams{}
if err := c.ShouldBindQuery(mp); err != nil {
return
}
localizer := global.Localizer(mp.Lang)
res := ""
if mp.Title != "" {
title, err := localizer.LocalizeMessage(&i18n.Message{
ID: mp.Title,
})
if err == nil {
res = utils.StringConcat(";title='", title, "';")
}
}
if mp.Msg != "" {
msg, err := localizer.LocalizeMessage(&i18n.Message{
ID: mp.Msg,
})
if err == nil {
res = utils.StringConcat(res, "msg = '", msg, "';")
}
}
//返回js内容
c.Header("Content-Type", "application/javascript")
c.String(http.StatusOK, res)
}

View File

@@ -1,6 +1,7 @@
package api package api
import ( import (
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
requstform "github.com/lejianwen/rustdesk-api/v2/http/request/api" requstform "github.com/lejianwen/rustdesk-api/v2/http/request/api"
@@ -13,7 +14,7 @@ type Peer struct {
} }
// SysInfo // SysInfo
// @Tags 地址 // @Tags System
// @Summary 提交系统信息 // @Summary 提交系统信息
// @Description 提交系统信息 // @Description 提交系统信息
// @Accept json // @Accept json
@@ -57,8 +58,19 @@ func (p *Peer) SysInfo(c *gin.Context) {
c.String(http.StatusOK, "SYSINFO_UPDATED") c.String(http.StatusOK, "SYSINFO_UPDATED")
} }
// SysInfoVer
// @Tags System
// @Summary 获取系统版本信息
// @Description 获取系统版本信息
// @Accept json
// @Produce json
// @Success 200 {string} string ""
// @Failure 500 {object} response.ErrorResponse
// @Router /sysinfo_ver [post]
func (p *Peer) SysInfoVer(c *gin.Context) { func (p *Peer) SysInfoVer(c *gin.Context) {
//读取resources/version文件 //读取resources/version文件
v := service.AllService.AppService.GetAppVersion() v := service.AllService.AppService.GetAppVersion()
// 加上启动时间方便client上传信息
v = fmt.Sprintf("%s\n%s", v, service.AllService.AppService.GetStartTime())
c.String(http.StatusOK, v) c.String(http.StatusOK, v)
} }

View File

@@ -33,7 +33,7 @@ func ApiInit() {
g.NoRoute(func(c *gin.Context) { g.NoRoute(func(c *gin.Context) {
c.String(http.StatusNotFound, "404 not found") c.String(http.StatusNotFound, "404 not found")
}) })
g.Use(middleware.Logger(), gin.Recovery()) g.Use(middleware.Logger(), middleware.Limiter(), gin.Recovery())
router.WebInit(g) router.WebInit(g)
router.Init(g) router.Init(g)
router.ApiInit(g) router.ApiInit(g)

View File

@@ -0,0 +1,22 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/http/response"
"net/http"
)
func Limiter() gin.HandlerFunc {
return func(c *gin.Context) {
loginLimiter := global.LoginLimiter
clientIp := c.ClientIP()
banned, _ := loginLimiter.CheckSecurityStatus(clientIp)
if banned {
response.Fail(c, http.StatusLocked, response.TranslateMsg(c, "Banned"))
c.Abort()
return
}
c.Next()
}
}

View File

@@ -5,6 +5,7 @@ type Login struct {
Password string `json:"password,omitempty" validate:"required" label:"密码"` Password string `json:"password,omitempty" validate:"required" label:"密码"`
Platform string `json:"platform" label:"平台"` Platform string `json:"platform" label:"平台"`
Captcha string `json:"captcha,omitempty" label:"验证码"` Captcha string `json:"captcha,omitempty" label:"验证码"`
CaptchaId string `json:"captcha_id,omitempty"`
} }
type LoginLogQuery struct { type LoginLogQuery struct {

View File

@@ -40,14 +40,14 @@ type LoginForm struct {
type UserListQuery struct { type UserListQuery struct {
Page uint `json:"page" form:"page" validate:"required" label:"页码"` Page uint `json:"page" form:"page" validate:"required" label:"页码"`
PageSize uint `json:"page_size" form:"page_size" validate:"required" label:"每页数量"` PageSize uint `json:"pageSize" form:"pageSize" validate:"required" label:"每页数量"`
Status int `json:"status" form:"status" label:"状态"` Status int `json:"status" form:"status" label:"状态"`
Accessible string `json:"accessible" form:"accessible"` Accessible string `json:"accessible" form:"accessible"`
} }
type PeerListQuery struct { type PeerListQuery struct {
Page uint `json:"page" form:"page" validate:"required" label:"页码"` Page uint `json:"page" form:"page" validate:"required" label:"页码"`
PageSize uint `json:"page_size" form:"page_size" validate:"required" label:"每页数量"` PageSize uint `json:"pageSize" form:"pageSize" validate:"required" label:"每页数量"`
Status int `json:"status" form:"status" label:"状态"` Status int `json:"status" form:"status" label:"状态"`
Accessible string `json:"accessible" form:"accessible"` Accessible string `json:"accessible" form:"accessible"`
} }

View File

@@ -22,15 +22,3 @@ type UserOauthItem struct {
Op string `json:"op"` Op string `json:"op"`
Status int `json:"status"` Status int `json:"status"`
} }
type GroupUsersPayload struct {
Id uint `json:"id"`
Username string `json:"username"`
Status int `json:"status"`
}
func (g *GroupUsersPayload) FromUser(user *model.User) {
g.Id = user.Id
g.Username = user.Username
g.Status = 1
}

View File

@@ -48,6 +48,7 @@ func ApiInit(g *gin.Engine) {
//api/oauth/callback //api/oauth/callback
frg.GET("/oauth/callback", o.OauthCallback) frg.GET("/oauth/callback", o.OauthCallback)
frg.GET("/oauth/login", o.OauthCallback) frg.GET("/oauth/login", o.OauthCallback)
frg.GET("/oauth/msg", o.Message)
} }
{ {
pe := &api.Peer{} pe := &api.Peer{}

View File

@@ -2,7 +2,6 @@ package orm
import ( import (
"fmt" "fmt"
"github.com/lejianwen/rustdesk-api/v2/global"
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
@@ -10,14 +9,14 @@ import (
) )
type MysqlConfig struct { type MysqlConfig struct {
Dns string Dsn string
MaxIdleConns int MaxIdleConns int
MaxOpenConns int MaxOpenConns int
} }
func NewMysql(mysqlConf *MysqlConfig) *gorm.DB { func NewMysql(mysqlConf *MysqlConfig, logwriter logger.Writer) *gorm.DB {
db, err := gorm.Open(mysql.New(mysql.Config{ db, err := gorm.Open(mysql.New(mysql.Config{
DSN: mysqlConf.Dns, // DSN data source name DSN: mysqlConf.Dsn, // DSN data source name
DefaultStringSize: 256, // string 类型字段的默认长度 DefaultStringSize: 256, // string 类型字段的默认长度
//DisableDatetimePrecision: true, // 禁用 datetime 精度MySQL 5.6 之前的数据库不支持 //DisableDatetimePrecision: true, // 禁用 datetime 精度MySQL 5.6 之前的数据库不支持
//DontSupportRenameIndex: true, // 重命名索引时采用删除并新建的方式MySQL 5.7 之前的数据库和 MariaDB 不支持重命名索引 //DontSupportRenameIndex: true, // 重命名索引时采用删除并新建的方式MySQL 5.7 之前的数据库和 MariaDB 不支持重命名索引
@@ -26,7 +25,7 @@ func NewMysql(mysqlConf *MysqlConfig) *gorm.DB {
}), &gorm.Config{ }), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true, DisableForeignKeyConstraintWhenMigrating: true,
Logger: logger.New( Logger: logger.New(
global.Logger, // io writer logwriter, // io writer
logger.Config{ logger.Config{
SlowThreshold: time.Second, // Slow SQL threshold SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: logger.Warn, // Log level LogLevel: logger.Warn, // Log level

45
lib/orm/postgresql.go Normal file
View File

@@ -0,0 +1,45 @@
package orm
import (
"fmt"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"time"
)
type PostgresqlConfig struct {
Dsn string
MaxIdleConns int
MaxOpenConns int
}
func NewPostgresql(conf *PostgresqlConfig, logwriter logger.Writer) *gorm.DB {
db, err := gorm.Open(postgres.Open(conf.Dsn), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: logger.New(
logwriter, // io writer
logger.Config{
SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: logger.Warn, // Log level
//IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
ParameterizedQueries: true, // Don't include params in the SQL log
Colorful: true,
},
),
})
if err != nil {
fmt.Println(err)
}
sqlDB, err2 := db.DB()
if err2 != nil {
fmt.Println(err2)
}
// SetMaxIdleConns 设置空闲连接池中连接的最大数量
sqlDB.SetMaxIdleConns(conf.MaxIdleConns)
// SetMaxOpenConns 设置打开数据库连接的最大数量。
sqlDB.SetMaxOpenConns(conf.MaxOpenConns)
return db
}

View File

@@ -2,7 +2,6 @@ package orm
import ( import (
"fmt" "fmt"
"github.com/lejianwen/rustdesk-api/v2/global"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
@@ -14,11 +13,11 @@ type SqliteConfig struct {
MaxOpenConns int MaxOpenConns int
} }
func NewSqlite(sqliteConf *SqliteConfig) *gorm.DB { func NewSqlite(sqliteConf *SqliteConfig, logwriter logger.Writer) *gorm.DB {
db, err := gorm.Open(sqlite.Open("./data/rustdeskapi.db"), &gorm.Config{ db, err := gorm.Open(sqlite.Open("./data/rustdeskapi.db"), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true, DisableForeignKeyConstraintWhenMigrating: true,
Logger: logger.New( Logger: logger.New(
global.Logger, // io writer logwriter, // io writer
logger.Config{ logger.Config{
SlowThreshold: time.Second, // Slow SQL threshold SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: logger.Warn, // Log level LogLevel: logger.Warn, // Log level

View File

@@ -14,6 +14,7 @@ const (
OauthTypeGoogle string = "google" OauthTypeGoogle string = "google"
OauthTypeOidc string = "oidc" OauthTypeOidc string = "oidc"
OauthTypeWebauth string = "webauth" OauthTypeWebauth string = "webauth"
OauthTypeLinuxdo string = "linuxdo"
PKCEMethodS256 string = "S256" PKCEMethodS256 string = "S256"
PKCEMethodPlain string = "plain" PKCEMethodPlain string = "plain"
) )
@@ -21,7 +22,7 @@ const (
// Validate the oauth type // Validate the oauth type
func ValidateOauthType(oauthType string) error { func ValidateOauthType(oauthType string) error {
switch oauthType { switch oauthType {
case OauthTypeGithub, OauthTypeGoogle, OauthTypeOidc, OauthTypeWebauth: case OauthTypeGithub, OauthTypeGoogle, OauthTypeOidc, OauthTypeWebauth, OauthTypeLinuxdo:
return nil return nil
default: default:
return errors.New("invalid Oauth type") return errors.New("invalid Oauth type")
@@ -30,6 +31,7 @@ func ValidateOauthType(oauthType string) error {
const ( const (
UserEndpointGithub string = "https://api.github.com/user" UserEndpointGithub string = "https://api.github.com/user"
UserEndpointLinuxdo string = "https://connect.linux.do/api/user"
IssuerGoogle string = "https://accounts.google.com" IssuerGoogle string = "https://accounts.google.com"
) )
@@ -60,6 +62,8 @@ func (oa *Oauth) FormatOauthInfo() error {
oa.Op = OauthTypeGithub oa.Op = OauthTypeGithub
case OauthTypeGoogle: case OauthTypeGoogle:
oa.Op = OauthTypeGoogle oa.Op = OauthTypeGoogle
case OauthTypeLinuxdo:
oa.Op = OauthTypeLinuxdo
} }
// check if the op is empty, set the default value // check if the op is empty, set the default value
op := strings.TrimSpace(oa.Op) op := strings.TrimSpace(oa.Op)
@@ -152,6 +156,24 @@ func (gu *GithubUser) ToOauthUser() *OauthUser {
} }
} }
type LinuxdoUser struct {
OauthUserBase
Id int `json:"id"`
Username string `json:"username"`
Avatar string `json:"avatar_url"`
}
func (lu *LinuxdoUser) ToOauthUser() *OauthUser {
return &OauthUser{
OpenId: strconv.Itoa(lu.Id),
Name: lu.Name,
Username: strings.ToLower(lu.Username),
Email: lu.Email,
VerifiedEmail: true, // linux.do 用户邮箱默认已验证
Picture: lu.Avatar,
}
}
type OauthList struct { type OauthList struct {
Oauths []*Oauth `json:"list"` Oauths []*Oauth `json:"list"`
Pagination Pagination

View File

@@ -138,3 +138,18 @@ other = "Captcha error."
description = "Password login disabled." description = "Password login disabled."
one = "Password login disabled." one = "Password login disabled."
other = "Password login disabled." other = "Password login disabled."
[CannotShareToSelf]
description = "Cannot share to self."
one = "Cannot share to self."
other = "Cannot share to self."
[Banned]
description = "Banned."
one = "Banned."
other = "Banned."
[RegisterSuccessWaitAdminConfirm]
description = "Register success, wait admin confirm."
one = "Register success, wait admin confirm."
other = "Register success, wait admin confirm."

View File

@@ -147,3 +147,18 @@ other = "Error de captcha."
description = "Password login disabled." description = "Password login disabled."
one = "Inicio de sesión con contraseña deshabilitado." one = "Inicio de sesión con contraseña deshabilitado."
other = "Inicio de sesión con contraseña deshabilitado." other = "Inicio de sesión con contraseña deshabilitado."
[CannotShareToSelf]
description = "Cannot share to self."
one = "No se puede compartir con uno mismo."
other = "No se puede compartir con uno mismo."
[Banned]
description = "Banned."
one = "Prohibido."
other = "Prohibido."
[RegisterSuccessWaitAdminConfirm]
description = "Register success, wait admin confirm."
one = "Registro exitoso, espere la confirmación del administrador."
other = "Registro exitoso, espere la confirmación del administrador."

View File

@@ -147,3 +147,18 @@ other = "Erreur de captcha."
description = "Password login disabled." description = "Password login disabled."
one = "Connexion par mot de passe désactivée." one = "Connexion par mot de passe désactivée."
other = "Connexion par mot de passe désactivée." other = "Connexion par mot de passe désactivée."
[CannotShareToSelf]
description = "Cannot share to self."
one = "Impossible de partager avec soi-même."
other = "Impossible de partager avec soi-même."
[Banned]
description = "Banned."
one = "Banni."
other = "Banni."
[RegisterSuccessWaitAdminConfirm]
description = "Register success wait admin confirm."
one = "Inscription réussie, veuillez attendre la confirmation de l'administrateur."
other = "Inscription réussie, veuillez attendre la confirmation de l'administrateur."

View File

@@ -141,3 +141,18 @@ other = "Captcha 오류."
description = "Password login disabled." description = "Password login disabled."
one = "비밀번호 로그인이 비활성화되었습니다." one = "비밀번호 로그인이 비활성화되었습니다."
other = "비밀번호 로그인이 비활성화되었습니다." other = "비밀번호 로그인이 비활성화되었습니다."
[CannotShareToSelf]
description = "Cannot share to self."
one = "자기 자신에게 공유할 수 없습니다."
other = "자기 자신에게 공유할 수 없습니다."
[Banned]
description = "Banned."
one = "금지됨."
other = "금지됨."
[RegisterSuccessWaitAdminConfirm]
description = "Register success wait admin confirm."
one = "가입 성공, 관리자 확인 대기 중."
other = "가입 성공, 관리자 확인 대기 중."

View File

@@ -147,3 +147,18 @@ other = "Ошибка капчи."
description = "Password login disabled." description = "Password login disabled."
one = "Вход по паролю отключен." one = "Вход по паролю отключен."
other = "Вход по паролю отключен." other = "Вход по паролю отключен."
[CannotShareToSelf]
description = "Cannot share to self."
one = "Нельзя поделиться с собой."
other = "Нельзя поделиться с собой."
[Banned]
description = "Banned."
one = "Заблокировано."
other = "Заблокировано."
[RegisterSuccessWaitAdminConfirm]
description = "Register success wait admin confirm."
one = "Регистрация прошла успешно, ожидайте подтверждения администратора."
other = "Регистрация прошла успешно, ожидайте подтверждения администратора."

View File

@@ -140,3 +140,18 @@ other = "验证码错误。"
description = "Password login disabled." description = "Password login disabled."
one = "密码登录已禁用。" one = "密码登录已禁用。"
other = "密码登录已禁用。" other = "密码登录已禁用。"
[CannotShareToSelf]
description = "Cannot share to self."
one = "不能共享给自己。"
other = "不能共享给自己。"
[Banned]
description = "Banned."
one = "已被封禁。"
other = "已被封禁。"
[RegisterSuccessWaitAdminConfirm]
description = "Register success, wait for admin confirm."
one = "注册成功,请等待管理员审核。"
other = "注册成功,请等待管理员审核。"

View File

@@ -140,3 +140,18 @@ other = "驗證碼錯誤。"
description = "Password login disabled." description = "Password login disabled."
one = "密碼登錄已禁用。" one = "密碼登錄已禁用。"
other = "密碼登錄已禁用。" other = "密碼登錄已禁用。"
[CannotShareToSelf]
description = "Cannot share to self."
one = "無法共享給自己。"
other = "無法共享給自己。"
[Banned]
description = "Banned."
one = "禁止使用。"
other = "禁止使用。"
[RegisterSuccessWaitAdminConfirm]
description = "Register success wait admin confirm."
one = "註冊成功,請等待管理員確認。"
other = "註冊成功,請等待管理員確認。"

View File

View File

@@ -1,9 +1,9 @@
<!DOCTYPE html> <!DOCTYPE html>
<html lang="zh-CN"> <html>
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>授权失败 - RustDesk API</title> <title>OauthFailed - RustDesk API</title>
<style> <style>
body { body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Arial, sans-serif; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Arial, sans-serif;
@@ -57,17 +57,25 @@
} }
</style> </style>
<link rel="stylesheet" href="https://lf9-cdn-tos.bytecdntp.com/cdn/expire-1-M/font-awesome/6.0.0/css/all.min.css"> <link rel="stylesheet" href="https://lf9-cdn-tos.bytecdntp.com/cdn/expire-1-M/font-awesome/6.0.0/css/all.min.css">
<script>
var lang = navigator.language || navigator.userLanguage || 'zh-CN';
var title = 'OauthFailed'
var msg = '{{.message}}'
var btn = 'Close'
document.writeln('<script src="/api/oauth/msg?lang=' + lang + '&msg=' + msg + '&title=OauthFailed"><\/script>');
</script>
</head> </head>
<body> <body>
<div class="success-container"> <div class="success-container">
<i class="fas fa-triangle-exclamation checkmark"></i> <i class="fas fa-triangle-exclamation checkmark"></i>
<h1>授权失败!</h1> <h1 id="h1"></h1>
<p>{{.message}}</p> <p id="msg"></p>
<a href="javascript:window.close()" class="return-link">关闭页面</a> <a href="javascript:window.close()" class="return-link" id="btn">Close</a>
</div> </div>
<script> <script>
document.title = title + ' - RustDesk API';
document.getElementById('h1').innerText = title;
document.getElementById('msg').innerText = msg;
</script> </script>
</body> </body>
</html> </html>

View File

@@ -3,7 +3,7 @@
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>授权成功 - RustDesk API</title> <title>OauthSuccess - RustDesk API</title>
<style> <style>
body { body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Arial, sans-serif; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Arial, sans-serif;
@@ -56,18 +56,27 @@
background-color: #45a049; background-color: #45a049;
} }
</style> </style>
<script>
var lang = navigator.language || navigator.userLanguage || 'zh-CN';
var title = 'OauthSuccess'
var msg = '{{.message}}'
var btn = 'Close'
document.writeln('<script src="/api/oauth/msg?lang=' + lang + '&msg=' + msg + '&title=OauthSuccess"><\/script>');
</script>
</head> </head>
<body> <body>
<div class="success-container"> <div class="success-container">
<i class="fas fa-check-circle checkmark"></i> <i class="fas fa-check-circle checkmark"></i>
<h1>授权成功!</h1> <h1 id="h1"></h1>
<p>您已成功授权访问您的账户。</p> <!-- <p>您已成功授权访问您的账户。</p>-->
<p>现在可以关闭本页面或返回应用继续操作。</p> <!-- <p>现在可以关闭本页面或返回应用继续操作。</p>-->
<a href="javascript:window.close()" class="return-link">关闭页面</a> <a href="javascript:window.close()" class="return-link">Close</a>
</div> </div>
<script> <script>
document.title = title + ' - RustDesk API';
document.getElementById('h1').innerText = title;
document.getElementById('msg').innerText = msg;
</script> </script>
</body> </body>
</html> </html>

View File

@@ -38,5 +38,21 @@
"asset": "assets/address_book.ttf" "asset": "assets/address_book.ttf"
} }
] ]
},
{
"family": "DeviceGroup",
"fonts": [
{
"asset": "assets/device_group.ttf"
}
]
},
{
"family": "More",
"fonts": [
{
"asset": "assets/more.ttf"
}
]
} }
] ]

Binary file not shown.

BIN
resources/web2/assets/assets/more.ttf vendored Normal file

Binary file not shown.

View File

@@ -1,6 +1,6 @@
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
<!-- <!--
If you are serving your web app in a path other than the root, change the If you are serving your web app in a path other than the root, change the
href value below to reflect the base path you are serving from. href value below to reflect the base path you are serving from.
@@ -16,24 +16,24 @@
--> -->
<base href="/webclient2/" /> <base href="/webclient2/" />
<meta charset="UTF-8" /> <meta charset="UTF-8"/>
<meta content="IE=Edge" http-equiv="X-UA-Compatible" /> <meta content="IE=Edge" http-equiv="X-UA-Compatible"/>
<meta name="description" content="Remote Desktop." /> <meta name="description" content="Remote Desktop."/>
<!-- iOS meta tags & icons --> <!-- iOS meta tags & icons -->
<meta name="apple-mobile-web-app-capable" content="yes" /> <meta name="apple-mobile-web-app-capable" content="yes"/>
<meta name="apple-mobile-web-app-status-bar-style" content="black" /> <meta name="apple-mobile-web-app-status-bar-style" content="black"/>
<meta name="apple-mobile-web-app-title" content="RustDesk" /> <meta name="apple-mobile-web-app-title" content="RustDesk"/>
<link rel="apple-touch-icon" href="icons/Icon-192.png?v=1a7ad736" /> <link rel="apple-touch-icon" href="icons/Icon-192.png?v=1a7ad736"/>
<!-- Favicon --> <!-- Favicon -->
<link rel="icon" type="image/svg+xml" href="favicon.svg?v=8fcccd9a" /> <link rel="icon" type="image/svg+xml" href="favicon.svg?v=8fcccd9a"/>
<title>RustDesk</title> <title>RustDesk</title>
<script src="/webclient-config/index.js"></script> <script src="/webclient-config/index.js"></script>
<link rel="manifest" href="manifest.json" /> <link rel="manifest" href="manifest.json"/>
<script type="module" crossorigin src="js/dist/index.js?v=cabfd933"></script> <script type="module" crossorigin src="js/dist/index.js?v=ddbe54f1"></script>
<link rel="modulepreload" href="js/dist/vendor.js?v=0b990c6e" /> <link rel="modulepreload" href="js/dist/vendor.js?v=0b990c6e"/>
<style> <style>
html, html,
body, body,
@@ -42,6 +42,7 @@
margin: 0; margin: 0;
padding: 0; padding: 0;
} }
#root { #root {
background-repeat: no-repeat; background-repeat: no-repeat;
background-size: 100% auto; background-size: 100% auto;
@@ -63,6 +64,7 @@
justify-content: center; justify-content: center;
padding: 26px; padding: 26px;
} }
.ant-spin { .ant-spin {
position: absolute; position: absolute;
display: none; display: none;
@@ -78,8 +80,7 @@
text-align: center; text-align: center;
list-style: none; list-style: none;
opacity: 0; opacity: 0;
-webkit-transition: -webkit-transform 0.3s -webkit-transition: -webkit-transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86);
cubic-bezier(0.78, 0.14, 0.15, 0.86);
transition: -webkit-transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86); transition: -webkit-transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86);
transition: transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86); transition: transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86);
transition: transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86), transition: transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86),
@@ -198,10 +199,10 @@
} }
} }
</style> </style>
</head> </head>
<body> <body>
<div id="root"> <div id="root">
<div <div
id="div-background" id="div-background"
style=" style="
@@ -213,7 +214,7 @@
min-height: 420px; min-height: 420px;
" "
> >
<img src="./favicon.svg?v=8fcccd9a" alt="logo" width="256" /> <img src="./favicon.svg?v=8fcccd9a" alt="logo" width="256"/>
<div class="page-loading-warp"> <div class="page-loading-warp">
<div class="ant-spin ant-spin-lg ant-spin-spinning"> <div class="ant-spin ant-spin-lg ant-spin-spinning">
<span class="ant-spin-dot ant-spin-dot-spin"> <span class="ant-spin-dot ant-spin-dot-spin">
@@ -226,15 +227,15 @@
<div <div
style="display: flex; align-items: center; justify-content: center" style="display: flex; align-items: center; justify-content: center"
> >
<img src="./favicon.svg?v=8fcccd9a" width="32" style="margin-right: 8px" /> <img src="./favicon.svg?v=8fcccd9a" width="32" style="margin-right: 8px"/>
<span id="span-text">RustDesk Web Client V2 Preview</span> <span id="span-text">RustDesk Web Client V2 Preview</span>
</div> </div>
</div> </div>
</div> </div>
<!-- This script installs service_worker.js to provide PWA functionality to <!-- This script installs service_worker.js to provide PWA functionality to
application. For more information, see: application. For more information, see:
https://developers.google.com/web/fundamentals/primers/service-workers --> https://developers.google.com/web/fundamentals/primers/service-workers -->
<script> <script>
const systemTheme = window.matchMedia("(prefers-color-scheme: dark)") const systemTheme = window.matchMedia("(prefers-color-scheme: dark)")
.matches .matches
? "dark" ? "dark"
@@ -251,15 +252,16 @@
spanConsole.style.color = them === "dark" ? "#fff" : "#000"; spanConsole.style.color = them === "dark" ? "#fff" : "#000";
} }
const serviceWorkerVersion = "3267265270"; const serviceWorkerVersion = "461457302";
var scriptLoaded = false; var scriptLoaded = false;
function loadMainDartJs() { function loadMainDartJs() {
if (scriptLoaded) { if (scriptLoaded) {
return; return;
} }
scriptLoaded = true; scriptLoaded = true;
var scriptTag = document.createElement("script"); var scriptTag = document.createElement("script");
scriptTag.src = "main.dart.js?v=060a626e"; scriptTag.src = "main.dart.js?v=6d16cb80";
scriptTag.type = "application/javascript"; scriptTag.type = "application/javascript";
document.body.append(scriptTag); document.body.append(scriptTag);
} }
@@ -281,6 +283,7 @@
} }
}); });
} }
if (!reg.active && (reg.installing || reg.waiting)) { if (!reg.active && (reg.installing || reg.waiting)) {
// No active web worker and we have installed or are installing // No active web worker and we have installed or are installing
// one for the first time. Simply wait for it to activate. // one for the first time. Simply wait for it to activate.
@@ -313,13 +316,13 @@
// Service workers not supported. Just drop the <script> tag. // Service workers not supported. Just drop the <script> tag.
loadMainDartJs(); loadMainDartJs();
} }
</script> </script>
<script src="libs/stream/ponyfill.min.js"></script> <script src="libs/stream/ponyfill.min.js"></script>
<script src="libs/stream/StreamSaver.min.js"></script> <script src="libs/stream/StreamSaver.min.js"></script>
<script src="libs/firebase-app.js?8.10.1"></script> <script src="libs/firebase-app.js?8.10.1"></script>
<script src="libs/firebase-analytics.js?8.10.1"></script> <script src="libs/firebase-analytics.js?8.10.1"></script>
<script> <script>
// Your web app's Firebase configuration // Your web app's Firebase configuration
// For Firebase JS SDK v7.20.0 and later, measurementId is optional // For Firebase JS SDK v7.20.0 and later, measurementId is optional
const firebaseConfig = { const firebaseConfig = {
@@ -336,6 +339,6 @@
// Initialize Firebase // Initialize Firebase
firebase.initializeApp(firebaseConfig); firebase.initializeApp(firebaseConfig);
firebase.analytics(); firebase.analytics();
</script> </script>
</body> </body>
</html> </html>

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,11 @@
window._gwen = {} window._gwen = {}
window._gwen.kv = {} window._gwen.kv = {}
//fix 语言
if(!localStorage.getItem('wc-option:local:lang') && navigator.language){
localStorage.setItem('wc-option:local:lang', navigator.language.toLowerCase())
}
const storage_prefix = 'wc-' const storage_prefix = 'wc-'
const apiserver = localStorage.getItem('wc-api-server') const apiserver = localStorage.getItem('wc-api-server')
@@ -46,7 +52,7 @@ if (share_token) {
password: peer.tmppwd, password: peer.tmppwd,
}*/ }*/
//修改location //修改location
window.location.href = `/webclient2/#/${peer.info.id}?password=${peer.tmppwd}` window.location.href = `/webclient2/#/${peer.info.id}?password=${encodeURIComponent(peer.tmppwd)}`
} }
}) })
} }

163526
resources/web2/main.dart.js vendored

File diff suppressed because one or more lines are too long

View File

@@ -293,8 +293,11 @@ func (s *AddressBookService) RuleInfoById(u uint) *model.AddressBookCollectionRu
return p return p
} }
func (s *AddressBookService) RulePersonalInfoByToIdAndCid(toid, cid uint) *model.AddressBookCollectionRule { func (s *AddressBookService) RulePersonalInfoByToIdAndCid(toid, cid uint) *model.AddressBookCollectionRule {
return s.RuleInfoByToIdAndCid(model.ShareAddressBookRuleTypePersonal, toid, cid)
}
func (s *AddressBookService) RuleInfoByToIdAndCid(t int, toid, cid uint) *model.AddressBookCollectionRule {
p := &model.AddressBookCollectionRule{} p := &model.AddressBookCollectionRule{}
DB.Where("type = ? and to_id = ? and collection_id = ?", model.ShareAddressBookRuleTypePersonal, toid, cid).First(p) DB.Where("type = ? and to_id = ? and collection_id = ?", t, toid, cid).First(p)
return p return p
} }
func (s *AddressBookService) CreateRule(t *model.AddressBookCollectionRule) error { func (s *AddressBookService) CreateRule(t *model.AddressBookCollectionRule) error {

View File

@@ -3,13 +3,14 @@ package service
import ( import (
"os" "os"
"sync" "sync"
"time"
) )
type AppService struct { type AppService struct {
} }
var version = "" var version = ""
var startTime = ""
var once = &sync.Once{} var once = &sync.Once{}
func (a *AppService) GetAppVersion() string { func (a *AppService) GetAppVersion() string {
@@ -26,3 +27,13 @@ func (a *AppService) GetAppVersion() string {
}) })
return version return version
} }
func init() {
// Initialize the AppService if needed
startTime = time.Now().Format("2006-01-02 15:04:05")
}
// GetStartTime
func (a *AppService) GetStartTime() string {
return startTime
}

View File

@@ -411,7 +411,7 @@ func (ls *LdapService) isUserAdmin(cfg *config.Ldap, ldapUser *LdapUser) bool {
// Check "memberOf" directly // Check "memberOf" directly
if len(ldapUser.MemberOf) > 0 { if len(ldapUser.MemberOf) > 0 {
for _, group := range ldapUser.MemberOf { for _, group := range ldapUser.MemberOf {
if group == adminGroup { if strings.EqualFold(group, adminGroup) {
return true return true
} }
} }

View File

@@ -154,6 +154,18 @@ func (os *OauthService) GithubProvider() *oidc.Provider {
}).NewProvider(context.Background()) }).NewProvider(context.Background())
} }
func (os *OauthService) LinuxdoProvider() *oidc.Provider {
return (&oidc.ProviderConfig{
IssuerURL: "",
AuthURL: "https://connect.linux.do/oauth2/authorize",
TokenURL: "https://connect.linux.do/oauth2/token",
DeviceAuthURL: "",
UserInfoURL: model.UserEndpointLinuxdo,
JWKSURL: "",
Algorithms: nil,
}).NewProvider(context.Background())
}
// GetOauthConfig retrieves the OAuth2 configuration based on the provider name // GetOauthConfig retrieves the OAuth2 configuration based on the provider name
func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config, provider *oidc.Provider) { func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config, provider *oidc.Provider) {
//err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op) //err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
@@ -182,6 +194,10 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O
oauthConfig.Endpoint = github.Endpoint oauthConfig.Endpoint = github.Endpoint
oauthConfig.Scopes = []string{"read:user", "user:email"} oauthConfig.Scopes = []string{"read:user", "user:email"}
provider = os.GithubProvider() provider = os.GithubProvider()
case model.OauthTypeLinuxdo:
provider = os.LinuxdoProvider()
oauthConfig.Endpoint = provider.Endpoint()
oauthConfig.Scopes = []string{"profile"}
//case model.OauthTypeGoogle: //google单独出来可以少一次FetchOidcEndpoint请求 //case model.OauthTypeGoogle: //google单独出来可以少一次FetchOidcEndpoint请求
// oauthConfig.Endpoint = google.Endpoint // oauthConfig.Endpoint = google.Endpoint
// oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes) // oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
@@ -299,6 +315,16 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oid
return nil, user.ToOauthUser() return nil, user.ToOauthUser()
} }
// linuxdoCallback linux.do回调
func (os *OauthService) linuxdoCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) {
var user = &model.LinuxdoUser{}
err, _ := os.callbackBase(oauthConfig, provider, code, verifier, nonce, user)
if err != nil {
return err, nil
}
return nil, user.ToOauthUser()
}
// oidcCallback oidc回调, 通过code获取用户信息 // oidcCallback oidc回调, 通过code获取用户信息
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) { func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) {
var user = &model.OidcUser{} var user = &model.OidcUser{}
@@ -319,6 +345,8 @@ func (os *OauthService) Callback(code, verifier, op, nonce string) (err error, o
switch oauthType { switch oauthType {
case model.OauthTypeGithub: case model.OauthTypeGithub:
err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier, nonce) err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier, nonce)
case model.OauthTypeLinuxdo:
err, oauthUser = os.linuxdoCallback(oauthConfig, provider, code, verifier, nonce)
case model.OauthTypeOidc, model.OauthTypeGoogle: case model.OauthTypeOidc, model.OauthTypeGoogle:
err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier, nonce) err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier, nonce)
default: default:

View File

@@ -40,14 +40,7 @@ func (is *ServerCmdService) Create(u *model.ServerCmd) error {
} }
// SendCmd 发送命令 // SendCmd 发送命令
func (is *ServerCmdService) SendCmd(target string, cmd string, arg string) (string, error) { func (is *ServerCmdService) SendCmd(port int, cmd string, arg string) (string, error) {
port := 0
switch target {
case model.ServerCmdTargetIdServer:
port = Config.Rustdesk.IdServerPort - 1
case model.ServerCmdTargetRelayServer:
port = Config.Rustdesk.RelayServerPort
}
//组装命令 //组装命令
cmd = cmd + " " + arg cmd = cmd + " " + arg
res, err := is.SendSocketCmd("v6", port, cmd) res, err := is.SendSocketCmd("v6", port, cmd)

View File

@@ -412,12 +412,13 @@ func (us *UserService) IsPasswordEmptyByUser(u *model.User) bool {
} }
// Register 注册, 如果用户名已存在则返回nil // Register 注册, 如果用户名已存在则返回nil
func (us *UserService) Register(username string, email string, password string) *model.User { func (us *UserService) Register(username string, email string, password string, status model.StatusCode) *model.User {
u := &model.User{ u := &model.User{
Username: username, Username: username,
Email: email, Email: email,
Password: password, Password: password,
GroupId: 1, GroupId: 1,
Status: status,
} }
err := us.Create(u) err := us.Create(u)
if err != nil { if err != nil {

48
utils/captcha.go Normal file
View File

@@ -0,0 +1,48 @@
package utils
import (
"github.com/mojocn/base64Captcha"
"time"
)
var capdString = base64Captcha.NewDriverString(50, 150, 0, 5, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
var capdMath = base64Captcha.NewDriverMath(50, 150, 3, 10, nil, nil, nil)
type B64StringCaptchaProvider struct{}
func (p B64StringCaptchaProvider) Generate() (string, string, string, error) {
id, content, answer := capdString.GenerateIdQuestionAnswer()
return id, content, answer, nil
}
func (p B64StringCaptchaProvider) Expiration() time.Duration {
return 5 * time.Minute
}
func (p B64StringCaptchaProvider) Draw(content string) (string, error) {
item, err := capdString.DrawCaptcha(content)
if err != nil {
return "", err
}
b64str := item.EncodeB64string()
return b64str, nil
}
type B64MathCaptchaProvider struct{}
func (p B64MathCaptchaProvider) Generate() (string, string, string, error) {
id, content, answer := capdMath.GenerateIdQuestionAnswer()
return id, content, answer, nil
}
func (p B64MathCaptchaProvider) Expiration() time.Duration {
return 5 * time.Minute
}
func (p B64MathCaptchaProvider) Draw(content string) (string, error) {
item, err := capdMath.DrawCaptcha(content)
if err != nil {
return "", err
}
b64str := item.EncodeB64string()
return b64str, nil
}

296
utils/login_limiter.go Normal file
View File

@@ -0,0 +1,296 @@
package utils
import (
"errors"
"sync"
"time"
)
// 安全策略配置
type SecurityPolicy struct {
CaptchaThreshold int // 尝试失败次数达到验证码阈值小于0表示不启用, 0表示强制启用
BanThreshold int // 尝试失败次数达到封禁阈值为0表示不启用
AttemptsWindow time.Duration
BanDuration time.Duration
}
// 验证码提供者接口
type CaptchaProvider interface {
Generate() (id string, content string, answer string, err error)
//Validate(ip, code string) bool
Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow
Draw(content string) (string, error) // 绘制验证码
}
// 验证码元数据
type CaptchaMeta struct {
Id string
Content string
Answer string
ExpiresAt time.Time
}
// IP封禁记录
type BanRecord struct {
ExpiresAt time.Time
Reason string
}
// 登录限制器
type LoginLimiter struct {
mu sync.Mutex
policy SecurityPolicy
attempts map[string][]time.Time //
captchas map[string]CaptchaMeta
bannedIPs map[string]BanRecord
provider CaptchaProvider
cleanupStop chan struct{}
}
var defaultSecurityPolicy = SecurityPolicy{
CaptchaThreshold: 3,
BanThreshold: 5,
AttemptsWindow: 5 * time.Minute,
BanDuration: 30 * time.Minute,
}
func NewLoginLimiter(policy SecurityPolicy) *LoginLimiter {
// 设置默认值
if policy.AttemptsWindow == 0 {
policy.AttemptsWindow = 5 * time.Minute
}
if policy.BanDuration == 0 {
policy.BanDuration = 30 * time.Minute
}
ll := &LoginLimiter{
policy: policy,
attempts: make(map[string][]time.Time),
captchas: make(map[string]CaptchaMeta),
bannedIPs: make(map[string]BanRecord),
cleanupStop: make(chan struct{}),
}
go ll.cleanupRoutine()
return ll
}
// 注册验证码提供者
func (ll *LoginLimiter) RegisterProvider(p CaptchaProvider) {
ll.mu.Lock()
defer ll.mu.Unlock()
ll.provider = p
}
// isDisabled 检查是否禁用登录限制
func (ll *LoginLimiter) isDisabled() bool {
return ll.policy.CaptchaThreshold < 0 && ll.policy.BanThreshold == 0
}
// 记录登录失败尝试
func (ll *LoginLimiter) RecordFailedAttempt(ip string) {
if ll.isDisabled() {
return
}
ll.mu.Lock()
defer ll.mu.Unlock()
if banned, _ := ll.isBanned(ip); banned {
return
}
now := time.Now()
windowStart := now.Add(-ll.policy.AttemptsWindow)
// 清理过期尝试
validAttempts := ll.pruneAttempts(ip, windowStart)
// 记录新尝试
validAttempts = append(validAttempts, now)
ll.attempts[ip] = validAttempts
// 检查封禁条件
if ll.policy.BanThreshold > 0 && len(validAttempts) >= ll.policy.BanThreshold {
ll.banIP(ip, "excessive failed attempts")
return
}
return
}
// 生成验证码
func (ll *LoginLimiter) RequireCaptcha() (error, CaptchaMeta) {
ll.mu.Lock()
defer ll.mu.Unlock()
if ll.provider == nil {
return errors.New("no captcha provider available"), CaptchaMeta{}
}
id, content, answer, err := ll.provider.Generate()
if err != nil {
return err, CaptchaMeta{}
}
// 存储验证码
ll.captchas[id] = CaptchaMeta{
Id: id,
Content: content,
Answer: answer,
ExpiresAt: time.Now().Add(ll.provider.Expiration()),
}
return nil, ll.captchas[id]
}
// 验证验证码
func (ll *LoginLimiter) VerifyCaptcha(id, answer string) bool {
ll.mu.Lock()
defer ll.mu.Unlock()
// 查找匹配验证码
if ll.provider == nil {
return false
}
// 获取并验证验证码
captcha, exists := ll.captchas[id]
if !exists {
return false
}
// 清理过期验证码
if time.Now().After(captcha.ExpiresAt) {
delete(ll.captchas, id)
return false
}
// 验证并清理状态
if answer == captcha.Answer {
delete(ll.captchas, id)
return true
}
return false
}
func (ll *LoginLimiter) DrawCaptcha(content string) (err error, str string) {
str, err = ll.provider.Draw(content)
return
}
// 清除记录窗口
func (ll *LoginLimiter) RemoveAttempts(ip string) {
ll.mu.Lock()
defer ll.mu.Unlock()
_, exists := ll.attempts[ip]
if exists {
delete(ll.attempts, ip)
}
}
// CheckSecurityStatus 检查安全状态
func (ll *LoginLimiter) CheckSecurityStatus(ip string) (banned bool, captchaRequired bool) {
if ll.isDisabled() {
return
}
ll.mu.Lock()
defer ll.mu.Unlock()
// 检查封禁状态
if banned, _ = ll.isBanned(ip); banned {
return
}
// 清理过期数据
ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow))
// 检查验证码要求
captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold
return
}
// 后台清理任务
func (ll *LoginLimiter) cleanupRoutine() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ll.cleanupExpired()
case <-ll.cleanupStop:
return
}
}
}
// 内部工具方法
func (ll *LoginLimiter) isBanned(ip string) (bool, BanRecord) {
record, exists := ll.bannedIPs[ip]
if !exists {
return false, BanRecord{}
}
if time.Now().After(record.ExpiresAt) {
delete(ll.bannedIPs, ip)
return false, BanRecord{}
}
return true, record
}
func (ll *LoginLimiter) banIP(ip, reason string) {
ll.bannedIPs[ip] = BanRecord{
ExpiresAt: time.Now().Add(ll.policy.BanDuration),
Reason: reason,
}
delete(ll.attempts, ip)
delete(ll.captchas, ip)
}
func (ll *LoginLimiter) pruneAttempts(ip string, cutoff time.Time) []time.Time {
var valid []time.Time
for _, t := range ll.attempts[ip] {
if t.After(cutoff) {
valid = append(valid, t)
}
}
if len(valid) == 0 {
delete(ll.attempts, ip)
} else {
ll.attempts[ip] = valid
}
return valid
}
func (ll *LoginLimiter) pruneCaptchas(id string) {
if captcha, exists := ll.captchas[id]; exists {
if time.Now().After(captcha.ExpiresAt) {
delete(ll.captchas, id)
}
}
}
func (ll *LoginLimiter) cleanupExpired() {
ll.mu.Lock()
defer ll.mu.Unlock()
now := time.Now()
// 清理封禁记录
for ip, record := range ll.bannedIPs {
if now.After(record.ExpiresAt) {
delete(ll.bannedIPs, ip)
}
}
// 清理尝试记录
for ip := range ll.attempts {
ll.pruneAttempts(ip, now.Add(-ll.policy.AttemptsWindow))
}
// 清理验证码
for id := range ll.captchas {
ll.pruneCaptchas(id)
}
}

290
utils/login_limiter_test.go Normal file
View File

@@ -0,0 +1,290 @@
package utils
import (
"fmt"
"github.com/google/uuid"
"testing"
"time"
)
type MockCaptchaProvider struct{}
func (p *MockCaptchaProvider) Generate() (string, string, string, error) {
id := uuid.New().String()
content := uuid.New().String()
answer := uuid.New().String()
return id, content, answer, nil
}
func (p *MockCaptchaProvider) Expiration() time.Duration {
return 2 * time.Second
}
func (p *MockCaptchaProvider) Draw(content string) (string, error) {
return "MOCK", nil
}
func TestSecurityWorkflow(t *testing.T) {
policy := SecurityPolicy{
CaptchaThreshold: 3,
BanThreshold: 5,
AttemptsWindow: 5 * time.Minute,
BanDuration: 5 * time.Minute,
}
limiter := NewLoginLimiter(policy)
ip := "192.168.1.100"
// 测试正常失败记录
for i := 0; i < 3; i++ {
limiter.RecordFailedAttempt(ip)
}
isBanned, capRequired := limiter.CheckSecurityStatus(ip)
fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, isBanned, capRequired)
if isBanned {
t.Error("IP should not be banned yet")
}
if !capRequired {
t.Error("Captcha should be required")
}
// 测试触发封禁
for i := 0; i < 3; i++ {
limiter.RecordFailedAttempt(ip)
isBanned, capRequired = limiter.CheckSecurityStatus(ip)
fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, isBanned, capRequired)
}
// 测试封禁状态
if isBanned, _ = limiter.CheckSecurityStatus(ip); !isBanned {
t.Error("IP should be banned")
}
}
func TestCaptchaFlow(t *testing.T) {
policy := SecurityPolicy{CaptchaThreshold: 2}
limiter := NewLoginLimiter(policy)
limiter.RegisterProvider(&MockCaptchaProvider{})
ip := "10.0.0.1"
// 触发验证码要求
limiter.RecordFailedAttempt(ip)
limiter.RecordFailedAttempt(ip)
// 检查状态
if _, need := limiter.CheckSecurityStatus(ip); !need {
t.Error("应该需要验证码")
}
// 生成验证码
err, capc := limiter.RequireCaptcha()
if err != nil {
t.Fatalf("生成验证码失败: %v", err)
}
fmt.Printf("验证码内容: %#v\n", capc)
// 验证成功
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该验证成功")
}
// 验证已删除
if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该已删除")
}
limiter.RemoveAttempts(ip)
// 验证后状态
if banned, need := limiter.CheckSecurityStatus(ip); banned || need {
t.Error("验证成功后应该重置状态")
}
}
func TestCaptchaMustFlow(t *testing.T) {
policy := SecurityPolicy{CaptchaThreshold: 0}
limiter := NewLoginLimiter(policy)
limiter.RegisterProvider(&MockCaptchaProvider{})
ip := "10.0.0.1"
// 检查状态
if _, need := limiter.CheckSecurityStatus(ip); !need {
t.Error("应该需要验证码")
}
// 生成验证码
err, capc := limiter.RequireCaptcha()
if err != nil {
t.Fatalf("生成验证码失败: %v", err)
}
fmt.Printf("验证码内容: %#v\n", capc)
// 验证成功
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该验证成功")
}
// 验证后状态
if _, need := limiter.CheckSecurityStatus(ip); !need {
t.Error("应该需要验证码")
}
}
func TestAttemptTimeout(t *testing.T) {
policy := SecurityPolicy{CaptchaThreshold: 2, AttemptsWindow: 1 * time.Second}
limiter := NewLoginLimiter(policy)
limiter.RegisterProvider(&MockCaptchaProvider{})
ip := "10.0.0.1"
// 触发验证码要求
limiter.RecordFailedAttempt(ip)
limiter.RecordFailedAttempt(ip)
// 检查状态
if _, need := limiter.CheckSecurityStatus(ip); !need {
t.Error("应该需要验证码")
}
// 生成验证码
err, _ := limiter.RequireCaptcha()
if err != nil {
t.Fatalf("生成验证码失败: %v", err)
}
// 等待超过 AttemptsWindow
time.Sleep(2 * time.Second)
// 触发验证码要求
limiter.RecordFailedAttempt(ip)
// 检查状态
if _, need := limiter.CheckSecurityStatus(ip); need {
t.Error("不应该需要验证码")
}
}
func TestCaptchaTimeout(t *testing.T) {
policy := SecurityPolicy{CaptchaThreshold: 2}
limiter := NewLoginLimiter(policy)
limiter.RegisterProvider(&MockCaptchaProvider{})
ip := "10.0.0.1"
// 触发验证码要求
limiter.RecordFailedAttempt(ip)
limiter.RecordFailedAttempt(ip)
// 检查状态
if _, need := limiter.CheckSecurityStatus(ip); !need {
t.Error("应该需要验证码")
}
// 生成验证码
err, capc := limiter.RequireCaptcha()
if err != nil {
t.Fatalf("生成验证码失败: %v", err)
}
// 等待超过 CaptchaValidPeriod
time.Sleep(3 * time.Second)
// 验证成功
if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该已过期")
}
}
func TestBanFlow(t *testing.T) {
policy := SecurityPolicy{BanThreshold: 5}
limiter := NewLoginLimiter(policy)
ip := "10.0.0.1"
// 触发ban
for i := 0; i < 5; i++ {
limiter.RecordFailedAttempt(ip)
}
// 检查状态
if banned, _ := limiter.CheckSecurityStatus(ip); !banned {
t.Error("should be banned")
}
}
func TestBanDisableFlow(t *testing.T) {
policy := SecurityPolicy{BanThreshold: 0}
limiter := NewLoginLimiter(policy)
ip := "10.0.0.1"
// 触发ban
for i := 0; i < 5; i++ {
limiter.RecordFailedAttempt(ip)
}
// 检查状态
if banned, _ := limiter.CheckSecurityStatus(ip); banned {
t.Error("should not be banned")
}
}
func TestBanTimeout(t *testing.T) {
policy := SecurityPolicy{BanThreshold: 5, BanDuration: 1 * time.Second}
limiter := NewLoginLimiter(policy)
ip := "10.0.0.1"
// 触发ban
// 触发ban
for i := 0; i < 5; i++ {
limiter.RecordFailedAttempt(ip)
}
time.Sleep(2 * time.Second)
// 检查状态
if banned, _ := limiter.CheckSecurityStatus(ip); banned {
t.Error("should not be banned")
}
}
func TestLimiterDisabled(t *testing.T) {
policy := SecurityPolicy{BanThreshold: 0, CaptchaThreshold: -1}
limiter := NewLoginLimiter(policy)
ip := "10.0.0.1"
// 触发ban
for i := 0; i < 5; i++ {
limiter.RecordFailedAttempt(ip)
}
// 检查状态
if banned, capNeed := limiter.CheckSecurityStatus(ip); banned || capNeed {
fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, banned, capNeed)
t.Error("should not be banned or need captcha")
}
}
func TestB64CaptchaFlow(t *testing.T) {
limiter := NewLoginLimiter(defaultSecurityPolicy)
limiter.RegisterProvider(B64StringCaptchaProvider{})
ip := "10.0.0.1"
// 触发验证码要求
limiter.RecordFailedAttempt(ip)
limiter.RecordFailedAttempt(ip)
limiter.RecordFailedAttempt(ip)
// 检查状态
if _, need := limiter.CheckSecurityStatus(ip); !need {
t.Error("应该需要验证码")
}
// 生成验证码
err, capc := limiter.RequireCaptcha()
if err != nil {
t.Fatalf("生成验证码失败: %v", err)
}
fmt.Printf("验证码内容: %#v\n", capc)
//draw
err, b64 := limiter.DrawCaptcha(capc.Content)
if err != nil {
t.Fatalf("绘制验证码失败: %v", err)
}
fmt.Printf("验证码内容: %#v\n", b64)
// 验证成功
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该验证成功")
}
limiter.RemoveAttempts(ip)
// 验证后状态
if banned, need := limiter.CheckSecurityStatus(ip); banned || need {
t.Error("验证成功后应该重置状态")
}
}

View File

@@ -7,6 +7,7 @@ import (
"math/rand" "math/rand"
"reflect" "reflect"
"runtime/debug" "runtime/debug"
"strings"
) )
func Md5(str string) string { func Md5(str string) string {
@@ -100,3 +101,11 @@ func InArray(k string, arr []string) bool {
} }
return false return false
} }
func StringConcat(strs ...string) string {
var builder strings.Builder
for _, str := range strs {
builder.WriteString(str)
}
return builder.String()
}