feat: 初步接入官方QQ群机器人API
This commit is contained in:
+53
-9
@@ -13,6 +13,13 @@ from cores.qqbot.personality import personalities
|
||||
from addons.baidu_aip_judge import BaiduJudge
|
||||
from model.platform.qqchan import QQChan, NakuruGuildMember, NakuruGuildMessage
|
||||
from model.platform.qq import QQ
|
||||
from model.platform.qqgroup import (
|
||||
UnofficialQQBotSDK,
|
||||
Event as QQEvent,
|
||||
Message as QQMessage,
|
||||
MessageChain,
|
||||
PlainText
|
||||
)
|
||||
from nakuru import (
|
||||
CQHTTP,
|
||||
GroupMessage,
|
||||
@@ -86,6 +93,9 @@ PLATFORM_QQCHAN = 'qqchan'
|
||||
qqchan_loop = None
|
||||
client = None
|
||||
|
||||
# QQ群机器人
|
||||
PLATFROM_QQBOT = 'qqbot'
|
||||
|
||||
# 配置
|
||||
cc.init_attributes(["qq_forward_threshold"], 200)
|
||||
cc.init_attributes(["qq_welcome"], "欢迎加入本群!\n欢迎给https://github.com/Soulter/QQChannelChatGPT项目一个Star😊~\n输入help查看帮助~\n")
|
||||
@@ -105,6 +115,8 @@ cc.init_attributes(["gocq_react_group_increase"], True)
|
||||
cc.init_attributes(["gocq_qqchan_admin"], "")
|
||||
cc.init_attributes(["other_admins"], [])
|
||||
cc.init_attributes(["CHATGPT_BASE_URL"], "")
|
||||
cc.init_attributes(["qqbot_appid"], "")
|
||||
cc.init_attributes(["qqbot_secret"], "")
|
||||
# cc.init_attributes(["qq_forward_mode"], False)
|
||||
|
||||
# QQ机器人
|
||||
@@ -115,8 +127,14 @@ gocq_app = CQHTTP(
|
||||
port=cc.get("gocq_websocket_port", 6700),
|
||||
http_port=cc.get("gocq_http_port", 5700),
|
||||
)
|
||||
qq_bot: UnofficialQQBotSDK = UnofficialQQBotSDK(
|
||||
cc.get("qqbot_appid", None),
|
||||
cc.get("qqbot_secret", None)
|
||||
)
|
||||
|
||||
gocq_loop: asyncio.AbstractEventLoop = None
|
||||
qqbot_loop: asyncio.AbstractEventLoop = None
|
||||
|
||||
gocq_loop = None
|
||||
|
||||
# 全局对象
|
||||
_global_object: GlobalObject = None
|
||||
@@ -350,9 +368,17 @@ def initBot(cfg, prov):
|
||||
_global_object.admin_qq = admin_qq
|
||||
_global_object.admin_qqchan = admin_qqchan
|
||||
|
||||
|
||||
global qq_bot, qqbot_loop
|
||||
qqbot_loop = asyncio.new_event_loop()
|
||||
if cc.get("qqbot_appid", None) is not None and cc.get("qqbot_secret", None) is not None:
|
||||
gu.log("- 启用QQ群机器人 -", gu.LEVEL_INFO)
|
||||
thread_inst = threading.Thread(target=run_qqbot, args=(qqbot_loop, qq_bot,), daemon=False)
|
||||
thread_inst.start()
|
||||
|
||||
|
||||
# GOCQ
|
||||
global gocq_bot
|
||||
|
||||
if 'gocqbot' in cfg and cfg['gocqbot']['enable']:
|
||||
gu.log("- 启用QQ机器人 -", gu.LEVEL_INFO)
|
||||
|
||||
@@ -431,6 +457,15 @@ def run_gocq_bot(loop, gocq_bot, gocq_app):
|
||||
except BaseException as e:
|
||||
input("启动QQ机器人出现错误"+str(e))
|
||||
|
||||
'''
|
||||
启动QQ群机器人(官方接口)
|
||||
'''
|
||||
def run_qqbot(loop: asyncio.AbstractEventLoop, qq_bot: UnofficialQQBotSDK):
|
||||
asyncio.set_event_loop(loop)
|
||||
QQBotClient()
|
||||
qq_bot.run_bot()
|
||||
|
||||
|
||||
'''
|
||||
检查发言频率
|
||||
'''
|
||||
@@ -469,6 +504,10 @@ async def send_message(platform, message, res, session_id = None):
|
||||
qqchannel_bot.send_qq_msg(message, res)
|
||||
if platform == PLATFORM_GOCQ:
|
||||
await gocq_bot.send_qq_msg(message, res)
|
||||
if platform == PLATFROM_QQBOT:
|
||||
message_chain = MessageChain()
|
||||
message_chain.parse_from_nakuru(res)
|
||||
await qq_bot.send(message, message_chain)
|
||||
|
||||
async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, NakuruGuildMessage],
|
||||
group: bool=False,
|
||||
@@ -493,13 +532,15 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
|
||||
|
||||
with_tag = False # 是否带有昵称
|
||||
|
||||
if platform == PLATFORM_GOCQ or platform == PLATFORM_QQCHAN:
|
||||
if platform == PLATFORM_QQCHAN or platform == PLATFROM_QQBOT:
|
||||
with_tag = True
|
||||
|
||||
if platform == PLATFORM_GOCQ or platform == PLATFORM_QQCHAN or platform == PLATFROM_QQBOT:
|
||||
_len = 0
|
||||
for i in message.message:
|
||||
if isinstance(i, Plain):
|
||||
if isinstance(i, Plain) or isinstance(i, PlainText):
|
||||
qq_msg += str(i.text).strip()
|
||||
if isinstance(i, At):
|
||||
# @机器人
|
||||
if message.type == "GuildMessage":
|
||||
if i.qq == message.user_id or i.qq == message.self_tiny_id:
|
||||
with_tag = True
|
||||
@@ -545,9 +586,6 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
|
||||
# 独立会话时,一个用户一个session
|
||||
session_id = sender_id
|
||||
|
||||
if platform == PLATFORM_QQCHAN:
|
||||
with_tag = True
|
||||
|
||||
if qq_msg == "":
|
||||
await send_message(platform, message, f"Hi~", session_id=session_id)
|
||||
return
|
||||
@@ -812,4 +850,10 @@ class gocqClient():
|
||||
if source.message[0].qq == source.self_tiny_id:
|
||||
new_sub_thread(oper_msg, (source, True, PLATFORM_GOCQ))
|
||||
else:
|
||||
return
|
||||
return
|
||||
|
||||
class QQBotClient():
|
||||
@qq_bot.on('GroupMessage')
|
||||
async def _(bot: UnofficialQQBotSDK, message: QQMessage):
|
||||
print(message)
|
||||
new_sub_thread(oper_msg, (message, True, PLATFROM_QQBOT))
|
||||
+119
-43
@@ -2,8 +2,64 @@ import requests
|
||||
import asyncio
|
||||
import websockets
|
||||
from websockets import WebSocketClientProtocol
|
||||
import threading
|
||||
import json
|
||||
import inspect
|
||||
from typing import Callable, Awaitable, Union
|
||||
from pydantic import BaseModel
|
||||
import datetime
|
||||
|
||||
class Event(BaseModel):
|
||||
GroupMessage: str = "GuildMessage"
|
||||
|
||||
class Sender(BaseModel):
|
||||
user_id: str
|
||||
member_openid: str
|
||||
|
||||
|
||||
class MessageComponent(BaseModel):
|
||||
type: str
|
||||
|
||||
class PlainText(MessageComponent):
|
||||
text: str
|
||||
|
||||
class Image(MessageComponent):
|
||||
path: str
|
||||
file: str
|
||||
url: str
|
||||
|
||||
class MessageChain(list):
|
||||
|
||||
def append(self, __object: MessageComponent) -> None:
|
||||
if not isinstance(__object, MessageComponent):
|
||||
raise TypeError("不受支持的消息链元素类型。回复的消息链必须是 MessageComponent 的子类。")
|
||||
return super().append(__object)
|
||||
|
||||
def insert(self, __index: int, __object: MessageComponent) -> None:
|
||||
if not isinstance(__object, MessageComponent):
|
||||
raise TypeError("不受支持的消息链元素类型。回复的消息链必须是 MessageComponent 的子类。")
|
||||
return super().insert(__index, __object)
|
||||
|
||||
def parse_from_nakuru(self, nakuru_message_chain: Union[list, str]) -> None:
|
||||
if isinstance(nakuru_message_chain, str):
|
||||
self.append(PlainText(type='Plain', text=nakuru_message_chain))
|
||||
else:
|
||||
for i in nakuru_message_chain:
|
||||
if i['type'] == 'Plain':
|
||||
self.append(PlainText(type='Plain', text=i['text']))
|
||||
elif i['type'] == 'Image':
|
||||
self.append(Image(path=i['path'], file=i['file'], url=i['url']))
|
||||
|
||||
class Message(BaseModel):
|
||||
type: str
|
||||
user_id: str
|
||||
member_openid: str
|
||||
message_id: str
|
||||
group_id: str
|
||||
group_openid: str
|
||||
content: str
|
||||
message: MessageChain
|
||||
time: int
|
||||
sender: Sender
|
||||
|
||||
class UnofficialQQBotSDK:
|
||||
|
||||
@@ -13,43 +69,42 @@ class UnofficialQQBotSDK:
|
||||
def __init__(self, appid: str, client_secret: str) -> None:
|
||||
self.appid = appid
|
||||
self.client_secret = client_secret
|
||||
self.get_access_token()
|
||||
self.get_wss_endpoint()
|
||||
asyncio.get_event_loop().run_until_complete(self.ws_client())
|
||||
self.events: dict[str, Awaitable] = {}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_access_token(self) -> None:
|
||||
def run_bot(self) -> None:
|
||||
self.__get_access_token()
|
||||
self.__get_wss_endpoint()
|
||||
asyncio.get_event_loop().run_until_complete(self.__ws_client())
|
||||
|
||||
def __get_access_token(self) -> None:
|
||||
res = requests.post(self.GET_APP_ACCESS_TOKEN_URL, json={
|
||||
"appId": self.appid,
|
||||
"clientSecret": self.client_secret
|
||||
}, headers={
|
||||
"Content-Type": "application/json"
|
||||
})
|
||||
print(res.text)
|
||||
self.access_token = 'QQBot ' + res.json()['access_token']
|
||||
print("access_token: " + self.access_token)
|
||||
res = res.json()
|
||||
code = res['code'] if 'code' in res else 1
|
||||
if 'access_token' not in res:
|
||||
raise Exception(f"获取 access_token 失败。原因:{res['message'] if 'message' in res else '未知'}")
|
||||
self.access_token = 'QQBot ' + res['access_token']
|
||||
|
||||
def auth_header(self) -> str:
|
||||
def __auth_header(self) -> str:
|
||||
return {
|
||||
'Authorization': self.access_token,
|
||||
'X-Union-Appid': self.appid,
|
||||
}
|
||||
|
||||
def get_wss_endpoint(self):
|
||||
# self.wss_endpoint = requests.get(self.OPENAPI_BASE_URL + "/gateway", headers=self.auth_header()).json()['url']
|
||||
res = requests.get(self.OPENAPI_BASE_URL + "/gateway", headers=self.auth_header())
|
||||
print(res.text)
|
||||
def __get_wss_endpoint(self):
|
||||
res = requests.get(self.OPENAPI_BASE_URL + "/gateway", headers=self.__auth_header())
|
||||
self.wss_endpoint = res.json()['url']
|
||||
print("wss_endpoint: " + self.wss_endpoint)
|
||||
|
||||
async def behav_heartbeat(self, ws: WebSocketClientProtocol, t: int):
|
||||
async def __behav_heartbeat(self, ws: WebSocketClientProtocol, t: int):
|
||||
while True:
|
||||
await asyncio.sleep(t - 1)
|
||||
try:
|
||||
print("heartbeat., s: " + str(self.s))
|
||||
await ws.send(json.dumps({
|
||||
"op": 1,
|
||||
"d": self.s
|
||||
@@ -57,12 +112,9 @@ class UnofficialQQBotSDK:
|
||||
except:
|
||||
print("heartbeat error.")
|
||||
|
||||
async def handle_msg(self, ws: WebSocketClientProtocol, msg: dict):
|
||||
async def __handle_msg(self, ws: WebSocketClientProtocol, msg: dict):
|
||||
if msg['op'] == 10:
|
||||
# hello
|
||||
# 创建心跳任务
|
||||
print("hello.")
|
||||
asyncio.get_event_loop().create_task(self.behav_heartbeat(ws, msg['d']['heartbeat_interval'] / 1000))
|
||||
asyncio.get_event_loop().create_task(self.__behav_heartbeat(ws, msg['d']['heartbeat_interval'] / 1000))
|
||||
# 鉴权,获得session
|
||||
await ws.send(json.dumps({
|
||||
"op": 2,
|
||||
@@ -77,36 +129,60 @@ class UnofficialQQBotSDK:
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
if msg['op'] == 0:
|
||||
# ready
|
||||
print("ready.")
|
||||
data = msg['d']
|
||||
print(data)
|
||||
event_typ: str = msg['t'] if 't' in msg else None
|
||||
if event_typ == 'GROUP_AT_MESSAGE_CREATE':
|
||||
if 'GroupMessage' in self.events:
|
||||
coro = self.events['GroupMessage']
|
||||
else:
|
||||
return
|
||||
message_chain = MessageChain()
|
||||
message_chain.append(PlainText(type="Plain", text=data['content']))
|
||||
group_message = Message(
|
||||
type='GroupMessage',
|
||||
user_id=data['author']['id'],
|
||||
member_openid=data['author']['member_openid'],
|
||||
message_id=data['id'],
|
||||
group_id=data['group_id'],
|
||||
group_openid=data['group_openid'],
|
||||
content=data['content'],
|
||||
# 2023-11-24T19:51:11+08:00
|
||||
time=int(datetime.datetime.strptime(data['timestamp'], "%Y-%m-%dT%H:%M:%S%z").timestamp()),
|
||||
sender=Sender(
|
||||
user_id=data['author']['id'],
|
||||
member_openid=data['author']['member_openid']
|
||||
),
|
||||
message=message_chain
|
||||
)
|
||||
await coro(self, group_message)
|
||||
|
||||
if 'group_openid' in data:
|
||||
group_openid = data['group_openid']
|
||||
message_str = data['content'].strip()
|
||||
message_id = data['id']
|
||||
# 发送消息
|
||||
requests.post(self.OPENAPI_BASE_URL + f"/v2/groups/{group_openid}/messages", headers=self.auth_header(), json={
|
||||
"content": message_str,
|
||||
"message_type": 0,
|
||||
"msg_id": message_id
|
||||
})
|
||||
async def send(self, message: Message, message_chain: MessageChain) -> None:
|
||||
# todo: 消息链转换支持更多类型。
|
||||
plain_text = ""
|
||||
for i in message_chain:
|
||||
if isinstance(i, PlainText):
|
||||
plain_text += i.text
|
||||
requests.post(self.OPENAPI_BASE_URL + f"/v2/groups/{message.group_openid}/messages", headers=self.__auth_header(), json={
|
||||
"content": plain_text,
|
||||
"message_type": 0,
|
||||
"msg_id": message.message_id
|
||||
})
|
||||
|
||||
async def ws_client(self):
|
||||
async def __ws_client(self):
|
||||
self.s = 0
|
||||
async with websockets.connect(self.wss_endpoint) as websocket:
|
||||
print("ws connected.")
|
||||
while True:
|
||||
msg = await websocket.recv()
|
||||
msg = json.loads(msg)
|
||||
if 's' in msg:
|
||||
self.s = msg['s']
|
||||
print("recv: " + str(msg))
|
||||
await self.handle_msg(websocket, msg)
|
||||
|
||||
await self.__handle_msg(websocket, msg)
|
||||
|
||||
if __name__ == "__main__":
|
||||
UnofficialQQBotSDK("", "")
|
||||
def on(self, event: str) -> None:
|
||||
def wrapper(func: Awaitable):
|
||||
if inspect.iscoroutinefunction(func) == False:
|
||||
raise TypeError("func must be a coroutine function")
|
||||
self.events[event] = func
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user