remove: 移除了 nakuru-project 库

但仍然使用其对 OneBot 的数据格式封装。
This commit is contained in:
Soulter
2024-12-02 19:31:33 +08:00
parent ba12d65792
commit 750a93a1aa
49 changed files with 904 additions and 420 deletions

View File

@@ -1,10 +1,9 @@
from astrbot.core.plugin import Context
from astrbot.core.platform import AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.core.message_event_result import MessageEventResult, MessageChain, CommandResult
from astrbot.core.provider import Provider
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain, CommandResult
from astrbot.core.provider import Provider, Personality
from astrbot.core.config.astrbot_config import AstrBotConfig
from nakuru.entities.components import *
from astrbot import logger
from astrbot.core.utils.personality import personalities

View File

@@ -0,0 +1 @@
from astrbot.core.message.components import *

View File

@@ -1,5 +1,5 @@
from .log import LogManager, LogBroker
from core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.t2i.renderer import HtmlRenderer
html_renderer = HtmlRenderer()
logger = LogManager.GetLogger(log_name='astrbot')

View File

@@ -67,6 +67,11 @@ class ImageGenerationModelConfig:
style: str = "vivid"
quality: str = "standard"
@dataclass
class EmbeddingModel:
enable: bool = False
model: str = ""
@dataclass
class LLMConfig:
id: str = ""
@@ -77,12 +82,17 @@ class LLMConfig:
prompt_prefix: str = ""
default_personality: str = ""
model_config: ModelConfig = field(default_factory=ModelConfig)
image_generation_model_config: Optional[ImageGenerationModelConfig] = None
image_generation_model_config: Optional[ImageGenerationModelConfig] = field(default_factory=ImageGenerationModelConfig)
embedding_model: Optional[EmbeddingModel] = field(default_factory=EmbeddingModel)
def __post_init__(self):
self.model_config = ModelConfig(**self.model_config)
if self.image_generation_model_config:
self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config)
if isinstance(self.model_config, dict):
self.model_config = ModelConfig(**self.model_config)
if isinstance(self.image_generation_model_config, dict):
self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config) if self.image_generation_model_config else None
if isinstance(self.embedding_model, dict):
self.embedding_model = EmbeddingModel(**self.embedding_model) if self.embedding_model else None
@dataclass
class LLMSettings:
wake_prefix: str = ""
@@ -115,6 +125,35 @@ class DashboardConfig:
enable: bool = True
username: str = ""
password: str = ""
@dataclass
class ATRILongTermMemory:
enable: bool = False
summary_threshold_cnt: int = 5
@dataclass
class ATRIActiveMessage:
enable: bool = False
@dataclass
class ProjectATRI:
enable: bool = False
long_term_memory: ATRILongTermMemory = field(default_factory=ATRILongTermMemory)
active_message: ATRIActiveMessage = field(default_factory=ATRIActiveMessage)
persona: str = ""
embedding_provider_id: str = ""
summarize_provider_id: str = ""
chat_provider_id: str = ""
chat_base_model_path: str = ""
chat_adapter_model_path: str = ""
quantization_bit: int = 4
def __post_init__(self):
if isinstance(self.long_term_memory, dict):
self.long_term_memory = ATRILongTermMemory(**self.long_term_memory)
if isinstance(self.active_message, dict):
self.active_message = ATRIActiveMessage(**self.active_message)
@dataclass
class AstrBotConfig():
@@ -134,6 +173,7 @@ class AstrBotConfig():
t2i_endpoint: str = ""
pip_install_arg: str = ""
plugin_repo_mirror: str = ""
project_atri: ProjectATRI = field(default_factory=ProjectATRI)
def __init__(self) -> None:
self.init_configs()
@@ -190,6 +230,7 @@ class AstrBotConfig():
self.t2i_endpoint=data.get("t2i_endpoint", "")
self.pip_install_arg=data.get("pip_install_arg", "")
self.plugin_repo_mirror=data.get("plugin_repo_mirror", "")
self.project_atri=ProjectATRI(**data.get("project_atri", {}))
def flush_config(self, config: dict = None):
'''将配置写入文件, 如果没有传入配置,则写入默认配置'''

View File

@@ -27,6 +27,10 @@ PROVIDER_CONFIG_TEMPLATE = {
"size": "1024x1024",
"style": "vivid",
"quality": "standard",
},
"embedding_model": {
"enable": False,
"model": "text-embedding-3-small"
}
},
"ollama": {
@@ -147,6 +151,23 @@ DEFAULT_CONFIG_VERSION_2 = {
"t2i_endpoint": "",
"pip_install_arg": "",
"plugin_repo_mirror": "default",
"project_atri": {
"enable": False,
"long_term_memory": {
"enable": False,
"summary_threshold_cnt": 6,
},
"active_message": {
"enable": False,
},
"persona": "",
"embedding_provider_id": "",
"summarize_provider_id": "",
"chat_provider_id": "",
"chat_base_model_path": "",
"chat_adapter_model_path": "",
"quantization_bit": 4
}
}
# 配置项的中文描述、值类型
@@ -167,7 +188,7 @@ CONFIG_METADATA_2 = {
"ws_reverse_port": {"description": "反向 Websocket 端口", "type": "int", "hint": "aiocqhttp 适配器的反向 Websocket 端口。"},
"qq_id_whitelist": {"description": "QQ 号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 号发来的消息事件。为空时表示不启用白名单过滤。"},
"qq_group_id_whitelist": {"description": "QQ 群号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 群发来的消息事件。为空时表示不启用白名单过滤。"},
"wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID不是微信号"},
"wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID不是微信号注意:每次扫码登录之后,相同联系人的 ID 会发生变化,白名单内的 ID 会失效。"},
}
},
"platform_settings": {
@@ -200,17 +221,17 @@ CONFIG_METADATA_2 = {
"prompt_prefix": {"description": "Prompt 前缀", "type": "text", "hint": "每次与 LLM 对话时在对话前加上的自定义文本。默认为空。"},
"default_personality": {"description": "默认人格", "type": "text", "hint": "在当前版本下,默认人格文本会被添加到 LLM 对话的 `system` 字段中。"},
"model_config": {
"description": "模型配置",
"description": "文本生成模型",
"type": "object",
"items": {
"model": {"description": "模型名称", "type": "string", "hint": "大语言模型的名称,一般是小写的英文。如 gpt-4o-mini, deepseek-chat 等。"},
"max_tokens": {"description": "最大令牌数", "type": "int"},
"max_tokens": {"description": "模型最大输出长度tokens", "type": "int"},
"temperature": {"description": "温度", "type": "float"},
"top_p": {"description": "Top P值", "type": "float"},
}
},
"image_generation_model_config": {
"description": "图像生成模型配置",
"description": "图像生成模型",
"type": "object",
"items": {
"enable": {"description": "启用", "type": "bool", "hint": "启用该功能需要提供商支持图像生成。如 dall-e-3"},
@@ -220,6 +241,14 @@ CONFIG_METADATA_2 = {
"quality": {"description": "图像质量", "type": "string"},
}
},
"embedding_model": {
"description": "文本嵌入模型",
"type": "object",
"items": {
"enable": {"description": "启用", "type": "bool", "hint": "启用该功能需要提供商支持文本嵌入。"},
"model": {"description": "模型名称", "type": "string", "hint": "文本嵌入模型的名称,一般是小写的英文。如 text-embedding-3-small"},
}
}
}
},
"llm_settings": {
@@ -273,6 +302,35 @@ CONFIG_METADATA_2 = {
"t2i_endpoint": {"description": "文本转图像服务接口", "type": "string", "hint": "为空时使用 AstrBot API 服务"},
"pip_install_arg": {"description": "pip 安装参数", "type": "string", "hint": "安装插件依赖时,会使用 Python 的 pip 工具。这里可以填写额外的参数,如 `--break-system-package` 等。"},
"plugin_repo_mirror": {"description": "插件仓库镜像", "type": "string", "hint": "插件仓库的镜像地址,用于加速插件的下载。", "options": ["default", "https://ghp.ci/", "https://github-mirror.us.kg/"]},
"project_atri": {
"description": "Project ATRI 配置",
"type": "object",
"items": {
"enable": {"description": "启用", "type": "bool"},
"long_term_memory": {
"description": "长期记忆",
"type": "object",
"items": {
"enable": {"description": "启用", "type": "bool"},
"summary_threshold_cnt": {"description": "摘要阈值", "type": "int", "hint": "当一个会话的对话记录数量超过该阈值时,会自动进行摘要。"},
}
},
"active_message": {
"description": "主动消息",
"type": "object",
"items": {
"enable": {"description": "启用", "type": "bool"},
}
},
"persona": {"description": "人格", "type": "string", "hint": "默认人格。当启动 ATRI 之后,在 Provider 处设置的人格将会失效。", "obvious_hint": True},
"embedding_provider_id": {"description": "Embedding provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置", "obvious_hint": True},
"summarize_provider_id": {"description": "Summary provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True},
"chat_provider_id": {"description": "Chat provider ID", "type": "string", "hint": "将会使用指定的 provider 来进行文本聊天,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True},
"chat_base_model_path": {"description": "用于聊天的基座模型路径", "type": "string", "hint": "用于聊天的基座模型路径。当填写此项和 Lora 路径后,将会忽略上面设置的 Chat provider ID。", "obvious_hint": True},
"chat_adapter_model_path": {"description": "用于聊天的 Lora 模型路径", "type": "string", "hint": "Lora 模型路径。", "obvious_hint": True},
"quantization_bit": {"description": "量化位数", "type": "int", "hint": "模型量化位数。如果你不知道这是什么,请不要修改。默认为 4。", "obvious_hint": True},
}
}
}
DEFAULT_VALUE_MAP = {

View File

@@ -2,14 +2,14 @@ import asyncio, time, threading
from .event_bus import EventBus
from asyncio import Queue
from typing import List
from core.config.astrbot_config import AstrBotConfig
from core.message_event_handler import MessageEventHandler
from core.plugin import PluginManager
from core import LogBroker
from core.db import BaseDatabase
from core.updator import AstrBotUpdator
from core import logger
from core.config.default import VERSION
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.message.message_event_handler import MessageEventHandler
from astrbot.core.plugin import PluginManager
from astrbot.core import LogBroker
from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core.config.default import VERSION
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):

View File

@@ -1,7 +1,7 @@
import abc
from dataclasses import dataclass
from typing import List
from core.db.po import Stats, LLMHistory
from astrbot.core.db.po import Stats, LLMHistory
@dataclass
class BaseDatabase(abc.ABC):

View File

@@ -1,7 +1,7 @@
import sqlite3
import os
import time
from core.db.po import (
from astrbot.core.db.po import (
Platform,
Command,
Provider,

View File

@@ -2,10 +2,10 @@ import asyncio
from asyncio import Queue
from collections import defaultdict
from typing import List
from .message_event_handler import MessageEventHandler
from core import logger
from astrbot.core.message.message_event_handler import MessageEventHandler
from astrbot.core import logger
from .platform import AstrMessageEvent
from nakuru.entities.components import Plain, Image
from astrbot.core.message.components import Image, Plain
class EventBus:
def __init__(self, event_queue: Queue, message_event_handler: MessageEventHandler):

View File

@@ -0,0 +1,443 @@
'''
MIT License
Copyright (c) 2021 Lxns-Network
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
import base64
import json
import os
import typing as T
from enum import Enum
from pydantic.v1 import BaseModel
class ComponentType(Enum):
Plain = "Plain"
Face = "Face"
Record = "Record"
Video = "Video"
At = "At"
RPS = "RPS" # TODO
Dice = "Dice" # TODO
Shake = "Shake" # TODO
Anonymous = "Anonymous" # TODO
Share = "Share"
Contact = "Contact" # TODO
Location = "Location" # TODO
Music = "Music"
Image = "Image"
Reply = "Reply"
RedBag = "RedBag"
Poke = "Poke"
Forward = "Forward"
Node = "Node"
Xml = "Xml"
Json = "Json"
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown"
class BaseMessageComponent(BaseModel):
type: ComponentType
def toString(self):
output = f"[CQ:{self.type.lower()}"
for k, v in self.__dict__.items():
if k == "type" or v is None:
continue
if k == "_type":
k = "type"
if isinstance(v, bool):
v = 1 if v else 0
output += ",%s=%s" % (k, str(v).replace("&", "&") \
.replace(",", ",") \
.replace("[", "[") \
.replace("]", "]"))
output += "]"
return output
def toDict(self):
data = dict()
for k, v in self.__dict__.items():
if k == "type" or v is None:
continue
if k == "_type":
k = "type"
data[k] = v
return {
"type": self.type.lower(),
"data": data
}
class Plain(BaseMessageComponent):
type: ComponentType = "Plain"
text: str
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
def __init__(self, text: str, convert: bool = True, **_):
super().__init__(text=text, convert=convert, **_)
def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本
if not self.convert:
return self.text
return self.text.replace("&", "&") \
.replace("[", "[") \
.replace("]", "]")
class Face(BaseMessageComponent):
type: ComponentType = "Face"
id: int
def __init__(self, **_):
super().__init__(**_)
class Record(BaseMessageComponent):
type: ComponentType = "Record"
file: T.Optional[str] = ""
magic: T.Optional[bool] = False
url: T.Optional[str] = ""
cache: T.Optional[bool] = True
proxy: T.Optional[bool] = True
timeout: T.Optional[int] = 0
# 额外
path: T.Optional[str]
def __init__(self, file: T.Optional[str], **_):
for k in _.keys():
if k == "url":
pass
# Protocol.warn(f"go-cqhttp doesn't support send {self.type} by {k}")
super().__init__(file=file, **_)
@staticmethod
def fromFileSystem(path, **_):
return Record(file=f"file:///{os.path.abspath(path)}", path=path, **_)
@staticmethod
def fromURL(url: str, **_):
if url.startswith("http://") or url.startswith("https://"):
return Record(file=url, **_)
raise Exception("not a valid url")
class Video(BaseMessageComponent):
type: ComponentType = "Video"
file: str
cover: T.Optional[str] = ""
c: T.Optional[int] = 2
# 额外
path: T.Optional[str] = ""
def __init__(self, file: str, **_):
# for k in _.keys():
# if k == "c" and _[k] not in [2, 3]:
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
super().__init__(file=file, **_)
@staticmethod
def fromFileSystem(path, **_):
return Video(file=f"file:///{os.path.abspath(path)}", path=path, **_)
@staticmethod
def fromURL(url: str, **_):
if url.startswith("http://") or url.startswith("https://"):
return Video(file=url, **_)
raise Exception("not a valid url")
class At(BaseMessageComponent):
type: ComponentType = "At"
qq: T.Union[int, str] # 此处str为all时代表所有人
name: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
class AtAll(At):
qq: str = "all"
def __init__(self, **_):
super().__init__(**_)
class RPS(BaseMessageComponent): # TODO
type: ComponentType = "RPS"
def __init__(self, **_):
super().__init__(**_)
class Dice(BaseMessageComponent): # TODO
type: ComponentType = "Dice"
def __init__(self, **_):
super().__init__(**_)
class Shake(BaseMessageComponent): # TODO
type: ComponentType = "Shake"
def __init__(self, **_):
super().__init__(**_)
class Anonymous(BaseMessageComponent): # TODO
type: ComponentType = "Anonymous"
ignore: T.Optional[bool] = False
def __init__(self, **_):
super().__init__(**_)
class Share(BaseMessageComponent):
type: ComponentType = "Share"
url: str
title: str
content: T.Optional[str] = ""
image: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
class Contact(BaseMessageComponent): # TODO
type: ComponentType = "Contact"
_type: str # type 字段冲突
id: T.Optional[int] = 0
def __init__(self, **_):
super().__init__(**_)
class Location(BaseMessageComponent): # TODO
type: ComponentType = "Location"
lat: float
lon: float
title: T.Optional[str] = ""
content: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
class Music(BaseMessageComponent):
type: ComponentType = "Music"
_type: str
id: T.Optional[int] = 0
url: T.Optional[str] = ""
audio: T.Optional[str] = ""
title: T.Optional[str] = ""
content: T.Optional[str] = ""
image: T.Optional[str] = ""
def __init__(self, **_):
# for k in _.keys():
# if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]:
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
super().__init__(**_)
class Image(BaseMessageComponent):
type: ComponentType = "Image"
file: T.Optional[str] = ""
_type: T.Optional[str] = ""
subType: T.Optional[int] = 0
url: T.Optional[str] = ""
cache: T.Optional[bool] = True
id: T.Optional[int] = 40000
c: T.Optional[int] = 2
# 额外
path: T.Optional[str] = ""
def __init__(self, file: T.Optional[str], **_):
# for k in _.keys():
# if (k == "_type" and _[k] not in ["flash", "show", None]) or \
# (k == "c" and _[k] not in [2, 3]):
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
super().__init__(file=file, **_)
@staticmethod
def fromURL(url: str, **_):
if url.startswith("http://") or url.startswith("https://"):
return Image(file=url, **_)
raise Exception("not a valid url")
@staticmethod
def fromFileSystem(path, **_):
return Image(file=f"file:///{os.path.abspath(path)}", path=path, **_)
@staticmethod
def fromBase64(base64: str, **_):
return Image(f"base64://{base64}", **_)
@staticmethod
def fromBytes(byte: bytes):
return Image.fromBase64(base64.b64encode(byte).decode())
@staticmethod
def fromIO(IO):
return Image.fromBytes(IO.read())
class Reply(BaseMessageComponent):
type: ComponentType = "Reply"
id: int
text: T.Optional[str] = ""
qq: T.Optional[int] = 0
time: T.Optional[int] = 0
seq: T.Optional[int] = 0
def __init__(self, **_):
super().__init__(**_)
class RedBag(BaseMessageComponent):
type: ComponentType = "RedBag"
title: str
def __init__(self, **_):
super().__init__(**_)
class Poke(BaseMessageComponent):
type: ComponentType = "Poke"
qq: int
def __init__(self, **_):
super().__init__(**_)
class Forward(BaseMessageComponent):
type: ComponentType = "Forward"
id: str
def __init__(self, **_):
super().__init__(**_)
class Node(BaseMessageComponent): # 该 component 仅支持使用 sendGroupForwardMessage 发送
type: ComponentType = "Node"
id: T.Optional[int] = 0
name: T.Optional[str] = ""
uin: T.Optional[int] = 0
content: T.Optional[T.Union[str, list]] = ""
seq: T.Optional[T.Union[str, list]] = "" # 不清楚是什么
time: T.Optional[int] = 0
def __init__(self, content: T.Union[str, list], **_):
if isinstance(content, list):
_content = ""
for chain in content:
_content += chain.toString()
content = _content
super().__init__(content=content, **_)
def toString(self):
# logger.warn("Protocol: node doesn't support stringify")
return ""
class Xml(BaseMessageComponent):
type: ComponentType = "Xml"
data: str
resid: T.Optional[int] = 0
def __init__(self, **_):
super().__init__(**_)
class Json(BaseMessageComponent):
type: ComponentType = "Json"
data: T.Union[str, dict]
resid: T.Optional[int] = 0
def __init__(self, data, **_):
if isinstance(data, dict):
data = json.dumps(data)
super().__init__(data=data, **_)
class CardImage(BaseMessageComponent):
type: ComponentType = "CardImage"
file: str
cache: T.Optional[bool] = True
minwidth: T.Optional[int] = 400
minheight: T.Optional[int] = 400
maxwidth: T.Optional[int] = 500
maxheight: T.Optional[int] = 500
source: T.Optional[str] = ""
icon: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
@staticmethod
def fromFileSystem(path, **_):
return CardImage(file=f"file:///{os.path.abspath(path)}", **_)
class TTS(BaseMessageComponent):
type: ComponentType = "TTS"
text: str
def __init__(self, **_):
super().__init__(**_)
class Unknown(BaseMessageComponent):
type: ComponentType = "Unknown"
text: str
def toString(self):
return ""
ComponentTypes = {
"plain": Plain,
"face": Face,
"record": Record,
"video": Video,
"at": At,
"rps": RPS,
"dice": Dice,
"shake": Shake,
"anonymous": Anonymous,
"share": Share,
"contact": Contact,
"location": Location,
"music": Music,
"image": Image,
"reply": Reply,
"redbag": RedBag,
"poke": Poke,
"forward": Forward,
"node": Node,
"xml": Xml,
"json": Json,
"cardimage": CardImage,
"tts": TTS,
"unknown": Unknown
}

View File

@@ -2,13 +2,13 @@ import asyncio, re, time
import inspect
import traceback
from typing import List, Union
from .platform import AstrMessageEvent
from .config.astrbot_config import AstrBotConfig
from astrbot.core.platform import AstrMessageEvent
from astrbot.core.config.astrbot_config import AstrBotConfig
from .message_event_result import MessageEventResult, CommandResult, MessageChain
from .plugin import PluginManager, Context, CommandMetadata
from nakuru.entities.components import *
from core import logger
from core import html_renderer
from astrbot.core.plugin import PluginManager, Context, CommandMetadata
from .components import *
from astrbot.core import logger
from astrbot.core import html_renderer
class CommandTokens():
def __init__(self) -> None:

View File

@@ -1,11 +1,12 @@
from typing import List, Union, Optional
from dataclasses import dataclass, field
from nakuru.entities.components import *
from astrbot.core.message.components import *
@dataclass
class MessageChain():
chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后将会依次发送 chain 中的每个 component。
def message(self, message: str):
'''
@@ -49,6 +50,15 @@ class MessageChain():
'''
self.use_t2i_ = use_t2i
return self
def is_split(self, is_split: bool):
'''
设置是否分条发送消息默认为 False启用后将会依次发送 chain 中的每个 component
具体的效果以各适配器实现为准
'''
self.is_split_ = is_split
return self
@dataclass
class MessageEventResult(MessageChain):

View File

@@ -2,11 +2,11 @@ import abc
from dataclasses import dataclass
from .astrbot_message import AstrBotMessage
from .platform_metadata import PlatformMetadata
from core.message_event_result import MessageEventResult, MessageChain
from core.platform.message_type import MessageType
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.platform.message_type import MessageType
from typing import List
from nakuru.entities.components import BaseMessageComponent, Plain, Image
from core.utils.metrics import Metric
from astrbot.core.message.components import BaseMessageComponent, Plain, Image
from astrbot.core.utils.metrics import Metric
@dataclass
class MessageSesion:

View File

@@ -1,7 +1,7 @@
import time
from typing import List
from dataclasses import dataclass
from nakuru.entities.components import BaseMessageComponent
from astrbot.core.message.components import BaseMessageComponent
from .message_type import MessageType
@dataclass

View File

@@ -3,9 +3,9 @@ from typing import Awaitable, Any
from asyncio import Queue
from .platform_metadata import PlatformMetadata
from .astr_message_event import AstrMessageEvent
from core.message_event_result import MessageChain
from astrbot.core.message.message_event_result import MessageChain
from .astr_message_event import MessageSesion
from core.utils.metrics import Metric
from astrbot.core.utils.metrics import Metric
class Platform(abc.ABC):
def __init__(self, event_queue: Queue):

View File

@@ -1,4 +1,4 @@
from .plugin import Plugin, RegisteredPlugin, PluginMetadata
from .plugin_manager import PluginManager
from .context import CommandMetadata, Context
from core.provider import Provider
from astrbot.core.provider import Provider

View File

@@ -4,12 +4,12 @@ from . import RegisteredPlugin, PluginMetadata
from typing import List, Dict, Awaitable, Union
from dataclasses import dataclass
from core.platform import Platform
from core.db import BaseDatabase
from core.config.astrbot_config import AstrBotConfig
from core.utils.func_call import FuncCall
from core.platform.astr_message_event import MessageSesion
from core.message_event_result import MessageChain
from astrbot.core.platform import Platform
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.utils.func_call import FuncCall
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
@dataclass
class CommandMetadata():
@@ -67,6 +67,9 @@ class Context:
# 维护了 LLM Tools 信息
llm_tools: FuncCall = FuncCall()
# 维护插件存储的数据
plugin_data: Dict[str, Dict[str, any]] = {}
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
self._event_queue = event_queue
self._config = config
@@ -205,4 +208,10 @@ class Context:
if platform.meta().name == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False
return False
def set_data(self, plugin_name: str, key: str, value: any):
'''
设置插件数据。
'''
self.plugin_data[plugin_name][key] = value

View File

@@ -10,13 +10,13 @@ from asyncio import Queue
from types import ModuleType
from typing import List, Awaitable
from pip import main as pip_main
from core.config.astrbot_config import AstrBotConfig
from core import logger
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core import logger
from .context import Context
from . import RegisteredPlugin, PluginMetadata
from .updator import PluginUpdator
from core.db import BaseDatabase
from core.utils.io import remove_dir
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.io import remove_dir
class PluginManager:
def __init__(self, config: AstrBotConfig, event_queue: Queue, db: BaseDatabase):

View File

@@ -1,10 +1,10 @@
import os, zipfile, shutil
from ..updator import RepoZipUpdator
from core.utils.io import remove_dir, on_error
from astrbot.core.utils.io import remove_dir, on_error
from ..plugin import RegisteredPlugin
from typing import Union
from core import logger
from astrbot.core import logger
class PluginUpdator(RepoZipUpdator):
def __init__(self, repo_mirror: str = "") -> None:

View File

@@ -1 +1 @@
from .provider import Provider
from .provider import Provider, Personality

View File

@@ -1,11 +1,30 @@
import abc
import abc, json, threading, time
from collections import defaultdict
from typing import List
# from core.utils.func_call import FuncCall
from astrbot.core.db import BaseDatabase
from astrbot.core import logger
from typing import TypedDict
class Personality(TypedDict):
prompt: str
name: str
class Provider(abc.ABC):
def __init__(self) -> None:
def __init__(self, db_helper: BaseDatabase, default_personality: str = None, persistant_history: bool = True) -> None:
self.model_name = "unknown"
# 维护了 session_id 的上下文,不包含 system 指令
self.session_memory = defaultdict(list)
self.curr_personality = Personality(prompt=default_personality, name="")
if persistant_history:
# 读取历史记录
self.db_helper = db_helper
try:
for history in db_helper.get_llm_history():
self.session_memory[history.session_id] = json.loads(history.content)
except BaseException as e:
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
def set_model(self, model_name: str):
self.model_name = model_name
@@ -13,12 +32,32 @@ class Provider(abc.ABC):
def get_model(self):
return self.model_name
async def get_human_readable_context(self, session_id: str) -> List[str]:
'''
获取人类可读的上下文
example:
["User: 你好", "Assistant: 你好"]
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
contexts.append(f"Assistant: {record['content']}")
return contexts
@abc.abstractmethod
async def text_chat(self,
prompt: str,
session_id: str,
image_urls: List[str] = None,
tool = None,
tools = None,
contexts=None,
**kwargs) -> str:
'''
prompt: 提示词
@@ -38,6 +77,13 @@ class Provider(abc.ABC):
'''
raise NotImplementedError()
@abc.abstractmethod
async def get_embedding(self, text: str) -> List[float]:
'''
获取文本的嵌入
'''
raise NotImplementedError()
@abc.abstractmethod
async def forget(self, session_id: str) -> bool:
'''

View File

@@ -1,8 +1,8 @@
import os, psutil, sys, time
from .zip_updator import ReleaseInfo, RepoZipUpdator
from core import logger
from core.config.default import VERSION
from core.utils.io import download_file
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.utils.io import download_file
class AstrBotUpdator(RepoZipUpdator):
def __init__(self, repo_mirror: str = "") -> None:

View File

@@ -1,4 +1,4 @@
from core.provider import Provider
from astrbot.core.provider import Provider
from typing import Awaitable
import json
import textwrap

View File

@@ -1,7 +1,7 @@
import aiohttp
import sys
import logging
from core.config import VERSION
from astrbot.core.config import VERSION
logger = logging.getLogger("astrbot")

View File

@@ -4,7 +4,7 @@ from io import BytesIO
from . import RenderStrategy
from PIL import ImageFont, Image, ImageDraw
from core.utils.io import save_temp_img
from astrbot.core.utils.io import save_temp_img
class LocalRenderStrategy(RenderStrategy):

View File

@@ -2,8 +2,8 @@ import aiohttp
import os
from . import RenderStrategy
from core.config import VERSION
from core.utils.io import download_image_by_url
from astrbot.core.config import VERSION
from astrbot.core.utils.io import download_image_by_url
ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"

View File

@@ -1,6 +1,6 @@
from .network_strategy import NetworkRenderStrategy
from .local_strategy import LocalRenderStrategy
from core.log import LogManager
from astrbot.core.log import LogManager
logger = LogManager.GetLogger(log_name='astrbot')

View File

@@ -1,6 +1,6 @@
import aiohttp, os, zipfile, shutil
from core.utils.io import on_error, download_file
from core import logger
from astrbot.core.utils.io import on_error, download_file
from astrbot.core import logger
class ReleaseInfo():
version: str

View File

@@ -1,9 +1,9 @@
import asyncio
from multiprocessing import Process
from core import logger
from core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .server import AstrBotDashboard
from core.db import BaseDatabase
from astrbot.core.db import BaseDatabase
class AstrBotDashBoardLifecycle:
def __init__(self, db: BaseDatabase):

View File

@@ -1,6 +1,6 @@
from .route import Route, Response
from quart import Quart, request
from core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.astrbot_config import AstrBotConfig
class AuthRoute(Route):
def __init__(self, config: AstrBotConfig, app: Quart) -> None:

View File

@@ -1,10 +1,10 @@
import os, json
from .route import Route, Response
from quart import Quart, request
from core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE
from core.config.astrbot_config import AstrBotConfig
from core.plugin.config import update_config
from core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.plugin.config import update_config
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from dataclasses import asdict
def try_cast(value: str, type_: str):

View File

@@ -1,8 +1,8 @@
import asyncio
from quart import websocket
from quart import Quart
from core.config.astrbot_config import AstrBotConfig
from core import logger, LogBroker
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core import logger, LogBroker
from .route import Route, Response
class LogRoute(Route):

View File

@@ -1,10 +1,10 @@
import threading, traceback, uuid
from .route import Route, Response
from core import logger
from astrbot.core import logger
from quart import Quart, request
from core.config.astrbot_config import AstrBotConfig
from core.plugin.plugin_manager import PluginManager
from core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.plugin.plugin_manager import PluginManager
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
class PluginRoute(Route):
def __init__(self, config: AstrBotConfig, app: Quart, core_lifecycle: AstrBotCoreLifecycle, plugin_manager: PluginManager) -> None:

View File

@@ -1,4 +1,4 @@
from core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.astrbot_config import AstrBotConfig
from dataclasses import dataclass
from quart import Quart

View File

@@ -1,11 +1,11 @@
import traceback, psutil, time, aiohttp
from .route import Route, Response
from core import logger
from astrbot.core import logger
from quart import Quart, request
from core.config.astrbot_config import AstrBotConfig
from core.core_lifecycle import AstrBotCoreLifecycle
from core.db import BaseDatabase
from core.config import VERSION
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.config import VERSION
class StatRoute(Route):
def __init__(self, config: AstrBotConfig, app: Quart, db_helper: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None:

View File

@@ -1,6 +1,6 @@
from .route import Route
from quart import Quart
from core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.astrbot_config import AstrBotConfig
class StaticFileRoute(Route):
def __init__(self, config: AstrBotConfig, app: Quart) -> None:

View File

@@ -1,9 +1,9 @@
import threading, traceback
from .route import Route, Response
from quart import Quart, request
from core.config.astrbot_config import AstrBotConfig
from core.updator import AstrBotUpdator
from core import logger
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
class UpdateRoute(Route):
def __init__(self, config: AstrBotConfig, app: Quart, astrbot_updator: AstrBotUpdator) -> None:

View File

@@ -2,22 +2,23 @@ import logging
import asyncio, os
from quart import Quart
from quart.logging import default_handler
from core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .routes import *
from core import logger
from core.db import BaseDatabase
from core.plugin.plugin_manager import PluginManager
from core.updator import AstrBotUpdator
from core.utils.io import get_local_ip_addresses
from core.config import AstrBotConfig
from core.db import BaseDatabase
from astrbot.core import logger
from astrbot.core.db import BaseDatabase
from astrbot.core.plugin.plugin_manager import PluginManager
from astrbot.core.updator import AstrBotUpdator
from astrbot.core.utils.io import get_local_ip_addresses
from astrbot.core.config import AstrBotConfig
from astrbot.core.db import BaseDatabase
class AstrBotDashboard():
def __init__(self, core_lifecycle: AstrBotCoreLifecycle, db: BaseDatabase) -> None:
self.core_lifecycle = core_lifecycle
self.config = core_lifecycle.astrbot_config
self.data_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data/dist"))
self.app = Quart("dashboard", static_folder="dist", static_url_path="/")
logger.info(f"Dashboard data path: {self.data_path}")
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
self.app.json.sort_keys = False
logging.getLogger(self.app.name).removeHandler(default_handler)

View File

@@ -6,12 +6,12 @@ import mimetypes
import aiohttp
import zipfile
from typing import List
from core.core_lifecycle import AstrBotCoreLifecycle
from core.db.sqlite import SQLiteDatabase
from core.config import DB_PATH
from dashboard import AstrBotDashBoardLifecycle
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config import DB_PATH
from astrbot.dashboard import AstrBotDashBoardLifecycle
from core import logger, LogManager, LogBroker
from astrbot.core import logger, LogManager, LogBroker
# add parent path to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

View File

@@ -1,7 +1,7 @@
import os, traceback
import os, traceback, random, asyncio
from astrbot.api import AstrMessageEvent, MessageChain, logger
from astrbot.api import Plain, Image
from astrbot.api.message_components import Plain, Image
from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url
@@ -11,7 +11,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
self.bot = bot
@staticmethod
async def _parse_onebot_josn(message_chain: MessageChain):
async def _parse_onebot_json(message_chain: MessageChain):
'''解析成 OneBot json 格式'''
ret = []
for segment in message_chain.chain:
@@ -31,8 +31,14 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
return ret
async def send(self, message: MessageChain):
ret = await AiocqhttpMessageEvent._parse_onebot_josn(message)
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
if os.environ.get('TEST_MODE', 'off') == 'on':
return
await self.bot.send(self.message_obj.raw_message, ret)
if message.is_split_: # 分条发送
for m in ret:
await self.bot.send(self.message_obj.raw_message, [m])
await asyncio.sleep(random.uniform(0.75, 2.5))
else:
await self.bot.send(self.message_obj.raw_message, ret)
await super().send(message)

View File

@@ -7,7 +7,7 @@ from aiocqhttp import CQHttp, Event
from astrbot.api import Platform
from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from .aiocqhttp_message_event import *
from nakuru.entities.components import *
from astrbot.api.message_components import *
from astrbot.api import logger
from .aiocqhttp_message_event import AiocqhttpMessageEvent
from astrbot.core.config.astrbot_config import PlatformConfig, AiocqhttpPlatformConfig, PlatformSettings
@@ -29,7 +29,7 @@ class AiocqhttpAdapter(Platform):
)
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
ret = await AiocqhttpMessageEvent._parse_onebot_josn(message_chain)
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
match session.message_type.value:
case MessageType.GROUP_MESSAGE.value:
if "_" in session.session_id:

View File

@@ -4,7 +4,7 @@ import botpy.types
import botpy.types.message
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata, MessageType
from astrbot.api import Plain, Image
from astrbot.api.message_components import Plain, Image
from botpy import Client
from botpy.http import Route

View File

@@ -9,7 +9,7 @@ from botpy import Client
from astrbot.api import Platform
from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from typing import Union, List, Dict
from nakuru.entities.components import *
from astrbot.api.message_components import *
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from .qqofficial_message_event import QQOfficialMessageEvent

View File

@@ -1,7 +1,7 @@
import random, asyncio
from astrbot.core.utils.io import download_image_by_url
from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata
from astrbot.api import Plain, Image
from astrbot.api.message_components import Plain, Image
from vchat import Core
class WechatPlatformEvent(AstrMessageEvent):
@@ -14,7 +14,10 @@ class WechatPlatformEvent(AstrMessageEvent):
plain = ""
for comp in message.chain:
if isinstance(comp, Plain):
plain += comp.text
if message.is_split_:
await client.send_msg(comp.text, user_name)
else:
plain += comp.text
elif isinstance(comp, Image):
if comp.file and comp.file.startswith("file:///"):
file_path = comp.file.replace("file:///", "")

View File

@@ -4,7 +4,7 @@ import asyncio
from astrbot.api import Platform
from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from typing import Union, List, Dict
from nakuru.entities.components import *
from astrbot.api.message_components import *
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from .wechat_message_event import WechatPlatformEvent
@@ -24,6 +24,7 @@ class WechatPlatformAdapter(Platform):
def __init__(self, platform_config: WechatPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settingss = platform_settings
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
self.client_self_id = uuid.uuid4().hex[:8]
@@ -51,6 +52,7 @@ class WechatPlatformAdapter(Platform):
if msg.create_time < self.start_time:
logger.debug(f"忽略旧消息: {msg}")
return
logger.debug(f"收到消息: {msg.todict()}")
if self.config.wechat_id_whitelist and msg.from_.username not in self.config.wechat_id_whitelist:
logger.debug(f"忽略不在白名单的微信消息。username: {msg.from_.username}")
return
@@ -80,7 +82,11 @@ class WechatPlatformAdapter(Platform):
sender = msg.chatroom_sender or msg.from_
amsg.sender = MessageMember(sender.username, sender.nickname)
amsg.message_str = msg.content.content
if msg.content.is_at_me:
amsg.message_str = msg.content.content.split("\u2005")[1].strip()
else:
amsg.message_str = msg.content.content
amsg.message_id = msg.message_id
if isinstance(msg.from_, model.User):
amsg.type = MessageType.FRIEND_MESSAGE
@@ -91,10 +97,13 @@ class WechatPlatformAdapter(Platform):
amsg.raw_message = msg
session_id = msg.from_.username + "$$" + msg.to.username
if msg.chatroom_sender is not None:
session_id += '$$' + msg.chatroom_sender.username
if self.settingss.unique_session:
session_id = msg.from_.username + "$$" + msg.to.username
if msg.chatroom_sender is not None:
session_id += '$$' + msg.chatroom_sender.username
else:
session_id = msg.from_.username
amsg.session_id = session_id
return amsg

View File

@@ -1,9 +1,10 @@
from astrbot.api import Context, AstrMessageEvent, MessageEventResult, MessageChain
from . import PLUGIN_NAME
from astrbot.api import logger, Image, Plain
from astrbot.api import logger
from astrbot.api.message_components import Image, Plain
from astrbot.api import personalities
from astrbot.api import command_parser
from astrbot.api import Provider
from astrbot.api import Provider, Personality
class OpenAIAdapterCommand:
@@ -25,7 +26,7 @@ class OpenAIAdapterCommand:
async def reset(self, message: AstrMessageEvent):
tokens = command_parser.parse(message.message_str)
if tokens.len == 1:
await self.provider.forget(message.session_id, keep_system_prompt=True)
await self.provider.forget(message.session_id)
message.set_result(MessageEventResult().message("重置成功"))
elif tokens.get(1) == 'p':
await self.provider.forget(message.session_id)
@@ -81,17 +82,13 @@ class OpenAIAdapterCommand:
message.set_result(MessageEventResult().message(f"历史记录:\n\n{contexts}\n{page} 页 | 共 {t_pages}\n\n*输入 /his 2 跳转到第 2 页"))
def status(self, message: AstrMessageEvent):
keys_data = self.provider.get_keys_data()
ret = "OpenAI Key"
keys_data = self.provider.get_all_keys()
ret = "{} Key"
for k in keys_data:
status = "🟢" if keys_data[k] else "🔴"
ret += "\n|- " + k[:8] + " " + status
ret += "\n|- " + k[:8]
ret += "\n当前模型: " + self.provider.get_model()
if message.session_id in self.provider.session_memory and len(self.provider.session_memory[message.session_id]):
ret += "\n你的会话上下文: " + str(self.provider.session_memory[message.session_id][-1]['usage_tokens']) + " tokens"
message.set_result(MessageEventResult().message(ret).use_t2i(False))
async def switch(self, message: AstrMessageEvent):
@@ -160,18 +157,10 @@ class OpenAIAdapterCommand:
else:
ps = "".join(l[1:]).strip()
if ps in personalities:
self.provider.curr_personality = {
'name': ps,
'prompt': personalities[ps]
}
self.provider.personality_set(self.provider.curr_personality, message.session_id)
self.provider.curr_personality = Personality(name=ps, prompt=personalities[ps])
message.set_result(MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
else:
self.provider.curr_personality = {
'name': '自定义人格',
'prompt': ps
}
self.provider.personality_set(self.provider.curr_personality, message.session_id)
self.provider.curr_personality = Personality(name="自定义人格", prompt=ps)
message.set_result(MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
async def draw(self, message: AstrMessageEvent):

View File

@@ -5,28 +5,34 @@ from .openai_adapter import ProviderOpenAIOfficial
from .commands import OpenAIAdapterCommand
from astrbot.api import logger
from . import PLUGIN_NAME
from astrbot.api import Image, Plain, MessageChain
from astrbot.api import MessageChain
from astrbot.api.message_components import Image, Plain
from openai._exceptions import *
from openai.types.chat.chat_completion_message_tool_call import Function
from astrbot.api import command_parser
from .web_searcher import search_from_bing, fetch_website_content
from astrbot.core.utils.metrics import Metric
from astrbot.core.config.astrbot_config import LLMConfig
from .atri import ATRI
class Main:
def __init__(self, context: Context) -> None:
supported_provider_names = ["openai", "ollama", "gemini", "deepseek", "zhipu"]
self.context = context
# 各 Provider 实例
self.provider_insts: List[ProviderOpenAIOfficial] = []
# Provider 的配置
self.provider_llm_configs: List[LLMConfig] = []
# 当前使用的 Provider
self.provider = None
# 当前使用的 Provider 的配置
self.provider_config = None
llms_config = self.context.get_config().llm
atri_config = self.context.get_config().project_atri
loaded = False
for llm in llms_config:
for llm in self.context.get_config().llm:
if llm.enable:
if llm.name in supported_provider_names:
if not llm.key or not llm.enable:
@@ -36,20 +42,33 @@ class Main:
self.provider_llm_configs.append(llm)
loaded = True
logger.info(f"已启用 LLM Provider(OpenAI API 适配器): {llm.id}({llm.name})。")
if loaded:
self.command_handler = OpenAIAdapterCommand(self.context)
self.command_handler.set_provider(self.provider_insts[0])
self.context.register_listener(PLUGIN_NAME, "openai_adapter_chat", self.chat, "OpenAI Adapter LLM 调用监听器", after_commands=True)
self.context.register_listener(PLUGIN_NAME, "llm_chat_listener", self.chat, "llm_chat_listener", after_commands=True)
self.provider = self.command_handler.provider
self.provider_config = self.provider_llm_configs[0]
self.context.register_commands(PLUGIN_NAME, "provider", "查看当前 LLM Provider", 10, self.provider_info)
self.context.register_commands(PLUGIN_NAME, "websearch", "启用/关闭网页搜索", 10, self.web_search)
if self.context.get_config().llm_settings.web_search:
self.add_web_search_tools()
# load atri
self.atri = None
if atri_config.enable:
try:
self.atri = ATRI(self.provider_llm_configs, atri_config, self.context)
self.command_handler.provider = self.atri.atri_chat_provider
except ImportError as e:
logger.error(traceback.format_exc())
logger.error("载入 ATRI 失败。请确保使用 pip 安装了 requirements_atri.txt 下的库。")
self.atri = None
except BaseException as e:
logger.error(traceback.format_exc())
logger.error("载入 ATRI 失败。")
self.atri = None
def add_web_search_tools(self):
self.context.register_llm_tool("web_search", [{
"type": "string",
@@ -121,7 +140,10 @@ class Main:
async def chat(self, event: AstrMessageEvent):
if not event.is_wake_up():
return
if self.atri:
await self.atri.chat(event)
return
# prompt 前缀
if self.provider_config.prompt_prefix:
event.message_str = self.provider_config.prompt_prefix + event.message_str
@@ -131,6 +153,8 @@ class Main:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
break
tool_use_flag = False
llm_result = None
try:
if not self.context.llm_tools.empty():
@@ -177,7 +201,6 @@ class Main:
return
else:
# normal chat
tool_use_flag = False
# add user info to the prompt
if self.context.get_config().llm_settings.identifier:
user_id = event.message_obj.sender.user_id

View File

@@ -1,11 +1,8 @@
import os
import asyncio
import json
import time
import tiktoken
import threading
import traceback
import base64
import json
from openai import AsyncOpenAI
from openai.types.chat.chat_completion import ChatCompletion
@@ -17,90 +14,38 @@ from astrbot.api import Provider
from astrbot.core.config.astrbot_config import LLMConfig
from astrbot import logger
from typing import List, Dict
from dataclasses import asdict
class ProviderOpenAIOfficial(Provider):
def __init__(self, llm_config: LLMConfig, db_helper: BaseDatabase) -> None:
super().__init__()
def __init__(self, llm_config: LLMConfig, db_helper: BaseDatabase, persistant_history = True) -> None:
super().__init__(db_helper, llm_config.default_personality, persistant_history)
self.api_keys = []
self.chosen_api_key = None
self.base_url = None
self.llm_config = llm_config
self.keys_data = {} # 记录超额
if llm_config.key: self.api_keys = llm_config.key
if llm_config.api_base: self.base_url = llm_config.api_base
if not self.api_keys:
logger.warn("看起来你没有添加 OpenAI 的 API 密钥OpenAI LLM 能力将不会启用。")
else:
self.chosen_api_key = self.api_keys[0]
for key in self.api_keys:
self.keys_data[key] = True
self.api_keys = llm_config.key
if llm_config.api_base:
self.base_url = llm_config.api_base
self.chosen_api_key = self.api_keys[0]
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=self.base_url
)
self.set_model(llm_config.model_config.model)
if llm_config.image_generation_model_config:
self.image_generator_model_configs: Dict = asdict(llm_config.image_generation_model_config)
self.session_memory: Dict[str, List] = {} # 会话记忆
self.session_memory_lock = threading.Lock()
self.max_tokens = self.llm_config.model_config.max_tokens # 上下文窗口大小
logger.info("正在载入分词器 cl100k_base...")
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
self.DEFAULT_PERSONALITY = {
"prompt": self.llm_config.default_personality,
"name": "default"
}
self.curr_personality = self.DEFAULT_PERSONALITY
self.session_personality = {} # 记录了某个session是否已设置人格。
# 读取历史记录
self.db_helper = db_helper
try:
for history in db_helper.get_llm_history():
self.session_memory_lock.acquire()
self.session_memory[history.session_id] = json.loads(history.content)
self.session_memory_lock.release()
except BaseException as e:
logger.warning(f"读取 OpenAI LLM 对话历史记录 失败:{e}。仍可正常使用。")
# 定时保存历史记录
threading.Thread(target=self.dump_history, daemon=True).start()
def dump_history(self):
'''转储历史记录'''
time.sleep(30)
while True:
try:
for session_id, content in self.session_memory.items():
self.db_helper.update_llm_history(session_id, json.dumps(content))
except BaseException as e:
logger.error("保存 LLM 历史记录失败: " + str(e))
finally:
time.sleep(10*60)
def personality_set(self, personality: dict, session_id: str):
if not personality or not personality['prompt']: return
if session_id not in self.session_memory:
self.session_memory[session_id] = []
self.curr_personality = personality
self.session_personality = {} # 重置
new_record = {
"user": {
"role": "system",
"content": personality['prompt'],
},
'usage_tokens': 0, # 到该条目的总 token 数
'single-tokens': 0 # 该条目的 token 数
}
self.session_memory[session_id] = [new_record]
# 各类模型的配置
self.image_generator_model_configs = None
self.embedding_model_configs = None
if llm_config.image_generation_model_config and llm_config.image_generation_model_config.enable:
self.image_generator_model_configs: Dict = asdict(
llm_config.image_generation_model_config)
self.image_generator_model_configs.pop("enable")
if llm_config.embedding_model and llm_config.embedding_model.enable:
self.embedding_model_configs: Dict = asdict(
llm_config.embedding_model)
self.embedding_model_configs.pop("enable")
async def encode_image_bs64(self, image_url: str) -> str:
'''
@@ -108,29 +53,12 @@ class ProviderOpenAIOfficial(Provider):
'''
if image_url.startswith("http"):
image_url = await download_image_by_url(image_url)
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode('utf-8')
return "data:image/jpeg;base64," + image_bs64
return ''
async def retrieve_context(self, session_id: str):
'''
根据 session_id 获取保存的 OpenAI 格式的上下文
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
# 转换为 openai 要求的格式
context = []
for record in self.session_memory[session_id]:
if "user" in record and record['user']:
context.append(record['user'])
if "AI" in record and record['AI']:
context.append(record['AI'])
return context
async def get_models(self):
models = []
try:
@@ -140,47 +68,6 @@ class ProviderOpenAIOfficial(Provider):
self.client.base_url = bu + "/v1"
models = await self.client.models.list()
return models
async def assemble_context(self, session_id: str, prompt: str, image_url: str = None):
'''
组装上下文,并且根据当前上下文窗口大小截断
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
tokens_num = len(self.tokenizer.encode(prompt))
previous_total_tokens_num = 0 if not self.session_memory[session_id] else self.session_memory[session_id][-1]['usage_tokens']
message = {
"usage_tokens": previous_total_tokens_num + tokens_num,
"single_tokens": tokens_num,
"AI": None
}
if image_url:
base_64_image = await self.encode_image_bs64(image_url)
user_content = {
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": base_64_image
}
}
]
}
else:
user_content = {
"role": "user",
"content": prompt
}
message["user"] = user_content
self.session_memory[session_id].append(message)
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
@@ -188,10 +75,10 @@ class ProviderOpenAIOfficial(Provider):
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
@@ -206,156 +93,111 @@ class ProviderOpenAIOfficial(Provider):
record = self.session_memory[session_id].pop(i)
break
# 更新之后所有记录的 usage_tokens
for i in range(len(self.session_memory[session_id])):
self.session_memory[session_id][i]['usage_tokens'] -= record['single-tokens']
logger.debug(f"淘汰上下文记录 1 条,释放 {record['single-tokens']} 个 token。当前上下文总 token 为 {self.session_memory[session_id][-1]['usage_tokens']}")
return record
async def text_chat(self,
prompt: str,
session_id: str,
image_url=None,
tools=None,
async def assemble_context(self, contexts: List, text: str, image_urls: List[str] = None):
'''
组装上下文。
'''
if image_urls:
for image_url in image_urls:
base_64_image = await self.encode_image_bs64(image_url)
user_content = {"role": "user","content": [
{"type": "text", "text": text},
{"type": "image_url", "image_url": {"url": base_64_image}}
]}
contexts.append(user_content)
else:
user_content = {"role": "user","content": text}
contexts.append(user_content)
async def text_chat(self,
prompt: str,
session_id: str,
image_urls=None,
tools=None,
contexts=None,
**kwargs
) -> str:
) -> str:
'''
调用 LLM 进行文本对话。
@param tools: LLM Function-calling 的工具函数
@param contexts: 如果不为 None则会原封不动地使用这个上下文进行对话。
'''
if os.environ.get("TEST_LLM", "off") != "on" and os.environ.get("TEST_MODE", "off") == "on":
return "这是一个测试消息。"
if not session_id:
session_id = "unknown"
if "unknown" in self.session_memory:
del self.session_memory["unknown"]
if session_id not in self.session_memory:
self.session_memory[session_id] = []
if session_id not in self.session_personality or not self.session_personality[session_id]:
self.personality_set(self.curr_personality, session_id)
self.session_personality[session_id] = True
# 组装上下文,并且根据当前上下文窗口大小截断
await self.assemble_context(session_id, prompt, image_url)
# 获取上下文openai 格式
contexts = await self.retrieve_context(session_id)
logger.debug(f"OpenAI 请求上下文:{contexts}")
await self.assemble_context(self.session_memory[session_id], prompt, image_urls)
if not contexts:
contexts = [*self.session_memory[session_id]]
if self.curr_personality["prompt"]:
contexts.insert(0, {"role": "system", "content": self.curr_personality["prompt"]})
logger.debug(f"请求上下文:{contexts}")
conf = asdict(self.llm_config.model_config)
if tools:
conf['tools'] = tools
# start request
retry = 0
rate_limit_retry = 0
while retry < 3 or rate_limit_retry < 5:
if tools:
completion_coro = self.client.chat.completions.create(
messages=contexts,
stream=False,
tools=tools,
**conf
)
else:
completion_coro = self.client.chat.completions.create(
messages=contexts,
stream=False,
**conf
)
while retry < 3:
completion_coro = self.client.chat.completions.create(
messages=contexts,
stream=False,
**conf
)
try:
completion = await completion_coro
break
except AuthenticationError as e:
api_key = self.chosen_api_key[10:] + "..."
logger.error(f"OpenAI API Key {api_key} 验证错误。详细原因:{e}。正在切换到下一个可用的 Key如果有的话")
self.keys_data[self.chosen_api_key] = False
ok = await self.switch_to_next_key()
if ok: continue
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
except RateLimitError as e:
if "You exceeded your current quota" in str(e):
self.keys_data[self.chosen_api_key] = False
ok = await self.switch_to_next_key()
if ok: continue
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}")
await self.switch_to_next_key()
rate_limit_retry += 1
await asyncio.sleep(1)
except BadRequestError as e:
raise e
except NotFoundError as e:
raise e
except Exception as e:
retry += 1
if retry >= 3:
logger.error(traceback.format_exc())
raise Exception(f"OpenAI 请求失败:{e}。重试次数已达到上限。")
raise Exception(f"请求失败:{e}。重试次数已达到上限。")
if "maximum context length" in str(e):
logger.warn(f"OpenAI 请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。")
logger.warning(f"请求失败:{e}。重试第 {retry} 次。")
await asyncio.sleep(1)
assert isinstance(completion, ChatCompletion)
logger.debug(f"openai completion: {completion.usage}")
logger.debug(f"completion: {completion.usage}")
if len(completion.choices) == 0:
raise Exception("OpenAI API 返回的 completion 为空。")
raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0]
usage_tokens = completion.usage.total_tokens
completion_tokens = completion.usage.completion_tokens
self.session_memory[session_id][-1]['usage_tokens'] = usage_tokens
self.session_memory[session_id][-1]['single_tokens'] += completion_tokens
if choice.message.content:
# 返回文本
completion_text = str(choice.message.content).strip()
self.session_memory[session_id].append({
"role": "assistant",
"content": completion_text
})
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]))
return completion_text
elif choice.message.tool_calls and choice.message.tool_calls:
# tools call (function calling)
return choice.message.tool_calls[0].function
self.session_memory[session_id][-1]['AI'] = {
"role": "assistant",
"content": completion_text
}
else:
raise Exception("Internal Error")
return completion_text
async def switch_to_next_key(self):
'''
切换到下一个 API Key
'''
if not self.api_keys:
logger.error("OpenAI API Key 不存在。")
return False
for key in self.keys_data:
if self.keys_data[key]:
# 没超额
self.chosen_api_key = key
self.client.api_key = key
logger.info(f"OpenAI 切换到 API Key {key[:10]}... 成功。")
return True
return False
async def image_generate(self, prompt: str, session_id: str = None, **kwargs) -> str:
'''
生成图片
'''
retry = 0
conf = self.image_generator_model_configs
if not conf:
logger.error("图片生成模型配置不存在。")
raise Exception("图片生成模型配置不存在。")
conf.pop("enable")
if not self.image_generator_model_configs:
return
while retry < 3:
try:
images_response = await self.client.images.generate(
prompt=prompt,
**conf
**self.image_generator_model_configs
)
image_url = images_response.data[0].url
return image_url
@@ -367,15 +209,25 @@ class ProviderOpenAIOfficial(Provider):
logger.warning(f"图片生成请求失败:{e}。重试第 {retry} 次。")
await asyncio.sleep(1)
async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool:
if session_id is None: return False
async def get_embedding(self, text) -> List[float]:
'''
获取文本的嵌入
'''
if not self.embedding_model_configs:
return
try:
embedding = await self.client.embeddings.create(
input=text,
**self.embedding_model_configs
)
return embedding.data[0].embedding
except Exception as e:
logger.error(f"获取文本嵌入失败:{e}")
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
if keep_system_prompt:
self.personality_set(self.curr_personality, session_id)
else:
self.curr_personality = self.DEFAULT_PERSONALITY
return True
def dump_contexts_page(self, session_id: str, size=5, page=1,):
'''
获取缓存的会话
@@ -383,25 +235,21 @@ class ProviderOpenAIOfficial(Provider):
contexts_str = ""
if session_id in self.session_memory:
for record in self.session_memory[session_id]:
if "user" in record and record['user']:
text = record['user']['content'][:100] + "..." if len(record['user']['content']) > 100 else record['user']['content']
if record['role'] == "user":
text = record['content'][:100] + "..." if len(
record['content']) > 100 else record['content']
contexts_str += f"User: {text}\n\n"
if "AI" in record and record['AI']:
text = record['AI']['content'][:100] + "..." if len(record['AI']['content']) > 100 else record['AI']['content']
elif record['role'] == "assistant":
text = record['content'][:100] + "..." if len(
record['content']) > 100 else record['content']
contexts_str += f"Assistant: {text}\n\n"
else:
contexts_str = "会话 ID 不存在。"
return contexts_str, len(self.session_memory[session_id])
def get_configs(self):
return asdict(self.llm_config)
def get_keys_data(self):
return self.keys_data
def get_curr_key(self):
return self.chosen_api_key
def set_key(self, key):
self.client.api_key = key
def get_all_keys(self):
return self.api_keys

View File

@@ -1,14 +1,12 @@
pydantic~=1.10.4
pydantic
vchat
aiohttp
openai
qq-botpy
chardet~=5.1.0
Pillow
nakuru-project
beautifulsoup4
googlesearch-python
tiktoken
readability-lxml
quart
psutil