chore: clean code
This commit is contained in:
@@ -1,48 +0,0 @@
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import quart
|
||||
import base64
|
||||
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.api.message_components import Plain, Image, At, Record
|
||||
from astrbot.api import logger, sp
|
||||
from .downloader import GeweDownloader
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
|
||||
class WeComClient():
|
||||
def __init__(self, config: dict):
|
||||
self.base_url = base_url
|
||||
if self.base_url.endswith('/'):
|
||||
self.base_url = self.base_url[:-1]
|
||||
|
||||
self.download_base_url = self.base_url.split(':')[:-1] # 去掉端口
|
||||
self.download_base_url = ':'.join(self.download_base_url) + ":2532/download/"
|
||||
|
||||
self.base_url += "/v2/api"
|
||||
|
||||
logger.info(f"wecom API: {self.base_url}")
|
||||
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
|
||||
|
||||
if isinstance(port, str):
|
||||
port = int(port)
|
||||
|
||||
self.token = None
|
||||
self.headers = {}
|
||||
self.nickname = nickname
|
||||
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
|
||||
|
||||
self.server = quart.Quart(__name__)
|
||||
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
|
||||
self.server.add_url_rule('/astrbot-gewechat/file/<file_id>', view_func=self.handle_file, methods=['GET'])
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
|
||||
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.multimedia_downloader = None
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
import time
|
||||
import asyncio
|
||||
import uuid
|
||||
import os
|
||||
from typing import Awaitable, Any
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record # noqa: F403
|
||||
from astrbot.api import logger
|
||||
from astrbot.core import web_chat_queue, web_chat_back_queue
|
||||
from .webchat_event import WebChatMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
|
||||
|
||||
@register_platform_adapter("wecom", "wecom")
|
||||
class WecomAdapter(Platform):
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings['unique_session']
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
"wecom",
|
||||
"wecom",
|
||||
)
|
||||
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def convert_message(self, data: tuple) -> AstrBotMessage:
|
||||
username, cid, payload = data
|
||||
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = "webchat"
|
||||
abm.tag = "webchat"
|
||||
abm.sender = MessageMember(username, username)
|
||||
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
|
||||
abm.session_id = f"webchat!{username}!{cid}"
|
||||
|
||||
abm.message_id = str(uuid.uuid4())
|
||||
abm.message = []
|
||||
|
||||
if payload['message']:
|
||||
abm.message.append(Plain(payload['message']))
|
||||
if payload['image_url']:
|
||||
if isinstance(payload['image_url'], list):
|
||||
for img in payload['image_url']:
|
||||
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img)))
|
||||
else:
|
||||
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url'])))
|
||||
if payload['audio_url']:
|
||||
if isinstance(payload['audio_url'], list):
|
||||
for audio in payload['audio_url']:
|
||||
path = os.path.join(self.imgs_dir, audio)
|
||||
abm.message.append(Record(file=path, path=path))
|
||||
else:
|
||||
path = os.path.join(self.imgs_dir, payload['audio_url'])
|
||||
abm.message.append(Record(file=path, path=path))
|
||||
|
||||
logger.debug(f"WebChatAdapter: {abm.message}")
|
||||
|
||||
message_str = payload['message']
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
abm.raw_message = data
|
||||
return abm
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
pass
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
|
||||
message_event = WebChatMessageEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
@@ -1,41 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
from astrbot.core import web_chat_back_queue
|
||||
|
||||
class WebChatMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str, message_obj, platform_meta, session_id):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
if not message:
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
return
|
||||
|
||||
cid = self.session_id.split("!")[-1]
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
web_chat_back_queue.put_nowait((comp.text, cid))
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
path = os.path.join(self.imgs_dir, filename)
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
ph = comp.file[8:]
|
||||
with open(path, "wb") as f:
|
||||
with open(ph, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
await super().send(message)
|
||||
@@ -191,7 +191,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
else:
|
||||
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
|
||||
@@ -109,7 +109,7 @@ async def download_file(url: str, path: str, show_progress: bool = False):
|
||||
'''
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url, timeout=120) as resp:
|
||||
async with session.get(url, timeout=1800) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
total_size = int(resp.headers.get('content-length', 0))
|
||||
|
||||
Reference in New Issue
Block a user