Files
AstrBot/model/provider/provider.py
2024-05-25 17:47:41 +08:00

59 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from collections import defaultdict
class Provider:
def __init__(self) -> None:
self.model_stat = defaultdict(int) # 用于记录 LLM Model 使用数据
self.curr_model_name = "unknown"
def reset_model_stat(self):
self.model_stat.clear()
def set_curr_model(self, model_name: str):
self.curr_model_name = model_name
def get_curr_model(self):
'''
返回当前正在使用的 LLM
'''
return self.curr_model_name
def accu_model_stat(self, model: str = None):
if not model:
model = self.get_curr_model()
self.model_stat[model] += 1
async def text_chat(self,
prompt: str,
session_id: str,
image_url: None = None,
tools: None = None,
extra_conf: dict = None,
default_personality: dict = None,
**kwargs) -> str:
'''
[require]
prompt: 提示词
session_id: 会话id
[optional]
image_url: 图片url识图
tools: 函数调用工具
extra_conf: 额外配置
default_personality: 默认人格
'''
raise NotImplementedError()
async def image_generate(self, prompt, session_id, **kwargs) -> str:
'''
[require]
prompt: 提示词
session_id: 会话id
'''
raise NotImplementedError()
async def forget(self, session_id=None) -> bool:
'''
重置会话
'''
raise NotImplementedError()