Compare commits

...

62 Commits

Author SHA1 Message Date
Soulter
752201cb46 update: requirements.txt 2023-11-14 09:33:30 +08:00
Soulter
deebf61b5f feat: 大幅优化网页搜索的信息提取准确性
perf: 使用 tictoken 预先计算 token
2023-11-14 09:33:18 +08:00
Soulter
d5e5b06e86 perf: 让回复末尾添加1-2个emoji 2023-11-13 23:05:19 +08:00
Soulter
cb5975c102 feat: 1. 适配新版openai sdk
2. 适配官方 function calling
2023-11-13 21:54:23 +08:00
Soulter
5b1aee1b4d feat: web search support prefix keyword call 2023-11-09 16:05:42 +00:00
Soulter
510c8b4236 feat: support gpt-4-vision-preview 2023-11-09 20:53:02 +08:00
Soulter
89fc7b0553 perf: 使用异步重写部分代码 2023-10-12 11:16:49 +08:00
Soulter
123c21fcb3 perf: 重载插件支持更新依赖库 2023-10-05 22:34:26 +08:00
Soulter
75d62d66f9 fix: 修复折叠发送时可能发送失败的问题 2023-10-05 21:38:35 +08:00
Soulter
23a8e989a5 perf: 优化插件加载机制 2023-10-05 13:38:10 +08:00
Soulter
9577e637f1 perf: 优化代码结构、稳定性和插件加载机制 2023-10-05 13:21:39 +08:00
Soulter
e51ef2201b Merge remote-tracking branch 'refs/remotes/origin/master' 2023-10-05 10:49:49 +08:00
Soulter
f4ae503abf perf: 优化报错提示和代码结构 2023-10-05 10:48:35 +08:00
Soulter
3424b658f3 bugfixes 2023-10-02 10:35:51 +08:00
Soulter
3198f73f3d perf: 清除警告;适配新版启动器 2023-10-02 10:17:10 +08:00
Soulter
aa3262a8ab chore: fix some typos 2023-10-02 10:10:04 +08:00
Soulter
6acd7be547 perf: 优化一些库的导入机制 2023-10-01 17:46:51 +08:00
Soulter
fb7669ddad perf: 依赖库安装优化 2023-10-01 16:20:51 +08:00
Soulter
f2c4ef126e perf: 优化openai模型消息截断机制 2023-09-30 15:11:06 +08:00
Soulter
33dcc4c152 perf: openai模型超限时截断消息(0.75x) 2023-09-30 15:06:57 +08:00
Soulter
b9e331ebd6 perf: 网页搜索改用google search,是改善效果 2023-09-30 14:59:25 +08:00
Soulter
7832ec386e perf: 优化web search 2023-09-30 14:06:50 +08:00
Soulter
b9828428cc perf: web search优化 2023-09-30 13:37:10 +08:00
Soulter
da11034aec feat: 支持在cmd_config中修改配置文件 2023-09-29 10:06:41 +08:00
Soulter
578c9e0695 feat: 支持戳一戳消息 2023-09-28 20:51:50 +08:00
Soulter
cc675a9b4f perf: 对插件开放更多接口 2023-09-28 20:12:39 +08:00
Soulter
08e7d4d0c6 fix: 修复一部分超限的报错
perf: web search稳定性和精确度优化
2023-09-27 22:06:08 +08:00
Soulter
553f1b8d83 fix: 修复官方模型下web search报错的问题 2023-09-27 21:14:03 +08:00
Soulter
73e7e2088d perf: 完善报错堆栈显示 2023-09-27 21:02:50 +08:00
Soulter
e40c9de610 perf: 优化聊天会话管理 2023-09-27 16:42:39 +08:00
Soulter
2f4e0bb4f2 fix: 修复人格一段时间后消失的问题 2023-09-25 15:55:51 +08:00
Soulter
191976e22e fix: 修复一些权限上的问题 2023-09-25 13:55:00 +08:00
Soulter
52656b8586 perf: 支持多管理员配置 2023-09-25 13:51:12 +08:00
Soulter
998e29ded6 fix: myid显示异常 2023-09-25 13:43:33 +08:00
Soulter
5bbe3f12d6 feat: OpenAI官方模型支持切换账号 2023-09-25 13:25:38 +08:00
Soulter
56aea81ed7 Merge remote-tracking branch 'refs/remotes/origin/master' 2023-09-25 12:04:04 +08:00
Soulter
7b8a311dde fix: 修复gocq启动下QQ频道无法通过@回复消息的问题
feat:  支持重置会话时保留人格
perf: 清除部分无用日志输出
2023-09-25 12:03:17 +08:00
Soulter
b75d20a3e8 Update README.md 2023-09-20 10:46:09 +08:00
Soulter
67faa587b6 fix: 修复初次调用/keyword指令时报错文件不存在的bug 2023-09-20 10:31:31 +08:00
Soulter
15fde686d4 perf: 精简日志输出和冗余的日志文件 2023-09-14 14:04:47 +08:00
Soulter
741284f6e8 perf: 去除启动时检查更新产生的大量的日志 2023-09-14 13:50:00 +08:00
Soulter
8352fc269b 1. 修复qq频道发不了图片的问题 2023-09-14 08:39:05 +08:00
Soulter
5852f36557 1. gocq支持选择不回复群、私聊、频道消息。
(在cmd_config.json文件设置gocq_react_xxx等项);
2. update指令升级成功后返回新版本信息
2023-09-10 09:03:26 +08:00
Soulter
cc1c723c12 fix: 修复OpenAI官方模型无法启用的问题 2023-09-09 09:45:34 +08:00
Soulter
adf5cbfeba fix: 优化网页搜索的稳定性 2023-09-08 16:41:37 +08:00
Soulter
d6d0516c9a feat: gocq服务器地址支持在cmd_config自定义。 2023-09-08 14:19:07 +08:00
Soulter
8aab10aaf3 websearch bugfixes 2023-09-08 13:46:57 +08:00
Soulter
4fe5616ae1 Merge remote-tracking branch 'refs/remotes/origin/master' 2023-09-08 13:40:03 +08:00
Soulter
7e1c76a3f5 fix: 修复openai官方模型一些指令报错的问题
feat: revChatGPT支持人格设置
2023-09-08 13:38:48 +08:00
Soulter
f74665ff71 Update README.md 2023-09-08 12:01:39 +08:00
Soulter
a96d64fe88 fix: 修复qq频道下无法发送图片的bug 2023-09-04 10:14:46 +08:00
Soulter
fd2aa0cba6 bugfixes 2023-09-02 19:59:14 +08:00
Soulter
a92ea3db02 fix: 修复只启动频道官方SDK下,不显示管理者QQ设置的问题 2023-09-02 19:39:38 +08:00
Soulter
d7a513b640 fix: 关键词指令 2023-09-02 18:30:11 +08:00
Soulter
8a017ff693 bugfixes 2023-09-02 11:11:54 +08:00
Soulter
7d08f57b32 bugfixes 2023-09-02 10:31:13 +08:00
Soulter
6f4ad7890b bugfixes 2023-09-02 10:05:06 +08:00
Soulter
37488118a6 feat: 1. keyword指令支持记录图片;
2. qq频道转gocq数据结构兼容层实现;
perf: 1. 优化代码结构;
2. log 支持环境变量指定log等级
2023-09-02 00:24:13 +08:00
Soulter
b2da0778ae Merge branch 'master' of https://github.com/Soulter/QQChannelChatGPT 2023-09-01 15:12:18 +08:00
Soulter
cc887a5037 perf: 优化代码结构 2023-09-01 15:11:58 +08:00
Soulter
ca86a02d30 Update requirements.txt 2023-08-31 21:27:26 +08:00
Soulter
d652dc19a6 Update README.md 2023-08-31 18:39:37 +08:00
20 changed files with 1533 additions and 950 deletions

View File

@@ -17,7 +17,7 @@
</div> </div>
## 🤔您可能想了解的 ## 🤔您可能想了解的
- **如何部署?** [帮助文档](https://github.com/Soulter/QQChannelChatGPT/wiki) - **如何部署?** [帮助文档](https://github.com/Soulter/QQChannelChatGPT/wiki) (部署不成功欢迎进群捞人解决<3)
- **go-cqhttp启动不成功报登录失败** [在这里搜索解决方法](https://github.com/Mrs4s/go-cqhttp/issues) - **go-cqhttp启动不成功报登录失败** [在这里搜索解决方法](https://github.com/Mrs4s/go-cqhttp/issues)
- **程序闪退/机器人启动不成功** [提交issue或加群反馈](https://github.com/Soulter/QQChannelChatGPT/issues) - **程序闪退/机器人启动不成功** [提交issue或加群反馈](https://github.com/Soulter/QQChannelChatGPT/issues)
- **如何开启ChatGPTBardClaude等语言模型** [查看帮助](https://github.com/Soulter/QQChannelChatGPT/wiki/%E8%A1%A5%E5%85%85%EF%BC%9A%E5%A6%82%E4%BD%95%E5%BC%80%E5%90%AFChatGPT%E3%80%81Bard%E3%80%81Claude%E7%AD%89%E8%AF%AD%E8%A8%80%E6%A8%A1%E5%9E%8B%EF%BC%9F) - **如何开启ChatGPTBardClaude等语言模型** [查看帮助](https://github.com/Soulter/QQChannelChatGPT/wiki/%E8%A1%A5%E5%85%85%EF%BC%9A%E5%A6%82%E4%BD%95%E5%BC%80%E5%90%AFChatGPT%E3%80%81Bard%E3%80%81Claude%E7%AD%89%E8%AF%AD%E8%A8%80%E6%A8%A1%E5%9E%8B%EF%BC%9F)
@@ -48,7 +48,7 @@
- 大模型对话 - 大模型对话
- 大模型网页搜索能力 **(目前仅支持OpenAI系的模型最新版本下使用web on指令打开)** - 大模型网页搜索能力 **(目前仅支持OpenAI系的模型最新版本下使用web on指令打开)**
- 插件安装在QQ或QQ频道聊天框内输入`plugin`了解详情 - 插件安装在QQ或QQ频道聊天框内输入`plugin`了解详情
- 回复文字图片渲染以图片markdown格式回复,降低被风控概率,需手动在`cmd_config.json`内开启 - 回复文字图片渲染以图片markdown格式回复**大幅度降低被风控概率**需手动在`cmd_config.json`内开启qq_pic_mode
- 人格设置 - 人格设置
- 关键词回复 - 关键词回复
- 热更新更新本项目时**仅需**在QQ或QQ频道聊天框内输入`update latest r` - 热更新更新本项目时**仅需**在QQ或QQ频道聊天框内输入`update latest r`
@@ -121,7 +121,7 @@
插件开发教程https://github.com/Soulter/QQChannelChatGPT/wiki/%E5%9B%9B%E3%80%81%E5%BC%80%E5%8F%91%E6%8F%92%E4%BB%B6 插件开发教程https://github.com/Soulter/QQChannelChatGPT/wiki/%E5%9B%9B%E3%80%81%E5%BC%80%E5%8F%91%E6%8F%92%E4%BB%B6
部分公开的插件: 部分插件:
- `LLMS`: https://github.com/Soulter/llms | Claude, HuggingChat 大语言模型接入。 - `LLMS`: https://github.com/Soulter/llms | Claude, HuggingChat 大语言模型接入。
@@ -129,7 +129,9 @@
- `sysstat`: https://github.com/Soulter/sysstatqcbot | 查看系统状态 - `sysstat`: https://github.com/Soulter/sysstatqcbot | 查看系统状态
- `BiliMonitor`: https://github.com/Soulter/BiliMonitor | 订阅B站动态 - `BiliMonitor`: https://github.com/Soulter/BiliMonitor | 订阅B站动态
- `liferestart`: https://github.com/Soulter/liferestart | 人生重开模拟器
<!-- <!--
### 指令 ### 指令

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
class GlobalObject():
'''
存放一些公用的数据,用于在不同模块(如core与command)之间传递
'''
def __init__(self):
self.nick = None # gocq 的昵称
self.base_config = None # config.yaml
self.cached_plugins = {} # 缓存的插件
self.web_search = False # 是否开启了网页搜索
self.reply_prefix = None
self.admin_qq = "123456"
self.admin_qqchan = "123456"
self.uniqueSession = False
self.cnt_total = 0
self.platform_qq = None
self.platform_qqchan = None

94
main.py
View File

@@ -1,19 +1,32 @@
import os, sys import os, sys
from pip._internal import main as pipmain from pip._internal import main as pipmain
import warnings
import traceback
warnings.filterwarnings("ignore")
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
def main(): def main():
# config.yaml 配置文件加载和环境确认
try: try:
import cores.qqbot.core as qqBot import cores.qqbot.core as qqBot
import yaml import yaml
from yaml.scanner import ScannerError
import util.general_utils as gu
ymlfile = open(abs_path+"configs/config.yaml", 'r', encoding='utf-8') ymlfile = open(abs_path+"configs/config.yaml", 'r', encoding='utf-8')
cfg = yaml.safe_load(ymlfile) cfg = yaml.safe_load(ymlfile)
except BaseException as e: except ImportError as import_error:
print(e) print(import_error)
input("第三方依赖库未完全安装完毕,请退出程序重试。") input("第三方库未完全安装完毕,请退出程序重试。")
exit() except FileNotFoundError as file_not_found:
import util.general_utils as gu print(file_not_found)
input("配置文件不存在,请检查是否已经下载配置文件。")
except ScannerError as e:
print(traceback.format_exc())
input("config.yaml 配置文件格式错误,请遵守 yaml 格式。")
# 设置代理
if 'http_proxy' in cfg: if 'http_proxy' in cfg:
os.environ['HTTP_PROXY'] = cfg['http_proxy'] os.environ['HTTP_PROXY'] = cfg['http_proxy']
if 'https_proxy' in cfg: if 'https_proxy' in cfg:
@@ -21,21 +34,20 @@ def main():
os.environ['NO_PROXY'] = 'cn.bing.com,https://api.sgroup.qq.com' os.environ['NO_PROXY'] = 'cn.bing.com,https://api.sgroup.qq.com'
# 检查temp文件夹 # 检查并创建 temp 文件夹
if not os.path.exists(abs_path+"temp"): if not os.path.exists(abs_path + "temp"):
os.mkdir(abs_path+"temp") os.mkdir(abs_path+"temp")
# 选择默认模型
provider = privider_chooser(cfg) provider = privider_chooser(cfg)
if len(provider) == 0: if len(provider) == 0:
gu.log("未开启任何语言模型, 请在configs/config.yaml下选择开启相应语言模型", gu.LEVEL_CRITICAL) gu.log("注意:您目前未开启任何语言模型。", gu.LEVEL_WARNING)
input("按任意键退出...")
exit()
print('[System] 开启的语言模型: ' + str(provider)) print('[System] 开启的语言模型: ' + str(provider))
# 执行Bot
# 启动主程序cores/qqbot/core.py
qqBot.initBot(cfg, provider) qqBot.initBot(cfg, provider)
# 语言模型提供商选择器 # 语言模型提供商选择器
# 目前有OpenAI官方API、逆向库
def privider_chooser(cfg): def privider_chooser(cfg):
l = [] l = []
if 'rev_ChatGPT' in cfg and cfg['rev_ChatGPT']['enable']: if 'rev_ChatGPT' in cfg and cfg['rev_ChatGPT']['enable']:
@@ -48,55 +60,44 @@ def privider_chooser(cfg):
l.append('openai_official') l.append('openai_official')
return l return l
def check_env(): def check_env(ch_mirror=False):
if not (sys.version_info.major == 3 and sys.version_info.minor >= 8): if not (sys.version_info.major == 3 and sys.version_info.minor >= 9):
print("请使用Python3.8运行本项目") print("请使用Python3.9+运行本项目")
input("按任意键退出...") input("按任意键退出...")
exit() exit()
# 检查pip
# pip_tag = "pip"
# mm = os.system("pip -V")
# if mm != 0:
# mm1 = os.system("pip3 -V")
# if mm1 != 0:
# print("未检测到pip, 请安装Python(版本应>=3.9)")
# input("按任意键退出...")
# exit()
# else:
# pip_tag = "pip3"
if os.path.exists('requirements.txt'): if os.path.exists('requirements.txt'):
pth = 'requirements.txt' pth = 'requirements.txt'
else: else:
pth = 'QQChannelChatGPT'+ os.sep +'requirements.txt' pth = 'QQChannelChatGPT'+ os.sep +'requirements.txt'
print("正在更新三方依赖库...") print("正在检查更新三方库...")
try: try:
pipmain(['install', '-r', pth]) if ch_mirror:
print("依赖库安装完毕。") print("使用阿里云镜像")
pipmain(['install', '-r', pth, '-i', 'https://mirrors.aliyun.com/pypi/simple/', '--quiet'])
else:
pipmain(['install', '-r', pth, '--quiet'])
except BaseException as e: except BaseException as e:
print(e) print(e)
while True: while True:
res = input("依赖库可能安装失败\n果是报错ValueError: check_hostname requires server_hostname请尝试先关闭代理后重试。\n输入y回车重试\n输入c回车使用国内镜像源下载\n输入其他按键回车继续往下执行。") res = input("安装失败。\n如报错ValueError: check_hostname requires server_hostname请尝试先关闭代理后重试。\n1.输入y回车重试\n2. 输入c回车使用国内镜像源下载\n3. 输入其他按键回车继续往下执行。")
if res == "y": if res == "y":
try: try:
pipmain(['install', '-r', pth]) pipmain(['install', '-r', pth])
print("依赖库安装完毕。")
break break
except BaseException as e: except BaseException as e:
print(e) print(e)
continue continue
elif res == "c": elif res == "c":
try: try:
pipmain(['install', '-r', pth, '-i', 'https://mirrors.aliyun.com/pypi/simple/']) pipmain(['install', '-r', pth, '-i', 'https://mirrors.aliyun.com/pypi/simple/'])
print("依赖库安装完毕。")
break break
except BaseException as e: except BaseException as e:
print(e) print(e)
continue continue
else: else:
break break
print("第三方库检查完毕。")
def get_platform(): def get_platform():
import platform import platform
@@ -111,17 +112,20 @@ def get_platform():
print("other") print("other")
if __name__ == "__main__": if __name__ == "__main__":
check_env()
# 获取参数
args = sys.argv args = sys.argv
if len(args) > 1:
if args[1] == '-replit': if '-cn' in args:
print("[System] 启动Replit Web保活服务...") check_env(True)
try: else:
from webapp_replit import keep_alive check_env()
keep_alive()
except BaseException as e: if '-replit' in args:
print(e) print("[System] 启动Replit Web保活服务...")
print(f"[System-err] Replit Web保活服务启动失败:{str(e)}") try:
from webapp_replit import keep_alive
keep_alive()
except BaseException as e:
print(e)
print(f"[System-err] Replit Web保活服务启动失败:{str(e)}")
main() main()

View File

@@ -1,6 +1,13 @@
import json import json
import git.exc
from git.repo import Repo has_git = True
try:
import git.exc
from git.repo import Repo
except BaseException as e:
print("你正运行在无Git环境下暂时将无法使用插件、热更新功能。")
has_git = False
import os import os
import sys import sys
import requests import requests
@@ -18,15 +25,72 @@ from nakuru.entities.components import (
Image Image
) )
from PIL import Image as PILImage from PIL import Image as PILImage
from cores.qqbot.global_object import GlobalObject
from pip._internal import main as pipmain
PLATFORM_QQCHAN = 'qqchan' PLATFORM_QQCHAN = 'qqchan'
PLATFORM_GOCQ = 'gocq' PLATFORM_GOCQ = 'gocq'
# 指令功能的基类,通用的(不区分语言模型)的指令就在这实现 # 指令功能的基类,通用的(不区分语言模型)的指令就在这实现
class Command: class Command:
def __init__(self, provider: Provider): def __init__(self, provider: Provider, global_object: GlobalObject = None):
self.provider = Provider self.provider = provider
self.global_object = global_object
def check_command(self,
message,
session_id: str,
role,
platform,
message_obj):
# 插件
cached_plugins = self.global_object.cached_plugins
for k, v in cached_plugins.items():
try:
hit, res = v["clsobj"].run(message, role, platform, message_obj, self.global_object.platform_qq)
if hit:
return True, res
except BaseException as e:
gu.log(f"{k}插件加载出现问题,原因: {str(e)}\n已安装插件: {cached_plugins.keys}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
if self.command_start_with(message, "nick"):
return True, self.set_nick(message, platform, role)
if self.command_start_with(message, "plugin"):
return True, self.plugin_oper(message, role, cached_plugins, platform)
if self.command_start_with(message, "myid") or self.command_start_with(message, "!myid"):
return True, self.get_my_id(message_obj)
if self.command_start_with(message, "nconf") or self.command_start_with(message, "newconf"):
return True, self.get_new_conf(message, role)
if self.command_start_with(message, "web"): # 网页搜索
return True, self.web_search(message)
if self.command_start_with(message, "keyword"):
return True, self.keyword(message_obj, role)
return False, None
def web_search(self, message):
if message == "web on":
self.global_object.web_search = True
return True, "已开启网页搜索", "web"
elif message == "web off":
self.global_object.web_search = False
return True, "已关闭网页搜索", "web"
return True, f"网页搜索功能当前状态: {self.global_object.web_search}", "web"
def get_my_id(self, message_obj):
return True, f"你的ID{str(message_obj.sender.tiny_id)}", "plugin"
def get_new_conf(self, message, role):
if role != "admin":
return False, f"你的身份组{role}没有权限使用此指令。", "newconf"
l = message.split(" ")
if len(l) <= 1:
obj = cc.get_all()
p = gu.create_text_image("【cmd_config.json】", json.dumps(obj, indent=4, ensure_ascii=False))
return True, [Image.fromFileSystem(p)], "newconf"
def get_plugin_modules(self): def get_plugin_modules(self):
plugins = [] plugins = []
try: try:
@@ -41,79 +105,20 @@ class Command:
except BaseException as e: except BaseException as e:
raise e raise e
def check_command(self, message, role, platform,
message_obj,
cached_plugins: dict,
qq_platform: QQ,
global_object: dict):
# 插件
for k, v in cached_plugins.items():
try:
hit, res = v["clsobj"].run(message, role, platform, message_obj, qq_platform)
if hit:
return True, res
except BaseException as e:
gu.log(f"{k}插件加载出现问题,原因: {str(e)}\n已安装插件: {cached_plugins.keys}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
if self.command_start_with(message, "nick"):
return True, self.set_nick(message, platform, role)
if self.command_start_with(message, "plugin"):
return True, self.plugin_oper(message, role, cached_plugins, platform)
if self.command_start_with(message, "myid"):
return True, self.get_my_id(message_obj, platform)
if self.command_start_with(message, "nconf") or self.command_start_with(message, "newconf"):
return True, self.get_new_conf(message, role, platform)
if self.command_start_with(message, "web"): # 网页搜索
return True, self.web_search(message, global_object)
return False, None
def web_search(self, message, global_object):
if "web_search" not in global_object:
global_object["web_search"] = False
if message == "web on":
global_object["web_search"] = True
return True, "已开启网页搜索", "web"
elif message == "web off":
global_object["web_search"] = False
return True, "已关闭网页搜索", "web"
return True, f"网页搜索功能当前状态: {global_object['web_search']}", "web"
def get_my_id(self, message_obj, platform):
print(message_obj)
if platform == "gocq":
if message_obj.type == "GuildMessage":
return True, f"你的频道id是{str(message_obj.sender.tiny_id)}", "plugin"
else:
return True, f"你的QQ是{str(message_obj.sender.user_id)}", "plugin"
else:
return True, f"{str(message_obj)}\n此指令为开发专用为提供更多数据请自行从中找出您的频道ID。在author->id中。", "plugin"
def get_new_conf(self, message, role, platform):
if role != "admin":
return False, f"你的身份组{role}没有权限使用此指令。", "newconf"
if platform == gu.PLATFORM_GOCQ:
l = message.split(" ")
if len(l) <= 1:
obj = cc.get_all()
p = gu.create_text_image("【cmd_config.json】", json.dumps(obj, indent=4, ensure_ascii=False))
return True, [Image.fromFileSystem(p)], "newconf"
return False, f"Not support or not implemented.", "newconf"
def plugin_reload(self, cached_plugins: dict, target: str = None, all: bool = False): def plugin_reload(self, cached_plugins: dict, target: str = None, all: bool = False):
plugins = self.get_plugin_modules() plugins = self.get_plugin_modules()
fail_rec = "" fail_rec = ""
if plugins is None: if plugins is None:
return False, "未找到任何插件模块" return False, "未找到任何插件模块"
print(plugins)
for p in plugins: for plugin in plugins:
try: try:
p = plugin['module']
root_dir_name = plugin['pname']
if p not in cached_plugins or p == target or all: if p not in cached_plugins or p == target or all:
module = __import__("addons.plugins." + p + "." + p, fromlist=[p]) module = __import__("addons.plugins." + root_dir_name + "." + p, fromlist=[p])
if p in cached_plugins: if p in cached_plugins:
module = importlib.reload(module) module = importlib.reload(module)
cls = putil.get_classes(p, module) cls = putil.get_classes(p, module)
@@ -129,13 +134,15 @@ class Command:
except BaseException as e: except BaseException as e:
fail_rec += f"调用插件{p} info失败, 原因: {str(e)}\n" fail_rec += f"调用插件{p} info失败, 原因: {str(e)}\n"
continue continue
cached_plugins[p] = { cached_plugins[info['name']] = {
"module": module, "module": module,
"clsobj": obj, "clsobj": obj,
"info": info "info": info,
} "name": info['name'],
"root_dir_name": root_dir_name,
}
except BaseException as e: except BaseException as e:
fail_rec += f"加载{p}插件出现问题,原因{str(e)}\n" fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n"
if fail_rec == "": if fail_rec == "":
return True, None return True, None
else: else:
@@ -145,12 +152,12 @@ class Command:
插件指令 插件指令
''' '''
def plugin_oper(self, message: str, role: str, cached_plugins: dict, platform: str): def plugin_oper(self, message: str, role: str, cached_plugins: dict, platform: str):
if not has_git:
return False, "你正在运行在无Git环境下暂时将无法使用插件、热更新功能。", "plugin"
l = message.split(" ") l = message.split(" ")
if len(l) < 2: if len(l) < 2:
if platform == gu.PLATFORM_GOCQ: p = gu.create_text_image("【插件指令面板】", "安装插件: \nplugin i 插件Github地址\n卸载插件: \nplugin d 插件名 \n重载插件: \nplugin reload\n查看插件列表:\nplugin l\n更新插件: plugin u 插件名\n")
p = gu.create_text_image("【插件指令面板】", "安装插件: \nplugin i 插件Github地址\n卸载插件: \nplugin d 插件名 \n重载插件: \nplugin reload\n查看插件列表:\nplugin l\n更新插件: plugin u 插件名\n") return True, [Image.fromFileSystem(p)], "plugin"
return True, [Image.fromFileSystem(p)], "plugin"
return True, "\n=====插件指令面板=====\n安装插件: \nplugin i 插件Github地址\n卸载插件: \nplugin d 插件名 \n重载插件: \nplugin reload\n查看插件列表:\nplugin l\n更新插件: plugin u 插件名\n===============", "plugin"
else: else:
ppath = "" ppath = ""
if os.path.exists("addons/plugins"): if os.path.exists("addons/plugins"):
@@ -163,8 +170,13 @@ class Command:
if role != "admin": if role != "admin":
return False, f"你的身份组{role}没有权限安装插件", "plugin" return False, f"你的身份组{role}没有权限安装插件", "plugin"
try: try:
# 删除末尾的/
if l[2].endswith("/"):
l[2] = l[2][:-1]
# 得到url的最后一段 # 得到url的最后一段
d = l[2].split("/")[-1] d = l[2].split("/")[-1]
# 转换非法字符:-
d = d.replace("-", "_")
# 创建文件夹 # 创建文件夹
plugin_path = os.path.join(ppath, d) plugin_path = os.path.join(ppath, d)
if os.path.exists(plugin_path): if os.path.exists(plugin_path):
@@ -174,11 +186,9 @@ class Command:
# 读取插件的requirements.txt # 读取插件的requirements.txt
if os.path.exists(os.path.join(plugin_path, "requirements.txt")): if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
with open(os.path.join(plugin_path, "requirements.txt"), "r", encoding="utf-8") as f: mm = pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt")])
for line in f.readlines(): if mm != 0:
mm = os.system(f"pip3 install {line.strip()}") return False, "插件依赖安装失败需要您手动pip安装对应插件的依赖。", "plugin"
if mm != 0:
return False, "插件依赖安装失败需要您手动pip安装对应插件的依赖。", "plugin"
# 加载没缓存的插件 # 加载没缓存的插件
ok, err = self.plugin_reload(cached_plugins, target=d) ok, err = self.plugin_reload(cached_plugins, target=d)
if ok: if ok:
@@ -192,28 +202,29 @@ class Command:
elif l[1] == "d": elif l[1] == "d":
if role != "admin": if role != "admin":
return False, f"你的身份组{role}没有权限删除插件", "plugin" return False, f"你的身份组{role}没有权限删除插件", "plugin"
if l[2] not in cached_plugins:
return False, "未找到该插件", "plugin"
try: try:
# 删除文件夹 root_dir_name = cached_plugins[l[2]]["root_dir_name"]
# shutil.rmtree(os.path.join(ppath, l[2])) self.remove_dir(os.path.join(ppath, root_dir_name))
self.remove_dir(os.path.join(ppath, l[2])) del cached_plugins[l[2]]
if l[2] in cached_plugins:
del cached_plugins[l[2]]
return True, "插件卸载成功~", "plugin" return True, "插件卸载成功~", "plugin"
except BaseException as e: except BaseException as e:
return False, f"卸载插件失败,原因: {str(e)}", "plugin" return False, f"卸载插件失败,原因: {str(e)}", "plugin"
elif l[1] == "u": elif l[1] == "u":
plugin_path = os.path.join(ppath, l[2]) if l[2] not in cached_plugins:
return False, "未找到该插件", "plugin"
root_dir_name = cached_plugins[l[2]]["root_dir_name"]
plugin_path = os.path.join(ppath, root_dir_name)
try: try:
repo = Repo(path = plugin_path) repo = Repo(path = plugin_path)
repo.remotes.origin.pull() repo.remotes.origin.pull()
# 读取插件的requirements.txt # 读取插件的requirements.txt
if os.path.exists(os.path.join(plugin_path, "requirements.txt")): if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
with open(os.path.join(plugin_path, "requirements.txt"), "r", encoding="utf-8") as f: mm = pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt")])
for line in f.readlines(): if mm != 0:
mm = os.system(f"pip3 install {line.strip()}") return False, "插件依赖安装失败需要您手动pip安装对应插件的依赖。", "plugin"
if mm != 0:
return False, "插件依赖安装失败需要您手动pip安装对应插件的依赖。", "plugin"
ok, err = self.plugin_reload(cached_plugins, target=l[2]) ok, err = self.plugin_reload(cached_plugins, target=l[2])
if ok: if ok:
@@ -226,21 +237,16 @@ class Command:
elif l[1] == "l": elif l[1] == "l":
try: try:
plugin_list_info = "\n".join([f"{k}: \n名称: {v['info']['name']}\n简介: {v['info']['desc']}\n版本: {v['info']['version']}\n作者: {v['info']['author']}\n" for k, v in cached_plugins.items()]) plugin_list_info = "\n".join([f"{k}: \n名称: {v['info']['name']}\n简介: {v['info']['desc']}\n版本: {v['info']['version']}\n作者: {v['info']['author']}\n" for k, v in cached_plugins.items()])
if platform == gu.PLATFORM_GOCQ: p = gu.create_text_image("【已激活插件列表】", plugin_list_info + "\n使用plugin v 插件名 查看插件帮助\n")
p = gu.create_text_image("【已激活插件列表】", plugin_list_info + "\n使用plugin v 插件名 查看插件帮助\n") return True, [Image.fromFileSystem(p)], "plugin"
return True, [Image.fromFileSystem(p)], "plugin"
return True, "\n=====已激活插件列表=====\n" + plugin_list_info + "\n使用plugin v 插件名 查看插件帮助\n=================", "plugin"
except BaseException as e: except BaseException as e:
return False, f"获取插件列表失败,原因: {str(e)}", "plugin" return False, f"获取插件列表失败,原因: {str(e)}", "plugin"
elif l[1] == "v": elif l[1] == "v":
try: try:
if l[2] in cached_plugins: if l[2] in cached_plugins:
info = cached_plugins[l[2]]["info"] info = cached_plugins[l[2]]["info"]
if platform == gu.PLATFORM_GOCQ: p = gu.create_text_image(f"【插件信息】", f"名称: {info['name']}\n{info['desc']}\n版本: {info['version']}\n作者: {info['author']}\n\n帮助:\n{info['help']}")
p = gu.create_text_image(f"【插件信息】", f"名称: {info['name']}\n{info['desc']}\n版本: {info['version']}\n作者: {info['author']}\n\n帮助:\n{info['help']}") return True, [Image.fromFileSystem(p)], "plugin"
return True, [Image.fromFileSystem(p)], "plugin"
res = f"\n=====插件信息=====\n名称: {info['name']}\n{info['desc']}\n版本: {info['version']}作者: {info['author']}\n\n帮助:\n{info['help']}"
return True, res, "plugin"
else: else:
return False, "未找到该插件", "plugin" return False, "未找到该插件", "plugin"
except BaseException as e: except BaseException as e:
@@ -248,6 +254,16 @@ class Command:
elif l[1] == "reload": elif l[1] == "reload":
if role != "admin": if role != "admin":
return False, f"你的身份组{role}没有权限重载插件", "plugin" return False, f"你的身份组{role}没有权限重载插件", "plugin"
for plugin in cached_plugins:
try:
print(f"更新插件 {plugin} 依赖...")
plugin_path = os.path.join(ppath, cached_plugins[plugin]["root_dir_name"])
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
mm = pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), "--quiet"])
if mm != 0:
return False, "插件依赖安装失败需要您手动pip安装对应插件的依赖。", "plugin"
except BaseException as e:
print(f"插件{plugin}依赖安装失败,原因: {str(e)}")
try: try:
ok, err = self.plugin_reload(cached_plugins, all = True) ok, err = self.plugin_reload(cached_plugins, all = True)
if ok: if ok:
@@ -264,7 +280,6 @@ class Command:
return False, f"你的身份组{role}没有权限开发者模式", "plugin" return False, f"你的身份组{role}没有权限开发者模式", "plugin"
return True, "cached_plugins: \n" + str(cached_plugins), "plugin" return True, "cached_plugins: \n" + str(cached_plugins), "plugin"
def remove_dir(self, file_path): def remove_dir(self, file_path):
while 1: while 1:
if not os.path.exists(file_path): if not os.path.exists(file_path):
@@ -276,7 +291,6 @@ class Command:
if os.path.exists(err_file_path): if os.path.exists(err_file_path):
os.chmod(err_file_path, stat.S_IWUSR) os.chmod(err_file_path, stat.S_IWUSR)
''' '''
nick: 存储机器人的昵称 nick: 存储机器人的昵称
''' '''
@@ -288,27 +302,13 @@ class Command:
if len(l) == 1: if len(l) == 1:
return True, "【设置机器人昵称】示例:\n支持多昵称\nnick 昵称1 昵称2 昵称3", "nick" return True, "【设置机器人昵称】示例:\n支持多昵称\nnick 昵称1 昵称2 昵称3", "nick"
nick = l[1:] nick = l[1:]
self.general_command_storer("nick_qq", nick) cc.put("nick_qq", nick)
self.global_object.nick = tuple(nick)
return True, f"设置成功!现在你可以叫我这些昵称来提问我啦~", "nick" return True, f"设置成功!现在你可以叫我这些昵称来提问我啦~", "nick"
elif platform == PLATFORM_QQCHAN: elif platform == PLATFORM_QQCHAN:
nick = message.split(" ")[2] nick = message.split(" ")[2]
return False, "QQ频道平台不支持为机器人设置昵称。", "nick" return False, "QQ频道平台不支持为机器人设置昵称。", "nick"
"""
存储指令结果到cmd_config.json
"""
def general_command_storer(self, key, value):
if not os.path.exists("cmd_config.json"):
config = {}
else:
with open("cmd_config.json", "r", encoding="utf-8") as f:
config = json.load(f)
config[key] = value
with open("cmd_config.json", "w", encoding="utf-8") as f:
json.dump(config, f, indent=4, ensure_ascii=False)
f.flush()
def general_commands(self): def general_commands(self):
return { return {
"help": "帮助", "help": "帮助",
@@ -343,15 +343,14 @@ class Command:
msg += plugin_list_info msg += plugin_list_info
msg += notice msg += notice
if platform == gu.PLATFORM_GOCQ: try:
try: # p = gu.create_text_image("【Help Center】", msg)
# p = gu.create_text_image("【Help Center】", msg) p = gu.create_markdown_image(msg)
p = gu.create_markdown_image(msg) return [Image.fromFileSystem(p)]
return [Image.fromFileSystem(p)] except BaseException as e:
except BaseException as e: gu.log(str(e))
gu.log(str(e)) finally:
return msg return msg
return msg
# 接受可变参数 # 接受可变参数
def command_start_with(self, message: str, *args): def command_start_with(self, message: str, *args):
@@ -361,14 +360,36 @@ class Command:
return False return False
# keyword: 关键字 # keyword: 关键字
def keyword(self, message: str, role: str): def keyword(self, message_obj, role: str):
if role != "admin": if role != "admin":
return True, "你没有权限使用该指令", "keyword" return True, "你没有权限使用该指令", "keyword"
plain_text = ""
image_url = ""
l = message.split(" ") for comp in message_obj.message:
if isinstance(comp, Plain):
plain_text += comp.text
elif isinstance(comp, Image) and image_url == "":
if comp.url is None:
image_url = comp.file
else:
image_url = comp.url
if len(l) < 3: l = plain_text.split(" ")
return True, "【设置关键词回复】示例:\nkeyword hi 你好\n当发送hi的时候会回复你好\nkeyword /hi 你好\n当发送/hi时会回复你好\n删除关键词: keyword d hi\n删除hi关键词的回复", "keyword"
if len(l) < 3 and image_url == "":
return True, """
【设置关键词回复】示例:
1. keyword hi 你好
当发送hi的时候会回复你好
2. keyword /hi 你好
当发送/hi时会回复你好
3. keyword d hi
删除hi关键词的回复
4. keyword hi <图片>
当发送hi时会回复图片
""", "keyword"
del_mode = False del_mode = False
if l[1] == "d": if l[1] == "d":
@@ -384,21 +405,34 @@ class Command:
return False, "该关键词不存在", "keyword" return False, "该关键词不存在", "keyword"
else: del keyword[l[2]] else: del keyword[l[2]]
else: else:
keyword[l[1]] = l[2] keyword[l[1]] = {
"plain_text": " ".join(l[2:]),
"image_url": image_url
}
else: else:
if del_mode: if del_mode:
return False, "该关键词不存在", "keyword" return False, "该关键词不存在", "keyword"
keyword = {l[1]: l[2]} keyword = {
l[1]: {
"plain_text": " ".join(l[2:]),
"image_url": image_url
}
}
with open("keyword.json", "w", encoding="utf-8") as f: with open("keyword.json", "w", encoding="utf-8") as f:
json.dump(keyword, f, ensure_ascii=False, indent=4) json.dump(keyword, f, ensure_ascii=False, indent=4)
f.flush() f.flush()
if del_mode: if del_mode:
return True, "删除成功: "+l[2], "keyword" return True, "删除成功: "+l[2], "keyword"
return True, "设置成功: "+l[1]+" -> "+l[2], "keyword" if image_url == "":
return True, "设置成功: "+l[1]+" "+" ".join(l[2:]), "keyword"
else:
return True, [Plain("设置成功: "+l[1]+" "+" ".join(l[2:])), Image.fromURL(image_url)], "keyword"
except BaseException as e: except BaseException as e:
return False, "设置失败: "+str(e), "keyword" return False, "设置失败: "+str(e), "keyword"
def update(self, message: str, role: str): def update(self, message: str, role: str):
if not has_git:
return False, "你正在运行在无Git环境下暂时将无法使用插件、热更新功能。", "update"
if role != "admin": if role != "admin":
return True, "你没有权限使用该指令", "keyword" return True, "你没有权限使用该指令", "keyword"
l = message.split(" ") l = message.split(" ")
@@ -436,11 +470,19 @@ class Command:
pash_tag = "QQChannelChatGPT"+os.sep pash_tag = "QQChannelChatGPT"+os.sep
repo.remotes.origin.pull() repo.remotes.origin.pull()
if len(l) == 3 and l[2] == "r": try:
py = sys.executable origin = repo.remotes.origin
os.execl(py, py, *sys.argv) origin.fetch()
commits = list(repo.iter_commits('master', max_count=1))
commit_log = commits[0].message
except BaseException as e:
commit_log = "无法获取commit信息"
return True, "更新成功~是否重启输入update r重启重启指令不返回任何确认信息", "update" tag = "update"
if len(l) == 3 and l[2] == "r":
tag = "update latest r"
return True, f"更新成功。新版本内容: \n{commit_log}\nps:重启后生效。输入update r重启重启指令不返回任何确认信息", tag
except BaseException as e: except BaseException as e:
return False, "更新失败: "+str(e), "update" return False, "更新失败: "+str(e), "update"
@@ -448,7 +490,6 @@ class Command:
py = sys.executable py = sys.executable
os.execl(py, py, *sys.argv) os.execl(py, py, *sys.argv)
def reset(self): def reset(self):
return False return False

View File

@@ -1,36 +1,40 @@
from model.command.command import Command from model.command.command import Command
from model.provider.provider_openai_official import ProviderOpenAIOfficial from model.provider.provider_openai_official import ProviderOpenAIOfficial
from cores.qqbot.personality import personalities from cores.qqbot.personality import personalities
from model.platform.qq import QQ from model.platform.qq import QQ
from util import general_utils as gu from util import general_utils as gu
from cores.qqbot.global_object import GlobalObject
class CommandOpenAIOfficial(Command): class CommandOpenAIOfficial(Command):
def __init__(self, provider: ProviderOpenAIOfficial, global_object: dict): def __init__(self, provider: ProviderOpenAIOfficial, global_object: GlobalObject):
self.provider = provider self.provider = provider
self.cached_plugins = {} self.cached_plugins = {}
self.global_object = global_object self.global_object = global_object
self.personality_str = ""
super().__init__(provider, global_object)
def check_command(self, def check_command(self,
message: str, message: str,
session_id: str, session_id: str,
user_name: str,
role: str, role: str,
platform: str, platform: str,
message_obj, message_obj):
cached_plugins: dict,
qq_platform: QQ,):
self.platform = platform self.platform = platform
hit, res = super().check_command(message, role, platform, message_obj=message_obj, hit, res = super().check_command(
cached_plugins=cached_plugins, message,
qq_platform=qq_platform, session_id,
global_object=self.global_object) role,
platform,
message_obj
)
if hit: if hit:
return True, res return True, res
if self.command_start_with(message, "reset", "重置"): if self.command_start_with(message, "reset", "重置"):
return True, self.reset(session_id) return True, self.reset(session_id, message)
elif self.command_start_with(message, "his", "历史"): elif self.command_start_with(message, "his", "历史"):
return True, self.his(message, session_id, user_name) return True, self.his(message, session_id)
elif self.command_start_with(message, "token"): elif self.command_start_with(message, "token"):
return True, self.token(session_id) return True, self.token(session_id)
elif self.command_start_with(message, "gpt"): elif self.command_start_with(message, "gpt"):
@@ -40,7 +44,7 @@ class CommandOpenAIOfficial(Command):
elif self.command_start_with(message, "count"): elif self.command_start_with(message, "count"):
return True, self.count() return True, self.count()
elif self.command_start_with(message, "help", "帮助"): elif self.command_start_with(message, "help", "帮助"):
return True, self.help(cached_plugins) return True, self.help()
elif self.command_start_with(message, "unset"): elif self.command_start_with(message, "unset"):
return True, self.unset(session_id) return True, self.unset(session_id)
elif self.command_start_with(message, "set"): elif self.command_start_with(message, "set"):
@@ -49,17 +53,14 @@ class CommandOpenAIOfficial(Command):
return True, self.update(message, role) return True, self.update(message, role)
elif self.command_start_with(message, "", "draw"): elif self.command_start_with(message, "", "draw"):
return True, self.draw(message) return True, self.draw(message)
elif self.command_start_with(message, "keyword"):
return True, self.keyword(message, role)
elif self.command_start_with(message, "key"): elif self.command_start_with(message, "key"):
return True, self.key(message, user_name) return True, self.key(message)
elif self.command_start_with(message, "switch"):
if self.command_start_with(message, "/"): return True, self.switch(message)
return True, (False, "未知指令", "unknown_command")
return False, None return False, None
def help(self, cached_plugins): def help(self):
commands = super().general_commands() commands = super().general_commands()
commands[''] = '画画' commands[''] = '画画'
commands['key'] = '添加OpenAI key' commands['key'] = '添加OpenAI key'
@@ -67,16 +68,23 @@ class CommandOpenAIOfficial(Command):
commands['gpt'] = '查看gpt配置信息' commands['gpt'] = '查看gpt配置信息'
commands['status'] = '查看key使用状态' commands['status'] = '查看key使用状态'
commands['token'] = '查看本轮会话token' commands['token'] = '查看本轮会话token'
return True, super().help_messager(commands, self.platform, cached_plugins), "help" return True, super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
def reset(self, session_id: str): def reset(self, session_id: str, message: str = "reset"):
if self.provider is None: if self.provider is None:
return False, "未启动OpenAI ChatGPT语言模型.", "reset" return False, "未启动OpenAI ChatGPT语言模型.", "reset"
self.provider.forget(session_id) l = message.split(" ")
return True, "重置成功", "reset" if len(l) == 1:
self.provider.forget(session_id)
return True, "重置成功", "reset"
if len(l) == 2 and l[1] == "p":
self.provider.forget(session_id)
if self.personality_str != "":
self.set(self.personality_str, session_id) # 重新设置人格
return True, "重置成功", "reset"
def his(self, message: str, session_id: str, name: str): def his(self, message: str, session_id: str):
if self.provider is None: if self.provider is None:
return False, "未启动OpenAI ChatGPT语言模型.", "his" return False, "未启动OpenAI ChatGPT语言模型.", "his"
#分页每页5条 #分页每页5条
@@ -122,17 +130,17 @@ class CommandOpenAIOfficial(Command):
continue continue
if 'sponsor' in key_stat[key]: if 'sponsor' in key_stat[key]:
sponsor = key_stat[key]['sponsor'] sponsor = key_stat[key]['sponsor']
chatgpt_cfg_str += f" |-{index}: {key_stat[key]['used']}/{max} {sponsor}赞助{tag}\n" chatgpt_cfg_str += f" |-{index}: {key[-8:]} {key_stat[key]['used']}/{max} {sponsor}{tag}\n"
index += 1 index += 1
return True, f"⭐使用情况({str(gg_count)}个已用):\n{chatgpt_cfg_str}⏰全频道已用{total}tokens", "status" return True, f"⭐使用情况({str(gg_count)}个已用):\n{chatgpt_cfg_str}", "status"
def count(self): def count(self):
if self.provider is None: if self.provider is None:
return False, "未启动OpenAI ChatGPT语言模型.", "reset" return False, "未启动OpenAI ChatGPT语言模型", "reset"
guild_count, guild_msg_count, guild_direct_msg_count, session_count = self.provider.get_stat() guild_count, guild_msg_count, guild_direct_msg_count, session_count = self.provider.get_stat()
return True, f"当前会话数: {len(self.provider.session_dict)}\n共有频道数: {guild_count} \n共有消息数: {guild_msg_count}\n私信数: {guild_direct_msg_count}\n历史会话数: {session_count}", "count" return True, f"【本指令部分统计可能已经过时】\n当前会话数: {len(self.provider.session_dict)}\n共有频道数: {guild_count} \n共有消息数: {guild_msg_count}\n私信数: {guild_direct_msg_count}\n历史会话数: {session_count}", "count"
def key(self, message: str, user_name: str): def key(self, message: str):
if self.provider is None: if self.provider is None:
return False, "未启动OpenAI ChatGPT语言模型.", "reset" return False, "未启动OpenAI ChatGPT语言模型.", "reset"
l = message.split(" ") l = message.split(" ")
@@ -141,11 +149,41 @@ class CommandOpenAIOfficial(Command):
return True, msg, "key" return True, msg, "key"
key = l[1] key = l[1]
if self.provider.check_key(key): if self.provider.check_key(key):
self.provider.append_key(key, user_name) self.provider.append_key(key)
return True, f"*★,°*:.☆( ̄▽ ̄)/$:*.°★* 。\n该Key被验证为有效。感谢{user_name}赞助~" return True, f"*★,°*:.☆( ̄▽ ̄)/$:*.°★* 。\n该Key被验证为有效。感谢你的赞助~"
else: else:
return True, "该Key被验证为无效。也许是输入错误了或者重试。", "key" return True, "该Key被验证为无效。也许是输入错误了或者重试。", "key"
def switch(self, message: str):
'''
切换账号
'''
l = message.split(" ")
if len(l) == 1:
_, ret, _ = self.status()
curr_ = self.provider.get_curr_key()
if curr_ is None:
ret += "当前您未选择账号。输入/switch <账号序号>切换账号。"
else:
ret += f"当前您选择的账号为:{curr_[-8:]}。输入/switch <账号序号>切换账号。"
return True, ret, "switch"
elif len(l) == 2:
try:
key_stat = self.provider.get_key_stat()
index = int(l[1])
if index > len(key_stat) or index < 1:
return True, "账号序号不合法。", "switch"
else:
ret = self.provider.check_key(list(key_stat.keys())[index-1])
if ret:
return True, f"账号切换成功。", "switch"
else:
return True, f"账号切换失败,可能超额或超频。", "switch"
except BaseException as e:
return True, "未知错误: "+str(e), "switch"
else:
return True, "参数过多。", "switch"
def unset(self, session_id: str): def unset(self, session_id: str):
if self.provider is None: if self.provider is None:
return False, "未启动OpenAI ChatGPT语言模型.", "unset" return False, "未启动OpenAI ChatGPT语言模型.", "unset"
@@ -158,9 +196,9 @@ class CommandOpenAIOfficial(Command):
return False, "未启动OpenAI ChatGPT语言模型.", "set" return False, "未启动OpenAI ChatGPT语言模型.", "set"
l = message.split(" ") l = message.split(" ")
if len(l) == 1: if len(l) == 1:
return True, f"由Github项目QQChannelChatGPT支持】\n\n人格文本由PlexPt开源项目awesome-chatgpt-pr \ return True, f"【人格文本由PlexPt开源项目awesome-chatgpt-pr \
ompts-zh提供】\n\n这个是人格设置指令。\n设置人格: \n/set 人格名。例如/set 编剧\n人格列表: /set list\n人格详细信息: \ ompts-zh提供】\n设置人格: \n/set 人格名。例如/set 编剧\n人格列表: /set list\n人格详细信息: \
/set view 人格名\n自定义人格: /set 人格文本\n清除人格: /unset\n【当前人格】: {str(self.provider.now_personality)}", "set" /set view 人格名\n自定义人格: /set 人格文本\n重置会话(清除人格): /reset\n重置会话(保留人格): /reset p\n【当前人格】: {str(self.provider.now_personality)}", "set"
elif l[1] == "list": elif l[1] == "list":
msg = "人格列表:\n" msg = "人格列表:\n"
for key in personalities.keys(): for key in personalities.keys():
@@ -188,14 +226,20 @@ class CommandOpenAIOfficial(Command):
self.provider.session_dict[session_id] = [] self.provider.session_dict[session_id] = []
new_record = { new_record = {
"user": { "user": {
"role": "system", "role": "user",
"content": personalities[ps], "content": personalities[ps],
}, },
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0, 'usage_tokens': 0,
'single-tokens': 0 'single-tokens': 0
} }
self.provider.session_dict[session_id].append(new_record) self.provider.session_dict[session_id].append(new_record)
return True, f"人格{ps}已设置.", "set" self.personality_str = message
return True, f"人格{ps}已设置。", "set"
else: else:
self.provider.now_personality = { self.provider.now_personality = {
'name': '自定义人格', 'name': '自定义人格',
@@ -203,14 +247,20 @@ class CommandOpenAIOfficial(Command):
} }
new_record = { new_record = {
"user": { "user": {
"role": "system", "role": "user",
"content": ps, "content": ps,
}, },
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0, 'usage_tokens': 0,
'single-tokens': 0 'single-tokens': 0
} }
self.provider.session_dict[session_id] = [] self.provider.session_dict[session_id] = []
self.provider.session_dict[session_id].append(new_record) self.provider.session_dict[session_id].append(new_record)
self.personality_str = message
return True, f"自定义人格已设置。 \n人格信息: {ps}", "set" return True, f"自定义人格已设置。 \n人格信息: {ps}", "set"
def draw(self, message): def draw(self, message):

View File

@@ -1,43 +1,134 @@
from model.command.command import Command from model.command.command import Command
from model.provider.provider_rev_chatgpt import ProviderRevChatGPT from model.provider.provider_rev_chatgpt import ProviderRevChatGPT
from model.platform.qq import QQ from model.platform.qq import QQ
from cores.qqbot.personality import personalities
from cores.qqbot.global_object import GlobalObject
class CommandRevChatGPT(Command): class CommandRevChatGPT(Command):
def __init__(self, provider: ProviderRevChatGPT, global_object: dict): def __init__(self, provider: ProviderRevChatGPT, global_object: GlobalObject):
self.provider = provider self.provider = provider
self.cached_plugins = {} self.cached_plugins = {}
self.global_object = global_object self.global_object = global_object
self.personality_str = ""
super().__init__(provider, global_object)
def check_command(self, def check_command(self,
message: str, message: str,
session_id: str,
role: str, role: str,
platform: str, platform: str,
message_obj, message_obj):
cached_plugins: dict,
qq_platform: QQ):
self.platform = platform self.platform = platform
hit, res = super().check_command(message, role, platform, message_obj=message_obj, hit, res = super().check_command(
cached_plugins=cached_plugins, message,
qq_platform=qq_platform, session_id,
global_object=self.global_object) role,
platform,
message_obj
)
if hit: if hit:
return True, res return True, res
if self.command_start_with(message, "help", "帮助"): if self.command_start_with(message, "help", "帮助"):
return True, self.help(cached_plugins) return True, self.help()
elif self.command_start_with(message, "reset"): elif self.command_start_with(message, "reset"):
return True, self.reset() return True, self.reset(session_id, message)
elif self.command_start_with(message, "update"): elif self.command_start_with(message, "update"):
return True, self.update(message, role) return True, self.update(message, role)
elif self.command_start_with(message, "keyword"): elif self.command_start_with(message, "set"):
return True, self.keyword(message, role) return True, self.set(message, session_id)
elif self.command_start_with(message, "switch"):
if self.command_start_with(message, "/"): return True, self.switch(message, session_id)
return True, (False, "未知指令", "unknown_command")
return False, None return False, None
def reset(self, session_id, message: str):
l = message.split(" ")
if len(l) == 1:
self.provider.forget(session_id)
return True, "重置完毕。", "reset"
if len(l) == 2 and l[1] == "p":
self.provider.forget(session_id)
ret = self.provider.text_chat(self.personality_str)
return True, f"重置完毕(保留人格)。\n\n{ret}", "reset"
def reset(self): def set(self, message: str, session_id: str):
return False, "此功能暂未开放", "reset" l = message.split(" ")
if len(l) == 1:
return True, f"设置人格: \n/set 人格名或人格文本。例如/set 编剧\n人格列表: /set list\n人格详细信息: \
def help(self, cached_plugins: dict): /set view 人格名\n重置会话(清除人格): /reset\n重置会话(保留人格): /reset p", "set"
return True, super().help_messager(super().general_commands(), self.platform, cached_plugins), "help" elif l[1] == "list":
msg = "人格列表:\n"
for key in personalities.keys():
msg += f" |-{key}\n"
msg += '\n\n*输入/set view 人格名查看人格详细信息'
msg += '\n*不定时更新人格库,请及时更新本项目。'
return True, msg, "set"
elif l[1] == "view":
if len(l) == 2:
return True, "请输入/set view 人格名", "set"
ps = l[2].strip()
if ps in personalities:
msg = f"人格【{ps}】详细信息:\n"
msg += f"{personalities[ps]}\n"
else:
msg = f"人格【{ps}】不存在。"
return True, msg, "set"
else:
ps = l[1].strip()
if ps in personalities:
self.reset(session_id, "reset")
self.personality_str = personalities[ps]
ret = self.provider.text_chat(self.personality_str, session_id)
return True, f"人格【{ps}】已设置。\n\n{ret}", "set"
else:
self.reset(session_id, "reset")
self.personality_str = ps
ret = self.provider.text_chat(ps, session_id)
return True, f"人格信息已设置。\n\n{ret}", "set"
def switch(self, message: str, session_id: str):
'''
切换账号
'''
l = message.split(" ")
rev_chatgpt = self.provider.get_revchatgpt()
if len(l) == 1:
ret = "当前账号:\n"
index = 0
curr_ = None
for revstat in rev_chatgpt:
index += 1
ret += f"[{index}]. {revstat['id']}\n"
# if session_id in revstat['user']:
# curr_ = revstat['id']
for user in revstat['user']:
if session_id == user['id']:
curr_ = revstat['id']
break
if curr_ is None:
ret += "当前您未选择账号。输入/switch <账号序号>切换账号。"
else:
ret += f"当前您选择的账号为:{curr_}。输入/switch <账号序号>切换账号。"
return True, ret, "switch"
elif len(l) == 2:
try:
index = int(l[1])
if index > len(self.provider.rev_chatgpt) or index < 1:
return True, "账号序号不合法。", "switch"
else:
# pop
for revstat in self.provider.rev_chatgpt:
if session_id in revstat['user']:
revstat['user'].remove(session_id)
# append
self.provider.rev_chatgpt[index - 1]['user'].append(session_id)
return True, f"切换账号成功。当前账号为:{self.provider.rev_chatgpt[index - 1]['id']}", "switch"
except BaseException:
return True, "账号序号不合法。", "switch"
else:
return True, "参数过多。", "switch"
def help(self):
commands = super().general_commands()
commands['set'] = '设置人格'
return True, super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"

View File

@@ -2,42 +2,43 @@ from model.command.command import Command
from model.provider.provider_rev_edgegpt import ProviderRevEdgeGPT from model.provider.provider_rev_edgegpt import ProviderRevEdgeGPT
import asyncio import asyncio
from model.platform.qq import QQ from model.platform.qq import QQ
from cores.qqbot.global_object import GlobalObject
class CommandRevEdgeGPT(Command): class CommandRevEdgeGPT(Command):
def __init__(self, provider: ProviderRevEdgeGPT, global_object: dict): def __init__(self, provider: ProviderRevEdgeGPT, global_object: GlobalObject):
self.provider = provider self.provider = provider
self.cached_plugins = {} self.cached_plugins = {}
self.global_object = global_object self.global_object = global_object
super().__init__(provider, global_object)
def check_command(self, def check_command(self,
message: str, message: str,
loop, session_id: str,
role: str, role: str,
platform: str, platform: str,
message_obj, message_obj):
cached_plugins: dict,
qq_platform: QQ):
self.platform = platform self.platform = platform
hit, res = super().check_command(message, role, platform, message_obj=message_obj,
cached_plugins=cached_plugins, hit, res = super().check_command(
qq_platform=qq_platform, message,
global_object=self.global_object) session_id,
role,
platform,
message_obj
)
if hit: if hit:
return True, res return True, res
if self.command_start_with(message, "reset"): if self.command_start_with(message, "reset"):
return True, self.reset(loop) return True, self.reset()
elif self.command_start_with(message, "help"): elif self.command_start_with(message, "help"):
return True, self.help(cached_plugins) return True, self.help()
elif self.command_start_with(message, "update"): elif self.command_start_with(message, "update"):
return True, self.update(message, role) return True, self.update(message, role)
elif self.command_start_with(message, "keyword"):
return True, self.keyword(message, role)
if self.command_start_with(message, "/"):
return True, (False, "未知指令", "unknown_command")
return False, None return False, None
def reset(self, loop): def reset(self, loop = None):
if self.provider is None: if self.provider is None:
return False, "未启动Bing语言模型.", "reset" return False, "未启动Bing语言模型.", "reset"
res = asyncio.run_coroutine_threadsafe(self.provider.forget(), loop).result() res = asyncio.run_coroutine_threadsafe(self.provider.forget(), loop).result()
@@ -47,6 +48,5 @@ class CommandRevEdgeGPT(Command):
else: else:
return res, "重置失败", "reset" return res, "重置失败", "reset"
def help(self, cached_plugins: dict): def help(self):
return True, super().help_messager(super().general_commands(), self.platform, cached_plugins), "help" return True, super().help_messager(super().general_commands(), self.platform, self.global_object.cached_plugins), "help"

View File

@@ -19,6 +19,8 @@ class QQ:
self.is_start = is_start self.is_start = is_start
self.gocq_loop = gocq_loop self.gocq_loop = gocq_loop
self.cc = cc self.cc = cc
self.waiting = {}
self.gocq_cnt = 0
def run_bot(self, gocq): def run_bot(self, gocq):
self.client: CQHTTP = gocq self.client: CQHTTP = gocq
@@ -26,12 +28,18 @@ class QQ:
def get_msg_loop(self): def get_msg_loop(self):
return self.gocq_loop return self.gocq_loop
def get_cnt(self):
return self.gocq_cnt
def set_cnt(self, cnt):
self.gocq_cnt = cnt
async def send_qq_msg(self, async def send_qq_msg(self,
source, source,
res, res,
image_mode: bool = False): image_mode=None):
self.gocq_cnt += 1
if not self.is_start: if not self.is_start:
raise Exception("管理员未启动GOCQ平台") raise Exception("管理员未启动GOCQ平台")
""" """
@@ -47,11 +55,13 @@ class QQ:
if isinstance(res, str): if isinstance(res, str):
res_str = res res_str = res
res = [] res = []
if source.type == "GroupMessage": if source.type == "GroupMessage" and not isinstance(source, FakeSource):
res.append(At(qq=source.user_id)) res.append(At(qq=source.user_id))
res.append(Plain(text=res_str)) res.append(Plain(text=res_str))
# if image mode, put all Plain texts into a new picture. # if image mode, put all Plain texts into a new picture.
if image_mode is None:
image_mode = self.cc.get('qq_pic_mode', False)
if image_mode and isinstance(res, list): if image_mode and isinstance(res, list):
plains = [] plains = []
news = [] news = []
@@ -60,10 +70,11 @@ class QQ:
plains.append(i.text) plains.append(i.text)
else: else:
news.append(i) news.append(i)
p = gu.create_markdown_image("".join(plains)) plains_str = "".join(plains).strip()
news.append(Image.fromFileSystem(p)) if plains_str != "" and len(plains_str) > 50:
res = news p = gu.create_markdown_image("".join(plains))
news.append(Image.fromFileSystem(p))
res = news
# 回复消息链 # 回复消息链
if isinstance(res, list) and len(res) > 0: if isinstance(res, list) and len(res) > 0:
@@ -89,10 +100,10 @@ class QQ:
res.remove(i) res.remove(i)
node = Node(res) node = Node(res)
# node.content = res # node.content = res
node.uin = source.self_id node.uin = 123456
node.name = f"To {source.sender.nickname}:" node.name = f"bot"
node.time = int(time.time()) node.time = int(time.time())
print(node) # print(node)
nodes=[node] nodes=[node]
await self.client.sendGroupForwardMessage(source.group_id, nodes) await self.client.sendGroupForwardMessage(source.group_id, nodes)
return return
@@ -102,13 +113,15 @@ class QQ:
def send(self, def send(self,
to, to,
res, res,
image_mode=False,
): ):
''' '''
提供给插件的发送QQ消息接口, 不用在外部await。 提供给插件的发送QQ消息接口, 不用在外部await。
参数说明第一个参数可以是消息对象也可以是QQ群号。第二个参数是消息内容消息内容可以是消息链列表也可以是纯文字信息 参数说明第一个参数可以是消息对象也可以是QQ群号。第二个参数是消息内容消息内容可以是消息链列表也可以是纯文字信息
第三个参数是是否开启图片模式,如果开启,那么所有纯文字信息都会被合并成一张图片。
''' '''
try: try:
asyncio.run_coroutine_threadsafe(self.send_qq_msg(to, res), self.gocq_loop).result() asyncio.run_coroutine_threadsafe(self.send_qq_msg(to, res, image_mode), self.gocq_loop).result()
except BaseException as e: except BaseException as e:
raise e raise e
@@ -141,3 +154,30 @@ class QQ:
return p return p
except Exception as e: except Exception as e:
raise e raise e
def wait_for_message(self, group_id):
'''
等待下一条消息
'''
self.waiting[group_id] = ''
while True:
if group_id in self.waiting and self.waiting[group_id] != '':
# 去掉
ret = self.waiting[group_id]
del self.waiting[group_id]
return ret
time.sleep(0.5)
def get_client(self):
return self.client
def nakuru_method_invoker(self, func, *args, **kwargs):
"""
返回一个方法调用器可以用来立即调用nakuru的方法。
"""
try:
ret = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self.gocq_loop).result()
return ret
except BaseException as e:
raise e

View File

@@ -8,42 +8,123 @@ import requests
from cores.qqbot.personality import personalities from cores.qqbot.personality import personalities
from util import general_utils as gu from util import general_utils as gu
from nakuru.entities.components import Plain, At, Image from nakuru.entities.components import Plain, At, Image
from botpy.types.message import Reference
class NakuruGuildMember():
tiny_id: int
user_id: int
title: str
nickname: str
role: int
icon_url: str
class NakuruGuildMessage():
type: str = "GuildMessage"
self_id: int
self_tiny_id: int
sub_type: str
message_id: str
guild_id: int
channel_id: int
user_id: int
message: list
sender: NakuruGuildMember
raw_message: Message
def __str__(self) -> str:
return str(self.__dict__)
class QQChan(): class QQChan():
def __init__(self, cnt: dict = None) -> None:
self.qqchan_cnt = 0
def get_cnt(self):
return self.qqchan_cnt
def set_cnt(self, cnt):
self.qqchan_cnt = cnt
def run_bot(self, botclient, appid, token): def run_bot(self, botclient, appid, token):
intents = botpy.Intents(public_guild_messages=True, direct_message=True) intents = botpy.Intents(public_guild_messages=True, direct_message=True)
self.client = botclient self.client = botclient
self.client.run(appid=appid, token=token) self.client.run(appid=appid, token=token)
# gocq兼容层 # gocq-频道SDK兼容层
def gocq_compatible(self, gocq_message_chain: list): def gocq_compatible_send(self, gocq_message_chain: list):
plain_text = "" plain_text = ""
image_path = None # only one img supported image_path = None # only one img supported
for i in gocq_message_chain: for i in gocq_message_chain:
if isinstance(i, Plain): if isinstance(i, Plain):
plain_text += i.text plain_text += i.text
elif isinstance(i, Image) and image_path == None: elif isinstance(i, Image) and image_path == None:
image_path = i.path if i.path is not None:
image_path = i.path
else:
image_path = i.file
return plain_text, image_path return plain_text, image_path
# gocq-频道SDK兼容层
def gocq_compatible_receive(self, message: Message) -> NakuruGuildMessage:
ngm = NakuruGuildMessage()
try:
ngm.self_id = message.mentions[0].id
ngm.self_tiny_id = message.mentions[0].id
except:
ngm.self_id = 0
ngm.self_tiny_id = 0
ngm.sub_type = "normal"
ngm.message_id = message.id
ngm.guild_id = int(message.channel_id)
ngm.channel_id = int(message.channel_id)
ngm.user_id = int(message.author.id)
msg = []
plain_content = message.content.replace("<@!"+str(ngm.self_id)+">", "").strip()
msg.append(Plain(plain_content))
if message.attachments:
for i in message.attachments:
if i.content_type.startswith("image"):
url = i.url
if not url.startswith("http"):
url = "https://"+url
img = Image.fromURL(url)
msg.append(img)
ngm.message = msg
ngm.sender = NakuruGuildMember()
ngm.sender.tiny_id = int(message.author.id)
ngm.sender.user_id = int(message.author.id)
ngm.sender.title = ""
ngm.sender.nickname = message.author.username
ngm.sender.role = 0
ngm.sender.icon_url = message.author.avatar
ngm.raw_message = message
return ngm
def send_qq_msg(self, message: Message, res, msg_ref = None):
def send_qq_msg(self, message: NakuruGuildMessage, res):
gu.log("回复QQ频道消息: "+str(res), level=gu.LEVEL_INFO, tag="QQ频道", max_len=500) gu.log("回复QQ频道消息: "+str(res), level=gu.LEVEL_INFO, tag="QQ频道", max_len=500)
self.qqchan_cnt += 1
plain_text = "" plain_text = ""
image_path = None image_path = None
if isinstance(res, list): if isinstance(res, list):
# 兼容gocq # 兼容gocq
plain_text, image_path = self.gocq_compatible(res) plain_text, image_path = self.gocq_compatible_send(res)
elif isinstance(res, str): elif isinstance(res, str):
plain_text = res plain_text = res
print(plain_text, image_path) # print(plain_text, image_path)
msg_ref = Reference(message_id=message.raw_message.id, ignore_get_message_error=False)
if image_path is not None:
msg_ref = None
if image_path.startswith("http"):
pic_res = requests.get(image_path, stream = True)
if pic_res.status_code == 200:
image = PILImage.open(io.BytesIO(pic_res.content))
image_path = gu.save_temp_img(image)
try: try:
reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(plain_text), message_reference = msg_ref, file_image=image_path), self.client.loop) reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(plain_text), message_reference = msg_ref, file_image=image_path), self.client.loop)
reply_res.result() reply_res.result()
except BaseException as e: except BaseException as e:
# 分割过长的消息 # 分割过长的消息
@@ -52,21 +133,21 @@ class QQChan():
split_res.append(plain_text[:len(plain_text)//2]) split_res.append(plain_text[:len(plain_text)//2])
split_res.append(plain_text[len(plain_text)//2:]) split_res.append(plain_text[len(plain_text)//2:])
for i in split_res: for i in split_res:
reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(i), message_reference = msg_ref, file_image=image_path), self.client.loop) reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(i), message_reference = msg_ref, file_image=image_path), self.client.loop)
reply_res.result() reply_res.result()
else: else:
# 发送qq信息 # 发送qq信息
try: try:
# 防止被qq频道过滤消息 # 防止被qq频道过滤消息
plain_text = plain_text.replace(".", " . ") plain_text = plain_text.replace(".", " . ")
reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(plain_text), message_reference = msg_ref, file_image=image_path), self.client.loop) reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(plain_text), message_reference = msg_ref, file_image=image_path), self.client.loop)
# 发送信息 # 发送信息
except BaseException as e: except BaseException as e:
print("QQ频道API错误: \n"+str(e)) print("QQ频道API错误: \n"+str(e))
try: try:
reply_res = asyncio.run_coroutine_threadsafe(message.reply(content=str(str.join(" ", plain_text)), message_reference = msg_ref, file_image=image_path), self.client.loop) reply_res = asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=str(str.join(" ", plain_text)), message_reference = msg_ref, file_image=image_path), self.client.loop)
except BaseException as e: except BaseException as e:
plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE) plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
plain_text = plain_text.replace(".", "·") plain_text = plain_text.replace(".", "·")
asyncio.run_coroutine_threadsafe(message.reply(content=plain_text), self.client.loop).result() asyncio.run_coroutine_threadsafe(message.raw_message.reply(content=plain_text), self.client.loop).result()
# send(message, f"QQ频道API错误{str(e)}\n下面是格式化后的回答\n{f_res}") # send(message, f"QQ频道API错误{str(e)}\n下面是格式化后的回答\n{f_res}")

View File

@@ -3,16 +3,11 @@ import abc
class Provider: class Provider:
def __init__(self, cfg): def __init__(self, cfg):
pass pass
def text_chat(self, prompt): @abc.abstractmethod
pass def text_chat(self, prompt, session_id, image_url: None, function_call: None):
def image_chat(self, prompt):
pass
def memory(self):
pass pass
@abc.abstractmethod @abc.abstractmethod
def forget(self) -> bool: def forget(self, session_id = None) -> bool:
pass pass

View File

@@ -1,4 +1,5 @@
import openai from openai import OpenAI
from openai.types.chat.chat_completion import ChatCompletion
import json import json
import time import time
import os import os
@@ -7,26 +8,43 @@ from cores.database.conn import dbConn
from model.provider.provider import Provider from model.provider.provider import Provider
import threading import threading
from util import general_utils as gu from util import general_utils as gu
import traceback
import tiktoken
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
key_record_path = abs_path+'chatgpt_key_record' key_record_path = abs_path + 'chatgpt_key_record'
class ProviderOpenAIOfficial(Provider): class ProviderOpenAIOfficial(Provider):
def __init__(self, cfg): def __init__(self, cfg):
self.key_list = [] self.key_list = []
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '': # 如果 cfg['key']中有长度为1的字符串那么是格式错误直接报错
openai.api_base = cfg['api_base'] for key in cfg['key']:
if len(key) == 1:
input("检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格),请退出程序并检查配置文件,按回车跳过。")
raise BaseException("配置文件格式错误")
if cfg['key'] != '' and cfg['key'] != None: if cfg['key'] != '' and cfg['key'] != None:
gu.log("读取ChatGPT Key成功")
self.key_list = cfg['key'] self.key_list = cfg['key']
else: else:
input("[System] 请先去完善ChatGPT的Key。详情请前往https://beta.openai.com/account/api-keys") input("[System] 请先去完善ChatGPT的Key。详情请前往https://beta.openai.com/account/api-keys")
if len(self.key_list) == 0:
raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。")
# init key record self.key_stat = {}
self.init_key_record() for k in self.key_list:
self.key_stat[k] = {'exceed': False, 'used': 0}
self.chatGPT_configs = cfg['chatGPTConfigs'] self.api_base = None
gu.log(f'加载ChatGPTConfigs: {self.chatGPT_configs}') if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
self.api_base = cfg['api_base']
print(f"设置 api_base 为: {self.api_base}")
# openai client
self.client = OpenAI(
api_key=self.key_list[0],
base_url=self.api_base
)
self.openai_model_configs: dict = cfg['chatGPTConfigs']
gu.log(f'加载 OpenAI Chat Configs: {self.openai_model_configs}')
self.openai_configs = cfg self.openai_configs = cfg
# 会话缓存 # 会话缓存
self.session_dict = {} self.session_dict = {}
@@ -35,14 +53,16 @@ class ProviderOpenAIOfficial(Provider):
# 历史记录持久化间隔时间 # 历史记录持久化间隔时间
self.history_dump_interval = 20 self.history_dump_interval = 20
self.enc = tiktoken.get_encoding("cl100k_base")
# 读取历史记录 # 读取历史记录
try: try:
db1 = dbConn() db1 = dbConn()
for session in db1.get_all_session(): for session in db1.get_all_session():
self.session_dict[session[0]] = json.loads(session[1])['data'] self.session_dict[session[0]] = json.loads(session[1])['data']
gu.log("历史记录读取成功") gu.log("读取历史记录成功")
except BaseException as e: except BaseException as e:
gu.log("历史记录读取失败喵", level=gu.LEVEL_ERROR) gu.log("读取历史记录失败,但不影响使用。", level=gu.LEVEL_ERROR)
# 读取统计信息 # 读取统计信息
@@ -67,7 +87,7 @@ class ProviderOpenAIOfficial(Provider):
self.now_personality = {} self.now_personality = {}
# 转储历史记录的定时器~ Soulter # 转储历史记录
def dump_history(self): def dump_history(self):
time.sleep(10) time.sleep(10)
db = dbConn() db = dbConn()
@@ -90,10 +110,11 @@ class ProviderOpenAIOfficial(Provider):
# 每隔10分钟转储一次 # 每隔10分钟转储一次
time.sleep(10*self.history_dump_interval) time.sleep(10*self.history_dump_interval)
def text_chat(self, prompt, session_id = None): def text_chat(self, prompt, session_id = None, image_url = None, function_call=None):
if session_id is None: if session_id is None:
session_id = "unknown" session_id = "unknown"
del self.session_dict["unknown"] if "unknown" in self.session_dict:
del self.session_dict["unknown"]
# 会话机制 # 会话机制
if session_id not in self.session_dict: if session_id not in self.session_dict:
self.session_dict[session_id] = [] self.session_dict[session_id] = []
@@ -111,49 +132,93 @@ class ProviderOpenAIOfficial(Provider):
f.write(json.dumps(fjson)) f.write(json.dumps(fjson))
f.flush() f.flush()
f.close() f.close()
# 使用 tictoken 截断消息
_encoded_prompt = self.enc.encode(prompt)
prompt = self.enc.decode(_encoded_prompt[:self.openai_model_configs['max_tokens'] - 100])
gu.log(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。", level=gu.LEVEL_WARNING, max_len=300)
cache_data_list, new_record, req = self.wrap(prompt, session_id) cache_data_list, new_record, req = self.wrap(prompt, session_id, image_url)
gu.log(f"CACHE_DATA_: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, max_len=99999)
gu.log(f"OPENAI REQUEST: {str(req)}", level=gu.LEVEL_DEBUG, max_len=9999)
retry = 0 retry = 0
response = None response = None
err = '' err = ''
while retry < 5:
# 截断倍率
truncate_rate = 0.75
use_gpt4v = False
for i in req:
if isinstance(i['content'], list):
use_gpt4v = True
break
if image_url is not None:
use_gpt4v = True
if use_gpt4v:
conf = self.openai_model_configs.copy()
conf['model'] = 'gpt-4-vision-preview'
else:
conf = self.openai_model_configs
print(req)
while retry < 10:
try: try:
response = openai.ChatCompletion.create( if function_call is None:
messages=req, response = self.client.chat.completions.create(
**self.chatGPT_configs messages=req,
) **conf
)
else:
response = self.client.chat.completions.create(
messages=req,
tools = function_call,
**conf
)
break break
except Exception as e: except Exception as e:
print(traceback.format_exc())
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
raise e
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
gu.log("当前Key已超额或异常, 正在切换", level=gu.LEVEL_WARNING) gu.log("当前Key已超额或异常, 正在切换", level=gu.LEVEL_WARNING)
self.key_stat[openai.api_key]['exceed'] = True self.key_stat[self.client.api_key]['exceed'] = True
self.save_key_record() is_switched = self.handle_switch_key()
response, is_switched = self.handle_switch_key(req)
if not is_switched: if not is_switched:
# 所有Key都超额或不正常 # 所有Key都超额或不正常
raise e raise e
else: retry -= 1
break
elif 'maximum context length' in str(e): elif 'maximum context length' in str(e):
gu.log("token超限, 清空对应缓存") gu.log("token超限, 清空对应缓存,并进行消息截断")
self.session_dict[session_id] = [] self.session_dict[session_id] = []
prompt = prompt[:int(len(prompt)*truncate_rate)]
truncate_rate -= 0.05
cache_data_list, new_record, req = self.wrap(prompt, session_id) cache_data_list, new_record, req = self.wrap(prompt, session_id)
elif 'Limit: 3 / min. Please try again in 20s.' in str(e):
time.sleep(60) elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e):
time.sleep(30)
continue
else: else:
gu.log(str(e), level=gu.LEVEL_ERROR) gu.log(str(e), level=gu.LEVEL_ERROR)
time.sleep(2)
err = str(e) err = str(e)
retry+=1 retry += 1
if retry >= 5: if retry >= 10:
gu.log(r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见https://github.com/Soulter/QQChannelChatGPT/wiki/%E4%BA%8C%E3%80%81%E9%A1%B9%E7%9B%AE%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E9%85%8D%E7%BD%AE", max_len=999) gu.log(r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见https://github.com/Soulter/QQChannelChatGPT/wiki/%E4%BA%8C%E3%80%81%E9%A1%B9%E7%9B%AE%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E9%85%8D%E7%BD%AE", max_len=999)
raise BaseException("连接出错: "+str(err)) raise BaseException("连接出错: "+str(err))
assert isinstance(response, ChatCompletion)
self.key_stat[openai.api_key]['used'] += response['usage']['total_tokens'] gu.log(f"OPENAI RESPONSE: {response.usage}", level=gu.LEVEL_DEBUG, max_len=9999)
self.save_key_record()
# print("[ChatGPT] "+str(response["choices"][0]["message"]["content"])) # 结果分类
chatgpt_res = str(response["choices"][0]["message"]["content"]).strip() choice = response.choices[0]
current_usage_tokens = response['usage']['total_tokens'] if choice.message.content != None:
# 文本形式
chatgpt_res = str(choice.message.content).strip()
elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0:
# tools call (function calling)
return choice.message.tool_calls[0].function
self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens
current_usage_tokens = response.usage.total_tokens
# 超过指定tokens 尽可能的保留最多的条目直到小于max_tokens # 超过指定tokens 尽可能的保留最多的条目直到小于max_tokens
if current_usage_tokens > self.max_tokens: if current_usage_tokens > self.max_tokens:
@@ -163,7 +228,7 @@ class ProviderOpenAIOfficial(Provider):
if index >= len(cache_data_list): if index >= len(cache_data_list):
break break
# 保留人格信息 # 保留人格信息
if 'user' in cache_data_list[index] and cache_data_list[index]['user']['role'] != 'system': if cache_data_list[index]['type'] != 'personality':
t -= int(cache_data_list[index]['single_tokens']) t -= int(cache_data_list[index]['single_tokens'])
del cache_data_list[index] del cache_data_list[index]
else: else:
@@ -182,6 +247,7 @@ class ProviderOpenAIOfficial(Provider):
new_record['single_tokens'] = current_usage_tokens - int(cache_data_list[-1]['usage_tokens']) new_record['single_tokens'] = current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
else: else:
new_record['single_tokens'] = current_usage_tokens new_record['single_tokens'] = current_usage_tokens
cache_data_list.append(new_record) cache_data_list.append(new_record)
self.session_dict[session_id] = cache_data_list self.session_dict[session_id] = cache_data_list
@@ -193,13 +259,11 @@ class ProviderOpenAIOfficial(Provider):
image_url = '' image_url = ''
while retry < 5: while retry < 5:
try: try:
# print("test1") response = self.client.images.generate(
response = openai.Image.create(
prompt=prompt, prompt=prompt,
n=img_num, n=img_num,
size=img_size size=img_size
) )
# print("test2")
image_url = [] image_url = []
for i in range(img_num): for i in range(img_num):
image_url.append(response['data'][i]['url']) image_url.append(response['data'][i]['url'])
@@ -208,23 +272,22 @@ class ProviderOpenAIOfficial(Provider):
gu.log(str(e), level=gu.LEVEL_ERROR) gu.log(str(e), level=gu.LEVEL_ERROR)
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str( if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
gu.log("当前Key已超额或者不正常, 正在切换", level=gu.LEVEL_WARNING) gu.log("当前 Key 已超额或者不正常, 正在切换", level=gu.LEVEL_WARNING)
self.key_stat[openai.api_key]['exceed'] = True self.key_stat[self.client.api_key]['exceed'] = True
self.save_key_record() is_switched = self.handle_switch_key()
response, is_switched = self.handle_switch_key(req)
if not is_switched: if not is_switched:
# 所有Key都超额或不正常 # 所有Key都超额或不正常
raise e raise e
else: else:
break retry += 1
retry += 1
if retry >= 5: if retry >= 5:
raise BaseException("连接超时") raise BaseException("连接超时")
return image_url return image_url
def forget(self, session_id) -> bool: def forget(self, session_id = None) -> bool:
if session_id is None:
return False
self.session_dict[session_id] = [] self.session_dict[session_id] = []
return True return True
@@ -285,7 +348,20 @@ class ProviderOpenAIOfficial(Provider):
return -1, -1, -1, -1 return -1, -1, -1, -1
# 包装信息 # 包装信息
def wrap(self, prompt, session_id): def wrap(self, prompt, session_id, image_url = None):
if image_url is not None:
prompt = [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
# 获得缓存信息 # 获得缓存信息
context = self.session_dict[session_id] context = self.session_dict[session_id]
new_record = { new_record = {
@@ -294,6 +370,7 @@ class ProviderOpenAIOfficial(Provider):
"content": prompt, "content": prompt,
}, },
"AI": {}, "AI": {},
'type': "common",
'usage_tokens': 0, 'usage_tokens': 0,
} }
req_list = [] req_list = []
@@ -305,105 +382,53 @@ class ProviderOpenAIOfficial(Provider):
req_list.append(new_record['user']) req_list.append(new_record['user'])
return context, new_record, req_list return context, new_record, req_list
def handle_switch_key(self, req): def handle_switch_key(self):
# messages = [{"role": "user", "content": prompt}] # messages = [{"role": "user", "content": prompt}]
while True: is_all_exceed = True
is_all_exceed = True for key in self.key_stat:
for key in self.key_stat: if key == None or self.key_stat[key]['exceed']:
if key == None: continue
continue is_all_exceed = False
if not self.key_stat[key]['exceed']: self.client.api_key = key
is_all_exceed = False gu.log(f"切换到Key: {key}, 已使用token: {self.key_stat[key]['used']}", level=gu.LEVEL_INFO)
openai.api_key = key break
gu.log(f"切换到Key: {key}, 已使用token: {self.key_stat[key]['used']}", level=gu.LEVEL_INFO) if is_all_exceed:
if len(req) > 0: gu.log("所有Key已超额", level=gu.LEVEL_CRITICAL)
try: return False
response = openai.ChatCompletion.create( return True
messages=req,
**self.chatGPT_configs def get_configs(self):
)
return response, True
except Exception as e:
if 'You exceeded' in str(e):
gu.log("当前Key已超额, 正在切换")
self.key_stat[openai.api_key]['exceed'] = True
self.save_key_record()
time.sleep(1)
continue
else:
gu.log(str(e), level=gu.LEVEL_ERROR)
else:
return True
if is_all_exceed:
gu.log("所有Key已超额", level=gu.LEVEL_CRITICAL)
return None, False
else:
gu.log("在切换key时程序异常。", level=gu.LEVEL_ERROR)
return None, False
def getConfigs(self):
return self.openai_configs return self.openai_configs
def save_key_record(self):
with open(key_record_path, 'w', encoding='utf-8') as f:
json.dump(self.key_stat, f)
def get_key_stat(self): def get_key_stat(self):
return self.key_stat return self.key_stat
def get_key_list(self): def get_key_list(self):
return self.key_list return self.key_list
def get_curr_key(self):
return self.client.api_key
# 添加key # 添加key
def append_key(self, key, sponsor): def append_key(self, key, sponsor):
self.key_list.append(key) self.key_list.append(key)
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor} self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
self.save_key_record()
self.init_key_record()
# 检查key是否可用 # 检查key是否可用
def check_key(self, key): def check_key(self, key):
pre_key = openai.api_key client_ = OpenAI(
openai.api_key = key api_key=key,
messages = [{"role": "user", "content": "1"}] base_url=self.api_base
)
messages = [{"role": "user", "content": "please just echo `test`"}]
try: try:
response = openai.ChatCompletion.create( client_.chat.completions.create(
messages=messages, messages=messages,
**self.chatGPT_configs **self.openai_model_configs
) )
openai.api_key = pre_key
return True return True
except Exception as e: except Exception as e:
pass pass
openai.api_key = pre_key
return False return False
#将key_list的key转储到key_record中并记录相关数据
def init_key_record(self):
if not os.path.exists(key_record_path):
with open(key_record_path, 'w', encoding='utf-8') as f:
json.dump({}, f)
with open(key_record_path, 'r', encoding='utf-8') as keyfile:
try:
self.key_stat = json.load(keyfile)
except Exception as e:
gu.log(str(e), level=gu.LEVEL_ERROR)
self.key_stat = {}
finally:
for key in self.key_list:
if key not in self.key_stat:
self.key_stat[key] = {'exceed': False, 'used': 0}
# if openai.api_key is None:
# openai.api_key = key
else:
# if self.key_stat[key]['exceed']:
# print(f"Key: {key} 已超额")
# continue
# else:
# if openai.api_key is None:
# openai.api_key = key
# print(f"使用Key: {key}, 已使用token: {self.key_stat[key]['used']}")
pass
if openai.api_key == None:
self.handle_switch_key("")
self.save_key_record()

View File

@@ -7,8 +7,10 @@ import time
class ProviderRevChatGPT(Provider): class ProviderRevChatGPT(Provider):
def __init__(self, config): def __init__(self, config, base_url = None):
self.rev_chatgpt = [] if base_url == "":
base_url = None
self.rev_chatgpt: list[dict] = []
self.cc = cc.CmdConfig() self.cc = cc.CmdConfig()
for i in range(0, len(config['account'])): for i in range(0, len(config['account'])):
try: try:
@@ -28,26 +30,34 @@ class ProviderRevChatGPT(Provider):
rev_account_config['PUID'] = self.cc.get("rev_chatgpt_PUID") rev_account_config['PUID'] = self.cc.get("rev_chatgpt_PUID")
if len(self.cc.get("rev_chatgpt_unverified_plugin_domains")) > 0: if len(self.cc.get("rev_chatgpt_unverified_plugin_domains")) > 0:
rev_account_config['unverified_plugin_domains'] = self.cc.get("rev_chatgpt_unverified_plugin_domains") rev_account_config['unverified_plugin_domains'] = self.cc.get("rev_chatgpt_unverified_plugin_domains")
cb = Chatbot(config=rev_account_config) cb = Chatbot(config=rev_account_config, base_url=base_url)
# cb.captcha_solver = self.__captcha_solver # cb.captcha_solver = self.__captcha_solver
# 后八位c
g_id = rev_account_config['access_token'][-8:]
revstat = { revstat = {
'id': g_id,
'obj': cb, 'obj': cb,
'busy': False 'busy': False,
'user': []
} }
self.rev_chatgpt.append(revstat) self.rev_chatgpt.append(revstat)
except BaseException as e: except BaseException as e:
gu.log(f"创建逆向ChatGPT负载{str(i+1)}失败: {str(e)}", level=gu.LEVEL_ERROR, tag="RevChatGPT") gu.log(f"创建逆向ChatGPT负载{str(i+1)}失败: {str(e)}", level=gu.LEVEL_ERROR, tag="RevChatGPT")
def forget(self) -> bool: def forget(self, session_id = None) -> bool:
for i in self.rev_chatgpt:
for user in i['user']:
if session_id == user['id']:
try:
i['obj'].reset_chat()
return True
except BaseException as e:
gu.log(f"重置RevChatGPT失败。原因: {str(e)}", level=gu.LEVEL_ERROR, tag="RevChatGPT")
return False
return False return False
# def __captcha_solver(images: list[str], challenge_details: dict) -> int: def get_revchatgpt(self) -> list:
# # Create tempfile return self.rev_chatgpt
# print("Captcha solver called")
# print(images)
# print(challenge_details)
# input("Press Enter to continue...")
# return 0
def request_text(self, prompt: str, bot) -> str: def request_text(self, prompt: str, bot) -> str:
resp = '' resp = ''
@@ -66,7 +76,8 @@ class ProviderRevChatGPT(Provider):
raise e raise e
if e.code == typings.ErrorType.PROHIBITED_CONCURRENT_QUERY_ERROR: if e.code == typings.ErrorType.PROHIBITED_CONCURRENT_QUERY_ERROR:
raise e raise e
if "Your authentication token has expired. Please try signing in again." in str(e):
raise e
if "The message you submitted was too long" in str(e): if "The message you submitted was too long" in str(e):
raise e raise e
if "You've reached our limit of messages per hour." in str(e): if "You've reached our limit of messages per hour." in str(e):
@@ -90,28 +101,110 @@ class ProviderRevChatGPT(Provider):
# print("[RevChatGPT] "+str(resp)) # print("[RevChatGPT] "+str(resp))
return resp return resp
def text_chat(self, prompt) -> str: def text_chat(self, prompt, session_id = None, image_url = None, function_call=None) -> str:
# 选择一个人少的账号。
selected_revstat = None
min_revstat = None
min_ = None
new_user = False
conversation_id = ''
parent_id = ''
for revstat in self.rev_chatgpt:
for user in revstat['user']:
if session_id == user['id']:
selected_revstat = revstat
conversation_id = user['conversation_id']
parent_id = user['parent_id']
break
if min_ is None:
min_ = len(revstat['user'])
min_revstat = revstat
elif len(revstat['user']) < min_:
min_ = len(revstat['user'])
min_revstat = revstat
# if session_id in revstat['user']:
# selected_revstat = revstat
# break
if selected_revstat is None:
selected_revstat = min_revstat
selected_revstat['user'].append({
'id': session_id,
'conversation_id': '',
'parent_id': ''
})
new_user = True
gu.log(f"选择账号{str(selected_revstat)}", tag="RevChatGPT", level=gu.LEVEL_DEBUG)
while selected_revstat['busy']:
gu.log(f"账号忙碌,等待中...", tag="RevChatGPT", level=gu.LEVEL_DEBUG)
time.sleep(1)
selected_revstat['busy'] = True
if not new_user:
# 非新用户,则使用其专用的会话
selected_revstat['obj'].conversation_id = conversation_id
selected_revstat['obj'].parent_id = parent_id
else:
# 新用户,则使用新的会话
selected_revstat['obj'].reset_chat()
res = '' res = ''
err_msg = '' err_msg = ''
cursor = 0 err_cnt = 0
for revstat in self.rev_chatgpt: while err_cnt < 15:
cursor += 1 try:
if not revstat['busy']: res = self.request_text(prompt, selected_revstat['obj'])
try: selected_revstat['busy'] = False
revstat['busy'] = True # 记录新用户的会话
res = self.request_text(prompt, revstat['obj']) if new_user:
revstat['busy'] = False i = 0
return res.strip() for user in selected_revstat['user']:
# todo: 细化错误管理 if user['id'] == session_id:
except BaseException as e: selected_revstat['user'][i]['conversation_id'] = selected_revstat['obj'].conversation_id
revstat['busy'] = False selected_revstat['user'][i]['parent_id'] = selected_revstat['obj'].parent_id
gu.log(f"请求出现问题: {str(e)}", level=gu.LEVEL_WARNING, tag="RevChatGPT") break
err_msg += f"账号{cursor} - 错误原因: {str(e)}" i += 1
continue return res.strip()
else: except BaseException as e:
err_msg += f"账号{cursor} - 错误原因: 忙碌" if "Your authentication token has expired. Please try signing in again." in str(e):
continue raise Exception(f"此账号(access_token后8位为{selected_revstat['id']})的access_token已过期请重新获取或者切换账号。")
raise Exception(f'回复失败。错误跟踪:{err_msg}') if "The message you submitted was too long" in str(e):
raise Exception("发送的消息太长,请分段发送。")
if "You've reached our limit of messages per hour." in str(e):
raise Exception("触发RevChatGPT请求频率限制。请1小时后再试或者切换账号。")
gu.log(f"请求异常: {str(e)}", level=gu.LEVEL_WARNING, tag="RevChatGPT")
err_cnt += 1
time.sleep(3)
raise Exception(f'回复失败。原因:{err_msg}。如果您设置了多个账号,可以使用/switch指令切换账号。输入/switch查看详情。')
# while self.is_all_busy():
# time.sleep(1)
# res = ''
# err_msg = ''
# cursor = 0
# for revstat in self.rev_chatgpt:
# cursor += 1
# if not revstat['busy']:
# try:
# revstat['busy'] = True
# res = self.request_text(prompt, revstat['obj'])
# revstat['busy'] = False
# return res.strip()
# # todo: 细化错误管理
# except BaseException as e:
# revstat['busy'] = False
# gu.log(f"请求出现问题: {str(e)}", level=gu.LEVEL_WARNING, tag="RevChatGPT")
# err_msg += f"账号{cursor} - 错误原因: {str(e)}"
# continue
# else:
# err_msg += f"账号{cursor} - 错误原因: 忙碌"
# continue
# raise Exception(f'回复失败。错误跟踪:{err_msg}')
def is_all_busy(self) -> bool: def is_all_busy(self) -> bool:
for revstat in self.rev_chatgpt: for revstat in self.rev_chatgpt:

View File

@@ -1,10 +1,12 @@
from model.provider.provider import Provider from model.provider.provider import Provider
from EdgeGPT import Chatbot, ConversationStyle # from EdgeGPT import Chatbot, ConversationStyle
import json import json
import os import os
from util import general_utils as gu from util import general_utils as gu
from util.cmd_config import CmdConfig as cc from util.cmd_config import CmdConfig as cc
import time
from EdgeGPT.EdgeUtils import Query, Cookie
from EdgeGPT.EdgeGPT import Chatbot as EdgeChatbot, ConversationStyle, NotAllowedToAccess
class ProviderRevEdgeGPT(Provider): class ProviderRevEdgeGPT(Provider):
def __init__(self): def __init__(self):
@@ -15,21 +17,27 @@ class ProviderRevEdgeGPT(Provider):
proxy = cc.get("bing_proxy", None) proxy = cc.get("bing_proxy", None)
if proxy == "": if proxy == "":
proxy = None proxy = None
self.bot = Chatbot(cookies=cookies, proxy = proxy) # q = Query("Hello, bing!", cookie_files="./cookies.json")
# print(q)
self.bot = EdgeChatbot(cookies=cookies, proxy = "http://127.0.0.1:7890")
ret = self.bot.ask_stream("Hello, bing!", conversation_style=ConversationStyle.creative, wss_link="wss://ai.nothingnessvoid.tech/sydney/ChatHub")
# self.bot = Chatbot(cookies=cookies, proxy = proxy)
for i in ret:
print(i, flush=True)
def is_busy(self): def is_busy(self):
return self.busy return self.busy
async def forget(self): async def forget(self, session_id = None):
try: try:
await self.bot.reset() await self.bot.reset()
return True return True
except BaseException: except BaseException:
return False return False
async def text_chat(self, prompt, platform = 'none'): async def text_chat(self, prompt, platform = 'none', image_url=None, function_call=None):
if self.busy: while self.busy:
return time.sleep(1)
self.busy = True self.busy = True
resp = 'err' resp = 'err'
err_count = 0 err_count = 0

View File

@@ -1,11 +1,15 @@
pydantic~=1.10.4 pydantic~=1.10.4
requests~=2.28.1 requests~=2.28.1
openai~=0.27.4 openai~=1.2.3
qq-botpy~=1.1.2 qq-botpy
revChatGPT~=6.8.6
baidu-aip~=4.16.9
EdgeGPT~=0.1.22.1
chardet~=5.1.0 chardet~=5.1.0
Pillow~=9.4.0 Pillow~=9.4.0
GitPython~=3.1.31 GitPython~=3.1.31
nakuru-project nakuru-project
beautifulsoup4
googlesearch-python
tictoken
readability-lxml
EdgeGPT
revChatGPT~=6.8.6
baidu-aip~=4.16.9

View File

@@ -1,3 +0,0 @@
class PromptExceededError(Exception):
pass

View File

@@ -2,6 +2,7 @@
import json import json
import util.general_utils as gu import util.general_utils as gu
import time
class FuncCallJsonFormatError(Exception): class FuncCallJsonFormatError(Exception):
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
@@ -24,9 +25,18 @@ class FuncCall():
def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj = None) -> None: def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj = None) -> None:
if name == None or func_args == None or desc == None or func_obj == None: if name == None or func_args == None or desc == None or func_obj == None:
raise FuncCallJsonFormatError("name, func_args, desc must be provided.") raise FuncCallJsonFormatError("name, func_args, desc must be provided.")
params = {
"type": "object", # hardcore here
"properties": {}
}
for param in func_args:
params['properties'][param['name']] = {
"type": param['type'],
"description": param['description']
}
self._func = { self._func = {
"name": name, "name": name,
"args": func_args, "parameters": params,
"description": desc, "description": desc,
"func_obj": func_obj, "func_obj": func_obj,
} }
@@ -37,18 +47,30 @@ class FuncCall():
for f in self.func_list: for f in self.func_list:
_l.append({ _l.append({
"name": f["name"], "name": f["name"],
"args": f["args"], "parameters": f["parameters"],
"description": f["description"], "description": f["description"],
}) })
return json.dumps(_l, indent=intent, ensur_ascii=False)
return json.dumps(_l, indent=intent, ensure_ascii=False)
def func_call(self, question, func_definition, is_task = False, tasks = None, taskindex = -1, is_summary = True): def get_func(self) -> list:
_l = []
for f in self.func_list:
_l.append({
"type": "function",
"function": {
"name": f["name"],
"parameters": f["parameters"],
"description": f["description"],
}
})
return _l
def func_call(self, question, func_definition, is_task = False, tasks = None, taskindex = -1, is_summary = True, session_id = None):
funccall_prompt = """ funccall_prompt = """
我正实现function call功能该功能旨在让你变成给定的问题到给定的函数的解析器意味着你不是创造函数)。 我正实现function call功能该功能旨在让你变成给定的问题到给定的函数的解析器意味着你不是创造函数
下面会给你提供可能用到函数相关信息和一个问题,你需要将其转换成给定的函数调用。 下面会给你提供可能用到函数相关信息和一个问题,你需要将其转换成给定的函数调用。
- 你的返回信息只含json严格仿照以下内容(不含注释): - 你的返回信息只含json严格仿照以下内容(不含注释),必须含有`res`,`func_call`字段:
``` ```
{ {
"res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。 "res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。
@@ -95,7 +117,7 @@ class FuncCall():
_c = 0 _c = 0
while _c < 3: while _c < 3:
try: try:
res = self.provider.text_chat(prompt) res = self.provider.text_chat(prompt, session_id)
if res.find('```') != -1: if res.find('```') != -1:
res = res[res.find('```json') + 7: res.rfind('```')] res = res[res.find('```json') + 7: res.rfind('```')]
gu.log("REVGPT func_call json result", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) gu.log("REVGPT func_call json result", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
@@ -111,7 +133,7 @@ class FuncCall():
invoke_func_res = "" invoke_func_res = ""
if len(res["func_call"]) > 0: if "func_call" in res and len(res["func_call"]) > 0:
task_list = res["func_call"] task_list = res["func_call"]
invoke_func_res_list = [] invoke_func_res_list = []
@@ -140,7 +162,7 @@ class FuncCall():
# 生成返回结果 # 生成返回结果
after_prompt = """ after_prompt = """
函数返回以下内容:"""+invoke_func_res+""" 以下内容:"""+invoke_func_res+"""
请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。 请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。
用户的提问是: 用户的提问是:
```""" + question + """``` ```""" + question + """```
@@ -157,7 +179,7 @@ class FuncCall():
_c = 0 _c = 0
while _c < 5: while _c < 5:
try: try:
res = self.provider.text_chat(after_prompt) res = self.provider.text_chat(after_prompt, session_id)
# 截取```之间的内容 # 截取```之间的内容
gu.log("DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"]) gu.log("DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
print(res) print(res)
@@ -174,6 +196,7 @@ class FuncCall():
raise e raise e
if "The message you submitted was too long" in str(e): if "The message you submitted was too long" in str(e):
# 如果返回的内容太长了,那么就截取一部分 # 如果返回的内容太长了,那么就截取一部分
time.sleep(3)
invoke_func_res = invoke_func_res[:int(len(invoke_func_res) / 2)] invoke_func_res = invoke_func_res[:int(len(invoke_func_res) / 2)]
after_prompt = """ after_prompt = """
函数返回以下内容:"""+invoke_func_res+""" 函数返回以下内容:"""+invoke_func_res+"""

View File

@@ -33,11 +33,20 @@ BG_COLORS = {
"default": "49", "default": "49",
} }
LEVEL_DEBUG = "DEBUG"
LEVEL_INFO = "INFO" LEVEL_INFO = "INFO"
LEVEL_WARNING = "WARNING" LEVEL_WARNING = "WARNING"
LEVEL_ERROR = "ERROR" LEVEL_ERROR = "ERROR"
LEVEL_CRITICAL = "CRITICAL" LEVEL_CRITICAL = "CRITICAL"
level_codes = {
LEVEL_DEBUG: 0,
LEVEL_INFO: 1,
LEVEL_WARNING: 2,
LEVEL_ERROR: 3,
LEVEL_CRITICAL: 4
}
level_colors = { level_colors = {
"INFO": "green", "INFO": "green",
"WARNING": "yellow", "WARNING": "yellow",
@@ -51,10 +60,22 @@ def log(
tag: str = "System", tag: str = "System",
fg: str = None, fg: str = None,
bg: str = None, bg: str = None,
max_len: int = 300): max_len: int = 500,
err: Exception = None,):
""" """
日志记录函数 日志打印函数
""" """
_set_level_code = level_codes[LEVEL_INFO]
if 'LOG_LEVEL' in os.environ and os.environ['LOG_LEVEL'] in level_codes:
_set_level_code = level_codes[os.environ['LOG_LEVEL']]
if level in level_codes and level_codes[level] < _set_level_code:
return
if err is not None:
msg += "\n异常原因: " + str(err)
level = LEVEL_ERROR
if len(msg) > max_len: if len(msg) > max_len:
msg = msg[:max_len] + "..." msg = msg[:max_len] + "..."
now = datetime.datetime.now().strftime("%m-%d %H:%M:%S") now = datetime.datetime.now().strftime("%m-%d %H:%M:%S")

View File

@@ -7,33 +7,79 @@ from util.func_call import (
FuncCallJsonFormatError, FuncCallJsonFormatError,
FuncNotFoundError FuncNotFoundError
) )
from openai.types.chat.chat_completion_message_tool_call import Function
import traceback
from googlesearch import search, SearchResult
from model.provider.provider import Provider
import json
from readability import Document
def tidy_text(text: str) -> str: def tidy_text(text: str) -> str:
return text.strip().replace("\n", "").replace(" ", "").replace("\r", "") '''
清理文本,去除空格、换行符等
'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
def special_fetch_zhihu(link: str) -> str: def special_fetch_zhihu(link: str) -> str:
'''
function-calling 函数, 用于获取知乎文章的内容
'''
headers = { headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
} }
response = requests.get(link, headers=headers) response = requests.get(link, headers=headers)
response.encoding = "utf-8"
soup = BeautifulSoup(response.text, "html.parser") soup = BeautifulSoup(response.text, "html.parser")
r = soup.find(class_="List-item").find(class_="RichContent-inner")
if "zhuanlan.zhihu.com" in link:
r = soup.find(class_="Post-RichTextContainer")
else:
r = soup.find(class_="List-item").find(class_="RichContent-inner")
if r is None: if r is None:
print("debug: zhihu none") print("debug: zhihu none")
raise Exception("zhihu none") raise Exception("zhihu none")
return tidy_text(r.text) return tidy_text(r.text)
def google_web_search(keyword) -> str:
'''
获取 google 搜索结果, 得到 title、desc、link
'''
ret = ""
index = 1
try:
ls = search(keyword, advanced=True, num_results=4)
for i in ls:
desc = i.description
try:
desc = fetch_website_content(i.url)
except BaseException as e:
print(f"(google) fetch_website_content err: {str(e)}")
gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999)
ret += f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n"
index += 1
except Exception as e:
print(f"google search err: {str(e)}")
return web_keyword_search_via_bing(keyword)
return ret
def web_keyword_search_via_bing(keyword) -> str: def web_keyword_search_via_bing(keyword) -> str:
'''
获取bing搜索结果, 得到 title、desc、link
'''
headers = { headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
} }
url = "https://cn.bing.com/search?q="+keyword url = "https://www.bing.com/search?q="+keyword
_cnt = 0 _cnt = 0
_detail_store = [] _detail_store = []
while _cnt < 5: while _cnt < 5:
try: try:
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers)
response.encoding = "utf-8"
gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999)
soup = BeautifulSoup(response.text, "html.parser") soup = BeautifulSoup(response.text, "html.parser")
res = [] res = []
ols = soup.find(id="b_results") ols = soup.find(id="b_results")
@@ -47,28 +93,39 @@ def web_keyword_search_via_bing(keyword) -> str:
"desc": desc, "desc": desc,
"link": link, "link": link,
}) })
if len(_detail_store) < 2 and "zhihu.com" in link:
try:
_detail_store.append(special_fetch_zhihu(link)[:800])
except BaseException as e:
print(f"zhihu parse err: {str(e)}")
if len(res) >= 5: # 限制5条 if len(res) >= 5: # 限制5条
break break
if len(_detail_store) >= 3:
continue
# 爬取前两条的网页内容
if "zhihu.com" in link:
try:
_detail_store.append(special_fetch_zhihu(link))
except BaseException as e:
print(f"zhihu parse err: {str(e)}")
else:
try:
_detail_store.append(fetch_website_content(link))
except BaseException as e:
print(f"fetch_website_content err: {str(e)}")
except Exception as e: except Exception as e:
print(f"bing parse err: {str(e)}") print(f"bing parse err: {str(e)}")
if len(res) == 0: if len(res) == 0:
break break
if len(_detail_store) > 0: if len(_detail_store) > 0:
ret = f"{str(res)} \n来源知乎的具体资料: {str(_detail_store)}" ret = f"{str(res)} \n具体网页内容: {str(_detail_store)}"
else: else:
ret = f"{str(res)}" ret = f"{str(res)}"
return str(ret) return str(ret)
except Exception as e: except Exception as e:
print(f"bing fetch err: {str(e)}") gu.log(f"bing fetch err: {str(e)}")
_cnt += 1 _cnt += 1
time.sleep(1) time.sleep(1)
print("fail to fetch bing info, using sougou.")
return web_keyword_search_via_sougou(keyword) gu.log("fail to fetch bing info, using sougou.")
return google_web_search(keyword)
def web_keyword_search_via_sougou(keyword) -> str: def web_keyword_search_via_sougou(keyword) -> str:
headers = { headers = {
@@ -92,59 +149,131 @@ def web_keyword_search_via_sougou(keyword) -> str:
"title": title, "title": title,
"link": link, "link": link,
}) })
except: if len(res) >= 5: # 限制5条
pass break
ret = f"{str(res)} \n全部内容: {tidy_text(soup.text)}" except Exception as e:
gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR)
# 爬取网页内容
_detail_store = []
for i in res:
if _detail_store >= 3:
break
try:
_detail_store.append(fetch_website_content(i["link"]))
except BaseException as e:
print(f"fetch_website_content err: {str(e)}")
ret = f"{str(res)}"
if len(_detail_store) > 0:
ret += f"\n网页内容: {str(_detail_store)}"
return ret return ret
def fetch_website_content(url): def fetch_website_content(url):
gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
headers = { headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
} }
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers, timeout=3)
soup = BeautifulSoup(response.text, "html.parser") response.encoding = "utf-8"
res = soup.text # soup = BeautifulSoup(response.text, "html.parser")
res = res.replace("\n", "") # # 如果有container / content / main等的话就只取这些部分
with open(f"temp_{time.time()}.html", "w", encoding="utf-8") as f: # has = False
f.write(res) # beleive_ls = ["container", "content", "main"]
return res # res = ""
# for cls in beleive_ls:
def web_search(question, provider): # for i in soup.find_all(class_=cls):
# has = True
# res += i.text
# if not has:
# res = soup.text
# res = res.replace("\n", "").replace(" ", " ").replace("\r", "").replace("\t", "")
# if not has:
# res = res[300:1100]
# else:
# res = res[100:800]
# # with open(f"temp_{time.time()}.html", "w", encoding="utf-8") as f:
# # f.write(res)
# gu.log(f"fetch_website_content: end", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
# return res
doc = Document(response.content)
# print('title:', doc.title())
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, 'html.parser')
ret = tidy_text(soup.get_text())
return ret
def web_search(question, provider: Provider, session_id, official_fc=False):
'''
official_fc: 使用官方 function-calling
'''
new_func_call = FuncCall(provider) new_func_call = FuncCall(provider)
new_func_call.add_func("google_web_search", [{
new_func_call.add_func("web_keyword_search_via_bing", [{
"type": "string", "type": "string",
"name": "keyword", "name": "keyword",
"brief": "必应搜索的关键词(分词,尽量保留所有信息)" "description": "google search query (分词,尽量保留所有信息)"
}], }],
"在必应搜索引擎搜索给定的关键词,并且返回第一页的搜索结果列表(标题,简介和链接)", "通过搜索引擎搜索。如果问题需要在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
web_keyword_search_via_bing google_web_search
) )
new_func_call.add_func("fetch_website_content", [{
"type": "string",
"name": "url",
"description": "网址"
}],
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下https://github.com的内容`), 就调用此函数。如果没有,不要调用此函数。",
fetch_website_content
)
question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。"
has_func = False
function_invoked_ret = ""
if official_fc:
func = provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
if isinstance(func, Function):
# arguments='{\n "keyword": "北京今天的天气"\n}', name='google_web_search'
# 执行对应的结果:
func_obj = None
for i in new_func_call.func_list:
if i["name"] == func.name:
func_obj = i["func_obj"]
break
if not func_obj:
gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
try:
args = json.loads(func.arguments)
function_invoked_ret = func_obj(**args)
has_func = True
except BaseException as e:
traceback.print_exc()
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
else:
# now func is a string
return func
else:
try:
function_invoked_ret, has_func = new_func_call.func_call(question1, new_func_call.func_dump(), is_task=False, is_summary=False)
except BaseException as e:
res = provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
return res
has_func = True
func_definition1 = new_func_call.func_dump()
question1 = f"{question} \n(只能调用一个函数。)"
res1, has_func = new_func_call.func_call(question1, func_definition1, is_task=False, is_summary=False)
has_func = True
if has_func: if has_func:
provider.forget() provider.forget(session_id)
question3 = f"""请你回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行总结回答,然后再给出参考链接。不要提到任何函数调用的信息。```\n{res1}\n```\n""" question3 = f"""请你用可爱的语气回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行总结回答,再给参考链接, 参考链接首末有空格。不要提到任何函数调用的信息。在总结的末尾加上1-2个相关的emoji。```\n{function_invoked_ret}\n```\n"""
print(question3) gu.log(f"web_search: {question3}", tag="web_search", level=gu.LEVEL_DEBUG, max_len=99999)
_c = 0 _c = 0
while _c < 5: while _c < 3:
try: try:
print('text chat') print('text chat')
res3 = provider.text_chat(question3) final_ret = provider.text_chat(question3)
break return final_ret
except Exception as e: except Exception as e:
print(e) print(e)
_c += 1 _c += 1
if _c == 5: if _c == 3: raise e
raise e
if "The message you submitted was too long" in str(e): if "The message you submitted was too long" in str(e):
res2 = res2[:int(len(res2) / 2)] provider.forget(session_id)
question3 = f"""请你回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行回答,然后再给出参考链接。```\n{res1}\n{res2}\n```\n""" function_invoked_ret = function_invoked_ret[:int(len(function_invoked_ret) / 2)]
return res3 time.sleep(3)
else: question3 = f"""请回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行回答,再给参考链接, 参考链接首末有空格。```\n{function_invoked_ret}\n```\n"""
return res1 return function_invoked_ret

View File

@@ -1,3 +1,6 @@
'''
插件工具函数
'''
import os import os
import inspect import inspect
@@ -7,16 +10,25 @@ def get_classes(p_name, arg):
clsmembers = inspect.getmembers(arg, inspect.isclass) clsmembers = inspect.getmembers(arg, inspect.isclass)
for (name, _) in clsmembers: for (name, _) in clsmembers:
# print(name, p_name) # print(name, p_name)
if p_name.lower() == name.lower()[:-6]: if p_name.lower() == name.lower()[:-6] or name.lower() == "main":
classes.append(name) classes.append(name)
break break
return classes return classes
# 获取一个文件夹下所有的模块 # 获取一个文件夹下所有的模块, 文件名和文件夹名相同
def get_modules(path): def get_modules(path):
modules = [] modules = []
for root, dirs, files in os.walk(path): for root, dirs, files in os.walk(path):
# 获得所在目录名
p_name = os.path.basename(root)
for file in files: for file in files:
if file.endswith(".py") and not file.startswith("__"): """
modules.append(file[:-3]) 与文件夹名不计大小写相同或者是main.py的都算启动模块
return modules """
if file.endswith(".py") and not file.startswith("__") and (p_name.lower() == file[:-3].lower() or file[:-3].lower() == "main"):
modules.append({
"pname": p_name,
"module": file[:-3],
})
return modules