59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
from collections import defaultdict
|
||
|
||
class Provider:
|
||
def __init__(self) -> None:
|
||
self.model_stat = defaultdict(int) # 用于记录 LLM Model 使用数据
|
||
self.curr_model_name = "unknown"
|
||
|
||
def reset_model_stat(self):
|
||
self.model_stat.clear()
|
||
|
||
def set_curr_model(self, model_name: str):
|
||
self.curr_model_name = model_name
|
||
|
||
def get_curr_model(self):
|
||
'''
|
||
返回当前正在使用的 LLM
|
||
'''
|
||
return self.curr_model_name
|
||
|
||
def accu_model_stat(self, model: str = None):
|
||
if not model:
|
||
model = self.get_curr_model()
|
||
self.model_stat[model] += 1
|
||
|
||
async def text_chat(self,
|
||
prompt: str,
|
||
session_id: str,
|
||
image_url: None = None,
|
||
tools: None = None,
|
||
extra_conf: dict = None,
|
||
default_personality: dict = None,
|
||
**kwargs) -> str:
|
||
'''
|
||
[require]
|
||
prompt: 提示词
|
||
session_id: 会话id
|
||
|
||
[optional]
|
||
image_url: 图片url(识图)
|
||
tools: 函数调用工具
|
||
extra_conf: 额外配置
|
||
default_personality: 默认人格
|
||
'''
|
||
raise NotImplementedError()
|
||
|
||
async def image_generate(self, prompt, session_id, **kwargs) -> str:
|
||
'''
|
||
[require]
|
||
prompt: 提示词
|
||
session_id: 会话id
|
||
'''
|
||
raise NotImplementedError()
|
||
|
||
async def forget(self, session_id=None) -> bool:
|
||
'''
|
||
重置会话
|
||
'''
|
||
raise NotImplementedError()
|