remove: 移除了 nakuru-project 库
但仍然使用其对 OneBot 的数据格式封装。
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
|
||||
from astrbot.core.plugin import Context
|
||||
from astrbot.core.platform import AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.core.message_event_result import MessageEventResult, MessageChain, CommandResult
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain, CommandResult
|
||||
from astrbot.core.provider import Provider, Personality
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from nakuru.entities.components import *
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.personality import personalities
|
||||
|
||||
|
||||
1
astrbot/api/message_components.py
Normal file
1
astrbot/api/message_components.py
Normal file
@@ -0,0 +1 @@
|
||||
from astrbot.core.message.components import *
|
||||
@@ -1,5 +1,5 @@
|
||||
from .log import LogManager, LogBroker
|
||||
from core.utils.t2i.renderer import HtmlRenderer
|
||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||
|
||||
html_renderer = HtmlRenderer()
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
@@ -67,6 +67,11 @@ class ImageGenerationModelConfig:
|
||||
style: str = "vivid"
|
||||
quality: str = "standard"
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModel:
|
||||
enable: bool = False
|
||||
model: str = ""
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
id: str = ""
|
||||
@@ -77,12 +82,17 @@ class LLMConfig:
|
||||
prompt_prefix: str = ""
|
||||
default_personality: str = ""
|
||||
model_config: ModelConfig = field(default_factory=ModelConfig)
|
||||
image_generation_model_config: Optional[ImageGenerationModelConfig] = None
|
||||
image_generation_model_config: Optional[ImageGenerationModelConfig] = field(default_factory=ImageGenerationModelConfig)
|
||||
embedding_model: Optional[EmbeddingModel] = field(default_factory=EmbeddingModel)
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_config = ModelConfig(**self.model_config)
|
||||
if self.image_generation_model_config:
|
||||
self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config)
|
||||
if isinstance(self.model_config, dict):
|
||||
self.model_config = ModelConfig(**self.model_config)
|
||||
if isinstance(self.image_generation_model_config, dict):
|
||||
self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config) if self.image_generation_model_config else None
|
||||
if isinstance(self.embedding_model, dict):
|
||||
self.embedding_model = EmbeddingModel(**self.embedding_model) if self.embedding_model else None
|
||||
|
||||
@dataclass
|
||||
class LLMSettings:
|
||||
wake_prefix: str = ""
|
||||
@@ -115,6 +125,35 @@ class DashboardConfig:
|
||||
enable: bool = True
|
||||
username: str = ""
|
||||
password: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ATRILongTermMemory:
|
||||
enable: bool = False
|
||||
summary_threshold_cnt: int = 5
|
||||
|
||||
@dataclass
|
||||
class ATRIActiveMessage:
|
||||
enable: bool = False
|
||||
|
||||
@dataclass
|
||||
class ProjectATRI:
|
||||
enable: bool = False
|
||||
long_term_memory: ATRILongTermMemory = field(default_factory=ATRILongTermMemory)
|
||||
active_message: ATRIActiveMessage = field(default_factory=ATRIActiveMessage)
|
||||
persona: str = ""
|
||||
embedding_provider_id: str = ""
|
||||
summarize_provider_id: str = ""
|
||||
chat_provider_id: str = ""
|
||||
chat_base_model_path: str = ""
|
||||
chat_adapter_model_path: str = ""
|
||||
quantization_bit: int = 4
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.long_term_memory, dict):
|
||||
self.long_term_memory = ATRILongTermMemory(**self.long_term_memory)
|
||||
if isinstance(self.active_message, dict):
|
||||
self.active_message = ATRIActiveMessage(**self.active_message)
|
||||
|
||||
@dataclass
|
||||
class AstrBotConfig():
|
||||
@@ -134,6 +173,7 @@ class AstrBotConfig():
|
||||
t2i_endpoint: str = ""
|
||||
pip_install_arg: str = ""
|
||||
plugin_repo_mirror: str = ""
|
||||
project_atri: ProjectATRI = field(default_factory=ProjectATRI)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.init_configs()
|
||||
@@ -190,6 +230,7 @@ class AstrBotConfig():
|
||||
self.t2i_endpoint=data.get("t2i_endpoint", "")
|
||||
self.pip_install_arg=data.get("pip_install_arg", "")
|
||||
self.plugin_repo_mirror=data.get("plugin_repo_mirror", "")
|
||||
self.project_atri=ProjectATRI(**data.get("project_atri", {}))
|
||||
|
||||
def flush_config(self, config: dict = None):
|
||||
'''将配置写入文件, 如果没有传入配置,则写入默认配置'''
|
||||
|
||||
@@ -27,6 +27,10 @@ PROVIDER_CONFIG_TEMPLATE = {
|
||||
"size": "1024x1024",
|
||||
"style": "vivid",
|
||||
"quality": "standard",
|
||||
},
|
||||
"embedding_model": {
|
||||
"enable": False,
|
||||
"model": "text-embedding-3-small"
|
||||
}
|
||||
},
|
||||
"ollama": {
|
||||
@@ -147,6 +151,23 @@ DEFAULT_CONFIG_VERSION_2 = {
|
||||
"t2i_endpoint": "",
|
||||
"pip_install_arg": "",
|
||||
"plugin_repo_mirror": "default",
|
||||
"project_atri": {
|
||||
"enable": False,
|
||||
"long_term_memory": {
|
||||
"enable": False,
|
||||
"summary_threshold_cnt": 6,
|
||||
},
|
||||
"active_message": {
|
||||
"enable": False,
|
||||
},
|
||||
"persona": "",
|
||||
"embedding_provider_id": "",
|
||||
"summarize_provider_id": "",
|
||||
"chat_provider_id": "",
|
||||
"chat_base_model_path": "",
|
||||
"chat_adapter_model_path": "",
|
||||
"quantization_bit": 4
|
||||
}
|
||||
}
|
||||
|
||||
# 配置项的中文描述、值类型
|
||||
@@ -167,7 +188,7 @@ CONFIG_METADATA_2 = {
|
||||
"ws_reverse_port": {"description": "反向 Websocket 端口", "type": "int", "hint": "aiocqhttp 适配器的反向 Websocket 端口。"},
|
||||
"qq_id_whitelist": {"description": "QQ 号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 号发来的消息事件。为空时表示不启用白名单过滤。"},
|
||||
"qq_group_id_whitelist": {"description": "QQ 群号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 群发来的消息事件。为空时表示不启用白名单过滤。"},
|
||||
"wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID(不是微信号)。"},
|
||||
"wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID(不是微信号)。注意:每次扫码登录之后,相同联系人的 ID 会发生变化,白名单内的 ID 会失效。"},
|
||||
}
|
||||
},
|
||||
"platform_settings": {
|
||||
@@ -200,17 +221,17 @@ CONFIG_METADATA_2 = {
|
||||
"prompt_prefix": {"description": "Prompt 前缀", "type": "text", "hint": "每次与 LLM 对话时在对话前加上的自定义文本。默认为空。"},
|
||||
"default_personality": {"description": "默认人格", "type": "text", "hint": "在当前版本下,默认人格文本会被添加到 LLM 对话的 `system` 字段中。"},
|
||||
"model_config": {
|
||||
"description": "模型配置",
|
||||
"description": "文本生成模型",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"model": {"description": "模型名称", "type": "string", "hint": "大语言模型的名称,一般是小写的英文。如 gpt-4o-mini, deepseek-chat 等。"},
|
||||
"max_tokens": {"description": "最大令牌数", "type": "int"},
|
||||
"max_tokens": {"description": "模型最大输出长度(tokens)", "type": "int"},
|
||||
"temperature": {"description": "温度", "type": "float"},
|
||||
"top_p": {"description": "Top P值", "type": "float"},
|
||||
}
|
||||
},
|
||||
"image_generation_model_config": {
|
||||
"description": "图像生成模型配置",
|
||||
"description": "图像生成模型",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "启用", "type": "bool", "hint": "启用该功能需要提供商支持图像生成。如 dall-e-3"},
|
||||
@@ -220,6 +241,14 @@ CONFIG_METADATA_2 = {
|
||||
"quality": {"description": "图像质量", "type": "string"},
|
||||
}
|
||||
},
|
||||
"embedding_model": {
|
||||
"description": "文本嵌入模型",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "启用", "type": "bool", "hint": "启用该功能需要提供商支持文本嵌入。"},
|
||||
"model": {"description": "模型名称", "type": "string", "hint": "文本嵌入模型的名称,一般是小写的英文。如 text-embedding-3-small"},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"llm_settings": {
|
||||
@@ -273,6 +302,35 @@ CONFIG_METADATA_2 = {
|
||||
"t2i_endpoint": {"description": "文本转图像服务接口", "type": "string", "hint": "为空时使用 AstrBot API 服务"},
|
||||
"pip_install_arg": {"description": "pip 安装参数", "type": "string", "hint": "安装插件依赖时,会使用 Python 的 pip 工具。这里可以填写额外的参数,如 `--break-system-package` 等。"},
|
||||
"plugin_repo_mirror": {"description": "插件仓库镜像", "type": "string", "hint": "插件仓库的镜像地址,用于加速插件的下载。", "options": ["default", "https://ghp.ci/", "https://github-mirror.us.kg/"]},
|
||||
"project_atri": {
|
||||
"description": "Project ATRI 配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "启用", "type": "bool"},
|
||||
"long_term_memory": {
|
||||
"description": "长期记忆",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "启用", "type": "bool"},
|
||||
"summary_threshold_cnt": {"description": "摘要阈值", "type": "int", "hint": "当一个会话的对话记录数量超过该阈值时,会自动进行摘要。"},
|
||||
}
|
||||
},
|
||||
"active_message": {
|
||||
"description": "主动消息",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "启用", "type": "bool"},
|
||||
}
|
||||
},
|
||||
"persona": {"description": "人格", "type": "string", "hint": "默认人格。当启动 ATRI 之后,在 Provider 处设置的人格将会失效。", "obvious_hint": True},
|
||||
"embedding_provider_id": {"description": "Embedding provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置", "obvious_hint": True},
|
||||
"summarize_provider_id": {"description": "Summary provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True},
|
||||
"chat_provider_id": {"description": "Chat provider ID", "type": "string", "hint": "将会使用指定的 provider 来进行文本聊天,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True},
|
||||
"chat_base_model_path": {"description": "用于聊天的基座模型路径", "type": "string", "hint": "用于聊天的基座模型路径。当填写此项和 Lora 路径后,将会忽略上面设置的 Chat provider ID。", "obvious_hint": True},
|
||||
"chat_adapter_model_path": {"description": "用于聊天的 Lora 模型路径", "type": "string", "hint": "Lora 模型路径。", "obvious_hint": True},
|
||||
"quantization_bit": {"description": "量化位数", "type": "int", "hint": "模型量化位数。如果你不知道这是什么,请不要修改。默认为 4。", "obvious_hint": True},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_VALUE_MAP = {
|
||||
|
||||
@@ -2,14 +2,14 @@ import asyncio, time, threading
|
||||
from .event_bus import EventBus
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
from core.config.astrbot_config import AstrBotConfig
|
||||
from core.message_event_handler import MessageEventHandler
|
||||
from core.plugin import PluginManager
|
||||
from core import LogBroker
|
||||
from core.db import BaseDatabase
|
||||
from core.updator import AstrBotUpdator
|
||||
from core import logger
|
||||
from core.config.default import VERSION
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.message.message_event_handler import MessageEventHandler
|
||||
from astrbot.core.plugin import PluginManager
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
class AstrBotCoreLifecycle:
|
||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from core.db.po import Stats, LLMHistory
|
||||
from astrbot.core.db.po import Stats, LLMHistory
|
||||
|
||||
@dataclass
|
||||
class BaseDatabase(abc.ABC):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sqlite3
|
||||
import os
|
||||
import time
|
||||
from core.db.po import (
|
||||
from astrbot.core.db.po import (
|
||||
Platform,
|
||||
Command,
|
||||
Provider,
|
||||
|
||||
@@ -2,10 +2,10 @@ import asyncio
|
||||
from asyncio import Queue
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from .message_event_handler import MessageEventHandler
|
||||
from core import logger
|
||||
from astrbot.core.message.message_event_handler import MessageEventHandler
|
||||
from astrbot.core import logger
|
||||
from .platform import AstrMessageEvent
|
||||
from nakuru.entities.components import Plain, Image
|
||||
from astrbot.core.message.components import Image, Plain
|
||||
|
||||
class EventBus:
|
||||
def __init__(self, event_queue: Queue, message_event_handler: MessageEventHandler):
|
||||
|
||||
443
astrbot/core/message/components.py
Normal file
443
astrbot/core/message/components.py
Normal file
@@ -0,0 +1,443 @@
|
||||
'''
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 Lxns-Network
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
'''
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import typing as T
|
||||
from enum import Enum
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
class ComponentType(Enum):
|
||||
Plain = "Plain"
|
||||
Face = "Face"
|
||||
Record = "Record"
|
||||
Video = "Video"
|
||||
At = "At"
|
||||
RPS = "RPS" # TODO
|
||||
Dice = "Dice" # TODO
|
||||
Shake = "Shake" # TODO
|
||||
Anonymous = "Anonymous" # TODO
|
||||
Share = "Share"
|
||||
Contact = "Contact" # TODO
|
||||
Location = "Location" # TODO
|
||||
Music = "Music"
|
||||
Image = "Image"
|
||||
Reply = "Reply"
|
||||
RedBag = "RedBag"
|
||||
Poke = "Poke"
|
||||
Forward = "Forward"
|
||||
Node = "Node"
|
||||
Xml = "Xml"
|
||||
Json = "Json"
|
||||
CardImage = "CardImage"
|
||||
TTS = "TTS"
|
||||
Unknown = "Unknown"
|
||||
|
||||
|
||||
class BaseMessageComponent(BaseModel):
|
||||
type: ComponentType
|
||||
|
||||
def toString(self):
|
||||
output = f"[CQ:{self.type.lower()}"
|
||||
for k, v in self.__dict__.items():
|
||||
if k == "type" or v is None:
|
||||
continue
|
||||
if k == "_type":
|
||||
k = "type"
|
||||
if isinstance(v, bool):
|
||||
v = 1 if v else 0
|
||||
output += ",%s=%s" % (k, str(v).replace("&", "&") \
|
||||
.replace(",", ",") \
|
||||
.replace("[", "[") \
|
||||
.replace("]", "]"))
|
||||
output += "]"
|
||||
return output
|
||||
|
||||
def toDict(self):
|
||||
data = dict()
|
||||
for k, v in self.__dict__.items():
|
||||
if k == "type" or v is None:
|
||||
continue
|
||||
if k == "_type":
|
||||
k = "type"
|
||||
data[k] = v
|
||||
return {
|
||||
"type": self.type.lower(),
|
||||
"data": data
|
||||
}
|
||||
|
||||
|
||||
class Plain(BaseMessageComponent):
|
||||
type: ComponentType = "Plain"
|
||||
text: str
|
||||
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
|
||||
|
||||
def __init__(self, text: str, convert: bool = True, **_):
|
||||
super().__init__(text=text, convert=convert, **_)
|
||||
|
||||
def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本
|
||||
if not self.convert:
|
||||
return self.text
|
||||
return self.text.replace("&", "&") \
|
||||
.replace("[", "[") \
|
||||
.replace("]", "]")
|
||||
|
||||
|
||||
class Face(BaseMessageComponent):
|
||||
type: ComponentType = "Face"
|
||||
id: int
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Record(BaseMessageComponent):
|
||||
type: ComponentType = "Record"
|
||||
file: T.Optional[str] = ""
|
||||
magic: T.Optional[bool] = False
|
||||
url: T.Optional[str] = ""
|
||||
cache: T.Optional[bool] = True
|
||||
proxy: T.Optional[bool] = True
|
||||
timeout: T.Optional[int] = 0
|
||||
# 额外
|
||||
path: T.Optional[str]
|
||||
|
||||
def __init__(self, file: T.Optional[str], **_):
|
||||
for k in _.keys():
|
||||
if k == "url":
|
||||
pass
|
||||
# Protocol.warn(f"go-cqhttp doesn't support send {self.type} by {k}")
|
||||
super().__init__(file=file, **_)
|
||||
|
||||
@staticmethod
|
||||
def fromFileSystem(path, **_):
|
||||
return Record(file=f"file:///{os.path.abspath(path)}", path=path, **_)
|
||||
|
||||
@staticmethod
|
||||
def fromURL(url: str, **_):
|
||||
if url.startswith("http://") or url.startswith("https://"):
|
||||
return Record(file=url, **_)
|
||||
raise Exception("not a valid url")
|
||||
|
||||
|
||||
class Video(BaseMessageComponent):
|
||||
type: ComponentType = "Video"
|
||||
file: str
|
||||
cover: T.Optional[str] = ""
|
||||
c: T.Optional[int] = 2
|
||||
# 额外
|
||||
path: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, file: str, **_):
|
||||
# for k in _.keys():
|
||||
# if k == "c" and _[k] not in [2, 3]:
|
||||
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
|
||||
super().__init__(file=file, **_)
|
||||
|
||||
@staticmethod
|
||||
def fromFileSystem(path, **_):
|
||||
return Video(file=f"file:///{os.path.abspath(path)}", path=path, **_)
|
||||
|
||||
@staticmethod
|
||||
def fromURL(url: str, **_):
|
||||
if url.startswith("http://") or url.startswith("https://"):
|
||||
return Video(file=url, **_)
|
||||
raise Exception("not a valid url")
|
||||
|
||||
|
||||
class At(BaseMessageComponent):
|
||||
type: ComponentType = "At"
|
||||
qq: T.Union[int, str] # 此处str为all时代表所有人
|
||||
name: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class AtAll(At):
|
||||
qq: str = "all"
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class RPS(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "RPS"
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Dice(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Dice"
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Shake(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Shake"
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Anonymous(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Anonymous"
|
||||
ignore: T.Optional[bool] = False
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Share(BaseMessageComponent):
|
||||
type: ComponentType = "Share"
|
||||
url: str
|
||||
title: str
|
||||
content: T.Optional[str] = ""
|
||||
image: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Contact(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Contact"
|
||||
_type: str # type 字段冲突
|
||||
id: T.Optional[int] = 0
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Location(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Location"
|
||||
lat: float
|
||||
lon: float
|
||||
title: T.Optional[str] = ""
|
||||
content: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Music(BaseMessageComponent):
|
||||
type: ComponentType = "Music"
|
||||
_type: str
|
||||
id: T.Optional[int] = 0
|
||||
url: T.Optional[str] = ""
|
||||
audio: T.Optional[str] = ""
|
||||
title: T.Optional[str] = ""
|
||||
content: T.Optional[str] = ""
|
||||
image: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
# for k in _.keys():
|
||||
# if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]:
|
||||
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Image(BaseMessageComponent):
|
||||
type: ComponentType = "Image"
|
||||
file: T.Optional[str] = ""
|
||||
_type: T.Optional[str] = ""
|
||||
subType: T.Optional[int] = 0
|
||||
url: T.Optional[str] = ""
|
||||
cache: T.Optional[bool] = True
|
||||
id: T.Optional[int] = 40000
|
||||
c: T.Optional[int] = 2
|
||||
# 额外
|
||||
path: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, file: T.Optional[str], **_):
|
||||
# for k in _.keys():
|
||||
# if (k == "_type" and _[k] not in ["flash", "show", None]) or \
|
||||
# (k == "c" and _[k] not in [2, 3]):
|
||||
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
|
||||
super().__init__(file=file, **_)
|
||||
|
||||
@staticmethod
|
||||
def fromURL(url: str, **_):
|
||||
if url.startswith("http://") or url.startswith("https://"):
|
||||
return Image(file=url, **_)
|
||||
raise Exception("not a valid url")
|
||||
|
||||
@staticmethod
|
||||
def fromFileSystem(path, **_):
|
||||
return Image(file=f"file:///{os.path.abspath(path)}", path=path, **_)
|
||||
|
||||
@staticmethod
|
||||
def fromBase64(base64: str, **_):
|
||||
return Image(f"base64://{base64}", **_)
|
||||
|
||||
@staticmethod
|
||||
def fromBytes(byte: bytes):
|
||||
return Image.fromBase64(base64.b64encode(byte).decode())
|
||||
|
||||
@staticmethod
|
||||
def fromIO(IO):
|
||||
return Image.fromBytes(IO.read())
|
||||
|
||||
|
||||
class Reply(BaseMessageComponent):
|
||||
type: ComponentType = "Reply"
|
||||
id: int
|
||||
text: T.Optional[str] = ""
|
||||
qq: T.Optional[int] = 0
|
||||
time: T.Optional[int] = 0
|
||||
seq: T.Optional[int] = 0
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class RedBag(BaseMessageComponent):
|
||||
type: ComponentType = "RedBag"
|
||||
title: str
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Poke(BaseMessageComponent):
|
||||
type: ComponentType = "Poke"
|
||||
qq: int
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Forward(BaseMessageComponent):
|
||||
type: ComponentType = "Forward"
|
||||
id: str
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Node(BaseMessageComponent): # 该 component 仅支持使用 sendGroupForwardMessage 发送
|
||||
type: ComponentType = "Node"
|
||||
id: T.Optional[int] = 0
|
||||
name: T.Optional[str] = ""
|
||||
uin: T.Optional[int] = 0
|
||||
content: T.Optional[T.Union[str, list]] = ""
|
||||
seq: T.Optional[T.Union[str, list]] = "" # 不清楚是什么
|
||||
time: T.Optional[int] = 0
|
||||
|
||||
def __init__(self, content: T.Union[str, list], **_):
|
||||
if isinstance(content, list):
|
||||
_content = ""
|
||||
for chain in content:
|
||||
_content += chain.toString()
|
||||
content = _content
|
||||
super().__init__(content=content, **_)
|
||||
|
||||
def toString(self):
|
||||
# logger.warn("Protocol: node doesn't support stringify")
|
||||
return ""
|
||||
|
||||
|
||||
class Xml(BaseMessageComponent):
|
||||
type: ComponentType = "Xml"
|
||||
data: str
|
||||
resid: T.Optional[int] = 0
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Json(BaseMessageComponent):
|
||||
type: ComponentType = "Json"
|
||||
data: T.Union[str, dict]
|
||||
resid: T.Optional[int] = 0
|
||||
|
||||
def __init__(self, data, **_):
|
||||
if isinstance(data, dict):
|
||||
data = json.dumps(data)
|
||||
super().__init__(data=data, **_)
|
||||
|
||||
|
||||
class CardImage(BaseMessageComponent):
|
||||
type: ComponentType = "CardImage"
|
||||
file: str
|
||||
cache: T.Optional[bool] = True
|
||||
minwidth: T.Optional[int] = 400
|
||||
minheight: T.Optional[int] = 400
|
||||
maxwidth: T.Optional[int] = 500
|
||||
maxheight: T.Optional[int] = 500
|
||||
source: T.Optional[str] = ""
|
||||
icon: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
@staticmethod
|
||||
def fromFileSystem(path, **_):
|
||||
return CardImage(file=f"file:///{os.path.abspath(path)}", **_)
|
||||
|
||||
|
||||
class TTS(BaseMessageComponent):
|
||||
type: ComponentType = "TTS"
|
||||
text: str
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Unknown(BaseMessageComponent):
|
||||
type: ComponentType = "Unknown"
|
||||
text: str
|
||||
|
||||
def toString(self):
|
||||
return ""
|
||||
|
||||
|
||||
ComponentTypes = {
|
||||
"plain": Plain,
|
||||
"face": Face,
|
||||
"record": Record,
|
||||
"video": Video,
|
||||
"at": At,
|
||||
"rps": RPS,
|
||||
"dice": Dice,
|
||||
"shake": Shake,
|
||||
"anonymous": Anonymous,
|
||||
"share": Share,
|
||||
"contact": Contact,
|
||||
"location": Location,
|
||||
"music": Music,
|
||||
"image": Image,
|
||||
"reply": Reply,
|
||||
"redbag": RedBag,
|
||||
"poke": Poke,
|
||||
"forward": Forward,
|
||||
"node": Node,
|
||||
"xml": Xml,
|
||||
"json": Json,
|
||||
"cardimage": CardImage,
|
||||
"tts": TTS,
|
||||
"unknown": Unknown
|
||||
}
|
||||
@@ -2,13 +2,13 @@ import asyncio, re, time
|
||||
import inspect
|
||||
import traceback
|
||||
from typing import List, Union
|
||||
from .platform import AstrMessageEvent
|
||||
from .config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from .message_event_result import MessageEventResult, CommandResult, MessageChain
|
||||
from .plugin import PluginManager, Context, CommandMetadata
|
||||
from nakuru.entities.components import *
|
||||
from core import logger
|
||||
from core import html_renderer
|
||||
from astrbot.core.plugin import PluginManager, Context, CommandMetadata
|
||||
from .components import *
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import html_renderer
|
||||
|
||||
class CommandTokens():
|
||||
def __init__(self) -> None:
|
||||
@@ -1,11 +1,12 @@
|
||||
from typing import List, Union, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from nakuru.entities.components import *
|
||||
from astrbot.core.message.components import *
|
||||
|
||||
@dataclass
|
||||
class MessageChain():
|
||||
chain: List[BaseMessageComponent] = field(default_factory=list)
|
||||
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
||||
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
|
||||
def message(self, message: str):
|
||||
'''
|
||||
@@ -49,6 +50,15 @@ class MessageChain():
|
||||
'''
|
||||
self.use_t2i_ = use_t2i
|
||||
return self
|
||||
|
||||
def is_split(self, is_split: bool):
|
||||
'''
|
||||
设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
|
||||
具体的效果以各适配器实现为准。
|
||||
'''
|
||||
self.is_split_ = is_split
|
||||
return self
|
||||
|
||||
@dataclass
|
||||
class MessageEventResult(MessageChain):
|
||||
@@ -2,11 +2,11 @@ import abc
|
||||
from dataclasses import dataclass
|
||||
from .astrbot_message import AstrBotMessage
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from core.message_event_result import MessageEventResult, MessageChain
|
||||
from core.platform.message_type import MessageType
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from typing import List
|
||||
from nakuru.entities.components import BaseMessageComponent, Plain, Image
|
||||
from core.utils.metrics import Metric
|
||||
from astrbot.core.message.components import BaseMessageComponent, Plain, Image
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
|
||||
@dataclass
|
||||
class MessageSesion:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
from nakuru.entities.components import BaseMessageComponent
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from .message_type import MessageType
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -3,9 +3,9 @@ from typing import Awaitable, Any
|
||||
from asyncio import Queue
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from .astr_message_event import AstrMessageEvent
|
||||
from core.message_event_result import MessageChain
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from .astr_message_event import MessageSesion
|
||||
from core.utils.metrics import Metric
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
|
||||
class Platform(abc.ABC):
|
||||
def __init__(self, event_queue: Queue):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .plugin import Plugin, RegisteredPlugin, PluginMetadata
|
||||
from .plugin_manager import PluginManager
|
||||
from .context import CommandMetadata, Context
|
||||
from core.provider import Provider
|
||||
from astrbot.core.provider import Provider
|
||||
@@ -4,12 +4,12 @@ from . import RegisteredPlugin, PluginMetadata
|
||||
from typing import List, Dict, Awaitable, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.platform import Platform
|
||||
from core.db import BaseDatabase
|
||||
from core.config.astrbot_config import AstrBotConfig
|
||||
from core.utils.func_call import FuncCall
|
||||
from core.platform.astr_message_event import MessageSesion
|
||||
from core.message_event_result import MessageChain
|
||||
from astrbot.core.platform import Platform
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.utils.func_call import FuncCall
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
@dataclass
|
||||
class CommandMetadata():
|
||||
@@ -67,6 +67,9 @@ class Context:
|
||||
# 维护了 LLM Tools 信息
|
||||
llm_tools: FuncCall = FuncCall()
|
||||
|
||||
# 维护插件存储的数据
|
||||
plugin_data: Dict[str, Dict[str, any]] = {}
|
||||
|
||||
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
|
||||
self._event_queue = event_queue
|
||||
self._config = config
|
||||
@@ -205,4 +208,10 @@ class Context:
|
||||
if platform.meta().name == session.platform_name:
|
||||
await platform.send_by_session(session, message_chain)
|
||||
return True
|
||||
return False
|
||||
return False
|
||||
|
||||
def set_data(self, plugin_name: str, key: str, value: any):
|
||||
'''
|
||||
设置插件数据。
|
||||
'''
|
||||
self.plugin_data[plugin_name][key] = value
|
||||
@@ -10,13 +10,13 @@ from asyncio import Queue
|
||||
from types import ModuleType
|
||||
from typing import List, Awaitable
|
||||
from pip import main as pip_main
|
||||
from core.config.astrbot_config import AstrBotConfig
|
||||
from core import logger
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core import logger
|
||||
from .context import Context
|
||||
from . import RegisteredPlugin, PluginMetadata
|
||||
from .updator import PluginUpdator
|
||||
from core.db import BaseDatabase
|
||||
from core.utils.io import remove_dir
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
|
||||
class PluginManager:
|
||||
def __init__(self, config: AstrBotConfig, event_queue: Queue, db: BaseDatabase):
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os, zipfile, shutil
|
||||
|
||||
from ..updator import RepoZipUpdator
|
||||
from core.utils.io import remove_dir, on_error
|
||||
from astrbot.core.utils.io import remove_dir, on_error
|
||||
from ..plugin import RegisteredPlugin
|
||||
from typing import Union
|
||||
from core import logger
|
||||
from astrbot.core import logger
|
||||
|
||||
class PluginUpdator(RepoZipUpdator):
|
||||
def __init__(self, repo_mirror: str = "") -> None:
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .provider import Provider
|
||||
from .provider import Provider, Personality
|
||||
@@ -1,11 +1,30 @@
|
||||
import abc
|
||||
import abc, json, threading, time
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
# from core.utils.func_call import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core import logger
|
||||
from typing import TypedDict
|
||||
|
||||
class Personality(TypedDict):
|
||||
prompt: str
|
||||
name: str
|
||||
|
||||
class Provider(abc.ABC):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, db_helper: BaseDatabase, default_personality: str = None, persistant_history: bool = True) -> None:
|
||||
self.model_name = "unknown"
|
||||
# 维护了 session_id 的上下文,不包含 system 指令
|
||||
self.session_memory = defaultdict(list)
|
||||
self.curr_personality = Personality(prompt=default_personality, name="")
|
||||
|
||||
if persistant_history:
|
||||
# 读取历史记录
|
||||
self.db_helper = db_helper
|
||||
try:
|
||||
for history in db_helper.get_llm_history():
|
||||
self.session_memory[history.session_id] = json.loads(history.content)
|
||||
except BaseException as e:
|
||||
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
||||
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
@@ -13,12 +32,32 @@ class Provider(abc.ABC):
|
||||
def get_model(self):
|
||||
return self.model_name
|
||||
|
||||
async def get_human_readable_context(self, session_id: str) -> List[str]:
|
||||
'''
|
||||
获取人类可读的上下文
|
||||
|
||||
example:
|
||||
["User: 你好", "Assistant: 你好"]
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record['role'] == "user":
|
||||
contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
contexts.append(f"Assistant: {record['content']}")
|
||||
|
||||
return contexts
|
||||
|
||||
@abc.abstractmethod
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_urls: List[str] = None,
|
||||
tool = None,
|
||||
tools = None,
|
||||
contexts=None,
|
||||
**kwargs) -> str:
|
||||
'''
|
||||
prompt: 提示词
|
||||
@@ -38,6 +77,13 @@ class Provider(abc.ABC):
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_embedding(self, text: str) -> List[float]:
|
||||
'''
|
||||
获取文本的嵌入
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
'''
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os, psutil, sys, time
|
||||
from .zip_updator import ReleaseInfo, RepoZipUpdator
|
||||
from core import logger
|
||||
from core.config.default import VERSION
|
||||
from core.utils.io import download_file
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
class AstrBotUpdator(RepoZipUpdator):
|
||||
def __init__(self, repo_mirror: str = "") -> None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from core.provider import Provider
|
||||
from astrbot.core.provider import Provider
|
||||
from typing import Awaitable
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import aiohttp
|
||||
import sys
|
||||
import logging
|
||||
from core.config import VERSION
|
||||
from astrbot.core.config import VERSION
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from io import BytesIO
|
||||
|
||||
from . import RenderStrategy
|
||||
from PIL import ImageFont, Image, ImageDraw
|
||||
from core.utils.io import save_temp_img
|
||||
from astrbot.core.utils.io import save_temp_img
|
||||
|
||||
class LocalRenderStrategy(RenderStrategy):
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ import aiohttp
|
||||
import os
|
||||
|
||||
from . import RenderStrategy
|
||||
from core.config import VERSION
|
||||
from core.utils.io import download_image_by_url
|
||||
from astrbot.core.config import VERSION
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from .network_strategy import NetworkRenderStrategy
|
||||
from .local_strategy import LocalRenderStrategy
|
||||
from core.log import LogManager
|
||||
from astrbot.core.log import LogManager
|
||||
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import aiohttp, os, zipfile, shutil
|
||||
from core.utils.io import on_error, download_file
|
||||
from core import logger
|
||||
from astrbot.core.utils.io import on_error, download_file
|
||||
from astrbot.core import logger
|
||||
|
||||
class ReleaseInfo():
|
||||
version: str
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import asyncio
|
||||
from multiprocessing import Process
|
||||
from core import logger
|
||||
from core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from .server import AstrBotDashboard
|
||||
from core.db import BaseDatabase
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
class AstrBotDashBoardLifecycle:
|
||||
def __init__(self, db: BaseDatabase):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from .route import Route, Response
|
||||
from quart import Quart, request
|
||||
from core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
|
||||
class AuthRoute(Route):
|
||||
def __init__(self, config: AstrBotConfig, app: Quart) -> None:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os, json
|
||||
from .route import Route, Response
|
||||
from quart import Quart, request
|
||||
from core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE
|
||||
from core.config.astrbot_config import AstrBotConfig
|
||||
from core.plugin.config import update_config
|
||||
from core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.plugin.config import update_config
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from dataclasses import asdict
|
||||
|
||||
def try_cast(value: str, type_: str):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
from quart import websocket
|
||||
from quart import Quart
|
||||
from core.config.astrbot_config import AstrBotConfig
|
||||
from core import logger, LogBroker
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core import logger, LogBroker
|
||||
from .route import Route, Response
|
||||
|
||||
class LogRoute(Route):
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import threading, traceback, uuid
|
||||
from .route import Route, Response
|
||||
from core import logger
|
||||
from astrbot.core import logger
|
||||
from quart import Quart, request
|
||||
from core.config.astrbot_config import AstrBotConfig
|
||||
from core.plugin.plugin_manager import PluginManager
|
||||
from core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.plugin.plugin_manager import PluginManager
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
|
||||
class PluginRoute(Route):
|
||||
def __init__(self, config: AstrBotConfig, app: Quart, core_lifecycle: AstrBotCoreLifecycle, plugin_manager: PluginManager) -> None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from dataclasses import dataclass
|
||||
from quart import Quart
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import traceback, psutil, time, aiohttp
|
||||
from .route import Route, Response
|
||||
from core import logger
|
||||
from astrbot.core import logger
|
||||
from quart import Quart, request
|
||||
from core.config.astrbot_config import AstrBotConfig
|
||||
from core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from core.db import BaseDatabase
|
||||
from core.config import VERSION
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config import VERSION
|
||||
|
||||
class StatRoute(Route):
|
||||
def __init__(self, config: AstrBotConfig, app: Quart, db_helper: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from .route import Route
|
||||
from quart import Quart
|
||||
from core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
|
||||
class StaticFileRoute(Route):
|
||||
def __init__(self, config: AstrBotConfig, app: Quart) -> None:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import threading, traceback
|
||||
from .route import Route, Response
|
||||
from quart import Quart, request
|
||||
from core.config.astrbot_config import AstrBotConfig
|
||||
from core.updator import AstrBotUpdator
|
||||
from core import logger
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
|
||||
class UpdateRoute(Route):
|
||||
def __init__(self, config: AstrBotConfig, app: Quart, astrbot_updator: AstrBotUpdator) -> None:
|
||||
|
||||
@@ -2,22 +2,23 @@ import logging
|
||||
import asyncio, os
|
||||
from quart import Quart
|
||||
from quart.logging import default_handler
|
||||
from core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from .routes import *
|
||||
from core import logger
|
||||
from core.db import BaseDatabase
|
||||
from core.plugin.plugin_manager import PluginManager
|
||||
from core.updator import AstrBotUpdator
|
||||
from core.utils.io import get_local_ip_addresses
|
||||
from core.config import AstrBotConfig
|
||||
from core.db import BaseDatabase
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.plugin.plugin_manager import PluginManager
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
class AstrBotDashboard():
|
||||
def __init__(self, core_lifecycle: AstrBotCoreLifecycle, db: BaseDatabase) -> None:
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config = core_lifecycle.astrbot_config
|
||||
self.data_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data/dist"))
|
||||
self.app = Quart("dashboard", static_folder="dist", static_url_path="/")
|
||||
logger.info(f"Dashboard data path: {self.data_path}")
|
||||
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
|
||||
self.app.json.sort_keys = False
|
||||
|
||||
logging.getLogger(self.app.name).removeHandler(default_handler)
|
||||
|
||||
@@ -6,12 +6,12 @@ import mimetypes
|
||||
import aiohttp
|
||||
import zipfile
|
||||
from typing import List
|
||||
from core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from core.db.sqlite import SQLiteDatabase
|
||||
from core.config import DB_PATH
|
||||
from dashboard import AstrBotDashBoardLifecycle
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.config import DB_PATH
|
||||
from astrbot.dashboard import AstrBotDashBoardLifecycle
|
||||
|
||||
from core import logger, LogManager, LogBroker
|
||||
from astrbot.core import logger, LogManager, LogBroker
|
||||
|
||||
# add parent path to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -1,7 +1,7 @@
|
||||
import os, traceback
|
||||
import os, traceback, random, asyncio
|
||||
|
||||
from astrbot.api import AstrMessageEvent, MessageChain, logger
|
||||
from astrbot.api import Plain, Image
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from aiocqhttp import CQHttp
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
|
||||
@@ -11,7 +11,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
self.bot = bot
|
||||
|
||||
@staticmethod
|
||||
async def _parse_onebot_josn(message_chain: MessageChain):
|
||||
async def _parse_onebot_json(message_chain: MessageChain):
|
||||
'''解析成 OneBot json 格式'''
|
||||
ret = []
|
||||
for segment in message_chain.chain:
|
||||
@@ -31,8 +31,14 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_josn(message)
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
if os.environ.get('TEST_MODE', 'off') == 'on':
|
||||
return
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
|
||||
if message.is_split_: # 分条发送
|
||||
for m in ret:
|
||||
await self.bot.send(self.message_obj.raw_message, [m])
|
||||
await asyncio.sleep(random.uniform(0.75, 2.5))
|
||||
else:
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
await super().send(message)
|
||||
@@ -7,7 +7,7 @@ from aiocqhttp import CQHttp, Event
|
||||
from astrbot.api import Platform
|
||||
from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from .aiocqhttp_message_event import *
|
||||
from nakuru.entities.components import *
|
||||
from astrbot.api.message_components import *
|
||||
from astrbot.api import logger
|
||||
from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
||||
from astrbot.core.config.astrbot_config import PlatformConfig, AiocqhttpPlatformConfig, PlatformSettings
|
||||
@@ -29,7 +29,7 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_josn(message_chain)
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
|
||||
match session.message_type.value:
|
||||
case MessageType.GROUP_MESSAGE.value:
|
||||
if "_" in session.session_id:
|
||||
|
||||
@@ -4,7 +4,7 @@ import botpy.types
|
||||
import botpy.types.message
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata, MessageType
|
||||
from astrbot.api import Plain, Image
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from botpy import Client
|
||||
from botpy.http import Route
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from botpy import Client
|
||||
from astrbot.api import Platform
|
||||
from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from typing import Union, List, Dict
|
||||
from nakuru.entities.components import *
|
||||
from astrbot.api.message_components import *
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from .qqofficial_message_event import QQOfficialMessageEvent
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import random, asyncio
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api import Plain, Image
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from vchat import Core
|
||||
|
||||
class WechatPlatformEvent(AstrMessageEvent):
|
||||
@@ -14,7 +14,10 @@ class WechatPlatformEvent(AstrMessageEvent):
|
||||
plain = ""
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
plain += comp.text
|
||||
if message.is_split_:
|
||||
await client.send_msg(comp.text, user_name)
|
||||
else:
|
||||
plain += comp.text
|
||||
elif isinstance(comp, Image):
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
file_path = comp.file.replace("file:///", "")
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
from astrbot.api import Platform
|
||||
from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from typing import Union, List, Dict
|
||||
from nakuru.entities.components import *
|
||||
from astrbot.api.message_components import *
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from .wechat_message_event import WechatPlatformEvent
|
||||
@@ -24,6 +24,7 @@ class WechatPlatformAdapter(Platform):
|
||||
def __init__(self, platform_config: WechatPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settingss = platform_settings
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
self.client_self_id = uuid.uuid4().hex[:8]
|
||||
|
||||
@@ -51,6 +52,7 @@ class WechatPlatformAdapter(Platform):
|
||||
if msg.create_time < self.start_time:
|
||||
logger.debug(f"忽略旧消息: {msg}")
|
||||
return
|
||||
logger.debug(f"收到消息: {msg.todict()}")
|
||||
if self.config.wechat_id_whitelist and msg.from_.username not in self.config.wechat_id_whitelist:
|
||||
logger.debug(f"忽略不在白名单的微信消息。username: {msg.from_.username}")
|
||||
return
|
||||
@@ -80,7 +82,11 @@ class WechatPlatformAdapter(Platform):
|
||||
|
||||
sender = msg.chatroom_sender or msg.from_
|
||||
amsg.sender = MessageMember(sender.username, sender.nickname)
|
||||
amsg.message_str = msg.content.content
|
||||
|
||||
if msg.content.is_at_me:
|
||||
amsg.message_str = msg.content.content.split("\u2005")[1].strip()
|
||||
else:
|
||||
amsg.message_str = msg.content.content
|
||||
amsg.message_id = msg.message_id
|
||||
if isinstance(msg.from_, model.User):
|
||||
amsg.type = MessageType.FRIEND_MESSAGE
|
||||
@@ -91,10 +97,13 @@ class WechatPlatformAdapter(Platform):
|
||||
|
||||
amsg.raw_message = msg
|
||||
|
||||
session_id = msg.from_.username + "$$" + msg.to.username
|
||||
if msg.chatroom_sender is not None:
|
||||
session_id += '$$' + msg.chatroom_sender.username
|
||||
|
||||
if self.settingss.unique_session:
|
||||
session_id = msg.from_.username + "$$" + msg.to.username
|
||||
if msg.chatroom_sender is not None:
|
||||
session_id += '$$' + msg.chatroom_sender.username
|
||||
else:
|
||||
session_id = msg.from_.username
|
||||
|
||||
amsg.session_id = session_id
|
||||
return amsg
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from astrbot.api import Context, AstrMessageEvent, MessageEventResult, MessageChain
|
||||
from . import PLUGIN_NAME
|
||||
from astrbot.api import logger, Image, Plain
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.message_components import Image, Plain
|
||||
from astrbot.api import personalities
|
||||
from astrbot.api import command_parser
|
||||
from astrbot.api import Provider
|
||||
from astrbot.api import Provider, Personality
|
||||
|
||||
|
||||
class OpenAIAdapterCommand:
|
||||
@@ -25,7 +26,7 @@ class OpenAIAdapterCommand:
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
tokens = command_parser.parse(message.message_str)
|
||||
if tokens.len == 1:
|
||||
await self.provider.forget(message.session_id, keep_system_prompt=True)
|
||||
await self.provider.forget(message.session_id)
|
||||
message.set_result(MessageEventResult().message("重置成功"))
|
||||
elif tokens.get(1) == 'p':
|
||||
await self.provider.forget(message.session_id)
|
||||
@@ -81,17 +82,13 @@ class OpenAIAdapterCommand:
|
||||
message.set_result(MessageEventResult().message(f"历史记录:\n\n{contexts}\n第 {page} 页 | 共 {t_pages} 页\n\n*输入 /his 2 跳转到第 2 页"))
|
||||
|
||||
def status(self, message: AstrMessageEvent):
|
||||
keys_data = self.provider.get_keys_data()
|
||||
ret = "OpenAI Key"
|
||||
keys_data = self.provider.get_all_keys()
|
||||
ret = "{} Key"
|
||||
for k in keys_data:
|
||||
status = "🟢" if keys_data[k] else "🔴"
|
||||
ret += "\n|- " + k[:8] + " " + status
|
||||
ret += "\n|- " + k[:8]
|
||||
|
||||
ret += "\n当前模型: " + self.provider.get_model()
|
||||
|
||||
if message.session_id in self.provider.session_memory and len(self.provider.session_memory[message.session_id]):
|
||||
ret += "\n你的会话上下文: " + str(self.provider.session_memory[message.session_id][-1]['usage_tokens']) + " tokens"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
async def switch(self, message: AstrMessageEvent):
|
||||
@@ -160,18 +157,10 @@ class OpenAIAdapterCommand:
|
||||
else:
|
||||
ps = "".join(l[1:]).strip()
|
||||
if ps in personalities:
|
||||
self.provider.curr_personality = {
|
||||
'name': ps,
|
||||
'prompt': personalities[ps]
|
||||
}
|
||||
self.provider.personality_set(self.provider.curr_personality, message.session_id)
|
||||
self.provider.curr_personality = Personality(name=ps, prompt=personalities[ps])
|
||||
message.set_result(MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
|
||||
else:
|
||||
self.provider.curr_personality = {
|
||||
'name': '自定义人格',
|
||||
'prompt': ps
|
||||
}
|
||||
self.provider.personality_set(self.provider.curr_personality, message.session_id)
|
||||
self.provider.curr_personality = Personality(name="自定义人格", prompt=ps)
|
||||
message.set_result(MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
|
||||
|
||||
async def draw(self, message: AstrMessageEvent):
|
||||
|
||||
@@ -5,28 +5,34 @@ from .openai_adapter import ProviderOpenAIOfficial
|
||||
from .commands import OpenAIAdapterCommand
|
||||
from astrbot.api import logger
|
||||
from . import PLUGIN_NAME
|
||||
from astrbot.api import Image, Plain, MessageChain
|
||||
from astrbot.api import MessageChain
|
||||
from astrbot.api.message_components import Image, Plain
|
||||
from openai._exceptions import *
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from astrbot.api import command_parser
|
||||
from .web_searcher import search_from_bing, fetch_website_content
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.config.astrbot_config import LLMConfig
|
||||
from .atri import ATRI
|
||||
|
||||
class Main:
|
||||
def __init__(self, context: Context) -> None:
|
||||
supported_provider_names = ["openai", "ollama", "gemini", "deepseek", "zhipu"]
|
||||
|
||||
self.context = context
|
||||
|
||||
# 各 Provider 实例
|
||||
self.provider_insts: List[ProviderOpenAIOfficial] = []
|
||||
# Provider 的配置
|
||||
self.provider_llm_configs: List[LLMConfig] = []
|
||||
# 当前使用的 Provider
|
||||
self.provider = None
|
||||
# 当前使用的 Provider 的配置
|
||||
self.provider_config = None
|
||||
|
||||
llms_config = self.context.get_config().llm
|
||||
atri_config = self.context.get_config().project_atri
|
||||
|
||||
loaded = False
|
||||
for llm in llms_config:
|
||||
for llm in self.context.get_config().llm:
|
||||
if llm.enable:
|
||||
if llm.name in supported_provider_names:
|
||||
if not llm.key or not llm.enable:
|
||||
@@ -36,20 +42,33 @@ class Main:
|
||||
self.provider_llm_configs.append(llm)
|
||||
loaded = True
|
||||
logger.info(f"已启用 LLM Provider(OpenAI API 适配器): {llm.id}({llm.name})。")
|
||||
|
||||
if loaded:
|
||||
self.command_handler = OpenAIAdapterCommand(self.context)
|
||||
self.command_handler.set_provider(self.provider_insts[0])
|
||||
self.context.register_listener(PLUGIN_NAME, "openai_adapter_chat", self.chat, "OpenAI Adapter LLM 调用监听器", after_commands=True)
|
||||
self.context.register_listener(PLUGIN_NAME, "llm_chat_listener", self.chat, "llm_chat_listener", after_commands=True)
|
||||
self.provider = self.command_handler.provider
|
||||
self.provider_config = self.provider_llm_configs[0]
|
||||
|
||||
self.context.register_commands(PLUGIN_NAME, "provider", "查看当前 LLM Provider", 10, self.provider_info)
|
||||
self.context.register_commands(PLUGIN_NAME, "websearch", "启用/关闭网页搜索", 10, self.web_search)
|
||||
|
||||
if self.context.get_config().llm_settings.web_search:
|
||||
self.add_web_search_tools()
|
||||
|
||||
|
||||
# load atri
|
||||
self.atri = None
|
||||
if atri_config.enable:
|
||||
try:
|
||||
self.atri = ATRI(self.provider_llm_configs, atri_config, self.context)
|
||||
self.command_handler.provider = self.atri.atri_chat_provider
|
||||
except ImportError as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("载入 ATRI 失败。请确保使用 pip 安装了 requirements_atri.txt 下的库。")
|
||||
self.atri = None
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("载入 ATRI 失败。")
|
||||
self.atri = None
|
||||
|
||||
def add_web_search_tools(self):
|
||||
self.context.register_llm_tool("web_search", [{
|
||||
"type": "string",
|
||||
@@ -121,7 +140,10 @@ class Main:
|
||||
async def chat(self, event: AstrMessageEvent):
|
||||
if not event.is_wake_up():
|
||||
return
|
||||
|
||||
if self.atri:
|
||||
await self.atri.chat(event)
|
||||
return
|
||||
|
||||
# prompt 前缀
|
||||
if self.provider_config.prompt_prefix:
|
||||
event.message_str = self.provider_config.prompt_prefix + event.message_str
|
||||
@@ -131,6 +153,8 @@ class Main:
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
break
|
||||
|
||||
tool_use_flag = False
|
||||
llm_result = None
|
||||
try:
|
||||
if not self.context.llm_tools.empty():
|
||||
@@ -177,7 +201,6 @@ class Main:
|
||||
return
|
||||
else:
|
||||
# normal chat
|
||||
tool_use_flag = False
|
||||
# add user info to the prompt
|
||||
if self.context.get_config().llm_settings.identifier:
|
||||
user_id = event.message_obj.sender.user_id
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import tiktoken
|
||||
import threading
|
||||
import traceback
|
||||
import base64
|
||||
import json
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
@@ -17,90 +14,38 @@ from astrbot.api import Provider
|
||||
from astrbot.core.config.astrbot_config import LLMConfig
|
||||
from astrbot import logger
|
||||
from typing import List, Dict
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def __init__(self, llm_config: LLMConfig, db_helper: BaseDatabase) -> None:
|
||||
super().__init__()
|
||||
def __init__(self, llm_config: LLMConfig, db_helper: BaseDatabase, persistant_history = True) -> None:
|
||||
super().__init__(db_helper, llm_config.default_personality, persistant_history)
|
||||
|
||||
self.api_keys = []
|
||||
self.chosen_api_key = None
|
||||
self.base_url = None
|
||||
self.llm_config = llm_config
|
||||
self.keys_data = {} # 记录超额
|
||||
if llm_config.key: self.api_keys = llm_config.key
|
||||
if llm_config.api_base: self.base_url = llm_config.api_base
|
||||
if not self.api_keys:
|
||||
logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。")
|
||||
else:
|
||||
self.chosen_api_key = self.api_keys[0]
|
||||
|
||||
for key in self.api_keys:
|
||||
self.keys_data[key] = True
|
||||
self.api_keys = llm_config.key
|
||||
if llm_config.api_base:
|
||||
self.base_url = llm_config.api_base
|
||||
self.chosen_api_key = self.api_keys[0]
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
self.set_model(llm_config.model_config.model)
|
||||
if llm_config.image_generation_model_config:
|
||||
self.image_generator_model_configs: Dict = asdict(llm_config.image_generation_model_config)
|
||||
self.session_memory: Dict[str, List] = {} # 会话记忆
|
||||
self.session_memory_lock = threading.Lock()
|
||||
self.max_tokens = self.llm_config.model_config.max_tokens # 上下文窗口大小
|
||||
|
||||
logger.info("正在载入分词器 cl100k_base...")
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
|
||||
|
||||
self.DEFAULT_PERSONALITY = {
|
||||
"prompt": self.llm_config.default_personality,
|
||||
"name": "default"
|
||||
}
|
||||
self.curr_personality = self.DEFAULT_PERSONALITY
|
||||
self.session_personality = {} # 记录了某个session是否已设置人格。
|
||||
# 读取历史记录
|
||||
self.db_helper = db_helper
|
||||
try:
|
||||
for history in db_helper.get_llm_history():
|
||||
self.session_memory_lock.acquire()
|
||||
self.session_memory[history.session_id] = json.loads(history.content)
|
||||
self.session_memory_lock.release()
|
||||
except BaseException as e:
|
||||
logger.warning(f"读取 OpenAI LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
||||
|
||||
# 定时保存历史记录
|
||||
threading.Thread(target=self.dump_history, daemon=True).start()
|
||||
|
||||
def dump_history(self):
|
||||
'''转储历史记录'''
|
||||
time.sleep(30)
|
||||
while True:
|
||||
try:
|
||||
for session_id, content in self.session_memory.items():
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(content))
|
||||
except BaseException as e:
|
||||
logger.error("保存 LLM 历史记录失败: " + str(e))
|
||||
finally:
|
||||
time.sleep(10*60)
|
||||
|
||||
def personality_set(self, personality: dict, session_id: str):
|
||||
if not personality or not personality['prompt']: return
|
||||
if session_id not in self.session_memory:
|
||||
self.session_memory[session_id] = []
|
||||
self.curr_personality = personality
|
||||
self.session_personality = {} # 重置
|
||||
|
||||
new_record = {
|
||||
"user": {
|
||||
"role": "system",
|
||||
"content": personality['prompt'],
|
||||
},
|
||||
'usage_tokens': 0, # 到该条目的总 token 数
|
||||
'single-tokens': 0 # 该条目的 token 数
|
||||
}
|
||||
|
||||
self.session_memory[session_id] = [new_record]
|
||||
# 各类模型的配置
|
||||
self.image_generator_model_configs = None
|
||||
self.embedding_model_configs = None
|
||||
if llm_config.image_generation_model_config and llm_config.image_generation_model_config.enable:
|
||||
self.image_generator_model_configs: Dict = asdict(
|
||||
llm_config.image_generation_model_config)
|
||||
self.image_generator_model_configs.pop("enable")
|
||||
if llm_config.embedding_model and llm_config.embedding_model.enable:
|
||||
self.embedding_model_configs: Dict = asdict(
|
||||
llm_config.embedding_model)
|
||||
self.embedding_model_configs.pop("enable")
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
'''
|
||||
@@ -108,29 +53,12 @@ class ProviderOpenAIOfficial(Provider):
|
||||
'''
|
||||
if image_url.startswith("http"):
|
||||
image_url = await download_image_by_url(image_url)
|
||||
|
||||
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode('utf-8')
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
return ''
|
||||
|
||||
async def retrieve_context(self, session_id: str):
|
||||
'''
|
||||
根据 session_id 获取保存的 OpenAI 格式的上下文
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
# 转换为 openai 要求的格式
|
||||
context = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if "user" in record and record['user']:
|
||||
context.append(record['user'])
|
||||
if "AI" in record and record['AI']:
|
||||
context.append(record['AI'])
|
||||
|
||||
return context
|
||||
|
||||
async def get_models(self):
|
||||
models = []
|
||||
try:
|
||||
@@ -140,47 +68,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.client.base_url = bu + "/v1"
|
||||
models = await self.client.models.list()
|
||||
return models
|
||||
|
||||
async def assemble_context(self, session_id: str, prompt: str, image_url: str = None):
|
||||
'''
|
||||
组装上下文,并且根据当前上下文窗口大小截断
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
tokens_num = len(self.tokenizer.encode(prompt))
|
||||
previous_total_tokens_num = 0 if not self.session_memory[session_id] else self.session_memory[session_id][-1]['usage_tokens']
|
||||
|
||||
message = {
|
||||
"usage_tokens": previous_total_tokens_num + tokens_num,
|
||||
"single_tokens": tokens_num,
|
||||
"AI": None
|
||||
}
|
||||
if image_url:
|
||||
base_64_image = await self.encode_image_bs64(image_url)
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": base_64_image
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
|
||||
message["user"] = user_content
|
||||
self.session_memory[session_id].append(message)
|
||||
|
||||
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
|
||||
'''
|
||||
@@ -188,10 +75,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
|
||||
if len(self.session_memory[session_id]) == 0:
|
||||
return None
|
||||
|
||||
|
||||
for i in range(len(self.session_memory[session_id])):
|
||||
# 检查是否是 system prompt
|
||||
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
|
||||
@@ -206,156 +93,111 @@ class ProviderOpenAIOfficial(Provider):
|
||||
record = self.session_memory[session_id].pop(i)
|
||||
break
|
||||
|
||||
# 更新之后所有记录的 usage_tokens
|
||||
for i in range(len(self.session_memory[session_id])):
|
||||
self.session_memory[session_id][i]['usage_tokens'] -= record['single-tokens']
|
||||
logger.debug(f"淘汰上下文记录 1 条,释放 {record['single-tokens']} 个 token。当前上下文总 token 为 {self.session_memory[session_id][-1]['usage_tokens']}。")
|
||||
return record
|
||||
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_url=None,
|
||||
tools=None,
|
||||
async def assemble_context(self, contexts: List, text: str, image_urls: List[str] = None):
|
||||
'''
|
||||
组装上下文。
|
||||
'''
|
||||
if image_urls:
|
||||
for image_url in image_urls:
|
||||
base_64_image = await self.encode_image_bs64(image_url)
|
||||
user_content = {"role": "user","content": [
|
||||
{"type": "text", "text": text},
|
||||
{"type": "image_url", "image_url": {"url": base_64_image}}
|
||||
]}
|
||||
contexts.append(user_content)
|
||||
else:
|
||||
user_content = {"role": "user","content": text}
|
||||
contexts.append(user_content)
|
||||
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_urls=None,
|
||||
tools=None,
|
||||
contexts=None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
) -> str:
|
||||
'''
|
||||
调用 LLM 进行文本对话。
|
||||
|
||||
@param tools: LLM Function-calling 的工具函数
|
||||
@param contexts: 如果不为 None,则会原封不动地使用这个上下文进行对话。
|
||||
'''
|
||||
if os.environ.get("TEST_LLM", "off") != "on" and os.environ.get("TEST_MODE", "off") == "on":
|
||||
return "这是一个测试消息。"
|
||||
if not session_id:
|
||||
session_id = "unknown"
|
||||
if "unknown" in self.session_memory:
|
||||
del self.session_memory["unknown"]
|
||||
|
||||
if session_id not in self.session_memory:
|
||||
self.session_memory[session_id] = []
|
||||
|
||||
if session_id not in self.session_personality or not self.session_personality[session_id]:
|
||||
self.personality_set(self.curr_personality, session_id)
|
||||
self.session_personality[session_id] = True
|
||||
|
||||
# 组装上下文,并且根据当前上下文窗口大小截断
|
||||
await self.assemble_context(session_id, prompt, image_url)
|
||||
|
||||
# 获取上下文,openai 格式
|
||||
contexts = await self.retrieve_context(session_id)
|
||||
|
||||
logger.debug(f"OpenAI 请求上下文:{contexts}")
|
||||
|
||||
|
||||
await self.assemble_context(self.session_memory[session_id], prompt, image_urls)
|
||||
if not contexts:
|
||||
contexts = [*self.session_memory[session_id]]
|
||||
if self.curr_personality["prompt"]:
|
||||
contexts.insert(0, {"role": "system", "content": self.curr_personality["prompt"]})
|
||||
|
||||
|
||||
logger.debug(f"请求上下文:{contexts}")
|
||||
conf = asdict(self.llm_config.model_config)
|
||||
if tools:
|
||||
conf['tools'] = tools
|
||||
|
||||
# start request
|
||||
retry = 0
|
||||
rate_limit_retry = 0
|
||||
while retry < 3 or rate_limit_retry < 5:
|
||||
if tools:
|
||||
completion_coro = self.client.chat.completions.create(
|
||||
messages=contexts,
|
||||
stream=False,
|
||||
tools=tools,
|
||||
**conf
|
||||
)
|
||||
else:
|
||||
completion_coro = self.client.chat.completions.create(
|
||||
messages=contexts,
|
||||
stream=False,
|
||||
**conf
|
||||
)
|
||||
while retry < 3:
|
||||
completion_coro = self.client.chat.completions.create(
|
||||
messages=contexts,
|
||||
stream=False,
|
||||
**conf
|
||||
)
|
||||
try:
|
||||
completion = await completion_coro
|
||||
break
|
||||
except AuthenticationError as e:
|
||||
api_key = self.chosen_api_key[10:] + "..."
|
||||
logger.error(f"OpenAI API Key {api_key} 验证错误。详细原因:{e}。正在切换到下一个可用的 Key(如果有的话)")
|
||||
self.keys_data[self.chosen_api_key] = False
|
||||
ok = await self.switch_to_next_key()
|
||||
if ok: continue
|
||||
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
|
||||
except RateLimitError as e:
|
||||
if "You exceeded your current quota" in str(e):
|
||||
self.keys_data[self.chosen_api_key] = False
|
||||
ok = await self.switch_to_next_key()
|
||||
if ok: continue
|
||||
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
|
||||
logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}")
|
||||
await self.switch_to_next_key()
|
||||
rate_limit_retry += 1
|
||||
await asyncio.sleep(1)
|
||||
except BadRequestError as e:
|
||||
raise e
|
||||
except NotFoundError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
retry += 1
|
||||
if retry >= 3:
|
||||
logger.error(traceback.format_exc())
|
||||
raise Exception(f"OpenAI 请求失败:{e}。重试次数已达到上限。")
|
||||
raise Exception(f"请求失败:{e}。重试次数已达到上限。")
|
||||
if "maximum context length" in str(e):
|
||||
logger.warn(f"OpenAI 请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
|
||||
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。")
|
||||
logger.warning(f"请求失败:{e}。重试第 {retry} 次。")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
assert isinstance(completion, ChatCompletion)
|
||||
logger.debug(f"openai completion: {completion.usage}")
|
||||
|
||||
logger.debug(f"completion: {completion.usage}")
|
||||
|
||||
if len(completion.choices) == 0:
|
||||
raise Exception("OpenAI API 返回的 completion 为空。")
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
choice = completion.choices[0]
|
||||
|
||||
usage_tokens = completion.usage.total_tokens
|
||||
completion_tokens = completion.usage.completion_tokens
|
||||
self.session_memory[session_id][-1]['usage_tokens'] = usage_tokens
|
||||
self.session_memory[session_id][-1]['single_tokens'] += completion_tokens
|
||||
|
||||
|
||||
if choice.message.content:
|
||||
# 返回文本
|
||||
completion_text = str(choice.message.content).strip()
|
||||
self.session_memory[session_id].append({
|
||||
"role": "assistant",
|
||||
"content": completion_text
|
||||
})
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]))
|
||||
return completion_text
|
||||
elif choice.message.tool_calls and choice.message.tool_calls:
|
||||
# tools call (function calling)
|
||||
return choice.message.tool_calls[0].function
|
||||
|
||||
self.session_memory[session_id][-1]['AI'] = {
|
||||
"role": "assistant",
|
||||
"content": completion_text
|
||||
}
|
||||
else:
|
||||
raise Exception("Internal Error")
|
||||
|
||||
return completion_text
|
||||
|
||||
async def switch_to_next_key(self):
|
||||
'''
|
||||
切换到下一个 API Key
|
||||
'''
|
||||
if not self.api_keys:
|
||||
logger.error("OpenAI API Key 不存在。")
|
||||
return False
|
||||
|
||||
for key in self.keys_data:
|
||||
if self.keys_data[key]:
|
||||
# 没超额
|
||||
self.chosen_api_key = key
|
||||
self.client.api_key = key
|
||||
logger.info(f"OpenAI 切换到 API Key {key[:10]}... 成功。")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def image_generate(self, prompt: str, session_id: str = None, **kwargs) -> str:
|
||||
'''
|
||||
生成图片
|
||||
'''
|
||||
retry = 0
|
||||
conf = self.image_generator_model_configs
|
||||
if not conf:
|
||||
logger.error("图片生成模型配置不存在。")
|
||||
raise Exception("图片生成模型配置不存在。")
|
||||
conf.pop("enable")
|
||||
if not self.image_generator_model_configs:
|
||||
return
|
||||
while retry < 3:
|
||||
try:
|
||||
images_response = await self.client.images.generate(
|
||||
prompt=prompt,
|
||||
**conf
|
||||
**self.image_generator_model_configs
|
||||
)
|
||||
image_url = images_response.data[0].url
|
||||
return image_url
|
||||
@@ -367,15 +209,25 @@ class ProviderOpenAIOfficial(Provider):
|
||||
logger.warning(f"图片生成请求失败:{e}。重试第 {retry} 次。")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool:
|
||||
if session_id is None: return False
|
||||
async def get_embedding(self, text) -> List[float]:
|
||||
'''
|
||||
获取文本的嵌入
|
||||
'''
|
||||
if not self.embedding_model_configs:
|
||||
return
|
||||
try:
|
||||
embedding = await self.client.embeddings.create(
|
||||
input=text,
|
||||
**self.embedding_model_configs
|
||||
)
|
||||
return embedding.data[0].embedding
|
||||
except Exception as e:
|
||||
logger.error(f"获取文本嵌入失败:{e}")
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
if keep_system_prompt:
|
||||
self.personality_set(self.curr_personality, session_id)
|
||||
else:
|
||||
self.curr_personality = self.DEFAULT_PERSONALITY
|
||||
return True
|
||||
|
||||
|
||||
def dump_contexts_page(self, session_id: str, size=5, page=1,):
|
||||
'''
|
||||
获取缓存的会话
|
||||
@@ -383,25 +235,21 @@ class ProviderOpenAIOfficial(Provider):
|
||||
contexts_str = ""
|
||||
if session_id in self.session_memory:
|
||||
for record in self.session_memory[session_id]:
|
||||
if "user" in record and record['user']:
|
||||
text = record['user']['content'][:100] + "..." if len(record['user']['content']) > 100 else record['user']['content']
|
||||
if record['role'] == "user":
|
||||
text = record['content'][:100] + "..." if len(
|
||||
record['content']) > 100 else record['content']
|
||||
contexts_str += f"User: {text}\n\n"
|
||||
if "AI" in record and record['AI']:
|
||||
text = record['AI']['content'][:100] + "..." if len(record['AI']['content']) > 100 else record['AI']['content']
|
||||
elif record['role'] == "assistant":
|
||||
text = record['content'][:100] + "..." if len(
|
||||
record['content']) > 100 else record['content']
|
||||
contexts_str += f"Assistant: {text}\n\n"
|
||||
else:
|
||||
contexts_str = "会话 ID 不存在。"
|
||||
|
||||
return contexts_str, len(self.session_memory[session_id])
|
||||
|
||||
def get_configs(self):
|
||||
return asdict(self.llm_config)
|
||||
|
||||
def get_keys_data(self):
|
||||
return self.keys_data
|
||||
|
||||
def get_curr_key(self):
|
||||
return self.chosen_api_key
|
||||
|
||||
def set_key(self, key):
|
||||
self.client.api_key = key
|
||||
def get_all_keys(self):
|
||||
return self.api_keys
|
||||
@@ -1,14 +1,12 @@
|
||||
pydantic~=1.10.4
|
||||
pydantic
|
||||
vchat
|
||||
aiohttp
|
||||
openai
|
||||
qq-botpy
|
||||
chardet~=5.1.0
|
||||
Pillow
|
||||
nakuru-project
|
||||
beautifulsoup4
|
||||
googlesearch-python
|
||||
tiktoken
|
||||
readability-lxml
|
||||
quart
|
||||
psutil
|
||||
|
||||
Reference in New Issue
Block a user