Compare commits
228 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16ec462abd | ||
|
|
ca55465d3c | ||
|
|
7098c98dde | ||
|
|
f56355da89 | ||
|
|
422160debd | ||
|
|
8062cf406a | ||
|
|
0e802232ec | ||
|
|
f650a9205d | ||
|
|
c85dbb2347 | ||
|
|
a6a79128c8 | ||
|
|
42839627e8 | ||
|
|
267e68a894 | ||
|
|
b32b444438 | ||
|
|
522d0f8313 | ||
|
|
5715e5de67 | ||
|
|
cc6b05e8b3 | ||
|
|
417747d5d0 | ||
|
|
a34f439226 | ||
|
|
b7ca014fd0 | ||
|
|
fa098d585a | ||
|
|
c35a14e3ec | ||
|
|
60651736a5 | ||
|
|
581f9b7bd3 | ||
|
|
124eb04807 | ||
|
|
1d561da7fb | ||
|
|
16e3cd0784 | ||
|
|
a6d91933dc | ||
|
|
445c40f758 | ||
|
|
725a841a3b | ||
|
|
f77c453843 | ||
|
|
ba6718d5bc | ||
|
|
cdb7a1b3fa | ||
|
|
a03c79b89d | ||
|
|
98800d3426 | ||
|
|
a616adaac4 | ||
|
|
ffb5605c99 | ||
|
|
621b556856 | ||
|
|
a3ffecbb2a | ||
|
|
ea64cebe2a | ||
|
|
e79487dd5f | ||
|
|
7fe1c1ec89 | ||
|
|
ab2bbff369 | ||
|
|
ec32825309 | ||
|
|
fd0c182087 | ||
|
|
49fcff1daf | ||
|
|
33b64ddf39 | ||
|
|
4c447aa648 | ||
|
|
ccbfc3d274 | ||
|
|
f83fe43bbb | ||
|
|
19022d67f8 | ||
|
|
58a815dd6b | ||
|
|
bc9fe82860 | ||
|
|
b3cd9bf2b9 | ||
|
|
c5c2b829ec | ||
|
|
9713f96401 | ||
|
|
11f35ebf96 | ||
|
|
7d403aa181 | ||
|
|
64af810a4a | ||
|
|
30821905af | ||
|
|
a9dbff756b | ||
|
|
a6aba10d3d | ||
|
|
9c276c37fe | ||
|
|
6ab6c0fd4c | ||
|
|
b6b0fe3fff | ||
|
|
0d5825bda9 | ||
|
|
cdfb64631a | ||
|
|
d161c281c8 | ||
|
|
8fed5bf2a1 | ||
|
|
98d2e9bd27 | ||
|
|
a03af55edd | ||
|
|
86e2fd9aee | ||
|
|
97bd0e5e58 | ||
|
|
ceaba21986 | ||
|
|
172a77d942 | ||
|
|
4f9d2d2a7d | ||
|
|
8c929f6e05 | ||
|
|
3319b71f5b | ||
|
|
46ec028a5b | ||
|
|
0ce0ef3e5c | ||
|
|
375b071cb2 | ||
|
|
29e1417ff2 | ||
|
|
75db2bd366 | ||
|
|
60ca1efbda | ||
|
|
2692e4978b | ||
|
|
91982eb002 | ||
|
|
bb1dec76fa | ||
|
|
f618b8fcdc | ||
|
|
9147cab75b | ||
|
|
5f07bcc8e6 | ||
|
|
705cf2ea1b | ||
|
|
42c4394484 | ||
|
|
221221a3c1 | ||
|
|
9564166297 | ||
|
|
f5cf3c3c8e | ||
|
|
18f919fb6b | ||
|
|
0924835253 | ||
|
|
20d2e5c578 | ||
|
|
907801605c | ||
|
|
93bc684e8c | ||
|
|
a76c98d57e | ||
|
|
d937a800d0 | ||
|
|
d16f3a227f | ||
|
|
80c9a3eeda | ||
|
|
e68173b451 | ||
|
|
40c27d87f5 | ||
|
|
3c13b5049d | ||
|
|
8288d5e51f | ||
|
|
6e1449900a | ||
|
|
4ffbb18ab4 | ||
|
|
b27271b7a3 | ||
|
|
ebb6665f64 | ||
|
|
e4e5731ffd | ||
|
|
2ab5810f13 | ||
|
|
af934c5d09 | ||
|
|
1e0cf7c112 | ||
|
|
46859c93c9 | ||
|
|
ea1f9cb3b2 | ||
|
|
1641549016 | ||
|
|
716a5dbb8a | ||
|
|
af98cb11c5 | ||
|
|
9a4c2cf341 | ||
|
|
2bc3bcd102 | ||
|
|
d6c663f79d | ||
|
|
9ed86e5f53 | ||
|
|
303e0bc037 | ||
|
|
2cc24019f9 | ||
|
|
83ce774d19 | ||
|
|
2b4ee13b5e | ||
|
|
3a964561f0 | ||
|
|
6959f86632 | ||
|
|
537d373e10 | ||
|
|
cceadf222c | ||
|
|
cf5a4af623 | ||
|
|
39aea11c22 | ||
|
|
c2f1227700 | ||
|
|
900f14d37c | ||
|
|
598249b1d6 | ||
|
|
7ed15bdf04 | ||
|
|
2fc0ec0f72 | ||
|
|
5e9c2a669b | ||
|
|
b310521884 | ||
|
|
288945bf7e | ||
|
|
4fc07cff36 | ||
|
|
b884fe0e86 | ||
|
|
855858c236 | ||
|
|
c11a2a5419 | ||
|
|
773a6572af | ||
|
|
88ad373c9b | ||
|
|
51666464b9 | ||
|
|
5af9cf2f52 | ||
|
|
12c4ae4b10 | ||
|
|
4e1bef414a | ||
|
|
e896c18644 | ||
|
|
c852685e74 | ||
|
|
1e99797df8 | ||
|
|
52a4c986a8 | ||
|
|
c501728204 | ||
|
|
6b067fa6a7 | ||
|
|
a1cd5c53a9 | ||
|
|
a46d487e03 | ||
|
|
3deb6d3ab3 | ||
|
|
af34cdd5d2 | ||
|
|
6e1393235a | ||
|
|
343e0b54b9 | ||
|
|
ecb70cb6f7 | ||
|
|
ca50618af6 | ||
|
|
29c07ba83e | ||
|
|
45fbb83a9f | ||
|
|
ae7ba2df25 | ||
|
|
c3ef57cc32 | ||
|
|
7bb4ca5a14 | ||
|
|
063783d81d | ||
|
|
42116c9b65 | ||
|
|
a36e11973d | ||
|
|
5125568ea2 | ||
|
|
0fa164e50d | ||
|
|
cf814e81ee | ||
|
|
43a45f18ce | ||
|
|
ad51381063 | ||
|
|
0b0e4ce904 | ||
|
|
6a3e04d688 | ||
|
|
4107a17370 | ||
|
|
06b4d8f169 | ||
|
|
1c0c820746 | ||
|
|
d061403a28 | ||
|
|
5c092321a6 | ||
|
|
bdd3f61c1f | ||
|
|
8023557d6e | ||
|
|
074b0ced7a | ||
|
|
3864b1ac9b | ||
|
|
6e9b43457d | ||
|
|
ca1aec8920 | ||
|
|
acac580862 | ||
|
|
673e1b2980 | ||
|
|
f62157be72 | ||
|
|
f894ecf3b6 | ||
|
|
66dd4e28ad | ||
|
|
939dc1b0fb | ||
|
|
56bf5d38a1 | ||
|
|
d09b70b295 | ||
|
|
205180387a | ||
|
|
39c8cfeda5 | ||
|
|
f38a329be5 | ||
|
|
a0cd069539 | ||
|
|
bf306a2f01 | ||
|
|
c31f93a8d1 | ||
|
|
4730ab6309 | ||
|
|
1ae78ca98c | ||
|
|
d2379da478 | ||
|
|
0f64981b20 | ||
|
|
0002e49bb5 | ||
|
|
db13a60274 | ||
|
|
db0f11a359 | ||
|
|
ac7f43520b | ||
|
|
f67b9f5f6e | ||
|
|
c75156c4ce | ||
|
|
10270b5595 | ||
|
|
f7458572ed | ||
|
|
d57b7222b2 | ||
|
|
62e70a673a | ||
|
|
5e9eba6478 | ||
|
|
c5ccc1a084 | ||
|
|
6439917cbe | ||
|
|
d21c18f657 | ||
|
|
e6981290bc | ||
|
|
75c3d8abbd | ||
|
|
d88683f498 | ||
|
|
40b9aa3a4c |
30
.github/workflows/auto_release.yml
vendored
30
.github/workflows/auto_release.yml
vendored
@@ -23,6 +23,36 @@ jobs:
|
||||
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
|
||||
echo ${{ github.ref_name }} > dist/assets/version
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Upload to Cloudflare R2
|
||||
env:
|
||||
R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }}
|
||||
R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}
|
||||
R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
|
||||
R2_BUCKET_NAME: "astrbot"
|
||||
R2_OBJECT_NAME: "astrbot-webui-latest.zip"
|
||||
VERSION_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
echo "Installing rclone..."
|
||||
curl https://rclone.org/install.sh | sudo bash
|
||||
|
||||
echo "Configuring rclone remote..."
|
||||
mkdir -p ~/.config/rclone
|
||||
cat <<EOF > ~/.config/rclone/rclone.conf
|
||||
[r2]
|
||||
type = s3
|
||||
provider = Cloudflare
|
||||
access_key_id = $R2_ACCESS_KEY_ID
|
||||
secret_access_key = $R2_SECRET_ACCESS_KEY
|
||||
endpoint = https://${R2_ACCOUNT_ID}.r2.cloudflarestorage.com
|
||||
EOF
|
||||
|
||||
echo "Uploading dist.zip to R2 bucket: $R2_BUCKET_NAME/$R2_OBJECT_NAME"
|
||||
mv dashboard/dist.zip dashboard/$R2_OBJECT_NAME
|
||||
rclone copy dashboard/$R2_OBJECT_NAME r2:$R2_BUCKET_NAME --progress
|
||||
mv dashboard/$R2_OBJECT_NAME dashboard/astrbot-webui-${VERSION_TAG}.zip
|
||||
rclone copy dashboard/astrbot-webui-${VERSION_TAG}.zip r2:$R2_BUCKET_NAME --progress
|
||||
mv dashboard/astrbot-webui-${VERSION_TAG}.zip dashboard/dist.zip
|
||||
|
||||
- name: Fetch Changelog
|
||||
run: |
|
||||
|
||||
35
.github/workflows/docker-image.yml
vendored
35
.github/workflows/docker-image.yml
vendored
@@ -11,24 +11,42 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: 拉取源码
|
||||
- name: Pull The Codes
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
fetch-depth: 0 # Must be 0 so we can fetch tags
|
||||
|
||||
- name: 设置 QEMU
|
||||
- name: Get latest tag (only on manual trigger)
|
||||
id: get-latest-tag
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: |
|
||||
tag=$(git describe --tags --abbrev=0)
|
||||
echo "latest_tag=$tag" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Checkout to latest tag (only on manual trigger)
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
|
||||
|
||||
- name: Set QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: 设置 Docker Buildx
|
||||
- name: Set Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: 登录到 DockerHub
|
||||
- name: Log in to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
|
||||
- name: 构建和推送 Docker hub
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: Soulter
|
||||
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and Push Docker to DockerHub and Github GHCR
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
@@ -36,8 +54,9 @@ jobs:
|
||||
push: true
|
||||
tags: |
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.ref_name }}
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||
ghcr.io/soulter/astrbot:latest
|
||||
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||
|
||||
- name: Post build notifications
|
||||
run: echo "Docker image has been built and pushed successfully"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.10-slim
|
||||
FROM python:3.11-slim
|
||||
WORKDIR /AstrBot
|
||||
|
||||
COPY . /AstrBot/
|
||||
|
||||
15
README.md
15
README.md
@@ -31,13 +31,21 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
<!-- [](https://codecov.io/gh/Soulter/AstrBot)
|
||||
-->
|
||||
|
||||
> [!NOTE]
|
||||
> [!WARNING]
|
||||
>
|
||||
> 个人微信接入所依赖的开源项目 Gewechat 近期已停止维护,`v3.5.10` 已经支持接入 WeChatPadPro 替换 gewechat 方式。详见文档 [WeChatPadPro](https://astrbot.app/deploy/platform/wechat/wechatpadpro.html)
|
||||
> 请务必修改默认密码以及保证 AstrBot 版本 >= 3.5.13。
|
||||
|
||||
## ✨ 近期更新
|
||||
|
||||
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||
<details><summary>1. AstrBot 现已自带知识库能力</summary>
|
||||
|
||||
📚 详见[文档](https://astrbot.app/use/knowledge-base.html)
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
2. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||
|
||||
## ✨ 主要功能
|
||||
|
||||
@@ -171,7 +179,6 @@ pre-commit install
|
||||
|
||||
- Star 这个项目!
|
||||
- 在[爱发电](https://afdian.com/a/soulter)支持我!
|
||||
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
|
||||
|
||||
## ✨ Demo
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import tempfile
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
import re
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
@@ -59,7 +58,16 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
|
||||
proxy=proxy if proxy else None, follow_redirects=True
|
||||
) as client:
|
||||
resp = client.get(download_url)
|
||||
resp.raise_for_status()
|
||||
if (
|
||||
resp.status_code == 404
|
||||
and "archive/refs/heads/master.zip" in download_url
|
||||
):
|
||||
alt_url = download_url.replace("master.zip", "main.zip")
|
||||
click.echo("master 分支不存在,尝试下载 main 分支")
|
||||
resp = client.get(alt_url)
|
||||
resp.raise_for_status()
|
||||
else:
|
||||
resp.raise_for_status()
|
||||
zip_content = BytesIO(resp.content)
|
||||
with ZipFile(zip_content) as z:
|
||||
z.extractall(temp_dir)
|
||||
@@ -91,39 +99,6 @@ def load_yaml_metadata(plugin_dir: Path) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
def extract_py_metadata(plugin_dir: Path) -> dict:
|
||||
"""从 Python 文件中提取插件元数据
|
||||
|
||||
Args:
|
||||
plugin_dir: 插件目录路径
|
||||
|
||||
Returns:
|
||||
dict: 包含元数据的字典,如果提取失败则返回空字典
|
||||
"""
|
||||
# 检查 main.py 或与目录同名的 py 文件
|
||||
for pattern in ["main.py", f"{plugin_dir.name}.py"]:
|
||||
for py_file in plugin_dir.glob(pattern):
|
||||
try:
|
||||
content = py_file.read_text(encoding="utf-8")
|
||||
register_match = re.search(
|
||||
r'@register_star\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"(?:\s*,\s*"?([^")]+)"?)?\s*\)',
|
||||
content,
|
||||
)
|
||||
if register_match:
|
||||
# 映射匹配组到元数据键
|
||||
metadata = {}
|
||||
keys = ["name", "author", "desc", "version", "repo"]
|
||||
for i, key in enumerate(keys):
|
||||
if i + 1 <= len(
|
||||
register_match.groups()
|
||||
) and register_match.group(i + 1):
|
||||
metadata[key] = register_match.group(i + 1)
|
||||
return metadata
|
||||
except Exception as e:
|
||||
click.echo(f"读取 {py_file} 失败: {e}", err=True)
|
||||
return {}
|
||||
|
||||
|
||||
def build_plug_list(plugins_dir: Path) -> list:
|
||||
"""构建插件列表,包含本地和在线插件信息
|
||||
|
||||
@@ -139,31 +114,22 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
|
||||
plugin_dir = plugins_dir / plugin_name
|
||||
|
||||
# 从不同来源加载元数据
|
||||
# 从 metadata.yaml 加载元数据
|
||||
metadata = load_yaml_metadata(plugin_dir)
|
||||
|
||||
# 如果元数据不完整,尝试从 Python 文件提取
|
||||
if not metadata or not all(
|
||||
# 如果成功加载元数据,添加到结果列表
|
||||
if metadata and all(
|
||||
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||
):
|
||||
py_metadata = extract_py_metadata(plugin_dir)
|
||||
# 合并元数据,保留已有的值
|
||||
for key, value in py_metadata.items():
|
||||
if key not in metadata or not metadata[key]:
|
||||
metadata[key] = value
|
||||
# 如果成功提取元数据,添加到结果列表
|
||||
if metadata:
|
||||
result.append(
|
||||
{
|
||||
"name": str(metadata.get("name", "")),
|
||||
"desc": str(metadata.get("desc", "")),
|
||||
"version": str(metadata.get("version", "")),
|
||||
"author": str(metadata.get("author", "")),
|
||||
"repo": str(metadata.get("repo", "")),
|
||||
"status": PluginStatus.INSTALLED,
|
||||
"local_path": str(plugin_dir),
|
||||
}
|
||||
)
|
||||
result.append({
|
||||
"name": str(metadata.get("name", "")),
|
||||
"desc": str(metadata.get("desc", "")),
|
||||
"version": str(metadata.get("version", "")),
|
||||
"author": str(metadata.get("author", "")),
|
||||
"repo": str(metadata.get("repo", "")),
|
||||
"status": PluginStatus.INSTALLED,
|
||||
"local_path": str(plugin_dir),
|
||||
})
|
||||
|
||||
# 获取在线插件列表
|
||||
online_plugins = []
|
||||
@@ -173,17 +139,15 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
for plugin_id, plugin_info in data.items():
|
||||
online_plugins.append(
|
||||
{
|
||||
"name": str(plugin_id),
|
||||
"desc": str(plugin_info.get("desc", "")),
|
||||
"version": str(plugin_info.get("version", "")),
|
||||
"author": str(plugin_info.get("author", "")),
|
||||
"repo": str(plugin_info.get("repo", "")),
|
||||
"status": PluginStatus.NOT_INSTALLED,
|
||||
"local_path": None,
|
||||
}
|
||||
)
|
||||
online_plugins.append({
|
||||
"name": str(plugin_id),
|
||||
"desc": str(plugin_info.get("desc", "")),
|
||||
"version": str(plugin_info.get("version", "")),
|
||||
"author": str(plugin_info.get("author", "")),
|
||||
"repo": str(plugin_info.get("repo", "")),
|
||||
"status": PluginStatus.NOT_INSTALLED,
|
||||
"local_path": None,
|
||||
})
|
||||
except Exception as e:
|
||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ class AstrBotConfig(dict):
|
||||
"""不存在时载入默认配置"""
|
||||
with open(config_path, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(default_config, f, indent=4, ensure_ascii=False)
|
||||
object.__setattr__(self, "first_deploy", True) # 标记第一次部署
|
||||
|
||||
with open(config_path, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
@@ -82,23 +83,61 @@ class AstrBotConfig(dict):
|
||||
return conf
|
||||
|
||||
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
|
||||
"""检查配置完整性,如果有新的配置项则返回 True"""
|
||||
"""检查配置完整性,如果有新的配置项或顺序不一致则返回 True"""
|
||||
has_new = False
|
||||
|
||||
# 创建一个新的有序字典以保持参考配置的顺序
|
||||
new_conf = {}
|
||||
|
||||
# 先按照参考配置的顺序添加配置项
|
||||
for key, value in refer_conf.items():
|
||||
if key not in conf:
|
||||
# logger.info(f"检查到配置项 {path + "." + key if path else key} 不存在,已插入默认值 {value}")
|
||||
# 配置项不存在,插入默认值
|
||||
path_ = path + "." + key if path else key
|
||||
logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}")
|
||||
conf[key] = value
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
else:
|
||||
if conf[key] is None:
|
||||
conf[key] = value
|
||||
# 配置项为 None,使用默认值
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
elif isinstance(value, dict):
|
||||
has_new |= self.check_config_integrity(
|
||||
value, conf[key], path + "." + key if path else key
|
||||
)
|
||||
# 递归检查子配置项
|
||||
if not isinstance(conf[key], dict):
|
||||
# 类型不匹配,使用默认值
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
else:
|
||||
# 递归检查并同步顺序
|
||||
child_has_new = self.check_config_integrity(
|
||||
value, conf[key], path + "." + key if path else key
|
||||
)
|
||||
new_conf[key] = conf[key]
|
||||
has_new |= child_has_new
|
||||
else:
|
||||
# 直接使用现有配置
|
||||
new_conf[key] = conf[key]
|
||||
|
||||
# 检查是否存在参考配置中没有的配置项
|
||||
for key in list(conf.keys()):
|
||||
if key not in refer_conf:
|
||||
path_ = path + "." + key if path else key
|
||||
logger.info(f"检查到配置项 {path_} 不存在,将从当前配置中删除")
|
||||
has_new = True
|
||||
|
||||
# 顺序不一致也算作变更
|
||||
if list(conf.keys()) != list(new_conf.keys()):
|
||||
if path:
|
||||
logger.info(f"检查到配置项 {path} 的子项顺序不一致,已重新排序")
|
||||
else:
|
||||
logger.info("检查到配置项顺序不一致,已重新排序")
|
||||
has_new = True
|
||||
|
||||
# 更新原始配置
|
||||
conf.clear()
|
||||
conf.update(new_conf)
|
||||
|
||||
return has_new
|
||||
|
||||
def save_config(self, replace_config: Dict = None):
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import os
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "3.5.10"
|
||||
VERSION = "3.5.15"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
|
||||
|
||||
# 默认配置
|
||||
@@ -40,12 +40,15 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
"no_permission_reply": True,
|
||||
"empty_mention_waiting": True,
|
||||
"empty_mention_waiting_need_reply": True,
|
||||
"friend_message_needs_wake_prefix": False,
|
||||
"ignore_bot_self_message": False,
|
||||
"ignore_at_all": False,
|
||||
},
|
||||
"provider": [],
|
||||
"provider_settings": {
|
||||
"enable": True,
|
||||
"default_provider_id": "",
|
||||
"wake_prefix": "",
|
||||
"web_search": False,
|
||||
"web_search_link": False,
|
||||
@@ -57,6 +60,7 @@ DEFAULT_CONFIG = {
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"streaming_segmented": False,
|
||||
"separate_provider": False,
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -66,6 +70,7 @@ DEFAULT_CONFIG = {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
"dual_output": False,
|
||||
"use_file_service": False,
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"group_icl_enable": False,
|
||||
@@ -91,6 +96,7 @@ DEFAULT_CONFIG = {
|
||||
"t2i_word_threshold": 150,
|
||||
"t2i_strategy": "remote",
|
||||
"t2i_endpoint": "",
|
||||
"t2i_use_file_service": False,
|
||||
"http_proxy": "",
|
||||
"dashboard": {
|
||||
"enable": True,
|
||||
@@ -176,6 +182,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base_url": "https://api.weixin.qq.com/cgi-bin/",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6194,
|
||||
"active_send_mode": False,
|
||||
},
|
||||
"wecom(企业微信)": {
|
||||
"id": "wecom",
|
||||
@@ -220,20 +227,25 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"active_send_mode": {
|
||||
"description": "是否换用主动发送接口",
|
||||
"type": "bool",
|
||||
"desc": "只有企业认证的公众号才能主动发送。主动发送接口的限制会少一些。",
|
||||
},
|
||||
"wpp_active_message_poll": {
|
||||
"description": "是否启用主动消息轮询",
|
||||
"type": "bool",
|
||||
"hint": "只有当你发现微信消息没有按时同步到 AstrBot 时,才需要启用这个功能,默认不启用。"
|
||||
"description": "是否启用主动消息轮询",
|
||||
"type": "bool",
|
||||
"hint": "只有当你发现微信消息没有按时同步到 AstrBot 时,才需要启用这个功能,默认不启用。",
|
||||
},
|
||||
"wpp_active_message_poll_interval": {
|
||||
"description": "主动消息轮询间隔",
|
||||
"type": "int",
|
||||
"hint": "主动消息轮询间隔,单位为秒,默认 3 秒,最大不要超过 60 秒,否则可能被认为是旧消息。"
|
||||
"description": "主动消息轮询间隔",
|
||||
"type": "int",
|
||||
"hint": "主动消息轮询间隔,单位为秒,默认 3 秒,最大不要超过 60 秒,否则可能被认为是旧消息。",
|
||||
},
|
||||
"kf_name": {
|
||||
"description": "微信客服账号名",
|
||||
"type": "string",
|
||||
"hint": "可选。微信客服账号名(不是 ID)。可在 https://kf.weixin.qq.com/kf/frame#/accounts 获取"
|
||||
"description": "微信客服账号名",
|
||||
"type": "string",
|
||||
"hint": "可选。微信客服账号名(不是 ID)。可在 https://kf.weixin.qq.com/kf/frame#/accounts 获取",
|
||||
},
|
||||
"telegram_token": {
|
||||
"description": "Bot Token",
|
||||
@@ -256,10 +268,10 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "Telegram 命令自动刷新间隔,单位为秒。",
|
||||
},
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"description": "机器人名称",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "ID 不能和其它的平台适配器重复,否则将发生严重冲突。",
|
||||
"hint": "机器人名称(ID)不能和其它的平台适配器重复。",
|
||||
},
|
||||
"type": {
|
||||
"description": "适配器类型",
|
||||
@@ -347,9 +359,14 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。",
|
||||
},
|
||||
"empty_mention_waiting": {
|
||||
"description": "只 @ 机器人是否触发等待回复",
|
||||
"description": "只 @ 机器人是否触发等待",
|
||||
"type": "bool",
|
||||
"hint": "启用后,当消息内容只有 @ 机器人时,会触发等待回复,在 60 秒内的该用户的任意一条消息均会唤醒机器人。这在某些平台不支持 @ 和语音/图片等消息同时发送时特别有用。",
|
||||
"hint": "启用后,当消息内容只有 @ 机器人时,会触发等待,在 60 秒内的该用户的任意一条消息均会唤醒机器人。这在某些平台不支持 @ 和语音/图片等消息同时发送时特别有用。",
|
||||
},
|
||||
"empty_mention_waiting_need_reply": {
|
||||
"description": "只 @ 机器人触发等待时是否需要回复提醒",
|
||||
"type": "bool",
|
||||
"hint": "在上面一个配置项中,如果启用了触发等待,启用此项后,机器人会使用 LLM 生成一条回复。否则,将不回复而只是等待。",
|
||||
},
|
||||
"friend_message_needs_wake_prefix": {
|
||||
"description": "私聊消息是否需要唤醒前缀",
|
||||
@@ -361,6 +378,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
|
||||
},
|
||||
"ignore_at_all": {
|
||||
"description": "是否忽略 @ 全体成员",
|
||||
"type": "bool",
|
||||
"hint": "启用后,机器人会忽略 @ 全体成员 的消息事件。",
|
||||
},
|
||||
"segmented_reply": {
|
||||
"description": "分段回复",
|
||||
"type": "object",
|
||||
@@ -612,6 +634,7 @@ CONFIG_METADATA_2 = {
|
||||
"gm_resp_image_modal": False,
|
||||
"gm_native_search": False,
|
||||
"gm_native_coderunner": False,
|
||||
"gm_url_context": False,
|
||||
"gm_safety_settings": {
|
||||
"harassment": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"hate_speech": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
@@ -818,7 +841,7 @@ CONFIG_METADATA_2 = {
|
||||
"azure_tts_rate": "1",
|
||||
"azure_tts_volume": "100",
|
||||
"azure_tts_subscription_key": "",
|
||||
"azure_tts_region": "eastus"
|
||||
"azure_tts_region": "eastus",
|
||||
},
|
||||
"MiniMax TTS(API)": {
|
||||
"id": "minimax_tts",
|
||||
@@ -841,44 +864,158 @@ CONFIG_METADATA_2 = {
|
||||
"minimax-voice-english-normalization": False,
|
||||
"timeout": 20,
|
||||
},
|
||||
"火山引擎_TTS(API)": {
|
||||
"id": "volcengine_tts",
|
||||
"type": "volcengine_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"appid": "",
|
||||
"volcengine_cluster": "volcano_tts",
|
||||
"volcengine_voice_type": "",
|
||||
"volcengine_speed_ratio": 1.0,
|
||||
"api_base": "https://openspeech.bytedance.com/api/v1/tts",
|
||||
"timeout": 20,
|
||||
},
|
||||
"OpenAI Embedding": {
|
||||
"id": "openai_embedding",
|
||||
"type": "openai_embedding",
|
||||
"provider_type": "embedding",
|
||||
"enable": True,
|
||||
"embedding_api_key": "",
|
||||
"embedding_api_base": "",
|
||||
"embedding_model": "",
|
||||
"embedding_dimensions": 1536,
|
||||
"timeout": 20,
|
||||
},
|
||||
"Gemini Embedding": {
|
||||
"id": "gemini_embedding",
|
||||
"type": "gemini_embedding",
|
||||
"provider_type": "embedding",
|
||||
"enable": True,
|
||||
"embedding_api_key": "",
|
||||
"embedding_api_base": "",
|
||||
"embedding_model": "gemini-embedding-exp-03-07",
|
||||
"embedding_dimensions": 768,
|
||||
"timeout": 20,
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"embedding_dimensions": {
|
||||
"description": "嵌入维度",
|
||||
"type": "int",
|
||||
"hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。",
|
||||
},
|
||||
"embedding_model": {
|
||||
"description": "嵌入模型",
|
||||
"type": "string",
|
||||
"hint": "嵌入模型名称。",
|
||||
},
|
||||
"embedding_api_key": {
|
||||
"description": "API Key",
|
||||
"type": "string",
|
||||
},
|
||||
"embedding_api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
},
|
||||
"volcengine_cluster": {
|
||||
"type": "string",
|
||||
"description": "火山引擎集群",
|
||||
"hint": "若使用语音复刻大模型,可选volcano_icl或volcano_icl_concurr,默认使用volcano_tts",
|
||||
},
|
||||
"volcengine_voice_type": {
|
||||
"type": "string",
|
||||
"description": "火山引擎音色",
|
||||
"hint": "输入声音id(Voice_type)",
|
||||
},
|
||||
"volcengine_speed_ratio": {
|
||||
"type": "float",
|
||||
"description": "语速设置",
|
||||
"hint": "语速设置,范围为 0.2 到 3.0,默认值为 1.0",
|
||||
},
|
||||
"volcengine_volume_ratio": {
|
||||
"type": "float",
|
||||
"description": "音量设置",
|
||||
"hint": "音量设置,范围为 0.0 到 2.0,默认值为 1.0",
|
||||
},
|
||||
"azure_tts_voice": {
|
||||
"type": "string",
|
||||
"description": "音色设置",
|
||||
"hint": "API 音色"
|
||||
"hint": "API 音色",
|
||||
},
|
||||
"azure_tts_style": {
|
||||
"type": "string",
|
||||
"description": "风格设置",
|
||||
"hint": "声音特定的讲话风格。 可以表达快乐、同情和平静等情绪。"
|
||||
"hint": "声音特定的讲话风格。 可以表达快乐、同情和平静等情绪。",
|
||||
},
|
||||
"azure_tts_role": {
|
||||
"type": "string",
|
||||
"description": "模仿设置(可选)",
|
||||
"hint": "讲话角色扮演。 声音可以模仿不同的年龄和性别,但声音名称不会更改。 例如,男性语音可以提高音调和改变语调来模拟女性语音,但语音名称不会更改。 如果角色缺失或不受声音的支持,则会忽略此属性。",
|
||||
"options": ["Boy","Girl","YoungAdultFemale","YoungAdultMale","OlderAdultFemale","OlderAdultMale","SeniorFemale","SeniorMale","禁用"]
|
||||
"options": [
|
||||
"Boy",
|
||||
"Girl",
|
||||
"YoungAdultFemale",
|
||||
"YoungAdultMale",
|
||||
"OlderAdultFemale",
|
||||
"OlderAdultMale",
|
||||
"SeniorFemale",
|
||||
"SeniorMale",
|
||||
"禁用",
|
||||
],
|
||||
},
|
||||
"azure_tts_rate": {
|
||||
"type": "string",
|
||||
"description": "语速设置",
|
||||
"hint": "指示文本的讲出速率。可在字词或句子层面应用语速。 速率变化应为原始音频的 0.5 到 2 倍。"
|
||||
"hint": "指示文本的讲出速率。可在字词或句子层面应用语速。 速率变化应为原始音频的 0.5 到 2 倍。",
|
||||
},
|
||||
"azure_tts_volume": {
|
||||
"type": "string",
|
||||
"description": "语音音量设置",
|
||||
"hint": "指示语音的音量级别。 可在句子层面应用音量的变化。以从 0.0 到 100.0(从最安静到最大声,例如 75)的数字表示。 默认值为 100.0。"
|
||||
"hint": "指示语音的音量级别。 可在句子层面应用音量的变化。以从 0.0 到 100.0(从最安静到最大声,例如 75)的数字表示。 默认值为 100.0。",
|
||||
},
|
||||
"azure_tts_region": {
|
||||
"type": "string",
|
||||
"description": "API 地区",
|
||||
"hint": "Azure_TTS 处理数据所在区域,具体参考 https://learn.microsoft.com/zh-cn/azure/ai-services/speech-service/regions",
|
||||
"options": ["southafricanorth", "eastasia", "southeastasia", "australiaeast", "centralindia", "japaneast", "japanwest", "koreacentral", "canadacentral", "northeurope", "westeurope", "francecentral", "germanywestcentral", "norwayeast", "swedencentral", "switzerlandnorth", "switzerlandwest", "uksouth", "uaenorth", "brazilsouth", "qatarcentral", "centralus", "eastus", "eastus2", "northcentralus", "southcentralus", "westcentralus", "westus", "westus2", "westus3"]
|
||||
"options": [
|
||||
"southafricanorth",
|
||||
"eastasia",
|
||||
"southeastasia",
|
||||
"australiaeast",
|
||||
"centralindia",
|
||||
"japaneast",
|
||||
"japanwest",
|
||||
"koreacentral",
|
||||
"canadacentral",
|
||||
"northeurope",
|
||||
"westeurope",
|
||||
"francecentral",
|
||||
"germanywestcentral",
|
||||
"norwayeast",
|
||||
"swedencentral",
|
||||
"switzerlandnorth",
|
||||
"switzerlandwest",
|
||||
"uksouth",
|
||||
"uaenorth",
|
||||
"brazilsouth",
|
||||
"qatarcentral",
|
||||
"centralus",
|
||||
"eastus",
|
||||
"eastus2",
|
||||
"northcentralus",
|
||||
"southcentralus",
|
||||
"westcentralus",
|
||||
"westus",
|
||||
"westus2",
|
||||
"westus3",
|
||||
],
|
||||
},
|
||||
"azure_tts_subscription_key": {
|
||||
"type": "string",
|
||||
"description": "服务订阅密钥",
|
||||
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)"
|
||||
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)",
|
||||
},
|
||||
"dashscope_tts_voice": {
|
||||
"description": "语音合成模型",
|
||||
@@ -902,6 +1039,12 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "启用后所有函数工具将全部失效",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_url_context": {
|
||||
"description": "启用URL上下文功能",
|
||||
"type": "bool",
|
||||
"hint": "启用后所有函数工具将全部失效",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_safety_settings": {
|
||||
"description": "安全过滤器",
|
||||
"type": "object",
|
||||
@@ -973,7 +1116,33 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"description": "指定语言/方言",
|
||||
"hint": "增强对指定的小语种和方言的识别能力,设置后可以提升在指定小语种/方言场景下的语音表现",
|
||||
"options": [ "Chinese","Chinese,Yue","English","Arabic","Russian","Spanish","French","Portuguese","German","Turkish","Dutch","Ukrainian","Vietnamese","Indonesian","Japanese","Italian","Korean","Thai","Polish","Romanian","Greek","Czech","Finnish","Hindi","auto",],
|
||||
"options": [
|
||||
"Chinese",
|
||||
"Chinese,Yue",
|
||||
"English",
|
||||
"Arabic",
|
||||
"Russian",
|
||||
"Spanish",
|
||||
"French",
|
||||
"Portuguese",
|
||||
"German",
|
||||
"Turkish",
|
||||
"Dutch",
|
||||
"Ukrainian",
|
||||
"Vietnamese",
|
||||
"Indonesian",
|
||||
"Japanese",
|
||||
"Italian",
|
||||
"Korean",
|
||||
"Thai",
|
||||
"Polish",
|
||||
"Romanian",
|
||||
"Greek",
|
||||
"Czech",
|
||||
"Finnish",
|
||||
"Hindi",
|
||||
"auto",
|
||||
],
|
||||
},
|
||||
"minimax-voice-speed": {
|
||||
"type": "float",
|
||||
@@ -1010,7 +1179,15 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"description": "情绪",
|
||||
"hint": "控制合成语音的情绪",
|
||||
"options": ["happy","sad","angry","fearful","disgusted","surprised","neutral",],
|
||||
"options": [
|
||||
"happy",
|
||||
"sad",
|
||||
"angry",
|
||||
"fearful",
|
||||
"disgusted",
|
||||
"surprised",
|
||||
"neutral",
|
||||
],
|
||||
},
|
||||
"minimax-voice-latex": {
|
||||
"type": "bool",
|
||||
@@ -1223,9 +1400,19 @@ CONFIG_METADATA_2 = {
|
||||
"enable": {
|
||||
"description": "启用大语言模型聊天",
|
||||
"type": "bool",
|
||||
"hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
|
||||
"hint": "如需切换大语言模型提供商,请使用 /provider 命令。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"separate_provider": {
|
||||
"description": "提供商会话隔离",
|
||||
"type": "bool",
|
||||
"hint": "启用后,每个会话支持独立选择文本生成、STT、TTS 等提供商。如果会话在使用 /provider 指令时提示无权限,可以将会话加入管理员名单或者使用 /alter_cmd provider member 将指令设为非管理员指令。",
|
||||
},
|
||||
"default_provider_id": {
|
||||
"description": "默认模型提供商 ID",
|
||||
"type": "string",
|
||||
"hint": "可选。每个聊天会话的默认提供商 ID。",
|
||||
},
|
||||
"wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀",
|
||||
"type": "string",
|
||||
@@ -1338,7 +1525,7 @@ CONFIG_METADATA_2 = {
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个STT提供商",
|
||||
"description": "提供商 ID",
|
||||
"type": "string",
|
||||
"hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
@@ -1355,7 +1542,7 @@ CONFIG_METADATA_2 = {
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个TTS提供商",
|
||||
"description": "提供商 ID",
|
||||
"type": "string",
|
||||
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
@@ -1365,6 +1552,11 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "启用后,Bot 将同时输出语音和文字消息。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"use_file_service": {
|
||||
"description": "使用文件服务提供 TTS 语音文件",
|
||||
"type": "bool",
|
||||
"hint": "启用后,如已配置 callback_api_base ,将会使用文件服务提供TTS语音文件",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
@@ -1481,7 +1673,7 @@ CONFIG_METADATA_2 = {
|
||||
"description": "对外可达的回调接口地址",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185,https://example.com 等。"
|
||||
"hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185,https://example.com 等。",
|
||||
},
|
||||
"log_level": {
|
||||
"description": "控制台日志级别",
|
||||
@@ -1500,6 +1692,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "当 t2i_strategy 为 remote 时生效。为空时使用 AstrBot API 服务",
|
||||
},
|
||||
"t2i_use_file_service": {
|
||||
"description": "本地文本转图像使用文件服务提供文件",
|
||||
"type": "bool",
|
||||
"hint": "当 t2i_strategy 为 local 并且配置 callback_api_base 时生效。是否使用文件服务提供文件。",
|
||||
},
|
||||
"pip_install_arg": {
|
||||
"description": "pip 安装参数",
|
||||
"type": "string",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
|
||||
工作流程:
|
||||
@@ -28,7 +28,6 @@ from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
@@ -37,7 +36,7 @@ from astrbot.core.star.star_handler import star_map
|
||||
class AstrBotCoreLifecycle:
|
||||
"""
|
||||
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||
EventBus 等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
"""
|
||||
@@ -54,7 +53,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
"""
|
||||
|
||||
# 初始化日志代理
|
||||
@@ -73,9 +72,6 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化平台管理器
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
||||
|
||||
# 初始化对话管理器
|
||||
self.conversation_manager = ConversationManager(self.db)
|
||||
|
||||
@@ -87,7 +83,6 @@ class AstrBotCoreLifecycle:
|
||||
self.provider_manager,
|
||||
self.platform_manager,
|
||||
self.conversation_manager,
|
||||
self.knowledge_db_manager,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
import json
|
||||
import aiosqlite
|
||||
import os
|
||||
from typing import Any
|
||||
from .plugin_storage import PluginStorage
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
DBPATH = os.path.join(get_astrbot_data_path(), "plugin_data", "sqlite", "plugin_data.db")
|
||||
|
||||
|
||||
class SQLitePluginStorage(PluginStorage):
|
||||
"""插件数据的 SQLite 存储实现类。
|
||||
|
||||
该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。
|
||||
所有数据以 (plugin, key) 作为复合主键进行索引。
|
||||
"""
|
||||
|
||||
_instance = None # Standalone instance of the class
|
||||
_db_conn = None
|
||||
db_path = None
|
||||
|
||||
def __new__(cls):
|
||||
"""
|
||||
创建或获取 SQLitePluginStorage 的单例实例。
|
||||
如果实例已存在,则返回现有实例;否则创建一个新实例。
|
||||
数据在 `data/plugin_data/sqlite/plugin_data.db` 下。
|
||||
"""
|
||||
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
|
||||
cls._instance.db_path = DBPATH
|
||||
return cls._instance
|
||||
|
||||
async def _init_db(self):
|
||||
"""初始化数据库连接(只执行一次)"""
|
||||
if SQLitePluginStorage._db_conn is None:
|
||||
SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path)
|
||||
await self._setup_db()
|
||||
|
||||
async def _setup_db(self):
|
||||
"""
|
||||
异步初始化数据库。
|
||||
|
||||
创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段,
|
||||
其中 plugin 和 key 组合作为主键。
|
||||
"""
|
||||
await self._db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS plugin_data (
|
||||
plugin TEXT,
|
||||
key TEXT,
|
||||
value TEXT,
|
||||
PRIMARY KEY (plugin, key)
|
||||
)
|
||||
""")
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def set(self, plugin: str, key: str, value: Any):
|
||||
"""
|
||||
异步存储数据。
|
||||
|
||||
将指定插件的键值对存入数据库,如果键已存在则更新值。
|
||||
值会被序列化为 JSON 字符串后存储。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
value: 要存储的数据值(任意类型,将被 JSON 序列化)
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
|
||||
(plugin, key, json.dumps(value)),
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def get(self, plugin: str, key: str) -> Any:
|
||||
"""
|
||||
异步获取数据。
|
||||
|
||||
从数据库中获取指定插件和键名对应的值,
|
||||
返回的值会从 JSON 字符串反序列化为原始数据类型。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
|
||||
Returns:
|
||||
Any: 存储的数据值,如果未找到则返回 None
|
||||
"""
|
||||
await self._init_db()
|
||||
async with self._db_conn.execute(
|
||||
"SELECT value FROM plugin_data WHERE plugin = ? AND key = ?",
|
||||
(plugin, key),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return json.loads(row[0]) if row else None
|
||||
|
||||
async def delete(self, plugin: str, key: str):
|
||||
"""
|
||||
异步删除数据。
|
||||
|
||||
从数据库中删除指定插件和键名对应的数据项。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 要删除的数据键名
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key)
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
@@ -11,7 +11,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
super().__init__()
|
||||
self.db_path = db_path
|
||||
|
||||
with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f:
|
||||
with open(
|
||||
os.path.dirname(__file__) + "/sqlite_init.sql", "r", encoding="utf-8"
|
||||
) as f:
|
||||
sql = f.read()
|
||||
|
||||
# 初始化数据库
|
||||
|
||||
46
astrbot/core/db/vec_db/base.py
Normal file
46
astrbot/core/db/vec_db/base.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
similarity: float
|
||||
data: dict
|
||||
|
||||
|
||||
class BaseVecDB:
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化向量数据库
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
Args:
|
||||
query (str): 查询文本
|
||||
top_k (int): 返回的最相似文档的数量
|
||||
Returns:
|
||||
List[Result]: 查询结果
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self, doc_id: str) -> bool:
|
||||
"""
|
||||
删除指定文档。
|
||||
Args:
|
||||
doc_id (str): 要删除的文档 ID
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
...
|
||||
3
astrbot/core/db/vec_db/faiss_impl/__init__.py
Normal file
3
astrbot/core/db/vec_db/faiss_impl/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .vec_db import FaissVecDB
|
||||
|
||||
__all__ = ["FaissVecDB"]
|
||||
121
astrbot/core/db/vec_db/faiss_impl/document_storage.py
Normal file
121
astrbot/core/db/vec_db/faiss_impl/document_storage.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import aiosqlite
|
||||
import os
|
||||
|
||||
|
||||
class DocumentStorage:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.connection = None
|
||||
self.sqlite_init_path = os.path.join(
|
||||
os.path.dirname(__file__), "sqlite_init.sql"
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
||||
if not os.path.exists(self.db_path):
|
||||
await self.connect()
|
||||
async with self.connection.cursor() as cursor:
|
||||
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
|
||||
sql_script = f.read()
|
||||
await cursor.executescript(sql_script)
|
||||
await self.connection.commit()
|
||||
else:
|
||||
await self.connect()
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the SQLite database."""
|
||||
self.connection = await aiosqlite.connect(self.db_path)
|
||||
|
||||
async def get_documents(self, metadata_filters: dict, ids: list = None):
|
||||
"""Retrieve documents by metadata filters and ids.
|
||||
|
||||
Args:
|
||||
metadata_filters (dict): The metadata filters to apply.
|
||||
|
||||
Returns:
|
||||
list: The list of document IDs(primary key, not doc_id) that match the filters.
|
||||
"""
|
||||
# metadata filter -> SQL WHERE clause
|
||||
where_clauses = []
|
||||
values = []
|
||||
for key, val in metadata_filters.items():
|
||||
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
|
||||
values.append(val)
|
||||
if ids is not None and len(ids) > 0:
|
||||
ids = [str(i) for i in ids if i != -1]
|
||||
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
|
||||
values.extend(ids)
|
||||
where_sql = " AND ".join(where_clauses) or "1=1"
|
||||
|
||||
result = []
|
||||
async with self.connection.cursor() as cursor:
|
||||
sql = "SELECT * FROM documents WHERE " + where_sql
|
||||
await cursor.execute(sql, values)
|
||||
for row in await cursor.fetchall():
|
||||
result.append(await self.tuple_to_dict(row))
|
||||
return result
|
||||
|
||||
async def get_document_by_doc_id(self, doc_id: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id of the document to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: The document data.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return await self.tuple_to_dict(row)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id.
|
||||
new_text (str): The new text to update the document with.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
|
||||
)
|
||||
await self.connection.commit()
|
||||
|
||||
async def get_user_ids(self) -> list[str]:
|
||||
"""Retrieve all user IDs from the documents table.
|
||||
|
||||
Returns:
|
||||
list: A list of user IDs.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT DISTINCT user_id FROM documents")
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
||||
async def tuple_to_dict(self, row):
|
||||
"""Convert a tuple to a dictionary.
|
||||
|
||||
Args:
|
||||
row (tuple): The row to convert.
|
||||
|
||||
Returns:
|
||||
dict: The converted dictionary.
|
||||
"""
|
||||
return {
|
||||
"id": row[0],
|
||||
"doc_id": row[1],
|
||||
"text": row[2],
|
||||
"metadata": row[3],
|
||||
"created_at": row[4],
|
||||
"updated_at": row[5],
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the SQLite database."""
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
self.connection = None
|
||||
59
astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Normal file
59
astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Normal file
@@ -0,0 +1,59 @@
|
||||
try:
|
||||
import faiss
|
||||
except ModuleNotFoundError:
|
||||
raise ImportError(
|
||||
"faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。"
|
||||
)
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EmbeddingStorage:
|
||||
def __init__(self, dimension: int, path: str = None):
|
||||
self.dimension = dimension
|
||||
self.path = path
|
||||
self.index = None
|
||||
if path and os.path.exists(path):
|
||||
self.index = faiss.read_index(path)
|
||||
else:
|
||||
base_index = faiss.IndexFlatL2(dimension)
|
||||
self.index = faiss.IndexIDMap(base_index)
|
||||
self.storage = {}
|
||||
|
||||
async def insert(self, vector: np.ndarray, id: int):
|
||||
"""插入向量
|
||||
|
||||
Args:
|
||||
vector (np.ndarray): 要插入的向量
|
||||
id (int): 向量的ID
|
||||
Raises:
|
||||
ValueError: 如果向量的维度与存储的维度不匹配
|
||||
"""
|
||||
if vector.shape[0] != self.dimension:
|
||||
raise ValueError(
|
||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
|
||||
)
|
||||
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
|
||||
self.storage[id] = vector
|
||||
await self.save_index()
|
||||
|
||||
async def search(self, vector: np.ndarray, k: int) -> tuple:
|
||||
"""搜索最相似的向量
|
||||
|
||||
Args:
|
||||
vector (np.ndarray): 查询向量
|
||||
k (int): 返回的最相似向量的数量
|
||||
Returns:
|
||||
tuple: (距离, 索引)
|
||||
"""
|
||||
faiss.normalize_L2(vector)
|
||||
distances, indices = self.index.search(vector, k)
|
||||
return distances, indices
|
||||
|
||||
async def save_index(self):
|
||||
"""保存索引
|
||||
|
||||
Args:
|
||||
path (str): 保存索引的路径
|
||||
"""
|
||||
faiss.write_index(self.index, self.path)
|
||||
17
astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql
Normal file
17
astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql
Normal file
@@ -0,0 +1,17 @@
|
||||
-- 创建文档存储表,包含 faiss 中文档的 id,文档文本,create_at,updated_at
|
||||
CREATE TABLE documents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
doc_id TEXT NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
ALTER TABLE documents
|
||||
ADD COLUMN group_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.group_id')) STORED;
|
||||
ALTER TABLE documents
|
||||
ADD COLUMN user_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED;
|
||||
|
||||
CREATE INDEX idx_documents_user_id ON documents(user_id);
|
||||
CREATE INDEX idx_documents_group_id ON documents(group_id);
|
||||
117
astrbot/core/db/vec_db/faiss_impl/vec_db.py
Normal file
117
astrbot/core/db/vec_db/faiss_impl/vec_db.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import uuid
|
||||
import json
|
||||
import numpy as np
|
||||
from .document_storage import DocumentStorage
|
||||
from .embedding_storage import EmbeddingStorage
|
||||
from ..base import Result, BaseVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
|
||||
|
||||
class FaissVecDB(BaseVecDB):
|
||||
"""
|
||||
A class to represent a vector database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
doc_store_path: str,
|
||||
index_store_path: str,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
):
|
||||
self.doc_store_path = doc_store_path
|
||||
self.index_store_path = index_store_path
|
||||
self.embedding_provider = embedding_provider
|
||||
self.document_storage = DocumentStorage(doc_store_path)
|
||||
self.embedding_storage = EmbeddingStorage(
|
||||
embedding_provider.get_dim(), index_store_path
|
||||
)
|
||||
self.embedding_provider = embedding_provider
|
||||
|
||||
async def initialize(self):
|
||||
await self.document_storage.initialize()
|
||||
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
|
||||
|
||||
vector = await self.embedding_provider.get_embedding(content)
|
||||
vector = np.array(vector, dtype=np.float32)
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
|
||||
(str_id, content, json.dumps(metadata)),
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
result = await self.document_storage.get_document_by_doc_id(str_id)
|
||||
int_id = result["id"]
|
||||
|
||||
# 插入向量到 FAISS
|
||||
await self.embedding_storage.insert(vector, int_id)
|
||||
return int_id
|
||||
|
||||
async def retrieve(
|
||||
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
|
||||
) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
|
||||
Args:
|
||||
query (str): 查询文本
|
||||
k (int): 返回的最相似文档的数量
|
||||
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
|
||||
metadata_filters (dict): 元数据过滤器
|
||||
|
||||
Returns:
|
||||
List[Result]: 查询结果
|
||||
"""
|
||||
embedding = await self.embedding_provider.get_embedding(query)
|
||||
scores, indices = await self.embedding_storage.search(
|
||||
vector=np.array([embedding]).astype("float32"),
|
||||
k=fetch_k if metadata_filters else k,
|
||||
)
|
||||
# TODO: rerank
|
||||
if len(indices[0]) == 0 or indices[0][0] == -1:
|
||||
return []
|
||||
# normalize scores
|
||||
scores[0] = 1.0 - (scores[0] / 2.0)
|
||||
# NOTE: maybe the size is less than k.
|
||||
fetched_docs = await self.document_storage.get_documents(
|
||||
metadata_filters=metadata_filters or {}, ids=indices[0]
|
||||
)
|
||||
if not fetched_docs:
|
||||
return []
|
||||
result_docs = []
|
||||
|
||||
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
|
||||
for i, indice_idx in enumerate(indices[0]):
|
||||
pos = idx_pos.get(indice_idx)
|
||||
if pos is None:
|
||||
continue
|
||||
fetch_doc = fetched_docs[pos]
|
||||
score = scores[0][i]
|
||||
result_docs.append(Result(similarity=float(score), data=fetch_doc))
|
||||
return result_docs[:k]
|
||||
|
||||
async def delete(self, doc_id: int):
|
||||
"""
|
||||
删除一条文档
|
||||
"""
|
||||
await self.document_storage.connection.execute(
|
||||
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
|
||||
async def close(self):
|
||||
await self.document_storage.close()
|
||||
|
||||
async def count_documents(self) -> int:
|
||||
"""
|
||||
计算文档数量
|
||||
"""
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(*) FROM documents")
|
||||
count = await cursor.fetchone()
|
||||
return count[0] if count else 0
|
||||
@@ -26,13 +26,14 @@ class InitialLoader:
|
||||
async def start(self):
|
||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||
|
||||
core_task = []
|
||||
try:
|
||||
await core_lifecycle.initialize()
|
||||
core_task = core_lifecycle.start()
|
||||
except Exception as e:
|
||||
logger.critical(traceback.format_exc())
|
||||
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
|
||||
return
|
||||
|
||||
core_task = core_lifecycle.start()
|
||||
|
||||
self.dashboard_server = AstrBotDashboard(
|
||||
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
|
||||
|
||||
@@ -102,6 +102,10 @@ class BaseMessageComponent(BaseModel):
|
||||
data[k] = v
|
||||
return {"type": self.type.lower(), "data": data}
|
||||
|
||||
async def to_dict(self) -> dict:
|
||||
# 默认情况下,回退到旧的同步 toDict()
|
||||
return self.toDict()
|
||||
|
||||
|
||||
class Plain(BaseMessageComponent):
|
||||
type: ComponentType = "Plain"
|
||||
@@ -118,6 +122,9 @@ class Plain(BaseMessageComponent):
|
||||
self.text.replace("&", "&").replace("[", "[").replace("]", "]")
|
||||
)
|
||||
|
||||
def toDict(self):
|
||||
return {"type": "text", "data": {"text": self.text.strip()}}
|
||||
|
||||
|
||||
class Face(BaseMessageComponent):
|
||||
type: ComponentType = "Face"
|
||||
@@ -235,9 +242,6 @@ class Video(BaseMessageComponent):
|
||||
path: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, file: str, **_):
|
||||
# for k in _.keys():
|
||||
# if k == "c" and _[k] not in [2, 3]:
|
||||
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
|
||||
super().__init__(file=file, **_)
|
||||
|
||||
@staticmethod
|
||||
@@ -250,6 +254,70 @@ class Video(BaseMessageComponent):
|
||||
return Video(file=url, **_)
|
||||
raise Exception("not a valid url")
|
||||
|
||||
async def convert_to_file_path(self) -> str:
|
||||
"""将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL,则会自动进行下载)。
|
||||
|
||||
Returns:
|
||||
str: 视频的本地路径,以绝对路径表示。
|
||||
"""
|
||||
url = self.file
|
||||
if url and url.startswith("file:///"):
|
||||
return url[8:]
|
||||
elif url and url.startswith("http"):
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(url, video_file_path)
|
||||
if os.path.exists(video_file_path):
|
||||
return os.path.abspath(video_file_path)
|
||||
else:
|
||||
raise Exception(f"download failed: {url}")
|
||||
elif os.path.exists(url):
|
||||
return os.path.abspath(url)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
|
||||
async def register_to_file_service(self):
|
||||
"""
|
||||
将视频注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.convert_to_file_path()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
async def to_dict(self):
|
||||
"""需要和 toDict 区分开,toDict 是同步方法"""
|
||||
url_or_path = self.file
|
||||
if url_or_path.startswith("http"):
|
||||
payload_file = url_or_path
|
||||
elif callback_host := astrbot_config.get("callback_api_base"):
|
||||
callback_host = str(callback_host).removesuffix("/")
|
||||
token = await file_token_service.register_file(url_or_path)
|
||||
payload_file = f"{callback_host}/api/file/{token}"
|
||||
logger.debug(f"Generated video file callback link: {payload_file}")
|
||||
else:
|
||||
payload_file = url_or_path
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {
|
||||
"file": payload_file,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class At(BaseMessageComponent):
|
||||
type: ComponentType = "At"
|
||||
@@ -259,6 +327,12 @@ class At(BaseMessageComponent):
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
def toDict(self):
|
||||
return {
|
||||
"type": "at",
|
||||
"data": {"qq": str(self.qq)},
|
||||
}
|
||||
|
||||
|
||||
class AtAll(At):
|
||||
qq: str = "all"
|
||||
@@ -514,27 +588,47 @@ class Node(BaseMessageComponent):
|
||||
id: T.Optional[int] = 0 # 忽略
|
||||
name: T.Optional[str] = "" # qq昵称
|
||||
uin: T.Optional[str] = "0" # qq号
|
||||
content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表
|
||||
content: T.Optional[list[BaseMessageComponent]] = []
|
||||
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
||||
time: T.Optional[int] = 0 # 忽略
|
||||
|
||||
def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_):
|
||||
if isinstance(content, list):
|
||||
_content = None
|
||||
if all(isinstance(item, Node) for item in content):
|
||||
_content = [node.toDict() for node in content]
|
||||
else:
|
||||
_content = ""
|
||||
for chain in content:
|
||||
_content += chain.toString()
|
||||
content = _content
|
||||
elif isinstance(content, Node):
|
||||
content = content.toDict()
|
||||
def __init__(self, content: list[BaseMessageComponent], **_):
|
||||
if isinstance(content, Node):
|
||||
# back
|
||||
content = [content]
|
||||
super().__init__(content=content, **_)
|
||||
|
||||
def toString(self):
|
||||
# logger.warn("Protocol: node doesn't support stringify")
|
||||
return ""
|
||||
async def to_dict(self):
|
||||
data_content = []
|
||||
for comp in self.content:
|
||||
if isinstance(comp, (Image, Record)):
|
||||
# For Image and Record segments, we convert them to base64
|
||||
bs64 = await comp.convert_to_base64()
|
||||
data_content.append(
|
||||
{
|
||||
"type": comp.type.lower(),
|
||||
"data": {"file": f"base64://{bs64}"},
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, File):
|
||||
# For File segments, we need to handle the file differently
|
||||
d = await comp.to_dict()
|
||||
data_content.append(d)
|
||||
elif isinstance(comp, (Node, Nodes)):
|
||||
# For Node segments, we recursively convert them to dict
|
||||
d = await comp.to_dict()
|
||||
data_content.append(d)
|
||||
else:
|
||||
d = comp.toDict()
|
||||
data_content.append(d)
|
||||
return {
|
||||
"type": "node",
|
||||
"data": {
|
||||
"user_id": str(self.uin),
|
||||
"nickname": self.name,
|
||||
"content": data_content,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Nodes(BaseMessageComponent):
|
||||
@@ -545,12 +639,20 @@ class Nodes(BaseMessageComponent):
|
||||
super().__init__(nodes=nodes, **_)
|
||||
|
||||
def toDict(self):
|
||||
"""Deprecated. Use to_dict instead"""
|
||||
ret = {
|
||||
"messages": [],
|
||||
}
|
||||
for node in self.nodes:
|
||||
d = node.toDict()
|
||||
d["data"]["uin"] = str(node.uin) # 转为字符串
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
async def to_dict(self):
|
||||
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
|
||||
ret = {"messages": []}
|
||||
for node in self.nodes:
|
||||
d = await node.to_dict()
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
@@ -723,6 +825,26 @@ class File(BaseMessageComponent):
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
async def to_dict(self):
|
||||
"""需要和 toDict 区分开,toDict 是同步方法"""
|
||||
url_or_path = await self.get_file(allow_return_url=True)
|
||||
if url_or_path.startswith("http"):
|
||||
payload_file = url_or_path
|
||||
elif callback_host := astrbot_config.get("callback_api_base"):
|
||||
callback_host = str(callback_host).removesuffix("/")
|
||||
token = await file_token_service.register_file(url_or_path)
|
||||
payload_file = f"{callback_host}/api/file/{token}"
|
||||
logger.debug(f"Generated file callback link: {payload_file}")
|
||||
else:
|
||||
payload_file = url_or_path
|
||||
return {
|
||||
"type": "file",
|
||||
"data": {
|
||||
"name": self.name,
|
||||
"file": payload_file,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class WechatEmoji(BaseMessageComponent):
|
||||
type: ComponentType = "WechatEmoji"
|
||||
|
||||
@@ -43,31 +43,31 @@ class PreProcessStage(Stage):
|
||||
# STT
|
||||
if self.stt_settings.get("enable", False):
|
||||
# TODO: 独立
|
||||
stt_provider = (
|
||||
self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
||||
)
|
||||
if stt_provider:
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.url:
|
||||
path = component.url.removeprefix("file://")
|
||||
retry = 5
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
if result:
|
||||
logger.info("语音转文本结果: " + result)
|
||||
message_chain[idx] = Plain(result)
|
||||
event.message_str += result
|
||||
event.message_obj.message_str += result
|
||||
break
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"语音转文本失败: {e}")
|
||||
break
|
||||
ctx = self.plugin_manager.context
|
||||
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
|
||||
if not stt_provider:
|
||||
return
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.url:
|
||||
path = component.url.removeprefix("file://")
|
||||
retry = 5
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
if result:
|
||||
logger.info("语音转文本结果: " + result)
|
||||
message_chain[idx] = Plain(result)
|
||||
event.message_str += result
|
||||
event.message_obj.message_str += result
|
||||
break
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"语音转文本失败: {e}")
|
||||
break
|
||||
|
||||
@@ -33,6 +33,7 @@ from mcp.types import (
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
)
|
||||
from astrbot.core import web_chat_back_queue
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -67,7 +68,11 @@ class LLMRequestSubStage(Stage):
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
umo = event.unified_msg_origin
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo)
|
||||
if provider is None:
|
||||
return
|
||||
|
||||
@@ -283,7 +288,66 @@ class LLMRequestSubStage(Stage):
|
||||
if img_b64 := event.get_extra("tool_call_img_respond"):
|
||||
await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
|
||||
event.set_extra("tool_call_img_respond", None)
|
||||
yield
|
||||
|
||||
if event.get_platform_name() == "webchat":
|
||||
# 异步处理 WebChat 特殊情况
|
||||
asyncio.create_task(self._handle_webchat(event, req))
|
||||
|
||||
async def _handle_webchat(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid
|
||||
)
|
||||
if conversation and not req.conversation.title:
|
||||
messages = json.loads(conversation.history)
|
||||
latest_pair = messages[-2:]
|
||||
if not latest_pair:
|
||||
return
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
|
||||
# if len(latest_pair) > 1:
|
||||
# cleaned_text += (
|
||||
# "\nAssistant: " + latest_pair[1].get("content", "").strip()
|
||||
# )
|
||||
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
||||
llm_resp = await provider.text_chat(
|
||||
system_prompt="You are expert in summarizing user's query.",
|
||||
prompt=(
|
||||
f"Please summarize the following query of user:\n"
|
||||
f"{cleaned_text}\n"
|
||||
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
|
||||
"You must use the same language as the user."
|
||||
"If you think the dialog is too short to summarize, only output a special mark: `None`"
|
||||
),
|
||||
)
|
||||
if llm_resp and llm_resp.completion_text:
|
||||
logger.debug(
|
||||
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
|
||||
)
|
||||
title = llm_resp.completion_text.strip()
|
||||
if not title or "None" == title:
|
||||
return
|
||||
await self.conv_manager.update_conversation_title(
|
||||
event.unified_msg_origin, title=title
|
||||
)
|
||||
# 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
|
||||
# webchat adapter 中,session_id 的格式是 f"webchat!{username}!{cid}"
|
||||
# TODO: 优化 WebChat 适配器的对话管理
|
||||
if event.session_id:
|
||||
username, cid = event.session_id.split("!")[1:3]
|
||||
db_helper = self.ctx.plugin_manager.context._db
|
||||
db_helper.update_conversation_title(
|
||||
user_id=username,
|
||||
cid=cid,
|
||||
title=title,
|
||||
)
|
||||
web_chat_back_queue.put_nowait(
|
||||
{
|
||||
"type": "update_title",
|
||||
"cid": cid,
|
||||
"data": title,
|
||||
}
|
||||
)
|
||||
|
||||
async def _handle_llm_response(
|
||||
self,
|
||||
|
||||
@@ -29,11 +29,10 @@ class RespondStage(Stage):
|
||||
Comp.Image: lambda comp: bool(comp.file), # 图片
|
||||
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
|
||||
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
|
||||
Comp.Node: lambda comp: bool(comp.name)
|
||||
and comp.uin != 0
|
||||
and bool(comp.content), # 一个转发节点
|
||||
Comp.Node: lambda comp: bool(comp.content), # 转发节点
|
||||
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
|
||||
Comp.File: lambda comp: bool(comp.file_ or comp.url),
|
||||
Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情
|
||||
}
|
||||
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
@@ -192,6 +191,7 @@ class RespondStage(Stage):
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
decorated_comps = [] # 清空已发送的装饰组件
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import time
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage, registered_stages
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot.core import html_renderer, logger, file_token_service
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage, registered_stages
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -168,30 +169,55 @@ class ResultDecorateStage(Stage):
|
||||
result.chain = new_chain
|
||||
|
||||
# TTS
|
||||
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
and tts_provider
|
||||
):
|
||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info("TTS 请求: " + comp.text)
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info("TTS 结果: " + audio_path)
|
||||
if audio_path:
|
||||
new_chain.append(
|
||||
Record(file=audio_path, url=audio_path)
|
||||
)
|
||||
if(self.ctx.astrbot_config["provider_tts_settings"]["dual_output"]):
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}"
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
|
||||
)
|
||||
new_chain.append(comp)
|
||||
except BaseException:
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
)
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
@@ -225,6 +251,14 @@ class ResultDecorateStage(Stage):
|
||||
if url:
|
||||
if url.startswith("http"):
|
||||
result.chain = [Image.fromURL(url)]
|
||||
elif (
|
||||
self.ctx.astrbot_config["t2i_use_file_service"]
|
||||
and self.ctx.astrbot_config["callback_api_base"]
|
||||
):
|
||||
token = await file_token_service.register_file(url)
|
||||
url = f"{self.ctx.astrbot_config['callback_api_base']}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
result.chain = [Image.fromURL(url)]
|
||||
else:
|
||||
result.chain = [Image.fromFileSystem(url)]
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from astrbot import logger
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.message.components import At
|
||||
from astrbot.core.message.components import At, AtAll
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
@@ -39,6 +39,9 @@ class WakingCheckStage(Stage):
|
||||
self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get(
|
||||
"ignore_bot_self_message", False
|
||||
)
|
||||
self.ignore_at_all = self.ctx.astrbot_config["platform_settings"].get(
|
||||
"ignore_at_all", False
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
@@ -79,10 +82,9 @@ class WakingCheckStage(Stage):
|
||||
if not is_wake:
|
||||
# 检查是否有 at 消息
|
||||
for message in messages:
|
||||
if isinstance(message, At) and (
|
||||
if (isinstance(message, At) and (
|
||||
str(message.qq) == str(event.get_self_id())
|
||||
or str(message.qq) == "all"
|
||||
):
|
||||
)) or (isinstance(message, AtAll) and not self.ignore_at_all):
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
wake_prefix = ""
|
||||
|
||||
@@ -3,9 +3,17 @@ import re
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
from aiocqhttp import CQHttp
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record, File
|
||||
from astrbot.api.message_components import (
|
||||
Image,
|
||||
Node,
|
||||
Nodes,
|
||||
Plain,
|
||||
Record,
|
||||
Video,
|
||||
File,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
from astrbot.api.platform import Group, MessageMember
|
||||
from astrbot.core import file_token_service, astrbot_config, logger
|
||||
|
||||
|
||||
class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
@@ -15,28 +23,38 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
|
||||
@staticmethod
|
||||
async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict:
|
||||
"""修复部分字段"""
|
||||
if isinstance(segment, (Image, Record)):
|
||||
# For Image and Record segments, we convert them to base64
|
||||
bs64 = await segment.convert_to_base64()
|
||||
return {
|
||||
"type": segment.type.lower(),
|
||||
"data": {
|
||||
"file": f"base64://{bs64}",
|
||||
},
|
||||
}
|
||||
elif isinstance(segment, File):
|
||||
# For File segments, we need to handle the file differently
|
||||
d = await segment.to_dict()
|
||||
return d
|
||||
elif isinstance(segment, Video):
|
||||
d = await segment.to_dict()
|
||||
return d
|
||||
else:
|
||||
# For other segments, we simply convert them to a dict by calling toDict
|
||||
return segment.toDict()
|
||||
|
||||
@staticmethod
|
||||
async def _parse_onebot_json(message_chain: MessageChain):
|
||||
"""解析成 OneBot json 格式"""
|
||||
ret = []
|
||||
for segment in message_chain.chain:
|
||||
d = segment.toDict()
|
||||
if isinstance(segment, Plain):
|
||||
d["type"] = "text"
|
||||
d["data"]["text"] = segment.text.strip()
|
||||
# 如果是空文本或者只带换行符的文本,不发送
|
||||
if not d["data"]["text"]:
|
||||
if not segment.text.strip():
|
||||
continue
|
||||
elif isinstance(segment, (Image, Record)):
|
||||
# convert to base64
|
||||
bs64 = await segment.convert_to_base64()
|
||||
d["data"] = {
|
||||
"file": f"base64://{bs64}",
|
||||
}
|
||||
elif isinstance(segment, At):
|
||||
d["data"] = {
|
||||
"qq": str(segment.qq), # 转换为字符串
|
||||
}
|
||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
|
||||
ret.append(d)
|
||||
return ret
|
||||
|
||||
@@ -54,7 +72,8 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
nodes = Nodes([seg])
|
||||
seg = nodes
|
||||
|
||||
payload = seg.toDict()
|
||||
payload = await seg.to_dict()
|
||||
|
||||
if self.get_group_id():
|
||||
payload["group_id"] = self.get_group_id()
|
||||
await self.bot.call_action("send_group_forward_msg", **payload)
|
||||
@@ -64,21 +83,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
"send_private_forward_msg", **payload
|
||||
)
|
||||
elif isinstance(seg, File):
|
||||
d = seg.toDict()
|
||||
url_or_path = await seg.get_file(allow_return_url=True)
|
||||
if url_or_path.startswith("http"):
|
||||
payload_file = url_or_path
|
||||
elif callback_host := astrbot_config.get("callback_api_base"):
|
||||
callback_host = str(callback_host).removesuffix("/")
|
||||
token = await file_token_service.register_file(url_or_path)
|
||||
payload_file = f"{callback_host}/api/file/{token}"
|
||||
logger.debug(f"Generated file callback link: {payload_file}")
|
||||
else:
|
||||
payload_file = url_or_path
|
||||
d["data"] = {
|
||||
"name": seg.name,
|
||||
"file": payload_file,
|
||||
}
|
||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(seg)
|
||||
await self.bot.send(
|
||||
self.message_obj.raw_message,
|
||||
[d],
|
||||
|
||||
@@ -221,6 +221,9 @@ class AiocqhttpAdapter(Platform):
|
||||
a = None
|
||||
if t == "text":
|
||||
current_text = "".join(m["data"]["text"] for m in m_group).strip()
|
||||
if not current_text:
|
||||
# 如果文本段为空,则跳过
|
||||
continue
|
||||
message_str += current_text
|
||||
a = ComponentTypes[t](text=current_text) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
@@ -144,8 +144,8 @@ class TelegramPlatformAdapter(Platform):
|
||||
command_dict = {}
|
||||
skip_commands = {"start"}
|
||||
|
||||
for handler_md in star_handlers_registry._handlers:
|
||||
handler_metadata = handler_md[1]
|
||||
for handler_md in star_handlers_registry:
|
||||
handler_metadata = handler_md
|
||||
if not star_map[handler_metadata.handler_module_path].activated:
|
||||
continue
|
||||
for event_filter in handler_metadata.event_filters:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
import telegramify_markdown
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -18,6 +19,16 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class TelegramPlatformEvent(AstrMessageEvent):
|
||||
# Telegram 的最大消息长度限制
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
SPLIT_PATTERNS = {
|
||||
"paragraph": re.compile(r"\n\n"),
|
||||
"line": re.compile(r"\n"),
|
||||
"sentence": re.compile(r"[.!?。!?]"),
|
||||
"word": re.compile(r"\s"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
@@ -29,8 +40,33 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(client: ExtBot, message: MessageChain, user_name: str):
|
||||
def _split_message(self, text: str) -> list[str]:
|
||||
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
while text:
|
||||
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
||||
chunks.append(text)
|
||||
break
|
||||
|
||||
split_point = self.MAX_MESSAGE_LENGTH
|
||||
segment = text[: self.MAX_MESSAGE_LENGTH]
|
||||
|
||||
for _, pattern in self.SPLIT_PATTERNS.items():
|
||||
if matches := list(pattern.finditer(segment)):
|
||||
last_match = matches[-1]
|
||||
split_point = last_match.end()
|
||||
break
|
||||
|
||||
chunks.append(text[:split_point])
|
||||
text = text[split_point:].lstrip()
|
||||
|
||||
return chunks
|
||||
|
||||
async def send_with_client(
|
||||
self, client: ExtBot, message: MessageChain, user_name: str
|
||||
):
|
||||
image_path = None
|
||||
|
||||
has_reply = False
|
||||
@@ -59,19 +95,22 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
|
||||
if isinstance(i, Plain):
|
||||
if at_user_id and not at_flag:
|
||||
i.text = f"@{at_user_id} " + i.text
|
||||
i.text = f"@{at_user_id} {i.text}"
|
||||
at_flag = True
|
||||
text = i.text
|
||||
try:
|
||||
text = telegramify_markdown.markdownify(
|
||||
i.text, max_line_length=None, normalize_whitespace=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MarkdownV2 conversion failed: {e}. Using plain text instead."
|
||||
)
|
||||
return
|
||||
await client.send_message(text=text, parse_mode="MarkdownV2", **payload)
|
||||
chunks = self._split_message(i.text)
|
||||
for chunk in chunks:
|
||||
try:
|
||||
md_text = telegramify_markdown.markdownify(
|
||||
chunk, max_line_length=None, normalize_whitespace=False
|
||||
)
|
||||
await client.send_message(
|
||||
text=md_text, parse_mode="MarkdownV2", **payload
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MarkdownV2 send failed: {e}. Using plain text instead."
|
||||
)
|
||||
await client.send_message(text=chunk, **payload)
|
||||
elif isinstance(i, Image):
|
||||
image_path = await i.convert_to_file_path()
|
||||
await client.send_photo(photo=image_path, **payload)
|
||||
@@ -147,17 +186,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
continue
|
||||
|
||||
# Plain
|
||||
if not message_id:
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
current_content = delta
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
message_id = msg.message_id
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 记录初始消息发送时间
|
||||
else:
|
||||
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
time_since_last_edit = current_time - last_edit_time
|
||||
|
||||
@@ -176,6 +205,18 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 更新上次编辑的时间
|
||||
else:
|
||||
# delta 长度一般不会大于 4096,因此这里直接发送
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
current_content = delta
|
||||
delta = ""
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
message_id = msg.message_id
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 记录初始消息发送时间
|
||||
|
||||
try:
|
||||
if delta and current_content != delta:
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
import websockets
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.api.message_components import Plain, Image, At, Record
|
||||
from astrbot.api.platform import Platform, PlatformMetadata
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.astrbot_message import (
|
||||
@@ -22,6 +23,13 @@ from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from .wechatpadpro_message_event import WeChatPadProMessageEvent
|
||||
|
||||
try:
|
||||
from .xml_data_parser import GeweDataParser
|
||||
except ImportError as e:
|
||||
logger.warning(
|
||||
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
|
||||
class WeChatPadProAdapter(Platform):
|
||||
@@ -59,6 +67,18 @@ class WeChatPadProAdapter(Platform):
|
||||
) # 持久化文件路径
|
||||
self.ws_handle_task = None
|
||||
|
||||
# 添加图片消息缓存,用于引用消息处理
|
||||
self.cached_images = {}
|
||||
"""缓存图片消息。key是NewMsgId (对应引用消息的svrid),value是图片的base64数据"""
|
||||
# 设置缓存大小限制,避免内存占用过大
|
||||
self.max_image_cache = 50
|
||||
|
||||
# 添加文本消息缓存,用于引用消息处理
|
||||
self.cached_texts = {}
|
||||
"""缓存文本消息。key是NewMsgId (对应引用消息的svrid),value是消息文本内容"""
|
||||
# 设置文本缓存大小限制
|
||||
self.max_text_cache = 100
|
||||
|
||||
async def run(self) -> None:
|
||||
"""
|
||||
启动平台适配器的运行实例。
|
||||
@@ -69,39 +89,42 @@ class WeChatPadProAdapter(Platform):
|
||||
self.auth_key = loaded_credentials.get("auth_key")
|
||||
self.wxid = loaded_credentials.get("wxid")
|
||||
|
||||
isLoginIn = await self.check_online_status()
|
||||
|
||||
# 检查在线状态
|
||||
if self.auth_key and await self.check_online_status():
|
||||
logger.info("WeChatPadPro 设备已在线,跳过扫码登录。")
|
||||
if self.auth_key and isLoginIn:
|
||||
logger.info("WeChatPadPro 设备已在线,凭据存在,跳过扫码登录。")
|
||||
# 如果在线,连接 WebSocket 接收消息
|
||||
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
|
||||
else:
|
||||
logger.info("WeChatPadPro 设备不在线或无可用凭据,开始扫码登录流程。")
|
||||
# 1. 生成授权码
|
||||
await self.generate_auth_key()
|
||||
|
||||
if not self.auth_key:
|
||||
logger.error("无法获取授权码,WeChatPadPro 适配器启动失败。")
|
||||
return
|
||||
logger.info("WeChatPadPro 无可用凭据,将生成新的授权码。")
|
||||
await self.generate_auth_key()
|
||||
|
||||
# 2. 获取登录二维码
|
||||
qr_code_url = await self.get_login_qr_code()
|
||||
if not isLoginIn:
|
||||
logger.info("WeChatPadPro 设备已离线,开始扫码登录。")
|
||||
qr_code_url = await self.get_login_qr_code()
|
||||
|
||||
if qr_code_url:
|
||||
logger.info(f"请扫描以下二维码登录: {qr_code_url}")
|
||||
else:
|
||||
logger.error("无法获取登录二维码。")
|
||||
return
|
||||
if qr_code_url:
|
||||
logger.info(f"请扫描以下二维码登录: {qr_code_url}")
|
||||
else:
|
||||
logger.error("无法获取登录二维码。")
|
||||
return
|
||||
|
||||
# 3. 检测扫码状态
|
||||
login_successful = await self.check_login_status()
|
||||
# 3. 检测扫码状态
|
||||
login_successful = await self.check_login_status()
|
||||
|
||||
if login_successful:
|
||||
# 登录成功后,连接 WebSocket 接收消息
|
||||
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
|
||||
else:
|
||||
logger.warning("登录失败或超时,WeChatPadPro 适配器将关闭。")
|
||||
await self.terminate()
|
||||
return
|
||||
if login_successful:
|
||||
logger.info("登录成功,WeChatPadPro适配器已连接。")
|
||||
else:
|
||||
logger.warning("登录失败或超时,WeChatPadPro 适配器将关闭。")
|
||||
await self.terminate()
|
||||
return
|
||||
|
||||
# 登录成功后,连接 WebSocket 接收消息
|
||||
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
|
||||
|
||||
self._shutdown_event = asyncio.Event()
|
||||
await self._shutdown_event.wait()
|
||||
@@ -156,16 +179,23 @@ class WeChatPadProAdapter(Platform):
|
||||
if login_state == 1:
|
||||
logger.info("WeChatPadPro 设备当前在线。")
|
||||
return True
|
||||
else:
|
||||
logger.info(
|
||||
f"WeChatPadPro 设备不在线,登录状态: {login_state}"
|
||||
)
|
||||
# login_state == 3 为离线状态
|
||||
elif login_state == 3:
|
||||
logger.info("WeChatPadPro 设备不在线。")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"未知的在线状态: {login_state:}")
|
||||
return False
|
||||
# Code == 300 为微信退出状态。
|
||||
elif response.status == 200 and response_data.get("Code") == 300:
|
||||
logger.info("WeChatPadPro 设备已退出。")
|
||||
return False
|
||||
else:
|
||||
logger.error(
|
||||
f"检查在线状态失败: {response.status}, {response_data}"
|
||||
)
|
||||
return False
|
||||
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return False
|
||||
@@ -179,7 +209,7 @@ class WeChatPadProAdapter(Platform):
|
||||
"""
|
||||
url = f"{self.base_url}/admin/GenAuthKey1"
|
||||
params = {"key": self.admin_key}
|
||||
payload = {"Count": 1, "Days": 30} # 生成一个有效期30天的授权码
|
||||
payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
@@ -336,12 +366,10 @@ class WeChatPadProAdapter(Platform):
|
||||
message = await asyncio.wait_for(
|
||||
websocket.recv(), timeout=wait_time
|
||||
)
|
||||
logger.info(message)
|
||||
# logger.debug(message) # 不显示原始消息内容
|
||||
asyncio.create_task(self.handle_websocket_message(message))
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"WebSocket 连接空闲超过 {wait_time} s"
|
||||
)
|
||||
logger.warning(f"WebSocket 连接空闲超过 {wait_time} s")
|
||||
break
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
logger.info("WebSocket 连接正常关闭。")
|
||||
@@ -350,7 +378,9 @@ class WeChatPadProAdapter(Platform):
|
||||
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 连接失败: {e}")
|
||||
logger.error(
|
||||
f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态,或尝试重启WeChatPadPro适配器。"
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def handle_websocket_message(self, message: str):
|
||||
@@ -443,6 +473,7 @@ class WeChatPadProAdapter(Platform):
|
||||
"""
|
||||
if from_user_name == "weixin":
|
||||
return False
|
||||
at_me = False
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = from_user_name
|
||||
@@ -464,6 +495,14 @@ class WeChatPadProAdapter(Platform):
|
||||
abm.session_id = f"{from_user_name}_{to_user_name}"
|
||||
else:
|
||||
abm.session_id = from_user_name
|
||||
|
||||
msg_source = raw_message.get("msg_source", "")
|
||||
if self.wxid in msg_source:
|
||||
at_me = True
|
||||
if "在群聊中@了你" in raw_message.get("push_content", ""):
|
||||
at_me = True
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id, name=""))
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.group_id = ""
|
||||
@@ -544,6 +583,32 @@ class WeChatPadProAdapter(Platform):
|
||||
logger.error(f"下载图片时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def download_voice(
|
||||
self, to_user_name: str, new_msg_id: str, bufid: str, length: int
|
||||
):
|
||||
"""下载原始音频。"""
|
||||
url = f"{self.base_url}/message/GetMsgVoice"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {
|
||||
"Bufid": bufid,
|
||||
"ToUserName": to_user_name,
|
||||
"NewMsgId": new_msg_id,
|
||||
"Length": length,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
logger.error(f"下载音频失败: {response.status}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"下载音频时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def _process_message_content(
|
||||
self, abm: AstrBotMessage, raw_message: dict, msg_type: int, content: str
|
||||
):
|
||||
@@ -555,12 +620,69 @@ class WeChatPadProAdapter(Platform):
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
parts = content.split(":\n", 1)
|
||||
if len(parts) == 2:
|
||||
abm.message_str = parts[1]
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
message_content = parts[1]
|
||||
abm.message_str = message_content
|
||||
|
||||
# 检查是否@了机器人,参考 gewechat 的实现方式
|
||||
# 微信大部分客户端在@用户昵称后面,紧接着是一个\u2005字符(四分之一空格)
|
||||
at_me = False
|
||||
|
||||
# 检查 msg_source 中是否包含机器人的 wxid
|
||||
# wechatpadpro 的格式: <atuserlist>wxid</atuserlist>
|
||||
# gewechat 的格式: <atuserlist><![CDATA[wxid]]></atuserlist>
|
||||
msg_source = raw_message.get("msg_source", "")
|
||||
if f"<atuserlist>{abm.self_id}</atuserlist>" in msg_source or f"<atuserlist>{abm.self_id}," in msg_source or f",{abm.self_id}</atuserlist>" in msg_source:
|
||||
at_me = True
|
||||
|
||||
# 也检查 push_content 中是否有@提示
|
||||
push_content = raw_message.get("push_content", "")
|
||||
if "在群聊中@了你" in push_content:
|
||||
at_me = True
|
||||
|
||||
if at_me:
|
||||
# 被@了,在消息开头插入At组件(参考gewechat的做法)
|
||||
bot_nickname = await self._get_group_member_nickname(abm.group_id, abm.self_id)
|
||||
abm.message.insert(0, At(qq=abm.self_id, name=bot_nickname or abm.self_id))
|
||||
|
||||
# 只有当消息内容不仅仅是@时才添加Plain组件
|
||||
if "\u2005" in message_content:
|
||||
# 检查@之后是否还有其他内容
|
||||
parts = message_content.split("\u2005")
|
||||
if len(parts) > 1 and any(part.strip() for part in parts[1:]):
|
||||
abm.message.append(Plain(message_content))
|
||||
else:
|
||||
# 检查是否只包含@机器人
|
||||
is_pure_at = False
|
||||
if bot_nickname and message_content.strip() == f"@{bot_nickname}":
|
||||
is_pure_at = True
|
||||
if not is_pure_at:
|
||||
abm.message.append(Plain(message_content))
|
||||
else:
|
||||
# 没有@机器人,作为普通文本处理
|
||||
abm.message.append(Plain(message_content))
|
||||
else:
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
else: # 私聊消息
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
|
||||
# 缓存文本消息,以便引用消息可以查找
|
||||
try:
|
||||
# 获取msg_id作为缓存的key
|
||||
new_msg_id = raw_message.get("new_msg_id")
|
||||
if new_msg_id:
|
||||
# 限制缓存大小
|
||||
if (
|
||||
len(self.cached_texts) >= self.max_text_cache
|
||||
and self.cached_texts
|
||||
):
|
||||
# 删除最早的一条缓存
|
||||
oldest_key = next(iter(self.cached_texts))
|
||||
self.cached_texts.pop(oldest_key)
|
||||
|
||||
logger.debug(f"缓存文本消息,new_msg_id={new_msg_id}")
|
||||
self.cached_texts[str(new_msg_id)] = content
|
||||
except Exception as e:
|
||||
logger.error(f"缓存文本消息失败: {e}")
|
||||
elif msg_type == 3:
|
||||
# 图片消息
|
||||
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||
@@ -574,15 +696,87 @@ class WeChatPadProAdapter(Platform):
|
||||
)
|
||||
if image_bs64_data:
|
||||
abm.message.append(Image.fromBase64(image_bs64_data))
|
||||
# 缓存图片,以便引用消息可以查找
|
||||
try:
|
||||
# 获取msg_id作为缓存的key
|
||||
new_msg_id = raw_message.get("new_msg_id")
|
||||
if new_msg_id:
|
||||
# 限制缓存大小
|
||||
if (
|
||||
len(self.cached_images) >= self.max_image_cache
|
||||
and self.cached_images
|
||||
):
|
||||
# 删除最早的一条缓存
|
||||
oldest_key = next(iter(self.cached_images))
|
||||
self.cached_images.pop(oldest_key)
|
||||
|
||||
logger.debug(f"缓存图片消息,new_msg_id={new_msg_id}")
|
||||
self.cached_images[str(new_msg_id)] = image_bs64_data
|
||||
except Exception as e:
|
||||
logger.error(f"缓存图片消息失败: {e}")
|
||||
elif msg_type == 47:
|
||||
# 视频消息 (注意:表情消息也是 47,需要区分)
|
||||
logger.warning("收到视频消息,待实现。")
|
||||
data_parser = GeweDataParser(
|
||||
content=content,
|
||||
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
||||
raw_message=raw_message,
|
||||
)
|
||||
emoji_message = data_parser.parse_emoji()
|
||||
if emoji_message is not None:
|
||||
abm.message.append(emoji_message)
|
||||
elif msg_type == 50:
|
||||
# 语音/视频
|
||||
logger.warning("收到语音/视频消息,待实现。")
|
||||
elif msg_type == 34:
|
||||
# 语音消息
|
||||
bufid = 0
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
new_msg_id = raw_message.get("new_msg_id")
|
||||
data_parser = GeweDataParser(
|
||||
content=content,
|
||||
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
||||
raw_message=raw_message,
|
||||
)
|
||||
|
||||
voicemsg = data_parser._format_to_xml().find("voicemsg")
|
||||
bufid = voicemsg.get("bufid") or "0"
|
||||
length = int(voicemsg.get("length") or 0)
|
||||
voice_resp = await self.download_voice(
|
||||
to_user_name=to_user_name,
|
||||
new_msg_id=new_msg_id,
|
||||
bufid=bufid,
|
||||
length=length,
|
||||
)
|
||||
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
|
||||
if voice_bs64_data:
|
||||
voice_bs64_data = base64.b64decode(voice_bs64_data)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(
|
||||
temp_dir, f"wechatpadpro_voice_{abm.message_id}.silk"
|
||||
)
|
||||
|
||||
async with await anyio.open_file(file_path, "wb") as f:
|
||||
await f.write(voice_bs64_data)
|
||||
abm.message.append(Record(file=file_path, url=file_path))
|
||||
elif msg_type == 49:
|
||||
# 引用消息
|
||||
logger.warning("收到引用消息,待实现。")
|
||||
try:
|
||||
parser = GeweDataParser(
|
||||
content=content,
|
||||
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
||||
cached_texts=self.cached_texts,
|
||||
cached_images=self.cached_images,
|
||||
raw_message=raw_message,
|
||||
downloader=self._download_raw_image,
|
||||
)
|
||||
components = await parser.parse_mutil_49()
|
||||
if components:
|
||||
abm.message.extend(components)
|
||||
abm.message_str = "\n".join(
|
||||
c.text for c in components if isinstance(c, Plain)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"msg_type 49 处理失败: {e}")
|
||||
abm.message.append(Plain("[XML 消息处理失败]"))
|
||||
abm.message_str = "[XML 消息处理失败]"
|
||||
else:
|
||||
logger.warning(f"收到未处理的消息类型: {msg_type}。")
|
||||
|
||||
@@ -627,3 +821,67 @@ class WeChatPadProAdapter(Platform):
|
||||
)
|
||||
# 调用实例方法 send
|
||||
await sending_event.send(message_chain)
|
||||
|
||||
async def get_contact_list(self):
|
||||
"""
|
||||
获取联系人列表。
|
||||
"""
|
||||
url = f"{self.base_url}/friend/GetContactList"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {"CurrentChatRoomContactSeq": 0, "CurrentWxcontactSeq": 0}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"获取联系人列表失败: {response.status}")
|
||||
return None
|
||||
result = await response.json()
|
||||
if result.get("Code") == 200 and result.get("Data"):
|
||||
contact_list = (
|
||||
result.get("Data", {})
|
||||
.get("ContactList", {})
|
||||
.get("contactUsernameList", [])
|
||||
)
|
||||
return contact_list
|
||||
else:
|
||||
logger.error(f"获取联系人列表失败: {result}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取联系人列表时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def get_contact_details_list(
|
||||
self, room_wx_id_list: list[str] = None, user_names: list[str] = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
获取联系人详情列表。
|
||||
"""
|
||||
if room_wx_id_list is None:
|
||||
room_wx_id_list = []
|
||||
if user_names is None:
|
||||
user_names = []
|
||||
url = f"{self.base_url}/friend/GetContactDetailsList"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {"RoomWxIDList": room_wx_id_list, "UserNames": user_names}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"获取联系人详情列表失败: {response.status}")
|
||||
return None
|
||||
result = await response.json()
|
||||
if result.get("Code") == 200 and result.get("Data"):
|
||||
contact_list = result.get("Data", {}).get("contactList", {})
|
||||
return contact_list
|
||||
else:
|
||||
logger.error(f"获取联系人详情列表失败: {result}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取联系人详情列表时发生错误: {e}")
|
||||
return None
|
||||
|
||||
@@ -7,11 +7,17 @@ import aiohttp
|
||||
from PIL import Image as PILImage # 使用别名避免冲突
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import Image, Plain # Import Image
|
||||
from astrbot.core.message.components import (
|
||||
Image,
|
||||
Plain,
|
||||
WechatEmoji,
|
||||
Record,
|
||||
) # Import Image
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType
|
||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk_base64
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wechatpadpro_adapter import WeChatPadProAdapter
|
||||
@@ -38,6 +44,10 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
|
||||
await self._send_text(session, comp.text)
|
||||
elif isinstance(comp, Image):
|
||||
await self._send_image(session, comp)
|
||||
elif isinstance(comp, WechatEmoji):
|
||||
await self._send_emoji(session, comp)
|
||||
elif isinstance(comp, Record):
|
||||
await self._send_voice(session, comp)
|
||||
await super().send(message)
|
||||
|
||||
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
|
||||
@@ -73,12 +83,42 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
|
||||
message_text = text
|
||||
payload = {
|
||||
"MsgItem": [
|
||||
{"MsgType": 1, "TextContent": message_text, "ToUserName": self.session_id}
|
||||
{
|
||||
"MsgType": 1,
|
||||
"TextContent": message_text,
|
||||
"ToUserName": self.session_id,
|
||||
}
|
||||
]
|
||||
}
|
||||
url = f"{self.adapter.base_url}/message/SendTextMessage"
|
||||
await self._post(session, url, payload)
|
||||
|
||||
async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji):
|
||||
payload = {
|
||||
"EmojiList": [
|
||||
{
|
||||
"EmojiMd5": comp.md5,
|
||||
"EmojiSize": comp.md5_len,
|
||||
"ToUserName": self.session_id,
|
||||
}
|
||||
]
|
||||
}
|
||||
url = f"{self.adapter.base_url}/message/SendEmojiMessage"
|
||||
await self._post(session, url, payload)
|
||||
|
||||
async def _send_voice(self, session: aiohttp.ClientSession, comp: Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 默认已经存在 data/temp 中
|
||||
b64, duration = await wav_to_tencent_silk_base64(record_path)
|
||||
payload = {
|
||||
"ToUserName": self.session_id,
|
||||
"VoiceData": b64,
|
||||
"VoiceFormat": 4,
|
||||
"VoiceSecond": duration,
|
||||
}
|
||||
url = f"{self.adapter.base_url}/message/SendVoice"
|
||||
await self._post(session, url, payload)
|
||||
|
||||
@staticmethod
|
||||
def _validate_base64(b64: str) -> bytes:
|
||||
return base64.b64decode(b64, validate=True)
|
||||
|
||||
160
astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py
Normal file
160
astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from defusedxml import ElementTree as eT
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.message_components import (
|
||||
WechatEmoji as Emoji,
|
||||
Plain,
|
||||
Image,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
|
||||
|
||||
class GeweDataParser:
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
is_private_chat: bool = False,
|
||||
cached_texts=None,
|
||||
cached_images=None,
|
||||
raw_message: dict = None,
|
||||
downloader=None,
|
||||
):
|
||||
self._xml = None
|
||||
self.content = content
|
||||
self.is_private_chat = is_private_chat
|
||||
self.cached_texts = cached_texts or {}
|
||||
self.cached_images = cached_images or {}
|
||||
self.downloader = downloader
|
||||
|
||||
raw_message = raw_message or {}
|
||||
self.from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||
self.to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
self.msg_id = raw_message.get("msg_id", "")
|
||||
|
||||
def _format_to_xml(self):
|
||||
if self._xml:
|
||||
return self._xml
|
||||
|
||||
try:
|
||||
msg_str = self.content
|
||||
if not self.is_private_chat:
|
||||
parts = self.content.split(":\n", 1)
|
||||
msg_str = parts[1] if len(parts) == 2 else self.content
|
||||
|
||||
self._xml = eT.fromstring(msg_str)
|
||||
return self._xml
|
||||
except Exception as e:
|
||||
logger.error(f"[XML解析失败] {e}")
|
||||
raise
|
||||
|
||||
async def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
|
||||
"""
|
||||
处理 msg_type == 49 的多种 appmsg 类型(目前支持 type==57)
|
||||
"""
|
||||
try:
|
||||
appmsg_type = self._format_to_xml().findtext(".//appmsg/type")
|
||||
if appmsg_type == "57":
|
||||
return await self.parse_reply()
|
||||
except Exception as e:
|
||||
logger.warning(f"[parse_mutil_49] 解析失败: {e}")
|
||||
return None
|
||||
|
||||
async def parse_reply(self) -> list[BaseMessageComponent]:
|
||||
"""
|
||||
处理 type == 57 的引用消息:支持文本(1)、图片(3)、嵌套49(49)
|
||||
"""
|
||||
components = []
|
||||
|
||||
try:
|
||||
appmsg = self._format_to_xml().find("appmsg")
|
||||
if appmsg is None:
|
||||
return [Plain("[引用消息解析失败]")]
|
||||
|
||||
refermsg = appmsg.find("refermsg")
|
||||
if refermsg is None:
|
||||
return [Plain("[引用消息解析失败]")]
|
||||
|
||||
quote_type = int(refermsg.findtext("type", "0"))
|
||||
nickname = refermsg.findtext("displayname", "未知发送者")
|
||||
quote_content = refermsg.findtext("content", "")
|
||||
svrid = refermsg.findtext("svrid")
|
||||
|
||||
match quote_type:
|
||||
case 1: # 文本引用
|
||||
quoted_text = self.cached_texts.get(str(svrid), quote_content)
|
||||
components.append(Plain(f"[引用] {nickname}: {quoted_text}"))
|
||||
|
||||
case 3: # 图片引用
|
||||
quoted_image_b64 = self.cached_images.get(str(svrid))
|
||||
if not quoted_image_b64:
|
||||
try:
|
||||
quote_xml = eT.fromstring(quote_content)
|
||||
img = quote_xml.find("img")
|
||||
cdn_url = (
|
||||
img.get("cdnbigimgurl") or img.get("cdnmidimgurl")
|
||||
if img is not None
|
||||
else None
|
||||
)
|
||||
if cdn_url and self.downloader:
|
||||
image_resp = await self.downloader(
|
||||
self.from_user_name, self.to_user_name, self.msg_id
|
||||
)
|
||||
quoted_image_b64 = (
|
||||
image_resp.get("Data", {})
|
||||
.get("Data", {})
|
||||
.get("Buffer")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[引用图片解析失败] svrid={svrid} err={e}")
|
||||
|
||||
if quoted_image_b64:
|
||||
components.extend(
|
||||
[
|
||||
Image.fromBase64(quoted_image_b64),
|
||||
Plain(f"[引用] {nickname}: [引用的图片]"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
components.append(
|
||||
Plain(f"[引用] {nickname}: [引用的图片 - 未能获取]")
|
||||
)
|
||||
|
||||
case 49: # 嵌套引用
|
||||
try:
|
||||
nested_root = eT.fromstring(quote_content)
|
||||
nested_title = nested_root.findtext(".//appmsg/title", "")
|
||||
components.append(Plain(f"[引用] {nickname}: {nested_title}"))
|
||||
except Exception as e:
|
||||
logger.warning(f"[嵌套引用解析失败] err={e}")
|
||||
components.append(Plain(f"[引用] {nickname}: [嵌套引用消息]"))
|
||||
|
||||
case _: # 其他未识别类型
|
||||
logger.info(f"[未知引用类型] quote_type={quote_type}")
|
||||
components.append(Plain(f"[引用] {nickname}: [不支持的引用类型]"))
|
||||
|
||||
# 主消息标题
|
||||
title = appmsg.findtext("title", "")
|
||||
if title:
|
||||
components.append(Plain(title))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[parse_reply] 总体解析失败: {e}")
|
||||
return [Plain("[引用消息解析失败]")]
|
||||
|
||||
return components
|
||||
|
||||
def parse_emoji(self) -> Emoji | None:
|
||||
"""
|
||||
处理 msg_type == 47 的表情消息(emoji)
|
||||
"""
|
||||
try:
|
||||
emoji_element = self._format_to_xml().find(".//emoji")
|
||||
if emoji_element is not None:
|
||||
return Emoji(
|
||||
md5=emoji_element.get("md5"),
|
||||
md5_len=emoji_element.get("len"),
|
||||
cdnurl=emoji_element.get("cdnurl"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[parse_emoji] 解析失败: {e}")
|
||||
|
||||
return None
|
||||
@@ -20,7 +20,7 @@ from requests import Response
|
||||
from wechatpy.utils import check_signature
|
||||
from wechatpy.crypto import WeChatCrypto
|
||||
from wechatpy import WeChatClient
|
||||
from wechatpy.messages import TextMessage, ImageMessage, VoiceMessage
|
||||
from wechatpy.messages import TextMessage, ImageMessage, VoiceMessage, BaseMessage
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy import parse_message
|
||||
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
|
||||
@@ -87,7 +87,11 @@ class WecomServer:
|
||||
logger.info(f"解析成功: {msg}")
|
||||
|
||||
if self.callback:
|
||||
await self.callback(msg)
|
||||
result_xml = await self.callback(msg)
|
||||
if not result_xml:
|
||||
return "success"
|
||||
if isinstance(result_xml, str):
|
||||
return result_xml
|
||||
|
||||
return "success"
|
||||
|
||||
@@ -117,6 +121,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
self.api_base_url = platform_config.get(
|
||||
"api_base_url", "https://api.weixin.qq.com/cgi-bin/"
|
||||
)
|
||||
self.active_send_mode = self.config.get("active_send_mode", False)
|
||||
|
||||
if not self.api_base_url:
|
||||
self.api_base_url = "https://api.weixin.qq.com/cgi-bin/"
|
||||
@@ -138,9 +143,29 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
|
||||
async def callback(msg):
|
||||
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
|
||||
# msgid -> Future
|
||||
self.wexin_event_workers: dict[str, asyncio.Future] = {}
|
||||
|
||||
async def callback(msg: BaseMessage):
|
||||
try:
|
||||
await self.convert_message(msg)
|
||||
if self.active_send_mode:
|
||||
await self.convert_message(msg, None)
|
||||
else:
|
||||
if msg.id in self.wexin_event_workers:
|
||||
future = self.wexin_event_workers[msg.id]
|
||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||
else:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.wexin_event_workers[msg.id] = future
|
||||
await self.convert_message(msg, future)
|
||||
# I love shield so much!
|
||||
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
|
||||
logger.debug(f"Got future result: {result}")
|
||||
self.wexin_event_workers.pop(msg.id, None)
|
||||
return result # xml. see weixin_offacc_event.py
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"转换消息时出现异常: {e}")
|
||||
|
||||
@@ -163,7 +188,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
async def run(self):
|
||||
await self.server.start_polling()
|
||||
|
||||
async def convert_message(self, msg) -> AstrBotMessage | None:
|
||||
async def convert_message(
|
||||
self, msg, future: asyncio.Future = None
|
||||
) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
if isinstance(msg, TextMessage):
|
||||
abm.message_str = msg.content
|
||||
@@ -177,7 +204,6 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
elif msg.type == "image":
|
||||
assert isinstance(msg, ImageMessage)
|
||||
abm.message_str = "[图片]"
|
||||
@@ -191,7 +217,6 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
elif msg.type == "voice":
|
||||
assert isinstance(msg, VoiceMessage)
|
||||
|
||||
@@ -209,7 +234,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
audio = AudioSegment.from_file(path)
|
||||
audio.export(path_wav, format="wav")
|
||||
except Exception as e:
|
||||
logger.error(f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。")
|
||||
logger.error(
|
||||
f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。"
|
||||
)
|
||||
path_wav = path
|
||||
return
|
||||
|
||||
@@ -224,11 +251,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
else:
|
||||
logger.warning(f"暂未实现的事件: {msg.type}")
|
||||
future.set_result(None)
|
||||
return
|
||||
|
||||
# 很不优雅 :(
|
||||
abm.raw_message = {
|
||||
"message": msg,
|
||||
"future": future,
|
||||
"active_send_mode": self.active_send_mode,
|
||||
}
|
||||
logger.info(f"abm: {abm}")
|
||||
await self.handle_msg(abm)
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from wechatpy import WeChatClient
|
||||
from wechatpy.replies import TextReply, ImageReply, VoiceReply
|
||||
|
||||
|
||||
from astrbot.api import logger
|
||||
|
||||
@@ -82,12 +84,23 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
message_obj = self.message_obj
|
||||
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
plain_chunks = await self.split_plain(comp.text)
|
||||
for chunk in plain_chunks:
|
||||
self.client.message.send_text(message_obj.sender.user_id, chunk)
|
||||
if active_send_mode:
|
||||
self.client.message.send_text(message_obj.sender.user_id, chunk)
|
||||
else:
|
||||
reply = TextReply(
|
||||
content=chunk,
|
||||
message=self.message_obj.raw_message["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
@@ -102,10 +115,22 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信公众平台上传图片返回: {response}")
|
||||
self.client.message.send_image(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_image(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
reply = ImageReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 转成amr
|
||||
@@ -124,10 +149,23 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
)
|
||||
return
|
||||
logger.info(f"微信公众平台上传语音返回: {response}")
|
||||
self.client.message.send_voice(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_voice(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
reply = VoiceReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
else:
|
||||
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ class ProviderType(enum.Enum):
|
||||
CHAT_COMPLETION = "chat_completion"
|
||||
SPEECH_TO_TEXT = "speech_to_text"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -155,7 +156,9 @@ class ProviderRequest:
|
||||
if self.image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": self.prompt if self.prompt else "[图片]"}],
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt if self.prompt else "[图片]"}
|
||||
],
|
||||
}
|
||||
for image_url in self.image_urls:
|
||||
if image_url.startswith("http"):
|
||||
|
||||
@@ -4,6 +4,7 @@ import textwrap
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from typing import Dict, List, Awaitable, Literal, Any
|
||||
from dataclasses import dataclass
|
||||
@@ -20,6 +21,13 @@ try:
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
||||
)
|
||||
|
||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||
|
||||
SUPPORTED_TYPES = [
|
||||
@@ -96,7 +104,10 @@ class MCPClient:
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
|
||||
如果 `url` 参数存在,则使用 SSE 的方式连接到 MCP 服务。
|
||||
如果 `url` 参数存在:
|
||||
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
||||
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
||||
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
||||
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
@@ -108,15 +119,41 @@ class MCPClient:
|
||||
cfg.pop("active", None) # Remove active flag from config
|
||||
|
||||
if "url" in cfg:
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(url=cfg["url"])
|
||||
streams = await self._streams_context.__aenter__()
|
||||
is_sse = True
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
is_sse = False
|
||||
if is_sse:
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self._streams_context.__aenter__()
|
||||
|
||||
# Create a new client session
|
||||
# self.session = await self._session_context.__aenter__()
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*streams)
|
||||
)
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*streams)
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
sse_read_timeout = timedelta(
|
||||
seconds=cfg.get("sse_read_timeout", 60 * 5)
|
||||
)
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self._streams_context.__aenter__()
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(read_stream=read_s, write_stream=write_s)
|
||||
)
|
||||
|
||||
else:
|
||||
server_params = mcp.StdioServerParameters(
|
||||
|
||||
@@ -18,13 +18,6 @@ class ProviderManager:
|
||||
self.persona_configs: list = config.get("persona", [])
|
||||
self.astrbot_config = config
|
||||
|
||||
self.selected_provider_id = sp.get("curr_provider")
|
||||
self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
||||
self.selected_tts_provider_id = self.provider_settings.get("provider_id")
|
||||
self.provider_enabled = self.provider_settings.get("enable", False)
|
||||
self.stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||
self.tts_enabled = self.provider_tts_settings.get("enable", False)
|
||||
|
||||
# 人格情景管理
|
||||
# 目前没有拆成独立的模块
|
||||
self.default_persona_name = self.provider_settings.get(
|
||||
@@ -98,15 +91,18 @@ class ProviderManager:
|
||||
"""加载的 Speech To Text Provider 的实例"""
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
"""加载的 Text To Speech Provider 的实例"""
|
||||
self.embedding_provider_insts: List[Provider] = []
|
||||
"""加载的 Embedding Provider 的实例"""
|
||||
self.inst_map = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
self.llm_tools = llm_tools
|
||||
|
||||
self.curr_provider_inst: Provider = None
|
||||
"""当前使用的 Provider 实例"""
|
||||
"""默认的 Provider 实例"""
|
||||
self.curr_stt_provider_inst: STTProvider = None
|
||||
"""当前使用的 Speech To Text Provider 实例"""
|
||||
"""默认的 Speech To Text Provider 实例"""
|
||||
self.curr_tts_provider_inst: TTSProvider = None
|
||||
"""当前使用的 Text To Speech Provider 实例"""
|
||||
"""默认的 Text To Speech Provider 实例"""
|
||||
self.db_helper = db_helper
|
||||
|
||||
# kdb(experimental)
|
||||
@@ -115,18 +111,57 @@ class ProviderManager:
|
||||
if kdb_cfg and len(kdb_cfg):
|
||||
self.curr_kdb_name = list(kdb_cfg.keys())[0]
|
||||
|
||||
async def set_provider(
|
||||
self, provider_id: str, provider_type: ProviderType, umo: str = None
|
||||
):
|
||||
"""设置提供商。
|
||||
|
||||
Args:
|
||||
provider_id (str): 提供商 ID。
|
||||
provider_type (ProviderType): 提供商类型。
|
||||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。
|
||||
"""
|
||||
if provider_id not in self.inst_map:
|
||||
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
|
||||
if umo and self.provider_settings["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
session_perf = perf.get(umo, {})
|
||||
session_perf[provider_type.value] = provider_id
|
||||
perf[umo] = session_perf
|
||||
sp.put("session_provider_perf", perf)
|
||||
return
|
||||
# 不启用提供商会话隔离模式的情况
|
||||
self.curr_provider_inst = self.inst_map[provider_id]
|
||||
if provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
sp.put("curr_provider_tts", provider_id)
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
sp.put("curr_provider_stt", provider_id)
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION:
|
||||
sp.put("curr_provider", provider_id)
|
||||
|
||||
async def initialize(self):
|
||||
# 逐个初始化提供商
|
||||
for provider_config in self.providers_config:
|
||||
await self.load_provider(provider_config)
|
||||
|
||||
if not self.curr_provider_inst:
|
||||
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
|
||||
# 设置默认提供商
|
||||
self.curr_provider_inst = self.inst_map.get(
|
||||
self.provider_settings.get("default_provider_id")
|
||||
)
|
||||
if not self.curr_provider_inst and self.provider_insts:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
|
||||
if self.stt_enabled and not self.curr_stt_provider_inst:
|
||||
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
|
||||
self.curr_stt_provider_inst = self.inst_map.get(
|
||||
self.provider_stt_settings.get("provider_id")
|
||||
)
|
||||
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
|
||||
if self.tts_enabled and not self.curr_tts_provider_inst:
|
||||
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
|
||||
self.curr_tts_provider_inst = self.inst_map.get(
|
||||
self.provider_tts_settings.get("provider_id")
|
||||
)
|
||||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(
|
||||
@@ -210,6 +245,18 @@ class ProviderManager:
|
||||
from .sources.minimax_tts_api_source import (
|
||||
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
||||
)
|
||||
case "volcengine_tts":
|
||||
from .sources.volcengine_tts import (
|
||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||
)
|
||||
case "openai_embedding":
|
||||
from .sources.openai_embedding_source import (
|
||||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||||
)
|
||||
case "gemini_embedding":
|
||||
from .sources.gemini_embedding_source import (
|
||||
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(
|
||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
||||
@@ -242,14 +289,14 @@ class ProviderManager:
|
||||
|
||||
self.stt_provider_insts.append(inst)
|
||||
if (
|
||||
self.selected_stt_provider_id == provider_config["id"]
|
||||
and self.stt_enabled
|
||||
self.provider_stt_settings.get("provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。"
|
||||
)
|
||||
if not self.curr_stt_provider_inst and self.stt_enabled:
|
||||
if not self.curr_stt_provider_inst:
|
||||
self.curr_stt_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
@@ -262,15 +309,12 @@ class ProviderManager:
|
||||
await inst.initialize()
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if (
|
||||
self.selected_tts_provider_id == provider_config["id"]
|
||||
and self.tts_enabled
|
||||
):
|
||||
if self.provider_settings.get("provider_id") == provider_config["id"]:
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。"
|
||||
)
|
||||
if not self.curr_tts_provider_inst and self.tts_enabled:
|
||||
if not self.curr_tts_provider_inst:
|
||||
self.curr_tts_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
@@ -288,16 +332,24 @@ class ProviderManager:
|
||||
|
||||
self.provider_insts.append(inst)
|
||||
if (
|
||||
self.selected_provider_id == provider_config["id"]
|
||||
and self.provider_enabled
|
||||
self.provider_settings.get("default_provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。"
|
||||
)
|
||||
if not self.curr_provider_inst and self.provider_enabled:
|
||||
if not self.curr_provider_inst:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config, self.provider_settings
|
||||
)
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.embedding_provider_insts.append(inst)
|
||||
|
||||
self.inst_map[provider_config["id"]] = inst
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -318,39 +370,24 @@ class ProviderManager:
|
||||
|
||||
if len(self.provider_insts) == 0:
|
||||
self.curr_provider_inst = None
|
||||
elif (
|
||||
self.curr_provider_inst is None
|
||||
and len(self.provider_insts) > 0
|
||||
and self.provider_enabled
|
||||
):
|
||||
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
self.selected_provider_id = self.curr_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
|
||||
)
|
||||
|
||||
if len(self.stt_provider_insts) == 0:
|
||||
self.curr_stt_provider_inst = None
|
||||
elif (
|
||||
self.curr_stt_provider_inst is None
|
||||
and len(self.stt_provider_insts) > 0
|
||||
and self.stt_enabled
|
||||
):
|
||||
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
|
||||
)
|
||||
|
||||
if len(self.tts_provider_insts) == 0:
|
||||
self.curr_tts_provider_inst = None
|
||||
elif (
|
||||
self.curr_tts_provider_inst is None
|
||||
and len(self.tts_provider_insts) > 0
|
||||
and self.tts_enabled
|
||||
):
|
||||
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
|
||||
)
|
||||
|
||||
@@ -179,3 +179,25 @@ class TTSProvider(AbstractProvider):
|
||||
async def get_audio(self, text: str) -> str:
|
||||
"""获取文本的音频,返回音频文件路径"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class EmbeddingProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""获取文本的向量"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||
"""批量获取文本的向量"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
...
|
||||
|
||||
@@ -104,11 +104,13 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
if not prompt:
|
||||
prompt = "<image>"
|
||||
|
||||
|
||||
@@ -53,8 +53,8 @@ class OTTSProvider:
|
||||
async def _generate_signature(self) -> str:
|
||||
await self._sync_time()
|
||||
timestamp = int(time.time()) + self.time_offset
|
||||
nonce = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=10))
|
||||
path = re.sub(r'^https?://[^/]+', '', self.api_url) or '/'
|
||||
nonce = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=10))
|
||||
path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
|
||||
return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"
|
||||
|
||||
async def get_audio(self, text: str, voice_params: Dict) -> str:
|
||||
@@ -92,7 +92,7 @@ class AzureNativeProvider(TTSProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict):
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.subscription_key = provider_config.get("azure_tts_subscription_key", "").strip()
|
||||
if not re.fullmatch(r'^[a-zA-Z0-9]{32}$', self.subscription_key):
|
||||
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
|
||||
raise ValueError("无效的Azure订阅密钥")
|
||||
self.region = provider_config.get("azure_tts_region", "eastus").strip()
|
||||
self.endpoint = f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
@@ -188,7 +188,7 @@ class AzureTTSProvider(TTSProvider):
|
||||
raise ValueError(error_msg) from e
|
||||
except KeyError as e:
|
||||
raise ValueError(f"配置错误: 缺少必要参数 {e}") from e
|
||||
if re.fullmatch(r'^[a-zA-Z0-9]{32}$', key_value):
|
||||
if re.fullmatch(r"^[a-zA-Z0-9]{32}$", key_value):
|
||||
return AzureNativeProvider(config, self.provider_settings)
|
||||
raise ValueError("订阅密钥格式无效,应为32位字母数字或other[...]格式")
|
||||
|
||||
|
||||
@@ -74,6 +74,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
|
||||
@@ -61,12 +61,14 @@ class ProviderDify(Provider):
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if image_urls is None:
|
||||
image_urls = []
|
||||
result = ""
|
||||
conversation_id = self.conversation_ids.get(session_id, "")
|
||||
|
||||
|
||||
63
astrbot/core/provider/sources/gemini_embedding_source.py
Normal file
63
astrbot/core/provider/sources/gemini_embedding_source.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
from ..provider import EmbeddingProvider
|
||||
from ..register import register_provider_adapter
|
||||
from ..entities import ProviderType
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"gemini_embedding",
|
||||
"Google Gemini Embedding 提供商适配器",
|
||||
provider_type=ProviderType.EMBEDDING,
|
||||
)
|
||||
class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
api_key: str = provider_config.get("embedding_api_key")
|
||||
api_base: str = provider_config.get("embedding_api_base", None)
|
||||
timeout: int = int(provider_config.get("timeout", 20))
|
||||
|
||||
http_options = types.HttpOptions(timeout=timeout * 1000)
|
||||
if api_base:
|
||||
if api_base.endswith("/"):
|
||||
api_base = api_base[:-1]
|
||||
http_options.base_url = api_base
|
||||
|
||||
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
|
||||
|
||||
self.model = provider_config.get(
|
||||
"embedding_model", "gemini-embedding-exp-03-07"
|
||||
)
|
||||
self.dimension = provider_config.get("embedding_dimensions", 768)
|
||||
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
获取文本的嵌入
|
||||
"""
|
||||
try:
|
||||
result = await self.client.models.embed_content(
|
||||
model=self.model, contents=text
|
||||
)
|
||||
return result.embeddings[0].values
|
||||
except APIError as e:
|
||||
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
|
||||
|
||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
批量获取文本的嵌入
|
||||
"""
|
||||
try:
|
||||
result = await self.client.models.embed_content(
|
||||
model=self.model, contents=texts
|
||||
)
|
||||
return [embedding.values for embedding in result.embeddings]
|
||||
except APIError as e:
|
||||
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
|
||||
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
return self.dimension
|
||||
@@ -141,24 +141,66 @@ class ProviderGoogleGenAI(Provider):
|
||||
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
||||
modalities = ["Text"]
|
||||
|
||||
tool_list = None
|
||||
tool_list = []
|
||||
model_name = self.get_model()
|
||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||
native_search = self.provider_config.get("gm_native_search", False)
|
||||
url_context = self.provider_config.get("gm_url_context", False)
|
||||
|
||||
if native_coderunner:
|
||||
tool_list = [types.Tool(code_execution=types.ToolCodeExecution())]
|
||||
if native_search:
|
||||
logger.warning("已启用代码执行工具,搜索工具将被忽略")
|
||||
if tools:
|
||||
logger.warning("已启用代码执行工具,函数工具将被忽略")
|
||||
elif native_search:
|
||||
tool_list = [types.Tool(google_search=types.GoogleSearch())]
|
||||
if tools:
|
||||
logger.warning("已启用搜索工具,函数工具将被忽略")
|
||||
if "gemini-2.5" in model_name:
|
||||
if native_coderunner:
|
||||
tool_list.append(types.Tool(code_execution=types.ToolCodeExecution()))
|
||||
if native_search:
|
||||
logger.warning("代码执行工具与搜索工具互斥,已忽略搜索工具")
|
||||
if url_context:
|
||||
logger.warning(
|
||||
"代码执行工具与URL上下文工具互斥,已忽略URL上下文工具"
|
||||
)
|
||||
else:
|
||||
if native_search:
|
||||
tool_list.append(types.Tool(google_search=types.GoogleSearch()))
|
||||
|
||||
if url_context:
|
||||
if hasattr(types, "UrlContext"):
|
||||
tool_list.append(types.Tool(url_context=types.UrlContext()))
|
||||
else:
|
||||
logger.warning(
|
||||
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包"
|
||||
)
|
||||
|
||||
elif "gemini-2.0-lite" in model_name:
|
||||
if native_coderunner or native_search or url_context:
|
||||
logger.warning(
|
||||
"gemini-2.0-lite 不支持代码执行、搜索工具和URL上下文,将忽略这些设置"
|
||||
)
|
||||
tool_list = None
|
||||
|
||||
else:
|
||||
if native_coderunner:
|
||||
tool_list.append(types.Tool(code_execution=types.ToolCodeExecution()))
|
||||
if native_search:
|
||||
logger.warning("代码执行工具与搜索工具互斥,已忽略搜索工具")
|
||||
elif native_search:
|
||||
tool_list.append(types.Tool(google_search=types.GoogleSearch()))
|
||||
|
||||
if url_context and not native_coderunner:
|
||||
if hasattr(types, "UrlContext"):
|
||||
tool_list.append(types.Tool(url_context=types.UrlContext()))
|
||||
else:
|
||||
logger.warning(
|
||||
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包"
|
||||
)
|
||||
|
||||
if not tool_list:
|
||||
tool_list = None
|
||||
|
||||
if tools and tool_list:
|
||||
logger.warning("已启用原生工具,函数工具将被忽略")
|
||||
elif tools and (func_desc := tools.get_func_desc_google_genai_style()):
|
||||
tool_list = [
|
||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||
]
|
||||
|
||||
return types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
@@ -291,19 +333,19 @@ class ProviderGoogleGenAI(Provider):
|
||||
result_parts: Optional[types.Part] = result.candidates[0].content.parts
|
||||
|
||||
if finish_reason == types.FinishReason.SAFETY:
|
||||
raise Exception("模型生成内容未通过用户定义的内容安全检查")
|
||||
raise Exception("模型生成内容未通过 Gemini 平台的安全检查")
|
||||
|
||||
if finish_reason in {
|
||||
types.FinishReason.PROHIBITED_CONTENT,
|
||||
types.FinishReason.SPII,
|
||||
types.FinishReason.BLOCKLIST,
|
||||
}:
|
||||
raise Exception("模型生成内容违反Gemini平台政策")
|
||||
raise Exception("模型生成内容违反 Gemini 平台政策")
|
||||
|
||||
# 防止旧版本SDK不存在IMAGE_SAFETY
|
||||
if hasattr(types.FinishReason, "IMAGE_SAFETY"):
|
||||
if finish_reason == types.FinishReason.IMAGE_SAFETY:
|
||||
raise Exception("模型生成内容违反Gemini平台政策")
|
||||
raise Exception("模型生成内容违反 Gemini 平台政策")
|
||||
|
||||
if not result_parts:
|
||||
logger.debug(result.candidates)
|
||||
|
||||
@@ -60,10 +60,12 @@ class LLMTunerModelLoader(Provider):
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = [],
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
system_prompt = ""
|
||||
new_record = {"role": "user", "content": prompt}
|
||||
query_context = [*contexts, new_record]
|
||||
|
||||
43
astrbot/core/provider/sources/openai_embedding_source.py
Normal file
43
astrbot/core/provider/sources/openai_embedding_source.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from openai import AsyncOpenAI
|
||||
from ..provider import EmbeddingProvider
|
||||
from ..register import register_provider_adapter
|
||||
from ..entities import ProviderType
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"openai_embedding",
|
||||
"OpenAI API Embedding 提供商适配器",
|
||||
provider_type=ProviderType.EMBEDDING,
|
||||
)
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=provider_config.get("embedding_api_key"),
|
||||
base_url=provider_config.get(
|
||||
"embedding_api_base", "https://api.openai.com/v1"
|
||||
),
|
||||
timeout=int(provider_config.get("timeout", 20)),
|
||||
)
|
||||
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
||||
self.dimension = provider_config.get("embedding_dimensions", 1536)
|
||||
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
获取文本的嵌入
|
||||
"""
|
||||
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return embedding.data[0].embedding
|
||||
|
||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
批量获取文本的嵌入
|
||||
"""
|
||||
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
|
||||
return [item.embedding for item in embeddings.data]
|
||||
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
return self.dimension
|
||||
@@ -195,7 +195,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for tool_call in choice.message.tool_calls:
|
||||
for tool in tools.func_list:
|
||||
if tool.name == tool_call.function.name:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
# workaround for #1454
|
||||
if isinstance(tool_call.function.arguments, str):
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
else:
|
||||
args = tool_call.function.arguments
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
tool_call_ids.append(tool_call.id)
|
||||
@@ -223,9 +227,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: list=None,
|
||||
system_prompt: str=None,
|
||||
tool_calls_result: ToolCallsResult=None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
@@ -340,9 +344,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt,
|
||||
session_id = None,
|
||||
image_urls = None,
|
||||
func_tool = None,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
|
||||
107
astrbot/core/provider/sources/volcengine_tts.py
Normal file
107
astrbot/core/provider/sources/volcengine_tts.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import uuid
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import requests
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot import logger
|
||||
|
||||
@register_provider_adapter(
|
||||
"volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||
)
|
||||
class ProviderVolcengineTTS(TTSProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.api_key = provider_config.get("api_key", "")
|
||||
self.appid = provider_config.get("appid", "")
|
||||
self.cluster = provider_config.get("volcengine_cluster", "")
|
||||
self.voice_type = provider_config.get("volcengine_voice_type", "")
|
||||
self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0)
|
||||
self.api_base = provider_config.get("api_base", f"https://openspeech.bytedance.com/api/v1/tts")
|
||||
self.timeout = provider_config.get("timeout", 20)
|
||||
|
||||
def _build_request_payload(self, text: str) -> dict:
|
||||
return {
|
||||
"app": {
|
||||
"appid": self.appid,
|
||||
"token": self.api_key,
|
||||
"cluster": self.cluster
|
||||
},
|
||||
"user": {
|
||||
"uid": str(uuid.uuid4())
|
||||
},
|
||||
"audio": {
|
||||
"voice_type": self.voice_type,
|
||||
"encoding": "mp3",
|
||||
"speed_ratio": self.speed_ratio,
|
||||
"volume_ratio": 1.0,
|
||||
"pitch_ratio": 1.0,
|
||||
},
|
||||
"request": {
|
||||
"reqid": str(uuid.uuid4()),
|
||||
"text": text,
|
||||
"text_type": "plain",
|
||||
"operation": "query",
|
||||
"with_frontend": 1,
|
||||
"frontend_type": "unitTson"
|
||||
}
|
||||
}
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
"""异步方法获取语音文件路径"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer; {self.api_key}"
|
||||
}
|
||||
|
||||
payload = self._build_request_payload(text)
|
||||
|
||||
logger.debug(f"请求头: {headers}")
|
||||
logger.debug(f"请求 URL: {self.api_base}")
|
||||
logger.debug(f"请求体: {json.dumps(payload, ensure_ascii=False)[:100]}...")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.api_base,
|
||||
data=json.dumps(payload),
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
) as response:
|
||||
logger.debug(f"响应状态码: {response.status}")
|
||||
|
||||
response_text = await response.text()
|
||||
logger.debug(f"响应内容: {response_text[:200]}...")
|
||||
|
||||
if response.status == 200:
|
||||
resp_data = json.loads(response_text)
|
||||
|
||||
if "data" in resp_data:
|
||||
audio_data = base64.b64decode(resp_data["data"])
|
||||
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
|
||||
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: open(file_path, "wb").write(audio_data)
|
||||
)
|
||||
|
||||
return file_path
|
||||
else:
|
||||
error_msg = resp_data.get("message", "未知错误")
|
||||
raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
|
||||
else:
|
||||
raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}")
|
||||
|
||||
except Exception as e:
|
||||
error_details = traceback.format_exc()
|
||||
logger.debug(f"火山引擎 TTS 异常详情: {error_details}")
|
||||
raise Exception(f"火山引擎 TTS 异常: {str(e)}")
|
||||
@@ -31,10 +31,12 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from typing import List
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
class SimpleOpenAIEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
api_key,
|
||||
api_base=None,
|
||||
) -> None:
|
||||
self.client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
||||
self.model = model
|
||||
|
||||
async def get_embedding(self, text) -> List[float]:
|
||||
"""
|
||||
获取文本的嵌入
|
||||
"""
|
||||
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return embedding.data[0].embedding
|
||||
@@ -1,95 +0,0 @@
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from astrbot.core import logger
|
||||
from .store import Store
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class KnowledgeDBManager:
|
||||
def __init__(self, astrbot_config: AstrBotConfig) -> None:
|
||||
self.db_path = os.path.join(get_astrbot_data_path(), "knowledge_db")
|
||||
self.config = astrbot_config.get("knowledge_db", {})
|
||||
self.astrbot_config = astrbot_config
|
||||
if not os.path.exists(self.db_path):
|
||||
os.makedirs(self.db_path)
|
||||
self.store_insts: Dict[str, Store] = {}
|
||||
for name, cfg in self.config.items():
|
||||
if cfg["strategy"] == "embedding":
|
||||
logger.info(f"加载 Chroma Vector Store:{name}")
|
||||
try:
|
||||
from .store.chroma_db import ChromaVectorStore
|
||||
except ImportError as ie:
|
||||
logger.error(f"{ie} 可能未安装 chromadb 库。")
|
||||
continue
|
||||
self.store_insts[name] = ChromaVectorStore(
|
||||
name, cfg["embedding_config"]
|
||||
)
|
||||
else:
|
||||
logger.error(f"不支持的策略:{cfg['strategy']}")
|
||||
|
||||
async def list_knowledge_db(self) -> List[str]:
|
||||
return [
|
||||
f
|
||||
for f in os.listdir(self.db_path)
|
||||
if os.path.isfile(os.path.join(self.db_path, f))
|
||||
]
|
||||
|
||||
async def create_knowledge_db(self, name: str, config: Dict):
|
||||
"""
|
||||
config 格式:
|
||||
```
|
||||
{
|
||||
"strategy": "embedding", # 目前只支持 embedding
|
||||
"chunk_method": {
|
||||
"strategy": "fixed",
|
||||
"chunk_size": 100,
|
||||
"overlap_size": 10
|
||||
},
|
||||
"embedding_config": {
|
||||
"strategy": "openai",
|
||||
"base_url": "",
|
||||
"model": "",
|
||||
"api_key": ""
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
if name in self.config:
|
||||
raise ValueError(f"知识库已存在:{name}")
|
||||
|
||||
self.config[name] = config
|
||||
self.astrbot_config["knowledge_db"] = self.config
|
||||
self.astrbot_config.save_config()
|
||||
|
||||
async def insert_record(self, name: str, text: str):
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
ret = []
|
||||
match self.config[name]["chunk_method"]["strategy"]:
|
||||
case "fixed":
|
||||
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
|
||||
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
|
||||
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
|
||||
case _:
|
||||
pass
|
||||
|
||||
for chunk in ret:
|
||||
await self.store_insts[name].save(chunk)
|
||||
|
||||
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
inst = self.store_insts[name]
|
||||
return await inst.query(query, top_n)
|
||||
|
||||
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
chunks.append(text[start:end])
|
||||
start += chunk_size - chunk_overlap
|
||||
return chunks
|
||||
@@ -1,9 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
|
||||
class Store:
|
||||
async def save(self, text: str):
|
||||
pass
|
||||
|
||||
async def query(self, query: str, top_n: int = 3) -> List[str]:
|
||||
pass
|
||||
@@ -1,44 +0,0 @@
|
||||
import chromadb
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
from astrbot.api import logger
|
||||
from ..embedding.openai_source import SimpleOpenAIEmbedding
|
||||
from . import Store
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class ChromaVectorStore(Store):
|
||||
def __init__(self, name: str, embedding_cfg: Dict) -> None:
|
||||
import os
|
||||
self.chroma_client = chromadb.PersistentClient(
|
||||
path=os.path.join(get_astrbot_data_path(), "long_term_memory_chroma.db")
|
||||
)
|
||||
self.collection = self.chroma_client.get_or_create_collection(name=name)
|
||||
self.embedding = None
|
||||
if embedding_cfg["strategy"] == "openai":
|
||||
self.embedding = SimpleOpenAIEmbedding(
|
||||
model=embedding_cfg["model"],
|
||||
api_key=embedding_cfg["api_key"],
|
||||
api_base=embedding_cfg.get("base_url", None),
|
||||
)
|
||||
|
||||
async def save(self, text: str, metadata: Dict = None):
|
||||
logger.debug(f"Saving text: {text}")
|
||||
embedding = await self.embedding.get_embedding(text)
|
||||
|
||||
self.collection.upsert(
|
||||
documents=text,
|
||||
metadatas=metadata,
|
||||
ids=str(uuid.uuid4()),
|
||||
embeddings=embedding,
|
||||
)
|
||||
|
||||
async def query(
|
||||
self, query: str, top_n=3, metadata_filter: Dict = None
|
||||
) -> List[str]:
|
||||
embedding = await self.embedding.get_embedding(query)
|
||||
|
||||
results = self.collection.query(
|
||||
query_embeddings=embedding, n_results=top_n, where=metadata_filter
|
||||
)
|
||||
return results["documents"][0]
|
||||
@@ -3,6 +3,7 @@ from typing import List, Union
|
||||
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
@@ -16,7 +17,6 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterType,
|
||||
@@ -42,6 +42,8 @@ class Context:
|
||||
|
||||
platform_manager: PlatformManager = None
|
||||
|
||||
registered_web_apis: list = []
|
||||
|
||||
# back compatibility
|
||||
_register_tasks: List[Awaitable] = []
|
||||
_star_manager = None
|
||||
@@ -54,14 +56,12 @@ class Context:
|
||||
provider_manager: ProviderManager = None,
|
||||
platform_manager: PlatformManager = None,
|
||||
conversation_manager: ConversationManager = None,
|
||||
knowledge_db_manager: KnowledgeDBManager = None,
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
self._config = config
|
||||
self._db = db
|
||||
self.provider_manager = provider_manager
|
||||
self.platform_manager = platform_manager
|
||||
self.knowledge_db_manager = knowledge_db_manager
|
||||
self.conversation_manager = conversation_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
@@ -126,11 +126,8 @@ class Context:
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider:
|
||||
"""通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
||||
for provider in self.provider_manager.provider_insts:
|
||||
if provider.meta().id == provider_id:
|
||||
return provider
|
||||
return None
|
||||
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
|
||||
return self.provider_manager.inst_map.get(provider_id)
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
||||
@@ -144,24 +141,46 @@ class Context:
|
||||
"""获取所有用于 STT 任务的 Provider。"""
|
||||
return self.provider_manager.stt_provider_insts
|
||||
|
||||
def get_using_provider(self) -> Provider:
|
||||
def get_using_provider(self, umo: str = None) -> Provider:
|
||||
"""
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
||||
|
||||
通过 /provider 指令切换。
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_provider_inst
|
||||
|
||||
def get_using_tts_provider(self) -> TTSProvider:
|
||||
def get_using_tts_provider(self, umo: str = None) -> TTSProvider:
|
||||
"""
|
||||
获取当前使用的用于 TTS 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_tts_provider_inst
|
||||
|
||||
def get_using_stt_provider(self) -> STTProvider:
|
||||
def get_using_stt_provider(self, umo: str = None) -> STTProvider:
|
||||
"""
|
||||
获取当前使用的用于 STT 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.SPEECH_TO_TEXT.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_stt_provider_inst
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
@@ -301,3 +320,12 @@ class Context:
|
||||
注册一个异步任务。
|
||||
"""
|
||||
self._register_tasks.append(task)
|
||||
|
||||
def register_web_api(
|
||||
self, route: str, view_handler: Awaitable, methods: list, desc: str
|
||||
):
|
||||
for idx, api in enumerate(self.registered_web_apis):
|
||||
if api[0] == route and methods == api[2]:
|
||||
self.registered_web_apis[idx] = (route, view_handler, methods, desc)
|
||||
return
|
||||
self.registered_web_apis.append((route, view_handler, methods, desc))
|
||||
|
||||
@@ -7,6 +7,9 @@ from astrbot.core.config import AstrBotConfig
|
||||
from .custom_filter import CustomFilter
|
||||
from ..star_handler import StarHandlerMetadata
|
||||
|
||||
class GreedyStr(str):
|
||||
"""标记指令完成其他参数接收后的所有剩余文本。"""
|
||||
pass
|
||||
|
||||
# 标准指令受到 wake_prefix 的制约。
|
||||
class CommandFilter(HandlerFilter):
|
||||
@@ -68,7 +71,22 @@ class CommandFilter(HandlerFilter):
|
||||
) -> Dict[str, Any]:
|
||||
"""将参数列表 params 根据 param_type 转换为参数字典。"""
|
||||
result = {}
|
||||
for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()):
|
||||
param_items = list(param_type.items())
|
||||
for i, (param_name, param_type_or_default_val) in enumerate(param_items):
|
||||
is_greedy = param_type_or_default_val is GreedyStr
|
||||
|
||||
if is_greedy:
|
||||
# GreedyStr 必须是最后一个参数
|
||||
if i != len(param_items) - 1:
|
||||
raise ValueError(
|
||||
f"参数 '{param_name}' (GreedyStr) 必须是最后一个参数。"
|
||||
)
|
||||
|
||||
# 将剩余的所有部分合并成一个字符串
|
||||
remaining_params = params[i:]
|
||||
result[param_name] = " ".join(remaining_params)
|
||||
break
|
||||
# 没有 GreedyStr 的情况
|
||||
if i >= len(params):
|
||||
if (
|
||||
isinstance(param_type_or_default_val, Type)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import heapq
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Awaitable, List, Dict, TypeVar, Generic
|
||||
from .filter import HandlerFilter
|
||||
@@ -8,100 +7,66 @@ from .star import star_map
|
||||
|
||||
T = TypeVar("T", bound="StarHandlerMetadata")
|
||||
|
||||
|
||||
class StarHandlerRegistry(Generic[T]):
|
||||
"""用于存储所有的 Star Handler"""
|
||||
|
||||
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
"""用于快速查找。key 是 handler_full_name"""
|
||||
_handlers = []
|
||||
def __init__(self):
|
||||
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
self._handlers: List[StarHandlerMetadata] = []
|
||||
|
||||
def append(self, handler: StarHandlerMetadata):
|
||||
"""添加一个 Handler"""
|
||||
"""添加一个 Handler,并保持按优先级有序"""
|
||||
if "priority" not in handler.extras_configs:
|
||||
handler.extras_configs["priority"] = 0
|
||||
|
||||
heapq.heappush(self._handlers, (-handler.extras_configs["priority"], handler))
|
||||
self.star_handlers_map[handler.handler_full_name] = handler
|
||||
self._handlers.append(handler)
|
||||
self._handlers.sort(key=lambda h: -h.extras_configs["priority"])
|
||||
|
||||
def _print_handlers(self):
|
||||
"""打印所有的 Handler"""
|
||||
for _, handler in self._handlers:
|
||||
for handler in self._handlers:
|
||||
print(handler.handler_full_name)
|
||||
|
||||
def get_handlers_by_event_type(
|
||||
self, event_type: EventType, only_activated=True, platform_id=None
|
||||
) -> List[StarHandlerMetadata]:
|
||||
"""通过事件类型获取 Handler
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
only_activated: 是否只返回已激活的插件的处理器
|
||||
platform_id: 平台ID,如果提供此参数,将过滤掉在此平台不兼容的处理器
|
||||
|
||||
Returns:
|
||||
List[StarHandlerMetadata]: 处理器列表
|
||||
"""
|
||||
handlers = []
|
||||
for _, handler in self._handlers:
|
||||
for handler in self._handlers:
|
||||
if handler.event_type != event_type:
|
||||
continue
|
||||
|
||||
# 只激活的插件处理器
|
||||
if only_activated:
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not (plugin and plugin.activated):
|
||||
continue
|
||||
|
||||
# 平台兼容性过滤
|
||||
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
|
||||
if not handler.is_enabled_for_platform(platform_id):
|
||||
continue
|
||||
|
||||
handlers.append(handler)
|
||||
|
||||
return handlers
|
||||
|
||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
||||
"""通过 Handler 的全名获取 Handler"""
|
||||
return self.star_handlers_map.get(full_name, None)
|
||||
|
||||
def get_handlers_by_module_name(
|
||||
self, module_name: str
|
||||
) -> List[StarHandlerMetadata]:
|
||||
"""通过模块名获取 Handler"""
|
||||
return [
|
||||
handler
|
||||
for _, handler in self._handlers
|
||||
handler for handler in self._handlers
|
||||
if handler.handler_module_path == module_name
|
||||
]
|
||||
|
||||
def clear(self):
|
||||
"""清空所有的 Handler"""
|
||||
self.star_handlers_map.clear()
|
||||
self._handlers.clear()
|
||||
|
||||
def remove(self, handler: StarHandlerMetadata):
|
||||
"""删除一个 Handler"""
|
||||
# self._handlers.remove(handler)
|
||||
for i, h in enumerate(self._handlers):
|
||||
if h[1] == handler:
|
||||
self._handlers.pop(i)
|
||||
break
|
||||
try:
|
||||
del self.star_handlers_map[handler.handler_full_name]
|
||||
except KeyError:
|
||||
pass
|
||||
self.star_handlers_map.pop(handler.handler_full_name, None)
|
||||
self._handlers = [h for h in self._handlers if h != handler]
|
||||
|
||||
def __iter__(self):
|
||||
"""使 StarHandlerRegistry 支持迭代"""
|
||||
return (handler for _, handler in self._handlers)
|
||||
return iter(self._handlers)
|
||||
|
||||
def __len__(self):
|
||||
"""返回 Handler 的数量"""
|
||||
return len(self._handlers)
|
||||
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
|
||||
|
||||
@@ -37,6 +37,12 @@ except ImportError:
|
||||
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
||||
logger.warning("未安装 watchfiles,无法实现插件的热重载。")
|
||||
|
||||
try:
|
||||
import nh3
|
||||
except ImportError:
|
||||
logger.warning("未安装 nh3 库,无法清理插件 README.md 中的 HTML 标签。")
|
||||
nh3 = None
|
||||
|
||||
|
||||
class PluginManager:
|
||||
def __init__(self, context: Context, config: AstrBotConfig):
|
||||
@@ -140,11 +146,13 @@ class PluginManager:
|
||||
if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(
|
||||
os.path.join(path, d, d + ".py")
|
||||
):
|
||||
modules.append({
|
||||
"pname": d,
|
||||
"module": module_str,
|
||||
"module_path": os.path.join(path, d, module_str),
|
||||
})
|
||||
modules.append(
|
||||
{
|
||||
"pname": d,
|
||||
"module": module_str,
|
||||
"module_path": os.path.join(path, d, module_str),
|
||||
}
|
||||
)
|
||||
return modules
|
||||
|
||||
def _get_plugin_modules(self) -> List[dict]:
|
||||
@@ -158,7 +166,7 @@ class PluginManager:
|
||||
plugins.extend(_p)
|
||||
return plugins
|
||||
|
||||
def _check_plugin_dept_update(self, target_plugin: str = None):
|
||||
async def _check_plugin_dept_update(self, target_plugin: str = None):
|
||||
"""检查插件的依赖
|
||||
如果 target_plugin 为 None,则检查所有插件的依赖
|
||||
"""
|
||||
@@ -177,7 +185,7 @@ class PluginManager:
|
||||
pth = os.path.join(plugin_path, "requirements.txt")
|
||||
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
|
||||
try:
|
||||
pip_installer.install(requirements_path=pth)
|
||||
await pip_installer.install(requirements_path=pth)
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
|
||||
|
||||
@@ -399,7 +407,7 @@ class PluginManager:
|
||||
module = __import__(path, fromlist=[module_str])
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# 尝试安装依赖
|
||||
self._check_plugin_dept_update(target_plugin=root_dir_name)
|
||||
await self._check_plugin_dept_update(target_plugin=root_dir_name)
|
||||
module = __import__(path, fromlist=[module_str])
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -443,11 +451,11 @@ class PluginManager:
|
||||
metadata.repo = metadata_yaml.repo
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
metadata.config = plugin_config
|
||||
if path not in inactivated_plugins:
|
||||
# 只有没有禁用插件时才实例化插件类
|
||||
if plugin_config:
|
||||
metadata.config = plugin_config
|
||||
# metadata.config = plugin_config
|
||||
try:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context, config=plugin_config
|
||||
@@ -634,16 +642,17 @@ class PluginManager:
|
||||
if not os.path.exists(readme_path):
|
||||
readme_path = os.path.join(plugin_path, "readme.md")
|
||||
|
||||
if os.path.exists(readme_path):
|
||||
if os.path.exists(readme_path) and nh3:
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
cleaned_content = nh3.clean(readme_content)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
|
||||
|
||||
plugin_info = None
|
||||
if plugin:
|
||||
plugin_info = {"repo": plugin.repo, "readme": readme_content}
|
||||
plugin_info = {"repo": plugin.repo, "readme": cleaned_content}
|
||||
|
||||
return plugin_info
|
||||
|
||||
|
||||
@@ -18,7 +18,8 @@ class PluginUpdator(RepoZipUpdator):
|
||||
return self.plugin_store_path
|
||||
|
||||
async def install(self, repo_url: str, proxy="") -> str:
|
||||
repo_name = self.format_repo_name(repo_url)
|
||||
_, repo_name, _ = self.parse_github_url(repo_url)
|
||||
repo_name = self.format_name(repo_name)
|
||||
plugin_path = os.path.join(self.plugin_store_path, repo_name)
|
||||
await self.download_from_repo_url(plugin_path, repo_url, proxy)
|
||||
self.unzip_file(plugin_path + ".zip", plugin_path)
|
||||
@@ -31,10 +32,6 @@ class PluginUpdator(RepoZipUpdator):
|
||||
if not repo_url:
|
||||
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
|
||||
|
||||
if proxy:
|
||||
proxy = proxy.removesuffix("/")
|
||||
repo_url = f"{proxy}/{repo_url}"
|
||||
|
||||
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
||||
|
||||
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
||||
@@ -54,7 +51,7 @@ class PluginUpdator(RepoZipUpdator):
|
||||
def unzip_file(self, zip_path: str, target_dir: str):
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
update_dir = ""
|
||||
logger.info(f"解压文件: {zip_path}")
|
||||
logger.info(f"正在解压压缩包: {zip_path}")
|
||||
with zipfile.ZipFile(zip_path, "r") as z:
|
||||
update_dir = z.namelist()[0]
|
||||
z.extractall(target_dir)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from pip import main as pip_main
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
@@ -9,7 +9,7 @@ class PipInstaller:
|
||||
self.pip_install_arg = pip_install_arg
|
||||
self.pypi_index_url = pypi_index_url
|
||||
|
||||
def install(
|
||||
async def install(
|
||||
self,
|
||||
package_name: str = None,
|
||||
requirements_path: str = None,
|
||||
@@ -29,12 +29,29 @@ class PipInstaller:
|
||||
args.extend(self.pip_install_arg.split())
|
||||
|
||||
logger.info(f"Pip 包管理器: pip {' '.join(args)}")
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"pip", *args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
|
||||
result_code = pip_main(args)
|
||||
assert process.stdout is not None
|
||||
async for line in process.stdout:
|
||||
logger.info(line.decode().strip())
|
||||
|
||||
# 清除 pip.main 导致的多余的 logging handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
await process.wait()
|
||||
|
||||
if result_code != 0:
|
||||
raise Exception(f"安装失败,错误码:{result_code}")
|
||||
if process.returncode != 0:
|
||||
raise Exception(f"安装失败,错误码:{process.returncode}")
|
||||
except FileNotFoundError:
|
||||
# 没有 pip
|
||||
from pip import main as pip_main
|
||||
result_code = await asyncio.to_thread(pip_main, args)
|
||||
|
||||
# 清除 pip.main 导致的多余的 logging handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
if result_code != 0:
|
||||
raise Exception(f"安装失败,错误码:{result_code}")
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import base64
|
||||
import wave
|
||||
import os
|
||||
from io import BytesIO
|
||||
import asyncio
|
||||
import tempfile
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
|
||||
@@ -50,3 +55,46 @@ async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
|
||||
rate = wav.getframerate()
|
||||
duration = pilk.encode(wav_path, output_path, pcm_rate=rate, tencent=True)
|
||||
return duration
|
||||
|
||||
|
||||
async def wav_to_tencent_silk_base64(wav_path: str) -> str:
|
||||
"""
|
||||
将 WAV 文件转为 Silk,并返回 Base64 字符串。
|
||||
默认采样率为 24000,输出临时文件为 temp/output.silk。
|
||||
|
||||
参数:
|
||||
- wav_path: 输入 .wav 文件路径(需为 PCM 16bit)
|
||||
|
||||
返回:
|
||||
- Base64 编码的 Silk 字符串
|
||||
- duration: 音频时长(秒)
|
||||
"""
|
||||
try:
|
||||
import pilk
|
||||
except ImportError as e:
|
||||
raise Exception("pysilk 模块未安装,请安装 pysilk") from e
|
||||
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
with wave.open(wav_path, "rb") as wav:
|
||||
rate = wav.getframerate()
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".silk", delete=False, dir=temp_dir
|
||||
) as tmp_file:
|
||||
silk_path = tmp_file.name
|
||||
|
||||
try:
|
||||
duration = await asyncio.to_thread(
|
||||
pilk.encode, wav_path, silk_path, pcm_rate=rate, tencent=True
|
||||
)
|
||||
|
||||
with open(silk_path, "rb") as f:
|
||||
silk_bytes = await asyncio.to_thread(f.read)
|
||||
silk_b64 = base64.b64encode(silk_bytes).decode("utf-8")
|
||||
|
||||
return silk_b64, duration # 已是秒
|
||||
finally:
|
||||
if os.path.exists(silk_path):
|
||||
os.remove(silk_path)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import aiohttp
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
import shutil
|
||||
|
||||
@@ -119,28 +120,61 @@ class RepoZipUpdator:
|
||||
)
|
||||
|
||||
async def download_from_repo_url(self, target_path: str, repo_url: str, proxy=""):
|
||||
repo_namespace = repo_url.split("/")[-2:]
|
||||
author = repo_namespace[0]
|
||||
repo = repo_namespace[1]
|
||||
author, repo, branch = self.parse_github_url(repo_url)
|
||||
|
||||
logger.info(f"正在下载更新 {repo} ...")
|
||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||
releases = await self.fetch_release_info(url=release_url)
|
||||
if not releases:
|
||||
# download from the default branch directly.
|
||||
logger.info(f"正在从默认分支下载 {author}/{repo} ")
|
||||
|
||||
if branch:
|
||||
logger.info(f"正在从指定分支 {branch} 下载 {author}/{repo}")
|
||||
release_url = (
|
||||
f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
f"https://github.com/{author}/{repo}/archive/refs/heads/{branch}.zip"
|
||||
)
|
||||
else:
|
||||
release_url = releases[0]["zipball_url"]
|
||||
try:
|
||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||
releases = await self.fetch_release_info(url=release_url)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"获取 {author}/{repo} 的 GitHub Releases 失败: {e},将尝试下载默认分支"
|
||||
)
|
||||
releases = []
|
||||
if not releases:
|
||||
# 如果没有最新版本,下载默认分支
|
||||
logger.info(f"正在从默认分支下载 {author}/{repo}")
|
||||
release_url = (
|
||||
f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
)
|
||||
else:
|
||||
release_url = releases[0]["zipball_url"]
|
||||
|
||||
if proxy:
|
||||
proxy = proxy.rstrip("/")
|
||||
release_url = f"{proxy}/{release_url}"
|
||||
logger.info(f"使用代理下载: {release_url}")
|
||||
logger.info(
|
||||
f"检查到设置了镜像站,将使用镜像站下载 {author}/{repo} 仓库源码: {release_url}"
|
||||
)
|
||||
|
||||
await download_file(release_url, target_path + ".zip")
|
||||
|
||||
def parse_github_url(self, url: str):
|
||||
"""使用正则表达式解析 GitHub 仓库 URL,支持 `.git` 后缀和 `tree/branch` 结构
|
||||
Returns:
|
||||
tuple[str, str, str]: 返回作者名、仓库名和分支名
|
||||
Raises:
|
||||
ValueError: 如果 URL 格式不正确
|
||||
"""
|
||||
cleaned_url = url.rstrip("/")
|
||||
pattern = r"^https://github\.com/([a-zA-Z0-9_-]+)/([a-zA-Z0-9_-]+)(\.git)?(?:/tree/([a-zA-Z0-9_-]+))?$"
|
||||
match = re.match(pattern, cleaned_url)
|
||||
|
||||
if match:
|
||||
author = match.group(1)
|
||||
repo = match.group(2)
|
||||
branch = match.group(4)
|
||||
return author, repo, branch
|
||||
else:
|
||||
raise ValueError("无效的 GitHub URL")
|
||||
|
||||
def unzip_file(self, zip_path: str, target_dir: str):
|
||||
"""
|
||||
解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir
|
||||
@@ -174,16 +208,5 @@ class RepoZipUpdator:
|
||||
f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}"
|
||||
)
|
||||
|
||||
def format_repo_name(self, repo_url: str) -> str:
|
||||
if repo_url.endswith("/"):
|
||||
repo_url = repo_url[:-1]
|
||||
|
||||
repo_namespace = repo_url.split("/")[-2:]
|
||||
repo = repo_namespace[1]
|
||||
|
||||
repo = self.format_name(repo)
|
||||
|
||||
return repo
|
||||
|
||||
def format_name(self, name: str) -> str:
|
||||
return name.replace("-", "_").lower()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import jwt
|
||||
import datetime
|
||||
import asyncio
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import request
|
||||
from astrbot.core import WEBUI_SK, DEMO_MODE
|
||||
@@ -21,7 +22,11 @@ class AuthRoute(Route):
|
||||
post_data = await request.json
|
||||
if post_data["username"] == username and post_data["password"] == password:
|
||||
change_pwd_hint = False
|
||||
if username == "astrbot" and password == "77b90590a8945a7d36c963981a307dc9":
|
||||
if (
|
||||
username == "astrbot"
|
||||
and password == "77b90590a8945a7d36c963981a307dc9"
|
||||
and not DEMO_MODE
|
||||
):
|
||||
change_pwd_hint = True
|
||||
logger.warning("为了保证安全,请尽快修改默认密码。")
|
||||
|
||||
@@ -37,6 +42,7 @@ class AuthRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
else:
|
||||
await asyncio.sleep(3)
|
||||
return Response().error("用户名或密码错误").__dict__
|
||||
|
||||
async def edit_account(self):
|
||||
@@ -72,7 +78,7 @@ class AuthRoute(Route):
|
||||
def generate_jwt(self, username):
|
||||
payload = {
|
||||
"username": username,
|
||||
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=30),
|
||||
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=7),
|
||||
}
|
||||
token = jwt.encode(payload, WEBUI_SK, algorithm="HS256")
|
||||
return token
|
||||
|
||||
@@ -26,6 +26,7 @@ class ChatRoute(Route):
|
||||
"/chat/conversations": ("GET", self.get_conversations),
|
||||
"/chat/get_conversation": ("GET", self.get_conversation),
|
||||
"/chat/delete_conversation": ("GET", self.delete_conversation),
|
||||
"/chat/rename_conversation": ("POST", self.rename_conversation),
|
||||
"/chat/get_file": ("GET", self.get_file),
|
||||
"/chat/post_image": ("POST", self.post_image),
|
||||
"/chat/post_file": ("POST", self.post_file),
|
||||
@@ -61,16 +62,25 @@ class ChatRoute(Route):
|
||||
return Response().error("Missing key: filename").__dict__
|
||||
|
||||
try:
|
||||
with open(os.path.join(self.imgs_dir, filename), "rb") as f:
|
||||
if filename.endswith(".wav"):
|
||||
file_path = os.path.join(self.imgs_dir, os.path.basename(filename))
|
||||
real_file_path = os.path.realpath(file_path)
|
||||
real_imgs_dir = os.path.realpath(self.imgs_dir)
|
||||
|
||||
if not real_file_path.startswith(real_imgs_dir):
|
||||
return Response().error("Invalid file path").__dict__
|
||||
|
||||
with open(real_file_path, "rb") as f:
|
||||
filename_ext = os.path.splitext(filename)[1].lower()
|
||||
|
||||
if filename_ext == ".wav":
|
||||
return QuartResponse(f.read(), mimetype="audio/wav")
|
||||
elif filename.split(".")[-1] in self.supported_imgs:
|
||||
elif filename_ext[1:] in self.supported_imgs:
|
||||
return QuartResponse(f.read(), mimetype="image/jpeg")
|
||||
else:
|
||||
return QuartResponse(f.read())
|
||||
|
||||
except FileNotFoundError:
|
||||
return Response().error("File not found").__dict__
|
||||
except (FileNotFoundError, OSError):
|
||||
return Response().error("File access error").__dict__
|
||||
|
||||
async def post_image(self):
|
||||
post_data = await request.files
|
||||
@@ -91,7 +101,6 @@ class ChatRoute(Route):
|
||||
|
||||
file = post_data["file"]
|
||||
filename = f"{str(uuid.uuid4())}"
|
||||
print(file)
|
||||
# 通过文件格式判断文件类型
|
||||
if file.content_type.startswith("audio"):
|
||||
filename += ".wav"
|
||||
@@ -143,7 +152,7 @@ class ChatRoute(Route):
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
logger.error(f"Failed to parse conversation history: {e}")
|
||||
history = []
|
||||
new_his = {"type": "user", "message": message}
|
||||
if image_url:
|
||||
@@ -197,6 +206,9 @@ class ChatRoute(Route):
|
||||
if streaming and type != "end":
|
||||
continue
|
||||
|
||||
if type == "update_title":
|
||||
continue
|
||||
|
||||
if result_text:
|
||||
conversation = self.db.get_conversation_by_user_id(
|
||||
username, cid
|
||||
@@ -204,7 +216,7 @@ class ChatRoute(Route):
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
logger.error(f"Failed to parse conversation history: {e}")
|
||||
history = []
|
||||
history.append({"type": "bot", "message": result_text})
|
||||
self.db.update_conversation(
|
||||
@@ -242,6 +254,18 @@ class ChatRoute(Route):
|
||||
self.db.new_conversation(username, conversation_id)
|
||||
return Response().ok(data={"conversation_id": conversation_id}).__dict__
|
||||
|
||||
async def rename_conversation(self):
|
||||
username = g.get("username", "guest")
|
||||
post_data = await request.json
|
||||
if "conversation_id" not in post_data or "title" not in post_data:
|
||||
return Response().error("Missing key: conversation_id or title").__dict__
|
||||
|
||||
conversation_id = post_data["conversation_id"]
|
||||
title = post_data["title"]
|
||||
|
||||
self.db.update_conversation_title(username, conversation_id, title=title)
|
||||
return Response().ok(message="重命名成功!").__dict__
|
||||
|
||||
async def get_conversations(self):
|
||||
username = g.get("username", "guest")
|
||||
conversations = self.db.get_conversations(username)
|
||||
|
||||
@@ -9,6 +9,7 @@ from astrbot.core.platform.register import platform_registry
|
||||
from astrbot.core.provider.register import provider_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core import logger
|
||||
import asyncio
|
||||
|
||||
|
||||
def try_cast(value: str, type_: str):
|
||||
@@ -153,6 +154,7 @@ class ConfigRoute(Route):
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config: AstrBotConfig = core_lifecycle.astrbot_config
|
||||
self.routes = {
|
||||
"/config/get": ("GET", self.get_configs),
|
||||
"/config/astrbot/update": ("POST", self.post_astrbot_configs),
|
||||
@@ -164,9 +166,125 @@ class ConfigRoute(Route):
|
||||
"/config/provider/update": ("POST", self.post_update_provider),
|
||||
"/config/provider/delete": ("POST", self.post_delete_provider),
|
||||
"/config/llmtools": ("GET", self.get_llm_tools),
|
||||
"/config/provider/check_status": ("GET", self.check_all_providers_status),
|
||||
"/config/provider/list": ("GET", self.get_provider_config_list),
|
||||
"/config/provider/get_session_seperate": (
|
||||
"GET",
|
||||
lambda: Response()
|
||||
.ok({"enable": self.config["provider_settings"]["separate_provider"]})
|
||||
.__dict__,
|
||||
),
|
||||
"/config/provider/set_session_seperate": (
|
||||
"POST",
|
||||
self.post_session_seperate,
|
||||
),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
async def _test_single_provider(self, provider):
|
||||
"""辅助函数:测试单个 provider 的可用性"""
|
||||
meta = provider.meta()
|
||||
provider_name = provider.provider_config.get("id", "Unknown Provider")
|
||||
logger.debug(f"Got provider meta: {meta}")
|
||||
if not provider_name and meta:
|
||||
provider_name = meta.id
|
||||
elif not provider_name:
|
||||
provider_name = "Unknown Provider"
|
||||
status_info = {
|
||||
"id": getattr(meta, "id", "Unknown ID"),
|
||||
"model": getattr(meta, "model", "Unknown Model"),
|
||||
"type": getattr(meta, "type", "Unknown Type"),
|
||||
"name": provider_name,
|
||||
"status": "unavailable", # 默认为不可用
|
||||
"error": None,
|
||||
}
|
||||
logger.debug(
|
||||
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})"
|
||||
)
|
||||
try:
|
||||
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
|
||||
response = await asyncio.wait_for(
|
||||
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
|
||||
)
|
||||
logger.debug(f"Received response from {status_info['name']}: {response}")
|
||||
# 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用
|
||||
if response is not None:
|
||||
status_info["status"] = "available"
|
||||
response_text_snippet = ""
|
||||
if hasattr(response, "completion_text") and response.completion_text:
|
||||
response_text_snippet = (
|
||||
response.completion_text[:70] + "..."
|
||||
if len(response.completion_text) > 70
|
||||
else response.completion_text
|
||||
)
|
||||
elif hasattr(response, "result_chain") and response.result_chain:
|
||||
try:
|
||||
response_text_snippet = (
|
||||
response.result_chain.get_plain_text()[:70] + "..."
|
||||
if len(response.result_chain.get_plain_text()) > 70
|
||||
else response.result_chain.get_plain_text()
|
||||
)
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
|
||||
)
|
||||
else:
|
||||
# 这个分支理论上不应该被走到,除非 text_chat 实现可能返回 None
|
||||
status_info["error"] = (
|
||||
"Test call returned None, but expected an LLMResponse object."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
status_info["error"] = (
|
||||
"Connection timed out after 45 seconds during test call."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out."
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
status_info["error"] = error_message
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}"
|
||||
)
|
||||
return status_info
|
||||
|
||||
async def check_all_providers_status(self):
|
||||
"""
|
||||
API 接口: 检查所有 LLM Providers 的状态
|
||||
"""
|
||||
logger.info("API call received: /config/provider/check_status")
|
||||
try:
|
||||
all_providers: typing.List = (
|
||||
self.core_lifecycle.star_context.get_all_providers()
|
||||
)
|
||||
logger.debug(f"Found {len(all_providers)} providers to check.")
|
||||
|
||||
if not all_providers:
|
||||
logger.info("No providers found to check.")
|
||||
return Response().ok([]).__dict__
|
||||
|
||||
tasks = [self._test_single_provider(p) for p in all_providers]
|
||||
logger.debug(f"Created {len(tasks)} tasks for concurrent provider checks.")
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
logger.info(f"Provider status check completed. Results: {results}")
|
||||
|
||||
return Response().ok(results).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error in check_all_providers_status: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response().error(f"检查 Provider 状态时发生严重错误: {str(e)}").__dict__
|
||||
)
|
||||
|
||||
async def get_configs(self):
|
||||
# plugin_name 为空时返回 AstrBot 配置
|
||||
# 否则返回指定 plugin_name 的插件配置
|
||||
@@ -175,6 +293,32 @@ class ConfigRoute(Route):
|
||||
return Response().ok(await self._get_astrbot_config()).__dict__
|
||||
return Response().ok(await self._get_plugin_config(plugin_name)).__dict__
|
||||
|
||||
async def post_session_seperate(self):
|
||||
"""设置提供商会话隔离"""
|
||||
post_config = await request.json
|
||||
enable = post_config.get("enable", None)
|
||||
if enable is None:
|
||||
return Response().error("缺少参数 enable").__dict__
|
||||
|
||||
astrbot_config = self.core_lifecycle.astrbot_config
|
||||
astrbot_config["provider_settings"]["separate_provider"] = enable
|
||||
try:
|
||||
astrbot_config.save_config()
|
||||
except Exception as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
return Response().ok(None, "设置成功~").__dict__
|
||||
|
||||
async def get_provider_config_list(self):
|
||||
provider_type = request.args.get("provider_type", None)
|
||||
if not provider_type:
|
||||
return Response().error("缺少参数 provider_type").__dict__
|
||||
provider_list = []
|
||||
astrbot_config = self.core_lifecycle.astrbot_config
|
||||
for provider in astrbot_config["provider"]:
|
||||
if provider.get("provider_type", None) == provider_type:
|
||||
provider_list.append(provider)
|
||||
return Response().ok(provider_list).__dict__
|
||||
|
||||
async def post_astrbot_configs(self):
|
||||
post_configs = await request.json
|
||||
try:
|
||||
|
||||
@@ -23,6 +23,7 @@ class LogRoute(Route):
|
||||
**message, # see astrbot/core/log.py
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
await asyncio.sleep(0.07) # 控制发送频率,避免过快
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except BaseException as e:
|
||||
|
||||
@@ -18,6 +18,12 @@ from astrbot.core.star.filter.regex import RegexFilter
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core import DEMO_MODE
|
||||
|
||||
try:
|
||||
import nh3
|
||||
except ImportError:
|
||||
logger.warning("未安装 nh3 库,无法清理插件 README.md 中的 HTML 标签。")
|
||||
nh3 = None
|
||||
|
||||
|
||||
class PluginRoute(Route):
|
||||
def __init__(
|
||||
@@ -102,7 +108,10 @@ class PluginRoute(Route):
|
||||
|
||||
async def get_plugins(self):
|
||||
_plugin_resp = []
|
||||
plugin_name = request.args.get("name")
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
if plugin_name and plugin.name != plugin_name:
|
||||
continue
|
||||
_t = {
|
||||
"name": plugin.name,
|
||||
"repo": "" if plugin.repo is None else plugin.repo,
|
||||
@@ -145,9 +154,7 @@ class PluginRoute(Route):
|
||||
if handler.event_type == EventType.AdapterMessageEvent:
|
||||
# 处理平台适配器消息事件
|
||||
has_admin = False
|
||||
for (
|
||||
filter
|
||||
) in (
|
||||
for filter in (
|
||||
handler.event_filters
|
||||
): # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高
|
||||
if isinstance(filter, CommandFilter):
|
||||
@@ -325,6 +332,9 @@ class PluginRoute(Route):
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def get_plugin_readme(self):
|
||||
if not nh3:
|
||||
return Response().error("未安装 nh3 库").__dict__
|
||||
|
||||
plugin_name = request.args.get("name")
|
||||
logger.debug(f"正在获取插件 {plugin_name} 的README文件内容")
|
||||
|
||||
@@ -360,9 +370,11 @@ class PluginRoute(Route):
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
|
||||
cleaned_content = nh3.clean(readme_content)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok({"content": readme_content}, "成功获取README内容")
|
||||
.ok({"content": cleaned_content}, "成功获取README内容")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -383,14 +395,12 @@ class PluginRoute(Route):
|
||||
platform_type = platform.get("type", "")
|
||||
platform_id = platform.get("id", "")
|
||||
|
||||
platforms.append(
|
||||
{
|
||||
"name": platform_id, # 使用type作为name,这是系统内部使用的平台名称
|
||||
"id": platform_id, # 保留id字段以便前端可以显示
|
||||
"type": platform_type,
|
||||
"display_name": f"{platform_type}({platform_id})",
|
||||
}
|
||||
)
|
||||
platforms.append({
|
||||
"name": platform_id, # 使用type作为name,这是系统内部使用的平台名称
|
||||
"id": platform_id, # 保留id字段以便前端可以显示
|
||||
"type": platform_type,
|
||||
"display_name": f"{platform_type}({platform_id})",
|
||||
})
|
||||
|
||||
adjusted_platform_enable = {}
|
||||
for platform_id, plugins in platform_enable.items():
|
||||
@@ -399,13 +409,11 @@ class PluginRoute(Route):
|
||||
# 获取所有插件,包括系统内部插件
|
||||
plugins = []
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
plugins.append(
|
||||
{
|
||||
"name": plugin.name,
|
||||
"desc": plugin.desc,
|
||||
"reserved": plugin.reserved, # 添加reserved标志
|
||||
}
|
||||
)
|
||||
plugins.append({
|
||||
"name": plugin.name,
|
||||
"desc": plugin.desc,
|
||||
"reserved": plugin.reserved, # 添加reserved标志
|
||||
})
|
||||
|
||||
logger.debug(
|
||||
f"获取插件平台配置: 原始配置={platform_enable}, 调整后={adjusted_platform_enable}"
|
||||
@@ -413,13 +421,11 @@ class PluginRoute(Route):
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"platforms": platforms,
|
||||
"plugins": plugins,
|
||||
"platform_enable": adjusted_platform_enable,
|
||||
}
|
||||
)
|
||||
.ok({
|
||||
"platforms": platforms,
|
||||
"plugins": plugins,
|
||||
"platform_enable": adjusted_platform_enable,
|
||||
})
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -8,6 +8,7 @@ from quart import request
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config import VERSION
|
||||
from astrbot.core.utils.io import get_dashboard_version
|
||||
from astrbot.core import DEMO_MODE
|
||||
|
||||
|
||||
@@ -45,8 +46,27 @@ class StatRoute(Route):
|
||||
h, m = divmod(m, 60)
|
||||
return f"{h}小时{m}分{s}秒"
|
||||
|
||||
def is_default_cred(self):
|
||||
username = self.config["dashboard"]["username"]
|
||||
password = self.config["dashboard"]["password"]
|
||||
return (
|
||||
username == "astrbot"
|
||||
and password == "77b90590a8945a7d36c963981a307dc9"
|
||||
and not DEMO_MODE
|
||||
)
|
||||
|
||||
async def get_version(self):
|
||||
return Response().ok({"version": VERSION}).__dict__
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"version": VERSION,
|
||||
"dashboard_version": await get_dashboard_version(),
|
||||
"change_pwd_hint": self.is_default_cred(),
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
async def get_start_time(self):
|
||||
return Response().ok({"start_time": self.core_lifecycle.start_time}).__dict__
|
||||
|
||||
@@ -12,7 +12,10 @@ class StaticFileRoute(Route):
|
||||
"/logs",
|
||||
"/extension",
|
||||
"/dashboard/default",
|
||||
"/project-atri",
|
||||
"/alkaid",
|
||||
"/alkaid/knowledge-base",
|
||||
"/alkaid/long-term-memory",
|
||||
"/alkaid/other",
|
||||
"/console",
|
||||
"/chat",
|
||||
"/settings",
|
||||
|
||||
@@ -91,7 +91,7 @@ class UpdateRoute(Route):
|
||||
# pip 更新依赖
|
||||
logger.info("更新依赖中...")
|
||||
try:
|
||||
pip_installer.install(requirements_path="requirements.txt")
|
||||
await pip_installer.install(requirements_path="requirements.txt")
|
||||
except Exception as e:
|
||||
logger.error(f"更新依赖失败: {e}")
|
||||
|
||||
@@ -140,7 +140,7 @@ class UpdateRoute(Route):
|
||||
if not package:
|
||||
return Response().error("缺少参数 package 或不合法。").__dict__
|
||||
try:
|
||||
pip_installer.install(package, mirror=mirror)
|
||||
await pip_installer.install(package, mirror=mirror)
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_pip: {traceback.format_exc()}")
|
||||
|
||||
@@ -15,6 +15,8 @@ from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
APP: Quart = None
|
||||
|
||||
|
||||
class AstrBotDashboard:
|
||||
def __init__(
|
||||
@@ -27,6 +29,7 @@ class AstrBotDashboard:
|
||||
self.config = core_lifecycle.astrbot_config
|
||||
self.data_path = os.path.abspath(os.path.join(get_astrbot_data_path(), "dist"))
|
||||
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
|
||||
APP = self.app # noqa
|
||||
self.app.config["MAX_CONTENT_LENGTH"] = (
|
||||
128 * 1024 * 1024
|
||||
) # 将 Flask 允许的最大上传文件体大小设置为 128 MB
|
||||
@@ -51,12 +54,29 @@ class AstrBotDashboard:
|
||||
self.conversation_route = ConversationRoute(self.context, db, core_lifecycle)
|
||||
self.file_route = FileRoute(self.context)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
view_func=self.srv_plug_route,
|
||||
methods=["GET", "POST"],
|
||||
)
|
||||
|
||||
self.shutdown_event = shutdown_event
|
||||
|
||||
async def srv_plug_route(self, subpath, *args, **kwargs):
|
||||
"""
|
||||
插件路由
|
||||
"""
|
||||
registered_web_apis = self.core_lifecycle.star_context.registered_web_apis
|
||||
for api in registered_web_apis:
|
||||
route, view_handler, methods, _ = api
|
||||
if route == f"/{subpath}" and request.method in methods:
|
||||
return await view_handler(*args, **kwargs)
|
||||
return jsonify(Response().error("未找到该路由").__dict__)
|
||||
|
||||
async def auth_middleware(self):
|
||||
if not request.path.startswith("/api"):
|
||||
return
|
||||
allowed_endpoints = ["/api/auth/login", "/api/chat/get_file", "/api/file"]
|
||||
allowed_endpoints = ["/api/auth/login", "/api/file"]
|
||||
if any(request.path.startswith(prefix) for prefix in allowed_endpoints):
|
||||
return
|
||||
# claim jwt
|
||||
|
||||
7
changelogs/v3.5.11.md
Normal file
7
changelogs/v3.5.11.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# What's Changed
|
||||
|
||||
1. 新增:火山引擎 TTS
|
||||
2. 修复:修复了 WeChatPadPro 在重新登录时为新设备的问题
|
||||
2. ‼️修复:微信公众号(个人认证或者未认证)的情况下能接收但无法回复消息的问题
|
||||
3. 修复:Minimax TTS 相关问题
|
||||
4. 优化:登录界面侧边栏、关于页面样式,修复如果此前已经登录但未自行跳转的问题
|
||||
18
changelogs/v3.5.12.md
Normal file
18
changelogs/v3.5.12.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# What's Changed
|
||||
|
||||
1. 新增:支持 MCP 的 Streamable HTTP 传输方式。详见 [#1637](https://github.com/Soulter/AstrBot/issues/1637)
|
||||
2. 新增:支持 MCP 的 SSE 传输方式的自定义请求头。详见 [#1659](https://github.com/Soulter/AstrBot/issues/1659)
|
||||
3. 优化:将 /llm 和 /model 和 /provider 指令设置为管理员指令
|
||||
4. 修复:修复插件的 priority 部分失效的问题
|
||||
5. 修复:修复 QQ 下合并转发消息内无法发送文件等问题,尽可能修复了各种文件、语音、视频、图片无法发送的问题
|
||||
6. 优化:Telegram 支持长消息分段发送,优化消息编辑的逻辑
|
||||
7. 优化:WebUI 强制默认修改密码
|
||||
8. 优化:移除了 vpet
|
||||
9. 新增:插件接口:支持动态路由注册
|
||||
10. 优化:CLI 模式下的插件下载
|
||||
11. 新增:WeChatPadPro 对接获取联系人接口
|
||||
12. 新增:T2I、语音、视频支持文件服务
|
||||
13. 优化:硅基流动下某些工具调用返回的 argument 格式适配
|
||||
14. 优化:在使用 /llm 指令关闭后重启 AstrBot 后,模型提供商未被加载
|
||||
15. 新增:新增基于 FAISS + SQLite 的向量存储接口
|
||||
16. 新增:Alkaid Page
|
||||
9
changelogs/v3.5.13.md
Normal file
9
changelogs/v3.5.13.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# What's Changed
|
||||
|
||||
1. 新增:WebUI 支持暗夜模式。
|
||||
2. 修复:修复 WebUI Chat 接口的未授权访问安全漏洞、插件 README 可能存在的 XSS 注入漏洞。
|
||||
3. 优化:优化 Vec DB 在 indexing 过程时的数据库事务处理。
|
||||
4. 修复:WebUI 下,插件市场的推荐卡片无法点击帮助文档的问题。
|
||||
5. 新增:知识库。
|
||||
6. 新增:WebUI 提供商测试功能,一键检测可用性。
|
||||
7. 新增:WebUI 提供商分类功能,按能力分类提供商。
|
||||
11
changelogs/v3.5.14.md
Normal file
11
changelogs/v3.5.14.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# What's Changed
|
||||
|
||||
1. 优化:强化了 WebUI 安全性
|
||||
2. 修复:测试文本生成提供商时可能出现的误报
|
||||
3. 修复:刷新知识库页面时出现404
|
||||
4. 新增:WeChatPadPro 支持获取引用、语音收发、视频等消息段
|
||||
5. 优化:WebUI 账户修改页面的设计逻辑
|
||||
6. 优化:插件更新后自动刷新插件列表
|
||||
7. 新增:支持下载插件的指定分支
|
||||
8. 修复:WeChatPadPro 群聊模式下 @ 不回复等问题
|
||||
9. 其他更新、优化及修复
|
||||
13
changelogs/v3.5.15.md
Normal file
13
changelogs/v3.5.15.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# What's Changed
|
||||
|
||||
1. 修复:如果设置了 GitHub 加速地址,更新插件会报错
|
||||
2. 修复:部分场景下,`只@触发等待` 配置项功能无效的问题
|
||||
3. 新增:增加 `只@触发等待时是否回复` 配置项
|
||||
4. 新增:**支持模型提供商使用时会话隔离(需要手动开启配置项:提供商会话隔离)**
|
||||
5. 新增:Google Gemini 提供商支持 URL 上下文功能
|
||||
6. 新增:优化 WebChat 的 UI 显示,WebChat 支持修改标题和自动生成标题,支持 WebChatBox
|
||||
7. 新增:支持可配置是否忽略 @ 全体成员
|
||||
8. 优化:WebUI 顶栏移动端显示
|
||||
9. 优化:插件/AstrBot 配置项完整性检查的同时也保证**配置项相对顺序一致性**
|
||||
10. 优化:perf: 分段回复时,仅在输出的第一句话带上回复/引用
|
||||
11. 修复: Windows 下部署项目时可能出现的 UnicodeDecodeError。
|
||||
@@ -20,6 +20,7 @@
|
||||
"axios": "^1.6.2",
|
||||
"axios-mock-adapter": "^1.22.0",
|
||||
"chance": "1.1.11",
|
||||
"d3": "^7.9.0",
|
||||
"date-fns": "2.30.0",
|
||||
"highlight.js": "^11.11.1",
|
||||
"js-md5": "^0.8.3",
|
||||
|
||||
BIN
dashboard/src/assets/images/astrbot_logo_mini.webp
Normal file
BIN
dashboard/src/assets/images/astrbot_logo_mini.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 10 KiB |
@@ -340,12 +340,12 @@ export default {
|
||||
.config-title {
|
||||
font-weight: 600;
|
||||
font-size: 1rem;
|
||||
color: var(--v-primary-darken1);
|
||||
color: var(--v-theme-primaryText);
|
||||
}
|
||||
|
||||
.config-hint {
|
||||
font-size: 0.75rem;
|
||||
color: rgba(0, 0, 0, 0.6);
|
||||
color: var(--v-theme-secondaryText);
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
@@ -400,12 +400,12 @@ export default {
|
||||
.property-name {
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
color: rgba(0, 0, 0, 0.87);
|
||||
color: var(--v-theme-primaryText);
|
||||
}
|
||||
|
||||
.property-hint {
|
||||
font-size: 0.75rem;
|
||||
color: rgba(0, 0, 0, 0.6);
|
||||
color: var(--v-theme-secondaryText);
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import { useCommonStore } from '@/stores/common';
|
||||
<template>
|
||||
<div>
|
||||
<!-- 添加筛选级别控件 -->
|
||||
<div class="filter-controls mb-2">
|
||||
<div class="filter-controls mb-2" v-if="showLevelBtns">
|
||||
<v-chip-group v-model="selectedLevels" column multiple>
|
||||
<v-chip v-for="level in logLevels" :key="level" :color="getLevelColor(level)" filter
|
||||
:text-color="level === 'DEBUG' || level === 'INFO' ? 'black' : 'white'">
|
||||
@@ -52,6 +52,10 @@ export default {
|
||||
historyNum: {
|
||||
type: String,
|
||||
default: -1
|
||||
},
|
||||
showLevelBtns: {
|
||||
type: Boolean,
|
||||
default: true
|
||||
}
|
||||
},
|
||||
watch: {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, inject } from 'vue';
|
||||
import {useCustomizerStore} from "@/stores/customizer";
|
||||
|
||||
const props = defineProps({
|
||||
extension: {
|
||||
@@ -75,7 +76,9 @@ const viewReadme = () => {
|
||||
|
||||
<template>
|
||||
<v-card class="mx-auto d-flex flex-column" :elevation="highlight ? 0 : 1"
|
||||
:style="{ height: $vuetify.display.xs ? '250px' : '220px', backgroundColor: highlight ? '#FAF0DB' : '#ffffff', color: highlight ? '#000' : '#000000' }">
|
||||
:style="{ height: $vuetify.display.xs ? '250px' : '220px',
|
||||
backgroundColor: useCustomizerStore().uiTheme==='PurpleTheme' ? marketMode ? '#f8f0dd' : '#ffffff' : '#282833',
|
||||
color: useCustomizerStore().uiTheme==='PurpleTheme' ? '#000000dd' : '#ffffff'}">
|
||||
<v-card-text style="padding: 16px; padding-bottom: 0px; display: flex; justify-content: space-between;">
|
||||
|
||||
<div class="flex-grow-1">
|
||||
@@ -128,7 +131,7 @@ const viewReadme = () => {
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions style="padding: 0px; margin-top: auto;">
|
||||
<v-card-actions style="margin-left: 0px; gap: 2px;">
|
||||
<v-btn color="teal-accent-4" text="查看文档" variant="text" @click="viewReadme"></v-btn>
|
||||
<v-btn v-if="!marketMode" color="teal-accent-4" text="操作" variant="text" @click="reveal = true"></v-btn>
|
||||
<v-btn v-if="marketMode && !extension?.installed" color="teal-accent-4" text="安装" variant="text"
|
||||
|
||||
@@ -104,11 +104,11 @@ export default {
|
||||
|
||||
<style scoped>
|
||||
.list-config-item {
|
||||
border: 1px solid #e0e0e0;
|
||||
border: 1px solid var(--v-theme-border);
|
||||
padding: 16px;
|
||||
margin-bottom: 8px;
|
||||
border-radius: 10px;
|
||||
background-color: #ffffff;
|
||||
background-color: var(--v-theme-background);
|
||||
}
|
||||
|
||||
.v-list-item {
|
||||
|
||||
78
dashboard/src/components/shared/Logo.vue
Normal file
78
dashboard/src/components/shared/Logo.vue
Normal file
@@ -0,0 +1,78 @@
|
||||
<template>
|
||||
<div class="logo-container">
|
||||
<div class="logo-content">
|
||||
<div class="logo-image">
|
||||
<img width="110" src="@/assets/images/astrbot_logo_mini.webp" alt="AstrBot Logo">
|
||||
</div>
|
||||
<div class="logo-text">
|
||||
<h2 class="text-secondary">{{ title }}</h2>
|
||||
<!-- 父子组件传递css变量可能会出错,暂时使用十六进制颜色值 -->
|
||||
<h4 :style="{color: useCustomizerStore().uiTheme === 'PurpleTheme' ? '#000000aa' : '#ffffffcc'}"
|
||||
class="hint-text">{{ subtitle }}</h4>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { useCustomizerStore } from "@/stores/customizer";
|
||||
|
||||
const props = withDefaults(defineProps<{
|
||||
title?: string;
|
||||
subtitle?: string;
|
||||
}>(), {
|
||||
title: 'AstrBot 仪表盘',
|
||||
subtitle: '欢迎使用'
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.logo-container {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.logo-content {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 20px;
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.logo-image {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.logo-image img {
|
||||
transition: transform 0.3s ease;
|
||||
}
|
||||
|
||||
.logo-image img:hover {
|
||||
transform: scale(1.05);
|
||||
}
|
||||
|
||||
.logo-text {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
}
|
||||
|
||||
.logo-text h2 {
|
||||
margin: 0;
|
||||
font-size: 1.8rem;
|
||||
font-weight: 600;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.logo-text h4 {
|
||||
margin: 4px 0 0 0;
|
||||
font-size: 1rem;
|
||||
font-weight: 400;
|
||||
letter-spacing: 0.3px;
|
||||
}
|
||||
</style>
|
||||
@@ -3,14 +3,25 @@ export type ConfigProps = {
|
||||
Customizer_drawer: boolean;
|
||||
mini_sidebar: boolean;
|
||||
fontTheme: string;
|
||||
uiTheme: string;
|
||||
inputBg: boolean;
|
||||
};
|
||||
|
||||
function checkUITheme() {
|
||||
/* 检查localStorage有无记忆的主题选项,如有则使用,否则使用默认值 */
|
||||
const theme = localStorage.getItem("uiTheme");
|
||||
if (!theme || !(['PurpleTheme', 'PurpleThemeDark'].includes(theme))) {
|
||||
localStorage.setItem("uiTheme", "PurpleTheme"); // todo: 这部分可以根据vuetify.ts的默认主题动态调整
|
||||
return 'PurpleTheme';
|
||||
} else return theme;
|
||||
}
|
||||
|
||||
const config: ConfigProps = {
|
||||
Sidebar_drawer: true,
|
||||
Customizer_drawer: false,
|
||||
mini_sidebar: false,
|
||||
fontTheme: 'Roboto',
|
||||
uiTheme: checkUITheme(),
|
||||
inputBg: false
|
||||
};
|
||||
|
||||
|
||||
@@ -2,14 +2,13 @@
|
||||
import { RouterView } from 'vue-router';
|
||||
import VerticalSidebarVue from './vertical-sidebar/VerticalSidebar.vue';
|
||||
import VerticalHeaderVue from './vertical-header/VerticalHeader.vue';
|
||||
import { useCustomizerStore } from '../../stores/customizer';
|
||||
import { useCustomizerStore } from '@/stores/customizer';
|
||||
const customizer = useCustomizerStore();
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<v-locale-provider>
|
||||
<v-app
|
||||
theme="PurpleTheme"
|
||||
<v-app :theme="useCustomizerStore().uiTheme"
|
||||
:class="[customizer.fontTheme, customizer.mini_sidebar ? 'mini-sidebar' : '', customizer.inputBg ? 'inputWithbg' : '']"
|
||||
>
|
||||
<VerticalHeaderVue />
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue';
|
||||
import { useCustomizerStore } from '../../../stores/customizer';
|
||||
import {ref, computed} from 'vue';
|
||||
import {useCustomizerStore} from '@/stores/customizer';
|
||||
import axios from 'axios';
|
||||
import { md5 } from 'js-md5';
|
||||
import { useAuthStore } from '@/stores/auth';
|
||||
import { useCommonStore } from '@/stores/common';
|
||||
import { marked } from 'marked';
|
||||
import Logo from '@/components/shared/Logo.vue';
|
||||
import {md5} from 'js-md5';
|
||||
import {useAuthStore} from '@/stores/auth';
|
||||
import {useCommonStore} from '@/stores/common';
|
||||
import {marked} from 'marked';
|
||||
|
||||
const customizer = useCustomizerStore();
|
||||
let dialog = ref(false);
|
||||
let accountWarning = ref(false)
|
||||
let updateStatusDialog = ref(false);
|
||||
const username = localStorage.getItem('user');
|
||||
let password = ref('');
|
||||
let newPassword = ref('');
|
||||
let newUsername = ref('');
|
||||
@@ -23,26 +25,52 @@ let dashboardHasNewVersion = ref(false);
|
||||
let dashboardCurrentVersion = ref('');
|
||||
let version = ref('');
|
||||
let releases = ref([]);
|
||||
let devCommits = ref([]); // 新增的 ref
|
||||
let devCommits = ref([]);
|
||||
|
||||
let installLoading = ref(false);
|
||||
|
||||
let tab = ref(0);
|
||||
|
||||
let releasesHeader = [
|
||||
{ title: '标签', key: 'tag_name' },
|
||||
{ title: '发布时间', key: 'published_at' },
|
||||
{ title: '内容', key: 'body' },
|
||||
{ title: '源码地址', key: 'zipball_url' },
|
||||
{ title: '操作', key: 'switch' }
|
||||
{title: '标签', key: 'tag_name'},
|
||||
{title: '发布时间', key: 'published_at'},
|
||||
{title: '内容', key: 'body'},
|
||||
{title: '源码地址', key: 'zipball_url'},
|
||||
{title: '操作', key: 'switch'}
|
||||
];
|
||||
|
||||
// Form validation
|
||||
const formValid = ref(true);
|
||||
const passwordRules = [
|
||||
(v: string) => !!v || '请输入密码',
|
||||
(v: string) => v.length >= 8 || '密码长度至少 8 位'
|
||||
];
|
||||
const usernameRules = [
|
||||
(v: string) => !v || v.length >= 3 || '用户名长度至少3位'
|
||||
];
|
||||
|
||||
// 显示密码相关
|
||||
const showPassword = ref(false);
|
||||
const showNewPassword = ref(false);
|
||||
|
||||
// 账户修改状态
|
||||
const accountEditStatus = ref({
|
||||
loading: false,
|
||||
success: false,
|
||||
error: false,
|
||||
message: ''
|
||||
});
|
||||
|
||||
const open = (link: string) => {
|
||||
window.open(link, '_blank');
|
||||
};
|
||||
|
||||
// 账户修改
|
||||
function accountEdit() {
|
||||
accountEditStatus.value.loading = true;
|
||||
accountEditStatus.value.error = false;
|
||||
accountEditStatus.value.success = false;
|
||||
|
||||
// md5加密
|
||||
// @ts-ignore
|
||||
if (password.value != '') {
|
||||
@@ -54,71 +82,92 @@ function accountEdit() {
|
||||
axios.post('/api/auth/account/edit', {
|
||||
password: password.value,
|
||||
new_password: newPassword.value,
|
||||
new_username: newUsername.value
|
||||
new_username: newUsername.value ? newUsername.value : username
|
||||
})
|
||||
.then((res) => {
|
||||
if (res.data.status == 'error') {
|
||||
status.value = res.data.message;
|
||||
.then((res) => {
|
||||
if (res.data.status == 'error') {
|
||||
accountEditStatus.value.error = true;
|
||||
accountEditStatus.value.message = res.data.message;
|
||||
password.value = '';
|
||||
newPassword.value = '';
|
||||
return;
|
||||
}
|
||||
accountEditStatus.value.success = true;
|
||||
accountEditStatus.value.message = res.data.message;
|
||||
setTimeout(() => {
|
||||
dialog.value = !dialog.value;
|
||||
const authStore = useAuthStore();
|
||||
authStore.logout();
|
||||
}, 2000);
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
accountEditStatus.value.error = true;
|
||||
accountEditStatus.value.message = typeof err === 'string' ? err : '修改失败,请重试';
|
||||
password.value = '';
|
||||
newPassword.value = '';
|
||||
return;
|
||||
}
|
||||
dialog.value = !dialog.value;
|
||||
status.value = res.data.message;
|
||||
setTimeout(() => {
|
||||
const authStore = useAuthStore();
|
||||
authStore.logout();
|
||||
}, 1000);
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
status.value = err
|
||||
password.value = '';
|
||||
newPassword.value = '';
|
||||
});
|
||||
})
|
||||
.finally(() => {
|
||||
accountEditStatus.value.loading = false;
|
||||
});
|
||||
}
|
||||
|
||||
function getVersion() {
|
||||
axios.get('/api/stat/version')
|
||||
.then((res) => {
|
||||
botCurrVersion.value = "v" + res.data.data.version;
|
||||
dashboardCurrentVersion.value = res.data.data?.dashboard_version;
|
||||
let change_pwd_hint = res.data.data?.change_pwd_hint;
|
||||
if (change_pwd_hint) {
|
||||
dialog.value = true;
|
||||
accountWarning.value = true;
|
||||
localStorage.setItem('change_pwd_hint', 'true');
|
||||
} else {
|
||||
localStorage.removeItem('change_pwd_hint');
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
});
|
||||
}
|
||||
|
||||
function checkUpdate() {
|
||||
updateStatus.value = '正在检查更新...';
|
||||
axios.get('/api/update/check')
|
||||
.then((res) => {
|
||||
hasNewVersion.value = res.data.data.has_new_version;
|
||||
.then((res) => {
|
||||
hasNewVersion.value = res.data.data.has_new_version;
|
||||
|
||||
if (res.data.data.has_new_version) {
|
||||
releaseMessage.value = res.data.message;
|
||||
updateStatus.value = '有新版本!';
|
||||
} else {
|
||||
updateStatus.value = res.data.message;
|
||||
}
|
||||
botCurrVersion.value = res.data.data.version;
|
||||
dashboardCurrentVersion.value = res.data.data.dashboard_version;
|
||||
dashboardHasNewVersion.value = res.data.data.dashboard_has_new_version;
|
||||
})
|
||||
.catch((err) => {
|
||||
if (err.response.status == 401) {
|
||||
console.log("401");
|
||||
const authStore = useAuthStore();
|
||||
authStore.logout();
|
||||
return;
|
||||
}
|
||||
console.log(err);
|
||||
updateStatus.value = err
|
||||
});
|
||||
if (res.data.data.has_new_version) {
|
||||
releaseMessage.value = res.data.message;
|
||||
updateStatus.value = '有新版本!';
|
||||
} else {
|
||||
updateStatus.value = res.data.message;
|
||||
}
|
||||
dashboardHasNewVersion.value = res.data.data.dashboard_has_new_version;
|
||||
})
|
||||
.catch((err) => {
|
||||
if (err.response.status == 401) {
|
||||
console.log("401");
|
||||
const authStore = useAuthStore();
|
||||
authStore.logout();
|
||||
return;
|
||||
}
|
||||
console.log(err);
|
||||
updateStatus.value = err
|
||||
});
|
||||
}
|
||||
|
||||
function getReleases() {
|
||||
axios.get('/api/update/releases')
|
||||
.then((res) => {
|
||||
// releases.value = res.data.data;
|
||||
// 更新 published_at 的时间为本地时间
|
||||
releases.value = res.data.data.map((item: any) => {
|
||||
item.published_at = new Date(item.published_at).toLocaleString();
|
||||
return item;
|
||||
.then((res) => {
|
||||
releases.value = res.data.data.map((item: any) => {
|
||||
item.published_at = new Date(item.published_at).toLocaleString();
|
||||
return item;
|
||||
})
|
||||
})
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
});
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
});
|
||||
}
|
||||
|
||||
function getDevCommits() {
|
||||
@@ -128,17 +177,17 @@ function getDevCommits() {
|
||||
'Referer': 'https://api.github.com'
|
||||
}
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
devCommits.value = data.map((commit: any) => ({
|
||||
sha: commit.sha,
|
||||
date: new Date(commit.commit.author.date).toLocaleString(),
|
||||
message: commit.commit.message
|
||||
}));
|
||||
})
|
||||
.catch(err => {
|
||||
console.log(err);
|
||||
});
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
devCommits.value = data.map((commit: any) => ({
|
||||
sha: commit.sha,
|
||||
date: new Date(commit.commit.author.date).toLocaleString(),
|
||||
message: commit.commit.message
|
||||
}));
|
||||
})
|
||||
.catch(err => {
|
||||
console.log(err);
|
||||
});
|
||||
}
|
||||
|
||||
function switchVersion(version: string) {
|
||||
@@ -148,88 +197,111 @@ function switchVersion(version: string) {
|
||||
version: version,
|
||||
proxy: localStorage.getItem('selectedGitHubProxy') || ''
|
||||
})
|
||||
.then((res) => {
|
||||
updateStatus.value = res.data.message;
|
||||
if (res.data.status == 'ok') {
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 1000);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
updateStatus.value = err
|
||||
}).finally(() => {
|
||||
installLoading.value = false;
|
||||
});
|
||||
.then((res) => {
|
||||
updateStatus.value = res.data.message;
|
||||
if (res.data.status == 'ok') {
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 1000);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
updateStatus.value = err
|
||||
}).finally(() => {
|
||||
installLoading.value = false;
|
||||
});
|
||||
}
|
||||
|
||||
function updateDashboard() {
|
||||
updateStatus.value = '正在更新...';
|
||||
axios.post('/api/update/dashboard')
|
||||
.then((res) => {
|
||||
updateStatus.value = res.data.message;
|
||||
if (res.data.status == 'ok') {
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 1000);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
updateStatus.value = err
|
||||
});
|
||||
.then((res) => {
|
||||
updateStatus.value = res.data.message;
|
||||
if (res.data.status == 'ok') {
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 1000);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
updateStatus.value = err
|
||||
});
|
||||
}
|
||||
|
||||
function toggleDarkMode() {
|
||||
customizer.SET_UI_THEME(customizer.uiTheme === 'PurpleThemeDark' ? 'PurpleTheme' : 'PurpleThemeDark');
|
||||
}
|
||||
|
||||
getVersion();
|
||||
checkUpdate();
|
||||
|
||||
const commonStore = useCommonStore();
|
||||
commonStore.createEventSource(); // log
|
||||
commonStore.getStartTime();
|
||||
|
||||
|
||||
if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('change_pwd_hint') == 'true') {
|
||||
dialog.value = true;
|
||||
accountWarning.value = true;
|
||||
localStorage.removeItem('change_pwd_hint');
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<v-app-bar elevation="0" height="55">
|
||||
|
||||
<v-btn style="margin-left: 22px;" class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm"
|
||||
variant="flat" @click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
|
||||
<v-btn v-if="useCustomizerStore().uiTheme==='PurpleTheme'" style="margin-left: 22px;" class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm"
|
||||
variant="flat" @click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
<v-btn class="hidden-lg-and-up text-secondary ms-3" color="lightsecondary" icon rounded="sm" variant="flat"
|
||||
@click.stop="customizer.SET_SIDEBAR_DRAWER" size="small">
|
||||
<v-btn v-else style="margin-left: 22px; color: var(--v-theme-primaryText); background-color: var(--v-theme-secondary)" class="hidden-md-and-down" icon rounded="sm"
|
||||
variant="flat" @click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
<v-btn v-if="useCustomizerStore().uiTheme==='PurpleTheme'" class="hidden-lg-and-up ms-3" color="lightsecondary" icon rounded="sm" variant="flat"
|
||||
@click.stop="customizer.SET_SIDEBAR_DRAWER" size="small">
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
<v-btn v-else class="hidden-lg-and-up ms-3" icon rounded="sm" variant="flat"
|
||||
@click.stop="customizer.SET_SIDEBAR_DRAWER" size="small">
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
|
||||
<span style="margin-left: 16px; font-size: 24px; font-weight: 1000;">Astr<span
|
||||
style="font-weight: normal;">Bot</span></span>
|
||||
<div class="logo-container" :class="{'mobile-logo': $vuetify.display.xs}">
|
||||
<span class="logo-text">Astr<span class="logo-text-light">Bot</span></span>
|
||||
<span class="version-text hidden-xs">{{ botCurrVersion }}</span>
|
||||
</div>
|
||||
|
||||
<v-spacer />
|
||||
<v-spacer/>
|
||||
|
||||
<div class="mr-4">
|
||||
<!-- 版本提示信息 - 在手机上隐藏 -->
|
||||
<div class="mr-4 hidden-xs">
|
||||
<small v-if="hasNewVersion">
|
||||
有新版本!
|
||||
AstrBot 有新版本!
|
||||
</small>
|
||||
<small v-else-if="dashboardHasNewVersion">
|
||||
WebUI 有新版本!
|
||||
</small>
|
||||
</div>
|
||||
|
||||
<!-- 主题切换按钮 -->
|
||||
<v-btn size="small" @click="toggleDarkMode();" class="action-btn"
|
||||
color="var(--v-theme-surface)" variant="flat" rounded="sm">
|
||||
<v-icon v-if="useCustomizerStore().uiTheme === 'PurpleThemeDark'">mdi-weather-night</v-icon>
|
||||
<v-icon v-else>mdi-white-balance-sunny</v-icon>
|
||||
</v-btn>
|
||||
|
||||
<v-dialog v-model="updateStatusDialog" width="1000">
|
||||
<!-- 更新对话框 -->
|
||||
<v-dialog v-model="updateStatusDialog" :width="$vuetify.display.smAndDown ? '100%' : '1000'" :fullscreen="$vuetify.display.xs">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn @click="checkUpdate(); getReleases(); getDevCommits();" class="text-primary mr-4" color="lightprimary"
|
||||
variant="flat" rounded="sm" v-bind="props">
|
||||
更新 🔄
|
||||
<v-btn size="small" @click="checkUpdate(); getReleases(); getDevCommits();" class="action-btn"
|
||||
color="var(--v-theme-surface)" variant="flat" rounded="sm" v-bind="props">
|
||||
<v-icon class="hidden-sm-and-up">mdi-update</v-icon>
|
||||
<span class="hidden-xs">更新</span>
|
||||
</v-btn>
|
||||
</template>
|
||||
<v-card>
|
||||
<v-card-title>
|
||||
<v-card-title class="mobile-card-title">
|
||||
<span class="text-h5">更新 AstrBot</span>
|
||||
<v-btn v-if="$vuetify.display.xs" icon @click="updateStatusDialog = false">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<v-container>
|
||||
@@ -240,16 +312,16 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
|
||||
<small style="margin-left: 4px;">{{ updateStatus }}</small>
|
||||
</div>
|
||||
|
||||
<div
|
||||
style="background-color: #646cff24; padding: 16px; border-radius: 10px; font-size: 14px; max-height: 400px; overflow-y: auto;"
|
||||
v-html="marked(releaseMessage)" class="markdown-content">
|
||||
|
||||
<div v-if="releaseMessage"
|
||||
style="background-color: #646cff24; padding: 16px; border-radius: 10px; font-size: 14px; max-height: 400px; overflow-y: auto;"
|
||||
v-html="marked(releaseMessage)" class="markdown-content">
|
||||
</div>
|
||||
|
||||
<div class="mb-4 mt-4">
|
||||
<small>💡 TIP: 跳到旧版本或者切换到某个版本不会重新下载管理面板文件,这可能会造成部分数据显示错误。您可在 <a
|
||||
href="https://github.com/Soulter/AstrBot/releases">此处</a>
|
||||
找到对应的面板文件 dist.zip,解压后替换 data/dist 文件夹即可。当然,前端源代码在 dashboard 目录下,你也可以自己使用 npm install 和 npm build
|
||||
找到对应的面板文件 dist.zip,解压后替换 data/dist 文件夹即可。当然,前端源代码在 dashboard 目录下,你也可以自己使用
|
||||
npm install 和 npm build
|
||||
构建。</small>
|
||||
</div>
|
||||
|
||||
@@ -262,12 +334,13 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
|
||||
<!-- 发行版 -->
|
||||
<v-tabs-window-item key="0" v-show="tab == 0">
|
||||
<v-btn class="mt-4 mb-4" @click="switchVersion('latest')" color="primary" style="border-radius: 10px;"
|
||||
:disabled="!hasNewVersion">
|
||||
:disabled="!hasNewVersion">
|
||||
更新到最新版本
|
||||
</v-btn>
|
||||
<div class="mb-4">
|
||||
<small>`更新到最新版本` 按钮会同时尝试更新机器人主程序和管理面板。如果您正在使用 Docker 部署,也可以重新拉取镜像或者使用 <a
|
||||
href="https://containrrr.dev/watchtower/usage-overview/">watchtower</a> 来自动监控拉取。</small>
|
||||
<small>`更新到最新版本` 按钮会同时尝试更新机器人主程序和管理面板。如果您正在使用 Docker
|
||||
部署,也可以重新拉取镜像或者使用 <a
|
||||
href="https://containrrr.dev/watchtower/usage-overview/">watchtower</a> 来自动监控拉取。</small>
|
||||
</div>
|
||||
|
||||
<v-data-table :headers="releasesHeader" :items="releases" item-key="name">
|
||||
@@ -290,8 +363,8 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
|
||||
<v-tabs-window-item key="1" v-show="tab == 1">
|
||||
<div style="margin-top: 16px;">
|
||||
<v-data-table
|
||||
:headers="[{ title: 'SHA', key: 'sha' }, { title: '日期', key: 'date' }, { title: '信息', key: 'message' }, { title: '操作', key: 'switch' }]"
|
||||
:items="devCommits" item-key="sha">
|
||||
:headers="[{ title: 'SHA', key: 'sha' }, { title: '日期', key: 'date' }, { title: '信息', key: 'message' }, { title: '操作', key: 'switch' }]"
|
||||
:items="devCommits" item-key="sha">
|
||||
<template v-slot:item.switch="{ item }: { item: { sha: string } }">
|
||||
<v-btn @click="switchVersion(item.sha)" rounded="xl" variant="plain" color="primary">
|
||||
切换
|
||||
@@ -306,12 +379,13 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
|
||||
<h3 class="mb-4">手动输入版本号或 Commit SHA</h3>
|
||||
|
||||
<v-text-field label="输入版本号或 master 分支下的 commit hash。" v-model="version" required
|
||||
variant="outlined"></v-text-field>
|
||||
variant="outlined"></v-text-field>
|
||||
<div class="mb-4">
|
||||
<small>如 v3.3.16 (不带 SHA) 或 42e5ec5d80b93b6bfe8b566754d45ffac4c3fe0b</small>
|
||||
<br>
|
||||
<a href="https://github.com/Soulter/AstrBot/commits/master"><small>查看 master 分支提交记录(点击右边的 copy
|
||||
即可复制)</small></a>
|
||||
<a href="https://github.com/Soulter/AstrBot/commits/master"><small>查看 master 分支提交记录(点击右边的
|
||||
copy
|
||||
即可复制)</small></a>
|
||||
</div>
|
||||
<v-btn color="error" style="border-radius: 10px;" @click="switchVersion(version)">
|
||||
确定切换
|
||||
@@ -336,7 +410,7 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
|
||||
</div>
|
||||
|
||||
<v-btn color="primary" style="border-radius: 10px;" @click="updateDashboard()"
|
||||
:disabled="!dashboardHasNewVersion">
|
||||
:disabled="!dashboardHasNewVersion">
|
||||
下载并更新
|
||||
</v-btn>
|
||||
</div>
|
||||
@@ -351,46 +425,119 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<v-dialog v-model="dialog" persistent width="700">
|
||||
<!-- 账户对话框 -->
|
||||
<v-dialog v-model="dialog" persistent :max-width="$vuetify.display.xs ? '90%' : '500'">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn class="text-primary mr-4" color="lightprimary" variant="flat" rounded="sm" v-bind="props">
|
||||
账户 📰
|
||||
<v-btn size="small" class="action-btn mr-4" color="var(--v-theme-surface)" variant="flat" rounded="sm" v-bind="props">
|
||||
<v-icon>mdi-account</v-icon>
|
||||
<span class="hidden-xs ml-1">账户</span>
|
||||
</v-btn>
|
||||
</template>
|
||||
<v-card>
|
||||
<v-card-title>
|
||||
<span class="text-h5">账户</span>
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<v-container>
|
||||
<v-row>
|
||||
<v-col cols="12">
|
||||
<v-card class="account-dialog">
|
||||
<v-card-text class="py-6">
|
||||
<div class="d-flex flex-column align-center mb-6">
|
||||
<logo title="AstrBot 仪表盘" subtitle="修改账户"></logo>
|
||||
</div>
|
||||
<v-alert
|
||||
v-if="accountWarning"
|
||||
type="warning"
|
||||
variant="tonal"
|
||||
border="start"
|
||||
class="mb-4"
|
||||
>
|
||||
<strong>安全提醒:</strong> 请修改默认密码以确保账户安全
|
||||
</v-alert>
|
||||
|
||||
<v-alert v-if="accountWarning" color="warning" style="margin-bottom: 16px;">
|
||||
<div>为了安全,请尽快修改默认密码。</div>
|
||||
</v-alert>
|
||||
<v-alert
|
||||
v-if="accountEditStatus.success"
|
||||
type="success"
|
||||
variant="tonal"
|
||||
border="start"
|
||||
class="mb-4"
|
||||
>
|
||||
{{ accountEditStatus.message }}
|
||||
</v-alert>
|
||||
|
||||
<v-text-field label="原密码*" type="password" v-model="password" required
|
||||
variant="outlined"></v-text-field>
|
||||
<v-alert
|
||||
v-if="accountEditStatus.error"
|
||||
type="error"
|
||||
variant="tonal"
|
||||
border="start"
|
||||
class="mb-4"
|
||||
>
|
||||
{{ accountEditStatus.message }}
|
||||
</v-alert>
|
||||
|
||||
<v-text-field label="新用户名" v-model="newUsername" required variant="outlined"></v-text-field>
|
||||
<v-form v-model="formValid" @submit.prevent="accountEdit">
|
||||
<v-text-field
|
||||
v-model="password"
|
||||
:append-inner-icon="showPassword ? 'mdi-eye-off' : 'mdi-eye'"
|
||||
:type="showPassword ? 'text' : 'password'"
|
||||
label="当前密码"
|
||||
variant="outlined"
|
||||
required
|
||||
clearable
|
||||
@click:append-inner="showPassword = !showPassword"
|
||||
prepend-inner-icon="mdi-lock-outline"
|
||||
hide-details="auto"
|
||||
class="mb-4"
|
||||
></v-text-field>
|
||||
|
||||
<v-text-field label="新密码" type="password" v-model="newPassword" required
|
||||
variant="outlined"></v-text-field>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-container>
|
||||
<small>默认用户名和密码是 astrbot。</small>
|
||||
<br>
|
||||
<small>{{ status }}</small>
|
||||
<v-text-field
|
||||
v-model="newPassword"
|
||||
:append-inner-icon="showNewPassword ? 'mdi-eye-off' : 'mdi-eye'"
|
||||
:type="showNewPassword ? 'text' : 'password'"
|
||||
:rules="passwordRules"
|
||||
label="新密码"
|
||||
variant="outlined"
|
||||
required
|
||||
clearable
|
||||
@click:append-inner="showNewPassword = !showNewPassword"
|
||||
prepend-inner-icon="mdi-lock-plus-outline"
|
||||
hint="密码长度至少 8 位"
|
||||
persistent-hint
|
||||
class="mb-4"
|
||||
></v-text-field>
|
||||
|
||||
<v-text-field
|
||||
v-model="newUsername"
|
||||
:rules="usernameRules"
|
||||
label="新用户名 (可选)"
|
||||
variant="outlined"
|
||||
clearable
|
||||
prepend-inner-icon="mdi-account-edit-outline"
|
||||
hint="留空表示不修改用户名"
|
||||
persistent-hint
|
||||
class="mb-3"
|
||||
></v-text-field>
|
||||
</v-form>
|
||||
|
||||
<div class="text-caption text-medium-emphasis mt-2">
|
||||
默认用户名和密码均为 astrbot
|
||||
</div>
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
|
||||
<v-divider></v-divider>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="blue-darken-1" variant="text" @click="dialog = false">
|
||||
关闭
|
||||
<v-btn
|
||||
v-if="!accountWarning"
|
||||
variant="tonal"
|
||||
color="secondary"
|
||||
@click="dialog = false"
|
||||
:disabled="accountEditStatus.loading"
|
||||
>
|
||||
取消
|
||||
</v-btn>
|
||||
<v-btn color="blue-darken-1" variant="text" @click="accountEdit">
|
||||
提交
|
||||
<v-btn
|
||||
color="primary"
|
||||
@click="accountEdit"
|
||||
:loading="accountEditStatus.loading"
|
||||
:disabled="!formValid"
|
||||
prepend-icon="mdi-content-save"
|
||||
>
|
||||
保存修改
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
@@ -416,4 +563,91 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
|
||||
margin-top: 8px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.account-dialog .v-card-text {
|
||||
padding-top: 24px;
|
||||
padding-bottom: 24px;
|
||||
}
|
||||
|
||||
.account-dialog .v-alert {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.account-dialog .v-btn {
|
||||
text-transform: none;
|
||||
font-weight: 500;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.account-dialog .v-avatar {
|
||||
transition: transform 0.3s ease;
|
||||
}
|
||||
|
||||
.account-dialog .v-avatar:hover {
|
||||
transform: scale(1.05);
|
||||
}
|
||||
|
||||
/* 响应式布局样式 */
|
||||
.logo-container {
|
||||
margin-left: 16px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.mobile-logo {
|
||||
margin-left: 8px;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.logo-text {
|
||||
font-size: 24px;
|
||||
font-weight: 1000;
|
||||
}
|
||||
|
||||
.logo-text-light {
|
||||
font-weight: normal;
|
||||
}
|
||||
|
||||
.version-text {
|
||||
font-size: 12px;
|
||||
color: var(--v-theme-secondaryText);
|
||||
}
|
||||
|
||||
.action-btn {
|
||||
margin-right: 6px;
|
||||
}
|
||||
|
||||
/* 移动端对话框标题样式 */
|
||||
.mobile-card-title {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
/* 移动端样式优化 */
|
||||
@media (max-width: 600px) {
|
||||
.logo-text {
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.action-btn {
|
||||
margin-right: 4px;
|
||||
min-width: 32px !important;
|
||||
width: 32px;
|
||||
}
|
||||
|
||||
.v-card-title {
|
||||
padding: 12px 16px;
|
||||
}
|
||||
|
||||
.v-card-text {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.v-tabs .v-tab {
|
||||
padding: 0 10px;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -9,9 +9,6 @@ const customizer = useCustomizerStore();
|
||||
const sidebarMenu = shallowRef(sidebarItems);
|
||||
|
||||
const showIframe = ref(false);
|
||||
const version = ref("");
|
||||
const buildVer = ref("");
|
||||
const hasWebUIUpdate = ref(false);
|
||||
|
||||
// 默认桌面端 iframe 样式
|
||||
const iframeStyle = ref({
|
||||
@@ -68,9 +65,10 @@ function toggleIframe() {
|
||||
showIframe.value = !showIframe.value;
|
||||
}
|
||||
|
||||
function openIframeLink() {
|
||||
function openIframeLink(url) {
|
||||
if (typeof window !== 'undefined') {
|
||||
window.open("https://astrbot.app", "_blank");
|
||||
let url_ = url || "https://astrbot.app";
|
||||
window.open(url_, "_blank");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,25 +147,6 @@ function endDrag() {
|
||||
document.removeEventListener('touchend', onTouchEnd);
|
||||
}
|
||||
|
||||
// 获取版本和更新信息
|
||||
onMounted(() => {
|
||||
axios.get('/api/stat/version')
|
||||
.then((res) => {
|
||||
version.value = "v" + res.data.data.version;
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
});
|
||||
|
||||
axios.get('/api/update/check?type=dashboard')
|
||||
.then((res) => {
|
||||
hasWebUIUpdate.value = res.data.data.has_new_version;
|
||||
buildVer.value = res.data.data.current_version;
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
});
|
||||
});
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -186,27 +165,23 @@ onMounted(() => {
|
||||
<NavItem :item="item" class="leftPadding" />
|
||||
</template>
|
||||
</v-list>
|
||||
<div class="text-center">
|
||||
<v-chip color="inputBorder" size="small"> {{ version }} </v-chip>
|
||||
</div>
|
||||
<div style="position: absolute; bottom: 32px; width: 100%; font-size: 13px;" class="text-center">
|
||||
<v-list-item v-if="!customizer.mini_sidebar" @click="toggleIframe">
|
||||
<v-btn variant="plain" size="small">
|
||||
🤔 点击此处 查看/关闭 悬浮文档!
|
||||
</v-btn>
|
||||
</v-list-item>
|
||||
<small style="display: block;" v-if="buildVer">WebUI 版本: {{ buildVer }}</small>
|
||||
<small style="display: block;" v-else>构建: embedded</small>
|
||||
<v-tooltip text="使用 /dashboard_update 指令更新管理面板">
|
||||
<template v-slot:activator="{ props }">
|
||||
<small v-bind="props" v-if="hasWebUIUpdate" style="display: block; margin-top: 4px;">面板有更新</small>
|
||||
</template>
|
||||
</v-tooltip>
|
||||
<small style="display: block; margin-top: 8px;">AGPL-3.0</small>
|
||||
<div style="position: absolute; bottom: 16px; width: 100%; font-size: 13px;" class="text-center">
|
||||
<v-btn style="margin-bottom: 8px;" size="small" variant="primary" v-if="!customizer.mini_sidebar" to="/settings">
|
||||
🔧 设置
|
||||
</v-btn>
|
||||
<br/>
|
||||
<v-btn style="margin-bottom: 8px;" size="small" variant="plain" v-if="!customizer.mini_sidebar" @click="toggleIframe">
|
||||
官方文档
|
||||
</v-btn>
|
||||
<br/>
|
||||
<v-btn style="margin-bottom: 8px;" size="small" variant="plain" v-if="!customizer.mini_sidebar" @click="openIframeLink('https://github.com/AstrBotDevs/AstrBot')">
|
||||
GitHub
|
||||
</v-btn>
|
||||
<br/>
|
||||
|
||||
</div>
|
||||
</v-navigation-drawer>
|
||||
|
||||
<!-- 优化后的悬浮 iframe -->
|
||||
<div
|
||||
v-if="showIframe"
|
||||
id="draggable-iframe"
|
||||
|
||||
@@ -66,9 +66,9 @@ const sidebarItem: menu[] = [
|
||||
to: '/console'
|
||||
},
|
||||
{
|
||||
title: '设置',
|
||||
icon: 'mdi-wrench',
|
||||
to: '/settings'
|
||||
title: 'Alkaid',
|
||||
icon: 'mdi-test-tube',
|
||||
to: '/alkaid'
|
||||
},
|
||||
{
|
||||
title: '关于',
|
||||
|
||||
@@ -3,6 +3,7 @@ import '@mdi/font/css/materialdesignicons.css';
|
||||
import * as components from 'vuetify/components';
|
||||
import * as directives from 'vuetify/directives';
|
||||
import { PurpleTheme } from '@/theme/LightTheme';
|
||||
import { PurpleThemeDark } from "@/theme/DarkTheme";
|
||||
|
||||
export default createVuetify({
|
||||
components,
|
||||
@@ -11,7 +12,8 @@ export default createVuetify({
|
||||
theme: {
|
||||
defaultTheme: 'PurpleTheme',
|
||||
themes: {
|
||||
PurpleTheme
|
||||
PurpleTheme,
|
||||
PurpleThemeDark
|
||||
}
|
||||
},
|
||||
defaults: {
|
||||
|
||||
21
dashboard/src/router/ChatBoxRoutes.ts
Normal file
21
dashboard/src/router/ChatBoxRoutes.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
const ChatBoxRoutes = {
|
||||
path: '/chatbox',
|
||||
component: () => import('@/layouts/blank/BlankLayout.vue'),
|
||||
children: [
|
||||
{
|
||||
name: 'ChatBox',
|
||||
path: '/chatbox',
|
||||
component: () => import('@/views/ChatBoxPage.vue'),
|
||||
children: [
|
||||
{
|
||||
path: ':conversationId',
|
||||
name: 'ChatBoxDetail',
|
||||
component: () => import('@/views/ChatBoxPage.vue'),
|
||||
props: true
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
};
|
||||
|
||||
export default ChatBoxRoutes;
|
||||
@@ -57,14 +57,39 @@ const MainRoutes = {
|
||||
component: () => import('@/views/ConsolePage.vue')
|
||||
},
|
||||
{
|
||||
name: 'Project ATRI',
|
||||
path: '/project-atri',
|
||||
component: () => import('@/views/ATRIProject.vue')
|
||||
name: 'Alkaid',
|
||||
path: '/alkaid',
|
||||
component: () => import('@/views/AlkaidPage.vue'),
|
||||
children: [
|
||||
{
|
||||
path: 'knowledge-base',
|
||||
name: 'KnowledgeBase',
|
||||
component: () => import('@/views/alkaid/KnowledgeBase.vue')
|
||||
},
|
||||
{
|
||||
path: 'long-term-memory',
|
||||
name: 'LongTermMemory',
|
||||
component: () => import('@/views/alkaid/LongTermMemory.vue')
|
||||
},
|
||||
{
|
||||
path: 'other',
|
||||
name: 'OtherFeatures',
|
||||
component: () => import('@/views/alkaid/Other.vue')
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
name: 'Chat',
|
||||
path: '/chat',
|
||||
component: () => import('@/views/ChatPage.vue')
|
||||
component: () => import('@/views/ChatPage.vue'),
|
||||
children: [
|
||||
{
|
||||
path: ':conversationId',
|
||||
name: 'ChatDetail',
|
||||
component: () => import('@/views/ChatPage.vue'),
|
||||
props: true
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
name: 'Settings',
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import { createRouter, createWebHistory } from 'vue-router';
|
||||
import MainRoutes from './MainRoutes';
|
||||
import AuthRoutes from './AuthRoutes';
|
||||
import ChatBoxRoutes from './ChatBoxRoutes';
|
||||
import { useAuthStore } from '@/stores/auth';
|
||||
|
||||
export const router = createRouter({
|
||||
history: createWebHistory(import.meta.env.BASE_URL),
|
||||
routes: [
|
||||
MainRoutes,
|
||||
AuthRoutes
|
||||
AuthRoutes,
|
||||
ChatBoxRoutes
|
||||
]
|
||||
});
|
||||
|
||||
@@ -24,6 +26,11 @@ router.beforeEach(async (to, from, next) => {
|
||||
const authRequired = !publicPages.includes(to.path);
|
||||
const auth: AuthStore = useAuthStore();
|
||||
|
||||
// 如果用户已登录且试图访问登录页面,则重定向到首页或之前尝试访问的页面
|
||||
if (to.path === '/auth/login' && auth.has_token()) {
|
||||
return next(auth.returnUrl || '/');
|
||||
}
|
||||
|
||||
if (to.matched.some((record) => record.meta.requiresAuth)) {
|
||||
if (authRequired && !auth.has_token()) {
|
||||
auth.returnUrl = to.fullPath;
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
.listitem {
|
||||
height: calc(100vh - 100px);
|
||||
.v-list {
|
||||
color: rgb(var(--v-theme-lightText));
|
||||
color: rgb(var(--v-theme-secondaryText));
|
||||
}
|
||||
.v-list-group__items .v-list-item,
|
||||
.v-list-item {
|
||||
|
||||
@@ -32,7 +32,7 @@ export const useAuthStore = defineStore({
|
||||
},
|
||||
logout() {
|
||||
this.username = '';
|
||||
localStorage.removeItem('username');
|
||||
localStorage.removeItem('user');
|
||||
localStorage.removeItem('token');
|
||||
router.push('/auth/login');
|
||||
},
|
||||
|
||||
@@ -8,6 +8,7 @@ export const useCustomizerStore = defineStore({
|
||||
Customizer_drawer: config.Customizer_drawer,
|
||||
mini_sidebar: config.mini_sidebar,
|
||||
fontTheme: "Poppins",
|
||||
uiTheme: config.uiTheme,
|
||||
inputBg: config.inputBg
|
||||
}),
|
||||
|
||||
@@ -21,6 +22,10 @@ export const useCustomizerStore = defineStore({
|
||||
},
|
||||
SET_FONT(payload: string) {
|
||||
this.fontTheme = payload;
|
||||
}
|
||||
},
|
||||
SET_UI_THEME(payload: string) {
|
||||
this.uiTheme = payload;
|
||||
localStorage.setItem("uiTheme", payload);
|
||||
},
|
||||
}
|
||||
});
|
||||
|
||||
46
dashboard/src/theme/DarkTheme.ts
Normal file
46
dashboard/src/theme/DarkTheme.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
import type { ThemeTypes } from '@/types/themeTypes/ThemeType';
|
||||
|
||||
const PurpleThemeDark: ThemeTypes = {
|
||||
name: 'PurpleThemeDark',
|
||||
dark: true,
|
||||
variables: {
|
||||
'border-color': '#1677ff',
|
||||
'carousel-control-size': 10
|
||||
},
|
||||
colors: {
|
||||
primary: '#1677ff',
|
||||
secondary: '#722ed1',
|
||||
info: '#03c9d7',
|
||||
success: '#52c41a',
|
||||
accent: '#FFAB91',
|
||||
warning: '#faad14',
|
||||
error: '#ff4d4f',
|
||||
lightprimary: '#eef2f6',
|
||||
lightsecondary: '#ede7f6',
|
||||
lightsuccess: '#b9f6ca',
|
||||
lighterror: '#f9d8d8',
|
||||
lightwarning: '#fff8e1',
|
||||
primaryText: '#ffffff',
|
||||
secondaryText: '#ffffffcc',
|
||||
darkprimary: '#1565c0',
|
||||
darksecondary: '#4527a0',
|
||||
borderLight: '#d0d0d0',
|
||||
border: '#333333ee',
|
||||
inputBorder: '#787878',
|
||||
containerBg: '#1a1a1a',
|
||||
surface: '#1f1f1f',
|
||||
'on-surface-variant': '#000',
|
||||
facebook: '#4267b2',
|
||||
twitter: '#1da1f2',
|
||||
linkedin: '#0e76a8',
|
||||
gray100: '#cccccccc',
|
||||
primary200: '#90caf9',
|
||||
secondary200: '#b39ddb',
|
||||
background: '#111111',
|
||||
overlay: '#111111aa',
|
||||
codeBg: '#282833',
|
||||
code: '#ffffffdd'
|
||||
}
|
||||
};
|
||||
|
||||
export { PurpleThemeDark };
|
||||
@@ -20,11 +20,12 @@ const PurpleTheme: ThemeTypes = {
|
||||
lightsuccess: '#b9f6ca',
|
||||
lighterror: '#f9d8d8',
|
||||
lightwarning: '#fff8e1',
|
||||
darkText: '#212121',
|
||||
lightText: '#616161',
|
||||
primaryText: '#000000dd',
|
||||
secondaryText: '#000000aa',
|
||||
darkprimary: '#1565c0',
|
||||
darksecondary: '#4527a0',
|
||||
borderLight: '#d0d0d0',
|
||||
border: '#d0d0d0',
|
||||
inputBorder: '#787878',
|
||||
containerBg: '#eef2f6',
|
||||
surface: '#fff',
|
||||
@@ -32,9 +33,13 @@ const PurpleTheme: ThemeTypes = {
|
||||
facebook: '#4267b2',
|
||||
twitter: '#1da1f2',
|
||||
linkedin: '#0e76a8',
|
||||
gray100: '#fafafa',
|
||||
gray100: '#fafafacc',
|
||||
primary200: '#90caf9',
|
||||
secondary200: '#b39ddb'
|
||||
secondary200: '#b39ddb',
|
||||
background: '#f9fafcf4',
|
||||
overlay: '#ffffffaa',
|
||||
codeBg: '#f5f0ff',
|
||||
code: '#673ab7'
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -17,13 +17,15 @@ export type ThemeTypes = {
|
||||
lightwarning?: string;
|
||||
darkprimary?: string;
|
||||
darksecondary?: string;
|
||||
darkText?: string;
|
||||
lightText?: string;
|
||||
primaryText?: string;
|
||||
secondaryText?: string;
|
||||
borderLight?: string;
|
||||
border?: string;
|
||||
inputBorder?: string;
|
||||
containerBg?: string;
|
||||
surface?: string;
|
||||
background?: string;
|
||||
overlay?: string;
|
||||
'on-surface-variant'?: string;
|
||||
facebook?: string;
|
||||
twitter?: string;
|
||||
@@ -31,5 +33,7 @@ export type ThemeTypes = {
|
||||
gray100?: string;
|
||||
primary200?: string;
|
||||
secondary200?: string;
|
||||
codeBg?: string;
|
||||
code?: string;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
<script setup>
|
||||
</script>
|
||||
|
||||
|
||||
<template>
|
||||
<v-alert style="margin-bottom: 16px"
|
||||
text="这是一个长期实验性功能,目标是实现更具人类机能的 LLM 对话。推荐使用 gpt-4o-mini 作为文本生成和视觉理解模型,成本很低。推荐使用 text-embedding-3-small 作为 Embedding 模型,成本忽略不计。"
|
||||
title="💡实验性功能" type="info" variant="tonal">
|
||||
</v-alert>
|
||||
<v-card>
|
||||
<v-card-text>
|
||||
<v-container fluid>
|
||||
<AstrBotConfig :metadata="project_atri_config_metadata" :iterable="project_atri_config?.project_atri"
|
||||
metadataKey="project_atri">
|
||||
</AstrBotConfig>
|
||||
</v-container>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
|
||||
<v-btn icon="mdi-content-save" size="x-large" style="position: fixed; right: 52px; bottom: 52px;" color="darkprimary"
|
||||
@click="updateConfig">
|
||||
</v-btn>
|
||||
<v-snackbar :timeout="3000" elevation="24" :color="save_message_success" v-model="save_message_snack">
|
||||
{{ save_message }}
|
||||
</v-snackbar>
|
||||
<WaitingForRestart ref="wfr"></WaitingForRestart>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import axios from 'axios';
|
||||
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
|
||||
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||
export default {
|
||||
name: 'AtriProject',
|
||||
components: {
|
||||
AstrBotConfig,
|
||||
WaitingForRestart
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
project_atri_config: {},
|
||||
fetched: false,
|
||||
project_atri_config_metadata: {},
|
||||
save_message_snack: false,
|
||||
save_message: "",
|
||||
save_message_success: "",
|
||||
}
|
||||
},
|
||||
mounted() {
|
||||
this.getConfig();
|
||||
},
|
||||
methods: {
|
||||
getConfig() {
|
||||
// 获取配置
|
||||
axios.get('/api/config/get').then((res) => {
|
||||
this.project_atri_config = res.data.data.config;
|
||||
this.fetched = true
|
||||
this.project_atri_config_metadata = res.data.data.metadata;
|
||||
}).catch((err) => {
|
||||
save_message = err;
|
||||
save_message_snack = true;
|
||||
save_message_success = "error";
|
||||
});
|
||||
},
|
||||
updateConfig() {
|
||||
if (!this.fetched) return;
|
||||
axios.post('/api/config/astrbot/update', this.project_atri_config).then((res) => {
|
||||
if (res.data.status === "ok") {
|
||||
this.save_message = res.data.message;
|
||||
this.save_message_snack = true;
|
||||
this.save_message_success = "success";
|
||||
this.$refs.wfr.check();
|
||||
} else {
|
||||
this.save_message = res.data.message;
|
||||
this.save_message_snack = true;
|
||||
this.save_message_success = "error";
|
||||
}
|
||||
}).catch((err) => {
|
||||
this.save_message = err;
|
||||
this.save_message_snack = true;
|
||||
this.save_message_success = "error";
|
||||
});
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
</script>
|
||||
@@ -1,55 +1,95 @@
|
||||
<template>
|
||||
<v-card style="height: 100%;">
|
||||
<v-card-text style="padding: 0; height: 100%; overflow-y: auto;">
|
||||
<div
|
||||
style="display: flex; justify-content: center; align-items: center; height: 100%; flex-direction: column;">
|
||||
<div @click="selectedLogo = selectedLogo == 0 ? 1 : 0" style="height: 300px;">
|
||||
<img v-if="selectedLogo == 0" width="300" src="@/assets/images/logo-waifu.png" alt="AstrBot Logo"
|
||||
class="fade-in">
|
||||
<img v-if="selectedLogo == 1" width="300" src="@/assets/images/logo-normal.svg" alt="AstrBot Logo"
|
||||
class="fade-in">
|
||||
</div>
|
||||
<v-card style="height: 100%;" elevation="0" class="bg-surface">
|
||||
<v-card-text style="padding: 0; height: 100%; overflow-y: hidden;">
|
||||
<div class="about-wrapper">
|
||||
<!-- Hero Section -->
|
||||
<section class="hero-section">
|
||||
<div class="logo-title-container">
|
||||
<div @click="selectedLogo = selectedLogo == 0 ? 1 : 0" class="logo-container">
|
||||
<img v-if="selectedLogo == 0" width="280" src="@/assets/images/logo-waifu.png" alt="AstrBot Logo" class="fade-in">
|
||||
<img v-if="selectedLogo == 1" width="280" src="@/assets/images/logo-normal.svg" alt="AstrBot Logo" class="fade-in">
|
||||
</div>
|
||||
<div class="title-container">
|
||||
<h1 class="text-h2 font-weight-bold">AstrBot</h1>
|
||||
<p class="text-subtitle-1" style="color: var(--v-theme-secondaryText);">A project out of interests and loves ❤️</p>
|
||||
<div class="action-buttons">
|
||||
<v-btn @click="open('https://github.com/Soulter/AstrBot')"
|
||||
color="primary" variant="elevated" prepend-icon="mdi-star">
|
||||
Star 这个项目! 🌟
|
||||
</v-btn>
|
||||
<v-btn class="ml-4" @click="open('https://github.com/Soulter/AstrBot/issues')"
|
||||
color="secondary" variant="elevated" prepend-icon="mdi-comment-question">
|
||||
提交 Issue
|
||||
</v-btn>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<h1 class="mt-8">AstrBot</h1>
|
||||
<!-- Contributors Section -->
|
||||
<section class="contributors-section">
|
||||
<v-container>
|
||||
<v-row justify="center" align="center">
|
||||
<v-col cols="12" md="6" class="pr-md-8 contributors-info">
|
||||
<h2 class="text-h4 font-weight-medium">贡献者</h2>
|
||||
<p class="mb-4 text-body-1" style="color: var(--v-theme-secondaryText);">
|
||||
本项目由众多开源社区成员共同维护。感谢每一位贡献者的付出!
|
||||
</p>
|
||||
<p class="text-body-1" style="color: var(--v-theme-secondaryText);">
|
||||
<a href="https://github.com/Soulter/AstrBot/graphs/contributors" class="text-decoration-none custom-link">查看 AstrBot 贡献者</a>
|
||||
</p>
|
||||
</v-col>
|
||||
<v-col cols="12" md="6">
|
||||
<v-card variant="outlined" class="overflow-hidden" elevation="2">
|
||||
<v-img v-if="useCustomizerStore().uiTheme==='PurpleThemeDark'"
|
||||
alt="Active Contributors of Soulter/AstrBot"
|
||||
src="https://next.ossinsight.io/widgets/official/compose-recent-active-contributors/thumbnail.png?repo_id=575865240&limit=365&image_size=auto&color_scheme=dark">
|
||||
</v-img>
|
||||
<v-img v-else
|
||||
alt="Active Contributors of Soulter/AstrBot"
|
||||
src="https://next.ossinsight.io/widgets/official/compose-recent-active-contributors/thumbnail.png?repo_id=575865240&limit=365&image_size=auto&color_scheme=light">
|
||||
</v-img>
|
||||
</v-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-container>
|
||||
</section>
|
||||
|
||||
<span class="mt-2" style="color: #777;">A project out of interests and loves ❤️</span>
|
||||
|
||||
<span style="color: #777; margin-left: 32px; margin-right: 32px" class="mt-4">By <a
|
||||
href="https://soulter.top">Soulter</a>, <a
|
||||
href="https://github.com/Soulter/AstrBot/graphs/contributors">AstrBot Contributors</a>
|
||||
and <a href="https://github.com/Soulter/AstrBot_Plugins_Collection/graphs/contributors">AstrBot
|
||||
Plugin Authors</a>
|
||||
</span>
|
||||
|
||||
<!-- Copy-paste in your Readme.md file -->
|
||||
|
||||
<img style="margin-top: 16px; width: 50%; max-width: 500px; margin-left: 32px; margin-right: 32px"
|
||||
alt="Active Contributors of Soulter/AstrBot - Last 28 days"
|
||||
src="https://next.ossinsight.io/widgets/official/compose-recent-active-contributors/thumbnail.png?repo_id=575865240&limit=365&image_size=auto&color_scheme=light">
|
||||
|
||||
<img style="margin-top: 16px; width: 50%; max-width: 500px; margin-left: 32px; margin-right: 32px"
|
||||
alt="Active Contributors of Soulter/AstrBot - Last 28 days"
|
||||
src="https://next.ossinsight.io/widgets/official/analyze-repo-stars-map/thumbnail.png?activity=stars&repo_id=575865240&image_size=auto&color_scheme=light
|
||||
">
|
||||
|
||||
|
||||
<!-- Made with [OSS Insight](https://ossinsight.io/) -->
|
||||
|
||||
<v-btn class="text-primary mt-8" @click="open('https://github.com/Soulter/AstrBot')"
|
||||
color="lightprimary" variant="flat" rounded="sm">
|
||||
Star 这个项目! 🌟
|
||||
</v-btn>
|
||||
|
||||
<v-btn class="text-primary mt-4" @click="open('https://github.com/Soulter/AstrBot/issues')"
|
||||
color="lightprimary" variant="flat" rounded="sm">
|
||||
有使用问题或者功能建议?提交 Issue!
|
||||
</v-btn>
|
||||
<!-- Stats Section -->
|
||||
<section class="stats-section">
|
||||
<v-container>
|
||||
<v-row justify="center" align="center" class="flex-md-row-reverse">
|
||||
<v-col cols="12" md="6" class="pl-md-8 stats-info">
|
||||
<h2 class="text-h4 font-weight-medium">全球部署</h2>
|
||||
|
||||
<div class="license-container mt-8">
|
||||
<img v-bind="props" src="https://www.gnu.org/graphics/agplv3-with-text-100x42.png" style="cursor: pointer;"/>
|
||||
<p class="text-caption mt-2" style="color: var(--v-theme-secondaryText);">AstrBot 采用 AGPL v3 协议开源</p>
|
||||
</div>
|
||||
</v-col>
|
||||
<v-col cols="12" md="6">
|
||||
<v-card variant="outlined" class="overflow-hidden" elevation="2">
|
||||
<v-img v-if="useCustomizerStore().uiTheme==='PurpleThemeDark'"
|
||||
alt="Stars Map of Soulter/AstrBot"
|
||||
src="https://next.ossinsight.io/widgets/official/analyze-repo-stars-map/thumbnail.png?activity=stars&repo_id=575865240&image_size=auto&color_scheme=dark">
|
||||
</v-img>
|
||||
<v-img v-else
|
||||
alt="Stars Map of Soulter/AstrBot"
|
||||
src="https://next.ossinsight.io/widgets/official/analyze-repo-stars-map/thumbnail.png?activity=stars&repo_id=575865240&image_size=auto&color_scheme=light">
|
||||
</v-img>
|
||||
</v-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-container>
|
||||
</section>
|
||||
</div>
|
||||
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import {useCustomizerStore} from "@/stores/customizer";
|
||||
|
||||
export default {
|
||||
name: 'AboutPage',
|
||||
data() {
|
||||
@@ -59,26 +99,141 @@ export default {
|
||||
},
|
||||
|
||||
methods: {
|
||||
useCustomizerStore,
|
||||
open(url) {
|
||||
window.open(url, '_blank');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
<style>
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
}
|
||||
<style scoped>
|
||||
.about-wrapper {
|
||||
min-height: 100%;
|
||||
}
|
||||
|
||||
to {
|
||||
opacity: 1;
|
||||
}
|
||||
.hero-section {
|
||||
padding: 40px 20px;
|
||||
background: linear-gradient(to right bottom, rgba(255,255,255,0.7), rgba(240,240,250,0.3));
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.logo-title-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
flex-direction: row;
|
||||
max-width: 900px;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
.logo-container {
|
||||
cursor: pointer;
|
||||
transition: all 0.3s ease;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.logo-container:hover {
|
||||
transform: scale(1.05);
|
||||
}
|
||||
|
||||
.title-container {
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.contributors-section, .stats-section {
|
||||
padding: 60px 20px;
|
||||
}
|
||||
|
||||
.contributors-section {
|
||||
background-color: var(--v-theme-containerBg, #f9f9fb);
|
||||
}
|
||||
|
||||
.contributors-info, .stats-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.custom-link {
|
||||
display: inline-block;
|
||||
padding: 5px 0;
|
||||
position: relative;
|
||||
color: var(--v-primary-base);
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.custom-link::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
width: 100%;
|
||||
transform: scaleX(0);
|
||||
height: 2px;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
background-color: var(--v-primary-base);
|
||||
transform-origin: bottom right;
|
||||
transition: transform 0.25s ease-out;
|
||||
}
|
||||
|
||||
.custom-link:hover::after {
|
||||
transform: scaleX(1);
|
||||
transform-origin: bottom left;
|
||||
}
|
||||
|
||||
.license-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
}
|
||||
|
||||
.action-buttons {
|
||||
display: flex;
|
||||
margin-top: 24px;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; }
|
||||
to { opacity: 1; }
|
||||
}
|
||||
|
||||
.fade-in {
|
||||
animation: fadeIn 0.2s ease-in-out;
|
||||
}
|
||||
|
||||
@media (max-width: 960px) {
|
||||
.logo-title-container {
|
||||
flex-direction: column;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.title-container {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.action-buttons {
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.license-container {
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.contributors-section, .stats-section {
|
||||
padding: 40px 20px;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.action-buttons {
|
||||
flex-direction: column;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.action-buttons .v-btn + .v-btn {
|
||||
margin-left: 0 !important;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
80
dashboard/src/views/AlkaidPage.vue
Normal file
80
dashboard/src/views/AlkaidPage.vue
Normal file
@@ -0,0 +1,80 @@
|
||||
<template>
|
||||
<v-card style="height: 100%; width: 100%;">
|
||||
<v-card-text class="pa-4" style="height: 100%;">
|
||||
<v-container fluid class="d-flex flex-column" style="height: 100%;">
|
||||
<div style="margin-bottom: 32px;">
|
||||
<h1 class="gradient-text">The Alkaid Project.</h1>
|
||||
<small style="color: #a3a3a3;">AstrBot Alpha 项目</small>
|
||||
</div>
|
||||
|
||||
<div style="display: flex; gap: 8px; margin-bottom: 16px;">
|
||||
<v-btn size="large" :variant="isActive('knowledge-base') ? 'flat' : 'tonal'"
|
||||
:color="isActive('knowledge-base') ? '#9b72cb' : ''" rounded="lg"
|
||||
@click="navigateTo('knowledge-base')">
|
||||
<v-icon start>mdi-text-box-search</v-icon>
|
||||
知识库
|
||||
</v-btn>
|
||||
<v-btn size="large" :variant="isActive('long-term-memory') ? 'flat' : 'tonal'"
|
||||
:color="isActive('long-term-memory') ? '#9b72cb' : ''" rounded="lg"
|
||||
@click="navigateTo('long-term-memory')">
|
||||
<v-icon start>mdi-dots-hexagon</v-icon>
|
||||
长期记忆层
|
||||
</v-btn>
|
||||
<v-btn size="large" :variant="isActive('other') ? 'flat' : 'tonal'"
|
||||
:color="isActive('other') ? '#9b72cb' : ''" rounded="lg"
|
||||
@click="navigateTo('other')">
|
||||
<v-icon start>mdi-tools</v-icon>
|
||||
...
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div id="sub-view" class="flex-grow-1" style="max-height: 100%;">
|
||||
<router-view></router-view>
|
||||
</div>
|
||||
</v-container>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
export default {
|
||||
name: 'AlkaidPage',
|
||||
components: {},
|
||||
data() {
|
||||
return {}
|
||||
},
|
||||
methods: {
|
||||
navigateTo(tab) {
|
||||
this.$router.push(`/alkaid/${tab}`);
|
||||
},
|
||||
isActive(tab) {
|
||||
return this.$route.path.includes(`/alkaid/${tab}`);
|
||||
}
|
||||
},
|
||||
mounted() {
|
||||
// 如果在根路径 /alkaid,默认跳转到知识库页面
|
||||
if (this.$route.path === '/alkaid') {
|
||||
this.navigateTo('knowledge-base');
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.gradient-text {
|
||||
background: linear-gradient(74deg, #2abfe1 0, #9b72cb 25%, #b55908 50%, #d93025 100%);
|
||||
|
||||
-webkit-background-clip: text;
|
||||
background-clip: text;
|
||||
color: transparent;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
#subview {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
flex-grow: 1;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
</style>
|
||||
432
dashboard/src/views/AlkaidPage_sigma.vue
Normal file
432
dashboard/src/views/AlkaidPage_sigma.vue
Normal file
@@ -0,0 +1,432 @@
|
||||
<script setup>
|
||||
import Graph from "graphology";
|
||||
import Sigma from "sigma";
|
||||
import ForceSupervisor from "graphology-layout-force/worker";
|
||||
</script>
|
||||
|
||||
|
||||
<template>
|
||||
<v-card style="height: 100%; width: 100%;">
|
||||
<v-card-text class="pa-4" style="height: 100%;">
|
||||
<v-container fluid class="d-flex flex-column" style="height: 100%;">
|
||||
<div style="margin-bottom: 32px;">
|
||||
<h1 class="gradient-text">The Alkaid Project.</h1>
|
||||
<small style="color: #a3a3a3;">AstrBot 实验性项目</small>
|
||||
</div>
|
||||
|
||||
<div style="display: flex; gap: 8px; margin-bottom: 16px;">
|
||||
<v-btn size="large" :variant="activeTab === 'long-term-memory' ? 'flat' : 'tonal'"
|
||||
:color="activeTab === 'long-term-memory' ? '#9b72cb' : ''" rounded="lg"
|
||||
@click="activeTab = 'long-term-memory'">
|
||||
<v-icon start>mdi-dots-hexagon</v-icon>
|
||||
长期记忆层
|
||||
</v-btn>
|
||||
<v-btn size="large" :variant="activeTab === 'other' ? 'flat' : 'tonal'"
|
||||
:color="activeTab === 'other' ? '#9b72cb' : ''" rounded="lg" @click="activeTab = 'other'">
|
||||
<v-icon start>mdi-dots-horizontal</v-icon>
|
||||
其他
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div v-if="activeTab === 'long-term-memory'" id="long-term-memory" class="flex-grow-1"
|
||||
style="display: flex; flex-direction: row;">
|
||||
<div id="graph-container" style="flex-grow: 1; width: 100%; border: 1px solid #eee; border-radius: 8px;">
|
||||
</div>
|
||||
<div id="graph-control-panel"
|
||||
style="min-width: 450px; border: 1px solid #eee; border-radius: 8px; padding: 16px; margin-left: 16px;">
|
||||
<div>
|
||||
<span style="color: #333333;">可视化</span>
|
||||
<div style="margin-top: 8px;">
|
||||
<v-autocomplete v-model="searchUserId" :items="userIdList" variant="outlined"
|
||||
label="筛选用户 ID"></v-autocomplete>
|
||||
<v-btn color="primary" @click="onNodeSelect" variant="tonal" style="margin-top: 8px;">
|
||||
<v-icon start>mdi-magnify</v-icon>
|
||||
筛选
|
||||
</v-btn>
|
||||
<v-btn color="secondary" @click="resetFilter" variant="tonal"
|
||||
style="margin-top: 8px; margin-left: 8px;">
|
||||
<v-icon start>mdi-filter-remove</v-icon>
|
||||
重置筛选
|
||||
</v-btn>
|
||||
</div>
|
||||
<div style="margin-top: 16px;">
|
||||
<v-btn color="primary" @click="refreshGraph" variant="tonal">
|
||||
<v-icon start>mdi-refresh</v-icon>
|
||||
刷新图形
|
||||
</v-btn>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<v-divider class="my-4"></v-divider>
|
||||
|
||||
<div v-if="selectedNode" class="mt-4">
|
||||
<h3>节点详情</h3>
|
||||
<v-card variant="outlined" class="mt-2 pa-3">
|
||||
<div v-if="selectedNode.id">
|
||||
<div class="d-flex justify-space-between">
|
||||
<span class="text-subtitle-2">ID:</span>
|
||||
<span>{{ selectedNode.id }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="selectedNode._label">
|
||||
<div class="d-flex justify-space-between">
|
||||
<span class="text-subtitle-2">类型:</span>
|
||||
<span>{{ selectedNode._label }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="selectedNode.name">
|
||||
<div class="d-flex justify-space-between">
|
||||
<span class="text-subtitle-2">名称:</span>
|
||||
<span>{{ selectedNode.name }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="selectedNode.user_id">
|
||||
<div class="d-flex justify-space-between">
|
||||
<span class="text-subtitle-2">用户ID:</span>
|
||||
<span>{{ selectedNode.user_id }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="selectedNode.ts">
|
||||
<div class="d-flex justify-space-between">
|
||||
<span class="text-subtitle-2">时间戳:</span>
|
||||
<span>{{ selectedNode.ts }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="selectedNode.type">
|
||||
<div class="d-flex justify-space-between">
|
||||
<span class="text-subtitle-2">类型:</span>
|
||||
<span>{{ selectedNode.type }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</v-card>
|
||||
</div>
|
||||
|
||||
<div v-if="graphStats" class="mt-4">
|
||||
<h3>图形统计</h3>
|
||||
<v-card variant="outlined" class="mt-2 pa-3">
|
||||
<div class="d-flex justify-space-between">
|
||||
<span class="text-subtitle-2">节点数:</span>
|
||||
<span>{{ graphStats.nodeCount }}</span>
|
||||
</div>
|
||||
<div class="d-flex justify-space-between">
|
||||
<span class="text-subtitle-2">边数:</span>
|
||||
<span>{{ graphStats.edgeCount }}</span>
|
||||
</div>
|
||||
</v-card>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="activeTab === 'other'" class="flex-grow-1" style="display: flex; flex-direction: column;">
|
||||
<div class="d-flex align-center justify-center"
|
||||
style="flex-grow: 1; width: 100%; border: 1px solid #eee; border-radius: 8px;">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-tools</v-icon>
|
||||
<p class="text-h6 text-grey ml-4">功能开发中</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</v-container>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import axios from 'axios';
|
||||
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
|
||||
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||
export default {
|
||||
name: 'AlkaidPage',
|
||||
components: {
|
||||
AstrBotConfig,
|
||||
WaitingForRestart
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
renderer: null,
|
||||
graph: null,
|
||||
layout: null,
|
||||
activeTab: 'long-term-memory',
|
||||
node_data: [],
|
||||
edge_data: [],
|
||||
searchUserId: null,
|
||||
userIdList: [],
|
||||
selectedNode: null,
|
||||
graphStats: null,
|
||||
nodeColors: {
|
||||
'PhaseNode': '#4CAF50', // 绿色
|
||||
'PassageNode': '#2196F3', // 蓝色
|
||||
'FactNode': '#FF9800', // 橙色
|
||||
'default': '#9C27B0' // 紫色作为默认
|
||||
},
|
||||
edgeColors: {
|
||||
'_include_': '#607D8B',
|
||||
'_related_': '#9E9E9E',
|
||||
'default': '#BDBDBD'
|
||||
},
|
||||
isLoading: false
|
||||
}
|
||||
},
|
||||
mounted() {
|
||||
this.initSigma();
|
||||
this.ltmGetGraph();
|
||||
this.ltmGetUserIds();
|
||||
},
|
||||
beforeUnmount() {
|
||||
if (this.renderer) {
|
||||
this.renderer.kill();
|
||||
}
|
||||
if (this.layout) {
|
||||
this.layout.stop();
|
||||
}
|
||||
},
|
||||
watch: {
|
||||
activeTab(newVal) {
|
||||
if (newVal === 'long-term-memory') {
|
||||
this.$nextTick(() => {
|
||||
if (!this.renderer) {
|
||||
this.initSigma();
|
||||
}
|
||||
});
|
||||
} else {
|
||||
if (this.renderer) {
|
||||
this.renderer.kill();
|
||||
this.renderer = null;
|
||||
}
|
||||
if (this.layout) {
|
||||
this.layout.stop();
|
||||
this.layout = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
ltmGetGraph(userId = null) {
|
||||
this.isLoading = true;
|
||||
const params = userId ? { user_id: userId } : {};
|
||||
|
||||
axios.get('/api/plug/alkaid/ltm/graph', { params })
|
||||
.then(response => {
|
||||
let nodes = response.data.data.nodes;
|
||||
let edges = response.data.data.edges;
|
||||
|
||||
this.node_data = nodes;
|
||||
this.edge_data = edges;
|
||||
|
||||
if (this.graph) {
|
||||
this.graph.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
nodes.forEach(node => {
|
||||
const nodeId = node[0];
|
||||
const nodeData = node[1];
|
||||
|
||||
if (!this.graph.hasNode(nodeId)) {
|
||||
const nodeType = nodeData._label || 'default';
|
||||
const color = this.nodeColors[nodeType] || this.nodeColors['default'];
|
||||
|
||||
this.graph.addNode(nodeId, {
|
||||
x: Math.random(),
|
||||
y: Math.random(),
|
||||
size: 5,
|
||||
label: nodeData.name || nodeId.split('_')[0],
|
||||
color: color,
|
||||
originalData: nodeData
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// 添加边
|
||||
edges.forEach(edge => {
|
||||
const sourceId = edge[0];
|
||||
const targetId = edge[1];
|
||||
const edgeData = edge[2];
|
||||
|
||||
if (this.graph.hasNode(sourceId) && this.graph.hasNode(targetId)) {
|
||||
const edgeId = `${sourceId}->${targetId}`;
|
||||
const relationType = edgeData.relation_type || 'default';
|
||||
const color = this.edgeColors[relationType] || this.edgeColors['default'];
|
||||
this.graph.addEdge(sourceId, targetId, {
|
||||
size: 1,
|
||||
color: color,
|
||||
originalData: edgeData,
|
||||
label: relationType,
|
||||
type: "line"
|
||||
});
|
||||
} else {
|
||||
console.warn(`Edge ${sourceId} -> ${targetId} has missing nodes.`);
|
||||
}
|
||||
});
|
||||
|
||||
this.updateGraphStats();
|
||||
|
||||
console.log('Graph initialized with', nodes.length, 'nodes and', edges.length, 'edges');
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching graph data:', error);
|
||||
})
|
||||
.finally(() => {
|
||||
this.isLoading = false;
|
||||
});
|
||||
|
||||
if (this.layout) {
|
||||
this.layout.start();
|
||||
}
|
||||
|
||||
},
|
||||
|
||||
ltmGetUserIds() {
|
||||
axios.get('/api/plug/alkaid/ltm/user_ids')
|
||||
.then(response => {
|
||||
this.userIdList = response.data.data;
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching user IDs:', error);
|
||||
});
|
||||
},
|
||||
|
||||
updateGraphStats() {
|
||||
if (this.graph) {
|
||||
this.graphStats = {
|
||||
nodeCount: this.graph.order,
|
||||
edgeCount: this.graph.size
|
||||
};
|
||||
}
|
||||
},
|
||||
|
||||
refreshGraph() {
|
||||
this.ltmGetGraph(this.searchUserId);
|
||||
},
|
||||
|
||||
onNodeSelect() {
|
||||
console.log('Selected user ID:', this.searchUserId);
|
||||
if (!this.searchUserId || !this.graph) return;
|
||||
|
||||
// 使用API的user_id参数筛选数据
|
||||
this.ltmGetGraph(this.searchUserId);
|
||||
},
|
||||
|
||||
resetFilter() {
|
||||
this.searchUserId = null;
|
||||
this.ltmGetGraph();
|
||||
},
|
||||
|
||||
initSigma() {
|
||||
const container = document.getElementById("graph-container");
|
||||
if (!container) return;
|
||||
|
||||
if (this.renderer) {
|
||||
this.renderer.kill();
|
||||
this.renderer = null;
|
||||
}
|
||||
if (this.layout) {
|
||||
this.layout.stop();
|
||||
this.layout = null;
|
||||
}
|
||||
|
||||
const graph = new Graph({
|
||||
multi: true,
|
||||
});
|
||||
|
||||
const layout = new ForceSupervisor(graph, {
|
||||
isNodeFixed: (_, attr) => attr.highlighted, settings: {
|
||||
gravity: 0.0001,
|
||||
repulsion: 0.001
|
||||
}
|
||||
});
|
||||
layout.start();
|
||||
|
||||
this.layout = layout;
|
||||
this.graph = graph;
|
||||
const renderer = new Sigma(graph, container, {
|
||||
minCameraRatio: 0.01,
|
||||
maxCameraRatio: 2,
|
||||
labelRenderedSizeThreshold: 1,
|
||||
renderLabels: true,
|
||||
renderEdgeLabels: true,
|
||||
labelSize: 14,
|
||||
labelColor: "#333333",
|
||||
});
|
||||
this.renderer = renderer;
|
||||
|
||||
let draggedNode = null;
|
||||
let isDragging = false;
|
||||
|
||||
renderer.on("downNode", (e) => {
|
||||
isDragging = true;
|
||||
draggedNode = e.node;
|
||||
graph.setNodeAttribute(draggedNode, "highlighted", true);
|
||||
if (!renderer.getCustomBBox()) renderer.setCustomBBox(renderer.getBBox());
|
||||
});
|
||||
|
||||
renderer.on("moveBody", ({ event }) => {
|
||||
if (!isDragging || !draggedNode) return;
|
||||
const pos = renderer.viewportToGraph(event);
|
||||
|
||||
graph.setNodeAttribute(draggedNode, "x", pos.x);
|
||||
graph.setNodeAttribute(draggedNode, "y", pos.y);
|
||||
event.preventSigmaDefault();
|
||||
event.original.preventDefault();
|
||||
event.original.stopPropagation();
|
||||
});
|
||||
const handleUp = () => {
|
||||
if (draggedNode) {
|
||||
graph.removeNodeAttribute(draggedNode, "highlighted");
|
||||
}
|
||||
isDragging = false;
|
||||
draggedNode = null;
|
||||
};
|
||||
renderer.on("upNode", handleUp);
|
||||
renderer.on("upStage", handleUp);
|
||||
|
||||
renderer.on("clickNode", (e) => {
|
||||
const nodeId = e.node;
|
||||
const nodeAttributes = graph.getNodeAttributes(nodeId);
|
||||
this.selectedNode = nodeAttributes.originalData;
|
||||
});
|
||||
|
||||
renderer.on("clickStage", () => {
|
||||
this.selectedNode = null;
|
||||
});
|
||||
|
||||
},
|
||||
|
||||
getRandomColor() {
|
||||
const letters = '0123456789ABCDEF';
|
||||
let color = '#';
|
||||
for (let i = 0; i < 6; i++) {
|
||||
color += letters[Math.floor(Math.random() * 16)];
|
||||
}
|
||||
return color;
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.gradient-text {
|
||||
background: linear-gradient(74deg, #2abfe1 0, #9b72cb 25%, #b55908 50%, #d93025 100%);
|
||||
|
||||
-webkit-background-clip: text;
|
||||
background-clip: text;
|
||||
color: transparent;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
#graph-container {
|
||||
position: relative;
|
||||
background-color: #f2f6f9;
|
||||
overflow: hidden;
|
||||
min-height: 200px;
|
||||
}
|
||||
|
||||
#graph-container:hover {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.memory-header {
|
||||
padding: 0 8px;
|
||||
}
|
||||
</style>
|
||||
36
dashboard/src/views/ChatBoxPage.vue
Normal file
36
dashboard/src/views/ChatBoxPage.vue
Normal file
@@ -0,0 +1,36 @@
|
||||
<script setup>
|
||||
import ChatPage from './ChatPage.vue';
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div style="height: 100%; width: 100%; display: flex; flex-direction: column; align-items: center; justify-content: center;">
|
||||
<div id="container">
|
||||
<ChatPage chatbox-mode="true"></ChatPage>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
#container {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
@media (min-width: 768px) {
|
||||
#container {
|
||||
min-width: 600px;
|
||||
min-height: 370px;
|
||||
max-width: 1100px;
|
||||
max-height: 860px;
|
||||
padding: 36px;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 767px) {
|
||||
#container {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
padding: 0;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user