216 lines
8.6 KiB
Python
216 lines
8.6 KiB
Python
import time
|
|
import re
|
|
import asyncio
|
|
import traceback
|
|
import astrbot.message.unfit_words as uw
|
|
|
|
from typing import Dict
|
|
from astrbot.persist.helper import dbConn
|
|
from model.provider.provider import Provider
|
|
from model.command.manager import CommandManager
|
|
from type.message_event import AstrMessageEvent, MessageResult
|
|
from type.types import Context
|
|
from type.command import CommandResult
|
|
from SparkleLogging.utils.core import LogManager
|
|
from logging import Logger
|
|
from nakuru.entities.components import Image
|
|
import util.agent.web_searcher as web_searcher
|
|
|
|
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
|
|
|
|
|
class RateLimitHelper():
|
|
def __init__(self, context: Context) -> None:
|
|
self.user_rate_limit: Dict[int, int] = {}
|
|
self.rate_limit_time: int = 60
|
|
self.rate_limit_count: int = 10
|
|
self.user_frequency = {}
|
|
|
|
if 'limit' in context.base_config:
|
|
if 'count' in context.base_config['limit']:
|
|
self.rate_limit_count = context.base_config['limit']['count']
|
|
if 'time' in context.base_config['limit']:
|
|
self.rate_limit_time = context.base_config['limit']['time']
|
|
|
|
def check_frequency(self, session_id: str) -> bool:
|
|
'''
|
|
检查发言频率
|
|
'''
|
|
ts = int(time.time())
|
|
if session_id in self.user_frequency:
|
|
if ts-self.user_frequency[session_id]['time'] > self.rate_limit_time:
|
|
self.user_frequency[session_id]['time'] = ts
|
|
self.user_frequency[session_id]['count'] = 1
|
|
return True
|
|
else:
|
|
if self.user_frequency[session_id]['count'] >= self.rate_limit_count:
|
|
return False
|
|
else:
|
|
self.user_frequency[session_id]['count'] += 1
|
|
return True
|
|
else:
|
|
t = {'time': ts, 'count': 1}
|
|
self.user_frequency[session_id] = t
|
|
return True
|
|
|
|
class ContentSafetyHelper():
|
|
def __init__(self, context: Context) -> None:
|
|
self.baidu_judge = None
|
|
if 'baidu_api' in context.base_config and \
|
|
'enable' in context.base_config['baidu_aip'] and \
|
|
context.base_config['baidu_aip']['enable']:
|
|
try:
|
|
from astrbot.message.baidu_aip_judge import BaiduJudge
|
|
self.baidu_judge = BaiduJudge(context.base_config['baidu_aip'])
|
|
logger.info("已启用百度 AI 内容审核。")
|
|
except BaseException as e:
|
|
logger.error("百度 AI 内容审核初始化失败。")
|
|
logger.error(e)
|
|
|
|
async def check_content(self, content: str) -> bool:
|
|
'''
|
|
检查文本内容是否合法
|
|
'''
|
|
for i in uw.unfit_words_q:
|
|
matches = re.match(i, content.strip(), re.I | re.M)
|
|
if matches:
|
|
return False
|
|
if self.baidu_judge != None:
|
|
check, msg = await asyncio.to_thread(self.baidu_judge.judge, content)
|
|
if not check:
|
|
logger.info(f"百度 AI 内容审核发现以下违规:{msg}")
|
|
return False
|
|
return True
|
|
|
|
def filter_content(self, content: str) -> str:
|
|
'''
|
|
过滤文本内容
|
|
'''
|
|
for i in uw.unfit_words_q:
|
|
content = re.sub(i, "*", content, flags=re.I)
|
|
return content
|
|
|
|
def baidu_check(self, content: str) -> bool:
|
|
'''
|
|
使用百度 AI 内容审核检查文本内容是否合法
|
|
'''
|
|
if self.baidu_judge != None:
|
|
check, msg = self.baidu_judge.judge(content)
|
|
if not check:
|
|
logger.info(f"百度 AI 内容审核发现以下违规:{msg}")
|
|
return False
|
|
return True
|
|
|
|
class MessageHandler():
|
|
def __init__(self, context: Context,
|
|
command_manager: CommandManager,
|
|
persist_manager: dbConn,
|
|
provider: Provider) -> None:
|
|
self.context = context
|
|
self.command_manager = command_manager
|
|
self.persist_manager = persist_manager
|
|
self.rate_limit_helper = RateLimitHelper(context)
|
|
self.content_safety_helper = ContentSafetyHelper(context)
|
|
self.llm_wake_prefix = self.context.base_config['llm_wake_prefix']
|
|
if self.llm_wake_prefix:
|
|
self.llm_wake_prefix = self.llm_wake_prefix.strip()
|
|
self.nicks = self.context.nick
|
|
self.provider = provider
|
|
self.reply_prefix = str(self.context.reply_prefix)
|
|
|
|
def set_provider(self, provider: Provider):
|
|
self.provider = provider
|
|
|
|
async def handle(self, message: AstrMessageEvent, llm_provider: Provider = None) -> MessageResult:
|
|
'''
|
|
Handle the message event, including commands, plugins, etc.
|
|
|
|
`llm_provider`: the provider to use for LLM. If None, use the default provider
|
|
'''
|
|
msg_plain = message.message_str.strip()
|
|
provider = llm_provider if llm_provider else self.provider
|
|
inner_provider = False if llm_provider else True
|
|
|
|
self.persist_manager.record_message(message.platform.platform_name, message.session_id)
|
|
|
|
# TODO: this should be configurable
|
|
# if not message.message_str:
|
|
# return MessageResult("Hi~")
|
|
|
|
# check the rate limit
|
|
if not message.only_command and not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
|
|
# return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置 {self.rate_limit_helper.rate_limit_time} 秒内只能提问{self.rate_limit_helper.rate_limit_count} 次。')
|
|
logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制, 跳过。")
|
|
return
|
|
|
|
# remove the nick prefix
|
|
for nick in self.nicks:
|
|
if msg_plain.startswith(nick):
|
|
msg_plain = msg_plain.removeprefix(nick)
|
|
break
|
|
message.message_str = msg_plain
|
|
|
|
# scan candidate commands
|
|
cmd_res = await self.command_manager.scan_command(message, self.context)
|
|
if cmd_res:
|
|
assert(isinstance(cmd_res, CommandResult))
|
|
return MessageResult(
|
|
cmd_res.message_chain,
|
|
is_command_call=True,
|
|
use_t2i=cmd_res.is_use_t2i
|
|
)
|
|
|
|
# next is the LLM part
|
|
|
|
if message.only_command:
|
|
return
|
|
|
|
# check if the message is a llm-wake-up command
|
|
if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix):
|
|
logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。")
|
|
return
|
|
|
|
if not provider:
|
|
logger.debug("没有任何 LLM 可用,忽略。")
|
|
return
|
|
|
|
# check the content safety
|
|
if not await self.content_safety_helper.check_content(msg_plain):
|
|
return MessageResult("信息包含违规内容,由于机器人管理者开启内容安全审核,你的此条消息已被停止继续处理。")
|
|
|
|
image_url = None
|
|
for comp in message.message_obj.message:
|
|
if isinstance(comp, Image):
|
|
image_url = comp.url if comp.url else comp.file
|
|
break
|
|
web_search = self.context.web_search
|
|
if not web_search and msg_plain.startswith("ws"):
|
|
# leverage web search feature
|
|
web_search = True
|
|
msg_plain = msg_plain.removeprefix("ws").strip()
|
|
|
|
try:
|
|
if web_search:
|
|
llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, official_fc=True)
|
|
else:
|
|
llm_result = await provider.text_chat(
|
|
prompt=msg_plain,
|
|
session_id=message.session_id,
|
|
image_url=image_url
|
|
)
|
|
except BaseException as e:
|
|
logger.error(traceback.format_exc())
|
|
logger.error(f"LLM 调用失败。")
|
|
return MessageResult("AstrBot 请求 LLM 资源失败:" + str(e))
|
|
|
|
# concatenate the reply prefix
|
|
if self.reply_prefix:
|
|
llm_result = self.reply_prefix + llm_result
|
|
|
|
# mask the unsafe content
|
|
llm_result = self.content_safety_helper.filter_content(llm_result)
|
|
check = self.content_safety_helper.baidu_check(llm_result)
|
|
if not check:
|
|
return MessageResult("LLM 输出的信息包含违规内容,由于机器人管理者开启了内容安全审核,该条消息已拦截。")
|
|
|
|
return MessageResult(llm_result) |