diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index bd5005ee..b5f0c188 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -24,6 +24,7 @@ from astrbot.core.provider.entities import ( ) from astrbot.core.star.star_handler import EventType from ..agent_runner.tool_loop_agent import ToolLoopAgent +from astrbot.core.provider import Provider class LLMRequestSubStage(Stage): @@ -51,16 +52,27 @@ class LLMRequestSubStage(Stage): self.conv_manager = ctx.plugin_manager.context.conversation_manager + def _select_provider(self, event: AstrMessageEvent) -> Provider | None: + """选择使用的 LLM 提供商""" + sel_provider = event.get_extra("selected_provider") + _ctx = self.ctx.plugin_manager.context + if sel_provider and isinstance(sel_provider, str): + provider = _ctx.get_provider_by_id(sel_provider) + if not provider: + logger.error(f"未找到指定的提供商: {sel_provider}。") + return provider + + return _ctx.get_using_provider(umo=event.unified_msg_origin) + async def process( self, event: AstrMessageEvent, _nested: bool = False ) -> Union[None, AsyncGenerator[None, None]]: - req: ProviderRequest = None + req: ProviderRequest | None = None if not self.ctx.astrbot_config["provider_settings"]["enable"]: logger.debug("未启用 LLM 能力,跳过处理。") return - umo = event.unified_msg_origin - provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo) + provider = self._select_provider(event) if provider is None: return @@ -75,6 +87,8 @@ class LLMRequestSubStage(Stage): else: req = ProviderRequest(prompt="", image_urls=[]) + if sel_model := event.get_extra("selected_model"): + req.model = sel_model if self.provider_wake_prefix: if not event.message_str.startswith(self.provider_wake_prefix): return @@ -168,7 +182,10 @@ class LLMRequestSubStage(Stage): if self.streaming_response: # 用来标记流式响应需要分节 yield MessageChain(chain=[], type="break") - if self.show_tool_use or event.get_platform_name() == "webchat": + if ( + self.show_tool_use + or event.get_platform_name() == "webchat" + ): resp.data["chain"].type = "tool_call" await event.send(resp.data["chain"]) continue diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index 82e36ca2..0354260e 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -164,7 +164,7 @@ class WakingCheckStage(Stage): "parsed_params" ) - event.clear_extra() + event._extras.pop("parsed_params", None) event.set_extra("activated_handlers", activated_handlers) event.set_extra("handlers_parsed_params", handlers_parsed_params) diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 41d3e941..aaac8e28 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -151,6 +151,10 @@ class WebChatAdapter(Platform): session_id=message.session_id, ) + _, _, payload = message.raw_message # type: ignore + message_event.set_extra("selected_provider", payload.get("selected_provider")) + message_event.set_extra("selected_model", payload.get("selected_model")) + self.commit_event(message_event) async def terminate(self): diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index abb01960..2d120d7f 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -110,6 +110,9 @@ class ProviderRequest: tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None """附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls""" + model: str | None = None + """模型名称,为 None 时使用提供商的默认模型""" + def __repr__(self): return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})" diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 1ecca353..98e8fab8 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -88,6 +88,7 @@ class Provider(AbstractProvider): contexts: list = None, system_prompt: str = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, + model: str | None = None, **kwargs, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 @@ -116,6 +117,7 @@ class Provider(AbstractProvider): contexts: list = None, system_prompt: str = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, + model: str | None = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 4ea4c2e0..aaff177e 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -235,6 +235,7 @@ class ProviderAnthropic(Provider): contexts=None, system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ) -> LLMResponse: if contexts is None: @@ -259,7 +260,7 @@ class ProviderAnthropic(Provider): system_prompt, new_messages = self._prepare_payload(context_query) model_config = self.provider_config.get("model_config", {}) - model_config["model"] = self.get_model() + model_config["model"] = model or self.get_model() payloads = {"messages": new_messages, **model_config} @@ -285,6 +286,7 @@ class ProviderAnthropic(Provider): contexts=..., system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ): if contexts is None: @@ -309,7 +311,7 @@ class ProviderAnthropic(Provider): system_prompt, new_messages = self._prepare_payload(context_query) model_config = self.provider_config.get("model_config", {}) - model_config["model"] = self.get_model() + model_config["model"] = model or self.get_model() payloads = {"messages": new_messages, **model_config} diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 3498f834..46b12726 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -67,6 +67,7 @@ class ProviderDashscope(ProviderOpenAIOfficial): func_tool: FuncCall = None, contexts: List = None, system_prompt: str = None, + model=None, **kwargs, ) -> LLMResponse: if contexts is None: @@ -163,6 +164,7 @@ class ProviderDashscope(ProviderOpenAIOfficial): contexts=..., system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ): # raise NotImplementedError("This method is not implemented yet.") diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index b3a0cccc..cc3e8062 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -60,6 +60,8 @@ class ProviderDify(Provider): func_tool: FuncCall = None, contexts: List = None, system_prompt: str = None, + tool_calls_result=None, + model=None, **kwargs, ) -> LLMResponse: if image_urls is None: @@ -84,11 +86,13 @@ class ProviderDify(Provider): f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" ) continue - files_payload.append({ - "type": "image", - "transfer_method": "local_file", - "upload_file_id": file_response["id"], - }) + files_payload.append( + { + "type": "image", + "transfer_method": "local_file", + "upload_file_id": file_response["id"], + } + ) # 获得会话变量 payload_vars = self.variables.copy() @@ -195,6 +199,7 @@ class ProviderDify(Provider): contexts=..., system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ): # raise NotImplementedError("This method is not implemented yet.") diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 573fe768..56526c12 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -259,10 +259,12 @@ class ProviderGoogleGenAI(Provider): contents.append(content_cls(parts=part)) gemini_contents: list[types.Content] = [] - native_tool_enabled = any([ - self.provider_config.get("gm_native_coderunner", False), - self.provider_config.get("gm_native_search", False), - ]) + native_tool_enabled = any( + [ + self.provider_config.get("gm_native_coderunner", False), + self.provider_config.get("gm_native_search", False), + ] + ) for message in payloads["messages"]: role, content = message["role"], message.get("content") @@ -505,6 +507,7 @@ class ProviderGoogleGenAI(Provider): contexts=None, system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ) -> LLMResponse: if contexts is None: @@ -527,7 +530,7 @@ class ProviderGoogleGenAI(Provider): context_query.extend(tcr.to_openai_messages()) model_config = self.provider_config.get("model_config", {}) - model_config["model"] = self.get_model() + model_config["model"] = model or self.get_model() payloads = {"messages": context_query, **model_config} @@ -551,6 +554,7 @@ class ProviderGoogleGenAI(Provider): contexts=None, system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: if contexts is None: @@ -573,7 +577,7 @@ class ProviderGoogleGenAI(Provider): context_query.extend(tcr.to_openai_messages()) model_config = self.provider_config.get("model_config", {}) - model_config["model"] = self.get_model() + model_config["model"] = model or self.get_model() payloads = {"messages": context_query, **model_config} @@ -632,10 +636,12 @@ class ProviderGoogleGenAI(Provider): if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue - user_content["content"].append({ - "type": "image_url", - "image_url": {"url": image_data}, - }) + user_content["content"].append( + { + "type": "image_url", + "image_url": {"url": image_data}, + } + ) return user_content else: return {"role": "user", "content": text} diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 936fc2e3..36a8579a 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -222,6 +222,7 @@ class ProviderOpenAIOfficial(Provider): contexts: list | None = None, system_prompt: str | None = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, + model: str | None = None, **kwargs, ) -> tuple: """准备聊天所需的有效载荷和上下文""" @@ -245,7 +246,7 @@ class ProviderOpenAIOfficial(Provider): context_query.extend(tcr.to_openai_messages()) model_config = self.provider_config.get("model_config", {}) - model_config["model"] = self.get_model() + model_config["model"] = model or self.get_model() payloads = {"messages": context_query, **model_config} @@ -346,6 +347,7 @@ class ProviderOpenAIOfficial(Provider): contexts=None, system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ) -> LLMResponse: payloads, context_query = await self._prepare_chat_payload( @@ -354,6 +356,7 @@ class ProviderOpenAIOfficial(Provider): contexts, system_prompt, tool_calls_result, + model=model, **kwargs, ) @@ -413,6 +416,7 @@ class ProviderOpenAIOfficial(Provider): contexts=[], system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """流式对话,与服务商交互并逐步返回结果""" @@ -422,6 +426,7 @@ class ProviderOpenAIOfficial(Provider): contexts, system_prompt, tool_calls_result, + model=model, **kwargs, ) @@ -525,10 +530,12 @@ class ProviderOpenAIOfficial(Provider): if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue - user_content["content"].append({ - "type": "image_url", - "image_url": {"url": image_data}, - }) + user_content["content"].append( + { + "type": "image_url", + "image_url": {"url": image_data}, + } + ) return user_content else: return {"role": "user", "content": text} diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py index 428dee8f..cf52e95f 100644 --- a/astrbot/core/provider/sources/zhipu_source.py +++ b/astrbot/core/provider/sources/zhipu_source.py @@ -28,6 +28,7 @@ class ProviderZhipu(ProviderOpenAIOfficial): func_tool: FuncCall = None, contexts=None, system_prompt=None, + model=None, **kwargs, ) -> LLMResponse: if contexts is None: @@ -38,7 +39,7 @@ class ProviderZhipu(ProviderOpenAIOfficial): context_query = [*contexts, new_record] model_cfgs: dict = self.provider_config.get("model_config", {}) - model = self.get_model() + model = model or self.get_model() # glm-4v-flash 只支持一张图片 if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1: logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片") diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index a273bccd..e7b086cd 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -120,6 +120,8 @@ class ChatRoute(Route): conversation_id = post_data["conversation_id"] image_url = post_data.get("image_url") audio_url = post_data.get("audio_url") + selected_provider = post_data.get("selected_provider") + selected_model = post_data.get("selected_model") if not message and not image_url and not audio_url: return ( Response() @@ -202,6 +204,8 @@ class ChatRoute(Route): "message": message, "image_url": image_url, # list "audio_url": audio_url, + "selected_provider": selected_provider, + "selected_model": selected_model, }, ) ) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index c225c762..b55f0b21 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -9,6 +9,7 @@ from astrbot.core.platform.register import platform_registry from astrbot.core.provider.register import provider_registry from astrbot.core.star.star import star_registry from astrbot.core import logger +from astrbot.core.provider import Provider import asyncio @@ -168,6 +169,7 @@ class ConfigRoute(Route): "/config/llmtools": ("GET", self.get_llm_tools), "/config/provider/check_status": ("GET", self.check_all_providers_status), "/config/provider/list": ("GET", self.get_provider_config_list), + "/config/provider/model_list": ("GET", self.get_provider_model_list), "/config/provider/get_session_seperate": ( "GET", lambda: Response() @@ -319,6 +321,28 @@ class ConfigRoute(Route): provider_list.append(provider) return Response().ok(provider_list).__dict__ + async def get_provider_model_list(self): + """获取指定提供商的模型列表""" + provider_id = request.args.get("provider_id", None) + if not provider_id: + return Response().error("缺少参数 provider_id").__dict__ + + prov_mgr = self.core_lifecycle.provider_manager + provider: Provider | None = prov_mgr.inst_map.get(provider_id, None) + if not provider: + return Response().error(f"未找到 ID 为 {provider_id} 的提供商").__dict__ + + try: + models = await provider.get_models() + ret = { + "models": models, + "provider_id": provider_id, + } + return Response().ok(ret).__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(str(e)).__dict__ + async def post_astrbot_configs(self): post_configs = await request.json try: diff --git a/astrbot/dashboard/routes/multi_user_chat.py b/astrbot/dashboard/routes/multi_user_chat.py deleted file mode 100644 index e69de29b..00000000 diff --git a/dashboard/src/components/chat/ProviderModelSelector.vue b/dashboard/src/components/chat/ProviderModelSelector.vue new file mode 100644 index 00000000..7509b529 --- /dev/null +++ b/dashboard/src/components/chat/ProviderModelSelector.vue @@ -0,0 +1,353 @@ + + + + + diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index 8c68f083..4995babf 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -50,10 +50,12 @@ @@ -100,8 +102,8 @@ @@ -188,14 +190,23 @@ -
- - +
+
+ + +
+
+ + +
+
@@ -246,12 +257,13 @@ import { ref } from 'vue'; import { useCustomizerStore } from '@/stores/customizer'; import { useI18n, useModuleI18n } from '@/i18n/composables'; import LanguageSwitcher from '@/components/shared/LanguageSwitcher.vue'; +import ProviderModelSelector from '@/components/chat/ProviderModelSelector.vue'; import hljs from 'highlight.js'; import 'highlight.js/styles/github.css'; marked.setOptions({ breaks: true, - highlight: function(code, lang) { + highlight: function (code, lang) { if (lang && hljs.getLanguage(lang)) { try { return hljs.highlight(code, { language: lang }).value; @@ -266,7 +278,8 @@ marked.setOptions({ export default { name: 'ChatPage', components: { - LanguageSwitcher + LanguageSwitcher, + ProviderModelSelector }, props: { chatboxMode: { @@ -790,6 +803,11 @@ export default { this.loadingChat = true + // 从ProviderModelSelector组件获取当前选择 + const selection = this.$refs.providerModelSelector?.getCurrentSelection(); + const selectedProviderId = selection?.providerId || ''; + const selectedModelName = selection?.modelName || ''; + try { const response = await fetch('/api/chat/send', { method: 'POST', @@ -801,7 +819,9 @@ export default { message: this.prompt.trim(), // 确保发送的消息已去除前后空格 conversation_id: this.currCid, image_url: this.stagedImagesName, - audio_url: this.stagedAudioUrl ? [this.stagedAudioUrl] : [] + audio_url: this.stagedAudioUrl ? [this.stagedAudioUrl] : [], + selected_provider: selectedProviderId, + selected_model: selectedModelName }) }); diff --git a/uv.lock b/uv.lock index 6279381b..7a245997 100644 --- a/uv.lock +++ b/uv.lock @@ -204,7 +204,7 @@ wheels = [ [[package]] name = "astrbot" -version = "3.5.17" +version = "3.5.18" source = { editable = "." } dependencies = [ { name = "aiocqhttp" },