Compare commits
62 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
752201cb46 | ||
|
|
deebf61b5f | ||
|
|
d5e5b06e86 | ||
|
|
cb5975c102 | ||
|
|
5b1aee1b4d | ||
|
|
510c8b4236 | ||
|
|
89fc7b0553 | ||
|
|
123c21fcb3 | ||
|
|
75d62d66f9 | ||
|
|
23a8e989a5 | ||
|
|
9577e637f1 | ||
|
|
e51ef2201b | ||
|
|
f4ae503abf | ||
|
|
3424b658f3 | ||
|
|
3198f73f3d | ||
|
|
aa3262a8ab | ||
|
|
6acd7be547 | ||
|
|
fb7669ddad | ||
|
|
f2c4ef126e | ||
|
|
33dcc4c152 | ||
|
|
b9e331ebd6 | ||
|
|
7832ec386e | ||
|
|
b9828428cc | ||
|
|
da11034aec | ||
|
|
578c9e0695 | ||
|
|
cc675a9b4f | ||
|
|
08e7d4d0c6 | ||
|
|
553f1b8d83 | ||
|
|
73e7e2088d | ||
|
|
e40c9de610 | ||
|
|
2f4e0bb4f2 | ||
|
|
191976e22e | ||
|
|
52656b8586 | ||
|
|
998e29ded6 | ||
|
|
5bbe3f12d6 | ||
|
|
56aea81ed7 | ||
|
|
7b8a311dde | ||
|
|
b75d20a3e8 | ||
|
|
67faa587b6 | ||
|
|
15fde686d4 | ||
|
|
741284f6e8 | ||
|
|
8352fc269b | ||
|
|
5852f36557 | ||
|
|
cc1c723c12 | ||
|
|
adf5cbfeba | ||
|
|
d6d0516c9a | ||
|
|
8aab10aaf3 | ||
|
|
4fe5616ae1 | ||
|
|
7e1c76a3f5 | ||
|
|
f74665ff71 | ||
|
|
a96d64fe88 | ||
|
|
fd2aa0cba6 | ||
|
|
a92ea3db02 | ||
|
|
d7a513b640 | ||
|
|
8a017ff693 | ||
|
|
7d08f57b32 | ||
|
|
6f4ad7890b | ||
|
|
37488118a6 | ||
|
|
b2da0778ae | ||
|
|
cc887a5037 | ||
|
|
ca86a02d30 | ||
|
|
d652dc19a6 |
10
README.md
10
README.md
@@ -17,7 +17,7 @@
|
||||
</div>
|
||||
|
||||
## 🤔您可能想了解的
|
||||
- **如何部署?** [帮助文档](https://github.com/Soulter/QQChannelChatGPT/wiki)
|
||||
- **如何部署?** [帮助文档](https://github.com/Soulter/QQChannelChatGPT/wiki) (部署不成功欢迎进群捞人解决<3)
|
||||
- **go-cqhttp启动不成功、报登录失败?** [在这里搜索解决方法](https://github.com/Mrs4s/go-cqhttp/issues)
|
||||
- **程序闪退/机器人启动不成功?** [提交issue或加群反馈](https://github.com/Soulter/QQChannelChatGPT/issues)
|
||||
- **如何开启ChatGPT、Bard、Claude等语言模型?** [查看帮助](https://github.com/Soulter/QQChannelChatGPT/wiki/%E8%A1%A5%E5%85%85%EF%BC%9A%E5%A6%82%E4%BD%95%E5%BC%80%E5%90%AFChatGPT%E3%80%81Bard%E3%80%81Claude%E7%AD%89%E8%AF%AD%E8%A8%80%E6%A8%A1%E5%9E%8B%EF%BC%9F)
|
||||
@@ -48,7 +48,7 @@
|
||||
- 大模型对话
|
||||
- 大模型网页搜索能力 **(目前仅支持OpenAI系的模型,最新版本下使用web on指令打开)**
|
||||
- 插件安装(在QQ或QQ频道聊天框内输入`plugin`了解详情)
|
||||
- 回复文字图片渲染(以图片markdown格式回复,降低被风控概率,需手动在`cmd_config.json`内开启)
|
||||
- 回复文字图片渲染(以图片markdown格式回复,**大幅度降低被风控概率**,需手动在`cmd_config.json`内开启qq_pic_mode)
|
||||
- 人格设置
|
||||
- 关键词回复
|
||||
- 热更新(更新本项目时**仅需**在QQ或QQ频道聊天框内输入`update latest r`)
|
||||
@@ -121,7 +121,7 @@
|
||||
|
||||
插件开发教程:https://github.com/Soulter/QQChannelChatGPT/wiki/%E5%9B%9B%E3%80%81%E5%BC%80%E5%8F%91%E6%8F%92%E4%BB%B6
|
||||
|
||||
部分公开的插件:
|
||||
部分插件:
|
||||
|
||||
- `LLMS`: https://github.com/Soulter/llms | Claude, HuggingChat 大语言模型接入。
|
||||
|
||||
@@ -129,7 +129,9 @@
|
||||
|
||||
- `sysstat`: https://github.com/Soulter/sysstatqcbot | 查看系统状态
|
||||
|
||||
- `BiliMonitor`: https://github.com/Soulter/BiliMonitor | 订阅B站动态!
|
||||
- `BiliMonitor`: https://github.com/Soulter/BiliMonitor | 订阅B站动态
|
||||
|
||||
- `liferestart`: https://github.com/Soulter/liferestart | 人生重开模拟器
|
||||
|
||||
<!--
|
||||
### 指令
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
16
cores/qqbot/global_object.py
Normal file
16
cores/qqbot/global_object.py
Normal file
@@ -0,0 +1,16 @@
|
||||
class GlobalObject():
|
||||
'''
|
||||
存放一些公用的数据,用于在不同模块(如core与command)之间传递
|
||||
'''
|
||||
def __init__(self):
|
||||
self.nick = None # gocq 的昵称
|
||||
self.base_config = None # config.yaml
|
||||
self.cached_plugins = {} # 缓存的插件
|
||||
self.web_search = False # 是否开启了网页搜索
|
||||
self.reply_prefix = None
|
||||
self.admin_qq = "123456"
|
||||
self.admin_qqchan = "123456"
|
||||
self.uniqueSession = False
|
||||
self.cnt_total = 0
|
||||
self.platform_qq = None
|
||||
self.platform_qqchan = None
|
||||
80
main.py
80
main.py
@@ -1,19 +1,32 @@
|
||||
import os, sys
|
||||
from pip._internal import main as pipmain
|
||||
import warnings
|
||||
import traceback
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
|
||||
|
||||
def main():
|
||||
|
||||
# config.yaml 配置文件加载和环境确认
|
||||
try:
|
||||
import cores.qqbot.core as qqBot
|
||||
import yaml
|
||||
from yaml.scanner import ScannerError
|
||||
import util.general_utils as gu
|
||||
ymlfile = open(abs_path+"configs/config.yaml", 'r', encoding='utf-8')
|
||||
cfg = yaml.safe_load(ymlfile)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
input("第三方依赖库未完全安装完毕,请退出程序重试。")
|
||||
exit()
|
||||
import util.general_utils as gu
|
||||
except ImportError as import_error:
|
||||
print(import_error)
|
||||
input("第三方库未完全安装完毕,请退出程序重试。")
|
||||
except FileNotFoundError as file_not_found:
|
||||
print(file_not_found)
|
||||
input("配置文件不存在,请检查是否已经下载配置文件。")
|
||||
except ScannerError as e:
|
||||
print(traceback.format_exc())
|
||||
input("config.yaml 配置文件格式错误,请遵守 yaml 格式。")
|
||||
|
||||
# 设置代理
|
||||
if 'http_proxy' in cfg:
|
||||
os.environ['HTTP_PROXY'] = cfg['http_proxy']
|
||||
if 'https_proxy' in cfg:
|
||||
@@ -21,21 +34,20 @@ def main():
|
||||
|
||||
os.environ['NO_PROXY'] = 'cn.bing.com,https://api.sgroup.qq.com'
|
||||
|
||||
# 检查temp文件夹
|
||||
if not os.path.exists(abs_path+"temp"):
|
||||
# 检查并创建 temp 文件夹
|
||||
if not os.path.exists(abs_path + "temp"):
|
||||
os.mkdir(abs_path+"temp")
|
||||
|
||||
# 选择默认模型
|
||||
provider = privider_chooser(cfg)
|
||||
if len(provider) == 0:
|
||||
gu.log("未开启任何语言模型, 请在configs/config.yaml下选择开启相应语言模型。", gu.LEVEL_CRITICAL)
|
||||
input("按任意键退出...")
|
||||
exit()
|
||||
gu.log("注意:您目前未开启任何语言模型。", gu.LEVEL_WARNING)
|
||||
print('[System] 开启的语言模型: ' + str(provider))
|
||||
# 执行Bot
|
||||
|
||||
# 启动主程序(cores/qqbot/core.py)
|
||||
qqBot.initBot(cfg, provider)
|
||||
|
||||
# 语言模型提供商选择器
|
||||
# 目前有:OpenAI官方API、逆向库
|
||||
def privider_chooser(cfg):
|
||||
l = []
|
||||
if 'rev_ChatGPT' in cfg and cfg['rev_ChatGPT']['enable']:
|
||||
@@ -48,55 +60,44 @@ def privider_chooser(cfg):
|
||||
l.append('openai_official')
|
||||
return l
|
||||
|
||||
def check_env():
|
||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 8):
|
||||
print("请使用Python3.8运行本项目")
|
||||
def check_env(ch_mirror=False):
|
||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 9):
|
||||
print("请使用Python3.9+运行本项目")
|
||||
input("按任意键退出...")
|
||||
exit()
|
||||
|
||||
# 检查pip
|
||||
# pip_tag = "pip"
|
||||
# mm = os.system("pip -V")
|
||||
# if mm != 0:
|
||||
# mm1 = os.system("pip3 -V")
|
||||
# if mm1 != 0:
|
||||
# print("未检测到pip, 请安装Python(版本应>=3.9)")
|
||||
# input("按任意键退出...")
|
||||
# exit()
|
||||
# else:
|
||||
# pip_tag = "pip3"
|
||||
|
||||
if os.path.exists('requirements.txt'):
|
||||
pth = 'requirements.txt'
|
||||
else:
|
||||
pth = 'QQChannelChatGPT'+ os.sep +'requirements.txt'
|
||||
print("正在更新三方依赖库...")
|
||||
print("正在检查更新第三方库...")
|
||||
try:
|
||||
pipmain(['install', '-r', pth])
|
||||
print("依赖库安装完毕。")
|
||||
if ch_mirror:
|
||||
print("使用阿里云镜像")
|
||||
pipmain(['install', '-r', pth, '-i', 'https://mirrors.aliyun.com/pypi/simple/', '--quiet'])
|
||||
else:
|
||||
pipmain(['install', '-r', pth, '--quiet'])
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
while True:
|
||||
res = input("依赖库可能安装失败了。\n如果是报错ValueError: check_hostname requires server_hostname,请尝试先关闭代理后重试。\n输入y回车重试\n输入c回车使用国内镜像源下载\n输入其他按键回车继续往下执行。")
|
||||
res = input("安装失败。\n如报错ValueError: check_hostname requires server_hostname,请尝试先关闭代理后重试。\n1.输入y回车重试\n2. 输入c回车使用国内镜像源下载\n3. 输入其他按键回车继续往下执行。")
|
||||
if res == "y":
|
||||
try:
|
||||
pipmain(['install', '-r', pth])
|
||||
print("依赖库安装完毕。")
|
||||
break
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
elif res == "c":
|
||||
try:
|
||||
pipmain(['install', '-r', pth, '-i', 'https://mirrors.aliyun.com/pypi/simple/'])
|
||||
print("依赖库安装完毕。")
|
||||
break
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
print("第三方库检查完毕。")
|
||||
|
||||
def get_platform():
|
||||
import platform
|
||||
@@ -111,12 +112,15 @@ def get_platform():
|
||||
print("other")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = sys.argv
|
||||
|
||||
if '-cn' in args:
|
||||
check_env(True)
|
||||
else:
|
||||
check_env()
|
||||
|
||||
# 获取参数
|
||||
args = sys.argv
|
||||
if len(args) > 1:
|
||||
if args[1] == '-replit':
|
||||
if '-replit' in args:
|
||||
print("[System] 启动Replit Web保活服务...")
|
||||
try:
|
||||
from webapp_replit import keep_alive
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
import json
|
||||
import git.exc
|
||||
from git.repo import Repo
|
||||
|
||||
has_git = True
|
||||
try:
|
||||
import git.exc
|
||||
from git.repo import Repo
|
||||
except BaseException as e:
|
||||
print("你正运行在无Git环境下,暂时将无法使用插件、热更新功能。")
|
||||
has_git = False
|
||||
|
||||
import os
|
||||
import sys
|
||||
import requests
|
||||
@@ -18,14 +25,71 @@ from nakuru.entities.components import (
|
||||
Image
|
||||
)
|
||||
from PIL import Image as PILImage
|
||||
from cores.qqbot.global_object import GlobalObject
|
||||
from pip._internal import main as pipmain
|
||||
|
||||
PLATFORM_QQCHAN = 'qqchan'
|
||||
PLATFORM_GOCQ = 'gocq'
|
||||
|
||||
# 指令功能的基类,通用的(不区分语言模型)的指令就在这实现
|
||||
class Command:
|
||||
def __init__(self, provider: Provider):
|
||||
self.provider = Provider
|
||||
def __init__(self, provider: Provider, global_object: GlobalObject = None):
|
||||
self.provider = provider
|
||||
self.global_object = global_object
|
||||
|
||||
def check_command(self,
|
||||
message,
|
||||
session_id: str,
|
||||
role,
|
||||
platform,
|
||||
message_obj):
|
||||
# 插件
|
||||
cached_plugins = self.global_object.cached_plugins
|
||||
for k, v in cached_plugins.items():
|
||||
try:
|
||||
hit, res = v["clsobj"].run(message, role, platform, message_obj, self.global_object.platform_qq)
|
||||
if hit:
|
||||
return True, res
|
||||
except BaseException as e:
|
||||
gu.log(f"{k}插件加载出现问题,原因: {str(e)}\n已安装插件: {cached_plugins.keys}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
|
||||
|
||||
if self.command_start_with(message, "nick"):
|
||||
return True, self.set_nick(message, platform, role)
|
||||
|
||||
if self.command_start_with(message, "plugin"):
|
||||
return True, self.plugin_oper(message, role, cached_plugins, platform)
|
||||
|
||||
if self.command_start_with(message, "myid") or self.command_start_with(message, "!myid"):
|
||||
return True, self.get_my_id(message_obj)
|
||||
if self.command_start_with(message, "nconf") or self.command_start_with(message, "newconf"):
|
||||
return True, self.get_new_conf(message, role)
|
||||
if self.command_start_with(message, "web"): # 网页搜索
|
||||
return True, self.web_search(message)
|
||||
if self.command_start_with(message, "keyword"):
|
||||
return True, self.keyword(message_obj, role)
|
||||
|
||||
return False, None
|
||||
|
||||
def web_search(self, message):
|
||||
if message == "web on":
|
||||
self.global_object.web_search = True
|
||||
return True, "已开启网页搜索", "web"
|
||||
elif message == "web off":
|
||||
self.global_object.web_search = False
|
||||
return True, "已关闭网页搜索", "web"
|
||||
return True, f"网页搜索功能当前状态: {self.global_object.web_search}", "web"
|
||||
|
||||
def get_my_id(self, message_obj):
|
||||
return True, f"你的ID:{str(message_obj.sender.tiny_id)}", "plugin"
|
||||
|
||||
def get_new_conf(self, message, role):
|
||||
if role != "admin":
|
||||
return False, f"你的身份组{role}没有权限使用此指令。", "newconf"
|
||||
l = message.split(" ")
|
||||
if len(l) <= 1:
|
||||
obj = cc.get_all()
|
||||
p = gu.create_text_image("【cmd_config.json】", json.dumps(obj, indent=4, ensure_ascii=False))
|
||||
return True, [Image.fromFileSystem(p)], "newconf"
|
||||
|
||||
def get_plugin_modules(self):
|
||||
plugins = []
|
||||
@@ -41,79 +105,20 @@ class Command:
|
||||
except BaseException as e:
|
||||
raise e
|
||||
|
||||
def check_command(self, message, role, platform,
|
||||
message_obj,
|
||||
cached_plugins: dict,
|
||||
qq_platform: QQ,
|
||||
global_object: dict):
|
||||
# 插件
|
||||
|
||||
for k, v in cached_plugins.items():
|
||||
try:
|
||||
hit, res = v["clsobj"].run(message, role, platform, message_obj, qq_platform)
|
||||
if hit:
|
||||
return True, res
|
||||
except BaseException as e:
|
||||
gu.log(f"{k}插件加载出现问题,原因: {str(e)}\n已安装插件: {cached_plugins.keys}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
|
||||
|
||||
if self.command_start_with(message, "nick"):
|
||||
return True, self.set_nick(message, platform, role)
|
||||
|
||||
if self.command_start_with(message, "plugin"):
|
||||
return True, self.plugin_oper(message, role, cached_plugins, platform)
|
||||
|
||||
if self.command_start_with(message, "myid"):
|
||||
return True, self.get_my_id(message_obj, platform)
|
||||
if self.command_start_with(message, "nconf") or self.command_start_with(message, "newconf"):
|
||||
return True, self.get_new_conf(message, role, platform)
|
||||
if self.command_start_with(message, "web"): # 网页搜索
|
||||
return True, self.web_search(message, global_object)
|
||||
|
||||
return False, None
|
||||
|
||||
def web_search(self, message, global_object):
|
||||
if "web_search" not in global_object:
|
||||
global_object["web_search"] = False
|
||||
if message == "web on":
|
||||
global_object["web_search"] = True
|
||||
return True, "已开启网页搜索", "web"
|
||||
elif message == "web off":
|
||||
global_object["web_search"] = False
|
||||
return True, "已关闭网页搜索", "web"
|
||||
return True, f"网页搜索功能当前状态: {global_object['web_search']}", "web"
|
||||
def get_my_id(self, message_obj, platform):
|
||||
print(message_obj)
|
||||
if platform == "gocq":
|
||||
if message_obj.type == "GuildMessage":
|
||||
return True, f"你的频道id是{str(message_obj.sender.tiny_id)}", "plugin"
|
||||
else:
|
||||
return True, f"你的QQ是{str(message_obj.sender.user_id)}", "plugin"
|
||||
else:
|
||||
return True, f"{str(message_obj)}\n(此指令为开发专用,为提供更多数据,请自行从中找出您的频道ID。在author->id中。)", "plugin"
|
||||
|
||||
def get_new_conf(self, message, role, platform):
|
||||
if role != "admin":
|
||||
return False, f"你的身份组{role}没有权限使用此指令。", "newconf"
|
||||
if platform == gu.PLATFORM_GOCQ:
|
||||
l = message.split(" ")
|
||||
if len(l) <= 1:
|
||||
obj = cc.get_all()
|
||||
p = gu.create_text_image("【cmd_config.json】", json.dumps(obj, indent=4, ensure_ascii=False))
|
||||
return True, [Image.fromFileSystem(p)], "newconf"
|
||||
return False, f"Not support or not implemented.", "newconf"
|
||||
|
||||
|
||||
|
||||
def plugin_reload(self, cached_plugins: dict, target: str = None, all: bool = False):
|
||||
plugins = self.get_plugin_modules()
|
||||
fail_rec = ""
|
||||
if plugins is None:
|
||||
return False, "未找到任何插件模块"
|
||||
|
||||
for p in plugins:
|
||||
print(plugins)
|
||||
|
||||
for plugin in plugins:
|
||||
try:
|
||||
p = plugin['module']
|
||||
root_dir_name = plugin['pname']
|
||||
if p not in cached_plugins or p == target or all:
|
||||
module = __import__("addons.plugins." + p + "." + p, fromlist=[p])
|
||||
module = __import__("addons.plugins." + root_dir_name + "." + p, fromlist=[p])
|
||||
if p in cached_plugins:
|
||||
module = importlib.reload(module)
|
||||
cls = putil.get_classes(p, module)
|
||||
@@ -129,13 +134,15 @@ class Command:
|
||||
except BaseException as e:
|
||||
fail_rec += f"调用插件{p} info失败, 原因: {str(e)}\n"
|
||||
continue
|
||||
cached_plugins[p] = {
|
||||
cached_plugins[info['name']] = {
|
||||
"module": module,
|
||||
"clsobj": obj,
|
||||
"info": info
|
||||
"info": info,
|
||||
"name": info['name'],
|
||||
"root_dir_name": root_dir_name,
|
||||
}
|
||||
except BaseException as e:
|
||||
fail_rec += f"加载{p}插件出现问题,原因{str(e)}\n"
|
||||
fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n"
|
||||
if fail_rec == "":
|
||||
return True, None
|
||||
else:
|
||||
@@ -145,12 +152,12 @@ class Command:
|
||||
插件指令
|
||||
'''
|
||||
def plugin_oper(self, message: str, role: str, cached_plugins: dict, platform: str):
|
||||
if not has_git:
|
||||
return False, "你正在运行在无Git环境下,暂时将无法使用插件、热更新功能。", "plugin"
|
||||
l = message.split(" ")
|
||||
if len(l) < 2:
|
||||
if platform == gu.PLATFORM_GOCQ:
|
||||
p = gu.create_text_image("【插件指令面板】", "安装插件: \nplugin i 插件Github地址\n卸载插件: \nplugin d 插件名 \n重载插件: \nplugin reload\n查看插件列表:\nplugin l\n更新插件: plugin u 插件名\n")
|
||||
return True, [Image.fromFileSystem(p)], "plugin"
|
||||
return True, "\n=====插件指令面板=====\n安装插件: \nplugin i 插件Github地址\n卸载插件: \nplugin d 插件名 \n重载插件: \nplugin reload\n查看插件列表:\nplugin l\n更新插件: plugin u 插件名\n===============", "plugin"
|
||||
else:
|
||||
ppath = ""
|
||||
if os.path.exists("addons/plugins"):
|
||||
@@ -163,8 +170,13 @@ class Command:
|
||||
if role != "admin":
|
||||
return False, f"你的身份组{role}没有权限安装插件", "plugin"
|
||||
try:
|
||||
# 删除末尾的/
|
||||
if l[2].endswith("/"):
|
||||
l[2] = l[2][:-1]
|
||||
# 得到url的最后一段
|
||||
d = l[2].split("/")[-1]
|
||||
# 转换非法字符:-
|
||||
d = d.replace("-", "_")
|
||||
# 创建文件夹
|
||||
plugin_path = os.path.join(ppath, d)
|
||||
if os.path.exists(plugin_path):
|
||||
@@ -174,9 +186,7 @@ class Command:
|
||||
|
||||
# 读取插件的requirements.txt
|
||||
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
|
||||
with open(os.path.join(plugin_path, "requirements.txt"), "r", encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
mm = os.system(f"pip3 install {line.strip()}")
|
||||
mm = pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt")])
|
||||
if mm != 0:
|
||||
return False, "插件依赖安装失败,需要您手动pip安装对应插件的依赖。", "plugin"
|
||||
# 加载没缓存的插件
|
||||
@@ -192,26 +202,27 @@ class Command:
|
||||
elif l[1] == "d":
|
||||
if role != "admin":
|
||||
return False, f"你的身份组{role}没有权限删除插件", "plugin"
|
||||
if l[2] not in cached_plugins:
|
||||
return False, "未找到该插件", "plugin"
|
||||
|
||||
try:
|
||||
# 删除文件夹
|
||||
# shutil.rmtree(os.path.join(ppath, l[2]))
|
||||
self.remove_dir(os.path.join(ppath, l[2]))
|
||||
if l[2] in cached_plugins:
|
||||
root_dir_name = cached_plugins[l[2]]["root_dir_name"]
|
||||
self.remove_dir(os.path.join(ppath, root_dir_name))
|
||||
del cached_plugins[l[2]]
|
||||
return True, "插件卸载成功~", "plugin"
|
||||
except BaseException as e:
|
||||
return False, f"卸载插件失败,原因: {str(e)}", "plugin"
|
||||
elif l[1] == "u":
|
||||
plugin_path = os.path.join(ppath, l[2])
|
||||
if l[2] not in cached_plugins:
|
||||
return False, "未找到该插件", "plugin"
|
||||
root_dir_name = cached_plugins[l[2]]["root_dir_name"]
|
||||
plugin_path = os.path.join(ppath, root_dir_name)
|
||||
try:
|
||||
repo = Repo(path = plugin_path)
|
||||
repo.remotes.origin.pull()
|
||||
|
||||
# 读取插件的requirements.txt
|
||||
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
|
||||
with open(os.path.join(plugin_path, "requirements.txt"), "r", encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
mm = os.system(f"pip3 install {line.strip()}")
|
||||
mm = pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt")])
|
||||
if mm != 0:
|
||||
return False, "插件依赖安装失败,需要您手动pip安装对应插件的依赖。", "plugin"
|
||||
|
||||
@@ -226,21 +237,16 @@ class Command:
|
||||
elif l[1] == "l":
|
||||
try:
|
||||
plugin_list_info = "\n".join([f"{k}: \n名称: {v['info']['name']}\n简介: {v['info']['desc']}\n版本: {v['info']['version']}\n作者: {v['info']['author']}\n" for k, v in cached_plugins.items()])
|
||||
if platform == gu.PLATFORM_GOCQ:
|
||||
p = gu.create_text_image("【已激活插件列表】", plugin_list_info + "\n使用plugin v 插件名 查看插件帮助\n")
|
||||
return True, [Image.fromFileSystem(p)], "plugin"
|
||||
return True, "\n=====已激活插件列表=====\n" + plugin_list_info + "\n使用plugin v 插件名 查看插件帮助\n=================", "plugin"
|
||||
except BaseException as e:
|
||||
return False, f"获取插件列表失败,原因: {str(e)}", "plugin"
|
||||
elif l[1] == "v":
|
||||
try:
|
||||
if l[2] in cached_plugins:
|
||||
info = cached_plugins[l[2]]["info"]
|
||||
if platform == gu.PLATFORM_GOCQ:
|
||||
p = gu.create_text_image(f"【插件信息】", f"名称: {info['name']}\n{info['desc']}\n版本: {info['version']}\n作者: {info['author']}\n\n帮助:\n{info['help']}")
|
||||
return True, [Image.fromFileSystem(p)], "plugin"
|
||||
res = f"\n=====插件信息=====\n名称: {info['name']}\n{info['desc']}\n版本: {info['version']}作者: {info['author']}\n\n帮助:\n{info['help']}"
|
||||
return True, res, "plugin"
|
||||
else:
|
||||
return False, "未找到该插件", "plugin"
|
||||
except BaseException as e:
|
||||
@@ -248,6 +254,16 @@ class Command:
|
||||
elif l[1] == "reload":
|
||||
if role != "admin":
|
||||
return False, f"你的身份组{role}没有权限重载插件", "plugin"
|
||||
for plugin in cached_plugins:
|
||||
try:
|
||||
print(f"更新插件 {plugin} 依赖...")
|
||||
plugin_path = os.path.join(ppath, cached_plugins[plugin]["root_dir_name"])
|
||||
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
|
||||
mm = pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), "--quiet"])
|
||||
if mm != 0:
|
||||
return False, "插件依赖安装失败,需要您手动pip安装对应插件的依赖。", "plugin"
|
||||
except BaseException as e:
|
||||
print(f"插件{plugin}依赖安装失败,原因: {str(e)}")
|
||||
try:
|
||||
ok, err = self.plugin_reload(cached_plugins, all = True)
|
||||
if ok:
|
||||
@@ -264,7 +280,6 @@ class Command:
|
||||
return False, f"你的身份组{role}没有权限开发者模式", "plugin"
|
||||
return True, "cached_plugins: \n" + str(cached_plugins), "plugin"
|
||||
|
||||
|
||||
def remove_dir(self, file_path):
|
||||
while 1:
|
||||
if not os.path.exists(file_path):
|
||||
@@ -276,7 +291,6 @@ class Command:
|
||||
if os.path.exists(err_file_path):
|
||||
os.chmod(err_file_path, stat.S_IWUSR)
|
||||
|
||||
|
||||
'''
|
||||
nick: 存储机器人的昵称
|
||||
'''
|
||||
@@ -288,27 +302,13 @@ class Command:
|
||||
if len(l) == 1:
|
||||
return True, "【设置机器人昵称】示例:\n支持多昵称\nnick 昵称1 昵称2 昵称3", "nick"
|
||||
nick = l[1:]
|
||||
self.general_command_storer("nick_qq", nick)
|
||||
cc.put("nick_qq", nick)
|
||||
self.global_object.nick = tuple(nick)
|
||||
return True, f"设置成功!现在你可以叫我这些昵称来提问我啦~", "nick"
|
||||
elif platform == PLATFORM_QQCHAN:
|
||||
nick = message.split(" ")[2]
|
||||
return False, "QQ频道平台不支持为机器人设置昵称。", "nick"
|
||||
|
||||
"""
|
||||
存储指令结果到cmd_config.json
|
||||
"""
|
||||
def general_command_storer(self, key, value):
|
||||
if not os.path.exists("cmd_config.json"):
|
||||
config = {}
|
||||
else:
|
||||
with open("cmd_config.json", "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
config[key] = value
|
||||
with open("cmd_config.json", "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
|
||||
def general_commands(self):
|
||||
return {
|
||||
"help": "帮助",
|
||||
@@ -343,14 +343,13 @@ class Command:
|
||||
msg += plugin_list_info
|
||||
msg += notice
|
||||
|
||||
if platform == gu.PLATFORM_GOCQ:
|
||||
try:
|
||||
# p = gu.create_text_image("【Help Center】", msg)
|
||||
p = gu.create_markdown_image(msg)
|
||||
return [Image.fromFileSystem(p)]
|
||||
except BaseException as e:
|
||||
gu.log(str(e))
|
||||
return msg
|
||||
finally:
|
||||
return msg
|
||||
|
||||
# 接受可变参数
|
||||
@@ -361,14 +360,36 @@ class Command:
|
||||
return False
|
||||
|
||||
# keyword: 关键字
|
||||
def keyword(self, message: str, role: str):
|
||||
def keyword(self, message_obj, role: str):
|
||||
if role != "admin":
|
||||
return True, "你没有权限使用该指令", "keyword"
|
||||
|
||||
l = message.split(" ")
|
||||
plain_text = ""
|
||||
image_url = ""
|
||||
|
||||
if len(l) < 3:
|
||||
return True, "【设置关键词回复】示例:\nkeyword hi 你好\n当发送hi的时候会回复你好\nkeyword /hi 你好\n当发送/hi时会回复你好\n删除关键词: keyword d hi\n删除hi关键词的回复", "keyword"
|
||||
for comp in message_obj.message:
|
||||
if isinstance(comp, Plain):
|
||||
plain_text += comp.text
|
||||
elif isinstance(comp, Image) and image_url == "":
|
||||
if comp.url is None:
|
||||
image_url = comp.file
|
||||
else:
|
||||
image_url = comp.url
|
||||
|
||||
l = plain_text.split(" ")
|
||||
|
||||
if len(l) < 3 and image_url == "":
|
||||
return True, """
|
||||
【设置关键词回复】示例:
|
||||
1. keyword hi 你好
|
||||
当发送hi的时候会回复你好
|
||||
2. keyword /hi 你好
|
||||
当发送/hi时会回复你好
|
||||
3. keyword d hi
|
||||
删除hi关键词的回复
|
||||
4. keyword hi <图片>
|
||||
当发送hi时会回复图片
|
||||
""", "keyword"
|
||||
|
||||
del_mode = False
|
||||
if l[1] == "d":
|
||||
@@ -384,21 +405,34 @@ class Command:
|
||||
return False, "该关键词不存在", "keyword"
|
||||
else: del keyword[l[2]]
|
||||
else:
|
||||
keyword[l[1]] = l[2]
|
||||
keyword[l[1]] = {
|
||||
"plain_text": " ".join(l[2:]),
|
||||
"image_url": image_url
|
||||
}
|
||||
else:
|
||||
if del_mode:
|
||||
return False, "该关键词不存在", "keyword"
|
||||
keyword = {l[1]: l[2]}
|
||||
keyword = {
|
||||
l[1]: {
|
||||
"plain_text": " ".join(l[2:]),
|
||||
"image_url": image_url
|
||||
}
|
||||
}
|
||||
with open("keyword.json", "w", encoding="utf-8") as f:
|
||||
json.dump(keyword, f, ensure_ascii=False, indent=4)
|
||||
f.flush()
|
||||
if del_mode:
|
||||
return True, "删除成功: "+l[2], "keyword"
|
||||
return True, "设置成功: "+l[1]+" -> "+l[2], "keyword"
|
||||
if image_url == "":
|
||||
return True, "设置成功: "+l[1]+" "+" ".join(l[2:]), "keyword"
|
||||
else:
|
||||
return True, [Plain("设置成功: "+l[1]+" "+" ".join(l[2:])), Image.fromURL(image_url)], "keyword"
|
||||
except BaseException as e:
|
||||
return False, "设置失败: "+str(e), "keyword"
|
||||
|
||||
def update(self, message: str, role: str):
|
||||
if not has_git:
|
||||
return False, "你正在运行在无Git环境下,暂时将无法使用插件、热更新功能。", "update"
|
||||
if role != "admin":
|
||||
return True, "你没有权限使用该指令", "keyword"
|
||||
l = message.split(" ")
|
||||
@@ -436,11 +470,19 @@ class Command:
|
||||
pash_tag = "QQChannelChatGPT"+os.sep
|
||||
repo.remotes.origin.pull()
|
||||
|
||||
if len(l) == 3 and l[2] == "r":
|
||||
py = sys.executable
|
||||
os.execl(py, py, *sys.argv)
|
||||
try:
|
||||
origin = repo.remotes.origin
|
||||
origin.fetch()
|
||||
commits = list(repo.iter_commits('master', max_count=1))
|
||||
commit_log = commits[0].message
|
||||
except BaseException as e:
|
||||
commit_log = "无法获取commit信息"
|
||||
|
||||
return True, "更新成功~是否重启?输入update r重启(重启指令不返回任何确认信息)。", "update"
|
||||
tag = "update"
|
||||
if len(l) == 3 and l[2] == "r":
|
||||
tag = "update latest r"
|
||||
|
||||
return True, f"更新成功。新版本内容: \n{commit_log}\nps:重启后生效。输入update r重启(重启指令不返回任何确认信息)。", tag
|
||||
|
||||
except BaseException as e:
|
||||
return False, "更新失败: "+str(e), "update"
|
||||
@@ -448,7 +490,6 @@ class Command:
|
||||
py = sys.executable
|
||||
os.execl(py, py, *sys.argv)
|
||||
|
||||
|
||||
def reset(self):
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,36 +1,40 @@
|
||||
from model.command.command import Command
|
||||
from model.provider.provider_openai_official import ProviderOpenAIOfficial
|
||||
from cores.qqbot.personality import personalities
|
||||
|
||||
from model.platform.qq import QQ
|
||||
from util import general_utils as gu
|
||||
|
||||
from cores.qqbot.global_object import GlobalObject
|
||||
|
||||
class CommandOpenAIOfficial(Command):
|
||||
def __init__(self, provider: ProviderOpenAIOfficial, global_object: dict):
|
||||
def __init__(self, provider: ProviderOpenAIOfficial, global_object: GlobalObject):
|
||||
self.provider = provider
|
||||
self.cached_plugins = {}
|
||||
self.global_object = global_object
|
||||
self.personality_str = ""
|
||||
super().__init__(provider, global_object)
|
||||
|
||||
def check_command(self,
|
||||
message: str,
|
||||
session_id: str,
|
||||
user_name: str,
|
||||
role: str,
|
||||
platform: str,
|
||||
message_obj,
|
||||
cached_plugins: dict,
|
||||
qq_platform: QQ,):
|
||||
message_obj):
|
||||
self.platform = platform
|
||||
hit, res = super().check_command(message, role, platform, message_obj=message_obj,
|
||||
cached_plugins=cached_plugins,
|
||||
qq_platform=qq_platform,
|
||||
global_object=self.global_object)
|
||||
hit, res = super().check_command(
|
||||
message,
|
||||
session_id,
|
||||
role,
|
||||
platform,
|
||||
message_obj
|
||||
)
|
||||
|
||||
if hit:
|
||||
return True, res
|
||||
if self.command_start_with(message, "reset", "重置"):
|
||||
return True, self.reset(session_id)
|
||||
return True, self.reset(session_id, message)
|
||||
elif self.command_start_with(message, "his", "历史"):
|
||||
return True, self.his(message, session_id, user_name)
|
||||
return True, self.his(message, session_id)
|
||||
elif self.command_start_with(message, "token"):
|
||||
return True, self.token(session_id)
|
||||
elif self.command_start_with(message, "gpt"):
|
||||
@@ -40,7 +44,7 @@ class CommandOpenAIOfficial(Command):
|
||||
elif self.command_start_with(message, "count"):
|
||||
return True, self.count()
|
||||
elif self.command_start_with(message, "help", "帮助"):
|
||||
return True, self.help(cached_plugins)
|
||||
return True, self.help()
|
||||
elif self.command_start_with(message, "unset"):
|
||||
return True, self.unset(session_id)
|
||||
elif self.command_start_with(message, "set"):
|
||||
@@ -49,17 +53,14 @@ class CommandOpenAIOfficial(Command):
|
||||
return True, self.update(message, role)
|
||||
elif self.command_start_with(message, "画", "draw"):
|
||||
return True, self.draw(message)
|
||||
elif self.command_start_with(message, "keyword"):
|
||||
return True, self.keyword(message, role)
|
||||
elif self.command_start_with(message, "key"):
|
||||
return True, self.key(message, user_name)
|
||||
|
||||
if self.command_start_with(message, "/"):
|
||||
return True, (False, "未知指令", "unknown_command")
|
||||
return True, self.key(message)
|
||||
elif self.command_start_with(message, "switch"):
|
||||
return True, self.switch(message)
|
||||
|
||||
return False, None
|
||||
|
||||
def help(self, cached_plugins):
|
||||
def help(self):
|
||||
commands = super().general_commands()
|
||||
commands['画'] = '画画'
|
||||
commands['key'] = '添加OpenAI key'
|
||||
@@ -67,16 +68,23 @@ class CommandOpenAIOfficial(Command):
|
||||
commands['gpt'] = '查看gpt配置信息'
|
||||
commands['status'] = '查看key使用状态'
|
||||
commands['token'] = '查看本轮会话token'
|
||||
return True, super().help_messager(commands, self.platform, cached_plugins), "help"
|
||||
return True, super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
|
||||
|
||||
|
||||
def reset(self, session_id: str):
|
||||
def reset(self, session_id: str, message: str = "reset"):
|
||||
if self.provider is None:
|
||||
return False, "未启动OpenAI ChatGPT语言模型.", "reset"
|
||||
l = message.split(" ")
|
||||
if len(l) == 1:
|
||||
self.provider.forget(session_id)
|
||||
return True, "重置成功", "reset"
|
||||
if len(l) == 2 and l[1] == "p":
|
||||
self.provider.forget(session_id)
|
||||
if self.personality_str != "":
|
||||
self.set(self.personality_str, session_id) # 重新设置人格
|
||||
return True, "重置成功", "reset"
|
||||
|
||||
def his(self, message: str, session_id: str, name: str):
|
||||
def his(self, message: str, session_id: str):
|
||||
if self.provider is None:
|
||||
return False, "未启动OpenAI ChatGPT语言模型.", "his"
|
||||
#分页,每页5条
|
||||
@@ -122,17 +130,17 @@ class CommandOpenAIOfficial(Command):
|
||||
continue
|
||||
if 'sponsor' in key_stat[key]:
|
||||
sponsor = key_stat[key]['sponsor']
|
||||
chatgpt_cfg_str += f" |-{index}: {key_stat[key]['used']}/{max} {sponsor}赞助{tag}\n"
|
||||
chatgpt_cfg_str += f" |-{index}: {key[-8:]} {key_stat[key]['used']}/{max} {sponsor}{tag}\n"
|
||||
index += 1
|
||||
return True, f"⭐使用情况({str(gg_count)}个已用):\n{chatgpt_cfg_str}⏰全频道已用{total}tokens", "status"
|
||||
return True, f"⭐使用情况({str(gg_count)}个已用):\n{chatgpt_cfg_str}", "status"
|
||||
|
||||
def count(self):
|
||||
if self.provider is None:
|
||||
return False, "未启动OpenAI ChatGPT语言模型.", "reset"
|
||||
return False, "未启动OpenAI ChatGPT语言模型。", "reset"
|
||||
guild_count, guild_msg_count, guild_direct_msg_count, session_count = self.provider.get_stat()
|
||||
return True, f"当前会话数: {len(self.provider.session_dict)}\n共有频道数: {guild_count} \n共有消息数: {guild_msg_count}\n私信数: {guild_direct_msg_count}\n历史会话数: {session_count}", "count"
|
||||
return True, f"【本指令部分统计可能已经过时】\n当前会话数: {len(self.provider.session_dict)}\n共有频道数: {guild_count} \n共有消息数: {guild_msg_count}\n私信数: {guild_direct_msg_count}\n历史会话数: {session_count}", "count"
|
||||
|
||||
def key(self, message: str, user_name: str):
|
||||
def key(self, message: str):
|
||||
if self.provider is None:
|
||||
return False, "未启动OpenAI ChatGPT语言模型.", "reset"
|
||||
l = message.split(" ")
|
||||
@@ -141,11 +149,41 @@ class CommandOpenAIOfficial(Command):
|
||||
return True, msg, "key"
|
||||
key = l[1]
|
||||
if self.provider.check_key(key):
|
||||
self.provider.append_key(key, user_name)
|
||||
return True, f"*★,°*:.☆( ̄▽ ̄)/$:*.°★* 。\n该Key被验证为有效。感谢{user_name}赞助~"
|
||||
self.provider.append_key(key)
|
||||
return True, f"*★,°*:.☆( ̄▽ ̄)/$:*.°★* 。\n该Key被验证为有效。感谢你的赞助~"
|
||||
else:
|
||||
return True, "该Key被验证为无效。也许是输入错误了,或者重试。", "key"
|
||||
|
||||
def switch(self, message: str):
|
||||
'''
|
||||
切换账号
|
||||
'''
|
||||
l = message.split(" ")
|
||||
if len(l) == 1:
|
||||
_, ret, _ = self.status()
|
||||
curr_ = self.provider.get_curr_key()
|
||||
if curr_ is None:
|
||||
ret += "当前您未选择账号。输入/switch <账号序号>切换账号。"
|
||||
else:
|
||||
ret += f"当前您选择的账号为:{curr_[-8:]}。输入/switch <账号序号>切换账号。"
|
||||
return True, ret, "switch"
|
||||
elif len(l) == 2:
|
||||
try:
|
||||
key_stat = self.provider.get_key_stat()
|
||||
index = int(l[1])
|
||||
if index > len(key_stat) or index < 1:
|
||||
return True, "账号序号不合法。", "switch"
|
||||
else:
|
||||
ret = self.provider.check_key(list(key_stat.keys())[index-1])
|
||||
if ret:
|
||||
return True, f"账号切换成功。", "switch"
|
||||
else:
|
||||
return True, f"账号切换失败,可能超额或超频。", "switch"
|
||||
except BaseException as e:
|
||||
return True, "未知错误: "+str(e), "switch"
|
||||
else:
|
||||
return True, "参数过多。", "switch"
|
||||
|
||||
def unset(self, session_id: str):
|
||||
if self.provider is None:
|
||||
return False, "未启动OpenAI ChatGPT语言模型.", "unset"
|
||||
@@ -158,9 +196,9 @@ class CommandOpenAIOfficial(Command):
|
||||
return False, "未启动OpenAI ChatGPT语言模型.", "set"
|
||||
l = message.split(" ")
|
||||
if len(l) == 1:
|
||||
return True, f"【由Github项目QQChannelChatGPT支持】\n\n【人格文本由PlexPt开源项目awesome-chatgpt-pr \
|
||||
ompts-zh提供】\n\n这个是人格设置指令。\n设置人格: \n/set 人格名。例如/set 编剧\n人格列表: /set list\n人格详细信息: \
|
||||
/set view 人格名\n自定义人格: /set 人格文本\n清除人格: /unset\n【当前人格】: {str(self.provider.now_personality)}", "set"
|
||||
return True, f"【人格文本由PlexPt开源项目awesome-chatgpt-pr \
|
||||
ompts-zh提供】\n设置人格: \n/set 人格名。例如/set 编剧\n人格列表: /set list\n人格详细信息: \
|
||||
/set view 人格名\n自定义人格: /set 人格文本\n重置会话(清除人格): /reset\n重置会话(保留人格): /reset p\n【当前人格】: {str(self.provider.now_personality)}", "set"
|
||||
elif l[1] == "list":
|
||||
msg = "人格列表:\n"
|
||||
for key in personalities.keys():
|
||||
@@ -188,14 +226,20 @@ class CommandOpenAIOfficial(Command):
|
||||
self.provider.session_dict[session_id] = []
|
||||
new_record = {
|
||||
"user": {
|
||||
"role": "system",
|
||||
"role": "user",
|
||||
"content": personalities[ps],
|
||||
},
|
||||
"AI": {
|
||||
"role": "assistant",
|
||||
"content": "好的,接下来我会扮演这个角色。"
|
||||
},
|
||||
'type': "personality",
|
||||
'usage_tokens': 0,
|
||||
'single-tokens': 0
|
||||
}
|
||||
self.provider.session_dict[session_id].append(new_record)
|
||||
return True, f"人格{ps}已设置.", "set"
|
||||
self.personality_str = message
|
||||
return True, f"人格{ps}已设置。", "set"
|
||||
else:
|
||||
self.provider.now_personality = {
|
||||
'name': '自定义人格',
|
||||
@@ -203,14 +247,20 @@ class CommandOpenAIOfficial(Command):
|
||||
}
|
||||
new_record = {
|
||||
"user": {
|
||||
"role": "system",
|
||||
"role": "user",
|
||||
"content": ps,
|
||||
},
|
||||
"AI": {
|
||||
"role": "assistant",
|
||||
"content": "好的,接下来我会扮演这个角色。"
|
||||
},
|
||||
'type': "personality",
|
||||
'usage_tokens': 0,
|
||||
'single-tokens': 0
|
||||
}
|
||||
self.provider.session_dict[session_id] = []
|
||||
self.provider.session_dict[session_id].append(new_record)
|
||||
self.personality_str = message
|
||||
return True, f"自定义人格已设置。 \n人格信息: {ps}", "set"
|
||||
|
||||
def draw(self, message):
|
||||
|
||||
@@ -1,43 +1,134 @@
|
||||
from model.command.command import Command
|
||||
from model.provider.provider_rev_chatgpt import ProviderRevChatGPT
|
||||
from model.platform.qq import QQ
|
||||
from cores.qqbot.personality import personalities
|
||||
from cores.qqbot.global_object import GlobalObject
|
||||
|
||||
class CommandRevChatGPT(Command):
|
||||
def __init__(self, provider: ProviderRevChatGPT, global_object: dict):
|
||||
def __init__(self, provider: ProviderRevChatGPT, global_object: GlobalObject):
|
||||
self.provider = provider
|
||||
self.cached_plugins = {}
|
||||
self.global_object = global_object
|
||||
self.personality_str = ""
|
||||
super().__init__(provider, global_object)
|
||||
|
||||
def check_command(self,
|
||||
message: str,
|
||||
session_id: str,
|
||||
role: str,
|
||||
platform: str,
|
||||
message_obj,
|
||||
cached_plugins: dict,
|
||||
qq_platform: QQ):
|
||||
message_obj):
|
||||
self.platform = platform
|
||||
hit, res = super().check_command(message, role, platform, message_obj=message_obj,
|
||||
cached_plugins=cached_plugins,
|
||||
qq_platform=qq_platform,
|
||||
global_object=self.global_object)
|
||||
hit, res = super().check_command(
|
||||
message,
|
||||
session_id,
|
||||
role,
|
||||
platform,
|
||||
message_obj
|
||||
)
|
||||
|
||||
if hit:
|
||||
return True, res
|
||||
if self.command_start_with(message, "help", "帮助"):
|
||||
return True, self.help(cached_plugins)
|
||||
return True, self.help()
|
||||
elif self.command_start_with(message, "reset"):
|
||||
return True, self.reset()
|
||||
return True, self.reset(session_id, message)
|
||||
elif self.command_start_with(message, "update"):
|
||||
return True, self.update(message, role)
|
||||
elif self.command_start_with(message, "keyword"):
|
||||
return True, self.keyword(message, role)
|
||||
|
||||
if self.command_start_with(message, "/"):
|
||||
return True, (False, "未知指令", "unknown_command")
|
||||
elif self.command_start_with(message, "set"):
|
||||
return True, self.set(message, session_id)
|
||||
elif self.command_start_with(message, "switch"):
|
||||
return True, self.switch(message, session_id)
|
||||
return False, None
|
||||
|
||||
def reset(self):
|
||||
return False, "此功能暂未开放", "reset"
|
||||
def reset(self, session_id, message: str):
|
||||
l = message.split(" ")
|
||||
if len(l) == 1:
|
||||
self.provider.forget(session_id)
|
||||
return True, "重置完毕。", "reset"
|
||||
if len(l) == 2 and l[1] == "p":
|
||||
self.provider.forget(session_id)
|
||||
ret = self.provider.text_chat(self.personality_str)
|
||||
return True, f"重置完毕(保留人格)。\n\n{ret}", "reset"
|
||||
|
||||
def set(self, message: str, session_id: str):
|
||||
l = message.split(" ")
|
||||
if len(l) == 1:
|
||||
return True, f"设置人格: \n/set 人格名或人格文本。例如/set 编剧\n人格列表: /set list\n人格详细信息: \
|
||||
/set view 人格名\n重置会话(清除人格): /reset\n重置会话(保留人格): /reset p", "set"
|
||||
elif l[1] == "list":
|
||||
msg = "人格列表:\n"
|
||||
for key in personalities.keys():
|
||||
msg += f" |-{key}\n"
|
||||
msg += '\n\n*输入/set view 人格名查看人格详细信息'
|
||||
msg += '\n*不定时更新人格库,请及时更新本项目。'
|
||||
return True, msg, "set"
|
||||
elif l[1] == "view":
|
||||
if len(l) == 2:
|
||||
return True, "请输入/set view 人格名", "set"
|
||||
ps = l[2].strip()
|
||||
if ps in personalities:
|
||||
msg = f"人格【{ps}】详细信息:\n"
|
||||
msg += f"{personalities[ps]}\n"
|
||||
else:
|
||||
msg = f"人格【{ps}】不存在。"
|
||||
return True, msg, "set"
|
||||
else:
|
||||
ps = l[1].strip()
|
||||
if ps in personalities:
|
||||
self.reset(session_id, "reset")
|
||||
self.personality_str = personalities[ps]
|
||||
ret = self.provider.text_chat(self.personality_str, session_id)
|
||||
return True, f"人格【{ps}】已设置。\n\n{ret}", "set"
|
||||
else:
|
||||
self.reset(session_id, "reset")
|
||||
self.personality_str = ps
|
||||
ret = self.provider.text_chat(ps, session_id)
|
||||
return True, f"人格信息已设置。\n\n{ret}", "set"
|
||||
|
||||
def help(self, cached_plugins: dict):
|
||||
return True, super().help_messager(super().general_commands(), self.platform, cached_plugins), "help"
|
||||
def switch(self, message: str, session_id: str):
|
||||
'''
|
||||
切换账号
|
||||
'''
|
||||
l = message.split(" ")
|
||||
rev_chatgpt = self.provider.get_revchatgpt()
|
||||
if len(l) == 1:
|
||||
ret = "当前账号:\n"
|
||||
index = 0
|
||||
curr_ = None
|
||||
for revstat in rev_chatgpt:
|
||||
index += 1
|
||||
ret += f"[{index}]. {revstat['id']}\n"
|
||||
# if session_id in revstat['user']:
|
||||
# curr_ = revstat['id']
|
||||
for user in revstat['user']:
|
||||
if session_id == user['id']:
|
||||
curr_ = revstat['id']
|
||||
break
|
||||
if curr_ is None:
|
||||
ret += "当前您未选择账号。输入/switch <账号序号>切换账号。"
|
||||
else:
|
||||
ret += f"当前您选择的账号为:{curr_}。输入/switch <账号序号>切换账号。"
|
||||
return True, ret, "switch"
|
||||
elif len(l) == 2:
|
||||
try:
|
||||
index = int(l[1])
|
||||
if index > len(self.provider.rev_chatgpt) or index < 1:
|
||||
return True, "账号序号不合法。", "switch"
|
||||
else:
|
||||
# pop
|
||||
for revstat in self.provider.rev_chatgpt:
|
||||
if session_id in revstat['user']:
|
||||
revstat['user'].remove(session_id)
|
||||
# append
|
||||
self.provider.rev_chatgpt[index - 1]['user'].append(session_id)
|
||||
return True, f"切换账号成功。当前账号为:{self.provider.rev_chatgpt[index - 1]['id']}", "switch"
|
||||
except BaseException:
|
||||
return True, "账号序号不合法。", "switch"
|
||||
else:
|
||||
return True, "参数过多。", "switch"
|
||||
|
||||
def help(self):
|
||||
commands = super().general_commands()
|
||||
commands['set'] = '设置人格'
|
||||
return True, super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
|
||||
|
||||
@@ -2,42 +2,43 @@ from model.command.command import Command
|
||||
from model.provider.provider_rev_edgegpt import ProviderRevEdgeGPT
|
||||
import asyncio
|
||||
from model.platform.qq import QQ
|
||||
from cores.qqbot.global_object import GlobalObject
|
||||
|
||||
class CommandRevEdgeGPT(Command):
|
||||
def __init__(self, provider: ProviderRevEdgeGPT, global_object: dict):
|
||||
def __init__(self, provider: ProviderRevEdgeGPT, global_object: GlobalObject):
|
||||
self.provider = provider
|
||||
self.cached_plugins = {}
|
||||
self.global_object = global_object
|
||||
super().__init__(provider, global_object)
|
||||
|
||||
def check_command(self,
|
||||
message: str,
|
||||
loop,
|
||||
session_id: str,
|
||||
role: str,
|
||||
platform: str,
|
||||
message_obj,
|
||||
cached_plugins: dict,
|
||||
qq_platform: QQ):
|
||||
message_obj):
|
||||
self.platform = platform
|
||||
hit, res = super().check_command(message, role, platform, message_obj=message_obj,
|
||||
cached_plugins=cached_plugins,
|
||||
qq_platform=qq_platform,
|
||||
global_object=self.global_object)
|
||||
|
||||
hit, res = super().check_command(
|
||||
message,
|
||||
session_id,
|
||||
role,
|
||||
platform,
|
||||
message_obj
|
||||
)
|
||||
|
||||
if hit:
|
||||
return True, res
|
||||
if self.command_start_with(message, "reset"):
|
||||
return True, self.reset(loop)
|
||||
return True, self.reset()
|
||||
elif self.command_start_with(message, "help"):
|
||||
return True, self.help(cached_plugins)
|
||||
return True, self.help()
|
||||
elif self.command_start_with(message, "update"):
|
||||
return True, self.update(message, role)
|
||||
elif self.command_start_with(message, "keyword"):
|
||||
return True, self.keyword(message, role)
|
||||
|
||||
if self.command_start_with(message, "/"):
|
||||
return True, (False, "未知指令", "unknown_command")
|
||||
return False, None
|
||||
|
||||
def reset(self, loop):
|
||||
def reset(self, loop = None):
|
||||
if self.provider is None:
|
||||
return False, "未启动Bing语言模型.", "reset"
|
||||
res = asyncio.run_coroutine_threadsafe(self.provider.forget(), loop).result()
|
||||
@@ -47,6 +48,5 @@ class CommandRevEdgeGPT(Command):
|
||||
else:
|
||||
return res, "重置失败", "reset"
|
||||
|
||||
def help(self, cached_plugins: dict):
|
||||
return True, super().help_messager(super().general_commands(), self.platform, cached_plugins), "help"
|
||||
|
||||
def help(self):
|
||||
return True, super().help_messager(super().general_commands(), self.platform, self.global_object.cached_plugins), "help"
|
||||
|
||||
@@ -19,6 +19,8 @@ class QQ:
|
||||
self.is_start = is_start
|
||||
self.gocq_loop = gocq_loop
|
||||
self.cc = cc
|
||||
self.waiting = {}
|
||||
self.gocq_cnt = 0
|
||||
|
||||
def run_bot(self, gocq):
|
||||
self.client: CQHTTP = gocq
|
||||
@@ -27,11 +29,17 @@ class QQ:
|
||||
def get_msg_loop(self):
|
||||
return self.gocq_loop
|
||||
|
||||
def get_cnt(self):
|
||||
return self.gocq_cnt
|
||||
|
||||
def set_cnt(self, cnt):
|
||||
self.gocq_cnt = cnt
|
||||
|
||||
async def send_qq_msg(self,
|
||||
source,
|
||||
res,
|
||||
image_mode: bool = False):
|
||||
|
||||
image_mode=None):
|
||||
self.gocq_cnt += 1
|
||||
if not self.is_start:
|
||||
raise Exception("管理员未启动GOCQ平台")
|
||||
"""
|
||||
@@ -47,11 +55,13 @@ class QQ:
|
||||
if isinstance(res, str):
|
||||
res_str = res
|
||||
res = []
|
||||
if source.type == "GroupMessage":
|
||||
if source.type == "GroupMessage" and not isinstance(source, FakeSource):
|
||||
res.append(At(qq=source.user_id))
|
||||
res.append(Plain(text=res_str))
|
||||
|
||||
# if image mode, put all Plain texts into a new picture.
|
||||
if image_mode is None:
|
||||
image_mode = self.cc.get('qq_pic_mode', False)
|
||||
if image_mode and isinstance(res, list):
|
||||
plains = []
|
||||
news = []
|
||||
@@ -60,11 +70,12 @@ class QQ:
|
||||
plains.append(i.text)
|
||||
else:
|
||||
news.append(i)
|
||||
plains_str = "".join(plains).strip()
|
||||
if plains_str != "" and len(plains_str) > 50:
|
||||
p = gu.create_markdown_image("".join(plains))
|
||||
news.append(Image.fromFileSystem(p))
|
||||
res = news
|
||||
|
||||
|
||||
# 回复消息链
|
||||
if isinstance(res, list) and len(res) > 0:
|
||||
if source.type == "GuildMessage":
|
||||
@@ -89,10 +100,10 @@ class QQ:
|
||||
res.remove(i)
|
||||
node = Node(res)
|
||||
# node.content = res
|
||||
node.uin = source.self_id
|
||||
node.name = f"To {source.sender.nickname}:"
|
||||
node.uin = 123456
|
||||
node.name = f"bot"
|
||||
node.time = int(time.time())
|
||||
print(node)
|
||||
# print(node)
|
||||
nodes=[node]
|
||||
await self.client.sendGroupForwardMessage(source.group_id, nodes)
|
||||
return
|
||||
@@ -102,13 +113,15 @@ class QQ:
|
||||
def send(self,
|
||||
to,
|
||||
res,
|
||||
image_mode=False,
|
||||
):
|
||||
'''
|
||||
提供给插件的发送QQ消息接口, 不用在外部await。
|
||||
参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。
|
||||
第三个参数是是否开启图片模式,如果开启,那么所有纯文字信息都会被合并成一张图片。
|
||||
'''
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(self.send_qq_msg(to, res), self.gocq_loop).result()
|
||||
asyncio.run_coroutine_threadsafe(self.send_qq_msg(to, res, image_mode), self.gocq_loop).result()
|
||||
except BaseException as e:
|
||||
raise e
|
||||
|
||||
@@ -141,3 +154,30 @@ class QQ:
|
||||
return p
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def wait_for_message(self, group_id):
|
||||
'''
|
||||
等待下一条消息
|
||||
'''
|
||||
self.waiting[group_id] = ''
|
||||
while True:
|
||||
if group_id in self.waiting and self.waiting[group_id] != '':
|
||||
# 去掉
|
||||
ret = self.waiting[group_id]
|
||||
del self.waiting[group_id]
|
||||
return ret
|
||||
time.sleep(0.5)
|
||||
|
||||
def get_client(self):
|
||||
return self.client
|
||||
|
||||
def nakuru_method_invoker(self, func, *args, **kwargs):
|
||||
"""
|
||||
返回一个方法调用器,可以用来立即调用nakuru的方法。
|
||||
"""
|
||||
try:
|
||||
ret = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self.gocq_loop).result()
|
||||
return ret
|
||||
except BaseException as e:
|
||||
raise e
|
||||
|
||||
|
||||
@@ -8,42 +8,123 @@ import requests
|
||||
from cores.qqbot.personality import personalities
|
||||
from util import general_utils as gu
|
||||
from nakuru.entities.components import Plain, At, Image
|
||||
from botpy.types.message import Reference
|
||||
|
||||
class NakuruGuildMember():
|
||||
tiny_id: int
|
||||
user_id: int
|
||||
title: str
|
||||
nickname: str
|
||||
role: int
|
||||
icon_url: str
|
||||
|
||||
class NakuruGuildMessage():
|
||||
type: str = "GuildMessage"
|
||||
self_id: int
|
||||
self_tiny_id: int
|
||||
sub_type: str
|
||||
message_id: str
|
||||
guild_id: int
|
||||
channel_id: int
|
||||
user_id: int
|
||||
message: list
|
||||
sender: NakuruGuildMember
|
||||
raw_message: Message
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.__dict__)
|
||||
|
||||
class QQChan():
|
||||
def __init__(self, cnt: dict = None) -> None:
|
||||
self.qqchan_cnt = 0
|
||||
|
||||
def get_cnt(self):
|
||||
return self.qqchan_cnt
|
||||
|
||||
def set_cnt(self, cnt):
|
||||
self.qqchan_cnt = cnt
|
||||
|
||||
def run_bot(self, botclient, appid, token):
|
||||
intents = botpy.Intents(public_guild_messages=True, direct_message=True)
|
||||
self.client = botclient
|
||||
self.client.run(appid=appid, token=token)
|
||||
|
||||
# gocq兼容层
|
||||
def gocq_compatible(self, gocq_message_chain: list):
|
||||
# gocq-频道SDK兼容层(发)
|
||||
def gocq_compatible_send(self, gocq_message_chain: list):
|
||||
plain_text = ""
|
||||
image_path = None # only one img supported
|
||||
for i in gocq_message_chain:
|
||||
if isinstance(i, Plain):
|
||||
plain_text += i.text
|
||||
elif isinstance(i, Image) and image_path == None:
|
||||
if i.path is not None:
|
||||
image_path = i.path
|
||||
else:
|
||||
image_path = i.file
|
||||
return plain_text, image_path
|
||||
|
||||
# gocq-频道SDK兼容层(收)
|
||||
def gocq_compatible_receive(self, message: Message) -> NakuruGuildMessage:
|
||||
ngm = NakuruGuildMessage()
|
||||
try:
|
||||
ngm.self_id = message.mentions[0].id
|
||||
ngm.self_tiny_id = message.mentions[0].id
|
||||
except:
|
||||
ngm.self_id = 0
|
||||
ngm.self_tiny_id = 0
|
||||
|
||||
ngm.sub_type = "normal"
|
||||
ngm.message_id = message.id
|
||||
ngm.guild_id = int(message.channel_id)
|
||||
ngm.channel_id = int(message.channel_id)
|
||||
ngm.user_id = int(message.author.id)
|
||||
msg = []
|
||||
plain_content = message.content.replace("<@!"+str(ngm.self_id)+">", "").strip()
|
||||
msg.append(Plain(plain_content))
|
||||
if message.attachments:
|
||||
for i in message.attachments:
|
||||
if i.content_type.startswith("image"):
|
||||
url = i.url
|
||||
if not url.startswith("http"):
|
||||
url = "https://"+url
|
||||
img = Image.fromURL(url)
|
||||
msg.append(img)
|
||||
ngm.message = msg
|
||||
ngm.sender = NakuruGuildMember()
|
||||
ngm.sender.tiny_id = int(message.author.id)
|
||||
ngm.sender.user_id = int(message.author.id)
|
||||
ngm.sender.title = ""
|
||||
ngm.sender.nickname = message.author.username
|
||||
ngm.sender.role = 0
|
||||
ngm.sender.icon_url = message.author.avatar
|
||||
ngm.raw_message = message
|
||||
return ngm
|
||||
|
||||
|
||||
def send_qq_msg(self, message: Message, res, msg_ref = None):
|
||||
|
||||
def send_qq_msg(self, message: NakuruGuildMessage, res):
|
||||
gu.log("回复QQ频道消息: "+str(res), level=gu.LEVEL_INFO, tag="QQ频道", max_len=500)
|
||||
|
||||
self.qqchan_cnt += 1
|
||||
plain_text = ""
|
||||
image_path = None
|
||||
if isinstance(res, list):
|
||||
# 兼容gocq
|
||||
plain_text, image_path = self.gocq_compatible(res)
|
||||
plain_text, image_path = self.gocq_compatible_send(res)
|
||||
elif isinstance(res, str):
|
||||
plain_text = res
|
||||
|
||||
print(plain_text, image_path)
|
||||
# print(plain_text, image_path)
|
||||
msg_ref = Reference(message_id=message.raw_message.id, ignore_get_message_error=False)
|
||||
if image_path is not None:
|
||||
msg_ref = None
|
||||
if image_path.startswith("http"):
|
||||
pic_res = requests.get(image_path, stream = True)
|
||||
if pic_res.status_code == 200:
|
||||
image = PILImage.open(io.BytesIO(pic_res.content))
|
||||
image_path = gu.save_temp_img(image)
|
||||
|
||||
try:
|
||||
reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(plain_text), message_reference = msg_ref, file_image=image_path), self.client.loop)
|
||||
reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(plain_text), message_reference = msg_ref, file_image=image_path), self.client.loop)
|
||||
reply_res.result()
|
||||
except BaseException as e:
|
||||
# 分割过长的消息
|
||||
@@ -52,21 +133,21 @@ class QQChan():
|
||||
split_res.append(plain_text[:len(plain_text)//2])
|
||||
split_res.append(plain_text[len(plain_text)//2:])
|
||||
for i in split_res:
|
||||
reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(i), message_reference = msg_ref, file_image=image_path), self.client.loop)
|
||||
reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(i), message_reference = msg_ref, file_image=image_path), self.client.loop)
|
||||
reply_res.result()
|
||||
else:
|
||||
# 发送qq信息
|
||||
try:
|
||||
# 防止被qq频道过滤消息
|
||||
plain_text = plain_text.replace(".", " . ")
|
||||
reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(plain_text), message_reference = msg_ref, file_image=image_path), self.client.loop)
|
||||
reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(plain_text), message_reference = msg_ref, file_image=image_path), self.client.loop)
|
||||
# 发送信息
|
||||
except BaseException as e:
|
||||
print("QQ频道API错误: \n"+str(e))
|
||||
try:
|
||||
reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(str.join(" ", plain_text)), message_reference = msg_ref, file_image=image_path), self.client.loop)
|
||||
reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(str.join(" ", plain_text)), message_reference = msg_ref, file_image=image_path), self.client.loop)
|
||||
except BaseException as e:
|
||||
plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
|
||||
plain_text = plain_text.replace(".", "·")
|
||||
asyncio.run_coroutine_threadsafe(message.reply(content=plain_text), self.client.loop).result()
|
||||
asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=plain_text), self.client.loop).result()
|
||||
# send(message, f"QQ频道API错误:{str(e)}\n下面是格式化后的回答:\n{f_res}")
|
||||
@@ -4,15 +4,10 @@ class Provider:
|
||||
def __init__(self, cfg):
|
||||
pass
|
||||
|
||||
def text_chat(self, prompt):
|
||||
pass
|
||||
|
||||
def image_chat(self, prompt):
|
||||
pass
|
||||
|
||||
def memory(self):
|
||||
@abc.abstractmethod
|
||||
def text_chat(self, prompt, session_id, image_url: None, function_call: None):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def forget(self) -> bool:
|
||||
def forget(self, session_id = None) -> bool:
|
||||
pass
|
||||
@@ -1,4 +1,5 @@
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
@@ -7,26 +8,43 @@ from cores.database.conn import dbConn
|
||||
from model.provider.provider import Provider
|
||||
import threading
|
||||
from util import general_utils as gu
|
||||
import traceback
|
||||
import tiktoken
|
||||
|
||||
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
|
||||
key_record_path = abs_path+'chatgpt_key_record'
|
||||
key_record_path = abs_path + 'chatgpt_key_record'
|
||||
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def __init__(self, cfg):
|
||||
self.key_list = []
|
||||
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
|
||||
openai.api_base = cfg['api_base']
|
||||
# 如果 cfg['key']中有长度为1的字符串,那么是格式错误,直接报错
|
||||
for key in cfg['key']:
|
||||
if len(key) == 1:
|
||||
input("检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格),请退出程序并检查配置文件,按回车跳过。")
|
||||
raise BaseException("配置文件格式错误")
|
||||
if cfg['key'] != '' and cfg['key'] != None:
|
||||
gu.log("读取ChatGPT Key成功")
|
||||
self.key_list = cfg['key']
|
||||
else:
|
||||
input("[System] 请先去完善ChatGPT的Key。详情请前往https://beta.openai.com/account/api-keys")
|
||||
if len(self.key_list) == 0:
|
||||
raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。")
|
||||
|
||||
# init key record
|
||||
self.init_key_record()
|
||||
self.key_stat = {}
|
||||
for k in self.key_list:
|
||||
self.key_stat[k] = {'exceed': False, 'used': 0}
|
||||
|
||||
self.chatGPT_configs = cfg['chatGPTConfigs']
|
||||
gu.log(f'加载ChatGPTConfigs: {self.chatGPT_configs}')
|
||||
self.api_base = None
|
||||
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
|
||||
self.api_base = cfg['api_base']
|
||||
print(f"设置 api_base 为: {self.api_base}")
|
||||
# openai client
|
||||
self.client = OpenAI(
|
||||
api_key=self.key_list[0],
|
||||
base_url=self.api_base
|
||||
)
|
||||
|
||||
self.openai_model_configs: dict = cfg['chatGPTConfigs']
|
||||
gu.log(f'加载 OpenAI Chat Configs: {self.openai_model_configs}')
|
||||
self.openai_configs = cfg
|
||||
# 会话缓存
|
||||
self.session_dict = {}
|
||||
@@ -35,14 +53,16 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# 历史记录持久化间隔时间
|
||||
self.history_dump_interval = 20
|
||||
|
||||
self.enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
# 读取历史记录
|
||||
try:
|
||||
db1 = dbConn()
|
||||
for session in db1.get_all_session():
|
||||
self.session_dict[session[0]] = json.loads(session[1])['data']
|
||||
gu.log("历史记录读取成功喵")
|
||||
gu.log("读取历史记录成功。")
|
||||
except BaseException as e:
|
||||
gu.log("历史记录读取失败喵", level=gu.LEVEL_ERROR)
|
||||
gu.log("读取历史记录失败,但不影响使用。", level=gu.LEVEL_ERROR)
|
||||
|
||||
|
||||
# 读取统计信息
|
||||
@@ -67,7 +87,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.now_personality = {}
|
||||
|
||||
|
||||
# 转储历史记录的定时器~ Soulter
|
||||
# 转储历史记录
|
||||
def dump_history(self):
|
||||
time.sleep(10)
|
||||
db = dbConn()
|
||||
@@ -90,9 +110,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# 每隔10分钟转储一次
|
||||
time.sleep(10*self.history_dump_interval)
|
||||
|
||||
def text_chat(self, prompt, session_id = None):
|
||||
def text_chat(self, prompt, session_id = None, image_url = None, function_call=None):
|
||||
if session_id is None:
|
||||
session_id = "unknown"
|
||||
if "unknown" in self.session_dict:
|
||||
del self.session_dict["unknown"]
|
||||
# 会话机制
|
||||
if session_id not in self.session_dict:
|
||||
@@ -112,48 +133,92 @@ class ProviderOpenAIOfficial(Provider):
|
||||
f.flush()
|
||||
f.close()
|
||||
|
||||
cache_data_list, new_record, req = self.wrap(prompt, session_id)
|
||||
# 使用 tictoken 截断消息
|
||||
_encoded_prompt = self.enc.encode(prompt)
|
||||
prompt = self.enc.decode(_encoded_prompt[:self.openai_model_configs['max_tokens'] - 100])
|
||||
gu.log(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。", level=gu.LEVEL_WARNING, max_len=300)
|
||||
|
||||
cache_data_list, new_record, req = self.wrap(prompt, session_id, image_url)
|
||||
gu.log(f"CACHE_DATA_: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, max_len=99999)
|
||||
gu.log(f"OPENAI REQUEST: {str(req)}", level=gu.LEVEL_DEBUG, max_len=9999)
|
||||
retry = 0
|
||||
response = None
|
||||
err = ''
|
||||
while retry < 5:
|
||||
|
||||
# 截断倍率
|
||||
truncate_rate = 0.75
|
||||
|
||||
use_gpt4v = False
|
||||
for i in req:
|
||||
if isinstance(i['content'], list):
|
||||
use_gpt4v = True
|
||||
break
|
||||
if image_url is not None:
|
||||
use_gpt4v = True
|
||||
if use_gpt4v:
|
||||
conf = self.openai_model_configs.copy()
|
||||
conf['model'] = 'gpt-4-vision-preview'
|
||||
else:
|
||||
conf = self.openai_model_configs
|
||||
print(req)
|
||||
while retry < 10:
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
if function_call is None:
|
||||
response = self.client.chat.completions.create(
|
||||
messages=req,
|
||||
**self.chatGPT_configs
|
||||
**conf
|
||||
)
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
messages=req,
|
||||
tools = function_call,
|
||||
**conf
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
|
||||
raise e
|
||||
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
||||
gu.log("当前Key已超额或异常, 正在切换", level=gu.LEVEL_WARNING)
|
||||
self.key_stat[openai.api_key]['exceed'] = True
|
||||
self.save_key_record()
|
||||
|
||||
response, is_switched = self.handle_switch_key(req)
|
||||
self.key_stat[self.client.api_key]['exceed'] = True
|
||||
is_switched = self.handle_switch_key()
|
||||
if not is_switched:
|
||||
# 所有Key都超额或不正常
|
||||
raise e
|
||||
else:
|
||||
break
|
||||
retry -= 1
|
||||
elif 'maximum context length' in str(e):
|
||||
gu.log("token超限, 清空对应缓存")
|
||||
gu.log("token超限, 清空对应缓存,并进行消息截断")
|
||||
self.session_dict[session_id] = []
|
||||
prompt = prompt[:int(len(prompt)*truncate_rate)]
|
||||
truncate_rate -= 0.05
|
||||
cache_data_list, new_record, req = self.wrap(prompt, session_id)
|
||||
elif 'Limit: 3 / min. Please try again in 20s.' in str(e):
|
||||
time.sleep(60)
|
||||
|
||||
elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e):
|
||||
time.sleep(30)
|
||||
continue
|
||||
else:
|
||||
gu.log(str(e), level=gu.LEVEL_ERROR)
|
||||
time.sleep(2)
|
||||
err = str(e)
|
||||
retry+=1
|
||||
if retry >= 5:
|
||||
retry += 1
|
||||
if retry >= 10:
|
||||
gu.log(r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见https://github.com/Soulter/QQChannelChatGPT/wiki/%E4%BA%8C%E3%80%81%E9%A1%B9%E7%9B%AE%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E9%85%8D%E7%BD%AE", max_len=999)
|
||||
raise BaseException("连接出错: "+str(err))
|
||||
assert isinstance(response, ChatCompletion)
|
||||
gu.log(f"OPENAI RESPONSE: {response.usage}", level=gu.LEVEL_DEBUG, max_len=9999)
|
||||
|
||||
self.key_stat[openai.api_key]['used'] += response['usage']['total_tokens']
|
||||
self.save_key_record()
|
||||
# print("[ChatGPT] "+str(response["choices"][0]["message"]["content"]))
|
||||
chatgpt_res = str(response["choices"][0]["message"]["content"]).strip()
|
||||
current_usage_tokens = response['usage']['total_tokens']
|
||||
# 结果分类
|
||||
choice = response.choices[0]
|
||||
if choice.message.content != None:
|
||||
# 文本形式
|
||||
chatgpt_res = str(choice.message.content).strip()
|
||||
elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0:
|
||||
# tools call (function calling)
|
||||
return choice.message.tool_calls[0].function
|
||||
|
||||
self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens
|
||||
current_usage_tokens = response.usage.total_tokens
|
||||
|
||||
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
|
||||
if current_usage_tokens > self.max_tokens:
|
||||
@@ -163,7 +228,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if index >= len(cache_data_list):
|
||||
break
|
||||
# 保留人格信息
|
||||
if 'user' in cache_data_list[index] and cache_data_list[index]['user']['role'] != 'system':
|
||||
if cache_data_list[index]['type'] != 'personality':
|
||||
t -= int(cache_data_list[index]['single_tokens'])
|
||||
del cache_data_list[index]
|
||||
else:
|
||||
@@ -182,6 +247,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
new_record['single_tokens'] = current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
|
||||
else:
|
||||
new_record['single_tokens'] = current_usage_tokens
|
||||
|
||||
cache_data_list.append(new_record)
|
||||
|
||||
self.session_dict[session_id] = cache_data_list
|
||||
@@ -193,13 +259,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
image_url = ''
|
||||
while retry < 5:
|
||||
try:
|
||||
# print("test1")
|
||||
response = openai.Image.create(
|
||||
response = self.client.images.generate(
|
||||
prompt=prompt,
|
||||
n=img_num,
|
||||
size=img_size
|
||||
)
|
||||
# print("test2")
|
||||
image_url = []
|
||||
for i in range(img_num):
|
||||
image_url.append(response['data'][i]['url'])
|
||||
@@ -208,23 +272,22 @@ class ProviderOpenAIOfficial(Provider):
|
||||
gu.log(str(e), level=gu.LEVEL_ERROR)
|
||||
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
|
||||
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
||||
gu.log("当前Key已超额或者不正常, 正在切换", level=gu.LEVEL_WARNING)
|
||||
self.key_stat[openai.api_key]['exceed'] = True
|
||||
self.save_key_record()
|
||||
|
||||
response, is_switched = self.handle_switch_key(req)
|
||||
gu.log("当前 Key 已超额或者不正常, 正在切换", level=gu.LEVEL_WARNING)
|
||||
self.key_stat[self.client.api_key]['exceed'] = True
|
||||
is_switched = self.handle_switch_key()
|
||||
if not is_switched:
|
||||
# 所有Key都超额或不正常
|
||||
raise e
|
||||
else:
|
||||
break
|
||||
retry += 1
|
||||
if retry >= 5:
|
||||
raise BaseException("连接超时")
|
||||
|
||||
return image_url
|
||||
|
||||
def forget(self, session_id) -> bool:
|
||||
def forget(self, session_id = None) -> bool:
|
||||
if session_id is None:
|
||||
return False
|
||||
self.session_dict[session_id] = []
|
||||
return True
|
||||
|
||||
@@ -285,7 +348,20 @@ class ProviderOpenAIOfficial(Provider):
|
||||
return -1, -1, -1, -1
|
||||
|
||||
# 包装信息
|
||||
def wrap(self, prompt, session_id):
|
||||
def wrap(self, prompt, session_id, image_url = None):
|
||||
if image_url is not None:
|
||||
prompt = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}
|
||||
]
|
||||
# 获得缓存信息
|
||||
context = self.session_dict[session_id]
|
||||
new_record = {
|
||||
@@ -294,6 +370,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"content": prompt,
|
||||
},
|
||||
"AI": {},
|
||||
'type': "common",
|
||||
'usage_tokens': 0,
|
||||
}
|
||||
req_list = []
|
||||
@@ -305,105 +382,53 @@ class ProviderOpenAIOfficial(Provider):
|
||||
req_list.append(new_record['user'])
|
||||
return context, new_record, req_list
|
||||
|
||||
def handle_switch_key(self, req):
|
||||
def handle_switch_key(self):
|
||||
# messages = [{"role": "user", "content": prompt}]
|
||||
while True:
|
||||
is_all_exceed = True
|
||||
for key in self.key_stat:
|
||||
if key == None:
|
||||
if key == None or self.key_stat[key]['exceed']:
|
||||
continue
|
||||
if not self.key_stat[key]['exceed']:
|
||||
is_all_exceed = False
|
||||
openai.api_key = key
|
||||
self.client.api_key = key
|
||||
gu.log(f"切换到Key: {key}, 已使用token: {self.key_stat[key]['used']}", level=gu.LEVEL_INFO)
|
||||
if len(req) > 0:
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=req,
|
||||
**self.chatGPT_configs
|
||||
)
|
||||
return response, True
|
||||
except Exception as e:
|
||||
if 'You exceeded' in str(e):
|
||||
gu.log("当前Key已超额, 正在切换")
|
||||
self.key_stat[openai.api_key]['exceed'] = True
|
||||
self.save_key_record()
|
||||
time.sleep(1)
|
||||
continue
|
||||
else:
|
||||
gu.log(str(e), level=gu.LEVEL_ERROR)
|
||||
else:
|
||||
return True
|
||||
break
|
||||
if is_all_exceed:
|
||||
gu.log("所有Key已超额", level=gu.LEVEL_CRITICAL)
|
||||
return None, False
|
||||
else:
|
||||
gu.log("在切换key时程序异常。", level=gu.LEVEL_ERROR)
|
||||
return None, False
|
||||
return False
|
||||
return True
|
||||
|
||||
def getConfigs(self):
|
||||
def get_configs(self):
|
||||
return self.openai_configs
|
||||
|
||||
def save_key_record(self):
|
||||
with open(key_record_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.key_stat, f)
|
||||
|
||||
def get_key_stat(self):
|
||||
return self.key_stat
|
||||
|
||||
def get_key_list(self):
|
||||
return self.key_list
|
||||
|
||||
def get_curr_key(self):
|
||||
return self.client.api_key
|
||||
|
||||
# 添加key
|
||||
def append_key(self, key, sponsor):
|
||||
self.key_list.append(key)
|
||||
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
|
||||
self.save_key_record()
|
||||
self.init_key_record()
|
||||
|
||||
# 检查key是否可用
|
||||
def check_key(self, key):
|
||||
pre_key = openai.api_key
|
||||
openai.api_key = key
|
||||
messages = [{"role": "user", "content": "1"}]
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=messages,
|
||||
**self.chatGPT_configs
|
||||
client_ = OpenAI(
|
||||
api_key=key,
|
||||
base_url=self.api_base
|
||||
)
|
||||
messages = [{"role": "user", "content": "please just echo `test`"}]
|
||||
try:
|
||||
client_.chat.completions.create(
|
||||
messages=messages,
|
||||
**self.openai_model_configs
|
||||
)
|
||||
openai.api_key = pre_key
|
||||
return True
|
||||
except Exception as e:
|
||||
pass
|
||||
openai.api_key = pre_key
|
||||
return False
|
||||
|
||||
#将key_list的key转储到key_record中,并记录相关数据
|
||||
def init_key_record(self):
|
||||
if not os.path.exists(key_record_path):
|
||||
with open(key_record_path, 'w', encoding='utf-8') as f:
|
||||
json.dump({}, f)
|
||||
with open(key_record_path, 'r', encoding='utf-8') as keyfile:
|
||||
try:
|
||||
self.key_stat = json.load(keyfile)
|
||||
except Exception as e:
|
||||
gu.log(str(e), level=gu.LEVEL_ERROR)
|
||||
self.key_stat = {}
|
||||
finally:
|
||||
for key in self.key_list:
|
||||
if key not in self.key_stat:
|
||||
self.key_stat[key] = {'exceed': False, 'used': 0}
|
||||
# if openai.api_key is None:
|
||||
# openai.api_key = key
|
||||
else:
|
||||
# if self.key_stat[key]['exceed']:
|
||||
# print(f"Key: {key} 已超额")
|
||||
# continue
|
||||
# else:
|
||||
# if openai.api_key is None:
|
||||
# openai.api_key = key
|
||||
# print(f"使用Key: {key}, 已使用token: {self.key_stat[key]['used']}")
|
||||
pass
|
||||
if openai.api_key == None:
|
||||
self.handle_switch_key("")
|
||||
self.save_key_record()
|
||||
|
||||
|
||||
@@ -7,8 +7,10 @@ import time
|
||||
|
||||
|
||||
class ProviderRevChatGPT(Provider):
|
||||
def __init__(self, config):
|
||||
self.rev_chatgpt = []
|
||||
def __init__(self, config, base_url = None):
|
||||
if base_url == "":
|
||||
base_url = None
|
||||
self.rev_chatgpt: list[dict] = []
|
||||
self.cc = cc.CmdConfig()
|
||||
for i in range(0, len(config['account'])):
|
||||
try:
|
||||
@@ -28,26 +30,34 @@ class ProviderRevChatGPT(Provider):
|
||||
rev_account_config['PUID'] = self.cc.get("rev_chatgpt_PUID")
|
||||
if len(self.cc.get("rev_chatgpt_unverified_plugin_domains")) > 0:
|
||||
rev_account_config['unverified_plugin_domains'] = self.cc.get("rev_chatgpt_unverified_plugin_domains")
|
||||
cb = Chatbot(config=rev_account_config)
|
||||
cb = Chatbot(config=rev_account_config, base_url=base_url)
|
||||
# cb.captcha_solver = self.__captcha_solver
|
||||
# 后八位c
|
||||
g_id = rev_account_config['access_token'][-8:]
|
||||
revstat = {
|
||||
'id': g_id,
|
||||
'obj': cb,
|
||||
'busy': False
|
||||
'busy': False,
|
||||
'user': []
|
||||
}
|
||||
self.rev_chatgpt.append(revstat)
|
||||
except BaseException as e:
|
||||
gu.log(f"创建逆向ChatGPT负载{str(i+1)}失败: {str(e)}", level=gu.LEVEL_ERROR, tag="RevChatGPT")
|
||||
|
||||
def forget(self) -> bool:
|
||||
def forget(self, session_id = None) -> bool:
|
||||
for i in self.rev_chatgpt:
|
||||
for user in i['user']:
|
||||
if session_id == user['id']:
|
||||
try:
|
||||
i['obj'].reset_chat()
|
||||
return True
|
||||
except BaseException as e:
|
||||
gu.log(f"重置RevChatGPT失败。原因: {str(e)}", level=gu.LEVEL_ERROR, tag="RevChatGPT")
|
||||
return False
|
||||
return False
|
||||
|
||||
# def __captcha_solver(images: list[str], challenge_details: dict) -> int:
|
||||
# # Create tempfile
|
||||
# print("Captcha solver called")
|
||||
# print(images)
|
||||
# print(challenge_details)
|
||||
# input("Press Enter to continue...")
|
||||
# return 0
|
||||
def get_revchatgpt(self) -> list:
|
||||
return self.rev_chatgpt
|
||||
|
||||
def request_text(self, prompt: str, bot) -> str:
|
||||
resp = ''
|
||||
@@ -66,7 +76,8 @@ class ProviderRevChatGPT(Provider):
|
||||
raise e
|
||||
if e.code == typings.ErrorType.PROHIBITED_CONCURRENT_QUERY_ERROR:
|
||||
raise e
|
||||
|
||||
if "Your authentication token has expired. Please try signing in again." in str(e):
|
||||
raise e
|
||||
if "The message you submitted was too long" in str(e):
|
||||
raise e
|
||||
if "You've reached our limit of messages per hour." in str(e):
|
||||
@@ -90,28 +101,110 @@ class ProviderRevChatGPT(Provider):
|
||||
# print("[RevChatGPT] "+str(resp))
|
||||
return resp
|
||||
|
||||
def text_chat(self, prompt) -> str:
|
||||
def text_chat(self, prompt, session_id = None, image_url = None, function_call=None) -> str:
|
||||
|
||||
# 选择一个人少的账号。
|
||||
selected_revstat = None
|
||||
min_revstat = None
|
||||
min_ = None
|
||||
new_user = False
|
||||
conversation_id = ''
|
||||
parent_id = ''
|
||||
for revstat in self.rev_chatgpt:
|
||||
for user in revstat['user']:
|
||||
if session_id == user['id']:
|
||||
selected_revstat = revstat
|
||||
conversation_id = user['conversation_id']
|
||||
parent_id = user['parent_id']
|
||||
break
|
||||
if min_ is None:
|
||||
min_ = len(revstat['user'])
|
||||
min_revstat = revstat
|
||||
elif len(revstat['user']) < min_:
|
||||
min_ = len(revstat['user'])
|
||||
min_revstat = revstat
|
||||
# if session_id in revstat['user']:
|
||||
# selected_revstat = revstat
|
||||
# break
|
||||
|
||||
if selected_revstat is None:
|
||||
selected_revstat = min_revstat
|
||||
selected_revstat['user'].append({
|
||||
'id': session_id,
|
||||
'conversation_id': '',
|
||||
'parent_id': ''
|
||||
})
|
||||
new_user = True
|
||||
|
||||
gu.log(f"选择账号{str(selected_revstat)}", tag="RevChatGPT", level=gu.LEVEL_DEBUG)
|
||||
|
||||
while selected_revstat['busy']:
|
||||
gu.log(f"账号忙碌,等待中...", tag="RevChatGPT", level=gu.LEVEL_DEBUG)
|
||||
time.sleep(1)
|
||||
selected_revstat['busy'] = True
|
||||
|
||||
if not new_user:
|
||||
# 非新用户,则使用其专用的会话
|
||||
selected_revstat['obj'].conversation_id = conversation_id
|
||||
selected_revstat['obj'].parent_id = parent_id
|
||||
else:
|
||||
# 新用户,则使用新的会话
|
||||
selected_revstat['obj'].reset_chat()
|
||||
|
||||
res = ''
|
||||
err_msg = ''
|
||||
cursor = 0
|
||||
for revstat in self.rev_chatgpt:
|
||||
cursor += 1
|
||||
if not revstat['busy']:
|
||||
err_cnt = 0
|
||||
while err_cnt < 15:
|
||||
try:
|
||||
revstat['busy'] = True
|
||||
res = self.request_text(prompt, revstat['obj'])
|
||||
revstat['busy'] = False
|
||||
res = self.request_text(prompt, selected_revstat['obj'])
|
||||
selected_revstat['busy'] = False
|
||||
# 记录新用户的会话
|
||||
if new_user:
|
||||
i = 0
|
||||
for user in selected_revstat['user']:
|
||||
if user['id'] == session_id:
|
||||
selected_revstat['user'][i]['conversation_id'] = selected_revstat['obj'].conversation_id
|
||||
selected_revstat['user'][i]['parent_id'] = selected_revstat['obj'].parent_id
|
||||
break
|
||||
i += 1
|
||||
return res.strip()
|
||||
# todo: 细化错误管理
|
||||
except BaseException as e:
|
||||
revstat['busy'] = False
|
||||
gu.log(f"请求出现问题: {str(e)}", level=gu.LEVEL_WARNING, tag="RevChatGPT")
|
||||
err_msg += f"账号{cursor} - 错误原因: {str(e)}"
|
||||
continue
|
||||
else:
|
||||
err_msg += f"账号{cursor} - 错误原因: 忙碌"
|
||||
continue
|
||||
raise Exception(f'回复失败。错误跟踪:{err_msg}')
|
||||
if "Your authentication token has expired. Please try signing in again." in str(e):
|
||||
raise Exception(f"此账号(access_token后8位为{selected_revstat['id']})的access_token已过期,请重新获取,或者切换账号。")
|
||||
if "The message you submitted was too long" in str(e):
|
||||
raise Exception("发送的消息太长,请分段发送。")
|
||||
if "You've reached our limit of messages per hour." in str(e):
|
||||
raise Exception("触发RevChatGPT请求频率限制。请1小时后再试,或者切换账号。")
|
||||
gu.log(f"请求异常: {str(e)}", level=gu.LEVEL_WARNING, tag="RevChatGPT")
|
||||
err_cnt += 1
|
||||
time.sleep(3)
|
||||
|
||||
raise Exception(f'回复失败。原因:{err_msg}。如果您设置了多个账号,可以使用/switch指令切换账号。输入/switch查看详情。')
|
||||
|
||||
|
||||
# while self.is_all_busy():
|
||||
# time.sleep(1)
|
||||
# res = ''
|
||||
# err_msg = ''
|
||||
# cursor = 0
|
||||
# for revstat in self.rev_chatgpt:
|
||||
# cursor += 1
|
||||
# if not revstat['busy']:
|
||||
# try:
|
||||
# revstat['busy'] = True
|
||||
# res = self.request_text(prompt, revstat['obj'])
|
||||
# revstat['busy'] = False
|
||||
# return res.strip()
|
||||
# # todo: 细化错误管理
|
||||
# except BaseException as e:
|
||||
# revstat['busy'] = False
|
||||
# gu.log(f"请求出现问题: {str(e)}", level=gu.LEVEL_WARNING, tag="RevChatGPT")
|
||||
# err_msg += f"账号{cursor} - 错误原因: {str(e)}"
|
||||
# continue
|
||||
# else:
|
||||
# err_msg += f"账号{cursor} - 错误原因: 忙碌"
|
||||
# continue
|
||||
# raise Exception(f'回复失败。错误跟踪:{err_msg}')
|
||||
|
||||
def is_all_busy(self) -> bool:
|
||||
for revstat in self.rev_chatgpt:
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from model.provider.provider import Provider
|
||||
from EdgeGPT import Chatbot, ConversationStyle
|
||||
# from EdgeGPT import Chatbot, ConversationStyle
|
||||
import json
|
||||
import os
|
||||
from util import general_utils as gu
|
||||
from util.cmd_config import CmdConfig as cc
|
||||
|
||||
import time
|
||||
from EdgeGPT.EdgeUtils import Query, Cookie
|
||||
from EdgeGPT.EdgeGPT import Chatbot as EdgeChatbot, ConversationStyle, NotAllowedToAccess
|
||||
|
||||
class ProviderRevEdgeGPT(Provider):
|
||||
def __init__(self):
|
||||
@@ -15,21 +17,27 @@ class ProviderRevEdgeGPT(Provider):
|
||||
proxy = cc.get("bing_proxy", None)
|
||||
if proxy == "":
|
||||
proxy = None
|
||||
self.bot = Chatbot(cookies=cookies, proxy = proxy)
|
||||
# q = Query("Hello, bing!", cookie_files="./cookies.json")
|
||||
# print(q)
|
||||
self.bot = EdgeChatbot(cookies=cookies, proxy = "http://127.0.0.1:7890")
|
||||
ret = self.bot.ask_stream("Hello, bing!", conversation_style=ConversationStyle.creative, wss_link="wss://ai.nothingnessvoid.tech/sydney/ChatHub")
|
||||
# self.bot = Chatbot(cookies=cookies, proxy = proxy)
|
||||
for i in ret:
|
||||
print(i, flush=True)
|
||||
|
||||
def is_busy(self):
|
||||
return self.busy
|
||||
|
||||
async def forget(self):
|
||||
async def forget(self, session_id = None):
|
||||
try:
|
||||
await self.bot.reset()
|
||||
return True
|
||||
except BaseException:
|
||||
return False
|
||||
|
||||
async def text_chat(self, prompt, platform = 'none'):
|
||||
if self.busy:
|
||||
return
|
||||
async def text_chat(self, prompt, platform = 'none', image_url=None, function_call=None):
|
||||
while self.busy:
|
||||
time.sleep(1)
|
||||
self.busy = True
|
||||
resp = 'err'
|
||||
err_count = 0
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
pydantic~=1.10.4
|
||||
requests~=2.28.1
|
||||
openai~=0.27.4
|
||||
qq-botpy~=1.1.2
|
||||
revChatGPT~=6.8.6
|
||||
baidu-aip~=4.16.9
|
||||
EdgeGPT~=0.1.22.1
|
||||
openai~=1.2.3
|
||||
qq-botpy
|
||||
chardet~=5.1.0
|
||||
Pillow~=9.4.0
|
||||
GitPython~=3.1.31
|
||||
nakuru-project
|
||||
beautifulsoup4
|
||||
googlesearch-python
|
||||
tictoken
|
||||
readability-lxml
|
||||
EdgeGPT
|
||||
revChatGPT~=6.8.6
|
||||
baidu-aip~=4.16.9
|
||||
@@ -1,3 +0,0 @@
|
||||
class PromptExceededError(Exception):
|
||||
|
||||
pass
|
||||
@@ -2,6 +2,7 @@
|
||||
import json
|
||||
import util.general_utils as gu
|
||||
|
||||
import time
|
||||
class FuncCallJsonFormatError(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
@@ -24,9 +25,18 @@ class FuncCall():
|
||||
def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj = None) -> None:
|
||||
if name == None or func_args == None or desc == None or func_obj == None:
|
||||
raise FuncCallJsonFormatError("name, func_args, desc must be provided.")
|
||||
params = {
|
||||
"type": "object", # hardcore here
|
||||
"properties": {}
|
||||
}
|
||||
for param in func_args:
|
||||
params['properties'][param['name']] = {
|
||||
"type": param['type'],
|
||||
"description": param['description']
|
||||
}
|
||||
self._func = {
|
||||
"name": name,
|
||||
"args": func_args,
|
||||
"parameters": params,
|
||||
"description": desc,
|
||||
"func_obj": func_obj,
|
||||
}
|
||||
@@ -37,18 +47,30 @@ class FuncCall():
|
||||
for f in self.func_list:
|
||||
_l.append({
|
||||
"name": f["name"],
|
||||
"args": f["args"],
|
||||
"parameters": f["parameters"],
|
||||
"description": f["description"],
|
||||
})
|
||||
return json.dumps(_l, indent=intent, ensur_ascii=False)
|
||||
|
||||
return json.dumps(_l, indent=intent, ensure_ascii=False)
|
||||
def get_func(self) -> list:
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
_l.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f["name"],
|
||||
"parameters": f["parameters"],
|
||||
"description": f["description"],
|
||||
}
|
||||
})
|
||||
return _l
|
||||
|
||||
def func_call(self, question, func_definition, is_task = False, tasks = None, taskindex = -1, is_summary = True):
|
||||
def func_call(self, question, func_definition, is_task = False, tasks = None, taskindex = -1, is_summary = True, session_id = None):
|
||||
|
||||
funccall_prompt = """
|
||||
我正在实现function call功能,该功能旨在让你变成给定的问题到给定的函数的解析器(这意味着你不是创造函数)。
|
||||
下面会给你提供可能会用到函数的相关信息,和一个问题,你需要将其转换成给定的函数调用。
|
||||
- 你的返回信息只含json,且严格仿照以下内容(不含注释):
|
||||
我正实现function call功能,该功能旨在让你变成给定的问题到给定的函数的解析器(意味着你不是创造函数)。
|
||||
下面会给你提供可能用到的函数相关信息和一个问题,你需要将其转换成给定的函数调用。
|
||||
- 你的返回信息只含json,请严格仿照以下内容(不含注释),必须含有`res`,`func_call`字段:
|
||||
```
|
||||
{
|
||||
"res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。
|
||||
@@ -95,7 +117,7 @@ class FuncCall():
|
||||
_c = 0
|
||||
while _c < 3:
|
||||
try:
|
||||
res = self.provider.text_chat(prompt)
|
||||
res = self.provider.text_chat(prompt, session_id)
|
||||
if res.find('```') != -1:
|
||||
res = res[res.find('```json') + 7: res.rfind('```')]
|
||||
gu.log("REVGPT func_call json result", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
|
||||
@@ -111,7 +133,7 @@ class FuncCall():
|
||||
|
||||
invoke_func_res = ""
|
||||
|
||||
if len(res["func_call"]) > 0:
|
||||
if "func_call" in res and len(res["func_call"]) > 0:
|
||||
task_list = res["func_call"]
|
||||
|
||||
invoke_func_res_list = []
|
||||
@@ -140,7 +162,7 @@ class FuncCall():
|
||||
|
||||
# 生成返回结果
|
||||
after_prompt = """
|
||||
函数返回以下内容:"""+invoke_func_res+"""
|
||||
有以下内容:"""+invoke_func_res+"""
|
||||
请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。
|
||||
用户的提问是:
|
||||
```""" + question + """```
|
||||
@@ -157,7 +179,7 @@ class FuncCall():
|
||||
_c = 0
|
||||
while _c < 5:
|
||||
try:
|
||||
res = self.provider.text_chat(after_prompt)
|
||||
res = self.provider.text_chat(after_prompt, session_id)
|
||||
# 截取```之间的内容
|
||||
gu.log("DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
|
||||
print(res)
|
||||
@@ -174,6 +196,7 @@ class FuncCall():
|
||||
raise e
|
||||
if "The message you submitted was too long" in str(e):
|
||||
# 如果返回的内容太长了,那么就截取一部分
|
||||
time.sleep(3)
|
||||
invoke_func_res = invoke_func_res[:int(len(invoke_func_res) / 2)]
|
||||
after_prompt = """
|
||||
函数返回以下内容:"""+invoke_func_res+"""
|
||||
|
||||
@@ -33,11 +33,20 @@ BG_COLORS = {
|
||||
"default": "49",
|
||||
}
|
||||
|
||||
LEVEL_DEBUG = "DEBUG"
|
||||
LEVEL_INFO = "INFO"
|
||||
LEVEL_WARNING = "WARNING"
|
||||
LEVEL_ERROR = "ERROR"
|
||||
LEVEL_CRITICAL = "CRITICAL"
|
||||
|
||||
level_codes = {
|
||||
LEVEL_DEBUG: 0,
|
||||
LEVEL_INFO: 1,
|
||||
LEVEL_WARNING: 2,
|
||||
LEVEL_ERROR: 3,
|
||||
LEVEL_CRITICAL: 4
|
||||
}
|
||||
|
||||
level_colors = {
|
||||
"INFO": "green",
|
||||
"WARNING": "yellow",
|
||||
@@ -51,10 +60,22 @@ def log(
|
||||
tag: str = "System",
|
||||
fg: str = None,
|
||||
bg: str = None,
|
||||
max_len: int = 300):
|
||||
max_len: int = 500,
|
||||
err: Exception = None,):
|
||||
"""
|
||||
日志记录函数
|
||||
日志打印函数
|
||||
"""
|
||||
_set_level_code = level_codes[LEVEL_INFO]
|
||||
if 'LOG_LEVEL' in os.environ and os.environ['LOG_LEVEL'] in level_codes:
|
||||
_set_level_code = level_codes[os.environ['LOG_LEVEL']]
|
||||
|
||||
if level in level_codes and level_codes[level] < _set_level_code:
|
||||
return
|
||||
|
||||
if err is not None:
|
||||
msg += "\n异常原因: " + str(err)
|
||||
level = LEVEL_ERROR
|
||||
|
||||
if len(msg) > max_len:
|
||||
msg = msg[:max_len] + "..."
|
||||
now = datetime.datetime.now().strftime("%m-%d %H:%M:%S")
|
||||
|
||||
219
util/gplugin.py
219
util/gplugin.py
@@ -7,33 +7,79 @@ from util.func_call import (
|
||||
FuncCallJsonFormatError,
|
||||
FuncNotFoundError
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
import traceback
|
||||
from googlesearch import search, SearchResult
|
||||
from model.provider.provider import Provider
|
||||
import json
|
||||
from readability import Document
|
||||
|
||||
|
||||
def tidy_text(text: str) -> str:
|
||||
return text.strip().replace("\n", "").replace(" ", "").replace("\r", "")
|
||||
'''
|
||||
清理文本,去除空格、换行符等
|
||||
'''
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
def special_fetch_zhihu(link: str) -> str:
|
||||
'''
|
||||
function-calling 函数, 用于获取知乎文章的内容
|
||||
'''
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
|
||||
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
response = requests.get(link, headers=headers)
|
||||
response.encoding = "utf-8"
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
if "zhuanlan.zhihu.com" in link:
|
||||
r = soup.find(class_="Post-RichTextContainer")
|
||||
else:
|
||||
r = soup.find(class_="List-item").find(class_="RichContent-inner")
|
||||
if r is None:
|
||||
print("debug: zhihu none")
|
||||
raise Exception("zhihu none")
|
||||
return tidy_text(r.text)
|
||||
|
||||
def google_web_search(keyword) -> str:
|
||||
'''
|
||||
获取 google 搜索结果, 得到 title、desc、link
|
||||
'''
|
||||
ret = ""
|
||||
index = 1
|
||||
try:
|
||||
ls = search(keyword, advanced=True, num_results=4)
|
||||
for i in ls:
|
||||
desc = i.description
|
||||
try:
|
||||
desc = fetch_website_content(i.url)
|
||||
except BaseException as e:
|
||||
print(f"(google) fetch_website_content err: {str(e)}")
|
||||
gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999)
|
||||
ret += f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n"
|
||||
index += 1
|
||||
except Exception as e:
|
||||
print(f"google search err: {str(e)}")
|
||||
return web_keyword_search_via_bing(keyword)
|
||||
return ret
|
||||
|
||||
def web_keyword_search_via_bing(keyword) -> str:
|
||||
'''
|
||||
获取bing搜索结果, 得到 title、desc、link
|
||||
'''
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
|
||||
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
url = "https://cn.bing.com/search?q="+keyword
|
||||
url = "https://www.bing.com/search?q="+keyword
|
||||
_cnt = 0
|
||||
_detail_store = []
|
||||
while _cnt < 5:
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.encoding = "utf-8"
|
||||
gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999)
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
res = []
|
||||
ols = soup.find(id="b_results")
|
||||
@@ -47,28 +93,39 @@ def web_keyword_search_via_bing(keyword) -> str:
|
||||
"desc": desc,
|
||||
"link": link,
|
||||
})
|
||||
if len(_detail_store) < 2 and "zhihu.com" in link:
|
||||
try:
|
||||
_detail_store.append(special_fetch_zhihu(link)[:800])
|
||||
except BaseException as e:
|
||||
print(f"zhihu parse err: {str(e)}")
|
||||
if len(res) >= 5: # 限制5条
|
||||
break
|
||||
if len(_detail_store) >= 3:
|
||||
continue
|
||||
|
||||
# 爬取前两条的网页内容
|
||||
if "zhihu.com" in link:
|
||||
try:
|
||||
_detail_store.append(special_fetch_zhihu(link))
|
||||
except BaseException as e:
|
||||
print(f"zhihu parse err: {str(e)}")
|
||||
else:
|
||||
try:
|
||||
_detail_store.append(fetch_website_content(link))
|
||||
except BaseException as e:
|
||||
print(f"fetch_website_content err: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"bing parse err: {str(e)}")
|
||||
if len(res) == 0:
|
||||
break
|
||||
if len(_detail_store) > 0:
|
||||
ret = f"{str(res)} \n来源知乎的具体资料: {str(_detail_store)}"
|
||||
ret = f"{str(res)} \n具体网页内容: {str(_detail_store)}"
|
||||
else:
|
||||
ret = f"{str(res)}"
|
||||
return str(ret)
|
||||
except Exception as e:
|
||||
print(f"bing fetch err: {str(e)}")
|
||||
gu.log(f"bing fetch err: {str(e)}")
|
||||
_cnt += 1
|
||||
time.sleep(1)
|
||||
print("fail to fetch bing info, using sougou.")
|
||||
return web_keyword_search_via_sougou(keyword)
|
||||
|
||||
gu.log("fail to fetch bing info, using sougou.")
|
||||
return google_web_search(keyword)
|
||||
|
||||
def web_keyword_search_via_sougou(keyword) -> str:
|
||||
headers = {
|
||||
@@ -92,59 +149,131 @@ def web_keyword_search_via_sougou(keyword) -> str:
|
||||
"title": title,
|
||||
"link": link,
|
||||
})
|
||||
except:
|
||||
pass
|
||||
ret = f"{str(res)} \n全部内容: {tidy_text(soup.text)}"
|
||||
if len(res) >= 5: # 限制5条
|
||||
break
|
||||
except Exception as e:
|
||||
gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR)
|
||||
# 爬取网页内容
|
||||
_detail_store = []
|
||||
for i in res:
|
||||
if _detail_store >= 3:
|
||||
break
|
||||
try:
|
||||
_detail_store.append(fetch_website_content(i["link"]))
|
||||
except BaseException as e:
|
||||
print(f"fetch_website_content err: {str(e)}")
|
||||
ret = f"{str(res)}"
|
||||
if len(_detail_store) > 0:
|
||||
ret += f"\n网页内容: {str(_detail_store)}"
|
||||
return ret
|
||||
|
||||
def fetch_website_content(url):
|
||||
gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
|
||||
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
res = soup.text
|
||||
res = res.replace("\n", "")
|
||||
with open(f"temp_{time.time()}.html", "w", encoding="utf-8") as f:
|
||||
f.write(res)
|
||||
return res
|
||||
|
||||
def web_search(question, provider):
|
||||
response = requests.get(url, headers=headers, timeout=3)
|
||||
response.encoding = "utf-8"
|
||||
# soup = BeautifulSoup(response.text, "html.parser")
|
||||
# # 如果有container / content / main等的话,就只取这些部分
|
||||
# has = False
|
||||
# beleive_ls = ["container", "content", "main"]
|
||||
# res = ""
|
||||
# for cls in beleive_ls:
|
||||
# for i in soup.find_all(class_=cls):
|
||||
# has = True
|
||||
# res += i.text
|
||||
# if not has:
|
||||
# res = soup.text
|
||||
# res = res.replace("\n", "").replace(" ", " ").replace("\r", "").replace("\t", "")
|
||||
# if not has:
|
||||
# res = res[300:1100]
|
||||
# else:
|
||||
# res = res[100:800]
|
||||
# # with open(f"temp_{time.time()}.html", "w", encoding="utf-8") as f:
|
||||
# # f.write(res)
|
||||
# gu.log(f"fetch_website_content: end", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
|
||||
# return res
|
||||
doc = Document(response.content)
|
||||
# print('title:', doc.title())
|
||||
ret = doc.summary(html_partial=True)
|
||||
soup = BeautifulSoup(ret, 'html.parser')
|
||||
ret = tidy_text(soup.get_text())
|
||||
return ret
|
||||
|
||||
def web_search(question, provider: Provider, session_id, official_fc=False):
|
||||
'''
|
||||
official_fc: 使用官方 function-calling
|
||||
'''
|
||||
new_func_call = FuncCall(provider)
|
||||
|
||||
new_func_call.add_func("web_keyword_search_via_bing", [{
|
||||
new_func_call.add_func("google_web_search", [{
|
||||
"type": "string",
|
||||
"name": "keyword",
|
||||
"brief": "必应搜索的关键词(分词,尽量保留所有信息)"
|
||||
"description": "google search query (分词,尽量保留所有信息)"
|
||||
}],
|
||||
"在必应搜索引擎上搜索给定的关键词,并且返回第一页的搜索结果列表(标题,简介和链接)",
|
||||
web_keyword_search_via_bing
|
||||
"通过搜索引擎搜索。如果问题需要在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
|
||||
google_web_search
|
||||
)
|
||||
|
||||
func_definition1 = new_func_call.func_dump()
|
||||
question1 = f"{question} \n(只能调用一个函数。)"
|
||||
res1, has_func = new_func_call.func_call(question1, func_definition1, is_task=False, is_summary=False)
|
||||
new_func_call.add_func("fetch_website_content", [{
|
||||
"type": "string",
|
||||
"name": "url",
|
||||
"description": "网址"
|
||||
}],
|
||||
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下https://github.com的内容`), 就调用此函数。如果没有,不要调用此函数。",
|
||||
fetch_website_content
|
||||
)
|
||||
question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。"
|
||||
has_func = False
|
||||
function_invoked_ret = ""
|
||||
if official_fc:
|
||||
func = provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
|
||||
if isinstance(func, Function):
|
||||
# arguments='{\n "keyword": "北京今天的天气"\n}', name='google_web_search'
|
||||
# 执行对应的结果:
|
||||
func_obj = None
|
||||
for i in new_func_call.func_list:
|
||||
if i["name"] == func.name:
|
||||
func_obj = i["func_obj"]
|
||||
break
|
||||
if not func_obj:
|
||||
gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
|
||||
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
try:
|
||||
args = json.loads(func.arguments)
|
||||
function_invoked_ret = func_obj(**args)
|
||||
has_func = True
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
else:
|
||||
# now func is a string
|
||||
return func
|
||||
else:
|
||||
try:
|
||||
function_invoked_ret, has_func = new_func_call.func_call(question1, new_func_call.func_dump(), is_task=False, is_summary=False)
|
||||
except BaseException as e:
|
||||
res = provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
|
||||
return res
|
||||
has_func = True
|
||||
|
||||
if has_func:
|
||||
provider.forget()
|
||||
question3 = f"""请你回答`{question}`问题。\n以下是相关材料,你请直接拿此材料针对问题进行总结回答,然后再给出参考链接。不要提到任何函数调用的信息。```\n{res1}\n```\n"""
|
||||
print(question3)
|
||||
provider.forget(session_id)
|
||||
question3 = f"""请你用可爱的语气回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行总结回答,再给参考链接, 参考链接首末有空格。不要提到任何函数调用的信息。在总结的末尾加上1-2个相关的emoji。```\n{function_invoked_ret}\n```\n"""
|
||||
gu.log(f"web_search: {question3}", tag="web_search", level=gu.LEVEL_DEBUG, max_len=99999)
|
||||
_c = 0
|
||||
while _c < 5:
|
||||
while _c < 3:
|
||||
try:
|
||||
print('text chat')
|
||||
res3 = provider.text_chat(question3)
|
||||
break
|
||||
final_ret = provider.text_chat(question3)
|
||||
return final_ret
|
||||
except Exception as e:
|
||||
print(e)
|
||||
_c += 1
|
||||
if _c == 5:
|
||||
raise e
|
||||
if _c == 3: raise e
|
||||
if "The message you submitted was too long" in str(e):
|
||||
res2 = res2[:int(len(res2) / 2)]
|
||||
question3 = f"""请你回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行回答,然后再给出参考链接。```\n{res1}\n{res2}\n```\n"""
|
||||
return res3
|
||||
else:
|
||||
return res1
|
||||
provider.forget(session_id)
|
||||
function_invoked_ret = function_invoked_ret[:int(len(function_invoked_ret) / 2)]
|
||||
time.sleep(3)
|
||||
question3 = f"""请回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行回答,再给参考链接, 参考链接首末有空格。```\n{function_invoked_ret}\n```\n"""
|
||||
return function_invoked_ret
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
'''
|
||||
插件工具函数
|
||||
'''
|
||||
import os
|
||||
import inspect
|
||||
|
||||
@@ -7,16 +10,25 @@ def get_classes(p_name, arg):
|
||||
clsmembers = inspect.getmembers(arg, inspect.isclass)
|
||||
for (name, _) in clsmembers:
|
||||
# print(name, p_name)
|
||||
if p_name.lower() == name.lower()[:-6]:
|
||||
if p_name.lower() == name.lower()[:-6] or name.lower() == "main":
|
||||
classes.append(name)
|
||||
break
|
||||
return classes
|
||||
|
||||
# 获取一个文件夹下所有的模块
|
||||
# 获取一个文件夹下所有的模块, 文件名和文件夹名相同
|
||||
def get_modules(path):
|
||||
modules = []
|
||||
for root, dirs, files in os.walk(path):
|
||||
# 获得所在目录名
|
||||
p_name = os.path.basename(root)
|
||||
for file in files:
|
||||
if file.endswith(".py") and not file.startswith("__"):
|
||||
modules.append(file[:-3])
|
||||
"""
|
||||
与文件夹名(不计大小写)相同或者是main.py的,都算启动模块
|
||||
"""
|
||||
if file.endswith(".py") and not file.startswith("__") and (p_name.lower() == file[:-3].lower() or file[:-3].lower() == "main"):
|
||||
modules.append({
|
||||
"pname": p_name,
|
||||
"module": file[:-3],
|
||||
})
|
||||
return modules
|
||||
|
||||
Reference in New Issue
Block a user