Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
65f0a9e3cf | ||
|
|
77836a4e56 | ||
|
|
09f8316bf1 | ||
|
|
c52706e621 | ||
|
|
17dcff4f43 | ||
|
|
0b39c4e104 | ||
|
|
ee176b314e | ||
|
|
1ffc9c4a5b | ||
|
|
1257246552 | ||
|
|
2948eaaa5c | ||
|
|
8641ba5c0c | ||
|
|
60b7a18fe7 | ||
|
|
ca068816ae | ||
|
|
06648d9a6c | ||
|
|
8a8abd5163 | ||
|
|
97f98cd6ce | ||
|
|
51f2920661 | ||
|
|
7a5d141ce8 | ||
|
|
3cef02a0bb | ||
|
|
46a7ecc1ba |
7
.github/workflows/build.yml
vendored
7
.github/workflows/build.yml
vendored
@@ -66,7 +66,7 @@ jobs:
|
||||
- name: Set up Go environment
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.22' # 选择 Go 版本
|
||||
go-version: '1.23' # 选择 Go 版本
|
||||
|
||||
- name: Set up npm
|
||||
uses: actions/setup-node@v2
|
||||
@@ -115,12 +115,12 @@ jobs:
|
||||
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
|
||||
else
|
||||
if [ "${{ matrix.job.platform }}" = "arm64" ]; then
|
||||
wget https://musl.cc/aarch64-linux-musl-cross.tgz
|
||||
wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
|
||||
tar -xf aarch64-linux-musl-cross.tgz
|
||||
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
|
||||
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
||||
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
|
||||
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz
|
||||
wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
|
||||
tar -xf armv7l-linux-musleabihf-cross.tgz
|
||||
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
|
||||
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
||||
@@ -147,6 +147,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Generate Changelog
|
||||
if: startsWith(github.ref, 'refs/tags/') && github.event_name == 'push'
|
||||
run: npx changelogithub # or changelogithub@0.12 if ensure the stable result
|
||||
env:
|
||||
GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
|
||||
|
||||
6
.github/workflows/build_test.yml
vendored
6
.github/workflows/build_test.yml
vendored
@@ -61,7 +61,7 @@ jobs:
|
||||
- name: Set up Go environment
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.22' # 选择 Go 版本
|
||||
go-version: '1.23' # 选择 Go 版本
|
||||
|
||||
- name: Set up npm
|
||||
uses: actions/setup-node@v2
|
||||
@@ -101,12 +101,12 @@ jobs:
|
||||
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
|
||||
else
|
||||
if [ "${{ matrix.job.platform }}" = "arm64" ]; then
|
||||
wget https://musl.cc/aarch64-linux-musl-cross.tgz
|
||||
wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
|
||||
tar -xf aarch64-linux-musl-cross.tgz
|
||||
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
|
||||
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
||||
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
|
||||
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz
|
||||
wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
|
||||
tar -xf armv7l-linux-musleabihf-cross.tgz
|
||||
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
|
||||
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -5,4 +5,4 @@ runtime/*
|
||||
go.sum
|
||||
resources/admin
|
||||
release
|
||||
data
|
||||
data/rustdeskapi.db
|
||||
@@ -42,11 +42,11 @@ RUN if [ "$COUNTRY" = "CN" ] ; then \
|
||||
fi && \
|
||||
apk update && apk add --no-cache git
|
||||
|
||||
ARG FREONTEND_GIT_REPO=https://github.com/lejianwen/rustdesk-api-web.git
|
||||
ARG FRONTEND_GIT_REPO=https://github.com/lejianwen/rustdesk-api-web.git
|
||||
ARG FRONTEND_GIT_BRANCH=master
|
||||
# Clone the frontend repository
|
||||
|
||||
RUN git clone -b $FRONTEND_GIT_BRANCH $FREONTEND_GIT_REPO .
|
||||
RUN git clone -b $FRONTEND_GIT_BRANCH $FRONTEND_GIT_REPO .
|
||||
|
||||
# Install required tools without caching index to minimize image size
|
||||
RUN if [ "$COUNTRY" = "CN" ] ; then \
|
||||
@@ -91,4 +91,4 @@ VOLUME /app/data
|
||||
EXPOSE 21114
|
||||
|
||||
# Define the command to run the application
|
||||
CMD ["./apimain"]
|
||||
CMD ["./apimain"]
|
||||
|
||||
@@ -255,6 +255,12 @@
|
||||
#或者使用generate_api.go生成api并运行
|
||||
go generate generate_api.go
|
||||
```
|
||||
> 注意:使用 `go run` 或编译后的二进制时,当前目录下必须存在 `conf` 和 `resources`
|
||||
> 目录。如果在其他目录运行,可通过 `-c` 和环境变量
|
||||
> `RUSTDESK_API_GIN_RESOURCES_PATH` 指定绝对路径,例如:
|
||||
> ```bash
|
||||
> RUSTDESK_API_GIN_RESOURCES_PATH=/opt/rustdesk-api/resources ./apimain -c /opt/rustdesk-api/conf/config.yaml
|
||||
> ```
|
||||
5. 编译,如果想自己编译,先cd到项目根目录,然后windows下直接运行`build.bat`,linux下运行`build.sh`,编译后会在`release`
|
||||
目录下生成对应的可执行文件。直接运行编译后的可执行文件即可。
|
||||
|
||||
|
||||
18
README_EN.md
18
README_EN.md
@@ -164,7 +164,8 @@ The table below does not list all configurations. Please refer to the configurat
|
||||
| RUSTDESK_API_APP_DISABLE_PWD_LOGIN | disable password login | `false` |
|
||||
| RUSTDESK_API_APP_REGISTER_STATUS | register user default status ; 1 enabled , 2 disabled ; default 1 | `1` |
|
||||
| RUSTDESK_API_APP_CAPTCHA_THRESHOLD | captcha threshold; -1 disabled, 0 always enable, >0 threshold ;default `3` | `3` |
|
||||
| RUSTDESK_API_APP_BAN_THRESHOLD | ban ip threshold; 0 disabled, >0 threshold ; default `0` | `0` |
|
||||
| RUSTDESK_API_APP_BAN_THRESHOLD | ban ip threshold; 0 disabled, >0 threshold ; default `0`
|
||||
| `0` |
|
||||
| ----- ADMIN Configuration----- | ---------- | ---------- |
|
||||
| RUSTDESK_API_ADMIN_TITLE | Admin Title | `RustDesk Api Admin` |
|
||||
| RUSTDESK_API_ADMIN_HELLO | Admin welcome message, you can use `html` | |
|
||||
@@ -251,10 +252,17 @@ Download the release from [release](https://github.com/lejianwen/rustdesk-api/re
|
||||
4. Run:
|
||||
```bash
|
||||
# Run directly
|
||||
go run cmd/apimain.go
|
||||
# Or generate and run the API using generate_api.go
|
||||
go generate generate_api.go
|
||||
```
|
||||
go run cmd/apimain.go
|
||||
# Or generate and run the API using generate_api.go
|
||||
go generate generate_api.go
|
||||
```
|
||||
> **Note:** When using `go run` or the compiled binary, the `conf` and `resources`
|
||||
> directories must exist relative to the current working directory. If you run
|
||||
> the program from another location, specify absolute paths with `-c` and the
|
||||
> `RUSTDESK_API_GIN_RESOURCES_PATH` environment variable. Example:
|
||||
> ```bash
|
||||
> RUSTDESK_API_GIN_RESOURCES_PATH=/opt/rustdesk-api/resources ./apimain -c /opt/rustdesk-api/conf/config.yaml
|
||||
> ```
|
||||
|
||||
5. To compile, change to the project root directory. For Windows, run `build.bat`, and for Linux, run `build.sh`. After
|
||||
compiling, the corresponding executables will be generated in the `release` directory. Run the compiled executables
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/lejianwen/rustdesk-api/v2/config"
|
||||
"github.com/lejianwen/rustdesk-api/v2/global"
|
||||
@@ -16,11 +21,10 @@ import (
|
||||
"github.com/lejianwen/rustdesk-api/v2/utils"
|
||||
"github.com/nicksnyder/go-i18n/v2/i18n"
|
||||
"github.com/spf13/cobra"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const DatabaseVersion = 264
|
||||
|
||||
// @title 管理系统API
|
||||
// @version 1.0
|
||||
// @description 接口
|
||||
@@ -140,18 +144,40 @@ func InitGlobal() {
|
||||
}
|
||||
//gorm
|
||||
if global.Config.Gorm.Type == config.TypeMysql {
|
||||
dns := global.Config.Mysql.Username + ":" + global.Config.Mysql.Password + "@(" + global.Config.Mysql.Addr + ")/" + global.Config.Mysql.Dbname + "?charset=utf8mb4&parseTime=True&loc=Local"
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
global.Config.Mysql.Username,
|
||||
global.Config.Mysql.Password,
|
||||
global.Config.Mysql.Addr,
|
||||
global.Config.Mysql.Dbname,
|
||||
)
|
||||
|
||||
global.DB = orm.NewMysql(&orm.MysqlConfig{
|
||||
Dns: dns,
|
||||
Dsn: dsn,
|
||||
MaxIdleConns: global.Config.Gorm.MaxIdleConns,
|
||||
MaxOpenConns: global.Config.Gorm.MaxOpenConns,
|
||||
})
|
||||
}, global.Logger)
|
||||
} else if global.Config.Gorm.Type == config.TypePostgresql {
|
||||
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
|
||||
global.Config.Postgresql.Host,
|
||||
global.Config.Postgresql.Port,
|
||||
global.Config.Postgresql.User,
|
||||
global.Config.Postgresql.Password,
|
||||
global.Config.Postgresql.Dbname,
|
||||
global.Config.Postgresql.Sslmode,
|
||||
global.Config.Postgresql.TimeZone,
|
||||
)
|
||||
global.DB = orm.NewPostgresql(&orm.PostgresqlConfig{
|
||||
Dsn: dsn,
|
||||
MaxIdleConns: global.Config.Gorm.MaxIdleConns,
|
||||
MaxOpenConns: global.Config.Gorm.MaxOpenConns,
|
||||
}, global.Logger)
|
||||
} else {
|
||||
//sqlite
|
||||
global.DB = orm.NewSqlite(&orm.SqliteConfig{
|
||||
MaxIdleConns: global.Config.Gorm.MaxIdleConns,
|
||||
MaxOpenConns: global.Config.Gorm.MaxOpenConns,
|
||||
})
|
||||
}, global.Logger)
|
||||
}
|
||||
|
||||
//validator
|
||||
@@ -187,7 +213,7 @@ func InitGlobal() {
|
||||
}
|
||||
|
||||
func DatabaseAutoUpdate() {
|
||||
version := 262
|
||||
version := DatabaseVersion
|
||||
|
||||
db := global.DB
|
||||
|
||||
@@ -197,11 +223,17 @@ func DatabaseAutoUpdate() {
|
||||
if dbName == "" {
|
||||
dbName = global.Config.Mysql.Dbname
|
||||
// 移除 DSN 中的数据库名称,以便初始连接时不指定数据库
|
||||
dsnWithoutDB := global.Config.Mysql.Username + ":" + global.Config.Mysql.Password + "@(" + global.Config.Mysql.Addr + ")/?charset=utf8mb4&parseTime=True&loc=Local"
|
||||
dsnWithoutDB := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
global.Config.Mysql.Username,
|
||||
global.Config.Mysql.Password,
|
||||
global.Config.Mysql.Addr,
|
||||
"",
|
||||
)
|
||||
|
||||
//新链接
|
||||
dbWithoutDB := orm.NewMysql(&orm.MysqlConfig{
|
||||
Dns: dsnWithoutDB,
|
||||
})
|
||||
Dsn: dsnWithoutDB,
|
||||
}, global.Logger)
|
||||
// 获取底层的 *sql.DB 对象,并确保在程序退出时关闭连接
|
||||
sqlDBWithoutDB, err := dbWithoutDB.DB()
|
||||
if err != nil {
|
||||
@@ -313,7 +345,11 @@ func Migrate(version uint) {
|
||||
// 生成随机密码
|
||||
pwd := utils.RandomString(8)
|
||||
global.Logger.Info("Admin Password Is: ", pwd)
|
||||
admin.Password = service.AllService.UserService.EncryptPassword(pwd)
|
||||
var err error
|
||||
admin.Password, err = utils.EncryptPassword(pwd)
|
||||
if err != nil {
|
||||
global.Logger.Fatalf("failed to generate admin password: %v", err)
|
||||
}
|
||||
global.DB.Create(admin)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,9 +11,12 @@ app:
|
||||
disable-pwd-login: false #禁用密码登录
|
||||
|
||||
admin:
|
||||
title: "RustDesk Api Admin"
|
||||
title: "RustDesk API Admin"
|
||||
hello-file: "./conf/admin/hello.html" #优先使用file
|
||||
hello: ""
|
||||
# ID Server and Relay Server ports https://github.com/lejianwen/rustdesk-api/issues/257
|
||||
id-server-port: 21116 # ID Server port (for server cmd)
|
||||
relay-server-port: 21117 # ID Server port (for server cmd)
|
||||
gin:
|
||||
api-addr: "0.0.0.0:21114"
|
||||
mode: "release" #release,debug,test
|
||||
@@ -28,6 +31,16 @@ mysql:
|
||||
password: ""
|
||||
addr: ""
|
||||
dbname: ""
|
||||
|
||||
postgresql:
|
||||
host: "127.0.0.1"
|
||||
port: "5432"
|
||||
user: ""
|
||||
password: ""
|
||||
dbname: "postgres"
|
||||
sslmode: "disable" # disable, require, verify-ca, verify-full
|
||||
time-zone: "Asia/Shanghai" # Time zone for PostgreSQL connection
|
||||
|
||||
rustdesk:
|
||||
id-server: "192.168.1.66:21116"
|
||||
relay-server: "192.168.1.66:21117"
|
||||
|
||||
@@ -25,25 +25,37 @@ type App struct {
|
||||
BanThreshold int `mapstructure:"ban-threshold"`
|
||||
}
|
||||
type Admin struct {
|
||||
Title string `mapstructure:"title"`
|
||||
Hello string `mapstructure:"hello"`
|
||||
HelloFile string `mapstructure:"hello-file"`
|
||||
Title string `mapstructure:"title"`
|
||||
Hello string `mapstructure:"hello"`
|
||||
HelloFile string `mapstructure:"hello-file"`
|
||||
IdServerPort int `mapstructure:"id-server-port"`
|
||||
RelayServerPort int `mapstructure:"relay-server-port"`
|
||||
}
|
||||
type Config struct {
|
||||
Lang string `mapstructure:"lang"`
|
||||
App App
|
||||
Admin Admin
|
||||
Gorm Gorm
|
||||
Mysql Mysql
|
||||
Gin Gin
|
||||
Logger Logger
|
||||
Redis Redis
|
||||
Cache Cache
|
||||
Oss Oss
|
||||
Jwt Jwt
|
||||
Rustdesk Rustdesk
|
||||
Proxy Proxy
|
||||
Ldap Ldap
|
||||
Lang string `mapstructure:"lang"`
|
||||
App App
|
||||
Admin Admin
|
||||
Gorm Gorm
|
||||
Mysql Mysql
|
||||
Postgresql Postgresql
|
||||
Gin Gin
|
||||
Logger Logger
|
||||
Redis Redis
|
||||
Cache Cache
|
||||
Oss Oss
|
||||
Jwt Jwt
|
||||
Rustdesk Rustdesk
|
||||
Proxy Proxy
|
||||
Ldap Ldap
|
||||
}
|
||||
|
||||
func (a *Admin) Init() {
|
||||
if a.IdServerPort == 0 {
|
||||
a.IdServerPort = DefaultIdServerPort
|
||||
}
|
||||
if a.RelayServerPort == 0 {
|
||||
a.RelayServerPort = DefaultRelayServerPort
|
||||
}
|
||||
}
|
||||
|
||||
// Init 初始化配置
|
||||
@@ -80,7 +92,7 @@ func Init(rowVal *Config, path string) *viper.Viper {
|
||||
panic(fmt.Errorf("Fatal error config: %s \n", err))
|
||||
}
|
||||
rowVal.Rustdesk.LoadKeyFile()
|
||||
rowVal.Rustdesk.ParsePort()
|
||||
rowVal.Admin.Init()
|
||||
return v
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package config
|
||||
|
||||
const (
|
||||
TypeSqlite = "sqlite"
|
||||
TypeMysql = "mysql"
|
||||
TypeSqlite = "sqlite"
|
||||
TypeMysql = "mysql"
|
||||
TypePostgresql = "postgresql"
|
||||
)
|
||||
|
||||
type Gorm struct {
|
||||
@@ -17,3 +18,13 @@ type Mysql struct {
|
||||
Password string `mapstructure:"password"`
|
||||
Dbname string `mapstructure:"dbname"`
|
||||
}
|
||||
|
||||
type Postgresql struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port string `mapstructure:"port"`
|
||||
User string `mapstructure:"user"`
|
||||
Password string `mapstructure:"password"`
|
||||
Dbname string `mapstructure:"dbname"`
|
||||
Sslmode string `mapstructure:"sslmode"` // "disable", "require", "verify-ca", "verify-full"
|
||||
TimeZone string `mapstructure:"time-zone"` // e.g., "Asia/Shanghai"
|
||||
}
|
||||
|
||||
@@ -3,18 +3,20 @@ package config
|
||||
type GithubOauth struct {
|
||||
ClientId string `mapstructure:"client-id"`
|
||||
ClientSecret string `mapstructure:"client-secret"`
|
||||
RedirectUrl string `mapstructure:"redirect-url"`
|
||||
}
|
||||
|
||||
type GoogleOauth struct {
|
||||
ClientId string `mapstructure:"client-id"`
|
||||
ClientSecret string `mapstructure:"client-secret"`
|
||||
RedirectUrl string `mapstructure:"redirect-url"`
|
||||
}
|
||||
|
||||
type OidcOauth struct {
|
||||
Issuer string `mapstructure:"issuer"`
|
||||
ClientId string `mapstructure:"client-id"`
|
||||
ClientSecret string `mapstructure:"client-secret"`
|
||||
RedirectUrl string `mapstructure:"redirect-url"`
|
||||
}
|
||||
|
||||
type LinuxdoOauth struct {
|
||||
ClientId string `mapstructure:"client-id"`
|
||||
ClientSecret string `mapstructure:"client-secret"`
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@ package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -40,19 +38,3 @@ func (rd *Rustdesk) LoadKeyFile() {
|
||||
return
|
||||
}
|
||||
}
|
||||
func (rd *Rustdesk) ParsePort() {
|
||||
// Parse port
|
||||
idres := strings.Split(rd.IdServer, ":")
|
||||
if len(idres) == 1 {
|
||||
rd.IdServerPort = DefaultIdServerPort
|
||||
} else if len(idres) == 2 {
|
||||
rd.IdServerPort, _ = strconv.Atoi(idres[1])
|
||||
}
|
||||
|
||||
relayres := strings.Split(rd.RelayServer, ":")
|
||||
if len(relayres) == 1 {
|
||||
rd.RelayServerPort = DefaultRelayServerPort
|
||||
} else if len(relayres) == 2 {
|
||||
rd.RelayServerPort, _ = strconv.Atoi(relayres[1])
|
||||
}
|
||||
}
|
||||
|
||||
0
data/.gitkeep
Normal file
0
data/.gitkeep
Normal file
@@ -5,7 +5,7 @@ services:
|
||||
dockerfile: Dockerfile.dev
|
||||
args:
|
||||
COUNTRY: CN
|
||||
FREONTEND_GIT_REPO: https://github.com/lejianwen/rustdesk-api-web.git
|
||||
FRONTEND_GIT_REPO: https://github.com/lejianwen/rustdesk-api-web.git
|
||||
FRONTEND_GIT_BRANCH: master
|
||||
# image: lejianwen/rustdesk-api
|
||||
container_name: rustdesk-api
|
||||
@@ -21,4 +21,4 @@ services:
|
||||
- ./data/rustdesk/api:/app/data #将数据库挂载出来方便备份
|
||||
- ./conf:/app/conf # config
|
||||
# - ./resources:/app/resources # 静态资源
|
||||
restart: unless-stopped
|
||||
restart: unless-stopped
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package admin Content generated by swaggo/swag. DO NOT EDIT
|
||||
// Package admin Code generated by swaggo/swag. DO NOT EDIT
|
||||
package admin
|
||||
|
||||
import "github.com/swaggo/swag"
|
||||
@@ -5569,8 +5569,7 @@ const docTemplateadmin = `{
|
||||
"required": [
|
||||
"client_id",
|
||||
"client_secret",
|
||||
"oauth_type",
|
||||
"redirect_url"
|
||||
"oauth_type"
|
||||
],
|
||||
"properties": {
|
||||
"auto_register": {
|
||||
@@ -5600,9 +5599,6 @@ const docTemplateadmin = `{
|
||||
"pkce_method": {
|
||||
"type": "string"
|
||||
},
|
||||
"redirect_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"scopes": {
|
||||
"type": "string"
|
||||
}
|
||||
@@ -5828,6 +5824,9 @@ const docTemplateadmin = `{
|
||||
"captcha": {
|
||||
"type": "string"
|
||||
},
|
||||
"captcha_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"password": {
|
||||
"type": "string"
|
||||
},
|
||||
@@ -6293,9 +6292,6 @@ const docTemplateadmin = `{
|
||||
"pkce_method": {
|
||||
"type": "string"
|
||||
},
|
||||
"redirect_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"scopes": {
|
||||
"type": "string"
|
||||
},
|
||||
|
||||
@@ -5562,8 +5562,7 @@
|
||||
"required": [
|
||||
"client_id",
|
||||
"client_secret",
|
||||
"oauth_type",
|
||||
"redirect_url"
|
||||
"oauth_type"
|
||||
],
|
||||
"properties": {
|
||||
"auto_register": {
|
||||
@@ -5593,9 +5592,6 @@
|
||||
"pkce_method": {
|
||||
"type": "string"
|
||||
},
|
||||
"redirect_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"scopes": {
|
||||
"type": "string"
|
||||
}
|
||||
@@ -5821,6 +5817,9 @@
|
||||
"captcha": {
|
||||
"type": "string"
|
||||
},
|
||||
"captcha_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"password": {
|
||||
"type": "string"
|
||||
},
|
||||
@@ -6286,9 +6285,6 @@
|
||||
"pkce_method": {
|
||||
"type": "string"
|
||||
},
|
||||
"redirect_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"scopes": {
|
||||
"type": "string"
|
||||
},
|
||||
@@ -6592,4 +6588,4 @@
|
||||
"in": "header"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,15 +143,12 @@ definitions:
|
||||
type: boolean
|
||||
pkce_method:
|
||||
type: string
|
||||
redirect_url:
|
||||
type: string
|
||||
scopes:
|
||||
type: string
|
||||
required:
|
||||
- client_id
|
||||
- client_secret
|
||||
- oauth_type
|
||||
- redirect_url
|
||||
type: object
|
||||
admin.PeerBatchDeleteForm:
|
||||
properties:
|
||||
@@ -297,6 +294,8 @@ definitions:
|
||||
properties:
|
||||
captcha:
|
||||
type: string
|
||||
captcha_id:
|
||||
type: string
|
||||
password:
|
||||
type: string
|
||||
platform:
|
||||
@@ -609,8 +608,6 @@ definitions:
|
||||
type: boolean
|
||||
pkce_method:
|
||||
type: string
|
||||
redirect_url:
|
||||
type: string
|
||||
scopes:
|
||||
type: string
|
||||
updated_at:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package api Content generated by swaggo/swag. DO NOT EDIT
|
||||
// Package api Code generated by swaggo/swag. DO NOT EDIT
|
||||
package api
|
||||
|
||||
import "github.com/swaggo/swag"
|
||||
@@ -1208,7 +1208,7 @@ const docTemplateapi = `{
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"地址"
|
||||
"System"
|
||||
],
|
||||
"summary": "提交系统信息",
|
||||
"parameters": [
|
||||
@@ -1238,6 +1238,35 @@ const docTemplateapi = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/sysinfo_ver": {
|
||||
"post": {
|
||||
"description": "获取系统版本信息",
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"System"
|
||||
],
|
||||
"summary": "获取系统版本信息",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/response.ErrorResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/users": {
|
||||
"get": {
|
||||
"security": [
|
||||
|
||||
@@ -1201,7 +1201,7 @@
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"地址"
|
||||
"System"
|
||||
],
|
||||
"summary": "提交系统信息",
|
||||
"parameters": [
|
||||
@@ -1231,6 +1231,35 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/sysinfo_ver": {
|
||||
"post": {
|
||||
"description": "获取系统版本信息",
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"System"
|
||||
],
|
||||
"summary": "获取系统版本信息",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/response.ErrorResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/users": {
|
||||
"get": {
|
||||
"security": [
|
||||
|
||||
@@ -973,7 +973,26 @@ paths:
|
||||
$ref: '#/definitions/response.ErrorResponse'
|
||||
summary: 提交系统信息
|
||||
tags:
|
||||
- 地址
|
||||
- System
|
||||
/sysinfo_ver:
|
||||
post:
|
||||
consumes:
|
||||
- application/json
|
||||
description: 获取系统版本信息
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
schema:
|
||||
type: string
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
$ref: '#/definitions/response.ErrorResponse'
|
||||
summary: 获取系统版本信息
|
||||
tags:
|
||||
- System
|
||||
/users:
|
||||
get:
|
||||
consumes:
|
||||
|
||||
23
go.mod
23
go.mod
@@ -1,19 +1,23 @@
|
||||
module github.com/lejianwen/rustdesk-api/v2
|
||||
|
||||
go 1.22
|
||||
go 1.23
|
||||
|
||||
toolchain go1.23.10
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.3.2
|
||||
github.com/antonfisher/nested-logrus-formatter v1.3.1
|
||||
github.com/fsnotify/fsnotify v1.5.1
|
||||
github.com/coreos/go-oidc/v3 v3.12.0
|
||||
github.com/fvbock/endless v0.0.0-20170109170031-447134032cb6
|
||||
github.com/gin-gonic/gin v1.9.0
|
||||
github.com/go-ldap/ldap/v3 v3.4.10
|
||||
github.com/go-playground/locales v0.14.1
|
||||
github.com/go-playground/universal-translator v0.18.1
|
||||
github.com/go-playground/validator/v10 v10.26.0
|
||||
github.com/go-redis/redis/v8 v8.11.4
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/mojocn/base64Captcha v1.3.6
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0
|
||||
github.com/sirupsen/logrus v1.8.1
|
||||
github.com/spf13/cobra v1.8.1
|
||||
@@ -24,8 +28,9 @@ require (
|
||||
golang.org/x/oauth2 v0.23.0
|
||||
golang.org/x/text v0.22.0
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/driver/sqlite v1.5.6
|
||||
gorm.io/gorm v1.25.7
|
||||
gorm.io/gorm v1.25.10
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -36,13 +41,12 @@ require (
|
||||
github.com/bytedance/sonic v1.8.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/coreos/go-oidc/v3 v3.12.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/fsnotify/fsnotify v1.5.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.0.2 // indirect
|
||||
github.com/go-ldap/ldap/v3 v3.4.10 // indirect
|
||||
github.com/go-openapi/jsonpointer v0.19.5 // indirect
|
||||
github.com/go-openapi/jsonreference v0.19.6 // indirect
|
||||
github.com/go-openapi/spec v0.20.4 // indirect
|
||||
@@ -52,6 +56,10 @@ require (
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.6.0 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
@@ -65,9 +73,9 @@ require (
|
||||
github.com/mitchellh/mapstructure v1.4.2 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/mojocn/base64Captcha v1.3.6 // indirect
|
||||
github.com/pelletier/go-toml v1.9.4 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.6 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/spf13/afero v1.6.0 // indirect
|
||||
github.com/spf13/cast v1.4.1 // indirect
|
||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||
@@ -79,8 +87,9 @@ require (
|
||||
golang.org/x/crypto v0.33.0 // indirect
|
||||
golang.org/x/image v0.13.0 // indirect
|
||||
golang.org/x/net v0.34.0 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
||||
golang.org/x/tools v0.26.0 // indirect
|
||||
google.golang.org/protobuf v1.33.0 // indirect
|
||||
gopkg.in/ini.v1 v1.63.2 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
|
||||
@@ -78,11 +78,13 @@ func (co *Config) AdminConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
hello := global.Config.Admin.Hello
|
||||
helloFile := global.Config.Admin.HelloFile
|
||||
if helloFile != "" {
|
||||
b, err := os.ReadFile(helloFile)
|
||||
if err == nil && len(b) > 0 {
|
||||
hello = string(b)
|
||||
if hello == "" {
|
||||
helloFile := global.Config.Admin.HelloFile
|
||||
if helloFile != "" {
|
||||
b, err := os.ReadFile(helloFile)
|
||||
if err == nil && len(b) > 0 {
|
||||
hello = string(b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ func (f *File) Notify(c *gin.Context) {
|
||||
|
||||
res := global.Oss.Verify(c.Request)
|
||||
if !res {
|
||||
response.Fail(c, 101, "权限错误")
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "NoAccess"))
|
||||
return
|
||||
}
|
||||
fm := &FileBack{}
|
||||
|
||||
@@ -2,6 +2,7 @@ package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lejianwen/rustdesk-api/v2/global"
|
||||
"github.com/lejianwen/rustdesk-api/v2/http/controller/api"
|
||||
@@ -57,7 +58,7 @@ func (ct *Login) Login(c *gin.Context) {
|
||||
|
||||
// 检查是否需要验证码
|
||||
if needCaptcha {
|
||||
if f.Captcha == "" || !loginLimiter.VerifyCaptcha(clientIp, f.Captcha) {
|
||||
if f.CaptchaId == "" || f.Captcha == "" || !loginLimiter.VerifyCaptcha(f.CaptchaId, f.Captcha) {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError"))
|
||||
return
|
||||
}
|
||||
@@ -68,8 +69,6 @@ func (ct *Login) Login(c *gin.Context) {
|
||||
if u.Id == 0 {
|
||||
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp))
|
||||
loginLimiter.RecordFailedAttempt(clientIp)
|
||||
// 移除验证码,重新生成
|
||||
loginLimiter.RemoveCaptcha(clientIp)
|
||||
if _, needCaptcha = loginLimiter.CheckSecurityStatus(clientIp); needCaptcha {
|
||||
response.Fail(c, 110, response.TranslateMsg(c, "UsernameOrPasswordError"))
|
||||
} else {
|
||||
@@ -80,7 +79,6 @@ func (ct *Login) Login(c *gin.Context) {
|
||||
|
||||
if !service.AllService.UserService.CheckUserEnable(u) {
|
||||
if needCaptcha {
|
||||
loginLimiter.RemoveCaptcha(clientIp)
|
||||
response.Fail(c, 110, response.TranslateMsg(c, "UserDisabled"))
|
||||
return
|
||||
}
|
||||
@@ -113,7 +111,7 @@ func (ct *Login) Captcha(c *gin.Context) {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired"))
|
||||
return
|
||||
}
|
||||
err, captcha := loginLimiter.RequireCaptcha(clientIp)
|
||||
err, captcha := loginLimiter.RequireCaptcha()
|
||||
if err != nil {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error())
|
||||
return
|
||||
@@ -125,6 +123,7 @@ func (ct *Login) Captcha(c *gin.Context) {
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"captcha": gin.H{
|
||||
"id": captcha.Id,
|
||||
"b64": b64,
|
||||
},
|
||||
})
|
||||
@@ -190,7 +189,7 @@ func (ct *Login) OidcAuth(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(f.Op)
|
||||
err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(c, f.Op)
|
||||
if err != nil {
|
||||
response.Error(c, response.TranslateMsg(c, err.Error()))
|
||||
return
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lejianwen/rustdesk-api/v2/global"
|
||||
"github.com/lejianwen/rustdesk-api/v2/http/request/admin"
|
||||
adminReq "github.com/lejianwen/rustdesk-api/v2/http/request/admin"
|
||||
"github.com/lejianwen/rustdesk-api/v2/http/response"
|
||||
"github.com/lejianwen/rustdesk-api/v2/service"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Oauth struct {
|
||||
@@ -43,7 +44,7 @@ func (o *Oauth) ToBind(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(f.Op)
|
||||
err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(c, f.Op)
|
||||
if err != nil {
|
||||
response.Error(c, response.TranslateMsg(c, err.Error()))
|
||||
return
|
||||
@@ -68,16 +69,16 @@ func (o *Oauth) Confirm(c *gin.Context) {
|
||||
j := &adminReq.OauthConfirmForm{}
|
||||
err := c.ShouldBindJSON(j)
|
||||
if err != nil {
|
||||
response.Fail(c, 101, "参数错误"+err.Error())
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
|
||||
return
|
||||
}
|
||||
if j.Code == "" {
|
||||
response.Fail(c, 101, "参数错误: code 不存在")
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
|
||||
return
|
||||
}
|
||||
v := service.AllService.OauthService.GetOauthCache(j.Code)
|
||||
if v == nil {
|
||||
response.Fail(c, 101, "授权已过期")
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "OauthExpired"))
|
||||
return
|
||||
}
|
||||
u := service.AllService.UserService.CurUser(c)
|
||||
|
||||
@@ -119,7 +119,16 @@ func (r *Rustdesk) SendCmd(c *gin.Context) {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
|
||||
return
|
||||
}
|
||||
res, err := service.AllService.ServerCmdService.SendCmd(rc.Target, rc.Cmd, rc.Option)
|
||||
|
||||
port := 0
|
||||
switch rc.Target {
|
||||
case model.ServerCmdTargetIdServer:
|
||||
port = global.Config.Admin.IdServerPort - 1
|
||||
case model.ServerCmdTargetRelayServer:
|
||||
port = global.Config.Admin.RelayServerPort
|
||||
}
|
||||
|
||||
res, err := service.AllService.ServerCmdService.SendCmd(port, rc.Cmd, rc.Option)
|
||||
if err != nil {
|
||||
response.Fail(c, 101, err.Error())
|
||||
return
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
adResp "github.com/lejianwen/rustdesk-api/v2/http/response/admin"
|
||||
"github.com/lejianwen/rustdesk-api/v2/model"
|
||||
"github.com/lejianwen/rustdesk-api/v2/service"
|
||||
"github.com/lejianwen/rustdesk-api/v2/utils"
|
||||
"gorm.io/gorm"
|
||||
"strconv"
|
||||
)
|
||||
@@ -243,11 +244,10 @@ func (ct *User) ChangeCurPwd(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
u := service.AllService.UserService.CurUser(c)
|
||||
// If the password is not empty, the old password is verified
|
||||
// otherwise, the old password is not verified
|
||||
// Verify the old password only when the account already has one set
|
||||
if !service.AllService.UserService.IsPasswordEmptyByUser(u) {
|
||||
oldPwd := service.AllService.UserService.EncryptPassword(f.OldPassword)
|
||||
if u.Password != oldPwd {
|
||||
ok, _, err := utils.VerifyPassword(u.Password, f.OldPassword)
|
||||
if err != nil || !ok {
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "OldPasswordError"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lejianwen/rustdesk-api/v2/global"
|
||||
"github.com/lejianwen/rustdesk-api/v2/http/request/api"
|
||||
@@ -10,7 +12,6 @@ import (
|
||||
"github.com/lejianwen/rustdesk-api/v2/service"
|
||||
"github.com/lejianwen/rustdesk-api/v2/utils"
|
||||
"github.com/nicksnyder/go-i18n/v2/i18n"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Oauth struct {
|
||||
@@ -35,7 +36,7 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
|
||||
|
||||
oauthService := service.AllService.OauthService
|
||||
|
||||
err, state, verifier, nonce, url := oauthService.BeginAuth(f.Op)
|
||||
err, state, verifier, nonce, url := oauthService.BeginAuth(c, f.Op)
|
||||
if err != nil {
|
||||
response.Error(c, response.TranslateMsg(c, err.Error()))
|
||||
return
|
||||
@@ -169,7 +170,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
|
||||
var user *model.User
|
||||
// 获取用户信息
|
||||
code := c.Query("code")
|
||||
err, oauthUser := oauthService.Callback(code, verifier, op, nonce)
|
||||
err, oauthUser := oauthService.Callback(c, code, verifier, op, nonce)
|
||||
if err != nil {
|
||||
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
|
||||
"message": "OauthFailed",
|
||||
@@ -225,8 +226,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
|
||||
if !*oauthConfig.AutoRegister {
|
||||
//c.String(http.StatusInternalServerError, "还未绑定用户,请先绑定")
|
||||
oauthCache.UpdateFromOauthUser(oauthUser)
|
||||
url := global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/bind/" + cacheKey
|
||||
c.Redirect(http.StatusFound, url)
|
||||
c.Redirect(http.StatusFound, "/_admin/#/oauth/bind/"+cacheKey)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -251,8 +251,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
|
||||
Type: model.LoginLogTypeOauth,
|
||||
Platform: oauthService.DeviceOs,
|
||||
})*/
|
||||
url := global.Config.Rustdesk.ApiServer + "/_admin/#/"
|
||||
c.Redirect(http.StatusFound, url)
|
||||
c.Redirect(http.StatusFound, "/_admin/#/")
|
||||
return
|
||||
}
|
||||
c.HTML(http.StatusOK, "oauth_success.html", gin.H{
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
requstform "github.com/lejianwen/rustdesk-api/v2/http/request/api"
|
||||
@@ -13,7 +14,7 @@ type Peer struct {
|
||||
}
|
||||
|
||||
// SysInfo
|
||||
// @Tags 地址
|
||||
// @Tags System
|
||||
// @Summary 提交系统信息
|
||||
// @Description 提交系统信息
|
||||
// @Accept json
|
||||
@@ -57,8 +58,19 @@ func (p *Peer) SysInfo(c *gin.Context) {
|
||||
c.String(http.StatusOK, "SYSINFO_UPDATED")
|
||||
}
|
||||
|
||||
// SysInfoVer
|
||||
// @Tags System
|
||||
// @Summary 获取系统版本信息
|
||||
// @Description 获取系统版本信息
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {string} string ""
|
||||
// @Failure 500 {object} response.ErrorResponse
|
||||
// @Router /sysinfo_ver [post]
|
||||
func (p *Peer) SysInfoVer(c *gin.Context) {
|
||||
//读取resources/version文件
|
||||
v := service.AllService.AppService.GetAppVersion()
|
||||
// 加上启动时间,方便client上传信息
|
||||
v = fmt.Sprintf("%s\n%s", v, service.AllService.AppService.GetStartTime())
|
||||
c.String(http.StatusOK, v)
|
||||
}
|
||||
|
||||
@@ -13,13 +13,13 @@ func BackendUserAuth() gin.HandlerFunc {
|
||||
//测试先关闭
|
||||
token := c.GetHeader("api-token")
|
||||
if token == "" {
|
||||
response.Fail(c, 403, "请先登录")
|
||||
response.Fail(c, 403, response.TranslateMsg(c, "NeedLogin"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
user, ut := service.AllService.UserService.InfoByAccessToken(token)
|
||||
if user.Id == 0 {
|
||||
response.Fail(c, 403, "请先登录")
|
||||
response.Fail(c, 403, response.TranslateMsg(c, "NeedLogin"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ func AdminPrivilege() gin.HandlerFunc {
|
||||
u := service.AllService.UserService.CurUser(c)
|
||||
|
||||
if !service.AllService.UserService.IsAdmin(u) {
|
||||
response.Fail(c, 403, "无权限")
|
||||
response.Fail(c, 403, response.TranslateMsg(c, "NoAccess"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -12,18 +12,18 @@ func JwtAuth() gin.HandlerFunc {
|
||||
//测试先关闭
|
||||
token := c.GetHeader("api-token")
|
||||
if token == "" {
|
||||
response.Fail(c, 403, "请先登录")
|
||||
response.Fail(c, 403, response.TranslateMsg(c, "NeedLogin"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
uid, err := global.Jwt.ParseToken(token)
|
||||
if err != nil {
|
||||
response.Fail(c, 403, "请先登录")
|
||||
response.Fail(c, 403, response.TranslateMsg(c, "NeedLogin"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if uid == 0 {
|
||||
response.Fail(c, 403, "请先登录")
|
||||
response.Fail(c, 403, response.TranslateMsg(c, "NeedLogin"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -34,12 +34,12 @@ func JwtAuth() gin.HandlerFunc {
|
||||
// Username: "测试用户",
|
||||
//}
|
||||
if user.Id == 0 {
|
||||
response.Fail(c, 403, "请先登录")
|
||||
response.Fail(c, 403, response.TranslateMsg(c, "NeedLogin"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if !service.AllService.UserService.CheckUserEnable(user) {
|
||||
response.Fail(c, 101, "你已被禁用")
|
||||
response.Fail(c, 101, response.TranslateMsg(c, "Banned"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package admin
|
||||
|
||||
type Login struct {
|
||||
Username string `json:"username" validate:"required" label:"用户名"`
|
||||
Password string `json:"password,omitempty" validate:"required" label:"密码"`
|
||||
Platform string `json:"platform" label:"平台"`
|
||||
Captcha string `json:"captcha,omitempty" label:"验证码"`
|
||||
Username string `json:"username" validate:"required" label:"用户名"`
|
||||
Password string `json:"password,omitempty" validate:"required" label:"密码"`
|
||||
Platform string `json:"platform" label:"平台"`
|
||||
Captcha string `json:"captcha,omitempty" label:"验证码"`
|
||||
CaptchaId string `json:"captcha_id,omitempty"`
|
||||
}
|
||||
|
||||
type LoginLogQuery struct {
|
||||
|
||||
@@ -22,7 +22,6 @@ type OauthForm struct {
|
||||
Scopes string `json:"scopes" validate:"omitempty"`
|
||||
ClientId string `json:"client_id" validate:"required"`
|
||||
ClientSecret string `json:"client_secret" validate:"required"`
|
||||
RedirectUrl string `json:"redirect_url" validate:"required"`
|
||||
AutoRegister *bool `json:"auto_register"`
|
||||
PkceEnable *bool `json:"pkce_enable"`
|
||||
PkceMethod string `json:"pkce_method"`
|
||||
@@ -34,7 +33,6 @@ func (of *OauthForm) ToOauth() *model.Oauth {
|
||||
OauthType: of.OauthType,
|
||||
ClientId: of.ClientId,
|
||||
ClientSecret: of.ClientSecret,
|
||||
RedirectUrl: of.RedirectUrl,
|
||||
AutoRegister: of.AutoRegister,
|
||||
Issuer: of.Issuer,
|
||||
Scopes: of.Scopes,
|
||||
|
||||
@@ -14,6 +14,7 @@ type UserForm struct {
|
||||
GroupId uint `json:"group_id" validate:"required"`
|
||||
IsAdmin *bool `json:"is_admin" `
|
||||
Status model.StatusCode `json:"status" validate:"required,gte=0"`
|
||||
Remark string `json:"remark"`
|
||||
}
|
||||
|
||||
func (uf *UserForm) FromUser(user *model.User) *UserForm {
|
||||
@@ -25,6 +26,7 @@ func (uf *UserForm) FromUser(user *model.User) *UserForm {
|
||||
uf.GroupId = user.GroupId
|
||||
uf.IsAdmin = user.IsAdmin
|
||||
uf.Status = user.Status
|
||||
uf.Remark = user.Remark
|
||||
return uf
|
||||
}
|
||||
func (uf *UserForm) ToUser() *model.User {
|
||||
@@ -37,6 +39,7 @@ func (uf *UserForm) ToUser() *model.User {
|
||||
user.GroupId = uf.GroupId
|
||||
user.IsAdmin = uf.IsAdmin
|
||||
user.Status = uf.Status
|
||||
user.Remark = uf.Remark
|
||||
return user
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/lejianwen/rustdesk-api/v2/global"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
@@ -10,14 +9,14 @@ import (
|
||||
)
|
||||
|
||||
type MysqlConfig struct {
|
||||
Dns string
|
||||
Dsn string
|
||||
MaxIdleConns int
|
||||
MaxOpenConns int
|
||||
}
|
||||
|
||||
func NewMysql(mysqlConf *MysqlConfig) *gorm.DB {
|
||||
func NewMysql(mysqlConf *MysqlConfig, logwriter logger.Writer) *gorm.DB {
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{
|
||||
DSN: mysqlConf.Dns, // DSN data source name
|
||||
DSN: mysqlConf.Dsn, // DSN data source name
|
||||
DefaultStringSize: 256, // string 类型字段的默认长度
|
||||
//DisableDatetimePrecision: true, // 禁用 datetime 精度,MySQL 5.6 之前的数据库不支持
|
||||
//DontSupportRenameIndex: true, // 重命名索引时采用删除并新建的方式,MySQL 5.7 之前的数据库和 MariaDB 不支持重命名索引
|
||||
@@ -26,7 +25,7 @@ func NewMysql(mysqlConf *MysqlConfig) *gorm.DB {
|
||||
}), &gorm.Config{
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
Logger: logger.New(
|
||||
global.Logger, // io writer
|
||||
logwriter, // io writer
|
||||
logger.Config{
|
||||
SlowThreshold: time.Second, // Slow SQL threshold
|
||||
LogLevel: logger.Warn, // Log level
|
||||
|
||||
45
lib/orm/postgresql.go
Normal file
45
lib/orm/postgresql.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"time"
|
||||
)
|
||||
|
||||
type PostgresqlConfig struct {
|
||||
Dsn string
|
||||
MaxIdleConns int
|
||||
MaxOpenConns int
|
||||
}
|
||||
|
||||
func NewPostgresql(conf *PostgresqlConfig, logwriter logger.Writer) *gorm.DB {
|
||||
db, err := gorm.Open(postgres.Open(conf.Dsn), &gorm.Config{
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
Logger: logger.New(
|
||||
logwriter, // io writer
|
||||
logger.Config{
|
||||
SlowThreshold: time.Second, // Slow SQL threshold
|
||||
LogLevel: logger.Warn, // Log level
|
||||
//IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
|
||||
ParameterizedQueries: true, // Don't include params in the SQL log
|
||||
Colorful: true,
|
||||
},
|
||||
),
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
sqlDB, err2 := db.DB()
|
||||
if err2 != nil {
|
||||
fmt.Println(err2)
|
||||
}
|
||||
// SetMaxIdleConns 设置空闲连接池中连接的最大数量
|
||||
sqlDB.SetMaxIdleConns(conf.MaxIdleConns)
|
||||
|
||||
// SetMaxOpenConns 设置打开数据库连接的最大数量。
|
||||
sqlDB.SetMaxOpenConns(conf.MaxOpenConns)
|
||||
|
||||
return db
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/lejianwen/rustdesk-api/v2/global"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
@@ -14,11 +13,11 @@ type SqliteConfig struct {
|
||||
MaxOpenConns int
|
||||
}
|
||||
|
||||
func NewSqlite(sqliteConf *SqliteConfig) *gorm.DB {
|
||||
func NewSqlite(sqliteConf *SqliteConfig, logwriter logger.Writer) *gorm.DB {
|
||||
db, err := gorm.Open(sqlite.Open("./data/rustdeskapi.db"), &gorm.Config{
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
Logger: logger.New(
|
||||
global.Logger, // io writer
|
||||
logwriter, // io writer
|
||||
logger.Config{
|
||||
SlowThreshold: time.Second, // Slow SQL threshold
|
||||
LogLevel: logger.Warn, // Log level
|
||||
|
||||
@@ -14,6 +14,7 @@ const (
|
||||
OauthTypeGoogle string = "google"
|
||||
OauthTypeOidc string = "oidc"
|
||||
OauthTypeWebauth string = "webauth"
|
||||
OauthTypeLinuxdo string = "linuxdo"
|
||||
PKCEMethodS256 string = "S256"
|
||||
PKCEMethodPlain string = "plain"
|
||||
)
|
||||
@@ -21,7 +22,7 @@ const (
|
||||
// Validate the oauth type
|
||||
func ValidateOauthType(oauthType string) error {
|
||||
switch oauthType {
|
||||
case OauthTypeGithub, OauthTypeGoogle, OauthTypeOidc, OauthTypeWebauth:
|
||||
case OauthTypeGithub, OauthTypeGoogle, OauthTypeOidc, OauthTypeWebauth, OauthTypeLinuxdo:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("invalid Oauth type")
|
||||
@@ -29,8 +30,9 @@ func ValidateOauthType(oauthType string) error {
|
||||
}
|
||||
|
||||
const (
|
||||
UserEndpointGithub string = "https://api.github.com/user"
|
||||
IssuerGoogle string = "https://accounts.google.com"
|
||||
UserEndpointGithub string = "https://api.github.com/user"
|
||||
UserEndpointLinuxdo string = "https://connect.linux.do/api/user"
|
||||
IssuerGoogle string = "https://accounts.google.com"
|
||||
)
|
||||
|
||||
type Oauth struct {
|
||||
@@ -39,12 +41,11 @@ type Oauth struct {
|
||||
OauthType string `json:"oauth_type"`
|
||||
ClientId string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
RedirectUrl string `json:"redirect_url"`
|
||||
AutoRegister *bool `json:"auto_register"`
|
||||
Scopes string `json:"scopes"`
|
||||
Issuer string `json:"issuer"`
|
||||
PkceEnable *bool `json:"pkce_enable"`
|
||||
PkceMethod string `json:"pkce_method"`
|
||||
PkceEnable *bool `json:"pkce_enable"`
|
||||
PkceMethod string `json:"pkce_method"`
|
||||
TimeModel
|
||||
}
|
||||
|
||||
@@ -60,6 +61,8 @@ func (oa *Oauth) FormatOauthInfo() error {
|
||||
oa.Op = OauthTypeGithub
|
||||
case OauthTypeGoogle:
|
||||
oa.Op = OauthTypeGoogle
|
||||
case OauthTypeLinuxdo:
|
||||
oa.Op = OauthTypeLinuxdo
|
||||
}
|
||||
// check if the op is empty, set the default value
|
||||
op := strings.TrimSpace(oa.Op)
|
||||
@@ -152,6 +155,24 @@ func (gu *GithubUser) ToOauthUser() *OauthUser {
|
||||
}
|
||||
}
|
||||
|
||||
type LinuxdoUser struct {
|
||||
OauthUserBase
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Avatar string `json:"avatar_url"`
|
||||
}
|
||||
|
||||
func (lu *LinuxdoUser) ToOauthUser() *OauthUser {
|
||||
return &OauthUser{
|
||||
OpenId: strconv.Itoa(lu.Id),
|
||||
Name: lu.Name,
|
||||
Username: strings.ToLower(lu.Username),
|
||||
Email: lu.Email,
|
||||
VerifiedEmail: true, // linux.do 用户邮箱默认已验证
|
||||
Picture: lu.Avatar,
|
||||
}
|
||||
}
|
||||
|
||||
type OauthList struct {
|
||||
Oauths []*Oauth `json:"list"`
|
||||
Pagination
|
||||
|
||||
@@ -11,6 +11,7 @@ type User struct {
|
||||
GroupId uint `json:"group_id" gorm:"default:0;not null;index"`
|
||||
IsAdmin *bool `json:"is_admin" gorm:"default:0;not null;"`
|
||||
Status StatusCode `json:"status" gorm:"default:1;not null;"`
|
||||
Remark string `json:"remark" gorm:"default:'';not null;"`
|
||||
TimeModel
|
||||
}
|
||||
|
||||
|
||||
@@ -33,6 +33,11 @@ description = "No access."
|
||||
one = "No access."
|
||||
other = "No access."
|
||||
|
||||
[NeedLogin]
|
||||
description = "Need login."
|
||||
one = "Please log in first."
|
||||
other = "Please log in first."
|
||||
|
||||
[UsernameOrPasswordError]
|
||||
description = "Username or password error."
|
||||
one = "Username or password error."
|
||||
|
||||
@@ -33,6 +33,11 @@ description = "No access."
|
||||
one = "Sin acceso."
|
||||
other = "Sin acceso."
|
||||
|
||||
[NeedLogin]
|
||||
description = "Need login."
|
||||
one = "Por favor inicie sesión primero."
|
||||
other = "Por favor inicie sesión primero."
|
||||
|
||||
[UsernameOrPasswordError]
|
||||
description = "Username or password error."
|
||||
one = "Error de usuario o contraseña."
|
||||
|
||||
@@ -33,6 +33,11 @@ description = "No access."
|
||||
one = "Aucun d'access."
|
||||
other = "Aucun d'access."
|
||||
|
||||
[NeedLogin]
|
||||
description = "Need login."
|
||||
one = "Veuillez d'abord vous connecter."
|
||||
other = "Veuillez d'abord vous connecter."
|
||||
|
||||
[UsernameOrPasswordError]
|
||||
description = "Username or password error."
|
||||
one = "Nom d'utilisateur ou de mot de passe incorrect."
|
||||
@@ -161,4 +166,4 @@ other = "Banni."
|
||||
[RegisterSuccessWaitAdminConfirm]
|
||||
description = "Register success wait admin confirm."
|
||||
one = "Inscription réussie, veuillez attendre la confirmation de l'administrateur."
|
||||
other = "Inscription réussie, veuillez attendre la confirmation de l'administrateur."
|
||||
other = "Inscription réussie, veuillez attendre la confirmation de l'administrateur."
|
||||
|
||||
@@ -33,6 +33,11 @@ description = "No access."
|
||||
one = "접근할 수 없습니다."
|
||||
other = "접근할 수 없습니다."
|
||||
|
||||
[NeedLogin]
|
||||
description = "Need login."
|
||||
one = "먼저 로그인해주세요."
|
||||
other = "먼저 로그인해주세요."
|
||||
|
||||
[UsernameOrPasswordError]
|
||||
description = "Username or password error."
|
||||
one = "사용자 이름이나 비밀번호가 올바르지 않습니다."
|
||||
|
||||
@@ -33,6 +33,11 @@ description = "No access."
|
||||
one = "Нет доступа."
|
||||
other = "Нет доступа."
|
||||
|
||||
[NeedLogin]
|
||||
description = "Need login."
|
||||
one = "Пожалуйста, войдите в систему."
|
||||
other = "Пожалуйста, войдите в систему."
|
||||
|
||||
[UsernameOrPasswordError]
|
||||
description = "Username or password error."
|
||||
one = "Неправильное имя пользователя или пароль."
|
||||
@@ -161,4 +166,4 @@ other = "Заблокировано."
|
||||
[RegisterSuccessWaitAdminConfirm]
|
||||
description = "Register success wait admin confirm."
|
||||
one = "Регистрация прошла успешно, ожидайте подтверждения администратора."
|
||||
other = "Регистрация прошла успешно, ожидайте подтверждения администратора."
|
||||
other = "Регистрация прошла успешно, ожидайте подтверждения администратора."
|
||||
|
||||
@@ -33,6 +33,11 @@ description = "No access."
|
||||
one = "无权限。"
|
||||
other = "无权限。"
|
||||
|
||||
[NeedLogin]
|
||||
description = "Need login."
|
||||
one = "请先登录。"
|
||||
other = "请先登录。"
|
||||
|
||||
[UsernameOrPasswordError]
|
||||
description = "Username or password error."
|
||||
one = "用户名或密码错误。"
|
||||
|
||||
@@ -33,6 +33,11 @@ description = "No access."
|
||||
one = "無許可權。"
|
||||
other = "無許可權。"
|
||||
|
||||
[NeedLogin]
|
||||
description = "Need login."
|
||||
one = "請先登入。"
|
||||
other = "請先登入。"
|
||||
|
||||
[UsernameOrPasswordError]
|
||||
description = "Username or password error."
|
||||
one = "使用者名稱或密碼錯誤。"
|
||||
|
||||
2
resources/web2/js/dist/index.js
vendored
2
resources/web2/js/dist/index.js
vendored
@@ -11550,7 +11550,7 @@ async function or(u) {
|
||||
let E = [], l = [];
|
||||
for (let d = 0; d < e.length; d++) {
|
||||
const c = 1 << 7 - d % 8;
|
||||
(s[d / 8] & c) === c ? E.push(e[d]) : l.push(e[d])
|
||||
(s[Math.floor(d / 8)] & c) === c ? E.push(e[d]) : l.push(e[d])
|
||||
}
|
||||
_t(E, l), n.close();
|
||||
return
|
||||
|
||||
@@ -3,13 +3,14 @@ package service
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AppService struct {
|
||||
}
|
||||
|
||||
var version = ""
|
||||
|
||||
var startTime = ""
|
||||
var once = &sync.Once{}
|
||||
|
||||
func (a *AppService) GetAppVersion() string {
|
||||
@@ -26,3 +27,13 @@ func (a *AppService) GetAppVersion() string {
|
||||
})
|
||||
return version
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Initialize the AppService if needed
|
||||
startTime = time.Now().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
// GetStartTime
|
||||
func (a *AppService) GetStartTime() string {
|
||||
return startTime
|
||||
}
|
||||
|
||||
@@ -4,11 +4,14 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lejianwen/rustdesk-api/v2/model"
|
||||
"github.com/lejianwen/rustdesk-api/v2/utils"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/github"
|
||||
|
||||
// "golang.org/x/oauth2/google"
|
||||
"gorm.io/gorm"
|
||||
// "io"
|
||||
@@ -93,16 +96,20 @@ func (os *OauthService) DeleteOauthCache(key string) {
|
||||
OauthCache.Delete(key)
|
||||
}
|
||||
|
||||
func (os *OauthService) BeginAuth(op string) (error error, state, verifier, nonce, url string) {
|
||||
func (os *OauthService) BeginAuth(c *gin.Context, op string) (error error, state, verifier, nonce, url string) {
|
||||
state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
|
||||
verifier = ""
|
||||
nonce = ""
|
||||
if op == model.OauthTypeWebauth {
|
||||
url = Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state
|
||||
host := c.GetHeader("Origin")
|
||||
if host == "" {
|
||||
host = Config.Rustdesk.ApiServer
|
||||
}
|
||||
url = host + "/_admin/#/oauth/" + state
|
||||
//url = "http://localhost:8888/_admin/#/oauth/" + code
|
||||
return nil, state, verifier, nonce, url
|
||||
}
|
||||
err, oauthInfo, oauthConfig, _ := os.GetOauthConfig(op)
|
||||
err, oauthInfo, oauthConfig, _ := os.GetOauthConfig(c, op)
|
||||
if err == nil {
|
||||
extras := make([]oauth2.AuthCodeOption, 0, 3)
|
||||
|
||||
@@ -154,21 +161,33 @@ func (os *OauthService) GithubProvider() *oidc.Provider {
|
||||
}).NewProvider(context.Background())
|
||||
}
|
||||
|
||||
func (os *OauthService) LinuxdoProvider() *oidc.Provider {
|
||||
return (&oidc.ProviderConfig{
|
||||
IssuerURL: "",
|
||||
AuthURL: "https://connect.linux.do/oauth2/authorize",
|
||||
TokenURL: "https://connect.linux.do/oauth2/token",
|
||||
DeviceAuthURL: "",
|
||||
UserInfoURL: model.UserEndpointLinuxdo,
|
||||
JWKSURL: "",
|
||||
Algorithms: nil,
|
||||
}).NewProvider(context.Background())
|
||||
}
|
||||
|
||||
// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
|
||||
func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config, provider *oidc.Provider) {
|
||||
func (os *OauthService) GetOauthConfig(c *gin.Context, op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config, provider *oidc.Provider) {
|
||||
//err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
|
||||
oauthInfo = os.InfoByOp(op)
|
||||
if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" {
|
||||
return errors.New("ConfigNotFound"), nil, nil, nil
|
||||
}
|
||||
// If the redirect URL is empty, use the default redirect URL
|
||||
if oauthInfo.RedirectUrl == "" {
|
||||
oauthInfo.RedirectUrl = Config.Rustdesk.ApiServer + "/api/oidc/callback"
|
||||
host := c.GetHeader("Origin")
|
||||
if host == "" {
|
||||
host = Config.Rustdesk.ApiServer
|
||||
}
|
||||
oauthConfig = &oauth2.Config{
|
||||
ClientID: oauthInfo.ClientId,
|
||||
ClientSecret: oauthInfo.ClientSecret,
|
||||
RedirectURL: oauthInfo.RedirectUrl,
|
||||
RedirectURL: host + "/api/oidc/callback",
|
||||
}
|
||||
|
||||
// Maybe should validate the oauthConfig here
|
||||
@@ -182,6 +201,10 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O
|
||||
oauthConfig.Endpoint = github.Endpoint
|
||||
oauthConfig.Scopes = []string{"read:user", "user:email"}
|
||||
provider = os.GithubProvider()
|
||||
case model.OauthTypeLinuxdo:
|
||||
provider = os.LinuxdoProvider()
|
||||
oauthConfig.Endpoint = provider.Endpoint()
|
||||
oauthConfig.Scopes = []string{"profile"}
|
||||
//case model.OauthTypeGoogle: //google单独出来,可以少一次FetchOidcEndpoint请求
|
||||
// oauthConfig.Endpoint = google.Endpoint
|
||||
// oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
|
||||
@@ -299,6 +322,16 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oid
|
||||
return nil, user.ToOauthUser()
|
||||
}
|
||||
|
||||
// linuxdoCallback linux.do回调
|
||||
func (os *OauthService) linuxdoCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) {
|
||||
var user = &model.LinuxdoUser{}
|
||||
err, _ := os.callbackBase(oauthConfig, provider, code, verifier, nonce, user)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
return nil, user.ToOauthUser()
|
||||
}
|
||||
|
||||
// oidcCallback oidc回调, 通过code获取用户信息
|
||||
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) {
|
||||
var user = &model.OidcUser{}
|
||||
@@ -309,8 +342,8 @@ func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.
|
||||
}
|
||||
|
||||
// Callback: Get user information by code and op(Oauth provider)
|
||||
func (os *OauthService) Callback(code, verifier, op, nonce string) (err error, oauthUser *model.OauthUser) {
|
||||
err, oauthInfo, oauthConfig, provider := os.GetOauthConfig(op)
|
||||
func (os *OauthService) Callback(c *gin.Context, code, verifier, op, nonce string) (err error, oauthUser *model.OauthUser) {
|
||||
err, oauthInfo, oauthConfig, provider := os.GetOauthConfig(c, op)
|
||||
// oauthType is already validated in GetOauthConfig
|
||||
if err != nil {
|
||||
return err, nil
|
||||
@@ -319,6 +352,8 @@ func (os *OauthService) Callback(code, verifier, op, nonce string) (err error, o
|
||||
switch oauthType {
|
||||
case model.OauthTypeGithub:
|
||||
err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier, nonce)
|
||||
case model.OauthTypeLinuxdo:
|
||||
err, oauthUser = os.linuxdoCallback(oauthConfig, provider, code, verifier, nonce)
|
||||
case model.OauthTypeOidc, model.OauthTypeGoogle:
|
||||
err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier, nonce)
|
||||
default:
|
||||
|
||||
@@ -40,14 +40,7 @@ func (is *ServerCmdService) Create(u *model.ServerCmd) error {
|
||||
}
|
||||
|
||||
// SendCmd 发送命令
|
||||
func (is *ServerCmdService) SendCmd(target string, cmd string, arg string) (string, error) {
|
||||
port := 0
|
||||
switch target {
|
||||
case model.ServerCmdTargetIdServer:
|
||||
port = Config.Rustdesk.IdServerPort - 1
|
||||
case model.ServerCmdTargetRelayServer:
|
||||
port = Config.Rustdesk.RelayServerPort
|
||||
}
|
||||
func (is *ServerCmdService) SendCmd(port int, cmd string, arg string) (string, error) {
|
||||
//组装命令
|
||||
cmd = cmd + " " + arg
|
||||
res, err := is.SendSocketCmd("v6", port, cmd)
|
||||
|
||||
@@ -2,14 +2,14 @@ package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/lejianwen/rustdesk-api/v2/model"
|
||||
"github.com/lejianwen/rustdesk-api/v2/utils"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lejianwen/rustdesk-api/v2/model"
|
||||
"github.com/lejianwen/rustdesk-api/v2/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -55,7 +55,18 @@ func (us *UserService) InfoByUsernamePassword(username, password string) *model.
|
||||
Logger.Warn("Fallback to local database")
|
||||
}
|
||||
u := &model.User{}
|
||||
DB.Where("username = ? and password = ?", username, us.EncryptPassword(password)).First(u)
|
||||
DB.Where("username = ?", username).First(u)
|
||||
if u.Id == 0 {
|
||||
return u
|
||||
}
|
||||
ok, newHash, err := utils.VerifyPassword(u.Password, password)
|
||||
if err != nil || !ok {
|
||||
return &model.User{}
|
||||
}
|
||||
if newHash != "" {
|
||||
DB.Model(u).Update("password", newHash)
|
||||
u.Password = newHash
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
@@ -151,11 +162,6 @@ func (us *UserService) ListIdAndNameByGroupId(groupId uint) (res []*model.User)
|
||||
return res
|
||||
}
|
||||
|
||||
// EncryptPassword 加密密码
|
||||
func (us *UserService) EncryptPassword(password string) string {
|
||||
return utils.Md5(password + "rustdesk-api")
|
||||
}
|
||||
|
||||
// CheckUserEnable 判断用户是否禁用
|
||||
func (us *UserService) CheckUserEnable(u *model.User) bool {
|
||||
return u.Status == model.COMMON_STATUS_ENABLE
|
||||
@@ -168,7 +174,11 @@ func (us *UserService) Create(u *model.User) error {
|
||||
return errors.New("UsernameExists")
|
||||
}
|
||||
u.Username = us.formatUsername(u.Username)
|
||||
u.Password = us.EncryptPassword(u.Password)
|
||||
var err error
|
||||
u.Password, err = utils.EncryptPassword(u.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res := DB.Create(u).Error
|
||||
return res
|
||||
}
|
||||
@@ -268,8 +278,12 @@ func (us *UserService) FlushTokenByUuids(uuids []string) error {
|
||||
|
||||
// UpdatePassword 更新密码
|
||||
func (us *UserService) UpdatePassword(u *model.User, password string) error {
|
||||
u.Password = us.EncryptPassword(password)
|
||||
err := DB.Model(u).Update("password", u.Password).Error
|
||||
var err error
|
||||
u.Password, err = utils.EncryptPassword(password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.Model(u).Update("password", u.Password).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -486,8 +500,9 @@ func (us *UserService) RefreshAccessToken(ut *model.UserToken) {
|
||||
ut.ExpiredAt = us.UserTokenExpireTimestamp()
|
||||
DB.Model(ut).Update("expired_at", ut.ExpiredAt)
|
||||
}
|
||||
|
||||
func (us *UserService) AutoRefreshAccessToken(ut *model.UserToken) {
|
||||
if ut.ExpiredAt-time.Now().Unix() < 86400 {
|
||||
if ut.ExpiredAt-time.Now().Unix() < Config.App.TokenExpire.Milliseconds()/3000 {
|
||||
us.RefreshAccessToken(ut)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,15 +5,15 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var capdString = base64Captcha.NewDriverString(50, 150, 5, 10, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
|
||||
var capdString = base64Captcha.NewDriverString(50, 150, 0, 5, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
|
||||
|
||||
var capdMath = base64Captcha.NewDriverMath(50, 150, 5, 10, nil, nil, nil)
|
||||
var capdMath = base64Captcha.NewDriverMath(50, 150, 3, 10, nil, nil, nil)
|
||||
|
||||
type B64StringCaptchaProvider struct{}
|
||||
|
||||
func (p B64StringCaptchaProvider) Generate(ip string) (string, string, error) {
|
||||
_, content, answer := capdString.GenerateIdQuestionAnswer()
|
||||
return content, answer, nil
|
||||
func (p B64StringCaptchaProvider) Generate() (string, string, string, error) {
|
||||
id, content, answer := capdString.GenerateIdQuestionAnswer()
|
||||
return id, content, answer, nil
|
||||
}
|
||||
|
||||
func (p B64StringCaptchaProvider) Expiration() time.Duration {
|
||||
@@ -30,9 +30,9 @@ func (p B64StringCaptchaProvider) Draw(content string) (string, error) {
|
||||
|
||||
type B64MathCaptchaProvider struct{}
|
||||
|
||||
func (p B64MathCaptchaProvider) Generate(ip string) (string, string, error) {
|
||||
_, content, answer := capdMath.GenerateIdQuestionAnswer()
|
||||
return content, answer, nil
|
||||
func (p B64MathCaptchaProvider) Generate() (string, string, string, error) {
|
||||
id, content, answer := capdMath.GenerateIdQuestionAnswer()
|
||||
return id, content, answer, nil
|
||||
}
|
||||
|
||||
func (p B64MathCaptchaProvider) Expiration() time.Duration {
|
||||
|
||||
@@ -16,7 +16,7 @@ type SecurityPolicy struct {
|
||||
|
||||
// 验证码提供者接口
|
||||
type CaptchaProvider interface {
|
||||
Generate(ip string) (string, string, error)
|
||||
Generate() (id string, content string, answer string, err error)
|
||||
//Validate(ip, code string) bool
|
||||
Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow
|
||||
Draw(content string) (string, error) // 绘制验证码
|
||||
@@ -24,6 +24,7 @@ type CaptchaProvider interface {
|
||||
|
||||
// 验证码元数据
|
||||
type CaptchaMeta struct {
|
||||
Id string
|
||||
Content string
|
||||
Answer string
|
||||
ExpiresAt time.Time
|
||||
@@ -117,7 +118,7 @@ func (ll *LoginLimiter) RecordFailedAttempt(ip string) {
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
func (ll *LoginLimiter) RequireCaptcha(ip string) (error, CaptchaMeta) {
|
||||
func (ll *LoginLimiter) RequireCaptcha() (error, CaptchaMeta) {
|
||||
ll.mu.Lock()
|
||||
defer ll.mu.Unlock()
|
||||
|
||||
@@ -125,23 +126,24 @@ func (ll *LoginLimiter) RequireCaptcha(ip string) (error, CaptchaMeta) {
|
||||
return errors.New("no captcha provider available"), CaptchaMeta{}
|
||||
}
|
||||
|
||||
content, answer, err := ll.provider.Generate(ip)
|
||||
id, content, answer, err := ll.provider.Generate()
|
||||
if err != nil {
|
||||
return err, CaptchaMeta{}
|
||||
}
|
||||
|
||||
// 存储验证码
|
||||
ll.captchas[ip] = CaptchaMeta{
|
||||
ll.captchas[id] = CaptchaMeta{
|
||||
Id: id,
|
||||
Content: content,
|
||||
Answer: answer,
|
||||
ExpiresAt: time.Now().Add(ll.provider.Expiration()),
|
||||
}
|
||||
|
||||
return nil, ll.captchas[ip]
|
||||
return nil, ll.captchas[id]
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
func (ll *LoginLimiter) VerifyCaptcha(ip, answer string) bool {
|
||||
func (ll *LoginLimiter) VerifyCaptcha(id, answer string) bool {
|
||||
ll.mu.Lock()
|
||||
defer ll.mu.Unlock()
|
||||
|
||||
@@ -151,20 +153,20 @@ func (ll *LoginLimiter) VerifyCaptcha(ip, answer string) bool {
|
||||
}
|
||||
|
||||
// 获取并验证验证码
|
||||
captcha, exists := ll.captchas[ip]
|
||||
captcha, exists := ll.captchas[id]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// 清理过期验证码
|
||||
if time.Now().After(captcha.ExpiresAt) {
|
||||
delete(ll.captchas, ip)
|
||||
delete(ll.captchas, id)
|
||||
return false
|
||||
}
|
||||
|
||||
// 验证并清理状态
|
||||
if answer == captcha.Answer {
|
||||
delete(ll.captchas, ip)
|
||||
delete(ll.captchas, id)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -176,16 +178,6 @@ func (ll *LoginLimiter) DrawCaptcha(content string) (err error, str string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (ll *LoginLimiter) RemoveCaptcha(ip string) {
|
||||
ll.mu.Lock()
|
||||
defer ll.mu.Unlock()
|
||||
|
||||
_, exists := ll.captchas[ip]
|
||||
if exists {
|
||||
delete(ll.captchas, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// 清除记录窗口
|
||||
func (ll *LoginLimiter) RemoveAttempts(ip string) {
|
||||
ll.mu.Lock()
|
||||
@@ -212,7 +204,6 @@ func (ll *LoginLimiter) CheckSecurityStatus(ip string) (banned bool, captchaRequ
|
||||
|
||||
// 清理过期数据
|
||||
ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow))
|
||||
ll.pruneCaptchas(ip)
|
||||
|
||||
// 检查验证码要求
|
||||
captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold
|
||||
@@ -272,10 +263,10 @@ func (ll *LoginLimiter) pruneAttempts(ip string, cutoff time.Time) []time.Time {
|
||||
return valid
|
||||
}
|
||||
|
||||
func (ll *LoginLimiter) pruneCaptchas(ip string) {
|
||||
if captcha, exists := ll.captchas[ip]; exists {
|
||||
func (ll *LoginLimiter) pruneCaptchas(id string) {
|
||||
if captcha, exists := ll.captchas[id]; exists {
|
||||
if time.Now().After(captcha.ExpiresAt) {
|
||||
delete(ll.captchas, ip)
|
||||
delete(ll.captchas, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -299,7 +290,7 @@ func (ll *LoginLimiter) cleanupExpired() {
|
||||
}
|
||||
|
||||
// 清理验证码
|
||||
for ip := range ll.captchas {
|
||||
ll.pruneCaptchas(ip)
|
||||
for id := range ll.captchas {
|
||||
ll.pruneCaptchas(id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,18 +2,18 @@ package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MockCaptchaProvider struct{}
|
||||
|
||||
func (p *MockCaptchaProvider) Generate(ip string) (string, string, error) {
|
||||
return "CONTENT", "MOCK", nil
|
||||
}
|
||||
|
||||
func (p *MockCaptchaProvider) Validate(ip, code string) bool {
|
||||
return code == "MOCK"
|
||||
func (p *MockCaptchaProvider) Generate() (string, string, string, error) {
|
||||
id := uuid.New().String()
|
||||
content := uuid.New().String()
|
||||
answer := uuid.New().String()
|
||||
return id, content, answer, nil
|
||||
}
|
||||
|
||||
func (p *MockCaptchaProvider) Expiration() time.Duration {
|
||||
@@ -74,17 +74,22 @@ func TestCaptchaFlow(t *testing.T) {
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
err, capc := limiter.RequireCaptcha(ip)
|
||||
err, capc := limiter.RequireCaptcha()
|
||||
if err != nil {
|
||||
t.Fatalf("生成验证码失败: %v", err)
|
||||
}
|
||||
fmt.Printf("验证码内容: %#v\n", capc)
|
||||
|
||||
// 验证成功
|
||||
if !limiter.VerifyCaptcha(ip, capc.Answer) {
|
||||
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||
t.Error("验证码应该验证成功")
|
||||
}
|
||||
|
||||
// 验证已删除
|
||||
if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||
t.Error("验证码应该已删除")
|
||||
}
|
||||
|
||||
limiter.RemoveAttempts(ip)
|
||||
// 验证后状态
|
||||
if banned, need := limiter.CheckSecurityStatus(ip); banned || need {
|
||||
@@ -104,14 +109,14 @@ func TestCaptchaMustFlow(t *testing.T) {
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
err, capc := limiter.RequireCaptcha(ip)
|
||||
err, capc := limiter.RequireCaptcha()
|
||||
if err != nil {
|
||||
t.Fatalf("生成验证码失败: %v", err)
|
||||
}
|
||||
fmt.Printf("验证码内容: %#v\n", capc)
|
||||
|
||||
// 验证成功
|
||||
if !limiter.VerifyCaptcha(ip, capc.Answer) {
|
||||
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||
t.Error("验证码应该验证成功")
|
||||
}
|
||||
|
||||
@@ -136,7 +141,7 @@ func TestAttemptTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
err, _ := limiter.RequireCaptcha(ip)
|
||||
err, _ := limiter.RequireCaptcha()
|
||||
if err != nil {
|
||||
t.Fatalf("生成验证码失败: %v", err)
|
||||
}
|
||||
@@ -167,7 +172,7 @@ func TestCaptchaTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
err, _ := limiter.RequireCaptcha(ip)
|
||||
err, capc := limiter.RequireCaptcha()
|
||||
if err != nil {
|
||||
t.Fatalf("生成验证码失败: %v", err)
|
||||
}
|
||||
@@ -175,9 +180,8 @@ func TestCaptchaTimeout(t *testing.T) {
|
||||
// 等待超过 CaptchaValidPeriod
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
code := "MOCK"
|
||||
// 验证成功
|
||||
if limiter.VerifyCaptcha(ip, code) {
|
||||
if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||
t.Error("验证码应该已过期")
|
||||
}
|
||||
|
||||
@@ -261,7 +265,7 @@ func TestB64CaptchaFlow(t *testing.T) {
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
err, capc := limiter.RequireCaptcha(ip)
|
||||
err, capc := limiter.RequireCaptcha()
|
||||
if err != nil {
|
||||
t.Fatalf("生成验证码失败: %v", err)
|
||||
}
|
||||
@@ -275,7 +279,7 @@ func TestB64CaptchaFlow(t *testing.T) {
|
||||
fmt.Printf("验证码内容: %#v\n", b64)
|
||||
|
||||
// 验证成功
|
||||
if !limiter.VerifyCaptcha(ip, capc.Answer) {
|
||||
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
|
||||
t.Error("验证码应该验证成功")
|
||||
}
|
||||
limiter.RemoveAttempts(ip)
|
||||
|
||||
42
utils/password.go
Normal file
42
utils/password.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// EncryptPassword hashes the input password using bcrypt.
|
||||
// An error is returned if hashing fails.
|
||||
func EncryptPassword(password string) (string, error) {
|
||||
bs, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bs), nil
|
||||
}
|
||||
|
||||
// VerifyPassword checks the input password against the stored hash.
|
||||
// When a legacy MD5 hash is provided, the password is rehashed with bcrypt
|
||||
// and the new hash is returned. Any internal bcrypt error is returned.
|
||||
func VerifyPassword(hash, input string) (bool, string, error) {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(input))
|
||||
if err == nil {
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
var invalidPrefixErr bcrypt.InvalidHashPrefixError
|
||||
if errors.As(err, &invalidPrefixErr) || errors.Is(err, bcrypt.ErrHashTooShort) {
|
||||
// Try fallback to legacy MD5 hash verification
|
||||
if hash == Md5(input+"rustdesk-api") {
|
||||
newHash, err2 := bcrypt.GenerateFromPassword([]byte(input), bcrypt.DefaultCost)
|
||||
if err2 != nil {
|
||||
return true, "", err2
|
||||
}
|
||||
return true, string(newHash), nil
|
||||
}
|
||||
}
|
||||
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
|
||||
return false, "", nil
|
||||
}
|
||||
return false, "", err
|
||||
}
|
||||
40
utils/password_test.go
Normal file
40
utils/password_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func TestVerifyPasswordMD5(t *testing.T) {
|
||||
hash := Md5("secret" + "rustdesk-api")
|
||||
ok, newHash, err := VerifyPassword(hash, "secret")
|
||||
if err != nil {
|
||||
t.Fatalf("md5 verify failed: %v", err)
|
||||
}
|
||||
if !ok || newHash == "" {
|
||||
t.Fatalf("md5 migration failed")
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(newHash), []byte("secret")) != nil {
|
||||
t.Fatalf("invalid rehash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPasswordBcrypt(t *testing.T) {
|
||||
b, _ := bcrypt.GenerateFromPassword([]byte("pass"), bcrypt.DefaultCost)
|
||||
ok, newHash, err := VerifyPassword(string(b), "pass")
|
||||
if err != nil || !ok || newHash != "" {
|
||||
t.Fatalf("bcrypt verify failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPasswordMigrate(t *testing.T) {
|
||||
md5hash := Md5("mypass" + "rustdesk-api")
|
||||
ok, newHash, err := VerifyPassword(md5hash, "mypass")
|
||||
if err != nil || !ok || newHash == "" {
|
||||
t.Fatalf("expected bcrypt rehash")
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(newHash), []byte("mypass")) != nil {
|
||||
t.Fatalf("rehash not valid bcrypt")
|
||||
}
|
||||
}
|
||||
@@ -2,9 +2,9 @@ package utils
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
crand "crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
@@ -69,8 +69,12 @@ func RandomString(n int) string {
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
length := len(letterBytes)
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letterBytes[rand.Intn(length)]
|
||||
randomBytes := make([]byte, n)
|
||||
if _, err := crand.Read(randomBytes); err != nil {
|
||||
return ""
|
||||
}
|
||||
for i, rb := range randomBytes {
|
||||
b[i] = letterBytes[int(rb)%length]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user