Compare commits

...

52 Commits

Author SHA1 Message Date
lejianwen
2948eaaa5c chore: Update Go version to 1.23 in build configurations 2025-06-16 15:41:16 +08:00
lejianwen
8641ba5c0c docs: Update swagger docs 2025-06-16 12:31:48 +08:00
lejianwen
60b7a18fe7 feat: Add PostgreSQL support and refactor MySQL DSN handling (#284) 2025-06-16 12:26:08 +08:00
lejianwen
ca068816ae feat: Add start time in /api/sysinfover 2025-06-16 12:23:48 +08:00
lejianwen
06648d9a6c fix(admin): Use admin-hello first
(#274) (#255)
2025-06-15 15:33:12 +08:00
puyujian
8a8abd5163 feat(oauth): 支持linux.do登录 (#280)
* 支持linux.do登录

* 修正
2025-06-15 15:32:20 +08:00
lejianwen
97f98cd6ce chore: update download links for musl cross-compilers 2025-06-05 12:14:17 +08:00
lejianwen
51f2920661 fix: Init sqlite fail(#266) 2025-06-04 09:31:43 +08:00
lejianwen
7a5d141ce8 fix(server): Port custom (#257) 2025-05-30 12:27:37 +08:00
lejianwen
3cef02a0bb fix(webclient): Peer online status 2025-05-29 18:51:37 +08:00
lejianwen
46a7ecc1ba fix: Captcha some problem when users login with same ip 2025-05-27 17:36:20 +08:00
lejianwen
4d2b037f5e docs: Readme 2025-05-25 17:44:29 +08:00
lejianwen
323364b24e feat(register): Register status can be set (#223) 2025-05-25 17:03:13 +08:00
lejianwen
f19109cdf8 feat(login): Captcha upgrade and add the function to ban IP addresses (#250) 2025-05-25 16:52:58 +08:00
Tao Chen
527260d60a fix: dn should be case-insensitive (#250) 2025-05-21 09:07:08 +08:00
lejianwen
46bb44f0ab fix(webclient): DefaultIdServerPort undefined (#238) 2025-05-16 20:14:36 +08:00
lejianwen
2f1380f24a fix(webclient): Remove license warning (#235) 2025-05-13 13:11:19 +08:00
lejianwen
ece3328e94 feat(webclient): Web client to 1.4.0 2025-05-12 20:16:08 +08:00
lejianwen
fdd26d87be fix: PageSize (#225) 2025-05-06 19:08:18 +08:00
lejianwen
2ade0dda42 chore: Noelware/docker-manifest-action 2025-04-25 16:20:36 +08:00
lejianwen
a87ae5cf65 chore: Noelware/docker-manifest-action 2025-04-25 14:34:45 +08:00
lejianwen
fe7b8b53a6 style: Oauth page languages 2025-04-24 21:52:43 +08:00
lejianwen
b929f3efdb style: Remove useless configurations 2025-04-15 10:52:46 +08:00
lejianwen
f847fc076f fix: Low case (#149) 2025-04-15 10:46:21 +08:00
lejianwen
60d0a701ce fix: Share pwd 2025-04-15 10:09:56 +08:00
lejianwen
0dedaf6824 feat: Peer share to group 2025-04-14 19:12:40 +08:00
lejianwen
ab231b3fed feat: Add SysInfoVer endpoint and AppService for version retrieval 2025-04-07 16:38:21 +08:00
lejianwen
e7f28cca36 fix: Update peer based on the UUID (#180) 2025-04-02 09:50:16 +08:00
lejianwen
505e8aac4b feat: Add Korean translations validator (#168) 2025-04-02 09:42:29 +08:00
lejianwen
746e2a6052 fix: Get Uuids 2025-03-15 21:02:47 +08:00
lejianwen
dc03d5d83d style: Update peer last online time logic (#173) 2025-03-15 21:02:08 +08:00
lejianwen
b770ab178d feat(admin): Add filter by ip and username (#172) 2025-03-15 19:49:49 +08:00
Tao Chen
fd7e022e88 fix: rm varify password accidentally (#176) 2025-03-15 19:40:02 +08:00
lejianwen
ac5df6826b fix: Init database err (#166) 2025-03-04 18:14:25 +08:00
lejianwen
91908859bc docs: Readme 2025-03-04 16:30:12 +08:00
lejianwen
9c8822f857 fix: templates 2025-03-04 16:14:34 +08:00
lejianwen
9709da7fb6 style(webclient): ws-host 2025-03-04 15:48:25 +08:00
lejianwen
1852f10131 feat(config): add ws-host configuration (#156) 2025-03-04 15:43:25 +08:00
Tao Chen
77d7b43e21 fix: Fix/ldap tls (#162)
* optimize and fix tls of LDAP

* fix
2025-03-02 22:59:01 +08:00
lejianwen
7410b539a0 style(service): refactor global dependencies to use local variables 2025-02-26 19:15:38 +08:00
lejianwen
0403d71502 feat(oauth): Oauth nonce (#148) 2025-02-26 16:36:53 +08:00
lejianwen
2af1d93a7d feat(oauth): Oauth callback page beautification (#115) 2025-02-25 13:51:49 +08:00
lejianwen
76281ad761 feat(api): Add device group for 1.3.8 2025-02-23 21:35:42 +08:00
lejianwen
d1ec9b4916 feat(api): Add /device-group/accessible for 1.3.8 2025-02-23 15:32:44 +08:00
lejianwen
448e7460a9 style(oidc): Oidc style 2025-02-21 09:49:41 +08:00
lejianwen
ee0cbabffc fix(admin): Admin hello 2025-02-20 19:37:34 +08:00
lejianwen
d6a5af890a style: No need exec sql on version 261 2025-02-20 19:22:22 +08:00
lejianwen
dc313441e5 fix: Js content-type 2025-02-19 16:04:03 +08:00
Tao Chen
c75320f4f4 feat(oidc): add pkce (#150) 2025-02-19 09:31:25 +08:00
lejianwen
c788f78416 docs: Readme 2025-02-17 10:59:38 +08:00
lejianwen
49cf954d4a fix(config)!: Token expire time (#145)
将配置中的过期时间单位统一为time.Duration,可以设置为`h`,`m`,`s`
2025-02-17 10:49:59 +08:00
lejianwen
014e3db54f docs: Readme 2025-02-16 21:06:55 +08:00
91 changed files with 95327 additions and 86972 deletions

View File

@@ -66,7 +66,7 @@ jobs:
- name: Set up Go environment
uses: actions/setup-go@v4
with:
go-version: '1.22' # 选择 Go 版本
go-version: '1.23' # 选择 Go 版本
- name: Set up npm
uses: actions/setup-node@v2
@@ -115,12 +115,12 @@ jobs:
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
else
if [ "${{ matrix.job.platform }}" = "arm64" ]; then
wget https://musl.cc/aarch64-linux-musl-cross.tgz
wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
tar -xf aarch64-linux-musl-cross.tgz
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz
wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
tar -xf armv7l-linux-musleabihf-cross.tgz
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
@@ -147,6 +147,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Generate Changelog
if: startsWith(github.ref, 'refs/tags/') && github.event_name == 'push'
run: npx changelogithub # or changelogithub@0.12 if ensure the stable result
env:
GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
@@ -380,7 +381,7 @@ jobs:
- name: Create and push manifest Docker Hub (:version)
if: ${{ env.SKIP_DOCKER_HUB == 'false' }}
uses: Noelware/docker-manifest-action@master
uses: Noelware/docker-manifest-action@v0.2.3
with:
base-image: ${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}
extra-images: ${{ env.DOCKERHUB_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}-amd64,
@@ -390,7 +391,7 @@ jobs:
- name: Create and push manifest GHCR (:version)
if: ${{ env.SKIP_GHCR == 'false' }}
uses: Noelware/docker-manifest-action@master
uses: Noelware/docker-manifest-action@v0.2.3
with:
base-image: ghcr.io/${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}
extra-images: ghcr.io/${{ env.GHCR_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}-amd64,
@@ -401,7 +402,7 @@ jobs:
- name: Create and push manifest Docker Hub (:latest)
if: ${{ env.SKIP_DOCKER_HUB == 'false' }}
uses: Noelware/docker-manifest-action@master
uses: Noelware/docker-manifest-action@v0.2.3
with:
base-image: ${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:latest
extra-images: ${{ env.DOCKERHUB_IMAGE_NAMESPACE }}/rustdesk-api:latest-amd64,
@@ -411,7 +412,7 @@ jobs:
- name: Create and push manifest GHCR (:latest)
if: ${{ env.SKIP_GHCR == 'false' }}
uses: Noelware/docker-manifest-action@master
uses: Noelware/docker-manifest-action@v0.2.3
with:
base-image: ghcr.io/${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:latest
extra-images: ghcr.io/${{ env.GHCR_IMAGE_NAMESPACE }}/rustdesk-api:latest-amd64,
@@ -422,7 +423,7 @@ jobs:
- name: Create and push Full S6 manifest Docker Hub (:version)
if: ${{ env.SKIP_DOCKER_HUB == 'false' }}
uses: Noelware/docker-manifest-action@master
uses: Noelware/docker-manifest-action@v0.2.3
with:
base-image: ${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:full-s6
extra-images: ${{ env.DOCKERHUB_IMAGE_NAMESPACE }}/rustdesk-api:full-s6-amd64,
@@ -433,7 +434,7 @@ jobs:
- name: Create and push Full S6 manifest GHCR (:latest)
if: ${{ env.SKIP_GHCR == 'false' }}
uses: Noelware/docker-manifest-action@master
uses: Noelware/docker-manifest-action@v0.2.3
with:
base-image: ghcr.io/${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:full-s6
extra-images: ghcr.io/${{ env.GHCR_IMAGE_NAMESPACE }}/rustdesk-api:full-s6-amd64,

View File

@@ -61,7 +61,7 @@ jobs:
- name: Set up Go environment
uses: actions/setup-go@v4
with:
go-version: '1.22' # 选择 Go 版本
go-version: '1.23' # 选择 Go 版本
- name: Set up npm
uses: actions/setup-node@v2
@@ -101,12 +101,12 @@ jobs:
zip -r ${{ matrix.job.goos}}-${{ matrix.job.platform }}.${{matrix.job.file_ext}} ./release
else
if [ "${{ matrix.job.platform }}" = "arm64" ]; then
wget https://musl.cc/aarch64-linux-musl-cross.tgz
wget https://musl.ljw.red/aarch64-linux-musl-cross.tgz
tar -xf aarch64-linux-musl-cross.tgz
export PATH=$PATH:$PWD/aarch64-linux-musl-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=${{ matrix.job.platform }} CC=aarch64-linux-musl-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
elif [ "${{ matrix.job.platform }}" = "armv7l" ]; then
wget https://musl.cc/armv7l-linux-musleabihf-cross.tgz
wget https://musl.ljw.red/armv7l-linux-musleabihf-cross.tgz
tar -xf armv7l-linux-musleabihf-cross.tgz
export PATH=$PATH:$PWD/armv7l-linux-musleabihf-cross/bin
GOOS=${{ matrix.job.goos }} GOARCH=arm GOARM=7 CC=armv7l-linux-musleabihf-gcc CGO_LDFLAGS="-static" CGO_ENABLED=1 go build -ldflags "-s -w" -o ./release/apimain ./cmd/apimain.go
@@ -317,7 +317,7 @@ jobs:
- name: Create and push manifest Docker Hub (:version)
if: ${{ env.SKIP_DOCKER_HUB == 'false' }}
uses: Noelware/docker-manifest-action@master
uses: Noelware/docker-manifest-action@v0.2.3
with:
base-image: ${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}
extra-images: ${{ env.DOCKERHUB_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}-amd64,
@@ -327,7 +327,7 @@ jobs:
- name: Create and push manifest GHCR (:version)
if: ${{ env.SKIP_GHCR == 'false' }}
uses: Noelware/docker-manifest-action@master
uses: Noelware/docker-manifest-action@v0.2.3
with:
base-image: ghcr.io/${{ env.BASE_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}
extra-images: ghcr.io/${{ env.GHCR_IMAGE_NAMESPACE }}/rustdesk-api:${{ env.TAG }}-amd64,

2
.gitignore vendored
View File

@@ -5,4 +5,4 @@ runtime/*
go.sum
resources/admin
release
data
data/rustdeskapi.db

View File

@@ -4,12 +4,12 @@
本项目使用 Go 实现了 RustDesk 的 API并包含了 Web Admin 和 Web 客户端。RustDesk 是一个远程桌面软件,提供了自托管的解决方案。
<div align=center>
<div align=center>
<img src="https://img.shields.io/badge/golang-1.22-blue"/>
<img src="https://img.shields.io/badge/gin-v1.9.0-lightBlue"/>
<img src="https://img.shields.io/badge/gorm-v1.25.7-green"/>
<img src="https://img.shields.io/badge/swag-v1.16.3-yellow"/>
<img src="https://img.shields.io/badge/i18n-7-green"/>
<img src="https://goreportcard.com/badge/github.com/lejianwen/rustdesk-api/v2"/>
<img src="https://github.com/lejianwen/rustdesk-api/actions/workflows/build.yml/badge.svg"/>
</div>
@@ -108,8 +108,7 @@
* 可以官方指令
* 可以添加自定义指令
* 可以执行自定义指令
![rustdesk_command_advance](./docs/rustdesk_command_advance.png)
11. **LDAP 支持**, 当在API Server上设置了LDAP(已测试AD和LDAP),可以通过LDAP中的用户信息进行登录 https://github.com/lejianwen/rustdesk-api/issues/114 ,如果LDAP验证失败返回本地用户
@@ -146,69 +145,11 @@
### 相关配置
* [配置文件](./conf/config.yaml)
* 参考`conf/config.yaml`配置文件,修改相关配置。
* 如果`gorm.type``sqlite`则不需要配置mysql相关配置。
* 语言如果不设置默认为`zh-CN`
```yaml
lang: "en"
app:
web-client: 1 # 1:启用 0:禁用
register: false #是否开启注册
show-swagger: 0 #是否显示swagger文档
gin:
api-addr: "0.0.0.0:21114"
mode: "release"
resources-path: 'resources'
trust-proxy: ""
gorm:
type: "sqlite"
max-idle-conns: 10
max-open-conns: 100
mysql:
username: "root"
password: "111111"
addr: "192.168.1.66:3308"
dbname: "rustdesk"
rustdesk:
id-server: "192.168.1.66:21116"
relay-server: "192.168.1.66:21117"
api-server: "http://192.168.1.66:21114"
key: "123456789"
personal: 1
logger:
path: "./runtime/log.txt"
level: "warn" #trace,debug,info,warn,error,fatal
report-caller: true
proxy:
enable: false
host: ""
jwt:
key: ""
expire-duration: 360000
ldap:
enable: false
url: "ldap://ldap.example.com:389"
tls: false
tls-verify: false
base-dn: "dc=example,dc=com"
bind-dn: "cn=admin,dc=example,dc=com"
bind-password: "password"
user:
base-dn: "ou=users,dc=example,dc=com"
enable-attr: "" #The attribute name of the user for enabling, in AD it is "userAccountControl", empty means no enable attribute, all users are enabled
enable-attr-value: "" # The value of the enable attribute when the user is enabled. If you are using AD, just set random value, it will be ignored.
filter: "(cn=*)"
username: "uid" # The attribute name of the user for usernamem if you are using AD, it should be "sAMAccountName"
email: "mail"
first-name: "givenName"
last-name: "sn"
sync: false # If true, the user will be synchronized to the database when the user logs in. If false, the user will be synchronized to the database when the user be created.
admin-group: "cn=admin,dc=example,dc=com" # The group name of the admin group, if the user is in this group, the user will be an admin.
```
### 环境变量
环境变量和配置文件`conf/config.yaml`中的配置一一对应,变量名前缀是`RUSTDESK_API`
下面表格并未全部列出,可以参考`conf/config.yaml`中的配置。
@@ -220,7 +161,11 @@ ldap:
| RUSTDESK_API_APP_WEB_CLIENT | 是否启用web-client; 1:启用,0:不启用; 默认启用 | 1 |
| RUSTDESK_API_APP_REGISTER | 是否开启注册; `true`, `false` 默认`false` | `false` |
| RUSTDESK_API_APP_SHOW_SWAGGER | 是否可见swagger文档;`1`显示,`0`不显示,默认`0`不显示 | `1` |
| RUSTDESK_API_APP_TOKEN_EXPIRE | token有效时长(秒) | `3600` |
| RUSTDESK_API_APP_TOKEN_EXPIRE | token有效时长 | `168h` |
| RUSTDESK_API_APP_DISABLE_PWD_LOGIN | 是否禁用密码登录; `true`, `false` 默认`false` | `false` |
| RUSTDESK_API_APP_REGISTER_STATUS | 注册用户默认状态; 1 启用2 禁用, 默认 1 | `1` |
| RUSTDESK_API_APP_CAPTCHA_THRESHOLD | 验证码触发次数; -1 不启用, 0 一直启用, >0 登录错误次数后启用 ;默认 `3` | `3` |
| RUSTDESK_API_APP_BAN_THRESHOLD | 封禁IP触发次数; 0 不启用, >0 登录错误次数后封禁IP; 默认 `0` | `0` |
| -----ADMIN配置----- | ---------- | ---------- |
| RUSTDESK_API_ADMIN_TITLE | 后台标题 | `RustDesk Api Admin` |
| RUSTDESK_API_ADMIN_HELLO | 后台欢迎语,可以使用`html` | |
@@ -244,12 +189,13 @@ ldap:
| RUSTDESK_API_RUSTDESK_KEY | Rustdesk的key | 123456789 |
| RUSTDESK_API_RUSTDESK_KEY_FILE | Rustdesk存放key的文件 | `./conf/data/id_ed25519.pub` |
| RUSTDESK_API_RUSTDESK_WEBCLIENT<br/>_MAGIC_QUERYONLINE | Web client v2 中是否启用新的在线状态查询方法; `1`:启用,`0`:不启用,默认不启用 | `0` |
| RUSTDESK_API_RUSTDESK_WS_HOST | 自定义Websocket Host | `wss://192.168.1.123:1234` |
| ----PROXY配置----- | ---------- | ---------- |
| RUSTDESK_API_PROXY_ENABLE | 是否启用代理:`false`, `true` | `false` |
| RUSTDESK_API_PROXY_HOST | 代理地址 | `http://127.0.0.1:1080` |
| ----JWT配置---- | -------- | -------- |
| RUSTDESK_API_JWT_KEY | 自定义JWT KEY,为空则不启用JWT<br/>如果没使用`lejianwen/rustdesk-server`中的`MUST_LOGIN`,建议设置为空 | |
| RUSTDESK_API_JWT_EXPIRE_DURATION | JWT有效时间 | 360000 |
| RUSTDESK_API_JWT_EXPIRE_DURATION | JWT有效时间 | `168h` |
### 运行

View File

@@ -8,7 +8,7 @@ desktop software that provides self-hosted solutions.
<img src="https://img.shields.io/badge/gin-v1.9.0-lightBlue"/>
<img src="https://img.shields.io/badge/gorm-v1.25.7-green"/>
<img src="https://img.shields.io/badge/swag-v1.16.3-yellow"/>
<img src="https://img.shields.io/badge/i18n-7-green"/>
<img src="https://goreportcard.com/badge/github.com/lejianwen/rustdesk-api/v2"/>
<img src="https://github.com/lejianwen/rustdesk-api/actions/workflows/build.yml/badge.svg"/>
</div>
@@ -109,8 +109,6 @@ displaying data.Frontend code is available at [rustdesk-api-web](https://github.
* Custom commands can be added
* Custom commands can be executed
![rustdesk_command_advance](./docs/en_img/rustdesk_command_advance.png)
11. **LDAP Support**, When you setup the LDAP(test for OpenLDAP and AD), you can login with the LDAP's user. https://github.com/lejianwen/rustdesk-api/issues/114 , if LDAP fail fallback local user
### Web Client:
@@ -145,110 +143,58 @@ displaying data.Frontend code is available at [rustdesk-api-web](https://github.
### Configuration
* [Config File](./conf/config.yaml)
* Modify the configuration in `conf/config.yaml`.
* If `gorm.type` is set to `sqlite`, MySQL-related configurations are not required.
* Language support: `en` and `zh-CN` are supported. The default is `zh-CN`.
```yaml
lang: "en"
app:
web-client: 1 # web client route 1:open 0:close
register: false #register enable
show-swagger: 0 #show swagger 1:open 0:close
gin:
api-addr: "0.0.0.0:21114"
mode: "release"
resources-path: 'resources'
trust-proxy: ""
gorm:
type: "sqlite"
max-idle-conns: 10
max-open-conns: 100
mysql:
username: "root"
password: "111111"
addr: "192.168.1.66:3308"
dbname: "rustdesk"
rustdesk:
id-server: "192.168.1.66:21116"
relay-server: "192.168.1.66:21117"
api-server: "http://192.168.1.66:21114"
key: "123456789"
personal: 1
logger:
path: "./runtime/log.txt"
level: "warn" #trace,debug,info,warn,error,fatal
report-caller: true
proxy:
enable: false
host: ""
jwt:
key: ""
expire-duration: 360000
ldap:
enable: false
url: "ldap://ldap.example.com:389"
tls: false
tls-verify: false
base-dn: "dc=example,dc=com"
bind-dn: "cn=admin,dc=example,dc=com"
bind-password: "password"
user:
base-dn: "ou=users,dc=example,dc=com"
enable-attr: "" #The attribute name of the user for enabling, in AD it is "userAccountControl", empty means no enable attribute, all users are enabled
enable-attr-value: "" # The value of the enable attribute when the user is enabled. If you are using AD, just set random value, it will be ignored.
filter: "(cn=*)"
username: "uid" # The attribute name of the user for usernamem if you are using AD, it should be "sAMAccountName"
email: "mail"
first-name: "givenName"
last-name: "sn"
sync: false # If true, the user will be synchronized to the database when the user logs in. If false, the user will be synchronized to the database when the user be created.
admin-group: "cn=admin,dc=example,dc=com" # The group name of the admin group, if the user is in this group, the user will be an admin.
```
### Environment Variables
The environment variables correspond one-to-one with the configurations in the `conf/config.yaml` file. The prefix for variable names is `RUSTDESK_API`.
The table below does not list all configurations. Please refer to the configurations in `conf/config.yaml`.
| Variable Name | Description | Example |
|---------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
| TZ | timezone | Asia/Shanghai |
| RUSTDESK_API_LANG | Language | `en`,`zh-CN` |
| RUSTDESK_API_APP_WEB_CLIENT | web client on/off; 1: on, 0 off, default: 1 | 1 |
| RUSTDESK_API_APP_REGISTER | register enable; `true`, `false`; default:`false` | `false` |
| RUSTDESK_API_APP_SHOW_SWAGGER | swagger visible; 1: yes, 0: no; default: 0 | `0` |
| RUSTDESK_API_APP_TOKEN_EXPIRE | token expire duration(second) | `3600` |
| ----- ADMIN Configuration----- | ---------- | ---------- |
| RUSTDESK_API_ADMIN_TITLE | Admin Title | `RustDesk Api Admin` |
| RUSTDESK_API_ADMIN_HELLO | Admin welcome message, you can use `html` | |
| RUSTDESK_API_ADMIN_HELLO_FILE | Admin welcome message file,<br>will override `RUSTDESK_API_ADMIN_HELLO` | `./conf/admin/hello.html` |
| ----- GIN Configuration ----- | --------------------------------------- | ----------------------------- |
| RUSTDESK_API_GIN_TRUST_PROXY | Trusted proxy IPs, separated by commas. | 192.168.1.2,192.168.1.3 |
| ----- GORM Configuration ----- | --------------------------------------- | ----------------------------- |
| RUSTDESK_API_GORM_TYPE | Database type (`sqlite` or `mysql`). Default is `sqlite`. | sqlite |
| RUSTDESK_API_GORM_MAX_IDLE_CONNS | Maximum idle connections | 10 |
| RUSTDESK_API_GORM_MAX_OPEN_CONNS | Maximum open connections | 100 |
| RUSTDESK_API_RUSTDESK_PERSONAL | Open Personal Api 1:Enable,0:Disable | 1 |
| ----- MYSQL Configuration ----- | --------------------------------------- | ----------------------------- |
| RUSTDESK_API_MYSQL_USERNAME | MySQL username | root |
| RUSTDESK_API_MYSQL_PASSWORD | MySQL password | 111111 |
| RUSTDESK_API_MYSQL_ADDR | MySQL address | 192.168.1.66:3306 |
| RUSTDESK_API_MYSQL_DBNAME | MySQL database name | rustdesk |
| ----- RUSTDESK Configuration ----- | --------------------------------------- | ----------------------------- |
| RUSTDESK_API_RUSTDESK_ID_SERVER | Rustdesk ID server address | 192.168.1.66:21116 |
| RUSTDESK_API_RUSTDESK_RELAY_SERVER | Rustdesk relay server address | 192.168.1.66:21117 |
| RUSTDESK_API_RUSTDESK_API_SERVER | Rustdesk API server address | http://192.168.1.66:21114 |
| RUSTDESK_API_RUSTDESK_KEY | Rustdesk key | 123456789 |
| RUSTDESK_API_RUSTDESK_KEY_FILE | Rustdesk key file | `./conf/data/id_ed25519.pub` |
| RUSTDESK_API_RUSTDESK<br/>_WEBCLIENT_MAGIC_QUERYONLINE | New online query method is enabled in the web client v2; '1': Enabled, '0': Disabled, not enabled by default | `0` |
| ---- PROXY ----- | --------------- | ---------- |
| RUSTDESK_API_PROXY_ENABLE | proxy_enable :`false`, `true` | `false` |
| RUSTDESK_API_PROXY_HOST | proxy_host | `http://127.0.0.1:1080` |
| ----JWT---- | -------- | -------- |
| RUSTDESK_API_JWT_KEY | Custom JWT KEY, if empty JWT is not enabled.<br/>If `MUST_LOGIN` from `lejianwen/rustdesk-server` is not used, it is recommended to leave it empty. | |
| RUSTDESK_API_JWT_EXPIRE_DURATION | JWT expire duration | 360000 |
| Variable Name | Description | Example |
|--------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
| TZ | timezone | Asia/Shanghai |
| RUSTDESK_API_LANG | Language | `en`,`zh-CN` |
| RUSTDESK_API_APP_WEB_CLIENT | web client on/off; 1: on, 0 off, default: 1 | 1 |
| RUSTDESK_API_APP_REGISTER | register enable; `true`, `false`; default:`false` | `false` |
| RUSTDESK_API_APP_SHOW_SWAGGER | swagger visible; 1: yes, 0: no; default: 0 | `0` |
| RUSTDESK_API_APP_TOKEN_EXPIRE | token expire duration | `168h` |
| RUSTDESK_API_APP_DISABLE_PWD_LOGIN | disable password login | `false` |
| RUSTDESK_API_APP_REGISTER_STATUS | register user default status ; 1 enabled , 2 disabled ; default 1 | `1` |
| RUSTDESK_API_APP_CAPTCHA_THRESHOLD | captcha threshold; -1 disabled, 0 always enable, >0 threshold ;default `3` | `3` |
| RUSTDESK_API_APP_BAN_THRESHOLD | ban ip threshold; 0 disabled, >0 threshold ; default `0` | `0` |
| ----- ADMIN Configuration----- | ---------- | ---------- |
| RUSTDESK_API_ADMIN_TITLE | Admin Title | `RustDesk Api Admin` |
| RUSTDESK_API_ADMIN_HELLO | Admin welcome message, you can use `html` | |
| RUSTDESK_API_ADMIN_HELLO_FILE | Admin welcome message file,<br>will override `RUSTDESK_API_ADMIN_HELLO` | `./conf/admin/hello.html` |
| ----- GIN Configuration ----- | --------------------------------------- | ----------------------------- |
| RUSTDESK_API_GIN_TRUST_PROXY | Trusted proxy IPs, separated by commas. | 192.168.1.2,192.168.1.3 |
| ----- GORM Configuration ----- | --------------------------------------- | ----------------------------- |
| RUSTDESK_API_GORM_TYPE | Database type (`sqlite` or `mysql`). Default is `sqlite`. | sqlite |
| RUSTDESK_API_GORM_MAX_IDLE_CONNS | Maximum idle connections | 10 |
| RUSTDESK_API_GORM_MAX_OPEN_CONNS | Maximum open connections | 100 |
| RUSTDESK_API_RUSTDESK_PERSONAL | Open Personal Api 1:Enable,0:Disable | 1 |
| ----- MYSQL Configuration ----- | --------------------------------------- | ----------------------------- |
| RUSTDESK_API_MYSQL_USERNAME | MySQL username | root |
| RUSTDESK_API_MYSQL_PASSWORD | MySQL password | 111111 |
| RUSTDESK_API_MYSQL_ADDR | MySQL address | 192.168.1.66:3306 |
| RUSTDESK_API_MYSQL_DBNAME | MySQL database name | rustdesk |
| ----- RUSTDESK Configuration ----- | --------------------------------------- | ----------------------------- |
| RUSTDESK_API_RUSTDESK_ID_SERVER | Rustdesk ID server address | 192.168.1.66:21116 |
| RUSTDESK_API_RUSTDESK_RELAY_SERVER | Rustdesk relay server address | 192.168.1.66:21117 |
| RUSTDESK_API_RUSTDESK_API_SERVER | Rustdesk API server address | http://192.168.1.66:21114 |
| RUSTDESK_API_RUSTDESK_KEY | Rustdesk key | 123456789 |
| RUSTDESK_API_RUSTDESK_KEY_FILE | Rustdesk key file | `./conf/data/id_ed25519.pub` |
| RUSTDESK_API_RUSTDESK<br/>_WEBCLIENT_MAGIC_QUERYONLINE | New online query method is enabled in the web client v2; '1': Enabled, '0': Disabled, not enabled by default | `0` |
| RUSTDESK_API_RUSTDESK_WS_HOST | Custom Websocket Host | `wss://192.168.1.123:1234` |
| ---- PROXY ----- | --------------- | ---------- |
| RUSTDESK_API_PROXY_ENABLE | proxy_enable :`false`, `true` | `false` |
| RUSTDESK_API_PROXY_HOST | proxy_host | `http://127.0.0.1:1080` |
| ----JWT---- | -------- | -------- |
| RUSTDESK_API_JWT_KEY | Custom JWT KEY, if empty JWT is not enabled.<br/>If `MUST_LOGIN` from `lejianwen/rustdesk-server` is not used, it is recommended to leave it empty. | |
| RUSTDESK_API_JWT_EXPIRE_DURATION | JWT expire duration | `168h` |
### Installation Steps

View File

@@ -1,6 +1,7 @@
package main
import (
"fmt"
"github.com/go-redis/redis/v8"
"github.com/lejianwen/rustdesk-api/v2/config"
"github.com/lejianwen/rustdesk-api/v2/global"
@@ -52,6 +53,10 @@ var resetPwdCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
pwd := args[0]
admin := service.AllService.UserService.InfoById(1)
if admin.Id == 0 {
global.Logger.Warn("user not found! ")
return
}
err := service.AllService.UserService.UpdatePassword(admin, pwd)
if err != nil {
global.Logger.Error("reset password fail! ", err)
@@ -78,6 +83,10 @@ var resetUserPwdCmd = &cobra.Command{
return
}
u := service.AllService.UserService.InfoById(uint(uid))
if u.Id == 0 {
global.Logger.Warn("user not found! ")
return
}
err = service.AllService.UserService.UpdatePassword(u, pwd)
if err != nil {
global.Logger.Warn("reset password fail! ", err)
@@ -132,20 +141,41 @@ func InitGlobal() {
}
//gorm
if global.Config.Gorm.Type == config.TypeMysql {
dns := global.Config.Mysql.Username + ":" + global.Config.Mysql.Password + "@(" + global.Config.Mysql.Addr + ")/" + global.Config.Mysql.Dbname + "?charset=utf8mb4&parseTime=True&loc=Local"
dsn := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
global.Config.Mysql.Username,
global.Config.Mysql.Password,
global.Config.Mysql.Addr,
global.Config.Mysql.Dbname,
)
global.DB = orm.NewMysql(&orm.MysqlConfig{
Dns: dns,
Dsn: dsn,
MaxIdleConns: global.Config.Gorm.MaxIdleConns,
MaxOpenConns: global.Config.Gorm.MaxOpenConns,
})
}, global.Logger)
} else if global.Config.Gorm.Type == config.TypePostgresql {
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
global.Config.Postgresql.Host,
global.Config.Postgresql.Port,
global.Config.Postgresql.User,
global.Config.Postgresql.Password,
global.Config.Postgresql.Dbname,
global.Config.Postgresql.Sslmode,
global.Config.Postgresql.TimeZone,
)
global.DB = orm.NewPostgresql(&orm.PostgresqlConfig{
Dsn: dsn,
MaxIdleConns: global.Config.Gorm.MaxIdleConns,
MaxOpenConns: global.Config.Gorm.MaxOpenConns,
}, global.Logger)
} else {
//sqlite
global.DB = orm.NewSqlite(&orm.SqliteConfig{
MaxIdleConns: global.Config.Gorm.MaxIdleConns,
MaxOpenConns: global.Config.Gorm.MaxOpenConns,
})
}, global.Logger)
}
DatabaseAutoUpdate()
//validator
global.ApiInitValidator()
@@ -162,13 +192,25 @@ func InitGlobal() {
//jwt
//fmt.Println(global.Config.Jwt.PrivateKey)
global.Jwt = jwt.NewJwt(global.Config.Jwt.Key, global.Config.Jwt.ExpireDuration*time.Second)
global.Jwt = jwt.NewJwt(global.Config.Jwt.Key, global.Config.Jwt.ExpireDuration)
//locker
global.Lock = lock.NewLocal()
//service
service.New(&global.Config, global.DB, global.Logger, global.Jwt, global.Lock)
global.LoginLimiter = utils.NewLoginLimiter(utils.SecurityPolicy{
CaptchaThreshold: global.Config.App.CaptchaThreshold,
BanThreshold: global.Config.App.BanThreshold,
AttemptsWindow: 10 * time.Minute,
BanDuration: 30 * time.Minute,
})
global.LoginLimiter.RegisterProvider(utils.B64StringCaptchaProvider{})
DatabaseAutoUpdate()
}
func DatabaseAutoUpdate() {
version := 260
version := 262
db := global.DB
@@ -178,11 +220,17 @@ func DatabaseAutoUpdate() {
if dbName == "" {
dbName = global.Config.Mysql.Dbname
// 移除 DSN 中的数据库名称,以便初始连接时不指定数据库
dsnWithoutDB := global.Config.Mysql.Username + ":" + global.Config.Mysql.Password + "@(" + global.Config.Mysql.Addr + ")/?charset=utf8mb4&parseTime=True&loc=Local"
dsnWithoutDB := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
global.Config.Mysql.Username,
global.Config.Mysql.Password,
global.Config.Mysql.Addr,
"",
)
//新链接
dbWithoutDB := orm.NewMysql(&orm.MysqlConfig{
Dns: dsnWithoutDB,
})
Dsn: dsnWithoutDB,
}, global.Logger)
// 获取底层的 *sql.DB 对象,并确保在程序退出时关闭连接
sqlDBWithoutDB, err := dbWithoutDB.DB()
if err != nil {
@@ -212,6 +260,7 @@ func DatabaseAutoUpdate() {
if v.Version < uint(version) {
Migrate(uint(version))
}
// 245迁移
if v.Version < 245 {
//oauths 表的 oauth_type 字段设置为 op同样的值
@@ -234,7 +283,7 @@ func DatabaseAutoUpdate() {
}
func Migrate(version uint) {
global.Logger.Info("migrating....", version)
global.Logger.Info("Migrating....", version)
err := global.DB.AutoMigrate(
&model.Version{},
&model.User{},
@@ -252,6 +301,7 @@ func Migrate(version uint) {
&model.AddressBookCollection{},
&model.AddressBookCollectionRule{},
&model.ServerCmd{},
&model.DeviceGroup{},
)
if err != nil {
global.Logger.Error("migrate err :=>", err)

View File

@@ -2,14 +2,21 @@ lang: "zh-CN"
app:
web-client: 1 # 1:启用 0:禁用
register: false #是否开启注册
register-status: 1 # 注册用户默认状态 1:启用 2:禁用
captcha-threshold: 3 # <0:disabled, 0 always, >0:enabled
ban-threshold: 0 # 0:disabled, >0:enabled
show-swagger: 0 # 1:启用 0:禁用
token-expire: 360000
token-expire: 168h
web-sso: true #web auth sso
disable-pwd-login: false #禁用密码登录
admin:
title: "RustDesk Api Admin"
title: "RustDesk API Admin"
hello-file: "./conf/admin/hello.html" #优先使用file
hello: ""
# ID Server and Relay Server ports https://github.com/lejianwen/rustdesk-api/issues/257
id-server-port: 21116 # ID Server port (for server cmd)
relay-server-port: 21117 # ID Server port (for server cmd)
gin:
api-addr: "0.0.0.0:21114"
mode: "release" #release,debug,test
@@ -24,6 +31,16 @@ mysql:
password: ""
addr: ""
dbname: ""
postgresql:
host: "127.0.0.1"
port: "5432"
user: ""
password: ""
dbname: "postgres"
sslmode: "disable" # disable, require, verify-ca, verify-full
time-zone: "Asia/Shanghai" # Time zone for PostgreSQL connection
rustdesk:
id-server: "192.168.1.66:21116"
relay-server: "192.168.1.66:21117"
@@ -32,6 +49,7 @@ rustdesk:
key-file: "/data/id_ed25519.pub"
personal: 1
webclient-magic-queryonline: 0
ws-host: "" #eg: wss://192.168.1.3:4443
logger:
path: "./runtime/log.txt"
level: "info" #trace,debug,info,warn,error,fatal
@@ -41,11 +59,11 @@ proxy:
host: "http://127.0.0.1:1080"
jwt:
key: ""
expire-duration: 360000
expire-duration: 168h
ldap:
enable: false
url: "ldap://ldap.example.com:389"
tls: false
tls-ca-file: ""
tls-verify: false
base-dn: "dc=example,dc=com"
bind-dn: "cn=admin,dc=example,dc=com"
@@ -63,21 +81,3 @@ ldap:
sync: false # If true, the user will be synchronized to the database when the user logs in. If false, the user will be synchronized to the database when the user be created.
admin-group: "cn=admin,dc=example,dc=com" # The group name of the admin group, if the user is in this group, the user will be an admin.
redis:
addr: "127.0.0.1:6379"
password: ""
db: 0
cache:
type: "file"
file-dir: "./runtime/cache"
redis-addr: "127.0.0.1:6379"
redis-pwd: ""
redis-db: 0
oss:
access-key-id: ""
access-key-secret: ""
host: ""
callback-url: ""
expire-time: 30
max-byte: 10240

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"github.com/spf13/viper"
"strings"
"time"
)
const (
@@ -13,33 +14,48 @@ const (
)
type App struct {
WebClient int `mapstructure:"web-client"`
Register bool `mapstructure:"register"`
ShowSwagger int `mapstructure:"show-swagger"`
TokenExpire int `mapstructure:"token-expire"`
WebSso bool `mapstructure:"web-sso"`
DisablePwdLogin bool `mapstructure:"disable-pwd-login"`
WebClient int `mapstructure:"web-client"`
Register bool `mapstructure:"register"`
RegisterStatus int `mapstructure:"register-status"`
ShowSwagger int `mapstructure:"show-swagger"`
TokenExpire time.Duration `mapstructure:"token-expire"`
WebSso bool `mapstructure:"web-sso"`
DisablePwdLogin bool `mapstructure:"disable-pwd-login"`
CaptchaThreshold int `mapstructure:"captcha-threshold"`
BanThreshold int `mapstructure:"ban-threshold"`
}
type Admin struct {
Title string `mapstructure:"title"`
Hello string `mapstructure:"hello"`
HelloFile string `mapstructure:"hello-file"`
Title string `mapstructure:"title"`
Hello string `mapstructure:"hello"`
HelloFile string `mapstructure:"hello-file"`
IdServerPort int `mapstructure:"id-server-port"`
RelayServerPort int `mapstructure:"relay-server-port"`
}
type Config struct {
Lang string `mapstructure:"lang"`
App App
Admin Admin
Gorm Gorm
Mysql Mysql
Gin Gin
Logger Logger
Redis Redis
Cache Cache
Oss Oss
Jwt Jwt
Rustdesk Rustdesk
Proxy Proxy
Ldap Ldap
Lang string `mapstructure:"lang"`
App App
Admin Admin
Gorm Gorm
Mysql Mysql
Postgresql Postgresql
Gin Gin
Logger Logger
Redis Redis
Cache Cache
Oss Oss
Jwt Jwt
Rustdesk Rustdesk
Proxy Proxy
Ldap Ldap
}
func (a *Admin) Init() {
if a.IdServerPort == 0 {
a.IdServerPort = DefaultIdServerPort
}
if a.RelayServerPort == 0 {
a.RelayServerPort = DefaultRelayServerPort
}
}
// Init 初始化配置
@@ -73,10 +89,10 @@ func Init(rowVal *Config, path string) *viper.Viper {
})
*/
if err := v.Unmarshal(rowVal); err != nil {
fmt.Println(err)
panic(fmt.Errorf("Fatal error config: %s \n", err))
}
rowVal.Rustdesk.LoadKeyFile()
rowVal.Rustdesk.ParsePort()
rowVal.Admin.Init()
return v
}

View File

@@ -1,8 +1,9 @@
package config
const (
TypeSqlite = "sqlite"
TypeMysql = "mysql"
TypeSqlite = "sqlite"
TypeMysql = "mysql"
TypePostgresql = "postgresql"
)
type Gorm struct {
@@ -17,3 +18,13 @@ type Mysql struct {
Password string `mapstructure:"password"`
Dbname string `mapstructure:"dbname"`
}
type Postgresql struct {
Host string `mapstructure:"host"`
Port string `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
Dbname string `mapstructure:"dbname"`
Sslmode string `mapstructure:"sslmode"` // "disable", "require", "verify-ca", "verify-full"
TimeZone string `mapstructure:"time-zone"` // e.g., "Asia/Shanghai"
}

View File

@@ -26,7 +26,7 @@ type LdapUser struct {
type Ldap struct {
Enable bool `mapstructure:"enable"`
Url string `mapstructure:"url"`
TLS bool `mapstructure:"tls"`
TlsCaFile string `mapstructure:"tls-ca-file"`
TlsVerify bool `mapstructure:"tls-verify"`
BaseDn string `mapstructure:"base-dn"`
BindDn string `mapstructure:"bind-dn"`

View File

@@ -18,3 +18,9 @@ type OidcOauth struct {
ClientSecret string `mapstructure:"client-secret"`
RedirectUrl string `mapstructure:"redirect-url"`
}
type LinuxdoOauth struct {
ClientId string `mapstructure:"client-id"`
ClientSecret string `mapstructure:"client-secret"`
RedirectUrl string `mapstructure:"redirect-url"`
}

View File

@@ -2,8 +2,6 @@ package config
import (
"os"
"strconv"
"strings"
)
const (
@@ -21,7 +19,8 @@ type Rustdesk struct {
KeyFile string `mapstructure:"key-file"`
Personal int `mapstructure:"personal"`
//webclient-magic-queryonline
WebclientMagicQueryonline int `mapstructure:"webclient-magic-queryonline"`
WebclientMagicQueryonline int `mapstructure:"webclient-magic-queryonline"`
WsHost string `mapstructure:"ws-host"`
}
func (rd *Rustdesk) LoadKeyFile() {
@@ -39,19 +38,3 @@ func (rd *Rustdesk) LoadKeyFile() {
return
}
}
func (rd *Rustdesk) ParsePort() {
// Parse port
idres := strings.Split(rd.IdServer, ":")
if len(idres) == 1 {
rd.IdServerPort = DefaultIdServerPort
} else if len(idres) == 2 {
rd.IdServerPort, _ = strconv.Atoi(idres[1])
}
relayres := strings.Split(rd.RelayServer, ":")
if len(relayres) == 1 {
rd.RelayServerPort = DefaultRelayServerPort
} else if len(relayres) == 2 {
rd.RelayServerPort, _ = strconv.Atoi(relayres[1])
}
}

0
data/.gitkeep Normal file
View File

View File

@@ -1407,6 +1407,280 @@ const docTemplateadmin = `{
}
}
},
"/admin/device_group/create": {
"post": {
"security": [
{
"token": []
}
],
"description": "创建设备群组",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"设备群组"
],
"summary": "创建设备群组",
"parameters": [
{
"description": "设备群组信息",
"name": "body",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/admin.DeviceGroupForm"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"allOf": [
{
"$ref": "#/definitions/response.Response"
},
{
"type": "object",
"properties": {
"data": {
"$ref": "#/definitions/model.DeviceGroup"
}
}
}
]
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.Response"
}
}
}
}
},
"/admin/device_group/delete": {
"post": {
"security": [
{
"token": []
}
],
"description": "设备群组删除",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"设备群组"
],
"summary": "设备群组删除",
"parameters": [
{
"description": "群组信息",
"name": "body",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/admin.DeviceGroupForm"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/response.Response"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.Response"
}
}
}
}
},
"/admin/device_group/detail/{id}": {
"get": {
"security": [
{
"token": []
}
],
"description": "设备群组详情",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"设备群组"
],
"summary": "设备群组详情",
"parameters": [
{
"type": "integer",
"description": "ID",
"name": "id",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"allOf": [
{
"$ref": "#/definitions/response.Response"
},
{
"type": "object",
"properties": {
"data": {
"$ref": "#/definitions/model.Group"
}
}
}
]
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.Response"
}
}
}
}
},
"/admin/device_group/list": {
"get": {
"security": [
{
"token": []
}
],
"description": "群组列表",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"群组"
],
"summary": "群组列表",
"parameters": [
{
"type": "integer",
"description": "页码",
"name": "page",
"in": "query"
},
{
"type": "integer",
"description": "页大小",
"name": "page_size",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"allOf": [
{
"$ref": "#/definitions/response.Response"
},
{
"type": "object",
"properties": {
"data": {
"$ref": "#/definitions/model.GroupList"
}
}
}
]
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.Response"
}
}
}
}
},
"/admin/device_group/update": {
"post": {
"security": [
{
"token": []
}
],
"description": "设备群组编辑",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"设备群组"
],
"summary": "设备群组编辑",
"parameters": [
{
"description": "群组信息",
"name": "body",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/admin.DeviceGroupForm"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"allOf": [
{
"$ref": "#/definitions/response.Response"
},
{
"type": "object",
"properties": {
"data": {
"$ref": "#/definitions/model.Group"
}
}
}
]
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.Response"
}
}
}
}
},
"/admin/file/oss_token": {
"get": {
"security": [
@@ -1783,7 +2057,7 @@ const docTemplateadmin = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/github_com_lejianwen_rustdesk-api_http_request_admin.Login"
"$ref": "#/definitions/github_com_lejianwen_rustdesk-api_v2_http_request_admin.Login"
}
}
],
@@ -5219,6 +5493,20 @@ const docTemplateadmin = `{
}
}
},
"admin.DeviceGroupForm": {
"type": "object",
"required": [
"name"
],
"properties": {
"id": {
"type": "integer"
},
"name": {
"type": "string"
}
}
},
"admin.GroupForm": {
"type": "object",
"required": [
@@ -5306,6 +5594,12 @@ const docTemplateadmin = `{
"op": {
"type": "string"
},
"pkce_enable": {
"type": "boolean"
},
"pkce_method": {
"type": "string"
},
"redirect_url": {
"type": "string"
},
@@ -5334,6 +5628,9 @@ const docTemplateadmin = `{
"cpu": {
"type": "string"
},
"group_id": {
"type": "integer"
},
"hostname": {
"type": "string"
},
@@ -5521,7 +5818,7 @@ const docTemplateadmin = `{
}
}
},
"github_com_lejianwen_rustdesk-api_http_request_admin.Login": {
"github_com_lejianwen_rustdesk-api_v2_http_request_admin.Login": {
"type": "object",
"required": [
"password",
@@ -5531,6 +5828,9 @@ const docTemplateadmin = `{
"captcha": {
"type": "string"
},
"captcha_id": {
"type": "string"
},
"password": {
"type": "string"
},
@@ -5842,6 +6142,23 @@ const docTemplateadmin = `{
}
}
},
"model.DeviceGroup": {
"type": "object",
"properties": {
"created_at": {
"type": "string"
},
"id": {
"type": "integer"
},
"name": {
"type": "string"
},
"updated_at": {
"type": "string"
}
}
},
"model.Group": {
"type": "object",
"properties": {
@@ -5973,6 +6290,12 @@ const docTemplateadmin = `{
"op": {
"type": "string"
},
"pkce_enable": {
"type": "boolean"
},
"pkce_method": {
"type": "string"
},
"redirect_url": {
"type": "string"
},
@@ -6013,6 +6336,9 @@ const docTemplateadmin = `{
"created_at": {
"type": "string"
},
"group_id": {
"type": "integer"
},
"hostname": {
"type": "string"
},

View File

@@ -1400,6 +1400,280 @@
}
}
},
"/admin/device_group/create": {
"post": {
"security": [
{
"token": []
}
],
"description": "创建设备群组",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"设备群组"
],
"summary": "创建设备群组",
"parameters": [
{
"description": "设备群组信息",
"name": "body",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/admin.DeviceGroupForm"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"allOf": [
{
"$ref": "#/definitions/response.Response"
},
{
"type": "object",
"properties": {
"data": {
"$ref": "#/definitions/model.DeviceGroup"
}
}
}
]
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.Response"
}
}
}
}
},
"/admin/device_group/delete": {
"post": {
"security": [
{
"token": []
}
],
"description": "设备群组删除",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"设备群组"
],
"summary": "设备群组删除",
"parameters": [
{
"description": "群组信息",
"name": "body",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/admin.DeviceGroupForm"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/response.Response"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.Response"
}
}
}
}
},
"/admin/device_group/detail/{id}": {
"get": {
"security": [
{
"token": []
}
],
"description": "设备群组详情",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"设备群组"
],
"summary": "设备群组详情",
"parameters": [
{
"type": "integer",
"description": "ID",
"name": "id",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"allOf": [
{
"$ref": "#/definitions/response.Response"
},
{
"type": "object",
"properties": {
"data": {
"$ref": "#/definitions/model.Group"
}
}
}
]
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.Response"
}
}
}
}
},
"/admin/device_group/list": {
"get": {
"security": [
{
"token": []
}
],
"description": "群组列表",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"群组"
],
"summary": "群组列表",
"parameters": [
{
"type": "integer",
"description": "页码",
"name": "page",
"in": "query"
},
{
"type": "integer",
"description": "页大小",
"name": "page_size",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"allOf": [
{
"$ref": "#/definitions/response.Response"
},
{
"type": "object",
"properties": {
"data": {
"$ref": "#/definitions/model.GroupList"
}
}
}
]
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.Response"
}
}
}
}
},
"/admin/device_group/update": {
"post": {
"security": [
{
"token": []
}
],
"description": "设备群组编辑",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"设备群组"
],
"summary": "设备群组编辑",
"parameters": [
{
"description": "群组信息",
"name": "body",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/admin.DeviceGroupForm"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"allOf": [
{
"$ref": "#/definitions/response.Response"
},
{
"type": "object",
"properties": {
"data": {
"$ref": "#/definitions/model.Group"
}
}
}
]
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.Response"
}
}
}
}
},
"/admin/file/oss_token": {
"get": {
"security": [
@@ -1776,7 +2050,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/github_com_lejianwen_rustdesk-api_http_request_admin.Login"
"$ref": "#/definitions/github_com_lejianwen_rustdesk-api_v2_http_request_admin.Login"
}
}
],
@@ -5212,6 +5486,20 @@
}
}
},
"admin.DeviceGroupForm": {
"type": "object",
"required": [
"name"
],
"properties": {
"id": {
"type": "integer"
},
"name": {
"type": "string"
}
}
},
"admin.GroupForm": {
"type": "object",
"required": [
@@ -5299,6 +5587,12 @@
"op": {
"type": "string"
},
"pkce_enable": {
"type": "boolean"
},
"pkce_method": {
"type": "string"
},
"redirect_url": {
"type": "string"
},
@@ -5327,6 +5621,9 @@
"cpu": {
"type": "string"
},
"group_id": {
"type": "integer"
},
"hostname": {
"type": "string"
},
@@ -5514,7 +5811,7 @@
}
}
},
"github_com_lejianwen_rustdesk-api_http_request_admin.Login": {
"github_com_lejianwen_rustdesk-api_v2_http_request_admin.Login": {
"type": "object",
"required": [
"password",
@@ -5524,6 +5821,9 @@
"captcha": {
"type": "string"
},
"captcha_id": {
"type": "string"
},
"password": {
"type": "string"
},
@@ -5835,6 +6135,23 @@
}
}
},
"model.DeviceGroup": {
"type": "object",
"properties": {
"created_at": {
"type": "string"
},
"id": {
"type": "integer"
},
"name": {
"type": "string"
},
"updated_at": {
"type": "string"
}
}
},
"model.Group": {
"type": "object",
"properties": {
@@ -5966,6 +6283,12 @@
"op": {
"type": "string"
},
"pkce_enable": {
"type": "boolean"
},
"pkce_method": {
"type": "string"
},
"redirect_url": {
"type": "string"
},
@@ -6006,6 +6329,9 @@
"created_at": {
"type": "string"
},
"group_id": {
"type": "integer"
},
"hostname": {
"type": "string"
},

View File

@@ -77,6 +77,15 @@ definitions:
- new_password
- old_password
type: object
admin.DeviceGroupForm:
properties:
id:
type: integer
name:
type: string
required:
- name
type: object
admin.GroupForm:
properties:
id:
@@ -130,6 +139,10 @@ definitions:
type: string
op:
type: string
pkce_enable:
type: boolean
pkce_method:
type: string
redirect_url:
type: string
scopes:
@@ -153,6 +166,8 @@ definitions:
properties:
cpu:
type: string
group_id:
type: integer
hostname:
type: string
id:
@@ -278,10 +293,12 @@ definitions:
required:
- ids
type: object
github_com_lejianwen_rustdesk-api_http_request_admin.Login:
github_com_lejianwen_rustdesk-api_v2_http_request_admin.Login:
properties:
captcha:
type: string
captcha_id:
type: string
password:
type: string
platform:
@@ -492,6 +509,17 @@ definitions:
total:
type: integer
type: object
model.DeviceGroup:
properties:
created_at:
type: string
id:
type: integer
name:
type: string
updated_at:
type: string
type: object
model.Group:
properties:
created_at:
@@ -579,6 +607,10 @@ definitions:
type: string
op:
type: string
pkce_enable:
type: boolean
pkce_method:
type: string
redirect_url:
type: string
scopes:
@@ -605,6 +637,8 @@ definitions:
type: string
created_at:
type: string
group_id:
type: integer
hostname:
type: string
id:
@@ -1610,6 +1644,167 @@ paths:
summary: RUSTDESK服务配置
tags:
- ADMIN
/admin/device_group/create:
post:
consumes:
- application/json
description: 创建设备群组
parameters:
- description: 设备群组信息
in: body
name: body
required: true
schema:
$ref: '#/definitions/admin.DeviceGroupForm'
produces:
- application/json
responses:
"200":
description: OK
schema:
allOf:
- $ref: '#/definitions/response.Response'
- properties:
data:
$ref: '#/definitions/model.DeviceGroup'
type: object
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/response.Response'
security:
- token: []
summary: 创建设备群组
tags:
- 设备群组
/admin/device_group/delete:
post:
consumes:
- application/json
description: 设备群组删除
parameters:
- description: 群组信息
in: body
name: body
required: true
schema:
$ref: '#/definitions/admin.DeviceGroupForm'
produces:
- application/json
responses:
"200":
description: OK
schema:
$ref: '#/definitions/response.Response'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/response.Response'
security:
- token: []
summary: 设备群组删除
tags:
- 设备群组
/admin/device_group/detail/{id}:
get:
consumes:
- application/json
description: 设备群组详情
parameters:
- description: ID
in: path
name: id
required: true
type: integer
produces:
- application/json
responses:
"200":
description: OK
schema:
allOf:
- $ref: '#/definitions/response.Response'
- properties:
data:
$ref: '#/definitions/model.Group'
type: object
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/response.Response'
security:
- token: []
summary: 设备群组详情
tags:
- 设备群组
/admin/device_group/list:
get:
consumes:
- application/json
description: 群组列表
parameters:
- description: 页码
in: query
name: page
type: integer
- description: 页大小
in: query
name: page_size
type: integer
produces:
- application/json
responses:
"200":
description: OK
schema:
allOf:
- $ref: '#/definitions/response.Response'
- properties:
data:
$ref: '#/definitions/model.GroupList'
type: object
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/response.Response'
security:
- token: []
summary: 群组列表
tags:
- 群组
/admin/device_group/update:
post:
consumes:
- application/json
description: 设备群组编辑
parameters:
- description: 群组信息
in: body
name: body
required: true
schema:
$ref: '#/definitions/admin.DeviceGroupForm'
produces:
- application/json
responses:
"200":
description: OK
schema:
allOf:
- $ref: '#/definitions/response.Response'
- properties:
data:
$ref: '#/definitions/model.Group'
type: object
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/response.Response'
security:
- token: []
summary: 设备群组编辑
tags:
- 设备群组
/admin/file/oss_token:
get:
consumes:
@@ -1830,7 +2025,7 @@ paths:
name: body
required: true
schema:
$ref: '#/definitions/github_com_lejianwen_rustdesk-api_http_request_admin.Login'
$ref: '#/definitions/github_com_lejianwen_rustdesk-api_v2_http_request_admin.Login'
produces:
- application/json
responses:

View File

@@ -767,6 +767,66 @@ const docTemplateapi = `{
}
}
},
"/device-group/accessible": {
"get": {
"security": [
{
"BearerAuth": []
}
],
"description": "机器",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"群组"
],
"summary": "设备",
"parameters": [
{
"type": "integer",
"description": "页码",
"name": "page",
"in": "query"
},
{
"type": "integer",
"description": "每页数量",
"name": "pageSize",
"in": "query"
},
{
"type": "integer",
"description": "状态",
"name": "status",
"in": "query"
},
{
"type": "string",
"description": "accessible",
"name": "accessible",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/response.DataResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.Response"
}
}
}
}
},
"/heartbeat": {
"post": {
"description": "心跳",
@@ -1148,7 +1208,7 @@ const docTemplateapi = `{
"application/json"
],
"tags": [
"地址"
"System"
],
"summary": "提交系统信息",
"parameters": [
@@ -1178,6 +1238,35 @@ const docTemplateapi = `{
}
}
},
"/sysinfo_ver": {
"post": {
"description": "获取系统版本信息",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"System"
],
"summary": "获取系统版本信息",
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "string"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.ErrorResponse"
}
}
}
}
},
"/users": {
"get": {
"security": [

View File

@@ -760,6 +760,66 @@
}
}
},
"/device-group/accessible": {
"get": {
"security": [
{
"BearerAuth": []
}
],
"description": "机器",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"群组"
],
"summary": "设备",
"parameters": [
{
"type": "integer",
"description": "页码",
"name": "page",
"in": "query"
},
{
"type": "integer",
"description": "每页数量",
"name": "pageSize",
"in": "query"
},
{
"type": "integer",
"description": "状态",
"name": "status",
"in": "query"
},
{
"type": "string",
"description": "accessible",
"name": "accessible",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/response.DataResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.Response"
}
}
}
}
},
"/heartbeat": {
"post": {
"description": "心跳",
@@ -1141,7 +1201,7 @@
"application/json"
],
"tags": [
"地址"
"System"
],
"summary": "提交系统信息",
"parameters": [
@@ -1171,6 +1231,35 @@
}
}
},
"/sysinfo_ver": {
"post": {
"description": "获取系统版本信息",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"System"
],
"summary": "获取系统版本信息",
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "string"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/response.ErrorResponse"
}
}
}
}
},
"/users": {
"get": {
"security": [

View File

@@ -671,6 +671,44 @@ paths:
summary: 用户信息
tags:
- 用户
/device-group/accessible:
get:
consumes:
- application/json
description: 机器
parameters:
- description: 页码
in: query
name: page
type: integer
- description: 每页数量
in: query
name: pageSize
type: integer
- description: 状态
in: query
name: status
type: integer
- description: accessible
in: query
name: accessible
type: string
produces:
- application/json
responses:
"200":
description: OK
schema:
$ref: '#/definitions/response.DataResponse'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/response.Response'
security:
- BearerAuth: []
summary: 设备
tags:
- 群组
/heartbeat:
post:
consumes:
@@ -935,7 +973,26 @@ paths:
$ref: '#/definitions/response.ErrorResponse'
summary: 提交系统信息
tags:
- 地址
- System
/sysinfo_ver:
post:
consumes:
- application/json
description: 获取系统版本信息
produces:
- application/json
responses:
"200":
description: OK
schema:
type: string
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/response.ErrorResponse'
summary: 获取系统版本信息
tags:
- System
/users:
get:
consumes:

View File

@@ -14,6 +14,7 @@ import (
en_translations "github.com/go-playground/validator/v10/translations/en"
es_translations "github.com/go-playground/validator/v10/translations/es"
fr_translations "github.com/go-playground/validator/v10/translations/fr"
ko_translations "github.com/go-playground/validator/v10/translations/ko"
ru_translations "github.com/go-playground/validator/v10/translations/ru"
zh_translations "github.com/go-playground/validator/v10/translations/zh"
zh_tw_translations "github.com/go-playground/validator/v10/translations/zh_tw"
@@ -51,8 +52,7 @@ func ApiInitValidator() {
panic(err)
}
//validate没有ko的翻译使用zh的翻译
err = zh_translations.RegisterDefaultTranslations(validate, koTrans)
err = ko_translations.RegisterDefaultTranslations(validate, koTrans)
if err != nil {
panic(err)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/lejianwen/rustdesk-api/v2/lib/jwt"
"github.com/lejianwen/rustdesk-api/v2/lib/lock"
"github.com/lejianwen/rustdesk-api/v2/lib/upload"
"github.com/lejianwen/rustdesk-api/v2/utils"
"github.com/nicksnyder/go-i18n/v2/i18n"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
@@ -31,8 +32,9 @@ var (
ValidStruct func(*gin.Context, interface{}) []string
ValidVar func(ctx *gin.Context, field interface{}, tag string) []string
}
Oss *upload.Oss
Jwt *jwt.Jwt
Lock lock.Locker
Localizer func(lang string) *i18n.Localizer
Oss *upload.Oss
Jwt *jwt.Jwt
Lock lock.Locker
Localizer func(lang string) *i18n.Localizer
LoginLimiter *utils.LoginLimiter
)

36
go.mod
View File

@@ -1,19 +1,23 @@
module github.com/lejianwen/rustdesk-api/v2
go 1.22
go 1.23
toolchain go1.23.10
require (
github.com/BurntSushi/toml v1.3.2
github.com/antonfisher/nested-logrus-formatter v1.3.1
github.com/fsnotify/fsnotify v1.5.1
github.com/coreos/go-oidc/v3 v3.12.0
github.com/fvbock/endless v0.0.0-20170109170031-447134032cb6
github.com/gin-gonic/gin v1.9.0
github.com/go-ldap/ldap/v3 v3.4.10
github.com/go-playground/locales v0.14.1
github.com/go-playground/universal-translator v0.18.1
github.com/go-playground/validator/v10 v10.11.2
github.com/go-playground/validator/v10 v10.26.0
github.com/go-redis/redis/v8 v8.11.4
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0
github.com/mojocn/base64Captcha v1.3.6
github.com/nicksnyder/go-i18n/v2 v2.4.0
github.com/sirupsen/logrus v1.8.1
github.com/spf13/cobra v1.8.1
@@ -22,10 +26,11 @@ require (
github.com/swaggo/gin-swagger v1.6.0
github.com/swaggo/swag v1.16.3
golang.org/x/oauth2 v0.23.0
golang.org/x/text v0.21.0
golang.org/x/text v0.22.0
gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.5.6
gorm.io/gorm v1.25.7
gorm.io/gorm v1.25.10
)
require (
@@ -37,9 +42,11 @@ require (
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/fsnotify/fsnotify v1.5.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect
github.com/go-ldap/ldap/v3 v3.4.10 // indirect
github.com/go-jose/go-jose/v4 v4.0.2 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect
github.com/go-openapi/jsonreference v0.19.6 // indirect
github.com/go-openapi/spec v0.20.4 // indirect
@@ -49,12 +56,16 @@ require (
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.6.0 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.0.9 // indirect
github.com/leodido/go-urn v1.2.1 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/magiconair/properties v1.8.5 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-isatty v0.0.17 // indirect
@@ -62,9 +73,9 @@ require (
github.com/mitchellh/mapstructure v1.4.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/mojocn/base64Captcha v1.3.6 // indirect
github.com/pelletier/go-toml v1.9.4 // indirect
github.com/pelletier/go-toml/v2 v2.0.6 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/spf13/afero v1.6.0 // indirect
github.com/spf13/cast v1.4.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
@@ -73,11 +84,12 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.9 // indirect
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/crypto v0.33.0 // indirect
golang.org/x/image v0.13.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.org/x/net v0.34.0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/tools v0.26.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/ini.v1 v1.63.2 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect

View File

@@ -120,7 +120,7 @@ func (abcr *AddressBookCollectionRule) CheckForm(t *model.AddressBookCollectionR
//check to_id
if t.Type == model.ShareAddressBookRuleTypePersonal {
if t.ToId == t.UserId {
return "ParamsError", false
return "CannotShareToSelf", false
}
tou := service.AllService.UserService.InfoById(t.ToId)
if tou.Id == 0 {
@@ -135,7 +135,7 @@ func (abcr *AddressBookCollectionRule) CheckForm(t *model.AddressBookCollectionR
return "ParamsError", false
}
// 重复检查
ex := service.AllService.AddressBookService.RulePersonalInfoByToIdAndCid(t.ToId, t.CollectionId)
ex := service.AllService.AddressBookService.RuleInfoByToIdAndCid(t.Type, t.ToId, t.CollectionId)
if t.Id == 0 && ex.Id > 0 {
return "ItemExists", false
}

View File

@@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/http/response"
"github.com/lejianwen/rustdesk-api/v2/model"
"github.com/lejianwen/rustdesk-api/v2/service"
"os"
"strings"
@@ -60,19 +61,30 @@ func (co *Config) AppConfig(c *gin.Context) {
// @Security token
func (co *Config) AdminConfig(c *gin.Context) {
u := service.AllService.UserService.CurUser(c)
if u == nil || u.Id == 0 {
u := &model.User{}
token := c.GetHeader("api-token")
if token != "" {
u, _ = service.AllService.UserService.InfoByAccessToken(token)
if !service.AllService.UserService.CheckUserEnable(u) {
u.Id = 0
}
}
if u.Id == 0 {
response.Success(c, &gin.H{
"title": global.Config.Admin.Title,
})
return
}
hello := global.Config.Admin.Hello
helloFile := global.Config.Admin.HelloFile
if helloFile != "" {
b, err := os.ReadFile(helloFile)
if err == nil && len(b) > 0 {
hello = string(b)
if hello == "" {
helloFile := global.Config.Admin.HelloFile
if helloFile != "" {
b, err := os.ReadFile(helloFile)
if err == nil && len(b) > 0 {
hello = string(b)
}
}
}

View File

@@ -0,0 +1,160 @@
package admin
import (
"github.com/gin-gonic/gin"
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/http/request/admin"
"github.com/lejianwen/rustdesk-api/v2/http/response"
"github.com/lejianwen/rustdesk-api/v2/service"
"strconv"
)
type DeviceGroup struct {
}
// Detail 设备群组
// @Tags 设备群组
// @Summary 设备群组详情
// @Description 设备群组详情
// @Accept json
// @Produce json
// @Param id path int true "ID"
// @Success 200 {object} response.Response{data=model.Group}
// @Failure 500 {object} response.Response
// @Router /admin/device_group/detail/{id} [get]
// @Security token
func (ct *DeviceGroup) Detail(c *gin.Context) {
id := c.Param("id")
iid, _ := strconv.Atoi(id)
u := service.AllService.GroupService.DeviceGroupInfoById(uint(iid))
if u.Id > 0 {
response.Success(c, u)
return
}
response.Fail(c, 101, response.TranslateMsg(c, "ItemNotFound"))
return
}
// Create 创建设备群组
// @Tags 设备群组
// @Summary 创建设备群组
// @Description 创建设备群组
// @Accept json
// @Produce json
// @Param body body admin.DeviceGroupForm true "设备群组信息"
// @Success 200 {object} response.Response{data=model.DeviceGroup}
// @Failure 500 {object} response.Response
// @Router /admin/device_group/create [post]
// @Security token
func (ct *DeviceGroup) Create(c *gin.Context) {
f := &admin.DeviceGroupForm{}
if err := c.ShouldBindJSON(f); err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
return
}
errList := global.Validator.ValidStruct(c, f)
if len(errList) > 0 {
response.Fail(c, 101, errList[0])
return
}
u := f.ToDeviceGroup()
err := service.AllService.GroupService.DeviceGroupCreate(u)
if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
return
}
response.Success(c, nil)
}
// List 列表
// @Tags 群组
// @Summary 群组列表
// @Description 群组列表
// @Accept json
// @Produce json
// @Param page query int false "页码"
// @Param page_size query int false "页大小"
// @Success 200 {object} response.Response{data=model.GroupList}
// @Failure 500 {object} response.Response
// @Router /admin/device_group/list [get]
// @Security token
func (ct *DeviceGroup) List(c *gin.Context) {
query := &admin.PageQuery{}
if err := c.ShouldBindQuery(query); err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
return
}
res := service.AllService.GroupService.DeviceGroupList(query.Page, query.PageSize, nil)
response.Success(c, res)
}
// Update 编辑
// @Tags 设备群组
// @Summary 设备群组编辑
// @Description 设备群组编辑
// @Accept json
// @Produce json
// @Param body body admin.DeviceGroupForm true "群组信息"
// @Success 200 {object} response.Response{data=model.Group}
// @Failure 500 {object} response.Response
// @Router /admin/device_group/update [post]
// @Security token
func (ct *DeviceGroup) Update(c *gin.Context) {
f := &admin.DeviceGroupForm{}
if err := c.ShouldBindJSON(f); err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
return
}
if f.Id == 0 {
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
return
}
errList := global.Validator.ValidStruct(c, f)
if len(errList) > 0 {
response.Fail(c, 101, errList[0])
return
}
u := f.ToDeviceGroup()
err := service.AllService.GroupService.DeviceGroupUpdate(u)
if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
return
}
response.Success(c, nil)
}
// Delete 删除
// @Tags 设备群组
// @Summary 设备群组删除
// @Description 设备群组删除
// @Accept json
// @Produce json
// @Param body body admin.DeviceGroupForm true "群组信息"
// @Success 200 {object} response.Response
// @Failure 500 {object} response.Response
// @Router /admin/device_group/delete [post]
// @Security token
func (ct *DeviceGroup) Delete(c *gin.Context) {
f := &admin.DeviceGroupForm{}
if err := c.ShouldBindJSON(f); err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
return
}
id := f.Id
errList := global.Validator.ValidVar(c, id, "required,gt=0")
if len(errList) > 0 {
response.Fail(c, 101, errList[0])
return
}
u := service.AllService.GroupService.DeviceGroupInfoById(f.Id)
if u.Id > 0 {
err := service.AllService.GroupService.DeviceGroupDelete(u)
if err == nil {
response.Success(c, nil)
return
}
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
return
}
response.Fail(c, 101, response.TranslateMsg(c, "ItemNotFound"))
}

View File

@@ -11,135 +11,11 @@ import (
adResp "github.com/lejianwen/rustdesk-api/v2/http/response/admin"
"github.com/lejianwen/rustdesk-api/v2/model"
"github.com/lejianwen/rustdesk-api/v2/service"
"github.com/mojocn/base64Captcha"
"sync"
"time"
)
type Login struct {
}
// Captcha 验证码结构
type Captcha struct {
Id string `json:"id"` // 验证码 ID
B64 string `json:"b64"` // base64 验证码
Code string `json:"-"` // 验证码内容
ExpiresAt time.Time `json:"-"` // 过期时间
}
type LoginLimiter struct {
mu sync.RWMutex
failCount map[string]int // 记录每个 IP 的失败次数
timestamp map[string]time.Time // 记录每个 IP 的最后失败时间
captchas map[string]Captcha // 每个 IP 的验证码
threshold int // 失败阈值
expiry time.Duration // 失败记录过期时间
}
func NewLoginLimiter(threshold int, expiry time.Duration) *LoginLimiter {
return &LoginLimiter{
failCount: make(map[string]int),
timestamp: make(map[string]time.Time),
captchas: make(map[string]Captcha),
threshold: threshold,
expiry: expiry,
}
}
// RecordFailure 记录登录失败
func (l *LoginLimiter) RecordFailure(ip string) {
l.mu.Lock()
defer l.mu.Unlock()
// 如果该 IP 的记录已经过期,重置计数
if lastTime, exists := l.timestamp[ip]; exists && time.Since(lastTime) > l.expiry {
l.failCount[ip] = 0
}
// 更新失败次数和时间戳
l.failCount[ip]++
l.timestamp[ip] = time.Now()
}
// NeedsCaptcha 检查是否需要验证码
func (l *LoginLimiter) NeedsCaptcha(ip string) bool {
l.mu.RLock()
defer l.mu.RUnlock()
// 检查记录是否存在且未过期
if lastTime, exists := l.timestamp[ip]; exists && time.Since(lastTime) <= l.expiry {
return l.failCount[ip] >= l.threshold
}
return false
}
// GenerateCaptcha 为指定 IP 生成验证码
func (l *LoginLimiter) GenerateCaptcha(ip string) Captcha {
l.mu.Lock()
defer l.mu.Unlock()
capd := base64Captcha.NewDriverString(50, 150, 5, 10, 4, "1234567890abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
b64cap := base64Captcha.NewCaptcha(capd, base64Captcha.DefaultMemStore)
id, b64s, answer, err := b64cap.Generate()
if err != nil {
global.Logger.Error("Generate captcha failed: " + err.Error())
return Captcha{}
}
// 保存验证码到对应 IP
l.captchas[ip] = Captcha{
Id: id,
B64: b64s,
Code: answer,
ExpiresAt: time.Now().Add(5 * time.Minute),
}
return l.captchas[ip]
}
// VerifyCaptcha 验证指定 IP 的验证码
func (l *LoginLimiter) VerifyCaptcha(ip, code string) bool {
l.mu.RLock()
defer l.mu.RUnlock()
// 检查验证码是否存在且未过期
if captcha, exists := l.captchas[ip]; exists && time.Now().Before(captcha.ExpiresAt) {
return captcha.Code == code
}
return false
}
// RemoveCaptcha 移除指定 IP 的验证码
func (l *LoginLimiter) RemoveCaptcha(ip string) {
l.mu.Lock()
defer l.mu.Unlock()
delete(l.captchas, ip)
}
// CleanupExpired 清理过期的记录
func (l *LoginLimiter) CleanupExpired() {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
for ip, lastTime := range l.timestamp {
if now.Sub(lastTime) > l.expiry {
delete(l.failCount, ip)
delete(l.timestamp, ip)
delete(l.captchas, ip)
}
}
}
func (l *LoginLimiter) RemoveRecord(ip string) {
l.mu.Lock()
defer l.mu.Unlock()
delete(l.failCount, ip)
delete(l.timestamp, ip)
delete(l.captchas, ip)
}
var loginLimiter = NewLoginLimiter(3, 5*time.Minute)
// Login 登录
// @Tags 登录
// @Summary 登录
@@ -156,10 +32,16 @@ func (ct *Login) Login(c *gin.Context) {
response.Fail(c, 101, response.TranslateMsg(c, "PwdLoginDisabled"))
return
}
// 检查登录限制
loginLimiter := global.LoginLimiter
clientIp := c.ClientIP()
_, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp)
f := &admin.Login{}
err := c.ShouldBindJSON(f)
clientIp := c.ClientIP()
if err != nil {
loginLimiter.RecordFailedAttempt(clientIp)
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp))
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
return
@@ -167,14 +49,15 @@ func (ct *Login) Login(c *gin.Context) {
errList := global.Validator.ValidStruct(c, f)
if len(errList) > 0 {
loginLimiter.RecordFailedAttempt(clientIp)
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp))
response.Fail(c, 101, errList[0])
return
}
// 检查是否需要验证码
if loginLimiter.NeedsCaptcha(clientIp) {
if f.Captcha == "" || !loginLimiter.VerifyCaptcha(clientIp, f.Captcha) {
if needCaptcha {
if f.CaptchaId == "" || f.Captcha == "" || !loginLimiter.VerifyCaptcha(f.CaptchaId, f.Captcha) {
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError"))
return
}
@@ -184,17 +67,19 @@ func (ct *Login) Login(c *gin.Context) {
if u.Id == 0 {
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp))
loginLimiter.RecordFailure(clientIp)
if loginLimiter.NeedsCaptcha(clientIp) {
loginLimiter.RemoveCaptcha(clientIp)
loginLimiter.RecordFailedAttempt(clientIp)
if _, needCaptcha = loginLimiter.CheckSecurityStatus(clientIp); needCaptcha {
response.Fail(c, 110, response.TranslateMsg(c, "UsernameOrPasswordError"))
} else {
response.Fail(c, 101, response.TranslateMsg(c, "UsernameOrPasswordError"))
}
response.Fail(c, 101, response.TranslateMsg(c, "UsernameOrPasswordError"))
return
}
if !service.AllService.UserService.CheckUserEnable(u) {
if loginLimiter.NeedsCaptcha(clientIp) {
loginLimiter.RemoveCaptcha(clientIp)
if needCaptcha {
response.Fail(c, 110, response.TranslateMsg(c, "UserDisabled"))
return
}
response.Fail(c, 101, response.TranslateMsg(c, "UserDisabled"))
return
@@ -209,23 +94,37 @@ func (ct *Login) Login(c *gin.Context) {
Platform: f.Platform,
})
// 成功清除记录
loginLimiter.RemoveRecord(clientIp)
// 清理过期记录
go loginLimiter.CleanupExpired()
// 登录成功清除登录限制
loginLimiter.RemoveAttempts(clientIp)
responseLoginSuccess(c, u, ut.Token)
}
func (ct *Login) Captcha(c *gin.Context) {
loginLimiter := global.LoginLimiter
clientIp := c.ClientIP()
if !loginLimiter.NeedsCaptcha(clientIp) {
banned, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp)
if banned {
response.Fail(c, 101, response.TranslateMsg(c, "LoginBanned"))
return
}
if !needCaptcha {
response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired"))
return
}
captcha := loginLimiter.GenerateCaptcha(clientIp)
err, captcha := loginLimiter.RequireCaptcha()
if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error())
return
}
err, b64 := loginLimiter.DrawCaptcha(captcha.Content)
if err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error())
return
}
response.Success(c, gin.H{
"captcha": captcha,
"captcha": gin.H{
"id": captcha.Id,
"b64": b64,
},
})
}
@@ -257,12 +156,18 @@ func (ct *Login) Logout(c *gin.Context) {
// @Failure 500 {object} response.ErrorResponse
// @Router /admin/login-options [post]
func (ct *Login) LoginOptions(c *gin.Context) {
ip := c.ClientIP()
loginLimiter := global.LoginLimiter
clientIp := c.ClientIP()
banned, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp)
if banned {
response.Fail(c, 101, response.TranslateMsg(c, "LoginBanned"))
return
}
ops := service.AllService.OauthService.GetOauthProviders()
response.Success(c, gin.H{
"ops": ops,
"register": global.Config.App.Register,
"need_captcha": loginLimiter.NeedsCaptcha(ip),
"need_captcha": needCaptcha,
})
}
@@ -283,13 +188,13 @@ func (ct *Login) OidcAuth(c *gin.Context) {
return
}
err, code, url := service.AllService.OauthService.BeginAuth(f.Op)
err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(f.Op)
if err != nil {
response.Error(c, response.TranslateMsg(c, err.Error()))
return
}
service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{
service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
Action: service.OauthActionTypeLogin,
Op: f.Op,
Id: f.Id,
@@ -297,10 +202,12 @@ func (ct *Login) OidcAuth(c *gin.Context) {
// DeviceOs: ct.Platform(c),
DeviceOs: f.DeviceInfo.Os,
Uuid: f.Uuid,
Verifier: verifier,
Nonce: nonce,
}, 5*60)
response.Success(c, gin.H{
"code": code,
"code": state,
"url": url,
})
}

View File

@@ -100,21 +100,21 @@ func (abcr *AddressBookCollectionRule) CheckForm(u *model.User, t *model.Address
//check to_id
if t.Type == model.ShareAddressBookRuleTypePersonal {
if t.ToId == t.UserId {
return "ParamsError", false
return "CannotShareToSelf", false
}
tou := service.AllService.UserService.InfoById(t.ToId)
if tou.Id == 0 {
return "ItemNotFound", false
}
//非管理员不能分享给非本组织用户
if tou.GroupId != u.GroupId {
return "NoAccess", false
}
//if tou.GroupId != u.GroupId {
// return "NoAccess", false
//}
} else if t.Type == model.ShareAddressBookRuleTypeGroup {
//非管理员不能分享给其他组
if t.ToId != u.GroupId {
return "NoAccess", false
}
//if t.ToId != u.GroupId {
// return "NoAccess", false
//}
tog := service.AllService.GroupService.InfoById(t.ToId)
if tog.Id == 0 {
@@ -124,7 +124,7 @@ func (abcr *AddressBookCollectionRule) CheckForm(u *model.User, t *model.Address
return "ParamsError", false
}
// 重复检查
ex := service.AllService.AddressBookService.RulePersonalInfoByToIdAndCid(t.ToId, t.CollectionId)
ex := service.AllService.AddressBookService.RuleInfoByToIdAndCid(t.Type, t.ToId, t.CollectionId)
if t.Id == 0 && ex.Id > 0 {
return "ItemExists", false
}

View File

@@ -43,20 +43,22 @@ func (o *Oauth) ToBind(c *gin.Context) {
return
}
err, code, url := service.AllService.OauthService.BeginAuth(f.Op)
err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(f.Op)
if err != nil {
response.Error(c, response.TranslateMsg(c, err.Error()))
return
}
service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{
Action: service.OauthActionTypeBind,
Op: f.Op,
UserId: u.Id,
service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
Action: service.OauthActionTypeBind,
Op: f.Op,
UserId: u.Id,
Verifier: verifier,
Nonce: nonce,
}, 5*60)
response.Success(c, gin.H{
"code": code,
"code": state,
"url": url,
})
}

View File

@@ -108,6 +108,12 @@ func (ct *Peer) List(c *gin.Context) {
if query.Uuids != "" {
tx.Where("uuid in (?)", query.Uuids)
}
if query.Username != "" {
tx.Where("username like ?", "%"+query.Username+"%")
}
if query.Ip != "" {
tx.Where("last_online_ip like ?", "%"+query.Ip+"%")
}
})
response.Success(c, res)
}

View File

@@ -119,7 +119,16 @@ func (r *Rustdesk) SendCmd(c *gin.Context) {
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
return
}
res, err := service.AllService.ServerCmdService.SendCmd(rc.Target, rc.Cmd, rc.Option)
port := 0
switch rc.Target {
case model.ServerCmdTargetIdServer:
port = global.Config.Admin.IdServerPort - 1
case model.ServerCmdTargetRelayServer:
port = global.Config.Admin.RelayServerPort
}
res, err := service.AllService.ServerCmdService.SendCmd(port, rc.Cmd, rc.Option)
if err != nil {
response.Fail(c, 101, err.Error())
return

View File

@@ -296,32 +296,12 @@ func (ct *User) MyOauth(c *gin.Context) {
// groupUsers
func (ct *User) GroupUsers(c *gin.Context) {
q := &admin.GroupUsersQuery{}
if err := c.ShouldBindJSON(q); err != nil {
response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
return
}
u := service.AllService.UserService.CurUser(c)
gid := u.GroupId
uid := u.Id
if service.AllService.UserService.IsAdmin(u) && q.UserId > 0 {
nu := service.AllService.UserService.InfoById(q.UserId)
gid = nu.GroupId
uid = q.UserId
}
res := service.AllService.UserService.List(1, 999, func(tx *gorm.DB) {
tx.Where("group_id = ?", gid)
aG := service.AllService.GroupService.List(1, 999, nil)
aU := service.AllService.UserService.List(1, 9999, nil)
response.Success(c, gin.H{
"groups": aG.Groups,
"users": aU.Users,
})
var data []*adResp.GroupUsersPayload
for _, _u := range res.Users {
gup := &adResp.GroupUsersPayload{}
gup.FromUser(_u)
if _u.Id == uid {
gup.Status = 0
}
data = append(data, gup)
}
response.Success(c, data)
}
// Register
@@ -340,11 +320,22 @@ func (ct *User) Register(c *gin.Context) {
response.Fail(c, 101, errList[0])
return
}
u := service.AllService.UserService.Register(f.Username, f.Email, f.Password)
regStatus := model.StatusCode(global.Config.App.RegisterStatus)
// 注册状态可能未配置,默认启用
if regStatus != model.COMMON_STATUS_DISABLED && regStatus != model.COMMON_STATUS_ENABLE {
regStatus = model.COMMON_STATUS_ENABLE
}
u := service.AllService.UserService.Register(f.Username, f.Email, f.Password, regStatus)
if u == nil || u.Id == 0 {
response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed"))
return
}
if regStatus == model.COMMON_STATUS_DISABLED {
// 需要管理员审核
response.Fail(c, 101, response.TranslateMsg(c, "RegisterSuccessWaitAdminConfirm"))
return
}
// 注册成功后自动登录
ut := service.AllService.UserService.Login(u, &model.LoginLog{
UserId: u.Id,

View File

@@ -45,7 +45,7 @@ func (g *Group) Users(c *gin.Context) {
userList = service.AllService.UserService.ListByGroupId(u.GroupId, q.Page, q.PageSize)
}
var data []*apiResp.UserPayload
data := make([]*apiResp.UserPayload, 0, len(userList.Users))
for _, user := range userList.Users {
up := &apiResp.UserPayload{}
up.FromUser(user)
@@ -88,21 +88,30 @@ func (g *Group) Peers(c *gin.Context) {
users = service.AllService.UserService.ListIdAndNameByGroupId(u.GroupId)
}
namesById := make(map[uint]string)
userIds := make([]uint, 0)
namesById := make(map[uint]string, len(users))
userIds := make([]uint, 0, len(users))
for _, user := range users {
namesById[user.Id] = user.Username
userIds = append(userIds, user.Id)
}
dGroupNameById := make(map[uint]string)
allGroup := service.AllService.GroupService.DeviceGroupList(1, 999, nil)
for _, group := range allGroup.DeviceGroups {
dGroupNameById[group.Id] = group.Name
}
peerList := service.AllService.PeerService.ListByUserIds(userIds, q.Page, q.PageSize)
var data []*apiResp.GroupPeerPayload
data := make([]*apiResp.GroupPeerPayload, 0, len(peerList.Peers))
for _, peer := range peerList.Peers {
uname, ok := namesById[peer.UserId]
if !ok {
uname = ""
}
dGroupName, ok2 := dGroupNameById[peer.GroupId]
if !ok2 {
dGroupName = ""
}
pp := &apiResp.GroupPeerPayload{}
pp.FromPeer(peer, uname)
pp.FromPeer(peer, uname, dGroupName)
data = append(data, pp)
}
@@ -111,3 +120,31 @@ func (g *Group) Peers(c *gin.Context) {
Data: data,
})
}
// Device
// @Tags 群组
// @Summary 设备
// @Description 机器
// @Accept json
// @Produce json
// @Param page query int false "页码"
// @Param pageSize query int false "每页数量"
// @Param status query int false "状态"
// @Param accessible query string false "accessible"
// @Success 200 {object} response.DataResponse
// @Failure 500 {object} response.Response
// @Router /device-group/accessible [get]
// @Security BearerAuth
func (g *Group) Device(c *gin.Context) {
u := service.AllService.UserService.CurUser(c)
if !service.AllService.UserService.IsAdmin(u) {
response.Error(c, "Permission denied")
return
}
allGroup := service.AllService.GroupService.DeviceGroupList(1, 999, nil)
c.JSON(http.StatusOK, response.DataResponse{
Total: 0,
Data: allGroup.DeviceGroups,
})
}

View File

@@ -7,7 +7,6 @@ import (
"github.com/lejianwen/rustdesk-api/v2/model"
"github.com/lejianwen/rustdesk-api/v2/service"
"net/http"
"os"
"time"
)
@@ -56,7 +55,7 @@ func (i *Index) Heartbeat(c *gin.Context) {
return
}
//如果在40s以内则不更新
if time.Now().Unix()-peer.LastOnlineTime > 40 {
if time.Now().Unix()-peer.LastOnlineTime >= 30 {
upp := &model.Peer{RowId: peer.RowId, LastOnlineTime: time.Now().Unix(), LastOnlineIp: c.ClientIP()}
service.AllService.PeerService.Update(upp)
}
@@ -74,13 +73,9 @@ func (i *Index) Heartbeat(c *gin.Context) {
// @Router /version [get]
func (i *Index) Version(c *gin.Context) {
//读取resources/version文件
v, err := os.ReadFile("resources/version")
if err != nil {
response.Fail(c, 101, err.Error())
return
}
v := service.AllService.AppService.GetAppVersion()
response.Success(
c,
string(v),
v,
)
}

View File

@@ -31,10 +31,16 @@ func (l *Login) Login(c *gin.Context) {
response.Error(c, response.TranslateMsg(c, "PwdLoginDisabled"))
return
}
// 检查登录限制
loginLimiter := global.LoginLimiter
clientIp := c.ClientIP()
f := &api.LoginForm{}
err := c.ShouldBindJSON(f)
//fmt.Println(f)
if err != nil {
loginLimiter.RecordFailedAttempt(clientIp)
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), c.ClientIP()))
response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error())
return
@@ -42,6 +48,7 @@ func (l *Login) Login(c *gin.Context) {
errList := global.Validator.ValidStruct(c, f)
if len(errList) > 0 {
loginLimiter.RecordFailedAttempt(clientIp)
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), c.ClientIP()))
response.Error(c, errList[0])
return
@@ -50,6 +57,7 @@ func (l *Login) Login(c *gin.Context) {
u := service.AllService.UserService.InfoByUsernamePassword(f.Username, f.Password)
if u.Id == 0 {
loginLimiter.RecordFailedAttempt(clientIp)
global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), c.ClientIP()))
response.Error(c, response.TranslateMsg(c, "UsernameOrPasswordError"))
return

View File

@@ -8,6 +8,8 @@ import (
apiResp "github.com/lejianwen/rustdesk-api/v2/http/response/api"
"github.com/lejianwen/rustdesk-api/v2/model"
"github.com/lejianwen/rustdesk-api/v2/service"
"github.com/lejianwen/rustdesk-api/v2/utils"
"github.com/nicksnyder/go-i18n/v2/i18n"
"net/http"
)
@@ -32,15 +34,14 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
}
oauthService := service.AllService.OauthService
var code string
var url string
err, code, url = oauthService.BeginAuth(f.Op)
err, state, verifier, nonce, url := oauthService.BeginAuth(f.Op)
if err != nil {
response.Error(c, response.TranslateMsg(c, err.Error()))
return
}
service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{
service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
Action: service.OauthActionTypeLogin,
Id: f.Id,
Op: f.Op,
@@ -48,10 +49,12 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
DeviceName: f.DeviceInfo.Name,
DeviceOs: f.DeviceInfo.Os,
DeviceType: f.DeviceInfo.Type,
Verifier: verifier,
Nonce: nonce,
}, 5*60)
//fmt.Println("code url", code, url)
c.JSON(http.StatusOK, gin.H{
"code": code,
"code": state,
"url": url,
})
}
@@ -143,7 +146,10 @@ func (o *Oauth) OidcAuthQuery(c *gin.Context) {
func (o *Oauth) OauthCallback(c *gin.Context) {
state := c.Query("state")
if state == "" {
c.String(http.StatusInternalServerError, response.TranslateParamMsg(c, "ParamIsEmpty", "state"))
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": "ParamIsEmpty",
"sub_message": "state",
})
return
}
cacheKey := state
@@ -151,17 +157,24 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
//从缓存中获取
oauthCache := oauthService.GetOauthCache(cacheKey)
if oauthCache == nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthExpired"))
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": "OauthExpired",
})
return
}
nonce := oauthCache.Nonce
op := oauthCache.Op
action := oauthCache.Action
verifier := oauthCache.Verifier
var user *model.User
// 获取用户信息
code := c.Query("code")
err, oauthUser := oauthService.Callback(code, op)
err, oauthUser := oauthService.Callback(code, verifier, op, nonce)
if err != nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": "OauthFailed",
"sub_message": err.Error(),
})
return
}
userId := oauthCache.UserId
@@ -172,28 +185,38 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
// 检查此openid是否已经绑定过
utr := oauthService.UserThirdInfo(op, openid)
if utr.UserId > 0 {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBindOtherUser"))
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": "OauthHasBindOtherUser",
})
return
}
//绑定
user = service.AllService.UserService.InfoById(userId)
if user == nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ItemNotFound"))
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": "ItemNotFound",
})
return
}
//绑定
err := oauthService.BindOauthUser(userId, oauthUser, op)
if err != nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "BindFail"))
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": "BindFail",
})
return
}
c.String(http.StatusOK, response.TranslateMsg(c, "BindSuccess"))
c.HTML(http.StatusOK, "oauth_success.html", gin.H{
"message": "BindSuccess",
})
return
} else if action == service.OauthActionTypeLogin {
//登录
if userId != 0 {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBeenSuccess"))
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": "OauthHasBeenSuccess",
})
return
}
user = service.AllService.UserService.InfoByOauthId(op, openid)
@@ -210,7 +233,9 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
//自动注册
err, user = service.AllService.UserService.RegisterByOauth(oauthUser, op)
if err != nil {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, err.Error()))
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": err.Error(),
})
return
}
}
@@ -230,11 +255,51 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
c.Redirect(http.StatusFound, url)
return
}
c.String(http.StatusOK, response.TranslateMsg(c, "OauthSuccess"))
c.HTML(http.StatusOK, "oauth_success.html", gin.H{
"message": "OauthSuccess",
})
return
} else {
c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ParamsError"))
c.HTML(http.StatusOK, "oauth_fail.html", gin.H{
"message": "ParamsError",
})
return
}
}
type MessageParams struct {
Lang string `json:"lang" form:"lang"`
Title string `json:"title" form:"title"`
Msg string `json:"msg" form:"msg"`
}
func (o *Oauth) Message(c *gin.Context) {
mp := &MessageParams{}
if err := c.ShouldBindQuery(mp); err != nil {
return
}
localizer := global.Localizer(mp.Lang)
res := ""
if mp.Title != "" {
title, err := localizer.LocalizeMessage(&i18n.Message{
ID: mp.Title,
})
if err == nil {
res = utils.StringConcat(";title='", title, "';")
}
}
if mp.Msg != "" {
msg, err := localizer.LocalizeMessage(&i18n.Message{
ID: mp.Msg,
})
if err == nil {
res = utils.StringConcat(res, "msg = '", msg, "';")
}
}
//返回js内容
c.Header("Content-Type", "application/javascript")
c.String(http.StatusOK, res)
}

View File

@@ -1,6 +1,7 @@
package api
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
requstform "github.com/lejianwen/rustdesk-api/v2/http/request/api"
@@ -13,7 +14,7 @@ type Peer struct {
}
// SysInfo
// @Tags 地址
// @Tags System
// @Summary 提交系统信息
// @Description 提交系统信息
// @Accept json
@@ -30,7 +31,7 @@ func (p *Peer) SysInfo(c *gin.Context) {
return
}
fpe := f.ToPeer()
pe := service.AllService.PeerService.FindById(f.Id)
pe := service.AllService.PeerService.FindByUuid(f.Uuid)
if pe.RowId == 0 {
pe = f.ToPeer()
pe.UserId = service.AllService.UserService.FindLatestUserIdFromLoginLogByUuid(pe.Uuid)
@@ -56,3 +57,20 @@ func (p *Peer) SysInfo(c *gin.Context) {
//直接响应文本
c.String(http.StatusOK, "SYSINFO_UPDATED")
}
// SysInfoVer
// @Tags System
// @Summary 获取系统版本信息
// @Description 获取系统版本信息
// @Accept json
// @Produce json
// @Success 200 {string} string ""
// @Failure 500 {object} response.ErrorResponse
// @Router /sysinfo_ver [post]
func (p *Peer) SysInfoVer(c *gin.Context) {
//读取resources/version文件
v := service.AllService.AppService.GetAppVersion()
// 加上启动时间方便client上传信息
v = fmt.Sprintf("%s\n%s", v, service.AllService.AppService.GetStartTime())
c.String(http.StatusOK, v)
}

View File

@@ -1,9 +1,9 @@
package web
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/lejianwen/rustdesk-api/v2/global"
"strconv"
)
type Index struct {
@@ -15,13 +15,21 @@ func (i *Index) Index(c *gin.Context) {
func (i *Index) ConfigJs(c *gin.Context) {
apiServer := global.Config.Rustdesk.ApiServer
magicQueryonline := strconv.Itoa(global.Config.Rustdesk.WebclientMagicQueryonline)
tmp := `
localStorage.setItem('api-server', "` + apiServer + `")
const ws2_prefix = 'wc-'
localStorage.setItem(ws2_prefix+'api-server', "` + apiServer + `")
magicQueryonline := global.Config.Rustdesk.WebclientMagicQueryonline
tmp := fmt.Sprintf(`localStorage.setItem('api-server', '%v');
const ws2_prefix = 'wc-';
localStorage.setItem(ws2_prefix+'api-server', '%v');
window.webclient_magic_queryonline = ` + magicQueryonline + ``
window.webclient_magic_queryonline = %d;
window.ws_host = '%v';
`, apiServer, apiServer, magicQueryonline, global.Config.Rustdesk.WsHost)
// tmp := `
//localStorage.setItem('api-server', "` + apiServer + `")
//const ws2_prefix = 'wc-'
//localStorage.setItem(ws2_prefix+'api-server', "` + apiServer + `")
//
//window.webclient_magic_queryonline = ` + magicQueryonline + ``
c.Header("Content-Type", "application/javascript")
c.String(200, tmp)
}

View File

@@ -33,7 +33,7 @@ func ApiInit() {
g.NoRoute(func(c *gin.Context) {
c.String(http.StatusNotFound, "404 not found")
})
g.Use(middleware.Logger(), gin.Recovery())
g.Use(middleware.Logger(), middleware.Limiter(), gin.Recovery())
router.WebInit(g)
router.Init(g)
router.ApiInit(g)

View File

@@ -0,0 +1,22 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/http/response"
"net/http"
)
func Limiter() gin.HandlerFunc {
return func(c *gin.Context) {
loginLimiter := global.LoginLimiter
clientIp := c.ClientIP()
banned, _ := loginLimiter.CheckSecurityStatus(clientIp)
if banned {
response.Fail(c, http.StatusLocked, response.TranslateMsg(c, "Banned"))
c.Abort()
return
}
c.Next()
}
}

View File

@@ -22,3 +22,15 @@ func (gf *GroupForm) ToGroup() *model.Group {
group.Type = gf.Type
return group
}
type DeviceGroupForm struct {
Id uint `json:"id"`
Name string `json:"name" validate:"required"`
}
func (gf *DeviceGroupForm) ToDeviceGroup() *model.DeviceGroup {
group := &model.DeviceGroup{}
group.Id = gf.Id
group.Name = gf.Name
return group
}

View File

@@ -1,10 +1,11 @@
package admin
type Login struct {
Username string `json:"username" validate:"required" label:"用户名"`
Password string `json:"password,omitempty" validate:"required" label:"密码"`
Platform string `json:"platform" label:"平台"`
Captcha string `json:"captcha,omitempty" label:"验证码"`
Username string `json:"username" validate:"required" label:"用户名"`
Password string `json:"password,omitempty" validate:"required" label:"密码"`
Platform string `json:"platform" label:"平台"`
Captcha string `json:"captcha,omitempty" label:"验证码"`
CaptchaId string `json:"captcha_id,omitempty"`
}
type LoginLogQuery struct {

View File

@@ -24,6 +24,8 @@ type OauthForm struct {
ClientSecret string `json:"client_secret" validate:"required"`
RedirectUrl string `json:"redirect_url" validate:"required"`
AutoRegister *bool `json:"auto_register"`
PkceEnable *bool `json:"pkce_enable"`
PkceMethod string `json:"pkce_method"`
}
func (of *OauthForm) ToOauth() *model.Oauth {
@@ -36,6 +38,8 @@ func (of *OauthForm) ToOauth() *model.Oauth {
AutoRegister: of.AutoRegister,
Issuer: of.Issuer,
Scopes: of.Scopes,
PkceEnable: of.PkceEnable,
PkceMethod: of.PkceMethod,
}
oa.Id = of.Id
return oa

View File

@@ -12,6 +12,7 @@ type PeerForm struct {
Username string `json:"username"`
Uuid string `json:"uuid"`
Version string `json:"version"`
GroupId uint `json:"group_id"`
}
type PeerBatchDeleteForm struct {
@@ -30,6 +31,7 @@ func (f *PeerForm) ToPeer() *model.Peer {
Username: f.Username,
Uuid: f.Uuid,
Version: f.Version,
GroupId: f.GroupId,
}
}
@@ -39,6 +41,8 @@ type PeerQuery struct {
Id string `json:"id" form:"id"`
Hostname string `json:"hostname" form:"hostname"`
Uuids string `json:"uuids" form:"uuids"`
Ip string `json:"ip" form:"ip"`
Username string `json:"username" form:"username"`
}
type SimpleDataQuery struct {

View File

@@ -40,14 +40,14 @@ type LoginForm struct {
type UserListQuery struct {
Page uint `json:"page" form:"page" validate:"required" label:"页码"`
PageSize uint `json:"page_size" form:"page_size" validate:"required" label:"每页数量"`
PageSize uint `json:"pageSize" form:"pageSize" validate:"required" label:"每页数量"`
Status int `json:"status" form:"status" label:"状态"`
Accessible string `json:"accessible" form:"accessible"`
}
type PeerListQuery struct {
Page uint `json:"page" form:"page" validate:"required" label:"页码"`
PageSize uint `json:"page_size" form:"page_size" validate:"required" label:"每页数量"`
PageSize uint `json:"pageSize" form:"pageSize" validate:"required" label:"每页数量"`
Status int `json:"status" form:"status" label:"状态"`
Accessible string `json:"accessible" form:"accessible"`
}

View File

@@ -22,15 +22,3 @@ type UserOauthItem struct {
Op string `json:"op"`
Status int `json:"status"`
}
type GroupUsersPayload struct {
Id uint `json:"id"`
Username string `json:"username"`
Status int `json:"status"`
}
func (g *GroupUsersPayload) FromUser(user *model.User) {
g.Id = user.Id
g.Username = user.Username
g.Status = 1
}

View File

@@ -32,12 +32,13 @@ https://github.com/rustdesk/rustdesk/blob/master/flutter/lib/common/hbbs/hbbs.da
}
*/
type GroupPeerPayload struct {
Id string `json:"id"`
Info *PeerPayloadInfo `json:"info"`
Status int `json:"status"`
User string `json:"user"`
UserName string `json:"user_name"`
Note string `json:"note"`
Id string `json:"id"`
Info *PeerPayloadInfo `json:"info"`
Status int `json:"status"`
User string `json:"user"`
UserName string `json:"user_name"`
Note string `json:"note"`
DeviceGroupName string `json:"device_group_name"`
}
type PeerPayloadInfo struct {
DeviceName string `json:"device_name"`
@@ -59,7 +60,7 @@ func (gpp *GroupPeerPayload) FromAddressBook(a *model.AddressBook, username stri
gpp.UserName = username
}
func (gpp *GroupPeerPayload) FromPeer(p *model.Peer, username string) {
func (gpp *GroupPeerPayload) FromPeer(p *model.Peer, username string, dGroupName string) {
gpp.Id = p.Id
gpp.Info = &PeerPayloadInfo{
DeviceName: p.Hostname,
@@ -68,4 +69,5 @@ func (gpp *GroupPeerPayload) FromPeer(p *model.Peer, username string) {
}
gpp.Note = ""
gpp.UserName = username
gpp.DeviceGroupName = dGroupName
}

View File

@@ -49,7 +49,7 @@ func Init(g *gin.Engine) {
MyBind(adg)
RustdeskCmdBind(adg)
DeviceGroupBind(adg)
//访问静态文件
//g.StaticFS("/upload", http.Dir(global.Config.Gin.ResourcesPath+"/upload"))
}
@@ -106,6 +106,18 @@ func GroupBind(rg *gin.RouterGroup) {
}
}
func DeviceGroupBind(rg *gin.RouterGroup) {
aR := rg.Group("/device_group").Use(middleware.AdminPrivilege())
{
cont := &admin.DeviceGroup{}
aR.GET("/list", cont.List)
aR.GET("/detail/:id", cont.Detail)
aR.POST("/create", cont.Create)
aR.POST("/update", cont.Update)
aR.POST("/delete", cont.Delete)
}
}
func TagBind(rg *gin.RouterGroup) {
aR := rg.Group("/tag").Use(middleware.AdminPrivilege())
{

View File

@@ -18,6 +18,8 @@ func ApiInit(g *gin.Engine) {
if global.Config.App.ShowSwagger == 1 {
g.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler, ginSwagger.InstanceName("api")))
}
// 加载 HTML 模板
g.LoadHTMLGlob("resources/templates/*")
frg := g.Group("/api")
@@ -46,11 +48,13 @@ func ApiInit(g *gin.Engine) {
//api/oauth/callback
frg.GET("/oauth/callback", o.OauthCallback)
frg.GET("/oauth/login", o.OauthCallback)
frg.GET("/oauth/msg", o.Message)
}
{
pe := &api.Peer{}
//提交系统信息
frg.POST("/sysinfo", pe.SysInfo)
frg.POST("/sysinfo_ver", pe.SysInfoVer)
}
if global.Config.App.WebClient == 1 {
@@ -79,6 +83,8 @@ func ApiInit(g *gin.Engine) {
gr := &api.Group{}
frg.GET("/users", gr.Users)
frg.GET("/peers", gr.Peers)
// /api/device-group/accessible?current=1&pageSize=100
frg.GET("/device-group/accessible", gr.Device)
}
{
@@ -88,6 +94,7 @@ func ApiInit(g *gin.Engine) {
//更新地址
frg.POST("/ab", ab.UpAb)
}
PersonalRoutes(frg)
//访问静态文件
g.StaticFS("/upload", http.Dir(global.Config.Gin.ResourcesPath+"/public/upload"))

View File

@@ -2,7 +2,6 @@ package orm
import (
"fmt"
"github.com/lejianwen/rustdesk-api/v2/global"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -10,14 +9,14 @@ import (
)
type MysqlConfig struct {
Dns string
Dsn string
MaxIdleConns int
MaxOpenConns int
}
func NewMysql(mysqlConf *MysqlConfig) *gorm.DB {
func NewMysql(mysqlConf *MysqlConfig, logwriter logger.Writer) *gorm.DB {
db, err := gorm.Open(mysql.New(mysql.Config{
DSN: mysqlConf.Dns, // DSN data source name
DSN: mysqlConf.Dsn, // DSN data source name
DefaultStringSize: 256, // string 类型字段的默认长度
//DisableDatetimePrecision: true, // 禁用 datetime 精度MySQL 5.6 之前的数据库不支持
//DontSupportRenameIndex: true, // 重命名索引时采用删除并新建的方式MySQL 5.7 之前的数据库和 MariaDB 不支持重命名索引
@@ -26,7 +25,7 @@ func NewMysql(mysqlConf *MysqlConfig) *gorm.DB {
}), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: logger.New(
global.Logger, // io writer
logwriter, // io writer
logger.Config{
SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: logger.Warn, // Log level

45
lib/orm/postgresql.go Normal file
View File

@@ -0,0 +1,45 @@
package orm
import (
"fmt"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"time"
)
type PostgresqlConfig struct {
Dsn string
MaxIdleConns int
MaxOpenConns int
}
func NewPostgresql(conf *PostgresqlConfig, logwriter logger.Writer) *gorm.DB {
db, err := gorm.Open(postgres.Open(conf.Dsn), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: logger.New(
logwriter, // io writer
logger.Config{
SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: logger.Warn, // Log level
//IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
ParameterizedQueries: true, // Don't include params in the SQL log
Colorful: true,
},
),
})
if err != nil {
fmt.Println(err)
}
sqlDB, err2 := db.DB()
if err2 != nil {
fmt.Println(err2)
}
// SetMaxIdleConns 设置空闲连接池中连接的最大数量
sqlDB.SetMaxIdleConns(conf.MaxIdleConns)
// SetMaxOpenConns 设置打开数据库连接的最大数量。
sqlDB.SetMaxOpenConns(conf.MaxOpenConns)
return db
}

View File

@@ -2,7 +2,6 @@ package orm
import (
"fmt"
"github.com/lejianwen/rustdesk-api/v2/global"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -14,11 +13,11 @@ type SqliteConfig struct {
MaxOpenConns int
}
func NewSqlite(sqliteConf *SqliteConfig) *gorm.DB {
func NewSqlite(sqliteConf *SqliteConfig, logwriter logger.Writer) *gorm.DB {
db, err := gorm.Open(sqlite.Open("./data/rustdeskapi.db"), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: logger.New(
global.Logger, // io writer
logwriter, // io writer
logger.Config{
SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: logger.Warn, // Log level

View File

@@ -16,3 +16,14 @@ type GroupList struct {
Groups []*Group `json:"list"`
Pagination
}
type DeviceGroup struct {
IdModel
Name string `json:"name" gorm:"default:'';not null;"`
TimeModel
}
type DeviceGroupList struct {
DeviceGroups []*DeviceGroup `json:"list"`
Pagination
}

View File

@@ -14,12 +14,15 @@ const (
OauthTypeGoogle string = "google"
OauthTypeOidc string = "oidc"
OauthTypeWebauth string = "webauth"
OauthTypeLinuxdo string = "linuxdo"
PKCEMethodS256 string = "S256"
PKCEMethodPlain string = "plain"
)
// Validate the oauth type
func ValidateOauthType(oauthType string) error {
switch oauthType {
case OauthTypeGithub, OauthTypeGoogle, OauthTypeOidc, OauthTypeWebauth:
case OauthTypeGithub, OauthTypeGoogle, OauthTypeOidc, OauthTypeWebauth, OauthTypeLinuxdo:
return nil
default:
return errors.New("invalid Oauth type")
@@ -28,6 +31,7 @@ func ValidateOauthType(oauthType string) error {
const (
UserEndpointGithub string = "https://api.github.com/user"
UserEndpointLinuxdo string = "https://connect.linux.do/api/user"
IssuerGoogle string = "https://accounts.google.com"
)
@@ -41,6 +45,8 @@ type Oauth struct {
AutoRegister *bool `json:"auto_register"`
Scopes string `json:"scopes"`
Issuer string `json:"issuer"`
PkceEnable *bool `json:"pkce_enable"`
PkceMethod string `json:"pkce_method"`
TimeModel
}
@@ -56,6 +62,8 @@ func (oa *Oauth) FormatOauthInfo() error {
oa.Op = OauthTypeGithub
case OauthTypeGoogle:
oa.Op = OauthTypeGoogle
case OauthTypeLinuxdo:
oa.Op = OauthTypeLinuxdo
}
// check if the op is empty, set the default value
op := strings.TrimSpace(oa.Op)
@@ -68,6 +76,13 @@ func (oa *Oauth) FormatOauthInfo() error {
if oauthType == OauthTypeGoogle && issuer == "" {
oa.Issuer = IssuerGoogle
}
if oa.PkceEnable == nil {
oa.PkceEnable = new(bool)
*oa.PkceEnable = false
}
if oa.PkceMethod == "" {
oa.PkceMethod = PKCEMethodS256
}
return nil
}
@@ -141,6 +156,24 @@ func (gu *GithubUser) ToOauthUser() *OauthUser {
}
}
type LinuxdoUser struct {
OauthUserBase
Id int `json:"id"`
Username string `json:"username"`
Avatar string `json:"avatar_url"`
}
func (lu *LinuxdoUser) ToOauthUser() *OauthUser {
return &OauthUser{
OpenId: strconv.Itoa(lu.Id),
Name: lu.Name,
Username: strings.ToLower(lu.Username),
Email: lu.Email,
VerifiedEmail: true, // linux.do 用户邮箱默认已验证
Picture: lu.Avatar,
}
}
type OauthList struct {
Oauths []*Oauth `json:"list"`
Pagination

View File

@@ -14,6 +14,7 @@ type Peer struct {
User *User `json:"user,omitempty"`
LastOnlineTime int64 `json:"last_online_time" gorm:"default:0;not null;"`
LastOnlineIp string `json:"last_online_ip" gorm:"default:'';not null;"`
GroupId uint `json:"group_id" gorm:"default:0;not null;index"`
TimeModel
}

View File

@@ -137,4 +137,19 @@ other = "Captcha error."
[PwdLoginDisabled]
description = "Password login disabled."
one = "Password login disabled."
other = "Password login disabled."
other = "Password login disabled."
[CannotShareToSelf]
description = "Cannot share to self."
one = "Cannot share to self."
other = "Cannot share to self."
[Banned]
description = "Banned."
one = "Banned."
other = "Banned."
[RegisterSuccessWaitAdminConfirm]
description = "Register success, wait admin confirm."
one = "Register success, wait admin confirm."
other = "Register success, wait admin confirm."

View File

@@ -146,4 +146,19 @@ other = "Error de captcha."
[PwdLoginDisabled]
description = "Password login disabled."
one = "Inicio de sesión con contraseña deshabilitado."
other = "Inicio de sesión con contraseña deshabilitado."
other = "Inicio de sesión con contraseña deshabilitado."
[CannotShareToSelf]
description = "Cannot share to self."
one = "No se puede compartir con uno mismo."
other = "No se puede compartir con uno mismo."
[Banned]
description = "Banned."
one = "Prohibido."
other = "Prohibido."
[RegisterSuccessWaitAdminConfirm]
description = "Register success, wait admin confirm."
one = "Registro exitoso, espere la confirmación del administrador."
other = "Registro exitoso, espere la confirmación del administrador."

View File

@@ -146,4 +146,19 @@ other = "Erreur de captcha."
[PwdLoginDisabled]
description = "Password login disabled."
one = "Connexion par mot de passe désactivée."
other = "Connexion par mot de passe désactivée."
other = "Connexion par mot de passe désactivée."
[CannotShareToSelf]
description = "Cannot share to self."
one = "Impossible de partager avec soi-même."
other = "Impossible de partager avec soi-même."
[Banned]
description = "Banned."
one = "Banni."
other = "Banni."
[RegisterSuccessWaitAdminConfirm]
description = "Register success wait admin confirm."
one = "Inscription réussie, veuillez attendre la confirmation de l'administrateur."
other = "Inscription réussie, veuillez attendre la confirmation de l'administrateur."

View File

@@ -141,3 +141,18 @@ other = "Captcha 오류."
description = "Password login disabled."
one = "비밀번호 로그인이 비활성화되었습니다."
other = "비밀번호 로그인이 비활성화되었습니다."
[CannotShareToSelf]
description = "Cannot share to self."
one = "자기 자신에게 공유할 수 없습니다."
other = "자기 자신에게 공유할 수 없습니다."
[Banned]
description = "Banned."
one = "금지됨."
other = "금지됨."
[RegisterSuccessWaitAdminConfirm]
description = "Register success wait admin confirm."
one = "가입 성공, 관리자 확인 대기 중."
other = "가입 성공, 관리자 확인 대기 중."

View File

@@ -146,4 +146,19 @@ other = "Ошибка капчи."
[PwdLoginDisabled]
description = "Password login disabled."
one = "Вход по паролю отключен."
other = "Вход по паролю отключен."
other = "Вход по паролю отключен."
[CannotShareToSelf]
description = "Cannot share to self."
one = "Нельзя поделиться с собой."
other = "Нельзя поделиться с собой."
[Banned]
description = "Banned."
one = "Заблокировано."
other = "Заблокировано."
[RegisterSuccessWaitAdminConfirm]
description = "Register success wait admin confirm."
one = "Регистрация прошла успешно, ожидайте подтверждения администратора."
other = "Регистрация прошла успешно, ожидайте подтверждения администратора."

View File

@@ -139,4 +139,19 @@ other = "验证码错误。"
[PwdLoginDisabled]
description = "Password login disabled."
one = "密码登录已禁用。"
other = "密码登录已禁用。"
other = "密码登录已禁用。"
[CannotShareToSelf]
description = "Cannot share to self."
one = "不能共享给自己。"
other = "不能共享给自己。"
[Banned]
description = "Banned."
one = "已被封禁。"
other = "已被封禁。"
[RegisterSuccessWaitAdminConfirm]
description = "Register success, wait for admin confirm."
one = "注册成功,请等待管理员审核。"
other = "注册成功,请等待管理员审核。"

View File

@@ -140,3 +140,18 @@ other = "驗證碼錯誤。"
description = "Password login disabled."
one = "密碼登錄已禁用。"
other = "密碼登錄已禁用。"
[CannotShareToSelf]
description = "Cannot share to self."
one = "無法共享給自己。"
other = "無法共享給自己。"
[Banned]
description = "Banned."
one = "禁止使用。"
other = "禁止使用。"
[RegisterSuccessWaitAdminConfirm]
description = "Register success wait admin confirm."
one = "註冊成功,請等待管理員確認。"
other = "註冊成功,請等待管理員確認。"

View File

View File

@@ -0,0 +1,81 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>OauthFailed - RustDesk API</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Arial, sans-serif;
background-color: #f5f5f5;
margin: 0;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
}
.success-container {
text-align: center;
background: white;
padding: 2rem;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
max-width: 400px;
width: 90%;
}
.checkmark {
color: #ba363a;
font-size: 4rem;
margin-bottom: 1rem;
}
h1 {
color: #333;
margin-bottom: 1rem;
}
p {
color: #666;
line-height: 1.6;
margin-bottom: 1.5rem;
}
.return-link {
display: inline-block;
padding: 10px 20px;
background-color: #ba363a;
color: white;
text-decoration: none;
border-radius: 5px;
transition: background-color 0.3s;
}
.return-link:hover {
background-color: #ba363a;
}
</style>
<link rel="stylesheet" href="https://lf9-cdn-tos.bytecdntp.com/cdn/expire-1-M/font-awesome/6.0.0/css/all.min.css">
<script>
var lang = navigator.language || navigator.userLanguage || 'zh-CN';
var title = 'OauthFailed'
var msg = '{{.message}}'
var btn = 'Close'
document.writeln('<script src="/api/oauth/msg?lang=' + lang + '&msg=' + msg + '&title=OauthFailed"><\/script>');
</script>
</head>
<body>
<div class="success-container">
<i class="fas fa-triangle-exclamation checkmark"></i>
<h1 id="h1"></h1>
<p id="msg"></p>
<a href="javascript:window.close()" class="return-link" id="btn">Close</a>
</div>
<script>
document.title = title + ' - RustDesk API';
document.getElementById('h1').innerText = title;
document.getElementById('msg').innerText = msg;
</script>
</body>
</html>

View File

@@ -0,0 +1,82 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>OauthSuccess - RustDesk API</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Arial, sans-serif;
background-color: #f5f5f5;
margin: 0;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
}
.success-container {
text-align: center;
background: white;
padding: 2rem;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
max-width: 400px;
width: 90%;
}
.checkmark {
color: #4CAF50;
font-size: 4rem;
margin-bottom: 1rem;
}
h1 {
color: #333;
margin-bottom: 1rem;
}
p {
color: #666;
line-height: 1.6;
margin-bottom: 1.5rem;
}
.return-link {
display: inline-block;
padding: 10px 20px;
background-color: #4CAF50;
color: white;
text-decoration: none;
border-radius: 5px;
transition: background-color 0.3s;
}
.return-link:hover {
background-color: #45a049;
}
</style>
<script>
var lang = navigator.language || navigator.userLanguage || 'zh-CN';
var title = 'OauthSuccess'
var msg = '{{.message}}'
var btn = 'Close'
document.writeln('<script src="/api/oauth/msg?lang=' + lang + '&msg=' + msg + '&title=OauthSuccess"><\/script>');
</script>
</head>
<body>
<div class="success-container">
<i class="fas fa-check-circle checkmark"></i>
<h1 id="h1"></h1>
<!-- <p>您已成功授权访问您的账户。</p>-->
<!-- <p>现在可以关闭本页面或返回应用继续操作。</p>-->
<a href="javascript:window.close()" class="return-link">Close</a>
</div>
<script>
document.title = title + ' - RustDesk API';
document.getElementById('h1').innerText = title;
document.getElementById('msg').innerText = msg;
</script>
</body>
</html>

View File

@@ -38,5 +38,21 @@
"asset": "assets/address_book.ttf"
}
]
},
{
"family": "DeviceGroup",
"fonts": [
{
"asset": "assets/device_group.ttf"
}
]
},
{
"family": "More",
"fonts": [
{
"asset": "assets/more.ttf"
}
]
}
]

Binary file not shown.

BIN
resources/web2/assets/assets/more.ttf vendored Normal file

Binary file not shown.

View File

@@ -1,6 +1,6 @@
<!DOCTYPE html>
<html>
<head>
<head>
<!--
If you are serving your web app in a path other than the root, change the
href value below to reflect the base path you are serving from.
@@ -16,195 +16,196 @@
-->
<base href="/webclient2/" />
<meta charset="UTF-8" />
<meta content="IE=Edge" http-equiv="X-UA-Compatible" />
<meta name="description" content="Remote Desktop." />
<meta charset="UTF-8"/>
<meta content="IE=Edge" http-equiv="X-UA-Compatible"/>
<meta name="description" content="Remote Desktop."/>
<!-- iOS meta tags & icons -->
<meta name="apple-mobile-web-app-capable" content="yes" />
<meta name="apple-mobile-web-app-status-bar-style" content="black" />
<meta name="apple-mobile-web-app-title" content="RustDesk" />
<link rel="apple-touch-icon" href="icons/Icon-192.png?v=1a7ad736" />
<meta name="apple-mobile-web-app-capable" content="yes"/>
<meta name="apple-mobile-web-app-status-bar-style" content="black"/>
<meta name="apple-mobile-web-app-title" content="RustDesk"/>
<link rel="apple-touch-icon" href="icons/Icon-192.png?v=1a7ad736"/>
<!-- Favicon -->
<link rel="icon" type="image/svg+xml" href="favicon.svg?v=8fcccd9a" />
<link rel="icon" type="image/svg+xml" href="favicon.svg?v=8fcccd9a"/>
<title>RustDesk</title>
<script src="/webclient-config/index.js"></script>
<link rel="manifest" href="manifest.json" />
<script type="module" crossorigin src="js/dist/index.js?v=cabfd933"></script>
<link rel="modulepreload" href="js/dist/vendor.js?v=0b990c6e" />
<link rel="manifest" href="manifest.json"/>
<script type="module" crossorigin src="js/dist/index.js?v=ddbe54f1"></script>
<link rel="modulepreload" href="js/dist/vendor.js?v=0b990c6e"/>
<style>
html,
body,
#root {
height: 100%;
margin: 0;
padding: 0;
}
#root {
background-repeat: no-repeat;
background-size: 100% auto;
}
.loading-title {
font-size: 1.1rem;
}
.loading-sub-title {
margin-top: 20px;
font-size: 1rem;
color: #888;
}
.page-loading-warp {
display: flex;
align-items: center;
justify-content: center;
padding: 26px;
}
.ant-spin {
position: absolute;
display: none;
-webkit-box-sizing: border-box;
box-sizing: border-box;
margin: 0;
padding: 0;
color: rgba(0, 0, 0, 0.65);
color: #1890ff;
font-size: 14px;
font-variant: tabular-nums;
line-height: 1.5;
text-align: center;
list-style: none;
opacity: 0;
-webkit-transition: -webkit-transform 0.3s
cubic-bezier(0.78, 0.14, 0.15, 0.86);
transition: -webkit-transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86);
transition: transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86);
transition: transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86),
-webkit-transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86);
-webkit-font-feature-settings: "tnum";
font-feature-settings: "tnum";
}
.ant-spin-spinning {
position: static;
display: inline-block;
opacity: 1;
}
.ant-spin-dot {
position: relative;
display: inline-block;
width: 20px;
height: 20px;
font-size: 20px;
}
.ant-spin-dot-item {
position: absolute;
display: block;
width: 9px;
height: 9px;
background-color: #1890ff;
border-radius: 100%;
-webkit-transform: scale(0.75);
-ms-transform: scale(0.75);
transform: scale(0.75);
-webkit-transform-origin: 50% 50%;
-ms-transform-origin: 50% 50%;
transform-origin: 50% 50%;
opacity: 0.3;
-webkit-animation: antspinmove 1s infinite linear alternate;
animation: antSpinMove 1s infinite linear alternate;
}
.ant-spin-dot-item:nth-child(1) {
top: 0;
left: 0;
}
.ant-spin-dot-item:nth-child(2) {
top: 0;
right: 0;
-webkit-animation-delay: 0.4s;
animation-delay: 0.4s;
}
.ant-spin-dot-item:nth-child(3) {
right: 0;
bottom: 0;
-webkit-animation-delay: 0.8s;
animation-delay: 0.8s;
}
.ant-spin-dot-item:nth-child(4) {
bottom: 0;
left: 0;
-webkit-animation-delay: 1.2s;
animation-delay: 1.2s;
}
.ant-spin-dot-spin {
-webkit-transform: rotate(45deg);
-ms-transform: rotate(45deg);
transform: rotate(45deg);
-webkit-animation: antrotate 1.2s infinite linear;
animation: antRotate 1.2s infinite linear;
}
.ant-spin-lg .ant-spin-dot {
width: 32px;
height: 32px;
font-size: 32px;
}
.ant-spin-lg .ant-spin-dot i {
width: 14px;
height: 14px;
}
@media all and (-ms-high-contrast: none), (-ms-high-contrast: active) {
.ant-spin-blur {
background: #fff;
opacity: 0.5;
html,
body,
#root {
height: 100%;
margin: 0;
padding: 0;
}
}
@-webkit-keyframes antSpinMove {
to {
opacity: 1;
#root {
background-repeat: no-repeat;
background-size: 100% auto;
}
}
@keyframes antSpinMove {
to {
opacity: 1;
.loading-title {
font-size: 1.1rem;
}
}
@-webkit-keyframes antRotate {
to {
-webkit-transform: rotate(405deg);
transform: rotate(405deg);
.loading-sub-title {
margin-top: 20px;
font-size: 1rem;
color: #888;
}
}
@keyframes antRotate {
to {
-webkit-transform: rotate(405deg);
transform: rotate(405deg);
.page-loading-warp {
display: flex;
align-items: center;
justify-content: center;
padding: 26px;
}
.ant-spin {
position: absolute;
display: none;
-webkit-box-sizing: border-box;
box-sizing: border-box;
margin: 0;
padding: 0;
color: rgba(0, 0, 0, 0.65);
color: #1890ff;
font-size: 14px;
font-variant: tabular-nums;
line-height: 1.5;
text-align: center;
list-style: none;
opacity: 0;
-webkit-transition: -webkit-transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86);
transition: -webkit-transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86);
transition: transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86);
transition: transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86),
-webkit-transform 0.3s cubic-bezier(0.78, 0.14, 0.15, 0.86);
-webkit-font-feature-settings: "tnum";
font-feature-settings: "tnum";
}
.ant-spin-spinning {
position: static;
display: inline-block;
opacity: 1;
}
.ant-spin-dot {
position: relative;
display: inline-block;
width: 20px;
height: 20px;
font-size: 20px;
}
.ant-spin-dot-item {
position: absolute;
display: block;
width: 9px;
height: 9px;
background-color: #1890ff;
border-radius: 100%;
-webkit-transform: scale(0.75);
-ms-transform: scale(0.75);
transform: scale(0.75);
-webkit-transform-origin: 50% 50%;
-ms-transform-origin: 50% 50%;
transform-origin: 50% 50%;
opacity: 0.3;
-webkit-animation: antspinmove 1s infinite linear alternate;
animation: antSpinMove 1s infinite linear alternate;
}
.ant-spin-dot-item:nth-child(1) {
top: 0;
left: 0;
}
.ant-spin-dot-item:nth-child(2) {
top: 0;
right: 0;
-webkit-animation-delay: 0.4s;
animation-delay: 0.4s;
}
.ant-spin-dot-item:nth-child(3) {
right: 0;
bottom: 0;
-webkit-animation-delay: 0.8s;
animation-delay: 0.8s;
}
.ant-spin-dot-item:nth-child(4) {
bottom: 0;
left: 0;
-webkit-animation-delay: 1.2s;
animation-delay: 1.2s;
}
.ant-spin-dot-spin {
-webkit-transform: rotate(45deg);
-ms-transform: rotate(45deg);
transform: rotate(45deg);
-webkit-animation: antrotate 1.2s infinite linear;
animation: antRotate 1.2s infinite linear;
}
.ant-spin-lg .ant-spin-dot {
width: 32px;
height: 32px;
font-size: 32px;
}
.ant-spin-lg .ant-spin-dot i {
width: 14px;
height: 14px;
}
@media all and (-ms-high-contrast: none), (-ms-high-contrast: active) {
.ant-spin-blur {
background: #fff;
opacity: 0.5;
}
}
@-webkit-keyframes antSpinMove {
to {
opacity: 1;
}
}
@keyframes antSpinMove {
to {
opacity: 1;
}
}
@-webkit-keyframes antRotate {
to {
-webkit-transform: rotate(405deg);
transform: rotate(405deg);
}
}
@keyframes antRotate {
to {
-webkit-transform: rotate(405deg);
transform: rotate(405deg);
}
}
}
</style>
</head>
</head>
<body>
<div id="root">
<div
id="div-background"
style="
<body>
<div id="root">
<div
id="div-background"
style="
display: flex;
flex-direction: column;
align-items: center;
@@ -212,117 +213,119 @@
height: 100%;
min-height: 420px;
"
>
<img src="./favicon.svg?v=8fcccd9a" alt="logo" width="256" />
>
<img src="./favicon.svg?v=8fcccd9a" alt="logo" width="256"/>
<div class="page-loading-warp">
<div class="ant-spin ant-spin-lg ant-spin-spinning">
<div class="ant-spin ant-spin-lg ant-spin-spinning">
<span class="ant-spin-dot ant-spin-dot-spin">
<i class="ant-spin-dot-item"></i>
<i class="ant-spin-dot-item"></i>
<i class="ant-spin-dot-item"></i><i class="ant-spin-dot-item"></i>
</span>
</div>
</div>
</div>
<div
style="display: flex; align-items: center; justify-content: center"
style="display: flex; align-items: center; justify-content: center"
>
<img src="./favicon.svg?v=8fcccd9a" width="32" style="margin-right: 8px" />
<span id="span-text">RustDesk Web Client V2 Preview</span>
<img src="./favicon.svg?v=8fcccd9a" width="32" style="margin-right: 8px"/>
<span id="span-text">RustDesk Web Client V2 Preview</span>
</div>
</div>
</div>
<!-- This script installs service_worker.js to provide PWA functionality to
application. For more information, see:
https://developers.google.com/web/fundamentals/primers/service-workers -->
<script>
const systemTheme = window.matchMedia("(prefers-color-scheme: dark)")
</div>
<!-- This script installs service_worker.js to provide PWA functionality to
application. For more information, see:
https://developers.google.com/web/fundamentals/primers/service-workers -->
<script>
const systemTheme = window.matchMedia("(prefers-color-scheme: dark)")
.matches
? "dark"
: "light";
const myTheme = localStorage.getItem("wc-option:local:theme");
const them = myTheme || systemTheme;
const myTheme = localStorage.getItem("wc-option:local:theme");
const them = myTheme || systemTheme;
const divBackground = document.querySelector("#div-background");
if (divBackground) {
const divBackground = document.querySelector("#div-background");
if (divBackground) {
divBackground.style.backgroundColor = them === "dark" ? "#000" : "#fff";
}
const spanConsole = document.querySelector("#span-text");
if (spanConsole) {
}
const spanConsole = document.querySelector("#span-text");
if (spanConsole) {
spanConsole.style.color = them === "dark" ? "#fff" : "#000";
}
}
const serviceWorkerVersion = "3267265270";
var scriptLoaded = false;
function loadMainDartJs() {
const serviceWorkerVersion = "461457302";
var scriptLoaded = false;
function loadMainDartJs() {
if (scriptLoaded) {
return;
return;
}
scriptLoaded = true;
var scriptTag = document.createElement("script");
scriptTag.src = "main.dart.js?v=060a626e";
scriptTag.src = "main.dart.js?v=6d16cb80";
scriptTag.type = "application/javascript";
document.body.append(scriptTag);
}
}
if ("serviceWorker" in navigator) {
if ("serviceWorker" in navigator) {
// Service workers are supported. Use them.
window.addEventListener("load", function () {
// Wait for registration to finish before dropping the <script> tag.
// Otherwise, the browser will load the script multiple times,
// potentially different versions.
var serviceWorkerUrl =
"flutter_service_worker.js?v=" + serviceWorkerVersion;
navigator.serviceWorker.register(serviceWorkerUrl).then((reg) => {
function waitForActivation(serviceWorker) {
serviceWorker.addEventListener("statechange", () => {
if (serviceWorker.state == "activated") {
console.log("Installed new service worker.");
loadMainDartJs();
// Wait for registration to finish before dropping the <script> tag.
// Otherwise, the browser will load the script multiple times,
// potentially different versions.
var serviceWorkerUrl =
"flutter_service_worker.js?v=" + serviceWorkerVersion;
navigator.serviceWorker.register(serviceWorkerUrl).then((reg) => {
function waitForActivation(serviceWorker) {
serviceWorker.addEventListener("statechange", () => {
if (serviceWorker.state == "activated") {
console.log("Installed new service worker.");
loadMainDartJs();
}
});
}
});
}
if (!reg.active && (reg.installing || reg.waiting)) {
// No active web worker and we have installed or are installing
// one for the first time. Simply wait for it to activate.
waitForActivation(reg.installing || reg.waiting);
} else if (!reg.active.scriptURL.endsWith(serviceWorkerVersion)) {
// When the app updates the serviceWorkerVersion changes, so we
// need to ask the service worker to update.
console.log("New service worker available.");
reg.update();
waitForActivation(reg.installing);
} else {
// Existing service worker is still good.
console.log("Loading app from service worker.");
loadMainDartJs();
}
});
// If service worker doesn't succeed in a reasonable amount of time,
// fallback to plaint <script> tag.
setTimeout(() => {
if (!scriptLoaded) {
console.warn(
"Failed to load app from service worker. Falling back to plain <script> tag."
);
loadMainDartJs();
}
}, 4000);
if (!reg.active && (reg.installing || reg.waiting)) {
// No active web worker and we have installed or are installing
// one for the first time. Simply wait for it to activate.
waitForActivation(reg.installing || reg.waiting);
} else if (!reg.active.scriptURL.endsWith(serviceWorkerVersion)) {
// When the app updates the serviceWorkerVersion changes, so we
// need to ask the service worker to update.
console.log("New service worker available.");
reg.update();
waitForActivation(reg.installing);
} else {
// Existing service worker is still good.
console.log("Loading app from service worker.");
loadMainDartJs();
}
});
// If service worker doesn't succeed in a reasonable amount of time,
// fallback to plaint <script> tag.
setTimeout(() => {
if (!scriptLoaded) {
console.warn(
"Failed to load app from service worker. Falling back to plain <script> tag."
);
loadMainDartJs();
}
}, 4000);
});
} else {
} else {
// Service workers not supported. Just drop the <script> tag.
loadMainDartJs();
}
</script>
<script src="libs/stream/ponyfill.min.js"></script>
<script src="libs/stream/StreamSaver.min.js"></script>
<script src="libs/firebase-app.js?8.10.1"></script>
<script src="libs/firebase-analytics.js?8.10.1"></script>
}
</script>
<script src="libs/stream/ponyfill.min.js"></script>
<script src="libs/stream/StreamSaver.min.js"></script>
<script src="libs/firebase-app.js?8.10.1"></script>
<script src="libs/firebase-analytics.js?8.10.1"></script>
<script>
// Your web app's Firebase configuration
// For Firebase JS SDK v7.20.0 and later, measurementId is optional
const firebaseConfig = {
<script>
// Your web app's Firebase configuration
// For Firebase JS SDK v7.20.0 and later, measurementId is optional
const firebaseConfig = {
apiKey: "AIzaSyCgehIZk1aFP0E7wZtYRRqrfvNiNAF39-A",
authDomain: "rustdesk.firebaseapp.com",
databaseURL: "https://rustdesk.firebaseio.com",
@@ -331,11 +334,11 @@
messagingSenderId: "768133699366",
appId: "1:768133699366:web:d50faf0792cb208d7993e7",
measurementId: "G-9PEH85N6ZQ",
};
};
// Initialize Firebase
firebase.initializeApp(firebaseConfig);
firebase.analytics();
</script>
</body>
// Initialize Firebase
firebase.initializeApp(firebaseConfig);
firebase.analytics();
</script>
</body>
</html>

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,11 @@
window._gwen = {}
window._gwen.kv = {}
//fix 语言
if(!localStorage.getItem('wc-option:local:lang') && navigator.language){
localStorage.setItem('wc-option:local:lang', navigator.language.toLowerCase())
}
const storage_prefix = 'wc-'
const apiserver = localStorage.getItem('wc-api-server')
@@ -46,7 +52,7 @@ if (share_token) {
password: peer.tmppwd,
}*/
//修改location
window.location.href = `/webclient2/#/${peer.info.id}?password=${peer.tmppwd}`
window.location.href = `/webclient2/#/${peer.info.id}?password=${encodeURIComponent(peer.tmppwd)}`
}
})
}

163526
resources/web2/main.dart.js vendored

File diff suppressed because one or more lines are too long

View File

@@ -3,7 +3,6 @@ package service
import (
"encoding/json"
"github.com/google/uuid"
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/model"
"gorm.io/gorm"
"strings"
@@ -14,24 +13,24 @@ type AddressBookService struct {
func (s *AddressBookService) Info(id string) *model.AddressBook {
p := &model.AddressBook{}
global.DB.Where("id = ?", id).First(p)
DB.Where("id = ?", id).First(p)
return p
}
func (s *AddressBookService) InfoByUserIdAndId(userid uint, id string) *model.AddressBook {
p := &model.AddressBook{}
global.DB.Where("user_id = ? and id = ?", userid, id).First(p)
DB.Where("user_id = ? and id = ?", userid, id).First(p)
return p
}
func (s *AddressBookService) InfoByUserIdAndIdAndCid(userid uint, id string, cid uint) *model.AddressBook {
p := &model.AddressBook{}
global.DB.Where("user_id = ? and id = ? and collection_id = ?", userid, id, cid).First(p)
DB.Where("user_id = ? and id = ? and collection_id = ?", userid, id, cid).First(p)
return p
}
func (s *AddressBookService) InfoByRowId(id uint) *model.AddressBook {
p := &model.AddressBook{}
global.DB.Where("row_id = ?", id).First(p)
DB.Where("row_id = ?", id).First(p)
return p
}
func (s *AddressBookService) ListByUserId(userId, page, pageSize uint) (res *model.AddressBookList) {
@@ -49,14 +48,14 @@ func (s *AddressBookService) ListByUserIds(userIds []uint, page, pageSize uint)
// AddAddressBook
func (s *AddressBookService) AddAddressBook(ab *model.AddressBook) error {
return global.DB.Create(ab).Error
return DB.Create(ab).Error
}
// UpdateAddressBook
func (s *AddressBookService) UpdateAddressBook(abs []*model.AddressBook, userId uint) error {
//比较peers和数据库中的数据如果peers中的数据在数据库中不存在则添加如果存在则更新如果数据库中的数据在peers中不存在则删除
// 开始事务
tx := global.DB.Begin()
tx := DB.Begin()
//1. 获取数据库中的数据
var dbABs []*model.AddressBook
tx.Where("user_id = ?", userId).Find(&dbABs)
@@ -107,7 +106,7 @@ func (s *AddressBookService) List(page, pageSize uint, where func(tx *gorm.DB))
res = &model.AddressBookList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := global.DB.Model(&model.AddressBook{})
tx := DB.Model(&model.AddressBook{})
if where != nil {
where(tx)
}
@@ -129,38 +128,38 @@ func (s *AddressBookService) FromPeer(peer *model.Peer) (a *model.AddressBook) {
// Create 创建
func (s *AddressBookService) Create(u *model.AddressBook) error {
res := global.DB.Create(u).Error
res := DB.Create(u).Error
return res
}
func (s *AddressBookService) Delete(u *model.AddressBook) error {
return global.DB.Delete(u).Error
return DB.Delete(u).Error
}
// Update 更新
func (s *AddressBookService) Update(u *model.AddressBook) error {
return global.DB.Model(u).Updates(u).Error
return DB.Model(u).Updates(u).Error
}
// UpdateByMap 更新
func (s *AddressBookService) UpdateByMap(u *model.AddressBook, data map[string]interface{}) error {
return global.DB.Model(u).Updates(data).Error
return DB.Model(u).Updates(data).Error
}
// UpdateAll 更新
func (s *AddressBookService) UpdateAll(u *model.AddressBook) error {
return global.DB.Model(u).Select("*").Omit("created_at").Updates(u).Error
return DB.Model(u).Select("*").Omit("created_at").Updates(u).Error
}
// ShareByWebClient 分享
func (s *AddressBookService) ShareByWebClient(m *model.ShareRecord) error {
m.ShareToken = uuid.New().String()
return global.DB.Create(m).Error
return DB.Create(m).Error
}
// SharedPeer
func (s *AddressBookService) SharedPeer(shareToken string) *model.ShareRecord {
m := &model.ShareRecord{}
global.DB.Where("share_token = ?", shareToken).First(m)
DB.Where("share_token = ?", shareToken).First(m)
return m
}
@@ -190,7 +189,7 @@ func (s *AddressBookService) ListCollection(page, pageSize uint, where func(tx *
res = &model.AddressBookCollectionList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := global.DB.Model(&model.AddressBookCollection{})
tx := DB.Model(&model.AddressBookCollection{})
if where != nil {
where(tx)
}
@@ -200,7 +199,7 @@ func (s *AddressBookService) ListCollection(page, pageSize uint, where func(tx *
return
}
func (s *AddressBookService) ListCollectionByIds(ids []uint) (res []*model.AddressBookCollection) {
global.DB.Where("id in ?", ids).Find(&res)
DB.Where("id in ?", ids).Find(&res)
return res
}
@@ -212,20 +211,20 @@ func (s *AddressBookService) ListCollectionByUserId(userId uint) (res *model.Add
}
func (s *AddressBookService) CollectionInfoById(id uint) *model.AddressBookCollection {
p := &model.AddressBookCollection{}
global.DB.Where("id = ?", id).First(p)
DB.Where("id = ?", id).First(p)
return p
}
func (s *AddressBookService) CollectionReadRules(user *model.User) (res []*model.AddressBookCollectionRule) {
// personalRules
var personalRules []*model.AddressBookCollectionRule
tx2 := global.DB.Model(&model.AddressBookCollectionRule{})
tx2 := DB.Model(&model.AddressBookCollectionRule{})
tx2.Where("type = ? and to_id = ? and rule > 0", model.ShareAddressBookRuleTypePersonal, user.Id).Find(&personalRules)
res = append(res, personalRules...)
//group
var groupRules []*model.AddressBookCollectionRule
tx3 := global.DB.Model(&model.AddressBookCollectionRule{})
tx3 := DB.Model(&model.AddressBookCollectionRule{})
tx3.Where("type = ? and to_id = ? and rule > 0", model.ShareAddressBookRuleTypeGroup, user.GroupId).Find(&groupRules)
res = append(res, groupRules...)
return
@@ -238,7 +237,7 @@ func (s *AddressBookService) UserMaxRule(user *model.User, uid, cid uint) int {
}
max := 0
personalRules := &model.AddressBookCollectionRule{}
tx := global.DB.Model(personalRules)
tx := DB.Model(personalRules)
tx.Where("type = ? and collection_id = ? and to_id = ?", model.ShareAddressBookRuleTypePersonal, cid, user.Id).First(&personalRules)
if personalRules.Id != 0 {
max = personalRules.Rule
@@ -248,7 +247,7 @@ func (s *AddressBookService) UserMaxRule(user *model.User, uid, cid uint) int {
}
groupRules := &model.AddressBookCollectionRule{}
tx2 := global.DB.Model(groupRules)
tx2 := DB.Model(groupRules)
tx2.Where("type = ? and collection_id = ? and to_id = ?", model.ShareAddressBookRuleTypeGroup, cid, user.GroupId).First(&groupRules)
if groupRules.Id != 0 {
if groupRules.Rule > max {
@@ -272,16 +271,16 @@ func (s *AddressBookService) CheckUserFullControlPrivilege(user *model.User, uid
}
func (s *AddressBookService) CreateCollection(t *model.AddressBookCollection) error {
return global.DB.Create(t).Error
return DB.Create(t).Error
}
func (s *AddressBookService) UpdateCollection(t *model.AddressBookCollection) error {
return global.DB.Model(t).Updates(t).Error
return DB.Model(t).Updates(t).Error
}
func (s *AddressBookService) DeleteCollection(t *model.AddressBookCollection) error {
//删除集合下的所有规则、地址簿,再删除集合
tx := global.DB.Begin()
tx := DB.Begin()
tx.Where("collection_id = ?", t.Id).Delete(&model.AddressBookCollectionRule{})
tx.Where("collection_id = ?", t.Id).Delete(&model.AddressBook{})
tx.Delete(t)
@@ -290,23 +289,26 @@ func (s *AddressBookService) DeleteCollection(t *model.AddressBookCollection) er
func (s *AddressBookService) RuleInfoById(u uint) *model.AddressBookCollectionRule {
p := &model.AddressBookCollectionRule{}
global.DB.Where("id = ?", u).First(p)
DB.Where("id = ?", u).First(p)
return p
}
func (s *AddressBookService) RulePersonalInfoByToIdAndCid(toid, cid uint) *model.AddressBookCollectionRule {
return s.RuleInfoByToIdAndCid(model.ShareAddressBookRuleTypePersonal, toid, cid)
}
func (s *AddressBookService) RuleInfoByToIdAndCid(t int, toid, cid uint) *model.AddressBookCollectionRule {
p := &model.AddressBookCollectionRule{}
global.DB.Where("type = ? and to_id = ? and collection_id = ?", model.ShareAddressBookRuleTypePersonal, toid, cid).First(p)
DB.Where("type = ? and to_id = ? and collection_id = ?", t, toid, cid).First(p)
return p
}
func (s *AddressBookService) CreateRule(t *model.AddressBookCollectionRule) error {
return global.DB.Create(t).Error
return DB.Create(t).Error
}
func (s *AddressBookService) ListRules(page uint, size uint, f func(tx *gorm.DB)) *model.AddressBookCollectionRuleList {
res := &model.AddressBookCollectionRuleList{}
res.Page = int64(page)
res.PageSize = int64(size)
tx := global.DB.Model(&model.AddressBookCollectionRule{})
tx := DB.Model(&model.AddressBookCollectionRule{})
if f != nil {
f(tx)
}
@@ -317,11 +319,11 @@ func (s *AddressBookService) ListRules(page uint, size uint, f func(tx *gorm.DB)
}
func (s *AddressBookService) UpdateRule(t *model.AddressBookCollectionRule) error {
return global.DB.Model(t).Updates(t).Error
return DB.Model(t).Updates(t).Error
}
func (s *AddressBookService) DeleteRule(t *model.AddressBookCollectionRule) error {
return global.DB.Delete(t).Error
return DB.Delete(t).Error
}
// CheckCollectionOwner 检查Collection的所有者
@@ -336,5 +338,5 @@ func (s *AddressBookService) BatchUpdateTags(abs []*model.AddressBook, tags []st
ids = append(ids, ab.RowId)
}
tagsv, _ := json.Marshal(tags)
return global.DB.Model(&model.AddressBook{}).Where("row_id in ?", ids).Update("tags", tagsv).Error
return DB.Model(&model.AddressBook{}).Where("row_id in ?", ids).Update("tags", tagsv).Error
}

39
service/app.go Normal file
View File

@@ -0,0 +1,39 @@
package service
import (
"os"
"sync"
"time"
)
type AppService struct {
}
var version = ""
var startTime = ""
var once = &sync.Once{}
func (a *AppService) GetAppVersion() string {
if version != "" {
return version
}
once.Do(func() {
v, err := os.ReadFile("resources/version")
if err != nil {
return
}
version = string(v)
})
return version
}
func init() {
// Initialize the AppService if needed
startTime = time.Now().Format("2006-01-02 15:04:05")
}
// GetStartTime
func (a *AppService) GetStartTime() string {
return startTime
}

33
service/app_test.go Normal file
View File

@@ -0,0 +1,33 @@
package service
import (
"sync"
"testing"
)
// TestGetAppVersion
func TestGetAppVersion(t *testing.T) {
s := &AppService{}
v := s.GetAppVersion()
// 打印结果
t.Logf("App Version: %s", v)
}
func TestMultipleGetAppVersion(t *testing.T) {
s := &AppService{}
//并发测试
// 使用 WaitGroup 等待所有 goroutine 完成
wg := sync.WaitGroup{}
wg.Add(10) // 启动 10 个 goroutine
// 启动 10 个 goroutine
for i := 0; i < 10; i++ {
go func() {
defer wg.Done() // 完成后减少计数
v := s.GetAppVersion()
// 打印结果
t.Logf("App Version: %s", v)
}()
}
// 等待所有 goroutine 完成
wg.Wait()
}

View File

@@ -1,7 +1,6 @@
package service
import (
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/model"
"gorm.io/gorm"
)
@@ -13,7 +12,7 @@ func (as *AuditService) AuditConnList(page, pageSize uint, where func(tx *gorm.D
res = &model.AuditConnList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := global.DB.Model(&model.AuditConn{})
tx := DB.Model(&model.AuditConn{})
if where != nil {
where(tx)
}
@@ -25,36 +24,36 @@ func (as *AuditService) AuditConnList(page, pageSize uint, where func(tx *gorm.D
// Create 创建
func (as *AuditService) CreateAuditConn(u *model.AuditConn) error {
res := global.DB.Create(u).Error
res := DB.Create(u).Error
return res
}
func (as *AuditService) DeleteAuditConn(u *model.AuditConn) error {
return global.DB.Delete(u).Error
return DB.Delete(u).Error
}
// Update 更新
func (as *AuditService) UpdateAuditConn(u *model.AuditConn) error {
return global.DB.Model(u).Updates(u).Error
return DB.Model(u).Updates(u).Error
}
// InfoByPeerIdAndConnId
func (as *AuditService) InfoByPeerIdAndConnId(peerId string, connId int64) (res *model.AuditConn) {
res = &model.AuditConn{}
global.DB.Where("peer_id = ? and conn_id = ?", peerId, connId).First(res)
DB.Where("peer_id = ? and conn_id = ?", peerId, connId).First(res)
return
}
// ConnInfoById
func (as *AuditService) ConnInfoById(id uint) (res *model.AuditConn) {
res = &model.AuditConn{}
global.DB.Where("id = ?", id).First(res)
DB.Where("id = ?", id).First(res)
return
}
// FileInfoById
func (as *AuditService) FileInfoById(id uint) (res *model.AuditFile) {
res = &model.AuditFile{}
global.DB.Where("id = ?", id).First(res)
DB.Where("id = ?", id).First(res)
return
}
@@ -62,7 +61,7 @@ func (as *AuditService) AuditFileList(page, pageSize uint, where func(tx *gorm.D
res = &model.AuditFileList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := global.DB.Model(&model.AuditFile{})
tx := DB.Model(&model.AuditFile{})
if where != nil {
where(tx)
}
@@ -74,22 +73,22 @@ func (as *AuditService) AuditFileList(page, pageSize uint, where func(tx *gorm.D
// CreateAuditFile
func (as *AuditService) CreateAuditFile(u *model.AuditFile) error {
res := global.DB.Create(u).Error
res := DB.Create(u).Error
return res
}
func (as *AuditService) DeleteAuditFile(u *model.AuditFile) error {
return global.DB.Delete(u).Error
return DB.Delete(u).Error
}
// Update 更新
func (as *AuditService) UpdateAuditFile(u *model.AuditFile) error {
return global.DB.Model(u).Updates(u).Error
return DB.Model(u).Updates(u).Error
}
func (as *AuditService) BatchDeleteAuditConn(ids []uint) error {
return global.DB.Where("id in (?)", ids).Delete(&model.AuditConn{}).Error
return DB.Where("id in (?)", ids).Delete(&model.AuditConn{}).Error
}
func (as *AuditService) BatchDeleteAuditFile(ids []uint) error {
return global.DB.Where("id in (?)", ids).Delete(&model.AuditFile{}).Error
return DB.Where("id in (?)", ids).Delete(&model.AuditFile{}).Error
}

View File

@@ -1,7 +1,6 @@
package service
import (
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/model"
"gorm.io/gorm"
)
@@ -12,7 +11,7 @@ type GroupService struct {
// InfoById 根据用户id取用户信息
func (us *GroupService) InfoById(id uint) *model.Group {
u := &model.Group{}
global.DB.Where("id = ?", id).First(u)
DB.Where("id = ?", id).First(u)
return u
}
@@ -20,7 +19,7 @@ func (us *GroupService) List(page, pageSize uint, where func(tx *gorm.DB)) (res
res = &model.GroupList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := global.DB.Model(&model.Group{})
tx := DB.Model(&model.Group{})
if where != nil {
where(tx)
}
@@ -32,14 +31,47 @@ func (us *GroupService) List(page, pageSize uint, where func(tx *gorm.DB)) (res
// Create 创建
func (us *GroupService) Create(u *model.Group) error {
res := global.DB.Create(u).Error
res := DB.Create(u).Error
return res
}
func (us *GroupService) Delete(u *model.Group) error {
return global.DB.Delete(u).Error
return DB.Delete(u).Error
}
// Update 更新
func (us *GroupService) Update(u *model.Group) error {
return global.DB.Model(u).Updates(u).Error
return DB.Model(u).Updates(u).Error
}
// DeviceGroupInfoById 根据用户id取用户信息
func (us *GroupService) DeviceGroupInfoById(id uint) *model.DeviceGroup {
u := &model.DeviceGroup{}
DB.Where("id = ?", id).First(u)
return u
}
func (us *GroupService) DeviceGroupList(page, pageSize uint, where func(tx *gorm.DB)) (res *model.DeviceGroupList) {
res = &model.DeviceGroupList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := DB.Model(&model.DeviceGroup{})
if where != nil {
where(tx)
}
tx.Count(&res.Total)
tx.Scopes(Paginate(page, pageSize))
tx.Find(&res.DeviceGroups)
return
}
func (us *GroupService) DeviceGroupCreate(u *model.DeviceGroup) error {
res := DB.Create(u).Error
return res
}
func (us *GroupService) DeviceGroupDelete(u *model.DeviceGroup) error {
return DB.Delete(u).Error
}
func (us *GroupService) DeviceGroupUpdate(u *model.DeviceGroup) error {
return DB.Model(u).Updates(u).Error
}

View File

@@ -2,19 +2,23 @@ package service
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net/url"
"os"
"strconv"
"strings"
"github.com/go-ldap/ldap/v3"
"github.com/lejianwen/rustdesk-api/v2/config"
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/model"
)
var (
ErrUrlParseFailed = errors.New("UrlParseFailed")
ErrFileReadFailed = errors.New("FileReadFailed")
ErrLdapNotEnabled = errors.New("LdapNotEnabled")
ErrLdapUserDisabled = errors.New("UserDisabledAtLdap")
ErrLdapUserNotFound = errors.New("UserNotFound")
@@ -26,6 +30,7 @@ var (
ErrLdapBindFailed = errors.New("LdapBindFailed")
ErrLdapToLocalUserFailed = errors.New("LdapToLocalUserFailed")
ErrLdapCreateUserFailed = errors.New("LdapCreateUserFailed")
ErrLdapPasswordNotMatch = errors.New("PasswordNotMatch")
)
// LdapService is responsible for LDAP authentication and user synchronization.
@@ -68,21 +73,38 @@ func (lu *LdapUser) ToUser(u *model.User) *model.User {
// connectAndBind creates an LDAP connection, optionally starts TLS, and then binds using the provided credentials.
func (ls *LdapService) connectAndBind(cfg *config.Ldap, username, password string) (*ldap.Conn, error) {
conn, err := ldap.DialURL(cfg.Url)
u, err := url.Parse(cfg.Url)
if err != nil {
return nil, errors.Join(ErrUrlParseFailed, err)
}
var conn *ldap.Conn
if u.Scheme == "ldaps" {
// WARNING: InsecureSkipVerify: true is not recommended for production
tlsConfig := &tls.Config{InsecureSkipVerify: !cfg.TlsVerify}
if cfg.TlsCaFile != "" {
caCert, err := os.ReadFile(cfg.TlsCaFile)
if err != nil {
return nil, errors.Join(ErrFileReadFailed, err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, errors.Join(ErrLdapTlsFailed, errors.New("failed to append CA certificate"))
}
tlsConfig.RootCAs = caCertPool
}
conn, err = ldap.DialURL(cfg.Url, ldap.DialWithTLSConfig(tlsConfig))
} else {
conn, err = ldap.DialURL(cfg.Url)
}
if err != nil {
return nil, errors.Join(ErrLdapConnectFailed, err)
}
if cfg.TLS {
// WARNING: InsecureSkipVerify: true is not recommended for production
if err = conn.StartTLS(&tls.Config{InsecureSkipVerify: !cfg.TlsVerify}); err != nil {
conn.Close()
return nil, errors.Join(ErrLdapTlsFailed, err)
}
}
// Bind as the "service" user
if err = conn.Bind(username, password); err != nil {
fmt.Println("Bind failed")
conn.Close()
return nil, errors.Join(ErrLdapBindService, err)
}
@@ -98,7 +120,7 @@ func (ls *LdapService) connectAndBindAdmin(cfg *config.Ldap) (*ldap.Conn, error)
func (ls *LdapService) verifyCredentials(cfg *config.Ldap, username, password string) error {
ldapConn, err := ls.connectAndBind(cfg, username, password)
if err != nil {
return err
return ErrLdapPasswordNotMatch
}
defer ldapConn.Close()
return nil
@@ -114,7 +136,11 @@ func (ls *LdapService) Authenticate(username, password string) (*model.User, err
if !ldapUser.Enabled {
return nil, ErrLdapUserDisabled
}
cfg := &global.Config.Ldap
cfg := &Config.Ldap
err = ls.verifyCredentials(cfg, ldapUser.Dn, password)
if err != nil {
return nil, err
}
user, err := ls.mapToLocalUser(cfg, ldapUser)
if err != nil {
return nil, errors.Join(ErrLdapToLocalUserFailed, err)
@@ -135,7 +161,7 @@ func (ls *LdapService) mapToLocalUser(cfg *config.Ldap, lu *LdapUser) (*model.Us
// If needed, you can set a random password here.
newUser.IsAdmin = &isAdmin
newUser.GroupId = 1
if err := global.DB.Create(newUser).Error; err != nil {
if err := DB.Create(newUser).Error; err != nil {
return nil, errors.Join(ErrLdapCreateUserFailed, err)
}
return userService.InfoByUsername(lu.Username), nil
@@ -164,7 +190,7 @@ func (ls *LdapService) mapToLocalUser(cfg *config.Ldap, lu *LdapUser) (*model.Us
// IsUsernameExists checks if a username exists in LDAP (can be useful for local registration checks).
func (ls *LdapService) IsUsernameExists(username string) bool {
cfg := &global.Config.Ldap
cfg := &Config.Ldap
if !cfg.Enable {
return false
}
@@ -177,7 +203,7 @@ func (ls *LdapService) IsUsernameExists(username string) bool {
// IsEmailExists checks if an email exists in LDAP (can be useful for local registration checks).
func (ls *LdapService) IsEmailExists(email string) bool {
cfg := &global.Config.Ldap
cfg := &Config.Ldap
if !cfg.Enable {
return false
}
@@ -190,7 +216,7 @@ func (ls *LdapService) IsEmailExists(email string) bool {
// GetUserInfoByUsernameLdap returns the user info from LDAP for the given username.
func (ls *LdapService) GetUserInfoByUsernameLdap(username string) (*LdapUser, error) {
cfg := &global.Config.Ldap
cfg := &Config.Ldap
if !cfg.Enable {
return nil, ErrLdapNotEnabled
}
@@ -210,12 +236,12 @@ func (ls *LdapService) GetUserInfoByUsernameLocal(username string) (*model.User,
if err != nil {
return &model.User{}, err
}
return ls.mapToLocalUser(&global.Config.Ldap, ldapUser)
return ls.mapToLocalUser(&Config.Ldap, ldapUser)
}
// GetUserInfoByEmailLdap returns the user info from LDAP for the given email.
func (ls *LdapService) GetUserInfoByEmailLdap(email string) (*LdapUser, error) {
cfg := &global.Config.Ldap
cfg := &Config.Ldap
if !cfg.Enable {
return nil, ErrLdapNotEnabled
}
@@ -235,7 +261,7 @@ func (ls *LdapService) GetUserInfoByEmailLocal(email string) (*model.User, error
if err != nil {
return &model.User{}, err
}
return ls.mapToLocalUser(&global.Config.Ldap, ldapUser)
return ls.mapToLocalUser(&Config.Ldap, ldapUser)
}
// usernameSearchResult returns the search result for the given username.
@@ -385,7 +411,7 @@ func (ls *LdapService) isUserAdmin(cfg *config.Ldap, ldapUser *LdapUser) bool {
// Check "memberOf" directly
if len(ldapUser.MemberOf) > 0 {
for _, group := range ldapUser.MemberOf {
if group == adminGroup {
if strings.EqualFold(group, adminGroup) {
return true
}
}
@@ -453,12 +479,12 @@ func (ls *LdapService) isUserEnabled(cfg *config.Ldap, ldapUser *LdapUser) bool
// Account is disabled if the ACCOUNTDISABLE flag (0x2) is set
const ACCOUNTDISABLE = 0x2
ldapUser.Enabled = (userAccountControl&ACCOUNTDISABLE == 0)
ldapUser.Enabled = userAccountControl&ACCOUNTDISABLE == 0
return ldapUser.Enabled
}
// For other attributes, perform a direct comparison with the expected value
ldapUser.Enabled = (ldapUser.EnableAttrValue == enableAttrValue)
ldapUser.Enabled = ldapUser.EnableAttrValue == enableAttrValue
return ldapUser.Enabled
}

View File

@@ -1,7 +1,6 @@
package service
import (
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/model"
"gorm.io/gorm"
)
@@ -12,7 +11,7 @@ type LoginLogService struct {
// InfoById 根据用户id取用户信息
func (us *LoginLogService) InfoById(id uint) *model.LoginLog {
u := &model.LoginLog{}
global.DB.Where("id = ?", id).First(u)
DB.Where("id = ?", id).First(u)
return u
}
@@ -20,7 +19,7 @@ func (us *LoginLogService) List(page, pageSize uint, where func(tx *gorm.DB)) (r
res = &model.LoginLogList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := global.DB.Model(&model.LoginLog{})
tx := DB.Model(&model.LoginLog{})
if where != nil {
where(tx)
}
@@ -32,20 +31,20 @@ func (us *LoginLogService) List(page, pageSize uint, where func(tx *gorm.DB)) (r
// Create 创建
func (us *LoginLogService) Create(u *model.LoginLog) error {
res := global.DB.Create(u).Error
res := DB.Create(u).Error
return res
}
func (us *LoginLogService) Delete(u *model.LoginLog) error {
return global.DB.Delete(u).Error
return DB.Delete(u).Error
}
// Update 更新
func (us *LoginLogService) Update(u *model.LoginLog) error {
return global.DB.Model(u).Updates(u).Error
return DB.Model(u).Updates(u).Error
}
func (us *LoginLogService) BatchDelete(ids []uint) error {
return global.DB.Where("id in (?)", ids).Delete(&model.LoginLog{}).Error
return DB.Where("id in (?)", ids).Delete(&model.LoginLog{}).Error
}
func (us *LoginLogService) SoftDelete(l *model.LoginLog) error {
@@ -54,5 +53,5 @@ func (us *LoginLogService) SoftDelete(l *model.LoginLog) error {
}
func (us *LoginLogService) BatchSoftDelete(uid uint, ids []uint) error {
return global.DB.Model(&model.LoginLog{}).Where("user_id = ? and id in (?)", uid, ids).Update("is_deleted", model.IsDeletedYes).Error
return DB.Model(&model.LoginLog{}).Where("user_id = ? and id in (?)", uid, ids).Update("is_deleted", model.IsDeletedYes).Error
}

View File

@@ -4,7 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/lejianwen/rustdesk-api/v2/model"
"github.com/lejianwen/rustdesk-api/v2/utils"
"golang.org/x/oauth2"
@@ -45,6 +45,8 @@ type OauthCacheItem struct {
Username string `json:"username"`
Name string `json:"name"`
Email string `json:"email"`
Verifier string `json:"verifier"` // used for oauth pkce
Nonce string `json:"nonce"`
}
func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
@@ -81,10 +83,9 @@ func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
func (os *OauthService) SetOauthCache(key string, item *OauthCacheItem, expire uint) {
OauthCache.Store(key, item)
if expire > 0 {
go func() {
time.Sleep(time.Duration(expire) * time.Second)
time.AfterFunc(time.Duration(expire)*time.Second, func() {
os.DeleteOauthCache(key)
}()
})
}
}
@@ -92,151 +93,208 @@ func (os *OauthService) DeleteOauthCache(key string) {
OauthCache.Delete(key)
}
func (os *OauthService) BeginAuth(op string) (error error, code, url string) {
code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
if op == string(model.OauthTypeWebauth) {
url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code
func (os *OauthService) BeginAuth(op string) (error error, state, verifier, nonce, url string) {
state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
verifier = ""
nonce = ""
if op == model.OauthTypeWebauth {
url = Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state
//url = "http://localhost:8888/_admin/#/oauth/" + code
return nil, code, url
return nil, state, verifier, nonce, url
}
err, _, oauthConfig := os.GetOauthConfig(op)
err, oauthInfo, oauthConfig, _ := os.GetOauthConfig(op)
if err == nil {
return err, code, oauthConfig.AuthCodeURL(code)
extras := make([]oauth2.AuthCodeOption, 0, 3)
nonce = utils.RandomString(10)
extras = append(extras, oauth2.SetAuthURLParam("nonce", nonce))
if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable {
extras = append(extras, oauth2.AccessTypeOffline)
verifier = oauth2.GenerateVerifier()
switch oauthInfo.PkceMethod {
case model.PKCEMethodS256:
extras = append(extras, oauth2.S256ChallengeOption(verifier))
case model.PKCEMethodPlain:
// oauth2 does not have a plain challenge option, so we add it manually
extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier))
}
}
return err, state, verifier, nonce, oauthConfig.AuthCodeURL(state, extras...)
}
return err, code, ""
return err, state, verifier, nonce, ""
}
// Method to fetch OIDC configuration dynamically
func (os *OauthService) FetchOidcEndpoint(issuer string) (error, OidcEndpoint) {
configURL := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
func (os *OauthService) FetchOidcProvider(issuer string) (error, *oidc.Provider) {
// Get the HTTP client (with or without proxy based on configuration)
client := getHTTPClientWithProxy()
resp, err := client.Get(configURL)
ctx := oidc.ClientContext(context.Background(), client)
provider, err := oidc.NewProvider(ctx, issuer)
if err != nil {
return errors.New("failed to fetch OIDC configuration"), OidcEndpoint{}
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.New("OIDC configuration not found, status code: %d"), OidcEndpoint{}
return err, nil
}
var endpoint OidcEndpoint
if err := json.NewDecoder(resp.Body).Decode(&endpoint); err != nil {
return errors.New("failed to parse OIDC configuration"), OidcEndpoint{}
}
return nil, endpoint
return nil, provider
}
func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) {
oauthInfo := os.InfoByOp(op)
if oauthInfo.Issuer == "" {
return errors.New("issuer is empty"), OidcEndpoint{}
}
return os.FetchOidcEndpoint(oauthInfo.Issuer)
func (os *OauthService) GithubProvider() *oidc.Provider {
return (&oidc.ProviderConfig{
IssuerURL: "",
AuthURL: github.Endpoint.AuthURL,
TokenURL: github.Endpoint.TokenURL,
DeviceAuthURL: github.Endpoint.DeviceAuthURL,
UserInfoURL: model.UserEndpointGithub,
JWKSURL: "",
Algorithms: nil,
}).NewProvider(context.Background())
}
func (os *OauthService) LinuxdoProvider() *oidc.Provider {
return (&oidc.ProviderConfig{
IssuerURL: "",
AuthURL: "https://connect.linux.do/oauth2/authorize",
TokenURL: "https://connect.linux.do/oauth2/token",
DeviceAuthURL: "",
UserInfoURL: model.UserEndpointLinuxdo,
JWKSURL: "",
Algorithms: nil,
}).NewProvider(context.Background())
}
// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
if err != nil {
return err, nil, nil
func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config, provider *oidc.Provider) {
//err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
oauthInfo = os.InfoByOp(op)
if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" {
return errors.New("ConfigNotFound"), nil, nil, nil
}
// If the redirect URL is empty, use the default redirect URL
if oauthInfo.RedirectUrl == "" {
oauthInfo.RedirectUrl = Config.Rustdesk.ApiServer + "/api/oidc/callback"
}
oauthConfig = &oauth2.Config{
ClientID: oauthInfo.ClientId,
ClientSecret: oauthInfo.ClientSecret,
RedirectURL: oauthInfo.RedirectUrl,
}
// Maybe should validate the oauthConfig here
oauthType := oauthInfo.OauthType
err = model.ValidateOauthType(oauthType)
if err != nil {
return err, nil, nil
return err, nil, nil, nil
}
switch oauthType {
case model.OauthTypeGithub:
oauthConfig.Endpoint = github.Endpoint
oauthConfig.Scopes = []string{"read:user", "user:email"}
provider = os.GithubProvider()
case model.OauthTypeLinuxdo:
provider = os.LinuxdoProvider()
oauthConfig.Endpoint = provider.Endpoint()
oauthConfig.Scopes = []string{"profile"}
//case model.OauthTypeGoogle: //google单独出来可以少一次FetchOidcEndpoint请求
// oauthConfig.Endpoint = google.Endpoint
// oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
case model.OauthTypeOidc, model.OauthTypeGoogle:
var endpoint OidcEndpoint
err, endpoint = os.FetchOidcEndpoint(oauthInfo.Issuer)
err, provider = os.FetchOidcProvider(oauthInfo.Issuer)
if err != nil {
return err, nil, nil
return err, nil, nil, nil
}
oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL, TokenURL: endpoint.TokenURL}
oauthConfig.Endpoint = provider.Endpoint()
oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
default:
return errors.New("unsupported OAuth type"), nil, nil
}
return nil, oauthInfo, oauthConfig
}
// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
oauthInfo = os.InfoByOp(op)
if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" {
return errors.New("ConfigNotFound"), nil, nil
}
// If the redirect URL is empty, use the default redirect URL
if oauthInfo.RedirectUrl == "" {
oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback"
}
return nil, oauthInfo, &oauth2.Config{
ClientID: oauthInfo.ClientId,
ClientSecret: oauthInfo.ClientSecret,
RedirectURL: oauthInfo.RedirectUrl,
return errors.New("unsupported OAuth type"), nil, nil, nil
}
return nil, oauthInfo, oauthConfig, provider
}
func getHTTPClientWithProxy() *http.Client {
//todo add timeout
if global.Config.Proxy.Enable {
if global.Config.Proxy.Host == "" {
global.Logger.Warn("Proxy is enabled but proxy host is empty.")
//add timeout 30s
timeout := time.Duration(60) * time.Second
if Config.Proxy.Enable {
if Config.Proxy.Host == "" {
Logger.Warn("Proxy is enabled but proxy host is empty.")
return http.DefaultClient
}
proxyURL, err := url.Parse(global.Config.Proxy.Host)
proxyURL, err := url.Parse(Config.Proxy.Host)
if err != nil {
global.Logger.Warn("Invalid proxy URL: ", err)
Logger.Warn("Invalid proxy URL: ", err)
return http.DefaultClient
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
return &http.Client{Transport: transport}
return &http.Client{Transport: transport, Timeout: timeout}
}
return http.DefaultClient
}
func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string, nonce string, userData interface{}) (err error, client *http.Client) {
// 设置代理客户端
httpClient := getHTTPClientWithProxy()
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
// 使用 code 换取 token
var token *oauth2.Token
token, err = oauthConfig.Exchange(ctx, code)
exchangeOpts := make([]oauth2.AuthCodeOption, 0, 1)
if verifier != "" {
exchangeOpts = append(exchangeOpts, oauth2.VerifierOption(verifier))
}
token, err := oauthConfig.Exchange(ctx, code, exchangeOpts...)
if err != nil {
global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
Logger.Warn("oauthConfig.Exchange() failed: ", err)
return errors.New("GetOauthTokenError"), nil
}
// 获取 ID Token github没有id_token
rawIDToken, ok := token.Extra("id_token").(string)
if ok && rawIDToken != "" {
// 验证 ID Token
v := provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID})
idToken, err2 := v.Verify(ctx, rawIDToken)
if err2 != nil {
Logger.Warn("IdTokenVerifyError: ", err2)
return errors.New("IdTokenVerifyError"), nil
}
if nonce != "" {
// 验证 nonce
var claims struct {
Nonce string `json:"nonce"`
}
if err2 = idToken.Claims(&claims); err2 != nil {
Logger.Warn("Failed to parse ID Token claims: ", err)
return errors.New("IDTokenClaimsError"), nil
}
if claims.Nonce != nonce {
Logger.Warn("Nonce does not match")
return errors.New("NonceDoesNotMatch"), nil
}
}
}
// 获取用户信息
client = oauthConfig.Client(ctx, token)
resp, err := client.Get(userEndpoint)
resp, err := client.Get(provider.UserInfoEndpoint())
if err != nil {
global.Logger.Warn("failed getting user info: ", err)
Logger.Warn("failed getting user info: ", err)
return errors.New("GetOauthUserInfoError"), nil
}
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
global.Logger.Warn("failed closing response body: ", closeErr)
Logger.Warn("failed closing response body: ", closeErr)
}
}()
// 解析用户信息
if err = json.NewDecoder(resp.Body).Decode(userData); err != nil {
global.Logger.Warn("failed decoding user info: ", err)
Logger.Warn("failed decoding user info: ", err)
return errors.New("DecodeOauthUserInfoError"), nil
}
@@ -244,9 +302,9 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, us
}
// githubCallback github回调
func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) {
func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) {
var user = &model.GithubUser{}
err, client := os.callbackBase(oauthConfig, code, model.UserEndpointGithub, user)
err, client := os.callbackBase(oauthConfig, provider, code, verifier, nonce, user)
if err != nil {
return err, nil
}
@@ -257,20 +315,28 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string)
return nil, user.ToOauthUser()
}
// linuxdoCallback linux.do回调
func (os *OauthService) linuxdoCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) {
var user = &model.LinuxdoUser{}
err, _ := os.callbackBase(oauthConfig, provider, code, verifier, nonce, user)
if err != nil {
return err, nil
}
return nil, user.ToOauthUser()
}
// oidcCallback oidc回调, 通过code获取用户信息
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) {
func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code, verifier, nonce string) (error, *model.OauthUser) {
var user = &model.OidcUser{}
if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
if err, _ := os.callbackBase(oauthConfig, provider, code, verifier, nonce, user); err != nil {
return err, nil
}
return nil, user.ToOauthUser()
}
// Callback: Get user information by code and op(Oauth provider)
func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) {
var oauthInfo *model.Oauth
var oauthConfig *oauth2.Config
err, oauthInfo, oauthConfig = os.GetOauthConfig(op)
func (os *OauthService) Callback(code, verifier, op, nonce string) (err error, oauthUser *model.OauthUser) {
err, oauthInfo, oauthConfig, provider := os.GetOauthConfig(op)
// oauthType is already validated in GetOauthConfig
if err != nil {
return err, nil
@@ -278,13 +344,11 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
oauthType := oauthInfo.OauthType
switch oauthType {
case model.OauthTypeGithub:
err, oauthUser = os.githubCallback(oauthConfig, code)
err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier, nonce)
case model.OauthTypeLinuxdo:
err, oauthUser = os.linuxdoCallback(oauthConfig, provider, code, verifier, nonce)
case model.OauthTypeOidc, model.OauthTypeGoogle:
err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
if err != nil {
return err, nil
}
err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier, nonce)
default:
return errors.New("unsupported OAuth type"), nil
}
@@ -293,7 +357,7 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird {
ut := &model.UserThird{}
global.DB.Where("open_id = ? and op = ?", openId, op).First(ut)
DB.Where("open_id = ? and op = ?", openId, op).First(ut)
return ut
}
@@ -305,7 +369,7 @@ func (os *OauthService) BindOauthUser(userId uint, oauthUser *model.OauthUser, o
return err
}
utr.FromOauthUser(userId, oauthUser, oauthType, op)
return global.DB.Create(utr).Error
return DB.Create(utr).Error
}
// UnBindOauthUser: Unbind third party account
@@ -315,25 +379,25 @@ func (os *OauthService) UnBindOauthUser(userId uint, op string) error {
// UnBindThird: Unbind third party account
func (os *OauthService) UnBindThird(op string, userId uint) error {
return global.DB.Where("user_id = ? and op = ?", userId, op).Delete(&model.UserThird{}).Error
return DB.Where("user_id = ? and op = ?", userId, op).Delete(&model.UserThird{}).Error
}
// DeleteUserByUserId: When user is deleted, delete all third party bindings
func (os *OauthService) DeleteUserByUserId(userId uint) error {
return global.DB.Where("user_id = ?", userId).Delete(&model.UserThird{}).Error
return DB.Where("user_id = ?", userId).Delete(&model.UserThird{}).Error
}
// InfoById 根据id获取Oauth信息
func (os *OauthService) InfoById(id uint) *model.Oauth {
oauthInfo := &model.Oauth{}
global.DB.Where("id = ?", id).First(oauthInfo)
DB.Where("id = ?", id).First(oauthInfo)
return oauthInfo
}
// InfoByOp 根据op获取Oauth信息
func (os *OauthService) InfoByOp(op string) *model.Oauth {
oauthInfo := &model.Oauth{}
global.DB.Where("op = ?", op).First(oauthInfo)
DB.Where("op = ?", op).First(oauthInfo)
return oauthInfo
}
@@ -356,7 +420,7 @@ func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res
res = &model.OauthList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := global.DB.Model(&model.Oauth{})
tx := DB.Model(&model.Oauth{})
if where != nil {
where(tx)
}
@@ -369,7 +433,7 @@ func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res
// GetTypeByOp 根据op获取OauthType
func (os *OauthService) GetTypeByOp(op string) (error, string) {
oauthInfo := &model.Oauth{}
if global.DB.Where("op = ?", op).First(oauthInfo).Error != nil {
if DB.Where("op = ?", op).First(oauthInfo).Error != nil {
return fmt.Errorf("OAuth provider with op '%s' not found", op), ""
}
return nil, oauthInfo.OauthType
@@ -387,7 +451,7 @@ func (os *OauthService) ValidateOauthProvider(op string) error {
func (os *OauthService) IsOauthProviderExist(op string) bool {
oauthInfo := &model.Oauth{}
// 使用 Gorm 的 Take 方法查找符合条件的记录
if err := global.DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil {
if err := DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil {
return false
}
return true
@@ -399,11 +463,11 @@ func (os *OauthService) Create(oauthInfo *model.Oauth) error {
if err != nil {
return err
}
res := global.DB.Create(oauthInfo).Error
res := DB.Create(oauthInfo).Error
return res
}
func (os *OauthService) Delete(oauthInfo *model.Oauth) error {
return global.DB.Delete(oauthInfo).Error
return DB.Delete(oauthInfo).Error
}
// Update 更新
@@ -412,13 +476,13 @@ func (os *OauthService) Update(oauthInfo *model.Oauth) error {
if err != nil {
return err
}
return global.DB.Model(oauthInfo).Updates(oauthInfo).Error
return DB.Model(oauthInfo).Updates(oauthInfo).Error
}
// GetOauthProviders 获取所有的provider
func (os *OauthService) GetOauthProviders() []string {
var res []string
global.DB.Model(&model.Oauth{}).Pluck("op", &res)
DB.Model(&model.Oauth{}).Pluck("op", &res)
return res
}

View File

@@ -1,7 +1,6 @@
package service
import (
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/model"
"gorm.io/gorm"
)
@@ -12,24 +11,24 @@ type PeerService struct {
// FindById 根据id查找
func (ps *PeerService) FindById(id string) *model.Peer {
p := &model.Peer{}
global.DB.Where("id = ?", id).First(p)
DB.Where("id = ?", id).First(p)
return p
}
func (ps *PeerService) FindByUuid(uuid string) *model.Peer {
p := &model.Peer{}
global.DB.Where("uuid = ?", uuid).First(p)
DB.Where("uuid = ?", uuid).First(p)
return p
}
func (ps *PeerService) InfoByRowId(id uint) *model.Peer {
p := &model.Peer{}
global.DB.Where("row_id = ?", id).First(p)
DB.Where("row_id = ?", id).First(p)
return p
}
// FindByUserIdAndUuid 根据用户id和uuid查找peer
func (ps *PeerService) FindByUserIdAndUuid(uuid string, userId uint) *model.Peer {
p := &model.Peer{}
global.DB.Where("uuid = ? and user_id = ?", uuid, userId).First(p)
DB.Where("uuid = ? and user_id = ?", uuid, userId).First(p)
return p
}
@@ -43,7 +42,7 @@ func (ps *PeerService) UuidBindUserId(deviceId string, uuid string, userId uint)
} else {
// 不存在则创建
/*if deviceId != "" {
global.DB.Create(&model.Peer{
DB.Create(&model.Peer{
Id: deviceId,
Uuid: uuid,
UserId: userId,
@@ -56,13 +55,13 @@ func (ps *PeerService) UuidBindUserId(deviceId string, uuid string, userId uint)
func (ps *PeerService) UuidUnbindUserId(uuid string, userId uint) {
peer := ps.FindByUserIdAndUuid(uuid, userId)
if peer.RowId > 0 {
global.DB.Model(peer).Update("user_id", 0)
DB.Model(peer).Update("user_id", 0)
}
}
// EraseUserId 清除用户id, 用于用户删除
func (ps *PeerService) EraseUserId(userId uint) error {
return global.DB.Model(&model.Peer{}).Where("user_id = ?", userId).Update("user_id", 0).Error
return DB.Model(&model.Peer{}).Where("user_id = ?", userId).Update("user_id", 0).Error
}
// ListByUserIds 根据用户id取列表
@@ -70,7 +69,7 @@ func (ps *PeerService) ListByUserIds(userIds []uint, page, pageSize uint) (res *
res = &model.PeerList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := global.DB.Model(&model.Peer{})
tx := DB.Model(&model.Peer{})
tx.Where("user_id in (?)", userIds)
tx.Count(&res.Total)
tx.Scopes(Paginate(page, pageSize))
@@ -82,7 +81,7 @@ func (ps *PeerService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *
res = &model.PeerList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := global.DB.Model(&model.Peer{})
tx := DB.Model(&model.Peer{})
if where != nil {
where(tx)
}
@@ -106,14 +105,14 @@ func (ps *PeerService) ListFilterByUserId(page, pageSize uint, where func(tx *go
// Create 创建
func (ps *PeerService) Create(u *model.Peer) error {
res := global.DB.Create(u).Error
res := DB.Create(u).Error
return res
}
// Delete 删除, 同时也应该删除token
func (ps *PeerService) Delete(u *model.Peer) error {
uuid := u.Uuid
err := global.DB.Delete(u).Error
err := DB.Delete(u).Error
if err != nil {
return err
}
@@ -124,16 +123,23 @@ func (ps *PeerService) Delete(u *model.Peer) error {
// GetUuidListByIDs 根据ids获取uuid列表
func (ps *PeerService) GetUuidListByIDs(ids []uint) ([]string, error) {
var uuids []string
err := global.DB.Model(&model.Peer{}).
err := DB.Model(&model.Peer{}).
Where("row_id in (?)", ids).
Pluck("uuid", &uuids).Error
return uuids, err
//过滤uuids中的空字符串
var newUuids []string
for _, uuid := range uuids {
if uuid != "" {
newUuids = append(newUuids, uuid)
}
}
return newUuids, err
}
// BatchDelete 批量删除, 同时也应该删除token
func (ps *PeerService) BatchDelete(ids []uint) error {
uuids, err := ps.GetUuidListByIDs(ids)
err = global.DB.Where("row_id in (?)", ids).Delete(&model.Peer{}).Error
err = DB.Where("row_id in (?)", ids).Delete(&model.Peer{}).Error
if err != nil {
return err
}
@@ -143,5 +149,5 @@ func (ps *PeerService) BatchDelete(ids []uint) error {
// Update 更新
func (ps *PeerService) Update(u *model.Peer) error {
return global.DB.Model(u).Updates(u).Error
return DB.Model(u).Updates(u).Error
}

View File

@@ -2,7 +2,6 @@ package service
import (
"fmt"
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/model"
"net"
"time"
@@ -15,7 +14,7 @@ func (is *ServerCmdService) List(page, pageSize uint) (res *model.ServerCmdList)
res = &model.ServerCmdList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := global.DB.Model(&model.ServerCmd{})
tx := DB.Model(&model.ServerCmd{})
tx.Count(&res.Total)
tx.Scopes(Paginate(page, pageSize))
tx.Find(&res.ServerCmds)
@@ -25,30 +24,23 @@ func (is *ServerCmdService) List(page, pageSize uint) (res *model.ServerCmdList)
// Info
func (is *ServerCmdService) Info(id uint) *model.ServerCmd {
u := &model.ServerCmd{}
global.DB.Where("id = ?", id).First(u)
DB.Where("id = ?", id).First(u)
return u
}
// Delete
func (is *ServerCmdService) Delete(u *model.ServerCmd) error {
return global.DB.Delete(u).Error
return DB.Delete(u).Error
}
// Create
func (is *ServerCmdService) Create(u *model.ServerCmd) error {
res := global.DB.Create(u).Error
res := DB.Create(u).Error
return res
}
// SendCmd 发送命令
func (is *ServerCmdService) SendCmd(target string, cmd string, arg string) (string, error) {
port := 0
switch target {
case model.ServerCmdTargetIdServer:
port = global.Config.Rustdesk.IdServerPort - 1
case model.ServerCmdTargetRelayServer:
port = global.Config.Rustdesk.RelayServerPort
}
func (is *ServerCmdService) SendCmd(port int, cmd string, arg string) (string, error) {
//组装命令
cmd = cmd + " " + arg
res, err := is.SendSocketCmd("v6", port, cmd)
@@ -73,14 +65,14 @@ func (is *ServerCmdService) SendSocketCmd(ty string, port int, cmd string) (stri
}
conn, err := net.Dial(tcp, fmt.Sprintf("%s:%v", addr, port))
if err != nil {
global.Logger.Debugf("%s connect to id server failed: %v", ty, err)
Logger.Debugf("%s connect to id server failed: %v", ty, err)
return "", err
}
defer conn.Close()
//发送命令
_, err = conn.Write([]byte(cmd))
if err != nil {
global.Logger.Debugf("%s send cmd failed: %v", ty, err)
Logger.Debugf("%s send cmd failed: %v", ty, err)
return "", err
}
time.Sleep(100 * time.Millisecond)
@@ -88,12 +80,12 @@ func (is *ServerCmdService) SendSocketCmd(ty string, port int, cmd string) (stri
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil && err.Error() != "EOF" {
global.Logger.Debugf("%s read response failed: %v", ty, err)
Logger.Debugf("%s read response failed: %v", ty, err)
return "", err
}
return string(buf[:n]), nil
}
func (is *ServerCmdService) Update(f *model.ServerCmd) error {
return global.DB.Model(f).Updates(f).Error
return DB.Model(f).Updates(f).Error
}

View File

@@ -1,7 +1,11 @@
package service
import (
"github.com/lejianwen/rustdesk-api/v2/config"
"github.com/lejianwen/rustdesk-api/v2/lib/jwt"
"github.com/lejianwen/rustdesk-api/v2/lib/lock"
"github.com/lejianwen/rustdesk-api/v2/model"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
)
@@ -19,14 +23,34 @@ type Service struct {
*ShareRecordService
*ServerCmdService
*LdapService
*AppService
}
func New() *Service {
all := new(Service)
return all
type Dependencies struct {
Config *config.Config
DB *gorm.DB
Logger *log.Logger
Jwt *jwt.Jwt
Lock *lock.Locker
}
var AllService = New()
var Config *config.Config
var DB *gorm.DB
var Logger *log.Logger
var Jwt *jwt.Jwt
var Lock lock.Locker
var AllService *Service
func New(c *config.Config, g *gorm.DB, l *log.Logger, j *jwt.Jwt, lo lock.Locker) *Service {
Config = c
DB = g
Logger = l
Jwt = j
Lock = lo
AllService = new(Service)
return AllService
}
func Paginate(page, pageSize uint) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {

View File

@@ -1,7 +1,6 @@
package service
import (
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/model"
"gorm.io/gorm"
)
@@ -12,7 +11,7 @@ type ShareRecordService struct {
// InfoById 根据用户id取用户信息
func (srs *ShareRecordService) InfoById(id uint) *model.ShareRecord {
u := &model.ShareRecord{}
global.DB.Where("id = ?", id).First(u)
DB.Where("id = ?", id).First(u)
return u
}
@@ -20,7 +19,7 @@ func (srs *ShareRecordService) List(page, pageSize uint, where func(tx *gorm.DB)
res = &model.ShareRecordList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := global.DB.Model(&model.ShareRecord{})
tx := DB.Model(&model.ShareRecord{})
if where != nil {
where(tx)
}
@@ -32,18 +31,18 @@ func (srs *ShareRecordService) List(page, pageSize uint, where func(tx *gorm.DB)
// Create 创建
func (srs *ShareRecordService) Create(u *model.ShareRecord) error {
res := global.DB.Create(u).Error
res := DB.Create(u).Error
return res
}
func (srs *ShareRecordService) Delete(u *model.ShareRecord) error {
return global.DB.Delete(u).Error
return DB.Delete(u).Error
}
// Update 更新
func (srs *ShareRecordService) Update(u *model.ShareRecord) error {
return global.DB.Model(u).Updates(u).Error
return DB.Model(u).Updates(u).Error
}
func (srs *ShareRecordService) BatchDelete(ids []uint) error {
return global.DB.Where("id in (?)", ids).Delete(&model.ShareRecord{}).Error
return DB.Where("id in (?)", ids).Delete(&model.ShareRecord{}).Error
}

View File

@@ -1,7 +1,6 @@
package service
import (
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/model"
"gorm.io/gorm"
)
@@ -11,12 +10,12 @@ type TagService struct {
func (s *TagService) Info(id uint) *model.Tag {
p := &model.Tag{}
global.DB.Where("id = ?", id).First(p)
DB.Where("id = ?", id).First(p)
return p
}
func (s *TagService) InfoByUserIdAndNameAndCollectionId(userid uint, name string, cid uint) *model.Tag {
p := &model.Tag{}
global.DB.Where("user_id = ? and name = ? and collection_id = ?", userid, name, cid).First(p)
DB.Where("user_id = ? and name = ? and collection_id = ?", userid, name, cid).First(p)
return p
}
@@ -34,7 +33,7 @@ func (s *TagService) ListByUserIdAndCollectionId(userId, cid uint) (res *model.T
return
}
func (s *TagService) UpdateTags(userId uint, tags map[string]uint) {
tx := global.DB.Begin()
tx := DB.Begin()
//先查询所有tag
var allTags []*model.Tag
tx.Where("user_id = ?", userId).Find(&allTags)
@@ -66,7 +65,7 @@ func (s *TagService) UpdateTags(userId uint, tags map[string]uint) {
// InfoById 根据用户id取用户信息
func (s *TagService) InfoById(id uint) *model.Tag {
u := &model.Tag{}
global.DB.Where("id = ?", id).First(u)
DB.Where("id = ?", id).First(u)
return u
}
@@ -74,7 +73,7 @@ func (s *TagService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *mo
res = &model.TagList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := global.DB.Model(&model.Tag{})
tx := DB.Model(&model.Tag{})
if where != nil {
where(tx)
}
@@ -86,14 +85,14 @@ func (s *TagService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *mo
// Create 创建
func (s *TagService) Create(u *model.Tag) error {
res := global.DB.Create(u).Error
res := DB.Create(u).Error
return res
}
func (s *TagService) Delete(u *model.Tag) error {
return global.DB.Delete(u).Error
return DB.Delete(u).Error
}
// Update 更新
func (s *TagService) Update(u *model.Tag) error {
return global.DB.Model(u).Select("*").Omit("created_at").Updates(u).Error
return DB.Model(u).Select("*").Omit("created_at").Updates(u).Error
}

View File

@@ -2,7 +2,6 @@ package service
import (
"errors"
"github.com/lejianwen/rustdesk-api/v2/global"
"github.com/lejianwen/rustdesk-api/v2/model"
"github.com/lejianwen/rustdesk-api/v2/utils"
"math/rand"
@@ -20,43 +19,43 @@ type UserService struct {
// InfoById 根据用户id取用户信息
func (us *UserService) InfoById(id uint) *model.User {
u := &model.User{}
global.DB.Where("id = ?", id).First(u)
DB.Where("id = ?", id).First(u)
return u
}
// InfoByUsername 根据用户名取用户信息
func (us *UserService) InfoByUsername(un string) *model.User {
u := &model.User{}
global.DB.Where("username = ?", un).First(u)
DB.Where("username = ?", un).First(u)
return u
}
// InfoByEmail 根据邮箱取用户信息
func (us *UserService) InfoByEmail(email string) *model.User {
u := &model.User{}
global.DB.Where("email = ?", email).First(u)
DB.Where("email = ?", email).First(u)
return u
}
// InfoByOpenid 根据openid取用户信息
func (us *UserService) InfoByOpenid(openid string) *model.User {
u := &model.User{}
global.DB.Where("openid = ?", openid).First(u)
DB.Where("openid = ?", openid).First(u)
return u
}
// InfoByUsernamePassword 根据用户名密码取用户信息
func (us *UserService) InfoByUsernamePassword(username, password string) *model.User {
if global.Config.Ldap.Enable {
if Config.Ldap.Enable {
u, err := AllService.LdapService.Authenticate(username, password)
if err == nil {
return u
}
global.Logger.Errorf("LDAP authentication failed, %v", err)
global.Logger.Warn("Fallback to local database")
Logger.Errorf("LDAP authentication failed, %v", err)
Logger.Warn("Fallback to local database")
}
u := &model.User{}
global.DB.Where("username = ? and password = ?", username, us.EncryptPassword(password)).First(u)
DB.Where("username = ? and password = ?", username, us.EncryptPassword(password)).First(u)
return u
}
@@ -64,21 +63,21 @@ func (us *UserService) InfoByUsernamePassword(username, password string) *model.
func (us *UserService) InfoByAccessToken(token string) (*model.User, *model.UserToken) {
u := &model.User{}
ut := &model.UserToken{}
global.DB.Where("token = ?", token).First(ut)
DB.Where("token = ?", token).First(ut)
if ut.Id == 0 {
return u, ut
}
if ut.ExpiredAt < time.Now().Unix() {
return u, ut
}
global.DB.Where("id = ?", ut.UserId).First(u)
DB.Where("id = ?", ut.UserId).First(u)
return u, ut
}
// GenerateToken 生成token
func (us *UserService) GenerateToken(u *model.User) string {
if len(global.Jwt.Key) > 0 {
return global.Jwt.GenerateToken(u.Id)
if len(Jwt.Key) > 0 {
return Jwt.GenerateToken(u.Id)
}
return utils.Md5(u.Username + time.Now().String())
}
@@ -93,9 +92,9 @@ func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserTok
DeviceId: llog.DeviceId,
ExpiredAt: us.UserTokenExpireTimestamp(),
}
global.DB.Create(ut)
DB.Create(ut)
llog.UserTokenId = ut.UserId
global.DB.Create(llog)
DB.Create(llog)
if llog.Uuid != "" {
AllService.PeerService.UuidBindUserId(llog.DeviceId, llog.Uuid, u.Id)
}
@@ -116,7 +115,7 @@ func (us *UserService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *
res = &model.UserList{}
res.Page = int64(page)
res.PageSize = int64(pageSize)
tx := global.DB.Model(&model.User{})
tx := DB.Model(&model.User{})
if where != nil {
where(tx)
}
@@ -127,7 +126,7 @@ func (us *UserService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *
}
func (us *UserService) ListByIds(ids []uint) (res []*model.User) {
global.DB.Where("id in ?", ids).Find(&res)
DB.Where("id in ?", ids).Find(&res)
return res
}
@@ -141,14 +140,14 @@ func (us *UserService) ListByGroupId(groupId, page, pageSize uint) (res *model.U
// ListIdsByGroupId 根据组id取用户id列表
func (us *UserService) ListIdsByGroupId(groupId uint) (ids []uint) {
global.DB.Model(&model.User{}).Where("group_id = ?", groupId).Pluck("id", &ids)
DB.Model(&model.User{}).Where("group_id = ?", groupId).Pluck("id", &ids)
return ids
}
// ListIdAndNameByGroupId 根据组id取用户id和用户名列表
func (us *UserService) ListIdAndNameByGroupId(groupId uint) (res []*model.User) {
global.DB.Model(&model.User{}).Where("group_id = ?", groupId).Select("id, username").Find(&res)
DB.Model(&model.User{}).Where("group_id = ?", groupId).Select("id, username").Find(&res)
return res
}
@@ -170,14 +169,14 @@ func (us *UserService) Create(u *model.User) error {
}
u.Username = us.formatUsername(u.Username)
u.Password = us.EncryptPassword(u.Password)
res := global.DB.Create(u).Error
res := DB.Create(u).Error
return res
}
// GetUuidByToken 根据token和user取uuid
func (us *UserService) GetUuidByToken(u *model.User, token string) string {
ut := &model.UserToken{}
err := global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error
err := DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error
if err != nil {
return ""
}
@@ -187,7 +186,7 @@ func (us *UserService) GetUuidByToken(u *model.User, token string) string {
// Logout 退出登录 -> 删除token, 解绑uuid
func (us *UserService) Logout(u *model.User, token string) error {
uuid := us.GetUuidByToken(u, token)
err := global.DB.Where("user_id = ? and token = ?", u.Id, token).Delete(&model.UserToken{}).Error
err := DB.Where("user_id = ? and token = ?", u.Id, token).Delete(&model.UserToken{}).Error
if err != nil {
return err
}
@@ -203,7 +202,7 @@ func (us *UserService) Delete(u *model.User) error {
if userCount <= 1 && us.IsAdmin(u) {
return errors.New("The last admin user cannot be deleted")
}
tx := global.DB.Begin()
tx := DB.Begin()
// 删除用户
if err := tx.Delete(u).Error; err != nil {
tx.Rollback()
@@ -232,7 +231,7 @@ func (us *UserService) Delete(u *model.User) error {
tx.Commit()
// 删除关联的peer
if err := AllService.PeerService.EraseUserId(u.Id); err != nil {
global.Logger.Warn("User deleted successfully, but failed to unlink peer.")
Logger.Warn("User deleted successfully, but failed to unlink peer.")
return nil
}
return nil
@@ -249,28 +248,28 @@ func (us *UserService) Update(u *model.User) error {
return errors.New("The last admin user cannot be disabled or demoted")
}
}
return global.DB.Model(u).Updates(u).Error
return DB.Model(u).Updates(u).Error
}
// FlushToken 清空token
func (us *UserService) FlushToken(u *model.User) error {
return global.DB.Where("user_id = ?", u.Id).Delete(&model.UserToken{}).Error
return DB.Where("user_id = ?", u.Id).Delete(&model.UserToken{}).Error
}
// FlushTokenByUuid 清空token
func (us *UserService) FlushTokenByUuid(uuid string) error {
return global.DB.Where("device_uuid = ?", uuid).Delete(&model.UserToken{}).Error
return DB.Where("device_uuid = ?", uuid).Delete(&model.UserToken{}).Error
}
// FlushTokenByUuids 清空token
func (us *UserService) FlushTokenByUuids(uuids []string) error {
return global.DB.Where("device_uuid in (?)", uuids).Delete(&model.UserToken{}).Error
return DB.Where("device_uuid in (?)", uuids).Delete(&model.UserToken{}).Error
}
// UpdatePassword 更新密码
func (us *UserService) UpdatePassword(u *model.User, password string) error {
u.Password = us.EncryptPassword(password)
err := global.DB.Model(u).Update("password", u.Password).Error
err := DB.Model(u).Update("password", u.Password).Error
if err != nil {
return err
}
@@ -306,8 +305,8 @@ func (us *UserService) InfoByOauthId(op string, openId string) *model.User {
// RegisterByOauth 注册
func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser, op string) (error, *model.User) {
global.Lock.Lock("registerByOauth")
defer global.Lock.UnLock("registerByOauth")
Lock.Lock("registerByOauth")
defer Lock.UnLock("registerByOauth")
ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId)
if ut.Id != 0 {
return nil, us.InfoById(ut.UserId)
@@ -335,12 +334,12 @@ func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser, op string) (e
}
if user.Id != 0 {
ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
global.DB.Create(ut)
DB.Create(ut)
return nil, user
}
}
tx := global.DB.Begin()
tx := DB.Begin()
ut = &model.UserThird{}
ut.FromOauthUser(0, oauthUser, oauthType, op)
// The initial username should be formatted
@@ -372,27 +371,27 @@ func (us *UserService) GenerateUsernameByOauth(name string) string {
// UserThirdsByUserId
func (us *UserService) UserThirdsByUserId(userId uint) (res []*model.UserThird) {
global.DB.Where("user_id = ?", userId).Find(&res)
DB.Where("user_id = ?", userId).Find(&res)
return res
}
func (us *UserService) UserThirdInfo(userId uint, op string) *model.UserThird {
ut := &model.UserThird{}
global.DB.Where("user_id = ? and op = ?", userId, op).First(ut)
DB.Where("user_id = ? and op = ?", userId, op).First(ut)
return ut
}
// FindLatestUserIdFromLoginLogByUuid 根据uuid查找最后登录的用户id
func (us *UserService) FindLatestUserIdFromLoginLogByUuid(uuid string) uint {
llog := &model.LoginLog{}
global.DB.Where("uuid = ?", uuid).Order("id desc").First(llog)
DB.Where("uuid = ?", uuid).Order("id desc").First(llog)
return llog.UserId
}
// IsPasswordEmptyById 根据用户id判断密码是否为空主要用于第三方登录的自动注册
func (us *UserService) IsPasswordEmptyById(id uint) bool {
u := &model.User{}
if global.DB.Where("id = ?", id).First(u).Error != nil {
if DB.Where("id = ?", id).First(u).Error != nil {
return false
}
return u.Password == ""
@@ -401,7 +400,7 @@ func (us *UserService) IsPasswordEmptyById(id uint) bool {
// IsPasswordEmptyByUsername 根据用户id判断密码是否为空主要用于第三方登录的自动注册
func (us *UserService) IsPasswordEmptyByUsername(username string) bool {
u := &model.User{}
if global.DB.Where("username = ?", username).First(u).Error != nil {
if DB.Where("username = ?", username).First(u).Error != nil {
return false
}
return u.Password == ""
@@ -413,12 +412,13 @@ func (us *UserService) IsPasswordEmptyByUser(u *model.User) bool {
}
// Register 注册, 如果用户名已存在则返回nil
func (us *UserService) Register(username string, email string, password string) *model.User {
func (us *UserService) Register(username string, email string, password string, status model.StatusCode) *model.User {
u := &model.User{
Username: username,
Email: email,
Password: password,
GroupId: 1,
Status: status,
}
err := us.Create(u)
if err != nil {
@@ -431,7 +431,7 @@ func (us *UserService) TokenList(page uint, size uint, f func(tx *gorm.DB)) *mod
res := &model.UserTokenList{}
res.Page = int64(page)
res.PageSize = int64(size)
tx := global.DB.Model(&model.UserToken{})
tx := DB.Model(&model.UserToken{})
if f != nil {
f(tx)
}
@@ -443,12 +443,12 @@ func (us *UserService) TokenList(page uint, size uint, f func(tx *gorm.DB)) *mod
func (us *UserService) TokenInfoById(id uint) *model.UserToken {
ut := &model.UserToken{}
global.DB.Where("id = ?", id).First(ut)
DB.Where("id = ?", id).First(ut)
return ut
}
func (us *UserService) DeleteToken(l *model.UserToken) error {
return global.DB.Delete(l).Error
return DB.Delete(l).Error
}
// Helper functions, used for formatting username
@@ -461,29 +461,30 @@ func (us *UserService) formatUsername(username string) string {
// Helper functions, getUserCount
func (us *UserService) getUserCount() int64 {
var count int64
global.DB.Model(&model.User{}).Count(&count)
DB.Model(&model.User{}).Count(&count)
return count
}
// helper functions, getAdminUserCount
func (us *UserService) getAdminUserCount() int64 {
var count int64
global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count)
DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count)
return count
}
// UserTokenExpireTimestamp 生成用户token过期时间
func (us *UserService) UserTokenExpireTimestamp() int64 {
exp := global.Config.App.TokenExpire
exp := Config.App.TokenExpire
if exp == 0 {
exp = 3600 * 24 * 7
//默认七天
exp = 604800
}
return time.Now().Add(time.Second * time.Duration(exp)).Unix()
return time.Now().Add(exp).Unix()
}
func (us *UserService) RefreshAccessToken(ut *model.UserToken) {
ut.ExpiredAt = us.UserTokenExpireTimestamp()
global.DB.Model(ut).Update("expired_at", ut.ExpiredAt)
DB.Model(ut).Update("expired_at", ut.ExpiredAt)
}
func (us *UserService) AutoRefreshAccessToken(ut *model.UserToken) {
if ut.ExpiredAt-time.Now().Unix() < 86400 {
@@ -492,11 +493,11 @@ func (us *UserService) AutoRefreshAccessToken(ut *model.UserToken) {
}
func (us *UserService) BatchDeleteUserToken(ids []uint) error {
return global.DB.Where("id in ?", ids).Delete(&model.UserToken{}).Error
return DB.Where("id in ?", ids).Delete(&model.UserToken{}).Error
}
func (us *UserService) VerifyJWT(token string) (uint, error) {
return global.Jwt.ParseToken(token)
return Jwt.ParseToken(token)
}
// IsUsernameExists 判断用户名是否存在, it will check the internal database and LDAP(if enabled)
@@ -506,7 +507,7 @@ func (us *UserService) IsUsernameExists(username string) bool {
func (us *UserService) IsUsernameExistsLocal(username string) bool {
u := &model.User{}
global.DB.Where("username = ?", username).First(u)
DB.Where("username = ?", username).First(u)
return u.Id != 0
}

48
utils/captcha.go Normal file
View File

@@ -0,0 +1,48 @@
package utils
import (
"github.com/mojocn/base64Captcha"
"time"
)
var capdString = base64Captcha.NewDriverString(50, 150, 0, 5, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
var capdMath = base64Captcha.NewDriverMath(50, 150, 3, 10, nil, nil, nil)
type B64StringCaptchaProvider struct{}
func (p B64StringCaptchaProvider) Generate() (string, string, string, error) {
id, content, answer := capdString.GenerateIdQuestionAnswer()
return id, content, answer, nil
}
func (p B64StringCaptchaProvider) Expiration() time.Duration {
return 5 * time.Minute
}
func (p B64StringCaptchaProvider) Draw(content string) (string, error) {
item, err := capdString.DrawCaptcha(content)
if err != nil {
return "", err
}
b64str := item.EncodeB64string()
return b64str, nil
}
type B64MathCaptchaProvider struct{}
func (p B64MathCaptchaProvider) Generate() (string, string, string, error) {
id, content, answer := capdMath.GenerateIdQuestionAnswer()
return id, content, answer, nil
}
func (p B64MathCaptchaProvider) Expiration() time.Duration {
return 5 * time.Minute
}
func (p B64MathCaptchaProvider) Draw(content string) (string, error) {
item, err := capdMath.DrawCaptcha(content)
if err != nil {
return "", err
}
b64str := item.EncodeB64string()
return b64str, nil
}

296
utils/login_limiter.go Normal file
View File

@@ -0,0 +1,296 @@
package utils
import (
"errors"
"sync"
"time"
)
// 安全策略配置
type SecurityPolicy struct {
CaptchaThreshold int // 尝试失败次数达到验证码阈值小于0表示不启用, 0表示强制启用
BanThreshold int // 尝试失败次数达到封禁阈值为0表示不启用
AttemptsWindow time.Duration
BanDuration time.Duration
}
// 验证码提供者接口
type CaptchaProvider interface {
Generate() (id string, content string, answer string, err error)
//Validate(ip, code string) bool
Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow
Draw(content string) (string, error) // 绘制验证码
}
// 验证码元数据
type CaptchaMeta struct {
Id string
Content string
Answer string
ExpiresAt time.Time
}
// IP封禁记录
type BanRecord struct {
ExpiresAt time.Time
Reason string
}
// 登录限制器
type LoginLimiter struct {
mu sync.Mutex
policy SecurityPolicy
attempts map[string][]time.Time //
captchas map[string]CaptchaMeta
bannedIPs map[string]BanRecord
provider CaptchaProvider
cleanupStop chan struct{}
}
var defaultSecurityPolicy = SecurityPolicy{
CaptchaThreshold: 3,
BanThreshold: 5,
AttemptsWindow: 5 * time.Minute,
BanDuration: 30 * time.Minute,
}
func NewLoginLimiter(policy SecurityPolicy) *LoginLimiter {
// 设置默认值
if policy.AttemptsWindow == 0 {
policy.AttemptsWindow = 5 * time.Minute
}
if policy.BanDuration == 0 {
policy.BanDuration = 30 * time.Minute
}
ll := &LoginLimiter{
policy: policy,
attempts: make(map[string][]time.Time),
captchas: make(map[string]CaptchaMeta),
bannedIPs: make(map[string]BanRecord),
cleanupStop: make(chan struct{}),
}
go ll.cleanupRoutine()
return ll
}
// 注册验证码提供者
func (ll *LoginLimiter) RegisterProvider(p CaptchaProvider) {
ll.mu.Lock()
defer ll.mu.Unlock()
ll.provider = p
}
// isDisabled 检查是否禁用登录限制
func (ll *LoginLimiter) isDisabled() bool {
return ll.policy.CaptchaThreshold < 0 && ll.policy.BanThreshold == 0
}
// 记录登录失败尝试
func (ll *LoginLimiter) RecordFailedAttempt(ip string) {
if ll.isDisabled() {
return
}
ll.mu.Lock()
defer ll.mu.Unlock()
if banned, _ := ll.isBanned(ip); banned {
return
}
now := time.Now()
windowStart := now.Add(-ll.policy.AttemptsWindow)
// 清理过期尝试
validAttempts := ll.pruneAttempts(ip, windowStart)
// 记录新尝试
validAttempts = append(validAttempts, now)
ll.attempts[ip] = validAttempts
// 检查封禁条件
if ll.policy.BanThreshold > 0 && len(validAttempts) >= ll.policy.BanThreshold {
ll.banIP(ip, "excessive failed attempts")
return
}
return
}
// 生成验证码
func (ll *LoginLimiter) RequireCaptcha() (error, CaptchaMeta) {
ll.mu.Lock()
defer ll.mu.Unlock()
if ll.provider == nil {
return errors.New("no captcha provider available"), CaptchaMeta{}
}
id, content, answer, err := ll.provider.Generate()
if err != nil {
return err, CaptchaMeta{}
}
// 存储验证码
ll.captchas[id] = CaptchaMeta{
Id: id,
Content: content,
Answer: answer,
ExpiresAt: time.Now().Add(ll.provider.Expiration()),
}
return nil, ll.captchas[id]
}
// 验证验证码
func (ll *LoginLimiter) VerifyCaptcha(id, answer string) bool {
ll.mu.Lock()
defer ll.mu.Unlock()
// 查找匹配验证码
if ll.provider == nil {
return false
}
// 获取并验证验证码
captcha, exists := ll.captchas[id]
if !exists {
return false
}
// 清理过期验证码
if time.Now().After(captcha.ExpiresAt) {
delete(ll.captchas, id)
return false
}
// 验证并清理状态
if answer == captcha.Answer {
delete(ll.captchas, id)
return true
}
return false
}
func (ll *LoginLimiter) DrawCaptcha(content string) (err error, str string) {
str, err = ll.provider.Draw(content)
return
}
// 清除记录窗口
func (ll *LoginLimiter) RemoveAttempts(ip string) {
ll.mu.Lock()
defer ll.mu.Unlock()
_, exists := ll.attempts[ip]
if exists {
delete(ll.attempts, ip)
}
}
// CheckSecurityStatus 检查安全状态
func (ll *LoginLimiter) CheckSecurityStatus(ip string) (banned bool, captchaRequired bool) {
if ll.isDisabled() {
return
}
ll.mu.Lock()
defer ll.mu.Unlock()
// 检查封禁状态
if banned, _ = ll.isBanned(ip); banned {
return
}
// 清理过期数据
ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow))
// 检查验证码要求
captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold
return
}
// 后台清理任务
func (ll *LoginLimiter) cleanupRoutine() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ll.cleanupExpired()
case <-ll.cleanupStop:
return
}
}
}
// 内部工具方法
func (ll *LoginLimiter) isBanned(ip string) (bool, BanRecord) {
record, exists := ll.bannedIPs[ip]
if !exists {
return false, BanRecord{}
}
if time.Now().After(record.ExpiresAt) {
delete(ll.bannedIPs, ip)
return false, BanRecord{}
}
return true, record
}
func (ll *LoginLimiter) banIP(ip, reason string) {
ll.bannedIPs[ip] = BanRecord{
ExpiresAt: time.Now().Add(ll.policy.BanDuration),
Reason: reason,
}
delete(ll.attempts, ip)
delete(ll.captchas, ip)
}
func (ll *LoginLimiter) pruneAttempts(ip string, cutoff time.Time) []time.Time {
var valid []time.Time
for _, t := range ll.attempts[ip] {
if t.After(cutoff) {
valid = append(valid, t)
}
}
if len(valid) == 0 {
delete(ll.attempts, ip)
} else {
ll.attempts[ip] = valid
}
return valid
}
func (ll *LoginLimiter) pruneCaptchas(id string) {
if captcha, exists := ll.captchas[id]; exists {
if time.Now().After(captcha.ExpiresAt) {
delete(ll.captchas, id)
}
}
}
func (ll *LoginLimiter) cleanupExpired() {
ll.mu.Lock()
defer ll.mu.Unlock()
now := time.Now()
// 清理封禁记录
for ip, record := range ll.bannedIPs {
if now.After(record.ExpiresAt) {
delete(ll.bannedIPs, ip)
}
}
// 清理尝试记录
for ip := range ll.attempts {
ll.pruneAttempts(ip, now.Add(-ll.policy.AttemptsWindow))
}
// 清理验证码
for id := range ll.captchas {
ll.pruneCaptchas(id)
}
}

290
utils/login_limiter_test.go Normal file
View File

@@ -0,0 +1,290 @@
package utils
import (
"fmt"
"github.com/google/uuid"
"testing"
"time"
)
type MockCaptchaProvider struct{}
func (p *MockCaptchaProvider) Generate() (string, string, string, error) {
id := uuid.New().String()
content := uuid.New().String()
answer := uuid.New().String()
return id, content, answer, nil
}
func (p *MockCaptchaProvider) Expiration() time.Duration {
return 2 * time.Second
}
func (p *MockCaptchaProvider) Draw(content string) (string, error) {
return "MOCK", nil
}
func TestSecurityWorkflow(t *testing.T) {
policy := SecurityPolicy{
CaptchaThreshold: 3,
BanThreshold: 5,
AttemptsWindow: 5 * time.Minute,
BanDuration: 5 * time.Minute,
}
limiter := NewLoginLimiter(policy)
ip := "192.168.1.100"
// 测试正常失败记录
for i := 0; i < 3; i++ {
limiter.RecordFailedAttempt(ip)
}
isBanned, capRequired := limiter.CheckSecurityStatus(ip)
fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, isBanned, capRequired)
if isBanned {
t.Error("IP should not be banned yet")
}
if !capRequired {
t.Error("Captcha should be required")
}
// 测试触发封禁
for i := 0; i < 3; i++ {
limiter.RecordFailedAttempt(ip)
isBanned, capRequired = limiter.CheckSecurityStatus(ip)
fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, isBanned, capRequired)
}
// 测试封禁状态
if isBanned, _ = limiter.CheckSecurityStatus(ip); !isBanned {
t.Error("IP should be banned")
}
}
func TestCaptchaFlow(t *testing.T) {
policy := SecurityPolicy{CaptchaThreshold: 2}
limiter := NewLoginLimiter(policy)
limiter.RegisterProvider(&MockCaptchaProvider{})
ip := "10.0.0.1"
// 触发验证码要求
limiter.RecordFailedAttempt(ip)
limiter.RecordFailedAttempt(ip)
// 检查状态
if _, need := limiter.CheckSecurityStatus(ip); !need {
t.Error("应该需要验证码")
}
// 生成验证码
err, capc := limiter.RequireCaptcha()
if err != nil {
t.Fatalf("生成验证码失败: %v", err)
}
fmt.Printf("验证码内容: %#v\n", capc)
// 验证成功
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该验证成功")
}
// 验证已删除
if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该已删除")
}
limiter.RemoveAttempts(ip)
// 验证后状态
if banned, need := limiter.CheckSecurityStatus(ip); banned || need {
t.Error("验证成功后应该重置状态")
}
}
func TestCaptchaMustFlow(t *testing.T) {
policy := SecurityPolicy{CaptchaThreshold: 0}
limiter := NewLoginLimiter(policy)
limiter.RegisterProvider(&MockCaptchaProvider{})
ip := "10.0.0.1"
// 检查状态
if _, need := limiter.CheckSecurityStatus(ip); !need {
t.Error("应该需要验证码")
}
// 生成验证码
err, capc := limiter.RequireCaptcha()
if err != nil {
t.Fatalf("生成验证码失败: %v", err)
}
fmt.Printf("验证码内容: %#v\n", capc)
// 验证成功
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该验证成功")
}
// 验证后状态
if _, need := limiter.CheckSecurityStatus(ip); !need {
t.Error("应该需要验证码")
}
}
func TestAttemptTimeout(t *testing.T) {
policy := SecurityPolicy{CaptchaThreshold: 2, AttemptsWindow: 1 * time.Second}
limiter := NewLoginLimiter(policy)
limiter.RegisterProvider(&MockCaptchaProvider{})
ip := "10.0.0.1"
// 触发验证码要求
limiter.RecordFailedAttempt(ip)
limiter.RecordFailedAttempt(ip)
// 检查状态
if _, need := limiter.CheckSecurityStatus(ip); !need {
t.Error("应该需要验证码")
}
// 生成验证码
err, _ := limiter.RequireCaptcha()
if err != nil {
t.Fatalf("生成验证码失败: %v", err)
}
// 等待超过 AttemptsWindow
time.Sleep(2 * time.Second)
// 触发验证码要求
limiter.RecordFailedAttempt(ip)
// 检查状态
if _, need := limiter.CheckSecurityStatus(ip); need {
t.Error("不应该需要验证码")
}
}
func TestCaptchaTimeout(t *testing.T) {
policy := SecurityPolicy{CaptchaThreshold: 2}
limiter := NewLoginLimiter(policy)
limiter.RegisterProvider(&MockCaptchaProvider{})
ip := "10.0.0.1"
// 触发验证码要求
limiter.RecordFailedAttempt(ip)
limiter.RecordFailedAttempt(ip)
// 检查状态
if _, need := limiter.CheckSecurityStatus(ip); !need {
t.Error("应该需要验证码")
}
// 生成验证码
err, capc := limiter.RequireCaptcha()
if err != nil {
t.Fatalf("生成验证码失败: %v", err)
}
// 等待超过 CaptchaValidPeriod
time.Sleep(3 * time.Second)
// 验证成功
if limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该已过期")
}
}
func TestBanFlow(t *testing.T) {
policy := SecurityPolicy{BanThreshold: 5}
limiter := NewLoginLimiter(policy)
ip := "10.0.0.1"
// 触发ban
for i := 0; i < 5; i++ {
limiter.RecordFailedAttempt(ip)
}
// 检查状态
if banned, _ := limiter.CheckSecurityStatus(ip); !banned {
t.Error("should be banned")
}
}
func TestBanDisableFlow(t *testing.T) {
policy := SecurityPolicy{BanThreshold: 0}
limiter := NewLoginLimiter(policy)
ip := "10.0.0.1"
// 触发ban
for i := 0; i < 5; i++ {
limiter.RecordFailedAttempt(ip)
}
// 检查状态
if banned, _ := limiter.CheckSecurityStatus(ip); banned {
t.Error("should not be banned")
}
}
func TestBanTimeout(t *testing.T) {
policy := SecurityPolicy{BanThreshold: 5, BanDuration: 1 * time.Second}
limiter := NewLoginLimiter(policy)
ip := "10.0.0.1"
// 触发ban
// 触发ban
for i := 0; i < 5; i++ {
limiter.RecordFailedAttempt(ip)
}
time.Sleep(2 * time.Second)
// 检查状态
if banned, _ := limiter.CheckSecurityStatus(ip); banned {
t.Error("should not be banned")
}
}
func TestLimiterDisabled(t *testing.T) {
policy := SecurityPolicy{BanThreshold: 0, CaptchaThreshold: -1}
limiter := NewLoginLimiter(policy)
ip := "10.0.0.1"
// 触发ban
for i := 0; i < 5; i++ {
limiter.RecordFailedAttempt(ip)
}
// 检查状态
if banned, capNeed := limiter.CheckSecurityStatus(ip); banned || capNeed {
fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, banned, capNeed)
t.Error("should not be banned or need captcha")
}
}
func TestB64CaptchaFlow(t *testing.T) {
limiter := NewLoginLimiter(defaultSecurityPolicy)
limiter.RegisterProvider(B64StringCaptchaProvider{})
ip := "10.0.0.1"
// 触发验证码要求
limiter.RecordFailedAttempt(ip)
limiter.RecordFailedAttempt(ip)
limiter.RecordFailedAttempt(ip)
// 检查状态
if _, need := limiter.CheckSecurityStatus(ip); !need {
t.Error("应该需要验证码")
}
// 生成验证码
err, capc := limiter.RequireCaptcha()
if err != nil {
t.Fatalf("生成验证码失败: %v", err)
}
fmt.Printf("验证码内容: %#v\n", capc)
//draw
err, b64 := limiter.DrawCaptcha(capc.Content)
if err != nil {
t.Fatalf("绘制验证码失败: %v", err)
}
fmt.Printf("验证码内容: %#v\n", b64)
// 验证成功
if !limiter.VerifyCaptcha(capc.Id, capc.Answer) {
t.Error("验证码应该验证成功")
}
limiter.RemoveAttempts(ip)
// 验证后状态
if banned, need := limiter.CheckSecurityStatus(ip); banned || need {
t.Error("验证成功后应该重置状态")
}
}

View File

@@ -7,6 +7,7 @@ import (
"math/rand"
"reflect"
"runtime/debug"
"strings"
)
func Md5(str string) string {
@@ -100,3 +101,11 @@ func InArray(k string, arr []string) bool {
}
return false
}
func StringConcat(strs ...string) string {
var builder strings.Builder
for _, str := range strs {
builder.WriteString(str)
}
return builder.String()
}