diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py
index bd5005ee..b5f0c188 100644
--- a/astrbot/core/pipeline/process_stage/method/llm_request.py
+++ b/astrbot/core/pipeline/process_stage/method/llm_request.py
@@ -24,6 +24,7 @@ from astrbot.core.provider.entities import (
)
from astrbot.core.star.star_handler import EventType
from ..agent_runner.tool_loop_agent import ToolLoopAgent
+from astrbot.core.provider import Provider
class LLMRequestSubStage(Stage):
@@ -51,16 +52,27 @@ class LLMRequestSubStage(Stage):
self.conv_manager = ctx.plugin_manager.context.conversation_manager
+ def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
+ """选择使用的 LLM 提供商"""
+ sel_provider = event.get_extra("selected_provider")
+ _ctx = self.ctx.plugin_manager.context
+ if sel_provider and isinstance(sel_provider, str):
+ provider = _ctx.get_provider_by_id(sel_provider)
+ if not provider:
+ logger.error(f"未找到指定的提供商: {sel_provider}。")
+ return provider
+
+ return _ctx.get_using_provider(umo=event.unified_msg_origin)
+
async def process(
self, event: AstrMessageEvent, _nested: bool = False
) -> Union[None, AsyncGenerator[None, None]]:
- req: ProviderRequest = None
+ req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
- umo = event.unified_msg_origin
- provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo)
+ provider = self._select_provider(event)
if provider is None:
return
@@ -75,6 +87,8 @@ class LLMRequestSubStage(Stage):
else:
req = ProviderRequest(prompt="", image_urls=[])
+ if sel_model := event.get_extra("selected_model"):
+ req.model = sel_model
if self.provider_wake_prefix:
if not event.message_str.startswith(self.provider_wake_prefix):
return
@@ -168,7 +182,10 @@ class LLMRequestSubStage(Stage):
if self.streaming_response:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
- if self.show_tool_use or event.get_platform_name() == "webchat":
+ if (
+ self.show_tool_use
+ or event.get_platform_name() == "webchat"
+ ):
resp.data["chain"].type = "tool_call"
await event.send(resp.data["chain"])
continue
diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py
index 82e36ca2..0354260e 100644
--- a/astrbot/core/pipeline/waking_check/stage.py
+++ b/astrbot/core/pipeline/waking_check/stage.py
@@ -164,7 +164,7 @@ class WakingCheckStage(Stage):
"parsed_params"
)
- event.clear_extra()
+ event._extras.pop("parsed_params", None)
event.set_extra("activated_handlers", activated_handlers)
event.set_extra("handlers_parsed_params", handlers_parsed_params)
diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py
index 41d3e941..aaac8e28 100644
--- a/astrbot/core/platform/sources/webchat/webchat_adapter.py
+++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py
@@ -151,6 +151,10 @@ class WebChatAdapter(Platform):
session_id=message.session_id,
)
+ _, _, payload = message.raw_message # type: ignore
+ message_event.set_extra("selected_provider", payload.get("selected_provider"))
+ message_event.set_extra("selected_model", payload.get("selected_model"))
+
self.commit_event(message_event)
async def terminate(self):
diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py
index abb01960..2d120d7f 100644
--- a/astrbot/core/provider/entities.py
+++ b/astrbot/core/provider/entities.py
@@ -110,6 +110,9 @@ class ProviderRequest:
tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
+ model: str | None = None
+ """模型名称,为 None 时使用提供商的默认模型"""
+
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py
index 1ecca353..98e8fab8 100644
--- a/astrbot/core/provider/provider.py
+++ b/astrbot/core/provider/provider.py
@@ -88,6 +88,7 @@ class Provider(AbstractProvider):
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
+ model: str | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -116,6 +117,7 @@ class Provider(AbstractProvider):
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
+ model: str | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py
index 4ea4c2e0..aaff177e 100644
--- a/astrbot/core/provider/sources/anthropic_source.py
+++ b/astrbot/core/provider/sources/anthropic_source.py
@@ -235,6 +235,7 @@ class ProviderAnthropic(Provider):
contexts=None,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -259,7 +260,7 @@ class ProviderAnthropic(Provider):
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
- model_config["model"] = self.get_model()
+ model_config["model"] = model or self.get_model()
payloads = {"messages": new_messages, **model_config}
@@ -285,6 +286,7 @@ class ProviderAnthropic(Provider):
contexts=...,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
):
if contexts is None:
@@ -309,7 +311,7 @@ class ProviderAnthropic(Provider):
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
- model_config["model"] = self.get_model()
+ model_config["model"] = model or self.get_model()
payloads = {"messages": new_messages, **model_config}
diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py
index 3498f834..46b12726 100644
--- a/astrbot/core/provider/sources/dashscope_source.py
+++ b/astrbot/core/provider/sources/dashscope_source.py
@@ -67,6 +67,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
+ model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -163,6 +164,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
contexts=...,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py
index b3a0cccc..cc3e8062 100644
--- a/astrbot/core/provider/sources/dify_source.py
+++ b/astrbot/core/provider/sources/dify_source.py
@@ -60,6 +60,8 @@ class ProviderDify(Provider):
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
+ tool_calls_result=None,
+ model=None,
**kwargs,
) -> LLMResponse:
if image_urls is None:
@@ -84,11 +86,13 @@ class ProviderDify(Provider):
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
)
continue
- files_payload.append({
- "type": "image",
- "transfer_method": "local_file",
- "upload_file_id": file_response["id"],
- })
+ files_payload.append(
+ {
+ "type": "image",
+ "transfer_method": "local_file",
+ "upload_file_id": file_response["id"],
+ }
+ )
# 获得会话变量
payload_vars = self.variables.copy()
@@ -195,6 +199,7 @@ class ProviderDify(Provider):
contexts=...,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py
index 573fe768..56526c12 100644
--- a/astrbot/core/provider/sources/gemini_source.py
+++ b/astrbot/core/provider/sources/gemini_source.py
@@ -259,10 +259,12 @@ class ProviderGoogleGenAI(Provider):
contents.append(content_cls(parts=part))
gemini_contents: list[types.Content] = []
- native_tool_enabled = any([
- self.provider_config.get("gm_native_coderunner", False),
- self.provider_config.get("gm_native_search", False),
- ])
+ native_tool_enabled = any(
+ [
+ self.provider_config.get("gm_native_coderunner", False),
+ self.provider_config.get("gm_native_search", False),
+ ]
+ )
for message in payloads["messages"]:
role, content = message["role"], message.get("content")
@@ -505,6 +507,7 @@ class ProviderGoogleGenAI(Provider):
contexts=None,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -527,7 +530,7 @@ class ProviderGoogleGenAI(Provider):
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
- model_config["model"] = self.get_model()
+ model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config}
@@ -551,6 +554,7 @@ class ProviderGoogleGenAI(Provider):
contexts=None,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
if contexts is None:
@@ -573,7 +577,7 @@ class ProviderGoogleGenAI(Provider):
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
- model_config["model"] = self.get_model()
+ model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config}
@@ -632,10 +636,12 @@ class ProviderGoogleGenAI(Provider):
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
- user_content["content"].append({
- "type": "image_url",
- "image_url": {"url": image_data},
- })
+ user_content["content"].append(
+ {
+ "type": "image_url",
+ "image_url": {"url": image_data},
+ }
+ )
return user_content
else:
return {"role": "user", "content": text}
diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py
index 936fc2e3..36a8579a 100644
--- a/astrbot/core/provider/sources/openai_source.py
+++ b/astrbot/core/provider/sources/openai_source.py
@@ -222,6 +222,7 @@ class ProviderOpenAIOfficial(Provider):
contexts: list | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
+ model: str | None = None,
**kwargs,
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
@@ -245,7 +246,7 @@ class ProviderOpenAIOfficial(Provider):
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
- model_config["model"] = self.get_model()
+ model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config}
@@ -346,6 +347,7 @@ class ProviderOpenAIOfficial(Provider):
contexts=None,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
) -> LLMResponse:
payloads, context_query = await self._prepare_chat_payload(
@@ -354,6 +356,7 @@ class ProviderOpenAIOfficial(Provider):
contexts,
system_prompt,
tool_calls_result,
+ model=model,
**kwargs,
)
@@ -413,6 +416,7 @@ class ProviderOpenAIOfficial(Provider):
contexts=[],
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
@@ -422,6 +426,7 @@ class ProviderOpenAIOfficial(Provider):
contexts,
system_prompt,
tool_calls_result,
+ model=model,
**kwargs,
)
@@ -525,10 +530,12 @@ class ProviderOpenAIOfficial(Provider):
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
- user_content["content"].append({
- "type": "image_url",
- "image_url": {"url": image_data},
- })
+ user_content["content"].append(
+ {
+ "type": "image_url",
+ "image_url": {"url": image_data},
+ }
+ )
return user_content
else:
return {"role": "user", "content": text}
diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py
index 428dee8f..cf52e95f 100644
--- a/astrbot/core/provider/sources/zhipu_source.py
+++ b/astrbot/core/provider/sources/zhipu_source.py
@@ -28,6 +28,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
func_tool: FuncCall = None,
contexts=None,
system_prompt=None,
+ model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -38,7 +39,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {})
- model = self.get_model()
+ model = model or self.get_model()
# glm-4v-flash 只支持一张图片
if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py
index a273bccd..e7b086cd 100644
--- a/astrbot/dashboard/routes/chat.py
+++ b/astrbot/dashboard/routes/chat.py
@@ -120,6 +120,8 @@ class ChatRoute(Route):
conversation_id = post_data["conversation_id"]
image_url = post_data.get("image_url")
audio_url = post_data.get("audio_url")
+ selected_provider = post_data.get("selected_provider")
+ selected_model = post_data.get("selected_model")
if not message and not image_url and not audio_url:
return (
Response()
@@ -202,6 +204,8 @@ class ChatRoute(Route):
"message": message,
"image_url": image_url, # list
"audio_url": audio_url,
+ "selected_provider": selected_provider,
+ "selected_model": selected_model,
},
)
)
diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py
index c225c762..b55f0b21 100644
--- a/astrbot/dashboard/routes/config.py
+++ b/astrbot/dashboard/routes/config.py
@@ -9,6 +9,7 @@ from astrbot.core.platform.register import platform_registry
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
from astrbot.core import logger
+from astrbot.core.provider import Provider
import asyncio
@@ -168,6 +169,7 @@ class ConfigRoute(Route):
"/config/llmtools": ("GET", self.get_llm_tools),
"/config/provider/check_status": ("GET", self.check_all_providers_status),
"/config/provider/list": ("GET", self.get_provider_config_list),
+ "/config/provider/model_list": ("GET", self.get_provider_model_list),
"/config/provider/get_session_seperate": (
"GET",
lambda: Response()
@@ -319,6 +321,28 @@ class ConfigRoute(Route):
provider_list.append(provider)
return Response().ok(provider_list).__dict__
+ async def get_provider_model_list(self):
+ """获取指定提供商的模型列表"""
+ provider_id = request.args.get("provider_id", None)
+ if not provider_id:
+ return Response().error("缺少参数 provider_id").__dict__
+
+ prov_mgr = self.core_lifecycle.provider_manager
+ provider: Provider | None = prov_mgr.inst_map.get(provider_id, None)
+ if not provider:
+ return Response().error(f"未找到 ID 为 {provider_id} 的提供商").__dict__
+
+ try:
+ models = await provider.get_models()
+ ret = {
+ "models": models,
+ "provider_id": provider_id,
+ }
+ return Response().ok(ret).__dict__
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return Response().error(str(e)).__dict__
+
async def post_astrbot_configs(self):
post_configs = await request.json
try:
diff --git a/astrbot/dashboard/routes/multi_user_chat.py b/astrbot/dashboard/routes/multi_user_chat.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/dashboard/src/components/chat/ProviderModelSelector.vue b/dashboard/src/components/chat/ProviderModelSelector.vue
new file mode 100644
index 00000000..7509b529
--- /dev/null
+++ b/dashboard/src/components/chat/ProviderModelSelector.vue
@@ -0,0 +1,353 @@
+
+
+
+
+ {{ selectedProviderId }} / {{ selectedModelName }}
+
+
+ 选择模型
+
+
+
+
+
+
+ 选择提供商和模型
+
+
+
+
+
+
+
+
+ {{ provider.id }}
+ {{ provider.api_base }}
+
+
+
+
+
+
+
+
+
+
+ {{ model }}
+ {{ model.description }}
+
+
+
+
+
+
+
+
+
+ 取消
+
+ 确认选择
+
+
+
+
+
+
+
+
+
+
diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue
index 8c68f083..4995babf 100644
--- a/dashboard/src/views/ChatPage.vue
+++ b/dashboard/src/views/ChatPage.vue
@@ -50,10 +50,12 @@
-
-
+
@@ -100,8 +102,8 @@
-
+
{{ isDark ? 'mdi-weather-night' : 'mdi-white-balance-sunny' }}
@@ -188,14 +190,23 @@
-
@@ -246,12 +257,13 @@ import { ref } from 'vue';
import { useCustomizerStore } from '@/stores/customizer';
import { useI18n, useModuleI18n } from '@/i18n/composables';
import LanguageSwitcher from '@/components/shared/LanguageSwitcher.vue';
+import ProviderModelSelector from '@/components/chat/ProviderModelSelector.vue';
import hljs from 'highlight.js';
import 'highlight.js/styles/github.css';
marked.setOptions({
breaks: true,
- highlight: function(code, lang) {
+ highlight: function (code, lang) {
if (lang && hljs.getLanguage(lang)) {
try {
return hljs.highlight(code, { language: lang }).value;
@@ -266,7 +278,8 @@ marked.setOptions({
export default {
name: 'ChatPage',
components: {
- LanguageSwitcher
+ LanguageSwitcher,
+ ProviderModelSelector
},
props: {
chatboxMode: {
@@ -790,6 +803,11 @@ export default {
this.loadingChat = true
+ // 从ProviderModelSelector组件获取当前选择
+ const selection = this.$refs.providerModelSelector?.getCurrentSelection();
+ const selectedProviderId = selection?.providerId || '';
+ const selectedModelName = selection?.modelName || '';
+
try {
const response = await fetch('/api/chat/send', {
method: 'POST',
@@ -801,7 +819,9 @@ export default {
message: this.prompt.trim(), // 确保发送的消息已去除前后空格
conversation_id: this.currCid,
image_url: this.stagedImagesName,
- audio_url: this.stagedAudioUrl ? [this.stagedAudioUrl] : []
+ audio_url: this.stagedAudioUrl ? [this.stagedAudioUrl] : [],
+ selected_provider: selectedProviderId,
+ selected_model: selectedModelName
})
});
diff --git a/uv.lock b/uv.lock
index 6279381b..7a245997 100644
--- a/uv.lock
+++ b/uv.lock
@@ -204,7 +204,7 @@ wheels = [
[[package]]
name = "astrbot"
-version = "3.5.17"
+version = "3.5.18"
source = { editable = "." }
dependencies = [
{ name = "aiocqhttp" },