From b8a6fb17201e6842104f4b537fdd33c46d550b08 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 25 Dec 2024 12:49:58 +0800 Subject: [PATCH] chore: update tests --- .github/workflows/coverage_test.yml | 8 +- astrbot/core/star/filter/command.py | 3 + packages/astrbot/main.py | 2 +- tests/test_main.py | 48 ++++++ tests/test_pipeline.py | 217 ++++++++++++++++++++++++++++ tests/test_plugin_manager.py | 93 ++++++++++++ 6 files changed, 363 insertions(+), 8 deletions(-) create mode 100644 tests/test_main.py create mode 100644 tests/test_pipeline.py create mode 100644 tests/test_plugin_manager.py diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml index 175e5f56..bd4ae6be 100644 --- a/.github/workflows/coverage_test.yml +++ b/.github/workflows/coverage_test.yml @@ -21,16 +21,10 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install pytest pytest-cov pytest-asyncio - mkdir data - mkdir data/plugins - mkdir data/config - mkdir temp - name: Run tests run: | - export LLM_MODEL=${{ secrets.LLM_MODEL }} - export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }} - export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} + export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }} PYTHONPATH=./ pytest --cov=. tests/ -v - name: Upload results to Codecov diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index 72b8b160..dce76b04 100644 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -51,6 +51,9 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin): ls = re.split(r"\s+", message_str) if self.command_name != ls[0]: return False + if len(self.handler_params) == 0 and len(ls) > 1: + # 一定程度避免 LLM 聊天时误判为指令 + return False # params_str = message_str[len(self.command_name):].strip() ls = ls[1:] # 去除空字符串 diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 6df2208c..c8580ea8 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -289,7 +289,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo - 重置 LLM 会话(保留人格): /reset p 【当前人格】: {str(self.context.get_using_provider().curr_personality['prompt'])} -""")) +""").use_t2i(False)) elif l[1] == "list": msg = "人格列表:\n" for key in personalities.keys(): diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 00000000..d2201e44 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,48 @@ +import os +import sys +import pytest +from unittest import mock +from main import check_env, check_dashboard_files + +class _version_info(): + def __init__(self, major, minor): + self.major = major + self.minor = minor + +def test_check_env(monkeypatch): + version_info_correct = _version_info(3, 10) + version_info_wrong = _version_info(3, 9) + monkeypatch.setattr(sys, 'version_info', version_info_correct) + with mock.patch('os.makedirs') as mock_makedirs: + check_env() + mock_makedirs.assert_any_call("data/config", exist_ok=True) + mock_makedirs.assert_any_call("data/plugins", exist_ok=True) + mock_makedirs.assert_any_call("data/temp", exist_ok=True) + + monkeypatch.setattr(sys, 'version_info', version_info_wrong) + with pytest.raises(SystemExit): + check_env() + +@pytest.mark.asyncio +async def test_check_dashboard_files(monkeypatch): + monkeypatch.setattr(os.path, 'exists', lambda x: False) + async def mock_get(*args, **kwargs): + class MockResponse: + status = 200 + async def read(self): + return b'content' + return MockResponse() + + with mock.patch('aiohttp.ClientSession.get', new=mock_get): + with mock.patch('builtins.open', mock.mock_open()) as mock_file: + with mock.patch('zipfile.ZipFile.extractall') as mock_extractall: + async def mock_aenter(_): + await check_dashboard_files() + mock_file.assert_called_once_with("data/dashboard.zip", "wb") + mock_extractall.assert_called_once() + + async def mock_aexit(obj, exc_type, exc, tb): + return + + mock_extractall.__aenter__ = mock_aenter + mock_extractall.__aexit__ = mock_aexit \ No newline at end of file diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 00000000..a142d79a --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,217 @@ +import pytest, logging, os +from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext +from astrbot.core.star import PluginManager +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType +from astrbot.core.message.message_event_result import MessageChain, ResultContentType +from astrbot.core.message.components import Plain, At +from astrbot.core.platform.platform_metadata import PlatformMetadata +from astrbot.core.platform.manager import PlatformManager +from astrbot.core.provider.manager import ProviderManager +from astrbot.core.db.sqlite import SQLiteDatabase +from astrbot.core.star.context import Context +from astrbot.core import logger +from asyncio import Queue + +SESSION_ID_IN_WHITELIST = "test_sid_wl" +SESSION_ID_NOT_IN_WHITELIST = "test_sid" +TEST_LLM_PROVIDER = { + "id": "zhipu_default", + "type": "openai_chat_completion", + "enable": True, + "key": [os.getenv("ZHIPU_API_KEY")], + "api_base": "https://open.bigmodel.cn/api/paas/v4/", + "model_config": { + "model": "glm-4-flash", + }, +} + +TEST_COMMANDS = [ + ["help", "已注册的 AstrBot 内置指令"], + ["tool ls", "查看、激活、停用当前注册的函数工具"], + ["tool on websearch", "激活工具"], + ["tool off websearch", "停用工具"], + ["plugin", "已加载的插件"], + ["t2i", "文本转图片模式"], + ["sid", "此 ID 可用于设置会话白名单。"], + ["op test_op", "授权成功。"], + ["deop test_op", "取消授权成功。"], + ["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"], + ["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"], + ["provider", "当前载入的 LLM 提供商"], + ["reset", "重置成功"], + # ["model", "查看、切换提供商模型列表"], + ["history", "历史记录:"], + ["key", "当前 Key"], + ["persona", "[Persona]"] +] + +class FakeAstrMessageEvent(AstrMessageEvent): + def __init__(self, abm: AstrBotMessage = None): + meta = PlatformMetadata("test_platform", "test") + super().__init__( + message_str=abm.message_str, + message_obj=abm, + platform_meta=meta, + session_id=abm.session_id + ) + + async def send(self, message: MessageChain): + await super().send(message) + + @staticmethod + def create_fake_event( + message_str: str, + session_id: str = "test_sid", + is_at: bool = False, + is_group: bool = False, + sender_id: str = "123456" + ): + abm = AstrBotMessage() + abm.message_str = message_str + abm.group_id = "test" + abm.message = [Plain(message_str)] + if is_at: + abm.message.append(At(qq="bot")) + abm.self_id = "bot" + abm.sender = MessageMember(sender_id, "mika") + abm.timestamp = 1234567890 + abm.message_id = "test" + abm.session_id = session_id + if is_group: + abm.type = MessageType.GROUP_MESSAGE + else: + abm.type = MessageType.FRIEND_MESSAGE + return FakeAstrMessageEvent(abm) + +@pytest.fixture(scope="module") +def event_queue(): + return Queue() + +@pytest.fixture(scope="module") +def config(): + cfg = AstrBotConfig() + cfg['platform_settings']['id_whitelist'] = ["test_platform:FriendMessage:test_sid_wl", "test_platform:GroupMessage:test_sid_wl"] + cfg['admins_id'] = ["123456"] + cfg['content_safety']['internal_keywords']['extra_keywords'] = ["^TEST_NEGATIVE"] + cfg['provider'] = [TEST_LLM_PROVIDER] + return cfg + +@pytest.fixture(scope="module") +def db(): + return SQLiteDatabase("data/data_v3.db") + +@pytest.fixture(scope="module") +def platform_manager(event_queue, config): + return PlatformManager(config, event_queue) + +@pytest.fixture(scope="module") +def provider_manager(config, db): + return ProviderManager(config, db) + +@pytest.fixture(scope="module") +def star_context(event_queue, config, db, platform_manager, provider_manager): + star_context = Context(event_queue, config, db, provider_manager, platform_manager) + return star_context + +@pytest.fixture(scope="module") +def plugin_manager(star_context, config): + plugin_manager = PluginManager(star_context, config) + plugin_manager.reload() + return plugin_manager + +@pytest.fixture(scope="module") +def pipeline_context(config, plugin_manager): + return PipelineContext(config, plugin_manager) + +@pytest.fixture(scope="module") +def pipeline_scheduler(pipeline_context): + return PipelineScheduler(pipeline_context) + +@pytest.mark.asyncio +async def test_platform_initialization(platform_manager: PlatformManager): + await platform_manager.initialize() + +@pytest.mark.asyncio +async def test_provider_initialization(provider_manager: ProviderManager): + await provider_manager.initialize() + +@pytest.mark.asyncio +async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler): + await pipeline_scheduler.initialize() + +@pytest.mark.asyncio +async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog): + '''测试唤醒''' + # 群聊无 @ 无指令 + mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True) + with caplog.at_level(logging.DEBUG): + await pipeline_scheduler.execute(mock_event) + assert any("执行阶段 WhitelistCheckStage" not in message for message in caplog.messages) + # 群聊有 @ 无指令 + mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True, is_at=True) + with caplog.at_level(logging.DEBUG): + await pipeline_scheduler.execute(mock_event) + assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages) + # 群聊有指令 + mock_event = FakeAstrMessageEvent.create_fake_event("/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST) + await pipeline_scheduler.execute(mock_event) + assert mock_event._has_send_oper is True + +@pytest.mark.asyncio +async def test_pipeline_wl(pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog): + mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123") + with caplog.at_level(logging.INFO): + await pipeline_scheduler.execute(mock_event) + assert any("不在会话白名单中,已终止事件传播。" in message for message in caplog.messages), "日志中未找到预期的消息" + + mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123") + await pipeline_scheduler.execute(mock_event) + assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息" + + +@pytest.mark.asyncio +async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog): + # 测试默认屏蔽词 + mock_event = FakeAstrMessageEvent.create_fake_event("色情", session_id=SESSION_ID_IN_WHITELIST) # 测试需要。 + with caplog.at_level(logging.INFO): + await pipeline_scheduler.execute(mock_event) + assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息" + # 测试额外屏蔽词 + mock_event = FakeAstrMessageEvent.create_fake_event("TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST) + with caplog.at_level(logging.INFO): + await pipeline_scheduler.execute(mock_event) + assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息" + mock_event = FakeAstrMessageEvent.create_fake_event("_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST) + with caplog.at_level(logging.INFO): + await pipeline_scheduler.execute(mock_event) + assert any("内容安全检查不通过" not in message for message in caplog.messages) + # TODO: 测试 百度AI 的内容安全检查 + + +@pytest.mark.asyncio +async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog): + mock_event = FakeAstrMessageEvent.create_fake_event("just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST) + with caplog.at_level(logging.DEBUG): + await pipeline_scheduler.execute(mock_event) + assert any("请求 LLM" in message for message in caplog.messages) + assert mock_event.get_result() is not None + assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT + +@pytest.mark.asyncio +async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog): + mock_event = FakeAstrMessageEvent.create_fake_event("help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST) + with caplog.at_level(logging.DEBUG): + await pipeline_scheduler.execute(mock_event) + assert any("请求 LLM" in message for message in caplog.messages) + assert any("web_searcher - search_from_search_engine" in message for message in caplog.messages) + +@pytest.mark.asyncio +async def test_commands(pipeline_scheduler: PipelineScheduler, caplog): + for command in TEST_COMMANDS: + mock_event = FakeAstrMessageEvent.create_fake_event(command[0], session_id=SESSION_ID_IN_WHITELIST) + with caplog.at_level(logging.DEBUG): + await pipeline_scheduler.execute(mock_event) + assert any("执行阶段 ProcessStage" in message for message in caplog.messages) + assert any(command[1] in message for message in caplog.messages) \ No newline at end of file diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py new file mode 100644 index 00000000..8056a184 --- /dev/null +++ b/tests/test_plugin_manager.py @@ -0,0 +1,93 @@ +import pytest +import os +import shutil +from astrbot.core.star.star_manager import PluginManager +from astrbot.core.star.star_handler import star_handlers_registry +from astrbot.core.star.star import star_registry +from astrbot.core.star.context import Context +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.db.sqlite import SQLiteDatabase +from asyncio import Queue + +@pytest.fixture +def event_queue(): + return Queue() + +@pytest.fixture +def config(): + return AstrBotConfig() + +@pytest.fixture +def db(): + return SQLiteDatabase("data/data_v3.db") + +@pytest.fixture +def star_context(event_queue, config, db): + return Context(event_queue, config, db) + +@pytest.fixture +def plugin_manager_pm(star_context, config): + return PluginManager(star_context, config) + +def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): + assert plugin_manager_pm is not None + assert plugin_manager_pm.context is not None + assert plugin_manager_pm.config is not None + +def test_plugin_manager_reload(plugin_manager_pm: PluginManager): + success, err_message = plugin_manager_pm.reload() + assert success is True + assert err_message is None + assert len(star_handlers_registry) > 0 # package + +@pytest.mark.asyncio +async def test_plugin_crud(plugin_manager_pm: PluginManager): + '''测试插件安装和重载''' + os.makedirs("data/plugins", exist_ok=True) + test_repo = "https://github.com/Soulter/astrbot_plugin_essential" + plugin_path = await plugin_manager_pm.install_plugin(test_repo) + exists = False + for md in star_registry: + if md.name == "astrbot_plugin_essential": + exists = True + break + assert plugin_path is not None + assert os.path.exists(plugin_path) + assert exists is True, "插件 astrbot_plugin_essential 未成功载入" + # shutil.rmtree(plugin_path) + + # install plugin which is not exists + with pytest.raises(Exception): + plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha") + + # update + await plugin_manager_pm.update_plugin("astrbot_plugin_essential") + + with pytest.raises(Exception): + await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha") + + # uninstall + await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential") + assert not os.path.exists(plugin_path) + exists = False + for md in star_registry: + if md.name == "astrbot_plugin_essential": + exists = True + break + assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" + exists = False + for md in star_handlers_registry: + if "astrbot_plugin_essential" in md.handler_module_path: + exists = True + break + assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" + + with pytest.raises(Exception): + await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha") + + # TODO: file installation + + + + + \ No newline at end of file