122 lines
4.5 KiB
Python
122 lines
4.5 KiB
Python
import json
|
||
from astrbot.core import logger
|
||
from aiohttp import ClientSession
|
||
from typing import Dict, List, Any, AsyncGenerator
|
||
|
||
|
||
class DifyAPIClient:
|
||
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
|
||
self.api_key = api_key
|
||
self.api_base = api_base
|
||
self.session = ClientSession()
|
||
self.headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
}
|
||
|
||
async def chat_messages(
|
||
self,
|
||
inputs: Dict,
|
||
query: str,
|
||
user: str,
|
||
response_mode: str = "streaming",
|
||
conversation_id: str = "",
|
||
files: List[Dict[str, Any]] = [],
|
||
timeout: float = 60,
|
||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||
url = f"{self.api_base}/chat-messages"
|
||
payload = locals()
|
||
payload.pop("self")
|
||
payload.pop("timeout")
|
||
logger.info(f"chat_messages payload: {payload}")
|
||
async with self.session.post(
|
||
url, json=payload, headers=self.headers, timeout=timeout
|
||
) as resp:
|
||
if resp.status != 200:
|
||
text = await resp.text()
|
||
raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
|
||
|
||
buffer = ""
|
||
while True:
|
||
# 保持原有的8192字节限制,防止数据过大导致高水位报错
|
||
chunk = await resp.content.read(8192)
|
||
if not chunk:
|
||
break
|
||
|
||
buffer += chunk.decode('utf-8')
|
||
blocks = buffer.split('\n\n')
|
||
|
||
# 处理完整的数据块
|
||
for block in blocks[:-1]:
|
||
if block.strip() and block.startswith('data:'):
|
||
try:
|
||
json_str = block[5:] # 移除 "data:" 前缀
|
||
json_obj = json.loads(json_str)
|
||
yield json_obj
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"JSON解析错误: {str(e)}")
|
||
logger.error(f"原始数据块: {json_str}")
|
||
|
||
# 保留最后一个可能不完整的块
|
||
buffer = blocks[-1] if blocks else ""
|
||
|
||
async def workflow_run(
|
||
self,
|
||
inputs: Dict,
|
||
user: str,
|
||
response_mode: str = "streaming",
|
||
files: List[Dict[str, Any]] = [],
|
||
timeout: float = 60,
|
||
):
|
||
url = f"{self.api_base}/workflows/run"
|
||
payload = locals()
|
||
payload.pop("self")
|
||
payload.pop("timeout")
|
||
logger.info(f"workflow_run payload: {payload}")
|
||
async with self.session.post(
|
||
url, json=payload, headers=self.headers, timeout=timeout
|
||
) as resp:
|
||
if resp.status != 200:
|
||
text = await resp.text()
|
||
raise Exception(f"workflow_run 请求失败:{resp.status}. {text}")
|
||
|
||
buffer = ""
|
||
while True:
|
||
# 保持原有的8192字节限制,防止数据过大导致高水位报错
|
||
chunk = await resp.content.read(8192)
|
||
if not chunk:
|
||
break
|
||
|
||
buffer += chunk.decode('utf-8')
|
||
blocks = buffer.split('\n\n')
|
||
|
||
# 处理完整的数据块
|
||
for block in blocks[:-1]:
|
||
if block.strip() and block.startswith('data:'):
|
||
try:
|
||
json_str = block[5:] # 移除 "data:" 前缀
|
||
json_obj = json.loads(json_str)
|
||
yield json_obj
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"JSON解析错误: {str(e)}")
|
||
logger.error(f"原始数据块: {json_str}")
|
||
|
||
# 保留最后一个可能不完整的块
|
||
buffer = blocks[-1] if blocks else ""
|
||
|
||
async def file_upload(
|
||
self,
|
||
file_path: str,
|
||
user: str,
|
||
) -> Dict[str, Any]:
|
||
url = f"{self.api_base}/files/upload"
|
||
payload = {
|
||
"user": user,
|
||
"file": open(file_path, "rb"),
|
||
}
|
||
async with self.session.post(
|
||
url, data=payload, headers=self.headers
|
||
) as resp:
|
||
return await resp.json() # {"id": "xxx", ...}
|
||
|
||
async def close(self):
|
||
await self.session.close() |