Files
AstrBot/packages/astrbot_plugin_openai/commands.py
Soulter 750a93a1aa remove: 移除了 nakuru-project 库
但仍然使用其对 OneBot 的数据格式封装。
2024-12-02 19:31:33 +08:00

169 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from astrbot.api import Context, AstrMessageEvent, MessageEventResult, MessageChain
from . import PLUGIN_NAME
from astrbot.api import logger
from astrbot.api.message_components import Image, Plain
from astrbot.api import personalities
from astrbot.api import command_parser
from astrbot.api import Provider, Personality
class OpenAIAdapterCommand:
def __init__(self, context: Context) -> None:
self.provider: Provider = None
self.context = context
context.register_commands(PLUGIN_NAME, "reset", "重置会话", 10, self.reset)
context.register_commands(PLUGIN_NAME, "his", "查看历史记录", 10, self.his)
context.register_commands(PLUGIN_NAME, "status", "查看当前状态", 10, self.status)
context.register_commands(PLUGIN_NAME, "switch", "切换账号", 10, self.switch)
context.register_commands(PLUGIN_NAME, "persona", "设置个性化人格", 10, self.persona)
context.register_commands(PLUGIN_NAME, "draw", "调用 DallE 模型画图", 10, self.draw)
context.register_commands(PLUGIN_NAME, "model", "切换 LLM 模型", 10, self.model)
context.register_commands(PLUGIN_NAME, "", "调用 DallE 模型画图", 10, self.draw)
def set_provider(self, provider: Provider):
self.provider = provider
async def reset(self, message: AstrMessageEvent):
tokens = command_parser.parse(message.message_str)
if tokens.len == 1:
await self.provider.forget(message.session_id)
message.set_result(MessageEventResult().message("重置成功"))
elif tokens.get(1) == 'p':
await self.provider.forget(message.session_id)
async def model(self, message: AstrMessageEvent):
tokens = command_parser.parse(message.message_str)
if tokens.len == 1:
ret = await self._print_models()
message.set_result(MessageEventResult().message(ret).use_t2i(False))
return
model = tokens.get(1)
if model.isdigit():
try:
models = await self.provider.get_models()
except BaseException as e:
logger.error(f"获取模型列表失败: {str(e)}。如果出现 404可能与服务提供商未提供模型列表有关。")
message.set_result(MessageEventResult().message("获取模型列表失败,无法使用编号切换模型。可以尝试直接输入模型名来切换,如 gpt-4o。"))
models = list(models)
if int(model) <= len(models) and int(model) >= 1:
model = models[int(model)-1]
self.provider.set_model(model.id)
message.set_result(MessageEventResult().message(f"模型已设置为 {model.id}"))
else:
self.provider.set_model(model)
message.set_result(MessageEventResult().message(f"模型已设置为 {model} (自定义)"))
async def _print_models(self):
models = []
try:
models = await self.provider.get_models()
except BaseException as e:
return "获取模型列表失败: " + str(e)
i = 1
ret = "下面列出了此服务提供商可用模型:"
for model in models:
ret += f"\n{i}. {model.id}"
i += 1
ret += "\nTips: 使用 /model 模型名/编号,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
logger.debug(ret)
return ret
def his(self, message: AstrMessageEvent):
tokens = command_parser.parse(message.message_str)
size_per_page = 3
page = 1
if tokens.len == 2:
try:
page = int(tokens.get(1))
except BaseException as e:
message.set_result(MessageEventResult().message("页码格式错误"))
contexts, total_num = self.provider.dump_contexts_page(message.session_id, size_per_page, page=page)
t_pages = total_num // size_per_page + 1
message.set_result(MessageEventResult().message(f"历史记录:\n\n{contexts}\n{page} 页 | 共 {t_pages}\n\n*输入 /his 2 跳转到第 2 页"))
def status(self, message: AstrMessageEvent):
keys_data = self.provider.get_all_keys()
ret = "{} Key"
for k in keys_data:
ret += "\n|- " + k[:8]
ret += "\n当前模型: " + self.provider.get_model()
message.set_result(MessageEventResult().message(ret).use_t2i(False))
async def switch(self, message: AstrMessageEvent):
'''
切换账号
'''
tokens = command_parser.parse(message.message_str)
if tokens.len == 1:
ret = ""
curr_ = self.provider.get_curr_key()
if curr_ is None:
ret += "当前您未选择账号。输入/switch <账号序号>切换账号。使用 /status 查看账号列表。"
else:
ret += f"当前您选择的账号为:{curr_[:8]}。输入/switch <账号序号>切换账号。使用 /status 查看账号列表。"
message.set_result(MessageEventResult().message(ret))
elif tokens.len == 2:
try:
key_stat = self.provider.get_keys_data()
index = int(tokens.get(1))
if index > len(key_stat) or index < 1:
message.set_result(MessageEventResult().message("账号序号错误。"))
else:
try:
new_key = list(key_stat.keys())[index-1]
self.provider.set_key(new_key)
except BaseException as e:
message.set_result(MessageEventResult().message("切换账号未知错误: "+str(e)))
message.set_result(MessageEventResult().message("切换账号成功。") )
except BaseException as e:
message.set_result(MessageEventResult().message("切换账号错误。"))
else:
message.set_result(MessageEventResult().message("参数过多。"))
def persona(self, message: AstrMessageEvent):
l = message.message_str.split(" ")
if len(l) == 1:
message.set_result(
MessageEventResult().message(f"""[Persona]
- 设置人格: `/persona 人格名`, 如 /persona 编剧
- 人格列表: `/persona list`
- 人格详细信息: `/persona view 人格名`
- 自定义人格: /persona 人格文本
- 重置 LLM 会话(清除人格): /reset
- 重置 LLM 会话(保留人格): /reset p
【当前人格】: {str(self.provider.curr_personality['prompt'])}
"""))
elif l[1] == "list":
msg = "人格列表:\n"
for key in personalities.keys():
msg += f"- {key}\n"
msg += '\n\n*输入 `/persona view 人格名` 查看人格详细信息'
message.set_result(MessageEventResult().message(msg))
elif l[1] == "view":
if len(l) == 2:
message.set_result(MessageEventResult().message("请输入人格名"))
ps = l[2].strip()
if ps in personalities:
msg = f"人格{ps}的详细信息:\n"
msg += f"{personalities[ps]}\n"
else:
msg = f"人格{ps}不存在"
message.set_result(MessageEventResult().message(msg))
else:
ps = "".join(l[1:]).strip()
if ps in personalities:
self.provider.curr_personality = Personality(name=ps, prompt=personalities[ps])
message.set_result(MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
else:
self.provider.curr_personality = Personality(name="自定义人格", prompt=ps)
message.set_result(MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
async def draw(self, message: AstrMessageEvent):
prompt = message.message_str.removeprefix("")
img_url = await self.provider.image_generate(prompt)
message.set_result(MessageEventResult().url_image(img_url))