Compare commits
465 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ca44133e90 | |||
| b4810bb487 | |||
| dc0f9c5f08 | |||
| 595fd878a6 | |||
| 9d45991181 | |||
| cf2f2fd707 | |||
| d4b1db0407 | |||
| 8470e252d6 | |||
| 131444ac52 | |||
| ab3083f943 | |||
| 1e1d5c4a14 | |||
| c8ab0b9428 | |||
| 33ce41704d | |||
| 4eb3aa31ee | |||
| d1a9dfa3e6 | |||
| 0e5ebcfd00 | |||
| c4e0a6acfe | |||
| 2243bb2862 | |||
| 1f7d2fa93f | |||
| fb680ce764 | |||
| dc5bc64040 | |||
| 1c2ce7e0aa | |||
| a290ee7f39 | |||
| 79c697c34d | |||
| 76271cbf77 | |||
| 51dcdf94fb | |||
| 254051cf62 | |||
| 24d2e6e6ce | |||
| cf9bfce43c | |||
| 9e0ee24fd7 | |||
| 5eb2772d53 | |||
| f943f05cb1 | |||
| 96ce645064 | |||
| 1a972ac0e0 | |||
| 2e173631a0 | |||
| c457d4a868 | |||
| b74655651d | |||
| f27a481c3c | |||
| 4028b26c1d | |||
| 011b6f2df1 | |||
| 7b3b73d390 | |||
| 004d6d8201 | |||
| 7cf57adceb | |||
| 76bf78b810 | |||
| f4441e2a55 | |||
| 84f590ec7b | |||
| a5865cfd01 | |||
| 4e7c714ea2 | |||
| d2c4231458 | |||
| b5004e2a51 | |||
| e0c334b5ed | |||
| d482e661fb | |||
| 3ac1caca69 | |||
| 94c112c066 | |||
| 2e694a87f8 | |||
| 4ae30db53a | |||
| f4a6dd91cf | |||
| c08a570c27 | |||
| 9c318c9526 | |||
| 4cee09870a | |||
| 866e8e8734 | |||
| 80e1784777 | |||
| 0d760ffa2e | |||
| 88f7e6a854 | |||
| de37e2355d | |||
| f27b04c5b0 | |||
| a02b8f4609 | |||
| 7b90dfb46c | |||
| 26a9dba01a | |||
| a176814ad1 | |||
| ea51439aac | |||
| 162e33f478 | |||
| ee4c310725 | |||
| a000ff2a1a | |||
| 2f9576b2ae | |||
| 92554dd398 | |||
| 9473ddc762 | |||
| 5f469a71f3 | |||
| 87bac60afc | |||
| 704339e835 | |||
| c8ab7180ba | |||
| 11757546c3 | |||
| 420b9ec2f2 | |||
| 1c73271e33 | |||
| acdbe6b9ed | |||
| 6c201228d9 | |||
| 73b2a375ad | |||
| 89bb830b60 | |||
| 2399db4944 | |||
| 62774b34d3 | |||
| 654f19eaa9 | |||
| ce642f17d9 | |||
| d7bcd5a20e | |||
| 27903e7d9d | |||
| a8c0d0a684 | |||
| 5e33c89fe7 | |||
| 42849e4586 | |||
| 6a8544fb0e | |||
| 37f7042f0f | |||
| 65d066cbef | |||
| 504531d4d5 | |||
| d4b3428160 | |||
| cd881ceb34 | |||
| 68b37e66e9 | |||
| d6e7ed81ee | |||
| a9843b4128 | |||
| d4c6131fa3 | |||
| d2d5064eed | |||
| 8bec7640fa | |||
| fcf53f06ef | |||
| 2048f210e7 | |||
| 78eacccf6e | |||
| a436ab1d78 | |||
| 2aedbf5702 | |||
| b7e7174f3d | |||
| e7e5c0456f | |||
| 53e38ed1aa | |||
| f91e7da0a1 | |||
| 74db4c4646 | |||
| 1e4902b267 | |||
| 932b1d529a | |||
| 53046460ec | |||
| 38ac42af8c | |||
| 538291c03f | |||
| 142ad9e41e | |||
| 7250ce3514 | |||
| 02cf012671 | |||
| d11a2cd95c | |||
| 65ac3181a8 | |||
| 998e54246f | |||
| fcd8f7a26e | |||
| b991afd69a | |||
| d9d8bae2d6 | |||
| 422ba52093 | |||
| 51630f95fd | |||
| ac1cab60a3 | |||
| 23f61b0d62 | |||
| 759f8518b2 | |||
| 7bd6c92f43 | |||
| ff705d99b3 | |||
| 7ec17dc771 | |||
| 35883e8601 | |||
| 48b7bdb9ba | |||
| d2d5b4370c | |||
| 27c31d6e0c | |||
| 961ee22327 | |||
| 37b3c08baa | |||
| d8c3f601df | |||
| cff9068359 | |||
| cc871b7a72 | |||
| 5b98ef5b3d | |||
| 3428d15299 | |||
| 9ea3f0842c | |||
| 90242e2285 | |||
| 1616345261 | |||
| 0b818477ac | |||
| 027d6ea2b2 | |||
| c7d2588f1a | |||
| 8712e26c74 | |||
| 0c652e0ac4 | |||
| 3a3a5e6c8b | |||
| 06ab2822be | |||
| 1b7596ebe1 | |||
| e95219f2ec | |||
| bb0ec0a3ec | |||
| 483b4e090e | |||
| 4975c2d9e8 | |||
| 965d7d3008 | |||
| 5365fddec9 | |||
| 128385bfe0 | |||
| cfdeb124b9 | |||
| 8deaa6e4f6 | |||
| e401685449 | |||
| e195ad4a8f | |||
| b8a84f62ac | |||
| 20f5271682 | |||
| 5524571c80 | |||
| cd3031479c | |||
| 1df6e8c732 | |||
| ed2e01491e | |||
| 228ed474ce | |||
| 6829a03437 | |||
| cb922b67ad | |||
| b0213742f4 | |||
| 90264f6ec9 | |||
| 1ef6de1869 | |||
| e4ad5084cf | |||
| 0ef2725dfd | |||
| 4bd492f498 | |||
| ef7433c823 | |||
| b6765b48b5 | |||
| 7c45e42602 | |||
| 57a40f84b9 | |||
| e737f71932 | |||
| 9ec6e5f771 | |||
| 4647688613 | |||
| bc0f283278 | |||
| aadadf8353 | |||
| d0a0685fc1 | |||
| ba7d5f53e5 | |||
| 86dde5dc0f | |||
| a5ceceeca3 | |||
| fcfda90d5a | |||
| 2ccfde1ba4 | |||
| ae1c1409e1 | |||
| b46237296e | |||
| 8c06d2f706 | |||
| 1b705edb06 | |||
| 42435e8f76 | |||
| 3111979bb4 | |||
| 56580e3fac | |||
| 7084b8d429 | |||
| 8b0e8506c2 | |||
| 4d133d59ea | |||
| 35b885798b | |||
| ae9e12b276 | |||
| 8018ac1a97 | |||
| f429e3fc01 | |||
| 6c63146556 | |||
| 29242154d0 | |||
| ccc5e830d7 | |||
| adf10f6ea1 | |||
| 26a6ff871f | |||
| d1e85f964d | |||
| 8d041438fd | |||
| c6dc1810e9 | |||
| dabfb8dc0e | |||
| 4aa9c9f225 | |||
| fa394576bb | |||
| 5c1ac376e6 | |||
| 88e77aa116 | |||
| 4e9340f551 | |||
| 0648a1f567 | |||
| 4b1f7db506 | |||
| 0be2177937 | |||
| c52cc5a94f | |||
| 947695fdc7 | |||
| 75296babe3 | |||
| c9381d672e | |||
| 67fa5df611 | |||
| 52a980f751 | |||
| 3b7ab2aec8 | |||
| 122e4a10d0 | |||
| d41e239b89 | |||
| b82b16b5f6 | |||
| ebdd90b235 | |||
| 5c8e06ed94 | |||
| f4e4586fbc | |||
| fab1d29c83 | |||
| de9cb2fbdb | |||
| a419aed404 | |||
| cb47e8decd | |||
| 3c4bb72a82 | |||
| cab79ef185 | |||
| a87c06aab8 | |||
| c19659daa5 | |||
| bafde1c518 | |||
| 45961d2eda | |||
| e1a0dd6810 | |||
| a1d14b9292 | |||
| a7d6065b08 | |||
| 5dbd38721f | |||
| 39fcc04d78 | |||
| 5c7784622e | |||
| b85040f579 | |||
| 8bcd229849 | |||
| d12515ccb9 | |||
| 499cb52e28 | |||
| 05a318225c | |||
| caad0bc005 | |||
| 067ecb5e8e | |||
| 3088887e57 | |||
| 0f8cbeed11 | |||
| 2ed99c0cb8 | |||
| 0a149e3d9e | |||
| a3a26c69c5 | |||
| 2bafc53b25 | |||
| 14f14b75b0 | |||
| 77351b7691 | |||
| b28fadd02f | |||
| 63fa70863c | |||
| 09e9b95e08 | |||
| 11a76ae90f | |||
| bf2ffb7465 | |||
| 287c96ea2e | |||
| adacb8c638 | |||
| 7a3d08672a | |||
| 1973e4d290 | |||
| 7a169c424d | |||
| ec4d106a59 | |||
| 69bcb0e13e | |||
| 54386bf624 | |||
| fe6e65f263 | |||
| fe0c0fac1e | |||
| f05b884646 | |||
| 8e163b8f17 | |||
| caebaf5d46 | |||
| 6950b6f1e7 | |||
| 0e35224787 | |||
| e5a84a2e84 | |||
| 09da7113a0 | |||
| e6e43dbcca | |||
| e02f826707 | |||
| 781b01ee17 | |||
| 1f1086ed7b | |||
| 0a80fc5517 | |||
| 6d8edc95d9 | |||
| 4a4a1686d3 | |||
| a54b49cc30 | |||
| 37218eef4f | |||
| 3b34efd33a | |||
| cc650b58d3 | |||
| 183b46be9e | |||
| a847b74c32 | |||
| d1b339f71d | |||
| a3c638946e | |||
| a0193451a9 | |||
| ede2b75cd0 | |||
| 34ab01e0a1 | |||
| b493172090 | |||
| 6bcd941cc6 | |||
| 98ebfd12b3 | |||
| 305a454ffd | |||
| dfc593f2e1 | |||
| b50203f85d | |||
| e2a0792e2d | |||
| b7d97cca69 | |||
| 7cdc80c3e2 | |||
| 59b6cbc23c | |||
| 21c436d900 | |||
| 87f3628b49 | |||
| 27b315bcca | |||
| 1ec81f9e75 | |||
| 087e757f9f | |||
| ce955e3ee9 | |||
| 4ddada4de8 | |||
| 164386a337 | |||
| d4d2510834 | |||
| 46a5ea88f3 | |||
| 7ca9dcd2fb | |||
| 9c679ede20 | |||
| 60c85b651f | |||
| 73895b5f4b | |||
| 3e2acde9e2 | |||
| a1d8f3eb0f | |||
| 0aba7bad31 | |||
| 75660766db | |||
| 53a6c70eca | |||
| da18ff3d48 | |||
| 4a671a9bc2 | |||
| ae1839ac33 | |||
| 56dbe6b050 | |||
| f5a41e9c78 | |||
| f65149af19 | |||
| affef443b6 | |||
| e40e1d0b36 | |||
| 14638b7470 | |||
| b4a92cecc8 | |||
| 49e4667410 | |||
| 5d26bf15a3 | |||
| 7631d9d730 | |||
| 91b4d806cd | |||
| 1b8bb568b1 | |||
| 18da9a19fd | |||
| 7fdae0173c | |||
| c872707791 | |||
| a0cab3341e | |||
| 25c5d671dc | |||
| 035001f841 | |||
| f533c1a2ca | |||
| e5aa58722c | |||
| 8645fe4ab1 | |||
| 15f216b050 | |||
| b4df5bbb13 | |||
| a17a198912 | |||
| 939782ac4e | |||
| 7e369bef00 | |||
| c3adcf663f | |||
| 373b2fcd78 | |||
| c9d1e30f8b | |||
| 761b57a834 | |||
| 632871b2f8 | |||
| 52f00f08f2 | |||
| 1d5761b1fd | |||
| 42dbc6555c | |||
| b237d9d38d | |||
| 1e9a811065 | |||
| 3c12f9052e | |||
| 9425437480 | |||
| 2385fba695 | |||
| 634c478e18 | |||
| 335bf47dbd | |||
| 23c4117d6f | |||
| 4a5d3b31ab | |||
| 17a27f0d55 | |||
| 8fbb93b0bf | |||
| e09cd6b6d7 | |||
| 26ac9e3c2e | |||
| 958edc0017 | |||
| efa54f3435 | |||
| 88a2cd6659 | |||
| 4e26291a61 | |||
| f25142e597 | |||
| 462fc84240 | |||
| c97ad627d1 | |||
| 36307abc30 | |||
| 8d3dbcb5f8 | |||
| 9d9ae7ba4e | |||
| d682045655 | |||
| 381397ed31 | |||
| 4484f39525 | |||
| 82c08128b6 | |||
| 8cd40a471e | |||
| bd6428d473 | |||
| 87d9c7b410 | |||
| d7960140dc | |||
| d7052b547f | |||
| 6aaef9b7be | |||
| b246676257 | |||
| 0e4b1820e7 | |||
| 1833092998 | |||
| 67a6a6a445 | |||
| 00717126e5 | |||
| 3816076464 | |||
| 710592b053 | |||
| 828c22310d | |||
| f45b744318 | |||
| f49d3791b6 | |||
| ea62294bd8 | |||
| bfe2e87f59 | |||
| 4f8507036a | |||
| 6f6944d003 | |||
| 4216ffd0da | |||
| a32fad06a0 | |||
| 1a49972583 | |||
| a09c52424f | |||
| b869869e26 | |||
| acf2f4758f | |||
| c3b2af5a15 | |||
| 01ffd4c4ca | |||
| a5d4a01ad8 | |||
| abf368e558 | |||
| 4d266fddb1 | |||
| 20dd4794b0 | |||
| 7b96900726 | |||
| 043a4fb5ca | |||
| 82e144be4c | |||
| 9b22e1671f | |||
| 7999149901 | |||
| c70a5d63aa | |||
| d1067bb6b3 | |||
| b43b4b581e | |||
| 55645a75cc | |||
| 36f86ff2b9 | |||
| 0697c79daa | |||
| 3d6a82fb00 | |||
| 97e9e42173 | |||
| 5d8e706c0b | |||
| a8cd2e2eac | |||
| 1e615d69e1 | |||
| 63be1d8cf2 | |||
| f039aa253d | |||
| 5a7521e335 | |||
| 5dac1f5867 | |||
| 1d0fc26025 |
@@ -2,8 +2,8 @@ name: Auto I18N
|
|||||||
|
|
||||||
env:
|
env:
|
||||||
API_KEY: ${{ secrets.TRANSLATE_API_KEY }}
|
API_KEY: ${{ secrets.TRANSLATE_API_KEY }}
|
||||||
MODEL: ${{ vars.MODEL || 'deepseek/deepseek-v3.1'}}
|
MODEL: ${{ vars.AUTO_I18N_MODEL || 'deepseek/deepseek-v3.1'}}
|
||||||
BASE_URL: ${{ vars.BASE_URL || 'https://api.ppinfra.com/openai'}}
|
BASE_URL: ${{ vars.AUTO_I18N_BASE_URL || 'https://api.ppinfra.com/openai'}}
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
@@ -26,7 +26,7 @@ jobs:
|
|||||||
ref: ${{ github.event.pull_request.head.ref }}
|
ref: ${{ github.event.pull_request.head.ref }}
|
||||||
|
|
||||||
- name: 📦 Setting Node.js
|
- name: 📦 Setting Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v5
|
||||||
with:
|
with:
|
||||||
node-version: 20
|
node-version: 20
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
|
|||||||
@@ -16,10 +16,13 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
translate:
|
translate:
|
||||||
if: |
|
if: |
|
||||||
(github.event_name == 'issues') ||
|
(github.event_name == 'issues')
|
||||||
(github.event_name == 'issue_comment' && github.event.sender.type != 'Bot') ||
|
|| (github.event_name == 'issue_comment' && github.event.sender.type != 'Bot')
|
||||||
(github.event_name == 'pull_request_review' && github.event.sender.type != 'Bot') ||
|
|| (
|
||||||
(github.event_name == 'pull_request_review_comment' && github.event.sender.type != 'Bot')
|
(github.event_name == 'pull_request_review' || github.event_name == 'pull_request_review_comment')
|
||||||
|
&& github.event.sender.type != 'Bot'
|
||||||
|
&& github.event.pull_request.head.repo.fork == false
|
||||||
|
)
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
@@ -29,7 +32,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -42,7 +45,7 @@ jobs:
|
|||||||
# See: https://github.com/anthropics/claude-code-action/blob/main/docs/security.md
|
# See: https://github.com/anthropics/claude-code-action/blob/main/docs/security.md
|
||||||
github_token: ${{ secrets.TOKEN_GITHUB_WRITE }}
|
github_token: ${{ secrets.TOKEN_GITHUB_WRITE }}
|
||||||
allowed_non_write_users: "*"
|
allowed_non_write_users: "*"
|
||||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
anthropic_api_key: ${{ secrets.CLAUDE_TRANSLATOR_APIKEY }}
|
||||||
claude_args: "--allowed-tools Bash(gh issue:*),Bash(gh api:repos/*/issues:*),Bash(gh api:repos/*/pulls/*/reviews/*),Bash(gh api:repos/*/pulls/comments/*)"
|
claude_args: "--allowed-tools Bash(gh issue:*),Bash(gh api:repos/*/issues:*),Bash(gh api:repos/*/pulls/*/reviews/*),Bash(gh api:repos/*/pulls/comments/*)"
|
||||||
prompt: |
|
prompt: |
|
||||||
你是一个多语言翻译助手。你需要响应 GitHub Webhooks 中的以下四种事件:
|
你是一个多语言翻译助手。你需要响应 GitHub Webhooks 中的以下四种事件:
|
||||||
@@ -105,3 +108,5 @@ jobs:
|
|||||||
|
|
||||||
使用以下命令获取完整信息:
|
使用以下命令获取完整信息:
|
||||||
gh issue view ${{ github.event.issue.number }} --json title,body,comments
|
gh issue view ${{ github.event.issue.number }} --json title,body,comments
|
||||||
|
env:
|
||||||
|
ANTHROPIC_BASE_URL: ${{ secrets.CLAUDE_TRANSLATOR_BASEURL }}
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ jobs:
|
|||||||
actions: read # Required for Claude to read CI results on PRs
|
actions: read # Required for Claude to read CI results on PRs
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
name: Delete merged branch
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types:
|
||||||
|
- closed
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
delete-branch:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
if: github.event.pull_request.merged == true && github.event.pull_request.head.repo.full_name == github.repository
|
||||||
|
steps:
|
||||||
|
- name: Delete merged branch
|
||||||
|
uses: actions/github-script@v8
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
github.rest.git.deleteRef({
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
ref: `heads/${context.payload.pull_request.head.ref}`,
|
||||||
|
})
|
||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
ref: main
|
ref: main
|
||||||
|
|
||||||
- name: Install Node.js
|
- name: Install Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v5
|
||||||
with:
|
with:
|
||||||
node-version: 20
|
node-version: 20
|
||||||
|
|
||||||
@@ -98,10 +98,10 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
NODE_OPTIONS: --max-old-space-size=8192
|
NODE_OPTIONS: --max-old-space-size=8192
|
||||||
MAIN_VITE_CHERRYIN_CLIENT_SECRET: ${{ secrets.MAIN_VITE_CHERRYIN_CLIENT_SECRET }}
|
MAIN_VITE_CHERRYAI_CLIENT_SECRET: ${{ secrets.MAIN_VITE_CHERRYAI_CLIENT_SECRET }}
|
||||||
MAIN_VITE_MINERU_API_KEY: ${{ vars.MAIN_VITE_MINERU_API_KEY }}
|
MAIN_VITE_MINERU_API_KEY: ${{ secrets.MAIN_VITE_MINERU_API_KEY }}
|
||||||
RENDERER_VITE_AIHUBMIX_SECRET: ${{ vars.RENDERER_VITE_AIHUBMIX_SECRET }}
|
RENDERER_VITE_AIHUBMIX_SECRET: ${{ secrets.RENDERER_VITE_AIHUBMIX_SECRET }}
|
||||||
RENDERER_VITE_PPIO_APP_SECRET: ${{ vars.RENDERER_VITE_PPIO_APP_SECRET }}
|
RENDERER_VITE_PPIO_APP_SECRET: ${{ secrets.RENDERER_VITE_PPIO_APP_SECRET }}
|
||||||
|
|
||||||
- name: Build Mac
|
- name: Build Mac
|
||||||
if: matrix.os == 'macos-latest'
|
if: matrix.os == 'macos-latest'
|
||||||
@@ -110,15 +110,15 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
CSC_LINK: ${{ secrets.CSC_LINK }}
|
CSC_LINK: ${{ secrets.CSC_LINK }}
|
||||||
CSC_KEY_PASSWORD: ${{ secrets.CSC_KEY_PASSWORD }}
|
CSC_KEY_PASSWORD: ${{ secrets.CSC_KEY_PASSWORD }}
|
||||||
APPLE_ID: ${{ vars.APPLE_ID }}
|
APPLE_ID: ${{ secrets.APPLE_ID }}
|
||||||
APPLE_APP_SPECIFIC_PASSWORD: ${{ vars.APPLE_APP_SPECIFIC_PASSWORD }}
|
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
|
||||||
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
|
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
NODE_OPTIONS: --max-old-space-size=8192
|
NODE_OPTIONS: --max-old-space-size=8192
|
||||||
MAIN_VITE_CHERRYIN_CLIENT_SECRET: ${{ secrets.MAIN_VITE_CHERRYIN_CLIENT_SECRET }}
|
MAIN_VITE_CHERRYAI_CLIENT_SECRET: ${{ secrets.MAIN_VITE_CHERRYAI_CLIENT_SECRET }}
|
||||||
MAIN_VITE_MINERU_API_KEY: ${{ vars.MAIN_VITE_MINERU_API_KEY }}
|
MAIN_VITE_MINERU_API_KEY: ${{ secrets.MAIN_VITE_MINERU_API_KEY }}
|
||||||
RENDERER_VITE_AIHUBMIX_SECRET: ${{ vars.RENDERER_VITE_AIHUBMIX_SECRET }}
|
RENDERER_VITE_AIHUBMIX_SECRET: ${{ secrets.RENDERER_VITE_AIHUBMIX_SECRET }}
|
||||||
RENDERER_VITE_PPIO_APP_SECRET: ${{ vars.RENDERER_VITE_PPIO_APP_SECRET }}
|
RENDERER_VITE_PPIO_APP_SECRET: ${{ secrets.RENDERER_VITE_PPIO_APP_SECRET }}
|
||||||
|
|
||||||
- name: Build Windows
|
- name: Build Windows
|
||||||
if: matrix.os == 'windows-latest'
|
if: matrix.os == 'windows-latest'
|
||||||
@@ -127,10 +127,10 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
NODE_OPTIONS: --max-old-space-size=8192
|
NODE_OPTIONS: --max-old-space-size=8192
|
||||||
MAIN_VITE_CHERRYIN_CLIENT_SECRET: ${{ secrets.MAIN_VITE_CHERRYIN_CLIENT_SECRET }}
|
MAIN_VITE_CHERRYAI_CLIENT_SECRET: ${{ secrets.MAIN_VITE_CHERRYAI_CLIENT_SECRET }}
|
||||||
MAIN_VITE_MINERU_API_KEY: ${{ vars.MAIN_VITE_MINERU_API_KEY }}
|
MAIN_VITE_MINERU_API_KEY: ${{ secrets.MAIN_VITE_MINERU_API_KEY }}
|
||||||
RENDERER_VITE_AIHUBMIX_SECRET: ${{ vars.RENDERER_VITE_AIHUBMIX_SECRET }}
|
RENDERER_VITE_AIHUBMIX_SECRET: ${{ secrets.RENDERER_VITE_AIHUBMIX_SECRET }}
|
||||||
RENDERER_VITE_PPIO_APP_SECRET: ${{ vars.RENDERER_VITE_PPIO_APP_SECRET }}
|
RENDERER_VITE_PPIO_APP_SECRET: ${{ secrets.RENDERER_VITE_PPIO_APP_SECRET }}
|
||||||
|
|
||||||
- name: Rename artifacts with nightly format
|
- name: Rename artifacts with nightly format
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
@@ -10,19 +10,21 @@ on:
|
|||||||
- main
|
- main
|
||||||
- develop
|
- develop
|
||||||
- v2
|
- v2
|
||||||
|
types: [ready_for_review, synchronize, opened]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
PRCI: true
|
PRCI: true
|
||||||
|
if: github.event.pull_request.draft == false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Check out Git repository
|
- name: Check out Git repository
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Install Node.js
|
- name: Install Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v5
|
||||||
with:
|
with:
|
||||||
node-version: 20
|
node-version: 20
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ jobs:
|
|||||||
npm version "$VERSION" --no-git-tag-version --allow-same-version
|
npm version "$VERSION" --no-git-tag-version --allow-same-version
|
||||||
|
|
||||||
- name: Install Node.js
|
- name: Install Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v5
|
||||||
with:
|
with:
|
||||||
node-version: 20
|
node-version: 20
|
||||||
|
|
||||||
@@ -85,10 +85,10 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
NODE_OPTIONS: --max-old-space-size=8192
|
NODE_OPTIONS: --max-old-space-size=8192
|
||||||
MAIN_VITE_CHERRYIN_CLIENT_SECRET: ${{ secrets.MAIN_VITE_CHERRYIN_CLIENT_SECRET }}
|
MAIN_VITE_CHERRYAI_CLIENT_SECRET: ${{ secrets.MAIN_VITE_CHERRYAI_CLIENT_SECRET }}
|
||||||
MAIN_VITE_MINERU_API_KEY: ${{ vars.MAIN_VITE_MINERU_API_KEY }}
|
MAIN_VITE_MINERU_API_KEY: ${{ secrets.MAIN_VITE_MINERU_API_KEY }}
|
||||||
RENDERER_VITE_AIHUBMIX_SECRET: ${{ vars.RENDERER_VITE_AIHUBMIX_SECRET }}
|
RENDERER_VITE_AIHUBMIX_SECRET: ${{ secrets.RENDERER_VITE_AIHUBMIX_SECRET }}
|
||||||
RENDERER_VITE_PPIO_APP_SECRET: ${{ vars.RENDERER_VITE_PPIO_APP_SECRET }}
|
RENDERER_VITE_PPIO_APP_SECRET: ${{ secrets.RENDERER_VITE_PPIO_APP_SECRET }}
|
||||||
|
|
||||||
- name: Build Mac
|
- name: Build Mac
|
||||||
if: matrix.os == 'macos-latest'
|
if: matrix.os == 'macos-latest'
|
||||||
@@ -98,15 +98,15 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
CSC_LINK: ${{ secrets.CSC_LINK }}
|
CSC_LINK: ${{ secrets.CSC_LINK }}
|
||||||
CSC_KEY_PASSWORD: ${{ secrets.CSC_KEY_PASSWORD }}
|
CSC_KEY_PASSWORD: ${{ secrets.CSC_KEY_PASSWORD }}
|
||||||
APPLE_ID: ${{ vars.APPLE_ID }}
|
APPLE_ID: ${{ secrets.APPLE_ID }}
|
||||||
APPLE_APP_SPECIFIC_PASSWORD: ${{ vars.APPLE_APP_SPECIFIC_PASSWORD }}
|
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
|
||||||
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
|
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
NODE_OPTIONS: --max-old-space-size=8192
|
NODE_OPTIONS: --max-old-space-size=8192
|
||||||
MAIN_VITE_CHERRYIN_CLIENT_SECRET: ${{ secrets.MAIN_VITE_CHERRYIN_CLIENT_SECRET }}
|
MAIN_VITE_CHERRYAI_CLIENT_SECRET: ${{ secrets.MAIN_VITE_CHERRYAI_CLIENT_SECRET }}
|
||||||
MAIN_VITE_MINERU_API_KEY: ${{ vars.MAIN_VITE_MINERU_API_KEY }}
|
MAIN_VITE_MINERU_API_KEY: ${{ secrets.MAIN_VITE_MINERU_API_KEY }}
|
||||||
RENDERER_VITE_AIHUBMIX_SECRET: ${{ vars.RENDERER_VITE_AIHUBMIX_SECRET }}
|
RENDERER_VITE_AIHUBMIX_SECRET: ${{ secrets.RENDERER_VITE_AIHUBMIX_SECRET }}
|
||||||
RENDERER_VITE_PPIO_APP_SECRET: ${{ vars.RENDERER_VITE_PPIO_APP_SECRET }}
|
RENDERER_VITE_PPIO_APP_SECRET: ${{ secrets.RENDERER_VITE_PPIO_APP_SECRET }}
|
||||||
|
|
||||||
- name: Build Windows
|
- name: Build Windows
|
||||||
if: matrix.os == 'windows-latest'
|
if: matrix.os == 'windows-latest'
|
||||||
@@ -115,10 +115,10 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
NODE_OPTIONS: --max-old-space-size=8192
|
NODE_OPTIONS: --max-old-space-size=8192
|
||||||
MAIN_VITE_CHERRYIN_CLIENT_SECRET: ${{ secrets.MAIN_VITE_CHERRYIN_CLIENT_SECRET }}
|
MAIN_VITE_CHERRYAI_CLIENT_SECRET: ${{ secrets.MAIN_VITE_CHERRYAI_CLIENT_SECRET }}
|
||||||
MAIN_VITE_MINERU_API_KEY: ${{ vars.MAIN_VITE_MINERU_API_KEY }}
|
MAIN_VITE_MINERU_API_KEY: ${{ secrets.MAIN_VITE_MINERU_API_KEY }}
|
||||||
RENDERER_VITE_AIHUBMIX_SECRET: ${{ vars.RENDERER_VITE_AIHUBMIX_SECRET }}
|
RENDERER_VITE_AIHUBMIX_SECRET: ${{ secrets.RENDERER_VITE_AIHUBMIX_SECRET }}
|
||||||
RENDERER_VITE_PPIO_APP_SECRET: ${{ vars.RENDERER_VITE_PPIO_APP_SECRET }}
|
RENDERER_VITE_PPIO_APP_SECRET: ${{ secrets.RENDERER_VITE_PPIO_APP_SECRET }}
|
||||||
|
|
||||||
- name: Release
|
- name: Release
|
||||||
uses: ncipollo/release-action@v1
|
uses: ncipollo/release-action@v1
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ local
|
|||||||
.qwen/*
|
.qwen/*
|
||||||
.trae/*
|
.trae/*
|
||||||
.claude-code-router/*
|
.claude-code-router/*
|
||||||
|
.codebuddy/*
|
||||||
|
.zed/*
|
||||||
CLAUDE.local.md
|
CLAUDE.local.md
|
||||||
|
|
||||||
# vitest
|
# vitest
|
||||||
|
|||||||
+2
-2
@@ -15,7 +15,7 @@
|
|||||||
".gitignore",
|
".gitignore",
|
||||||
"scripts/cloudflare-worker.js",
|
"scripts/cloudflare-worker.js",
|
||||||
"src/main/integration/nutstore/sso/lib/**",
|
"src/main/integration/nutstore/sso/lib/**",
|
||||||
"src/main/integration/cherryin/index.js",
|
"src/main/integration/cherryai/index.js",
|
||||||
"src/main/integration/nutstore/sso/lib/**",
|
"src/main/integration/nutstore/sso/lib/**",
|
||||||
"src/renderer/src/ui/**",
|
"src/renderer/src/ui/**",
|
||||||
"packages/**/dist",
|
"packages/**/dist",
|
||||||
@@ -117,7 +117,7 @@
|
|||||||
"no-unused-expressions": "off", // this rule disallow us to use expression to call function, like `condition && fn()`
|
"no-unused-expressions": "off", // this rule disallow us to use expression to call function, like `condition && fn()`
|
||||||
"no-unused-labels": "error",
|
"no-unused-labels": "error",
|
||||||
"no-unused-private-class-members": "error",
|
"no-unused-private-class-members": "error",
|
||||||
"no-unused-vars": ["error", { "caughtErrors": "none" }],
|
"no-unused-vars": ["warn", { "caughtErrors": "none" }],
|
||||||
"no-useless-backreference": "error",
|
"no-useless-backreference": "error",
|
||||||
"no-useless-catch": "error",
|
"no-useless-catch": "error",
|
||||||
"no-useless-escape": "error",
|
"no-useless-escape": "error",
|
||||||
|
|||||||
Vendored
+5
-1
@@ -47,5 +47,9 @@
|
|||||||
"search.exclude": {
|
"search.exclude": {
|
||||||
"**/dist/**": true,
|
"**/dist/**": true,
|
||||||
".yarn/releases/**": true
|
".yarn/releases/**": true
|
||||||
}
|
},
|
||||||
|
"tailwindCSS.classAttributes": [
|
||||||
|
"className",
|
||||||
|
"classNames",
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
diff --git a/dist/index.mjs b/dist/index.mjs
|
||||||
|
index 69ab1599c76801dc1167551b6fa283dded123466..f0af43bba7ad1196fe05338817e65b4ebda40955 100644
|
||||||
|
--- a/dist/index.mjs
|
||||||
|
+++ b/dist/index.mjs
|
||||||
|
@@ -477,7 +477,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
|
||||||
|
|
||||||
|
// src/get-model-path.ts
|
||||||
|
function getModelPath(modelId) {
|
||||||
|
- return modelId.includes("/") ? modelId : `models/${modelId}`;
|
||||||
|
+ return modelId?.includes("models/") ? modelId : `models/${modelId}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// src/google-generative-ai-options.ts
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
diff --git a/sdk.mjs b/sdk.mjs
|
||||||
|
index 461e9a2ba246778261108a682762ffcf26f7224e..44bd667d9f591969d36a105ba5eb8b478c738dd8 100644
|
||||||
|
--- a/sdk.mjs
|
||||||
|
+++ b/sdk.mjs
|
||||||
|
@@ -6215,7 +6215,7 @@ function createAbortController(maxListeners = DEFAULT_MAX_LISTENERS) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// ../src/transport/ProcessTransport.ts
|
||||||
|
-import { spawn } from "child_process";
|
||||||
|
+import { fork } from "child_process";
|
||||||
|
import { createInterface } from "readline";
|
||||||
|
|
||||||
|
// ../src/utils/fsOperations.ts
|
||||||
|
@@ -6473,14 +6473,11 @@ class ProcessTransport {
|
||||||
|
const errorMessage = isNativeBinary(pathToClaudeCodeExecutable) ? `Claude Code native binary not found at ${pathToClaudeCodeExecutable}. Please ensure Claude Code is installed via native installer or specify a valid path with options.pathToClaudeCodeExecutable.` : `Claude Code executable not found at ${pathToClaudeCodeExecutable}. Is options.pathToClaudeCodeExecutable set?`;
|
||||||
|
throw new ReferenceError(errorMessage);
|
||||||
|
}
|
||||||
|
- const isNative = isNativeBinary(pathToClaudeCodeExecutable);
|
||||||
|
- const spawnCommand = isNative ? pathToClaudeCodeExecutable : executable;
|
||||||
|
- const spawnArgs = isNative ? args : [...executableArgs, pathToClaudeCodeExecutable, ...args];
|
||||||
|
- this.logForDebugging(isNative ? `Spawning Claude Code native binary: ${pathToClaudeCodeExecutable} ${args.join(" ")}` : `Spawning Claude Code process: ${executable} ${[...executableArgs, pathToClaudeCodeExecutable, ...args].join(" ")}`);
|
||||||
|
+ this.logForDebugging(`Forking Claude Code Node.js process: ${pathToClaudeCodeExecutable} ${args.join(" ")}`);
|
||||||
|
const stderrMode = env.DEBUG || stderr ? "pipe" : "ignore";
|
||||||
|
- this.child = spawn(spawnCommand, spawnArgs, {
|
||||||
|
+ this.child = fork(pathToClaudeCodeExecutable, args, {
|
||||||
|
cwd,
|
||||||
|
- stdio: ["pipe", "pipe", stderrMode],
|
||||||
|
+ stdio: stderrMode === "pipe" ? ["pipe", "pipe", "pipe", "ipc"] : ["pipe", "pipe", "ignore", "ipc"],
|
||||||
|
signal: this.abortController.signal,
|
||||||
|
env
|
||||||
|
});
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
diff --git a/sdk.mjs b/sdk.mjs
|
|
||||||
index e2dbafb4e2faa1bf2b6b02f0009a2b9bbf57c757..3f07a1d5c2949a246fe5414e69ab45942fa605a2 100644
|
|
||||||
--- a/sdk.mjs
|
|
||||||
+++ b/sdk.mjs
|
|
||||||
@@ -6355,11 +6355,11 @@ class ProcessTransport {
|
|
||||||
prompt,
|
|
||||||
additionalDirectories = [],
|
|
||||||
cwd,
|
|
||||||
- executable = isRunningWithBun() ? "bun" : "node",
|
|
||||||
+ executable = process.execPath,
|
|
||||||
executableArgs = [],
|
|
||||||
extraArgs = {},
|
|
||||||
pathToClaudeCodeExecutable,
|
|
||||||
- env = { ...process.env },
|
|
||||||
+ env = { ...process.env, ELECTRON_RUN_AS_NODE: '1' },
|
|
||||||
stderr,
|
|
||||||
customSystemPrompt,
|
|
||||||
appendSystemPrompt,
|
|
||||||
+44
@@ -0,0 +1,44 @@
|
|||||||
|
diff --git a/dist/index.js b/dist/index.js
|
||||||
|
index 53f411e55a4c9a06fd29bb4ab8161c4ad15980cd..71b91f196c8b886ed90dd237dec5625d79d5677e 100644
|
||||||
|
--- a/dist/index.js
|
||||||
|
+++ b/dist/index.js
|
||||||
|
@@ -12676,10 +12676,13 @@ var OpenAIResponsesLanguageModel = class {
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else if (value.item.type === "message") {
|
||||||
|
- controller.enqueue({
|
||||||
|
- type: "text-end",
|
||||||
|
- id: value.item.id
|
||||||
|
- });
|
||||||
|
+ // Fix for gpt-5-codex: use currentTextId to ensure text-end matches text-start
|
||||||
|
+ if (currentTextId) {
|
||||||
|
+ controller.enqueue({
|
||||||
|
+ type: "text-end",
|
||||||
|
+ id: currentTextId
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
currentTextId = null;
|
||||||
|
} else if (isResponseOutputItemDoneReasoningChunk(value)) {
|
||||||
|
const activeReasoningPart = activeReasoning[value.item.id];
|
||||||
|
diff --git a/dist/index.mjs b/dist/index.mjs
|
||||||
|
index 7719264da3c49a66c2626082f6ccaae6e3ef5e89..090fd8cf142674192a826148428ed6a0c4a54e35 100644
|
||||||
|
--- a/dist/index.mjs
|
||||||
|
+++ b/dist/index.mjs
|
||||||
|
@@ -12670,10 +12670,13 @@ var OpenAIResponsesLanguageModel = class {
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else if (value.item.type === "message") {
|
||||||
|
- controller.enqueue({
|
||||||
|
- type: "text-end",
|
||||||
|
- id: value.item.id
|
||||||
|
- });
|
||||||
|
+ // Fix for gpt-5-codex: use currentTextId to ensure text-end matches text-start
|
||||||
|
+ if (currentTextId) {
|
||||||
|
+ controller.enqueue({
|
||||||
|
+ type: "text-end",
|
||||||
|
+ id: currentTextId
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
currentTextId = null;
|
||||||
|
} else if (isResponseOutputItemDoneReasoningChunk(value)) {
|
||||||
|
const activeReasoningPart = activeReasoning[value.item.id];
|
||||||
@@ -2,32 +2,31 @@
|
|||||||
|
|
||||||
This file provides guidance to AI coding assistants when working with code in this repository. Adherence to these guidelines is crucial for maintaining code quality and consistency.
|
This file provides guidance to AI coding assistants when working with code in this repository. Adherence to these guidelines is crucial for maintaining code quality and consistency.
|
||||||
|
|
||||||
## Guiding Principles
|
## Guiding Principles (MUST FOLLOW)
|
||||||
|
|
||||||
- **Clarity and Simplicity**: Write code that is easy to understand and maintain.
|
- **Keep it clear**: Write code that is easy to read, maintain, and explain.
|
||||||
- **Consistency**: Follow existing patterns and conventions in the codebase.
|
- **Match the house style**: Reuse existing patterns, naming, and conventions.
|
||||||
- **Correctness**: Ensure code is correct, well-tested, and robust.
|
- **Search smart**: Prefer `ast-grep` for semantic queries; fall back to `rg`/`grep` when needed.
|
||||||
- **Efficiency**: Write performant code and use resources judiciously.
|
- **Build with HeroUI**: Use HeroUI for every new UI component; never add `antd` or `styled-components`.
|
||||||
|
- **Log centrally**: Route all logging through `loggerService` with the right context—no `console.log`.
|
||||||
## MUST Follow Rules
|
- **Research via subagent**: Lean on `subagent` for external docs, APIs, news, and references.
|
||||||
|
- **Seek review**: Ask a human developer to review substantial changes before merging.
|
||||||
1. **Code Search**: Use `ast-grep` for semantic code pattern searches when available. Fallback to `rg` (ripgrep) or `grep` for text-based searches.
|
- **Commit in rhythm**: Keep commits small, conventional, and emoji-tagged.
|
||||||
2. **UI Framework**: Exclusively use **HeroUI** for all new UI components. The use of `antd` or `styled-components` is strictly **PROHIBITED**.
|
|
||||||
3. **Quality Assurance**: **Always** run `yarn build:check` before finalizing your work or making any commits. This ensures code quality (linting, testing, and type checking).
|
|
||||||
4. **Centralized Logging**: Use the `loggerService` exclusively for all application logging (info, warn, error levels) with proper context. Do not use `console.log`.
|
|
||||||
5. **External Research**: Leverage `subagent` for gathering external information, including latest documentation, API references, news, or web-based research. This keeps the main conversation focused on the task at hand.
|
|
||||||
6. **Code Reviews**: Always seek a code review from a human developer before merging significant changes. This ensures adherence to project standards and catches potential issues.
|
|
||||||
7. **Documentation**: Update or create documentation for any new features, modules, or significant changes to existing functionality.
|
|
||||||
|
|
||||||
## Development Commands
|
## Development Commands
|
||||||
|
|
||||||
- **Install**: `yarn install`
|
- **Install**: `yarn install` - Install all project dependencies
|
||||||
- **Development**: `yarn dev` - Runs Electron app in development mode
|
- **Development**: `yarn dev` - Runs Electron app in development mode with hot reload
|
||||||
- **Debug**: `yarn debug` - Starts with debugging enabled, use chrome://inspect
|
- **Debug**: `yarn debug` - Starts with debugging enabled, use `chrome://inspect` to attach debugger
|
||||||
- **Build Check**: `yarn build:check` - REQUIRED before commits (lint + test + typecheck)
|
- **Build Check**: `yarn build:check` - **REQUIRED** before commits (lint + test + typecheck)
|
||||||
- **Test**: `yarn test` - Run all tests (Vitest)
|
- If having i18n sort issues, run `yarn sync:i18n` first to sync template
|
||||||
- **Single Test**: `yarn test:main` or `yarn test:renderer`
|
- If having formatting issues, run `yarn format` first
|
||||||
- **Lint**: `yarn lint` - Fix linting issues and run typecheck
|
- **Test**: `yarn test` - Run all tests (Vitest) across main and renderer processes
|
||||||
|
- **Single Test**:
|
||||||
|
- `yarn test:main` - Run tests for main process only
|
||||||
|
- `yarn test:renderer` - Run tests for renderer process only
|
||||||
|
- **Lint**: `yarn lint` - Fix linting issues and run TypeScript type checking
|
||||||
|
- **Format**: `yarn format` - Auto-format code using Biome
|
||||||
|
|
||||||
## Project Architecture
|
## Project Architecture
|
||||||
|
|
||||||
|
|||||||
+56
-11
@@ -125,16 +125,61 @@ afterSign: scripts/notarize.js
|
|||||||
artifactBuildCompleted: scripts/artifact-build-completed.js
|
artifactBuildCompleted: scripts/artifact-build-completed.js
|
||||||
releaseInfo:
|
releaseInfo:
|
||||||
releaseNotes: |
|
releaseNotes: |
|
||||||
🐛 问题修复:
|
<!--LANG:en-->
|
||||||
- 修复 Anthropic API URL 处理,移除尾部斜杠并添加端点路径处理
|
What's New in v1.7.0-beta.2
|
||||||
- 修复 MessageEditor 缺少 QuickPanelProvider 包装的问题
|
|
||||||
- 修复 MiniWindow 高度问题
|
|
||||||
|
|
||||||
🚀 性能优化:
|
New Features:
|
||||||
- 优化输入栏提及模型状态缓存,在渲染间保持状态
|
- Session Settings: Manage session-specific settings and model configurations independently
|
||||||
- 重构网络搜索参数支持模型内置搜索,新增 OpenAI Chat 和 OpenRouter 支持
|
- Notes Full-Text Search: Search across all notes with match highlighting
|
||||||
|
- Built-in DiDi MCP Server: Integration with DiDi ride-hailing services (China only)
|
||||||
|
- Intel OV OCR: Hardware-accelerated OCR using Intel NPU
|
||||||
|
- Auto-start API Server: Automatically starts when agents exist
|
||||||
|
|
||||||
🔧 重构改进:
|
Improvements:
|
||||||
- 更新 HeroUIProvider 导入路径,改善上下文管理
|
- Agent model selection now requires explicit user choice
|
||||||
- 更新依赖项和 VSCode 开发环境配置
|
- Added Mistral AI provider support
|
||||||
- 升级 @cherrystudio/ai-core 到 v1.0.0-alpha.17
|
- Added NewAPI generic provider support
|
||||||
|
- Improved navbar layout consistency across different modes
|
||||||
|
- Enhanced chat component responsiveness
|
||||||
|
- Better code block display on small screens
|
||||||
|
- Updated OVMS to 2025.3 official release
|
||||||
|
- Added Greek language support
|
||||||
|
|
||||||
|
Bug Fixes:
|
||||||
|
- Fixed GitHub Copilot gpt-5-codex streaming issues
|
||||||
|
- Fixed assistant creation failures
|
||||||
|
- Fixed translate auto-copy functionality
|
||||||
|
- Fixed miniapps external link opening
|
||||||
|
- Fixed message layout and overflow issues
|
||||||
|
- Fixed API key parsing to preserve spaces
|
||||||
|
- Fixed agent display in different navbar layouts
|
||||||
|
|
||||||
|
<!--LANG:zh-CN-->
|
||||||
|
v1.7.0-beta.2 新特性
|
||||||
|
|
||||||
|
新功能:
|
||||||
|
- 会话设置:独立管理会话特定的设置和模型配置
|
||||||
|
- 笔记全文搜索:跨所有笔记搜索并高亮匹配内容
|
||||||
|
- 内置滴滴 MCP 服务器:集成滴滴打车服务(仅限中国地区)
|
||||||
|
- Intel OV OCR:使用 Intel NPU 的硬件加速 OCR
|
||||||
|
- 自动启动 API 服务器:当存在 Agent 时自动启动
|
||||||
|
|
||||||
|
改进:
|
||||||
|
- Agent 模型选择现在需要用户显式选择
|
||||||
|
- 添加 Mistral AI 提供商支持
|
||||||
|
- 添加 NewAPI 通用提供商支持
|
||||||
|
- 改进不同模式下的导航栏布局一致性
|
||||||
|
- 增强聊天组件响应式设计
|
||||||
|
- 优化小屏幕代码块显示
|
||||||
|
- 更新 OVMS 至 2025.3 正式版
|
||||||
|
- 添加希腊语支持
|
||||||
|
|
||||||
|
问题修复:
|
||||||
|
- 修复 GitHub Copilot gpt-5-codex 流式传输问题
|
||||||
|
- 修复助手创建失败
|
||||||
|
- 修复翻译自动复制功能
|
||||||
|
- 修复小程序外部链接打开
|
||||||
|
- 修复消息布局和溢出问题
|
||||||
|
- 修复 API 密钥解析以保留空格
|
||||||
|
- 修复不同导航栏布局中的 Agent 显示
|
||||||
|
<!--LANG:END-->
|
||||||
|
|||||||
@@ -34,6 +34,10 @@ export default defineConfig({
|
|||||||
output: {
|
output: {
|
||||||
manualChunks: undefined, // 彻底禁用代码分割 - 返回 null 强制单文件打包
|
manualChunks: undefined, // 彻底禁用代码分割 - 返回 null 强制单文件打包
|
||||||
inlineDynamicImports: true // 内联所有动态导入,这是关键配置
|
inlineDynamicImports: true // 内联所有动态导入,这是关键配置
|
||||||
|
},
|
||||||
|
onwarn(warning, warn) {
|
||||||
|
if (warning.code === 'COMMONJS_VARIABLE_IN_ESM') return
|
||||||
|
warn(warning)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
sourcemap: isDev
|
sourcemap: isDev
|
||||||
@@ -112,6 +116,10 @@ export default defineConfig({
|
|||||||
selectionToolbar: resolve(__dirname, 'src/renderer/selectionToolbar.html'),
|
selectionToolbar: resolve(__dirname, 'src/renderer/selectionToolbar.html'),
|
||||||
selectionAction: resolve(__dirname, 'src/renderer/selectionAction.html'),
|
selectionAction: resolve(__dirname, 'src/renderer/selectionAction.html'),
|
||||||
traceWindow: resolve(__dirname, 'src/renderer/traceWindow.html')
|
traceWindow: resolve(__dirname, 'src/renderer/traceWindow.html')
|
||||||
|
},
|
||||||
|
onwarn(warning, warn) {
|
||||||
|
if (warning.code === 'COMMONJS_VARIABLE_IN_ESM') return
|
||||||
|
warn(warning)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
+26
-23
@@ -2,6 +2,7 @@ import tseslint from '@electron-toolkit/eslint-config-ts'
|
|||||||
import eslint from '@eslint/js'
|
import eslint from '@eslint/js'
|
||||||
import eslintReact from '@eslint-react/eslint-plugin'
|
import eslintReact from '@eslint-react/eslint-plugin'
|
||||||
import { defineConfig } from 'eslint/config'
|
import { defineConfig } from 'eslint/config'
|
||||||
|
import importZod from 'eslint-plugin-import-zod'
|
||||||
import oxlint from 'eslint-plugin-oxlint'
|
import oxlint from 'eslint-plugin-oxlint'
|
||||||
import reactHooks from 'eslint-plugin-react-hooks'
|
import reactHooks from 'eslint-plugin-react-hooks'
|
||||||
import simpleImportSort from 'eslint-plugin-simple-import-sort'
|
import simpleImportSort from 'eslint-plugin-simple-import-sort'
|
||||||
@@ -11,11 +12,12 @@ export default defineConfig([
|
|||||||
eslint.configs.recommended,
|
eslint.configs.recommended,
|
||||||
tseslint.configs.recommended,
|
tseslint.configs.recommended,
|
||||||
eslintReact.configs['recommended-typescript'],
|
eslintReact.configs['recommended-typescript'],
|
||||||
reactHooks.configs['recommended-latest'],
|
reactHooks.configs.flat.recommended,
|
||||||
{
|
{
|
||||||
plugins: {
|
plugins: {
|
||||||
'simple-import-sort': simpleImportSort,
|
'simple-import-sort': simpleImportSort,
|
||||||
'unused-imports': unusedImports
|
'unused-imports': unusedImports,
|
||||||
|
'import-zod': importZod
|
||||||
},
|
},
|
||||||
rules: {
|
rules: {
|
||||||
'@typescript-eslint/explicit-function-return-type': 'off',
|
'@typescript-eslint/explicit-function-return-type': 'off',
|
||||||
@@ -25,6 +27,7 @@ export default defineConfig([
|
|||||||
'simple-import-sort/exports': 'error',
|
'simple-import-sort/exports': 'error',
|
||||||
'unused-imports/no-unused-imports': 'error',
|
'unused-imports/no-unused-imports': 'error',
|
||||||
'@eslint-react/no-prop-types': 'error',
|
'@eslint-react/no-prop-types': 'error',
|
||||||
|
'import-zod/prefer-zod-namespace': 'error'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// Configuration for ensuring compatibility with the original ESLint(8.x) rules
|
// Configuration for ensuring compatibility with the original ESLint(8.x) rules
|
||||||
@@ -48,6 +51,27 @@ export default defineConfig([
|
|||||||
'@eslint-react/no-children-to-array': 'off'
|
'@eslint-react/no-children-to-array': 'off'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ignores: [
|
||||||
|
'node_modules/**',
|
||||||
|
'build/**',
|
||||||
|
'dist/**',
|
||||||
|
'out/**',
|
||||||
|
'local/**',
|
||||||
|
'.yarn/**',
|
||||||
|
'.gitignore',
|
||||||
|
'scripts/cloudflare-worker.js',
|
||||||
|
'src/main/integration/nutstore/sso/lib/**',
|
||||||
|
'src/main/integration/cherryai/index.js',
|
||||||
|
'src/main/integration/nutstore/sso/lib/**',
|
||||||
|
'src/renderer/src/ui/**',
|
||||||
|
'packages/**/dist'
|
||||||
|
]
|
||||||
|
},
|
||||||
|
// turn off oxlint supported rules.
|
||||||
|
...oxlint.configs['flat/eslint'],
|
||||||
|
...oxlint.configs['flat/typescript'],
|
||||||
|
...oxlint.configs['flat/unicorn'],
|
||||||
{
|
{
|
||||||
// LoggerService Custom Rules - only apply to src directory
|
// LoggerService Custom Rules - only apply to src directory
|
||||||
files: ['src/**/*.{ts,tsx,js,jsx}'],
|
files: ['src/**/*.{ts,tsx,js,jsx}'],
|
||||||
@@ -110,25 +134,4 @@ export default defineConfig([
|
|||||||
'i18n/no-template-in-t': 'warn'
|
'i18n/no-template-in-t': 'warn'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
ignores: [
|
|
||||||
'node_modules/**',
|
|
||||||
'build/**',
|
|
||||||
'dist/**',
|
|
||||||
'out/**',
|
|
||||||
'local/**',
|
|
||||||
'.yarn/**',
|
|
||||||
'.gitignore',
|
|
||||||
'scripts/cloudflare-worker.js',
|
|
||||||
'src/main/integration/nutstore/sso/lib/**',
|
|
||||||
'src/main/integration/cherryin/index.js',
|
|
||||||
'src/main/integration/nutstore/sso/lib/**',
|
|
||||||
'src/renderer/src/ui/**',
|
|
||||||
'packages/**/dist'
|
|
||||||
]
|
|
||||||
},
|
|
||||||
// turn off oxlint supported rules.
|
|
||||||
...oxlint.configs['flat/eslint'],
|
|
||||||
...oxlint.configs['flat/typescript'],
|
|
||||||
...oxlint.configs['flat/unicorn']
|
|
||||||
])
|
])
|
||||||
|
|||||||
+22
-18
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "CherryStudio",
|
"name": "CherryStudio",
|
||||||
"version": "1.6.0-rc.2",
|
"version": "1.7.0-beta.2",
|
||||||
"private": true,
|
"private": true,
|
||||||
"description": "A powerful AI assistant for producer.",
|
"description": "A powerful AI assistant for producer.",
|
||||||
"main": "./out/main/index.js",
|
"main": "./out/main/index.js",
|
||||||
@@ -27,7 +27,6 @@
|
|||||||
"scripts": {
|
"scripts": {
|
||||||
"start": "electron-vite preview",
|
"start": "electron-vite preview",
|
||||||
"dev": "dotenv electron-vite dev",
|
"dev": "dotenv electron-vite dev",
|
||||||
"dev:main": "dotenv electron-vite dev --watch",
|
|
||||||
"debug": "electron-vite -- --inspect --sourcemap --remote-debugging-port=9222",
|
"debug": "electron-vite -- --inspect --sourcemap --remote-debugging-port=9222",
|
||||||
"build": "npm run typecheck && electron-vite build",
|
"build": "npm run typecheck && electron-vite build",
|
||||||
"build:check": "yarn lint && yarn test",
|
"build:check": "yarn lint && yarn test",
|
||||||
@@ -69,7 +68,7 @@
|
|||||||
"test:e2e": "yarn playwright test",
|
"test:e2e": "yarn playwright test",
|
||||||
"test:lint": "oxlint --deny-warnings && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --cache",
|
"test:lint": "oxlint --deny-warnings && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --cache",
|
||||||
"test:scripts": "vitest scripts",
|
"test:scripts": "vitest scripts",
|
||||||
"lint": "oxlint --fix && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --cache && yarn typecheck && yarn check:i18n",
|
"lint": "oxlint --fix && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --cache && yarn typecheck && yarn check:i18n && yarn format:check",
|
||||||
"format": "biome format --write && biome lint --write",
|
"format": "biome format --write && biome lint --write",
|
||||||
"format:check": "biome format && biome lint",
|
"format:check": "biome format && biome lint",
|
||||||
"prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky",
|
"prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky",
|
||||||
@@ -79,15 +78,12 @@
|
|||||||
"release:aicore": "yarn workspace @cherrystudio/ai-core version patch --immediate && yarn workspace @cherrystudio/ai-core npm publish --access public"
|
"release:aicore": "yarn workspace @cherrystudio/ai-core version patch --immediate && yarn workspace @cherrystudio/ai-core npm publish --access public"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@anthropic-ai/claude-code": "patch:@anthropic-ai/claude-code@npm%3A1.0.118#~/.yarn/patches/@anthropic-ai-claude-code-npm-1.0.118-bbf4e9e59f.patch",
|
"@anthropic-ai/claude-agent-sdk": "patch:@anthropic-ai/claude-agent-sdk@npm%3A0.1.1#~/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.1-d937b73fed.patch",
|
||||||
"@libsql/client": "0.14.0",
|
"@libsql/client": "0.14.0",
|
||||||
"@libsql/win32-x64-msvc": "^0.4.7",
|
"@libsql/win32-x64-msvc": "^0.4.7",
|
||||||
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
|
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
|
||||||
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
||||||
"@types/uuid": "^10.0.0",
|
|
||||||
"drizzle-orm": "^0.44.5",
|
|
||||||
"express": "^5.1.0",
|
"express": "^5.1.0",
|
||||||
"express-validator": "^7.2.1",
|
|
||||||
"font-list": "^2.0.0",
|
"font-list": "^2.0.0",
|
||||||
"graceful-fs": "^4.2.11",
|
"graceful-fs": "^4.2.11",
|
||||||
"jsdom": "26.1.0",
|
"jsdom": "26.1.0",
|
||||||
@@ -105,10 +101,10 @@
|
|||||||
"@agentic/exa": "^7.3.3",
|
"@agentic/exa": "^7.3.3",
|
||||||
"@agentic/searxng": "^7.3.3",
|
"@agentic/searxng": "^7.3.3",
|
||||||
"@agentic/tavily": "^7.3.3",
|
"@agentic/tavily": "^7.3.3",
|
||||||
"@ai-sdk/amazon-bedrock": "^3.0.21",
|
"@ai-sdk/amazon-bedrock": "^3.0.35",
|
||||||
"@ai-sdk/google-vertex": "^3.0.27",
|
"@ai-sdk/google-vertex": "^3.0.40",
|
||||||
"@ai-sdk/mistral": "^2.0.14",
|
"@ai-sdk/mistral": "^2.0.19",
|
||||||
"@ai-sdk/perplexity": "^2.0.9",
|
"@ai-sdk/perplexity": "^2.0.13",
|
||||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||||
"@anthropic-ai/sdk": "^0.41.0",
|
"@anthropic-ai/sdk": "^0.41.0",
|
||||||
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
||||||
@@ -158,6 +154,7 @@
|
|||||||
"@opentelemetry/sdk-trace-base": "^2.0.0",
|
"@opentelemetry/sdk-trace-base": "^2.0.0",
|
||||||
"@opentelemetry/sdk-trace-node": "^2.0.0",
|
"@opentelemetry/sdk-trace-node": "^2.0.0",
|
||||||
"@opentelemetry/sdk-trace-web": "^2.0.0",
|
"@opentelemetry/sdk-trace-web": "^2.0.0",
|
||||||
|
"@opeoginni/github-copilot-openai-compatible": "patch:@opeoginni/github-copilot-openai-compatible@npm%3A0.1.18#~/.yarn/patches/@opeoginni-github-copilot-openai-compatible-npm-0.1.18-3f65760532.patch",
|
||||||
"@playwright/test": "^1.52.0",
|
"@playwright/test": "^1.52.0",
|
||||||
"@radix-ui/react-context-menu": "^2.2.16",
|
"@radix-ui/react-context-menu": "^2.2.16",
|
||||||
"@reduxjs/toolkit": "^2.2.5",
|
"@reduxjs/toolkit": "^2.2.5",
|
||||||
@@ -211,6 +208,7 @@
|
|||||||
"@types/swagger-ui-express": "^4.1.8",
|
"@types/swagger-ui-express": "^4.1.8",
|
||||||
"@types/tinycolor2": "^1",
|
"@types/tinycolor2": "^1",
|
||||||
"@types/turndown": "^5.0.5",
|
"@types/turndown": "^5.0.5",
|
||||||
|
"@types/uuid": "^10.0.0",
|
||||||
"@types/word-extractor": "^1",
|
"@types/word-extractor": "^1",
|
||||||
"@typescript/native-preview": "latest",
|
"@typescript/native-preview": "latest",
|
||||||
"@uiw/codemirror-extensions-langs": "^4.25.1",
|
"@uiw/codemirror-extensions-langs": "^4.25.1",
|
||||||
@@ -224,7 +222,7 @@
|
|||||||
"@viz-js/lang-dot": "^1.0.5",
|
"@viz-js/lang-dot": "^1.0.5",
|
||||||
"@viz-js/viz": "^3.14.0",
|
"@viz-js/viz": "^3.14.0",
|
||||||
"@xyflow/react": "^12.4.4",
|
"@xyflow/react": "^12.4.4",
|
||||||
"ai": "^5.0.44",
|
"ai": "^5.0.68",
|
||||||
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
|
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
|
||||||
"archiver": "^7.0.1",
|
"archiver": "^7.0.1",
|
||||||
"async-mutex": "^0.5.0",
|
"async-mutex": "^0.5.0",
|
||||||
@@ -248,7 +246,8 @@
|
|||||||
"dompurify": "^3.2.6",
|
"dompurify": "^3.2.6",
|
||||||
"dotenv-cli": "^7.4.2",
|
"dotenv-cli": "^7.4.2",
|
||||||
"drizzle-kit": "^0.31.4",
|
"drizzle-kit": "^0.31.4",
|
||||||
"electron": "37.4.0",
|
"drizzle-orm": "^0.44.5",
|
||||||
|
"electron": "37.6.0",
|
||||||
"electron-builder": "26.0.15",
|
"electron-builder": "26.0.15",
|
||||||
"electron-devtools-installer": "^3.2.0",
|
"electron-devtools-installer": "^3.2.0",
|
||||||
"electron-reload": "^2.0.0-alpha.1",
|
"electron-reload": "^2.0.0-alpha.1",
|
||||||
@@ -260,10 +259,12 @@
|
|||||||
"emoji-picker-element": "^1.22.1",
|
"emoji-picker-element": "^1.22.1",
|
||||||
"epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch",
|
"epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch",
|
||||||
"eslint": "^9.22.0",
|
"eslint": "^9.22.0",
|
||||||
|
"eslint-plugin-import-zod": "^1.2.0",
|
||||||
"eslint-plugin-oxlint": "^1.15.0",
|
"eslint-plugin-oxlint": "^1.15.0",
|
||||||
"eslint-plugin-react-hooks": "^5.2.0",
|
"eslint-plugin-react-hooks": "^7.0.0",
|
||||||
"eslint-plugin-simple-import-sort": "^12.1.1",
|
"eslint-plugin-simple-import-sort": "^12.1.1",
|
||||||
"eslint-plugin-unused-imports": "^4.1.4",
|
"eslint-plugin-unused-imports": "^4.1.4",
|
||||||
|
"express-validator": "^7.2.1",
|
||||||
"fast-diff": "^1.3.0",
|
"fast-diff": "^1.3.0",
|
||||||
"fast-xml-parser": "^5.2.0",
|
"fast-xml-parser": "^5.2.0",
|
||||||
"fetch-socks": "1.3.2",
|
"fetch-socks": "1.3.2",
|
||||||
@@ -296,15 +297,15 @@
|
|||||||
"notion-helper": "^1.3.22",
|
"notion-helper": "^1.3.22",
|
||||||
"npx-scope-finder": "^1.2.0",
|
"npx-scope-finder": "^1.2.0",
|
||||||
"openai": "patch:openai@npm%3A5.12.2#~/.yarn/patches/openai-npm-5.12.2-30b075401c.patch",
|
"openai": "patch:openai@npm%3A5.12.2#~/.yarn/patches/openai-npm-5.12.2-30b075401c.patch",
|
||||||
"oxlint": "^1.15.0",
|
"oxlint": "^1.22.0",
|
||||||
"oxlint-tsgolint": "^0.2.0",
|
"oxlint-tsgolint": "^0.2.0",
|
||||||
"p-queue": "^8.1.0",
|
"p-queue": "^8.1.0",
|
||||||
"pdf-lib": "^1.17.1",
|
"pdf-lib": "^1.17.1",
|
||||||
"pdf-parse": "^1.1.1",
|
"pdf-parse": "^1.1.1",
|
||||||
"playwright": "^1.52.0",
|
"playwright": "^1.52.0",
|
||||||
"proxy-agent": "^6.5.0",
|
"proxy-agent": "^6.5.0",
|
||||||
"react": "^19.0.0",
|
"react": "^19.2.0",
|
||||||
"react-dom": "^19.0.0",
|
"react-dom": "^19.2.0",
|
||||||
"react-error-boundary": "^6.0.0",
|
"react-error-boundary": "^6.0.0",
|
||||||
"react-hotkeys-hook": "^4.6.1",
|
"react-hotkeys-hook": "^4.6.1",
|
||||||
"react-i18next": "^14.1.2",
|
"react-i18next": "^14.1.2",
|
||||||
@@ -371,6 +372,7 @@
|
|||||||
"app-builder-lib@npm:26.0.13": "patch:app-builder-lib@npm%3A26.0.13#~/.yarn/patches/app-builder-lib-npm-26.0.13-a064c9e1d0.patch",
|
"app-builder-lib@npm:26.0.13": "patch:app-builder-lib@npm%3A26.0.13#~/.yarn/patches/app-builder-lib-npm-26.0.13-a064c9e1d0.patch",
|
||||||
"app-builder-lib@npm:26.0.15": "patch:app-builder-lib@npm%3A26.0.15#~/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch",
|
"app-builder-lib@npm:26.0.15": "patch:app-builder-lib@npm%3A26.0.15#~/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch",
|
||||||
"atomically@npm:^1.7.0": "patch:atomically@npm%3A1.7.0#~/.yarn/patches/atomically-npm-1.7.0-e742e5293b.patch",
|
"atomically@npm:^1.7.0": "patch:atomically@npm%3A1.7.0#~/.yarn/patches/atomically-npm-1.7.0-e742e5293b.patch",
|
||||||
|
"esbuild": "^0.25.0",
|
||||||
"file-stream-rotator@npm:^0.6.1": "patch:file-stream-rotator@npm%3A0.6.1#~/.yarn/patches/file-stream-rotator-npm-0.6.1-eab45fb13d.patch",
|
"file-stream-rotator@npm:^0.6.1": "patch:file-stream-rotator@npm%3A0.6.1#~/.yarn/patches/file-stream-rotator-npm-0.6.1-eab45fb13d.patch",
|
||||||
"libsql@npm:^0.4.4": "patch:libsql@npm%3A0.4.7#~/.yarn/patches/libsql-npm-0.4.7-444e260fb1.patch",
|
"libsql@npm:^0.4.4": "patch:libsql@npm%3A0.4.7#~/.yarn/patches/libsql-npm-0.4.7-444e260fb1.patch",
|
||||||
"node-abi": "4.12.0",
|
"node-abi": "4.12.0",
|
||||||
@@ -378,9 +380,11 @@
|
|||||||
"openai@npm:^4.87.3": "patch:openai@npm%3A5.12.2#~/.yarn/patches/openai-npm-5.12.2-30b075401c.patch",
|
"openai@npm:^4.87.3": "patch:openai@npm%3A5.12.2#~/.yarn/patches/openai-npm-5.12.2-30b075401c.patch",
|
||||||
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
|
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
|
||||||
"pkce-challenge@npm:^4.1.0": "patch:pkce-challenge@npm%3A4.1.0#~/.yarn/patches/pkce-challenge-npm-4.1.0-fbc51695a3.patch",
|
"pkce-challenge@npm:^4.1.0": "patch:pkce-challenge@npm%3A4.1.0#~/.yarn/patches/pkce-challenge-npm-4.1.0-fbc51695a3.patch",
|
||||||
|
"tar-fs": "^2.1.4",
|
||||||
"undici": "6.21.2",
|
"undici": "6.21.2",
|
||||||
"vite": "npm:rolldown-vite@latest",
|
"vite": "npm:rolldown-vite@latest",
|
||||||
"tesseract.js@npm:*": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch"
|
"tesseract.js@npm:*": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch",
|
||||||
|
"@ai-sdk/google@npm:2.0.20": "patch:@ai-sdk/google@npm%3A2.0.20#~/.yarn/patches/@ai-sdk-google-npm-2.0.20-b9102f9d54.patch"
|
||||||
},
|
},
|
||||||
"packageManager": "yarn@4.9.1",
|
"packageManager": "yarn@4.9.1",
|
||||||
"lint-staged": {
|
"lint-staged": {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@cherrystudio/ai-core",
|
"name": "@cherrystudio/ai-core",
|
||||||
"version": "1.0.0-alpha.18",
|
"version": "1.0.1",
|
||||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"module": "dist/index.mjs",
|
"module": "dist/index.mjs",
|
||||||
@@ -36,15 +36,14 @@
|
|||||||
"ai": "^5.0.26"
|
"ai": "^5.0.26"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@ai-sdk/anthropic": "^2.0.17",
|
"@ai-sdk/anthropic": "^2.0.27",
|
||||||
"@ai-sdk/azure": "^2.0.30",
|
"@ai-sdk/azure": "^2.0.49",
|
||||||
"@ai-sdk/deepseek": "^1.0.17",
|
"@ai-sdk/deepseek": "^1.0.23",
|
||||||
"@ai-sdk/google": "^2.0.14",
|
"@ai-sdk/openai": "^2.0.48",
|
||||||
"@ai-sdk/openai": "^2.0.30",
|
"@ai-sdk/openai-compatible": "^1.0.22",
|
||||||
"@ai-sdk/openai-compatible": "^1.0.17",
|
|
||||||
"@ai-sdk/provider": "^2.0.0",
|
"@ai-sdk/provider": "^2.0.0",
|
||||||
"@ai-sdk/provider-utils": "^3.0.9",
|
"@ai-sdk/provider-utils": "^3.0.12",
|
||||||
"@ai-sdk/xai": "^2.0.18",
|
"@ai-sdk/xai": "^2.0.26",
|
||||||
"zod": "^4.1.5"
|
"zod": "^4.1.5"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
|||||||
@@ -261,22 +261,39 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
|||||||
return params
|
return params
|
||||||
}
|
}
|
||||||
|
|
||||||
context.mcpTools = params.tools
|
// 分离 provider-defined 和其他类型的工具
|
||||||
|
const providerDefinedTools: ToolSet = {}
|
||||||
|
const promptTools: ToolSet = {}
|
||||||
|
|
||||||
// 构建系统提示符
|
for (const [toolName, tool] of Object.entries(params.tools as ToolSet)) {
|
||||||
|
if (tool.type === 'provider-defined') {
|
||||||
|
// provider-defined 类型的工具保留在 tools 参数中
|
||||||
|
providerDefinedTools[toolName] = tool
|
||||||
|
} else {
|
||||||
|
// 其他工具转换为 prompt 模式
|
||||||
|
promptTools[toolName] = tool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 只有当有非 provider-defined 工具时才保存到 context
|
||||||
|
if (Object.keys(promptTools).length > 0) {
|
||||||
|
context.mcpTools = promptTools
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建系统提示符(只包含非 provider-defined 工具)
|
||||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||||
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
|
const systemPrompt = buildSystemPrompt(userSystemPrompt, promptTools)
|
||||||
let systemMessage: string | null = systemPrompt
|
let systemMessage: string | null = systemPrompt
|
||||||
if (config.createSystemMessage) {
|
if (config.createSystemMessage) {
|
||||||
// 🎯 如果用户提供了自定义处理函数,使用它
|
// 🎯 如果用户提供了自定义处理函数,使用它
|
||||||
systemMessage = config.createSystemMessage(systemPrompt, params, context)
|
systemMessage = config.createSystemMessage(systemPrompt, params, context)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 移除 tools,改为 prompt 模式
|
// 保留 provider-defined tools,移除其他 tools
|
||||||
const transformedParams = {
|
const transformedParams = {
|
||||||
...params,
|
...params,
|
||||||
...(systemMessage ? { system: systemMessage } : {}),
|
...(systemMessage ? { system: systemMessage } : {}),
|
||||||
tools: undefined
|
tools: Object.keys(providerDefinedTools).length > 0 ? providerDefinedTools : undefined
|
||||||
}
|
}
|
||||||
context.originalParams = transformedParams
|
context.originalParams = transformedParams
|
||||||
return transformedParams
|
return transformedParams
|
||||||
@@ -285,8 +302,9 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
|||||||
let textBuffer = ''
|
let textBuffer = ''
|
||||||
// let stepId = ''
|
// let stepId = ''
|
||||||
|
|
||||||
|
// 如果没有需要 prompt 模式处理的工具,直接返回原始流
|
||||||
if (!context.mcpTools) {
|
if (!context.mcpTools) {
|
||||||
throw new Error('No tools available')
|
return new TransformStream()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 从 context 中获取或初始化 usage 累加器
|
// 从 context 中获取或初始化 usage 累加器
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import { anthropic } from '@ai-sdk/anthropic'
|
import { anthropic } from '@ai-sdk/anthropic'
|
||||||
import { google } from '@ai-sdk/google'
|
import { google } from '@ai-sdk/google'
|
||||||
import { openai } from '@ai-sdk/openai'
|
import { openai } from '@ai-sdk/openai'
|
||||||
|
import { InferToolInput, InferToolOutput, type Tool } from 'ai'
|
||||||
|
|
||||||
import { ProviderOptionsMap } from '../../../options/types'
|
import { ProviderOptionsMap } from '../../../options/types'
|
||||||
import { OpenRouterSearchConfig } from './openrouter'
|
import { OpenRouterSearchConfig } from './openrouter'
|
||||||
@@ -14,6 +15,13 @@ export type AnthropicSearchConfig = NonNullable<Parameters<typeof anthropic.tool
|
|||||||
export type GoogleSearchConfig = NonNullable<Parameters<typeof google.tools.googleSearch>[0]>
|
export type GoogleSearchConfig = NonNullable<Parameters<typeof google.tools.googleSearch>[0]>
|
||||||
export type XAISearchConfig = NonNullable<ProviderOptionsMap['xai']['searchParameters']>
|
export type XAISearchConfig = NonNullable<ProviderOptionsMap['xai']['searchParameters']>
|
||||||
|
|
||||||
|
type NormalizeTool<T> = T extends Tool<infer INPUT, infer OUTPUT> ? Tool<INPUT, OUTPUT> : Tool<any, any>
|
||||||
|
|
||||||
|
type AnthropicWebSearchTool = NormalizeTool<ReturnType<typeof anthropic.tools.webSearch_20250305>>
|
||||||
|
type OpenAIWebSearchTool = NormalizeTool<ReturnType<typeof openai.tools.webSearch>>
|
||||||
|
type OpenAIChatWebSearchTool = NormalizeTool<ReturnType<typeof openai.tools.webSearchPreview>>
|
||||||
|
type GoogleWebSearchTool = NormalizeTool<ReturnType<typeof google.tools.googleSearch>>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 插件初始化时接收的完整配置对象
|
* 插件初始化时接收的完整配置对象
|
||||||
*
|
*
|
||||||
@@ -58,24 +66,31 @@ export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
|
|||||||
|
|
||||||
export type WebSearchToolOutputSchema = {
|
export type WebSearchToolOutputSchema = {
|
||||||
// Anthropic 工具 - 手动定义
|
// Anthropic 工具 - 手动定义
|
||||||
anthropicWebSearch: Array<{
|
anthropic: InferToolOutput<AnthropicWebSearchTool>
|
||||||
url: string
|
|
||||||
title: string
|
|
||||||
pageAge: string | null
|
|
||||||
encryptedContent: string
|
|
||||||
type: string
|
|
||||||
}>
|
|
||||||
|
|
||||||
// OpenAI 工具 - 基于实际输出
|
// OpenAI 工具 - 基于实际输出
|
||||||
openaiWebSearch: {
|
// TODO: 上游定义不规范,是unknown
|
||||||
|
// openai: InferToolOutput<ReturnType<typeof openai.tools.webSearch>>
|
||||||
|
openai: {
|
||||||
|
status: 'completed' | 'failed'
|
||||||
|
}
|
||||||
|
'openai-chat': {
|
||||||
status: 'completed' | 'failed'
|
status: 'completed' | 'failed'
|
||||||
}
|
}
|
||||||
|
|
||||||
// Google 工具
|
// Google 工具
|
||||||
googleSearch: {
|
// TODO: 上游定义不规范,是unknown
|
||||||
|
// google: InferToolOutput<ReturnType<typeof google.tools.googleSearch>>
|
||||||
|
google: {
|
||||||
webSearchQueries?: string[]
|
webSearchQueries?: string[]
|
||||||
groundingChunks?: Array<{
|
groundingChunks?: Array<{
|
||||||
web?: { uri: string; title: string }
|
web?: { uri: string; title: string }
|
||||||
}>
|
}>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type WebSearchToolInputSchema = {
|
||||||
|
anthropic: InferToolInput<AnthropicWebSearchTool>
|
||||||
|
openai: InferToolInput<OpenAIWebSearchTool>
|
||||||
|
google: InferToolInput<GoogleWebSearchTool>
|
||||||
|
'openai-chat': InferToolInput<OpenAIChatWebSearchTool>
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import { LanguageModelV2 } from '@ai-sdk/provider'
|
|||||||
import { createXai } from '@ai-sdk/xai'
|
import { createXai } from '@ai-sdk/xai'
|
||||||
import { createOpenRouter } from '@openrouter/ai-sdk-provider'
|
import { createOpenRouter } from '@openrouter/ai-sdk-provider'
|
||||||
import { customProvider, Provider } from 'ai'
|
import { customProvider, Provider } from 'ai'
|
||||||
import { z } from 'zod'
|
import * as z from 'zod'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 基础 Provider IDs
|
* 基础 Provider IDs
|
||||||
|
|||||||
@@ -5,9 +5,10 @@ export enum IpcChannel {
|
|||||||
App_SetLanguage = 'app:set-language',
|
App_SetLanguage = 'app:set-language',
|
||||||
App_SetEnableSpellCheck = 'app:set-enable-spell-check',
|
App_SetEnableSpellCheck = 'app:set-enable-spell-check',
|
||||||
App_SetSpellCheckLanguages = 'app:set-spell-check-languages',
|
App_SetSpellCheckLanguages = 'app:set-spell-check-languages',
|
||||||
App_ShowUpdateDialog = 'app:show-update-dialog',
|
|
||||||
App_CheckForUpdate = 'app:check-for-update',
|
App_CheckForUpdate = 'app:check-for-update',
|
||||||
|
App_QuitAndInstall = 'app:quit-and-install',
|
||||||
App_Reload = 'app:reload',
|
App_Reload = 'app:reload',
|
||||||
|
App_Quit = 'app:quit',
|
||||||
App_Info = 'app:info',
|
App_Info = 'app:info',
|
||||||
App_Proxy = 'app:proxy',
|
App_Proxy = 'app:proxy',
|
||||||
App_SetLaunchToTray = 'app:set-launch-to-tray',
|
App_SetLaunchToTray = 'app:set-launch-to-tray',
|
||||||
@@ -33,6 +34,7 @@ export enum IpcChannel {
|
|||||||
App_GetBinaryPath = 'app:get-binary-path',
|
App_GetBinaryPath = 'app:get-binary-path',
|
||||||
App_InstallUvBinary = 'app:install-uv-binary',
|
App_InstallUvBinary = 'app:install-uv-binary',
|
||||||
App_InstallBunBinary = 'app:install-bun-binary',
|
App_InstallBunBinary = 'app:install-bun-binary',
|
||||||
|
App_InstallOvmsBinary = 'app:install-ovms-binary',
|
||||||
App_LogToMain = 'app:log-to-main',
|
App_LogToMain = 'app:log-to-main',
|
||||||
App_SaveData = 'app:save-data',
|
App_SaveData = 'app:save-data',
|
||||||
App_GetDiskInfo = 'app:get-disk-info',
|
App_GetDiskInfo = 'app:get-disk-info',
|
||||||
@@ -51,6 +53,7 @@ export enum IpcChannel {
|
|||||||
|
|
||||||
Webview_SetOpenLinkExternal = 'webview:set-open-link-external',
|
Webview_SetOpenLinkExternal = 'webview:set-open-link-external',
|
||||||
Webview_SetSpellCheckEnabled = 'webview:set-spell-check-enabled',
|
Webview_SetSpellCheckEnabled = 'webview:set-spell-check-enabled',
|
||||||
|
Webview_SearchHotkey = 'webview:search-hotkey',
|
||||||
|
|
||||||
// Open
|
// Open
|
||||||
Open_Path = 'open:path',
|
Open_Path = 'open:path',
|
||||||
@@ -91,6 +94,7 @@ export enum IpcChannel {
|
|||||||
|
|
||||||
// agent messages
|
// agent messages
|
||||||
AgentMessage_PersistExchange = 'agent-message:persist-exchange',
|
AgentMessage_PersistExchange = 'agent-message:persist-exchange',
|
||||||
|
AgentMessage_GetHistory = 'agent-message:get-history',
|
||||||
|
|
||||||
//copilot
|
//copilot
|
||||||
Copilot_GetAuthMessage = 'copilot:get-auth-message',
|
Copilot_GetAuthMessage = 'copilot:get-auth-message',
|
||||||
@@ -185,6 +189,7 @@ export enum IpcChannel {
|
|||||||
File_ValidateNotesDirectory = 'file:validateNotesDirectory',
|
File_ValidateNotesDirectory = 'file:validateNotesDirectory',
|
||||||
File_StartWatcher = 'file:startWatcher',
|
File_StartWatcher = 'file:startWatcher',
|
||||||
File_StopWatcher = 'file:stopWatcher',
|
File_StopWatcher = 'file:stopWatcher',
|
||||||
|
File_ShowInFolder = 'file:showInFolder',
|
||||||
|
|
||||||
// file service
|
// file service
|
||||||
FileService_Upload = 'file-service:upload',
|
FileService_Upload = 'file-service:upload',
|
||||||
@@ -222,6 +227,7 @@ export enum IpcChannel {
|
|||||||
// system
|
// system
|
||||||
System_GetDeviceType = 'system:getDeviceType',
|
System_GetDeviceType = 'system:getDeviceType',
|
||||||
System_GetHostname = 'system:getHostname',
|
System_GetHostname = 'system:getHostname',
|
||||||
|
System_GetCpuName = 'system:getCpuName',
|
||||||
|
|
||||||
// DevTools
|
// DevTools
|
||||||
System_ToggleDevTools = 'system:toggleDevTools',
|
System_ToggleDevTools = 'system:toggleDevTools',
|
||||||
@@ -229,7 +235,6 @@ export enum IpcChannel {
|
|||||||
// events
|
// events
|
||||||
BackupProgress = 'backup-progress',
|
BackupProgress = 'backup-progress',
|
||||||
ThemeUpdated = 'theme:updated',
|
ThemeUpdated = 'theme:updated',
|
||||||
UpdateDownloadedCancelled = 'update-downloaded-cancelled',
|
|
||||||
RestoreProgress = 'restore-progress',
|
RestoreProgress = 'restore-progress',
|
||||||
UpdateError = 'update-error',
|
UpdateError = 'update-error',
|
||||||
UpdateAvailable = 'update-available',
|
UpdateAvailable = 'update-available',
|
||||||
@@ -312,6 +317,7 @@ export enum IpcChannel {
|
|||||||
ApiServer_Stop = 'api-server:stop',
|
ApiServer_Stop = 'api-server:stop',
|
||||||
ApiServer_Restart = 'api-server:restart',
|
ApiServer_Restart = 'api-server:restart',
|
||||||
ApiServer_GetStatus = 'api-server:get-status',
|
ApiServer_GetStatus = 'api-server:get-status',
|
||||||
|
// NOTE: This api is not be used.
|
||||||
ApiServer_GetConfig = 'api-server:get-config',
|
ApiServer_GetConfig = 'api-server:get-config',
|
||||||
|
|
||||||
// Anthropic OAuth
|
// Anthropic OAuth
|
||||||
@@ -324,10 +330,24 @@ export enum IpcChannel {
|
|||||||
|
|
||||||
// CodeTools
|
// CodeTools
|
||||||
CodeTools_Run = 'code-tools:run',
|
CodeTools_Run = 'code-tools:run',
|
||||||
|
CodeTools_GetAvailableTerminals = 'code-tools:get-available-terminals',
|
||||||
|
CodeTools_SetCustomTerminalPath = 'code-tools:set-custom-terminal-path',
|
||||||
|
CodeTools_GetCustomTerminalPath = 'code-tools:get-custom-terminal-path',
|
||||||
|
CodeTools_RemoveCustomTerminalPath = 'code-tools:remove-custom-terminal-path',
|
||||||
|
|
||||||
// OCR
|
// OCR
|
||||||
OCR_ocr = 'ocr:ocr',
|
OCR_ocr = 'ocr:ocr',
|
||||||
|
OCR_ListProviders = 'ocr:list-providers',
|
||||||
|
|
||||||
// Cherryin
|
// OVMS
|
||||||
Cherryin_GetSignature = 'cherryin:get-signature'
|
Ovms_AddModel = 'ovms:add-model',
|
||||||
|
Ovms_StopAddModel = 'ovms:stop-addmodel',
|
||||||
|
Ovms_GetModels = 'ovms:get-models',
|
||||||
|
Ovms_IsRunning = 'ovms:is-running',
|
||||||
|
Ovms_GetStatus = 'ovms:get-status',
|
||||||
|
Ovms_RunOVMS = 'ovms:run-ovms',
|
||||||
|
Ovms_StopOVMS = 'ovms:stop-ovms',
|
||||||
|
|
||||||
|
// CherryAI
|
||||||
|
Cherryai_GetSignature = 'cherryai:get-signature'
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
import type { SDKMessage } from '@anthropic-ai/claude-agent-sdk'
|
||||||
|
import type { ContentBlockParam } from '@anthropic-ai/sdk/resources/messages'
|
||||||
|
|
||||||
|
export type ClaudeCodeRawValue =
|
||||||
|
| {
|
||||||
|
type: string
|
||||||
|
session_id: string
|
||||||
|
slash_commands: string[]
|
||||||
|
tools: string[]
|
||||||
|
raw: Extract<SDKMessage, { type: 'system' }>
|
||||||
|
}
|
||||||
|
| ContentBlockParam
|
||||||
@@ -0,0 +1,170 @@
|
|||||||
|
/**
|
||||||
|
* @fileoverview Shared Anthropic AI client utilities for Cherry Studio
|
||||||
|
*
|
||||||
|
* This module provides functions for creating Anthropic SDK clients with different
|
||||||
|
* authentication methods (OAuth, API key) and building Claude Code system messages.
|
||||||
|
* It supports both standard Anthropic API and Anthropic Vertex AI endpoints.
|
||||||
|
*
|
||||||
|
* This shared module can be used by both main and renderer processes.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import Anthropic from '@anthropic-ai/sdk'
|
||||||
|
import { TextBlockParam } from '@anthropic-ai/sdk/resources'
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
import { Provider } from '@types'
|
||||||
|
import type { ModelMessage } from 'ai'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('anthropic-sdk')
|
||||||
|
|
||||||
|
const defaultClaudeCodeSystemPrompt = `You are Claude Code, Anthropic's official CLI for Claude.`
|
||||||
|
|
||||||
|
const defaultClaudeCodeSystem: Array<TextBlockParam> = [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: defaultClaudeCodeSystemPrompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates and configures an Anthropic SDK client based on the provider configuration.
|
||||||
|
*
|
||||||
|
* This function supports two authentication methods:
|
||||||
|
* 1. OAuth: Uses OAuth tokens passed as parameter
|
||||||
|
* 2. API Key: Uses traditional API key authentication
|
||||||
|
*
|
||||||
|
* For OAuth authentication, it includes Claude Code specific headers and beta features.
|
||||||
|
* For API key authentication, it uses the provider's configuration with custom headers.
|
||||||
|
*
|
||||||
|
* @param provider - The provider configuration containing authentication details
|
||||||
|
* @param oauthToken - Optional OAuth token for OAuth authentication
|
||||||
|
* @returns An initialized Anthropic or AnthropicVertex client
|
||||||
|
* @throws Error when OAuth token is not available for OAuth authentication
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* ```typescript
|
||||||
|
* // OAuth authentication
|
||||||
|
* const oauthProvider = { authType: 'oauth' };
|
||||||
|
* const oauthClient = getSdkClient(oauthProvider, 'oauth-token-here');
|
||||||
|
*
|
||||||
|
* // API key authentication
|
||||||
|
* const apiKeyProvider = {
|
||||||
|
* authType: 'apikey',
|
||||||
|
* apiKey: 'your-api-key',
|
||||||
|
* apiHost: 'https://api.anthropic.com'
|
||||||
|
* };
|
||||||
|
* const apiKeyClient = getSdkClient(apiKeyProvider);
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
export function getSdkClient(
|
||||||
|
provider: Provider,
|
||||||
|
oauthToken?: string | null,
|
||||||
|
extraHeaders?: Record<string, string | string[]>
|
||||||
|
): Anthropic {
|
||||||
|
if (provider.authType === 'oauth') {
|
||||||
|
if (!oauthToken) {
|
||||||
|
throw new Error('OAuth token is not available')
|
||||||
|
}
|
||||||
|
return new Anthropic({
|
||||||
|
authToken: oauthToken,
|
||||||
|
baseURL: 'https://api.anthropic.com',
|
||||||
|
dangerouslyAllowBrowser: true,
|
||||||
|
defaultHeaders: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'anthropic-version': '2023-06-01',
|
||||||
|
'anthropic-beta':
|
||||||
|
'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14',
|
||||||
|
'anthropic-dangerous-direct-browser-access': 'true',
|
||||||
|
'user-agent': 'claude-cli/1.0.118 (external, sdk-ts)',
|
||||||
|
'x-app': 'cli',
|
||||||
|
'x-stainless-retry-count': '0',
|
||||||
|
'x-stainless-timeout': '600',
|
||||||
|
'x-stainless-lang': 'js',
|
||||||
|
'x-stainless-package-version': '0.60.0',
|
||||||
|
'x-stainless-os': 'MacOS',
|
||||||
|
'x-stainless-arch': 'arm64',
|
||||||
|
'x-stainless-runtime': 'node',
|
||||||
|
'x-stainless-runtime-version': 'v22.18.0',
|
||||||
|
...extraHeaders
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
const baseURL =
|
||||||
|
provider.type === 'anthropic'
|
||||||
|
? provider.apiHost
|
||||||
|
: (provider.anthropicApiHost && provider.anthropicApiHost.trim()) || provider.apiHost
|
||||||
|
|
||||||
|
logger.debug('Anthropic API baseURL', { baseURL, providerId: provider.id })
|
||||||
|
|
||||||
|
if (provider.id === 'aihubmix') {
|
||||||
|
return new Anthropic({
|
||||||
|
apiKey: provider.apiKey,
|
||||||
|
baseURL,
|
||||||
|
dangerouslyAllowBrowser: true,
|
||||||
|
defaultHeaders: {
|
||||||
|
'anthropic-beta': 'output-128k-2025-02-19',
|
||||||
|
'APP-Code': 'MLTG2087',
|
||||||
|
...provider.extra_headers,
|
||||||
|
...extraHeaders
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Anthropic({
|
||||||
|
apiKey: provider.apiKey,
|
||||||
|
authToken: provider.apiKey,
|
||||||
|
baseURL,
|
||||||
|
dangerouslyAllowBrowser: true,
|
||||||
|
defaultHeaders: {
|
||||||
|
'anthropic-beta': 'output-128k-2025-02-19',
|
||||||
|
...provider.extra_headers
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds and prepends the Claude Code system message to user-provided system messages.
|
||||||
|
*
|
||||||
|
* This function ensures that all interactions with Claude include the official Claude Code
|
||||||
|
* system prompt, which identifies the assistant as "Claude Code, Anthropic's official CLI for Claude."
|
||||||
|
*
|
||||||
|
* The function handles three cases:
|
||||||
|
* 1. No system message provided: Returns only the default Claude Code system message
|
||||||
|
* 2. String system message: Converts to array format and prepends Claude Code message
|
||||||
|
* 3. Array system message: Checks if Claude Code message exists and prepends if missing
|
||||||
|
*
|
||||||
|
* @param system - Optional user-provided system message (string or TextBlockParam array)
|
||||||
|
* @returns Combined system message with Claude Code prompt prepended
|
||||||
|
*
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
export function buildClaudeCodeSystemMessage(system?: string | Array<TextBlockParam>): Array<TextBlockParam> {
|
||||||
|
if (!system) {
|
||||||
|
return defaultClaudeCodeSystem
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof system === 'string') {
|
||||||
|
if (system.trim() === defaultClaudeCodeSystemPrompt || system.trim() === '') {
|
||||||
|
return defaultClaudeCodeSystem
|
||||||
|
} else {
|
||||||
|
return [...defaultClaudeCodeSystem, { type: 'text', text: system }]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Array.isArray(system)) {
|
||||||
|
const firstSystem = system[0]
|
||||||
|
if (firstSystem.type === 'text' && firstSystem.text.trim() === defaultClaudeCodeSystemPrompt) {
|
||||||
|
return system
|
||||||
|
} else {
|
||||||
|
return [...defaultClaudeCodeSystem, ...system]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultClaudeCodeSystem
|
||||||
|
}
|
||||||
|
|
||||||
|
export function buildClaudeCodeSystemModelMessage(system?: string | Array<TextBlockParam>): Array<ModelMessage> {
|
||||||
|
const textBlocks = buildClaudeCodeSystemMessage(system)
|
||||||
|
return textBlocks.map((block) => ({
|
||||||
|
role: 'system',
|
||||||
|
content: block.text
|
||||||
|
}))
|
||||||
|
}
|
||||||
@@ -217,5 +217,256 @@ export enum codeTools {
|
|||||||
claudeCode = 'claude-code',
|
claudeCode = 'claude-code',
|
||||||
geminiCli = 'gemini-cli',
|
geminiCli = 'gemini-cli',
|
||||||
openaiCodex = 'openai-codex',
|
openaiCodex = 'openai-codex',
|
||||||
iFlowCli = 'iflow-cli'
|
iFlowCli = 'iflow-cli',
|
||||||
|
githubCopilotCli = 'github-copilot-cli'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export enum terminalApps {
|
||||||
|
systemDefault = 'Terminal',
|
||||||
|
iterm2 = 'iTerm2',
|
||||||
|
kitty = 'kitty',
|
||||||
|
alacritty = 'Alacritty',
|
||||||
|
wezterm = 'WezTerm',
|
||||||
|
ghostty = 'Ghostty',
|
||||||
|
tabby = 'Tabby',
|
||||||
|
// Windows terminals
|
||||||
|
windowsTerminal = 'WindowsTerminal',
|
||||||
|
powershell = 'PowerShell',
|
||||||
|
cmd = 'CMD',
|
||||||
|
wsl = 'WSL'
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TerminalConfig {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
bundleId?: string
|
||||||
|
customPath?: string // For user-configured terminal paths on Windows
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TerminalConfigWithCommand extends TerminalConfig {
|
||||||
|
command: (directory: string, fullCommand: string) => { command: string; args: string[] }
|
||||||
|
}
|
||||||
|
|
||||||
|
export const MACOS_TERMINALS: TerminalConfig[] = [
|
||||||
|
{
|
||||||
|
id: terminalApps.systemDefault,
|
||||||
|
name: 'Terminal',
|
||||||
|
bundleId: 'com.apple.Terminal'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.iterm2,
|
||||||
|
name: 'iTerm2',
|
||||||
|
bundleId: 'com.googlecode.iterm2'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.kitty,
|
||||||
|
name: 'kitty',
|
||||||
|
bundleId: 'net.kovidgoyal.kitty'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.alacritty,
|
||||||
|
name: 'Alacritty',
|
||||||
|
bundleId: 'org.alacritty'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.wezterm,
|
||||||
|
name: 'WezTerm',
|
||||||
|
bundleId: 'com.github.wez.wezterm'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.ghostty,
|
||||||
|
name: 'Ghostty',
|
||||||
|
bundleId: 'com.mitchellh.ghostty'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.tabby,
|
||||||
|
name: 'Tabby',
|
||||||
|
bundleId: 'org.tabby'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
export const WINDOWS_TERMINALS: TerminalConfig[] = [
|
||||||
|
{
|
||||||
|
id: terminalApps.cmd,
|
||||||
|
name: 'Command Prompt'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.powershell,
|
||||||
|
name: 'PowerShell'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.windowsTerminal,
|
||||||
|
name: 'Windows Terminal'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.wsl,
|
||||||
|
name: 'WSL (Ubuntu/Debian)'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.alacritty,
|
||||||
|
name: 'Alacritty'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.wezterm,
|
||||||
|
name: 'WezTerm'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
export const WINDOWS_TERMINALS_WITH_COMMANDS: TerminalConfigWithCommand[] = [
|
||||||
|
{
|
||||||
|
id: terminalApps.cmd,
|
||||||
|
name: 'Command Prompt',
|
||||||
|
command: (_: string, fullCommand: string) => ({
|
||||||
|
command: 'cmd',
|
||||||
|
args: ['/c', 'start', 'cmd', '/k', fullCommand]
|
||||||
|
})
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.powershell,
|
||||||
|
name: 'PowerShell',
|
||||||
|
command: (_: string, fullCommand: string) => ({
|
||||||
|
command: 'cmd',
|
||||||
|
args: ['/c', 'start', 'powershell', '-NoExit', '-Command', `& '${fullCommand}'`]
|
||||||
|
})
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.windowsTerminal,
|
||||||
|
name: 'Windows Terminal',
|
||||||
|
command: (_: string, fullCommand: string) => ({
|
||||||
|
command: 'wt',
|
||||||
|
args: ['cmd', '/k', fullCommand]
|
||||||
|
})
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.wsl,
|
||||||
|
name: 'WSL (Ubuntu/Debian)',
|
||||||
|
command: (_: string, fullCommand: string) => {
|
||||||
|
// Start WSL in a new window and execute the batch file from within WSL using cmd.exe
|
||||||
|
// The batch file will run in Windows context but output will be in WSL terminal
|
||||||
|
return {
|
||||||
|
command: 'cmd',
|
||||||
|
args: ['/c', 'start', 'wsl', '-e', 'bash', '-c', `cmd.exe /c '${fullCommand}' ; exec bash`]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.alacritty,
|
||||||
|
name: 'Alacritty',
|
||||||
|
customPath: '', // Will be set by user in settings
|
||||||
|
command: (_: string, fullCommand: string) => ({
|
||||||
|
command: 'alacritty', // Will be replaced with customPath if set
|
||||||
|
args: ['-e', 'cmd', '/k', fullCommand]
|
||||||
|
})
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.wezterm,
|
||||||
|
name: 'WezTerm',
|
||||||
|
customPath: '', // Will be set by user in settings
|
||||||
|
command: (_: string, fullCommand: string) => ({
|
||||||
|
command: 'wezterm', // Will be replaced with customPath if set
|
||||||
|
args: ['start', 'cmd', '/k', fullCommand]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
// Helper function to escape strings for AppleScript
|
||||||
|
const escapeForAppleScript = (str: string): string => {
|
||||||
|
// In AppleScript strings, backslashes and double quotes need to be escaped
|
||||||
|
// When passed through osascript -e with single quotes, we need:
|
||||||
|
// 1. Backslash: \ -> \\
|
||||||
|
// 2. Double quote: " -> \"
|
||||||
|
return str
|
||||||
|
.replace(/\\/g, '\\\\') // Escape backslashes first
|
||||||
|
.replace(/"/g, '\\"') // Then escape double quotes
|
||||||
|
}
|
||||||
|
|
||||||
|
export const MACOS_TERMINALS_WITH_COMMANDS: TerminalConfigWithCommand[] = [
|
||||||
|
{
|
||||||
|
id: terminalApps.systemDefault,
|
||||||
|
name: 'Terminal',
|
||||||
|
bundleId: 'com.apple.Terminal',
|
||||||
|
command: (_directory: string, fullCommand: string) => ({
|
||||||
|
command: 'sh',
|
||||||
|
args: [
|
||||||
|
'-c',
|
||||||
|
`open -na Terminal && sleep 0.5 && osascript -e 'tell application "Terminal" to activate' -e 'tell application "Terminal" to do script "${escapeForAppleScript(fullCommand)}" in front window'`
|
||||||
|
]
|
||||||
|
})
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.iterm2,
|
||||||
|
name: 'iTerm2',
|
||||||
|
bundleId: 'com.googlecode.iterm2',
|
||||||
|
command: (_directory: string, fullCommand: string) => ({
|
||||||
|
command: 'sh',
|
||||||
|
args: [
|
||||||
|
'-c',
|
||||||
|
`open -na iTerm && sleep 0.8 && osascript -e 'on waitUntilRunning()\n repeat 50 times\n tell application "System Events"\n if (exists process "iTerm2") then exit repeat\n end tell\n delay 0.1\n end repeat\nend waitUntilRunning\n\nwaitUntilRunning()\n\ntell application "iTerm2"\n if (count of windows) = 0 then\n create window with default profile\n delay 0.3\n else\n tell current window\n create tab with default profile\n end tell\n delay 0.3\n end if\n tell current session of current window to write text "${escapeForAppleScript(fullCommand)}"\n activate\nend tell'`
|
||||||
|
]
|
||||||
|
})
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.kitty,
|
||||||
|
name: 'kitty',
|
||||||
|
bundleId: 'net.kovidgoyal.kitty',
|
||||||
|
command: (_directory: string, fullCommand: string) => ({
|
||||||
|
command: 'sh',
|
||||||
|
args: [
|
||||||
|
'-c',
|
||||||
|
`cd "${_directory}" && open -na kitty --args --directory="${_directory}" sh -c "${fullCommand.replace(/\\/g, '\\\\').replace(/"/g, '\\"')}; exec \\$SHELL" && sleep 0.5 && osascript -e 'tell application "kitty" to activate'`
|
||||||
|
]
|
||||||
|
})
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.alacritty,
|
||||||
|
name: 'Alacritty',
|
||||||
|
bundleId: 'org.alacritty',
|
||||||
|
command: (_directory: string, fullCommand: string) => ({
|
||||||
|
command: 'sh',
|
||||||
|
args: [
|
||||||
|
'-c',
|
||||||
|
`open -na Alacritty --args --working-directory "${_directory}" -e sh -c "${fullCommand.replace(/\\/g, '\\\\').replace(/"/g, '\\"')}; exec \\$SHELL" && sleep 0.5 && osascript -e 'tell application "Alacritty" to activate'`
|
||||||
|
]
|
||||||
|
})
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.wezterm,
|
||||||
|
name: 'WezTerm',
|
||||||
|
bundleId: 'com.github.wez.wezterm',
|
||||||
|
command: (_directory: string, fullCommand: string) => ({
|
||||||
|
command: 'sh',
|
||||||
|
args: [
|
||||||
|
'-c',
|
||||||
|
`open -na WezTerm --args start --new-tab --cwd "${_directory}" -- sh -c "${fullCommand.replace(/\\/g, '\\\\').replace(/"/g, '\\"')}; exec \\$SHELL" && sleep 0.5 && osascript -e 'tell application "WezTerm" to activate'`
|
||||||
|
]
|
||||||
|
})
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.ghostty,
|
||||||
|
name: 'Ghostty',
|
||||||
|
bundleId: 'com.mitchellh.ghostty',
|
||||||
|
command: (_directory: string, fullCommand: string) => ({
|
||||||
|
command: 'sh',
|
||||||
|
args: [
|
||||||
|
'-c',
|
||||||
|
`cd "${_directory}" && open -na Ghostty --args --working-directory="${_directory}" -e sh -c "${fullCommand.replace(/\\/g, '\\\\').replace(/"/g, '\\"')}; exec \\$SHELL" && sleep 0.5 && osascript -e 'tell application "Ghostty" to activate'`
|
||||||
|
]
|
||||||
|
})
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: terminalApps.tabby,
|
||||||
|
name: 'Tabby',
|
||||||
|
bundleId: 'org.tabby',
|
||||||
|
command: (_directory: string, fullCommand: string) => ({
|
||||||
|
command: 'sh',
|
||||||
|
args: [
|
||||||
|
'-c',
|
||||||
|
`if pgrep -x "Tabby" > /dev/null; then
|
||||||
|
open -na Tabby --args open && sleep 0.3
|
||||||
|
else
|
||||||
|
open -na Tabby --args open && sleep 2
|
||||||
|
fi && osascript -e 'tell application "Tabby" to activate' -e 'set the clipboard to "${escapeForAppleScript(fullCommand)}"' -e 'tell application "System Events" to tell process "Tabby" to keystroke "v" using {command down}' -e 'tell application "System Events" to key code 36'`
|
||||||
|
]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|||||||
@@ -22,3 +22,12 @@ export type MCPProgressEvent = {
|
|||||||
callId: string
|
callId: string
|
||||||
progress: number // 0-1 range
|
progress: number // 0-1 range
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type WebviewKeyEvent = {
|
||||||
|
webviewId: number
|
||||||
|
key: string
|
||||||
|
control: boolean
|
||||||
|
meta: boolean
|
||||||
|
shift: boolean
|
||||||
|
alt: boolean
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,252 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Privacy Policy</title>
|
||||||
|
<style>
|
||||||
|
* {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Helvetica Neue', Arial, sans-serif;
|
||||||
|
line-height: 1.6;
|
||||||
|
color: #333;
|
||||||
|
background: transparent;
|
||||||
|
margin: 0 auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark {
|
||||||
|
background: transparent;
|
||||||
|
color: rgba(255, 255, 255, 0.85);
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 {
|
||||||
|
font-size: 24px;
|
||||||
|
font-weight: 600;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
color: #1a1a1a;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark h1 {
|
||||||
|
color: rgba(255, 255, 255, 0.95);
|
||||||
|
}
|
||||||
|
|
||||||
|
h2 {
|
||||||
|
font-size: 18px;
|
||||||
|
font-weight: 600;
|
||||||
|
margin-top: 24px;
|
||||||
|
margin-bottom: 12px;
|
||||||
|
color: #2c2c2c;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark h2 {
|
||||||
|
color: rgba(255, 255, 255, 0.9);
|
||||||
|
}
|
||||||
|
|
||||||
|
p {
|
||||||
|
margin: 12px 0;
|
||||||
|
line-height: 1.8;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark p {
|
||||||
|
color: rgba(255, 255, 255, 0.8);
|
||||||
|
}
|
||||||
|
|
||||||
|
ul {
|
||||||
|
margin: 12px 0;
|
||||||
|
padding-left: 24px;
|
||||||
|
}
|
||||||
|
|
||||||
|
li {
|
||||||
|
margin: 6px 0;
|
||||||
|
line-height: 1.6;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark li {
|
||||||
|
color: rgba(255, 255, 255, 0.75);
|
||||||
|
}
|
||||||
|
|
||||||
|
a {
|
||||||
|
color: #0066cc;
|
||||||
|
text-decoration: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
a:hover {
|
||||||
|
text-decoration: underline;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark a {
|
||||||
|
color: #4da6ff;
|
||||||
|
}
|
||||||
|
|
||||||
|
.footer {
|
||||||
|
margin-top: 40px;
|
||||||
|
padding-top: 20px;
|
||||||
|
border-top: 1px solid #e0e0e0;
|
||||||
|
font-size: 13px;
|
||||||
|
color: #666;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark .footer {
|
||||||
|
border-top-color: rgba(255, 255, 255, 0.1);
|
||||||
|
color: rgba(255, 255, 255, 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
.content-wrapper {
|
||||||
|
max-height: calc(100vh - 40px);
|
||||||
|
overflow-y: auto;
|
||||||
|
padding-right: 10px;
|
||||||
|
background: transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Scrollbar styles - Light mode */
|
||||||
|
::-webkit-scrollbar {
|
||||||
|
width: 8px;
|
||||||
|
height: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
::-webkit-scrollbar-track {
|
||||||
|
background: rgba(0, 0, 0, 0.05);
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
::-webkit-scrollbar-thumb {
|
||||||
|
background: rgba(0, 0, 0, 0.2);
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
::-webkit-scrollbar-thumb:hover {
|
||||||
|
background: rgba(0, 0, 0, 0.3);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Scrollbar styles - Dark mode */
|
||||||
|
body.dark ::-webkit-scrollbar-track {
|
||||||
|
background: rgba(255, 255, 255, 0.05);
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark ::-webkit-scrollbar-thumb {
|
||||||
|
background: rgba(255, 255, 255, 0.2);
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark ::-webkit-scrollbar-thumb:hover {
|
||||||
|
background: rgba(255, 255, 255, 0.3);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
<script>
|
||||||
|
// Detect theme
|
||||||
|
document.addEventListener('DOMContentLoaded', function () {
|
||||||
|
const urlParams = new URLSearchParams(window.location.search);
|
||||||
|
const theme = urlParams.get('theme');
|
||||||
|
if (theme === 'dark') {
|
||||||
|
document.documentElement.classList.add('dark');
|
||||||
|
document.body.classList.add('dark');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<div class="content-wrapper">
|
||||||
|
<h1>Privacy Policy</h1>
|
||||||
|
|
||||||
|
<p>
|
||||||
|
Welcome to Cherry Studio (hereinafter referred to as "the Software" or "we"). We highly value your privacy
|
||||||
|
protection. This Privacy Policy explains how we process and protect your personal information and data.
|
||||||
|
Please read and understand this policy carefully before using the Software:
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<h2>1. Information We Collect</h2>
|
||||||
|
<p>To optimize user experience and improve software quality, we may only collect the following anonymous,
|
||||||
|
non-personal information:</p>
|
||||||
|
<ul>
|
||||||
|
<li>Software version information</li>
|
||||||
|
<li>Activity and usage frequency of software features</li>
|
||||||
|
<li>Anonymous crash and error log information</li>
|
||||||
|
</ul>
|
||||||
|
<p>The above information is completely anonymous, does not involve any personal identity data, and cannot be
|
||||||
|
linked to your personal information.</p>
|
||||||
|
|
||||||
|
<h2>2. Information We Do Not Collect</h2>
|
||||||
|
<p>To maximize the protection of your privacy and security, we explicitly commit that we:</p>
|
||||||
|
<ul>
|
||||||
|
<li>Will not collect, save, transmit, or process model service API Key information you enter into the
|
||||||
|
Software</li>
|
||||||
|
<li>Will not collect, save, transmit, or process any conversation data generated during your use of the
|
||||||
|
Software, including but not limited to chat content, instruction information, knowledge base
|
||||||
|
information, vector data, and other custom content</li>
|
||||||
|
<li>Will not collect, save, transmit, or process any sensitive information that can identify personal
|
||||||
|
identity</li>
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
<h2>3. Data Interaction Description</h2>
|
||||||
|
<p>
|
||||||
|
The Software uses API Keys from third-party model service providers that you apply for and configure
|
||||||
|
yourself to complete model calls and conversation functions. The model services you use (such as large
|
||||||
|
models, API interfaces, etc.) are directly provided by third-party providers of your choice. We do not
|
||||||
|
intervene, monitor, or interfere with the data transmission process.
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
Data interactions between you and third-party model services are governed by the privacy policies and user
|
||||||
|
agreements of third-party service providers. We recommend that you fully understand the privacy terms of
|
||||||
|
relevant service providers before use.
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<h2>4. Local Data Security Protection</h2>
|
||||||
|
<p>The Software is a localized application, and all data is stored on your local device by default. We have
|
||||||
|
taken the following measures to ensure data security:</p>
|
||||||
|
<ul>
|
||||||
|
<li>Conversation records, configuration information, and other data are only saved on your local device</li>
|
||||||
|
<li>Data import/export functions are provided to facilitate your independent management and backup of data
|
||||||
|
</li>
|
||||||
|
<li>Your local data will not be uploaded to any server or cloud storage</li>
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
<h2>5. Third-Party Services</h2>
|
||||||
|
<p>
|
||||||
|
When using the Software, you may access third-party services (such as AI model APIs, translation services,
|
||||||
|
etc.). The use of these third-party services is governed by their respective terms of service and privacy
|
||||||
|
policies. We strongly recommend that you carefully read and understand the relevant terms before use.
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<h2>6. User Rights</h2>
|
||||||
|
<p>You have complete control over your data:</p>
|
||||||
|
<ul>
|
||||||
|
<li>You can view, modify, and delete all locally stored data at any time</li>
|
||||||
|
<li>You can choose whether to enable specific features or services</li>
|
||||||
|
<li>You can stop using the Software and delete all related data at any time</li>
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
<h2>7. Children's Privacy Protection</h2>
|
||||||
|
<p>The Software is not intended for minors under 18 years of age. If you are a minor, please use the Software
|
||||||
|
under the guidance of a guardian.</p>
|
||||||
|
|
||||||
|
<h2>8. Privacy Policy Updates</h2>
|
||||||
|
<p>
|
||||||
|
We may update this Privacy Policy based on legal requirements or changes in product features. The updated
|
||||||
|
policy will be published in the Software and you will be notified before it takes effect. If you do not
|
||||||
|
agree with the updated terms, you can choose to stop using the Software.
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<h2>9. Contact Us</h2>
|
||||||
|
<p>If you have any questions, suggestions, or complaints about this Privacy Policy, please contact us through
|
||||||
|
the following methods:</p>
|
||||||
|
<ul>
|
||||||
|
<li>
|
||||||
|
GitHub: <a href="https://github.com/CherryHQ/cherry-studio" target="_blank"
|
||||||
|
rel="noopener noreferrer">https://github.com/CherryHQ/cherry-studio</a>
|
||||||
|
</li>
|
||||||
|
<li>Email: support@cherry-ai.com</li>
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
<div class="footer">
|
||||||
|
Last Updated: December 2024
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
||||||
@@ -0,0 +1,230 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="zh">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>隐私协议</title>
|
||||||
|
<style>
|
||||||
|
* {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Helvetica Neue', Arial, sans-serif;
|
||||||
|
line-height: 1.6;
|
||||||
|
color: #333;
|
||||||
|
background: transparent;
|
||||||
|
margin: 0 auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark {
|
||||||
|
background: transparent;
|
||||||
|
color: rgba(255, 255, 255, 0.85);
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 {
|
||||||
|
font-size: 24px;
|
||||||
|
font-weight: 600;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
color: #1a1a1a;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark h1 {
|
||||||
|
color: rgba(255, 255, 255, 0.95);
|
||||||
|
}
|
||||||
|
|
||||||
|
h2 {
|
||||||
|
font-size: 18px;
|
||||||
|
font-weight: 600;
|
||||||
|
margin-top: 24px;
|
||||||
|
margin-bottom: 12px;
|
||||||
|
color: #2c2c2c;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark h2 {
|
||||||
|
color: rgba(255, 255, 255, 0.9);
|
||||||
|
}
|
||||||
|
|
||||||
|
p {
|
||||||
|
margin: 12px 0;
|
||||||
|
line-height: 1.8;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark p {
|
||||||
|
color: rgba(255, 255, 255, 0.8);
|
||||||
|
}
|
||||||
|
|
||||||
|
ul {
|
||||||
|
margin: 12px 0;
|
||||||
|
padding-left: 24px;
|
||||||
|
}
|
||||||
|
|
||||||
|
li {
|
||||||
|
margin: 6px 0;
|
||||||
|
line-height: 1.6;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark li {
|
||||||
|
color: rgba(255, 255, 255, 0.75);
|
||||||
|
}
|
||||||
|
|
||||||
|
a {
|
||||||
|
color: #0066cc;
|
||||||
|
text-decoration: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
a:hover {
|
||||||
|
text-decoration: underline;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark a {
|
||||||
|
color: #4da6ff;
|
||||||
|
}
|
||||||
|
|
||||||
|
.footer {
|
||||||
|
margin-top: 40px;
|
||||||
|
padding-top: 20px;
|
||||||
|
border-top: 1px solid #e0e0e0;
|
||||||
|
font-size: 13px;
|
||||||
|
color: #666;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark .footer {
|
||||||
|
border-top-color: rgba(255, 255, 255, 0.1);
|
||||||
|
color: rgba(255, 255, 255, 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
.content-wrapper {
|
||||||
|
overflow-y: auto;
|
||||||
|
padding-right: 10px;
|
||||||
|
background: transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 滚动条样式 - 亮色模式 */
|
||||||
|
::-webkit-scrollbar {
|
||||||
|
width: 8px;
|
||||||
|
height: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
::-webkit-scrollbar-track {
|
||||||
|
background: rgba(0, 0, 0, 0.05);
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
::-webkit-scrollbar-thumb {
|
||||||
|
background: rgba(0, 0, 0, 0.2);
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
::-webkit-scrollbar-thumb:hover {
|
||||||
|
background: rgba(0, 0, 0, 0.3);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 滚动条样式 - 暗色模式 */
|
||||||
|
body.dark ::-webkit-scrollbar-track {
|
||||||
|
background: rgba(255, 255, 255, 0.05);
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark ::-webkit-scrollbar-thumb {
|
||||||
|
background: rgba(255, 255, 255, 0.2);
|
||||||
|
}
|
||||||
|
|
||||||
|
body.dark ::-webkit-scrollbar-thumb:hover {
|
||||||
|
background: rgba(255, 255, 255, 0.3);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
<script>
|
||||||
|
// 检测主题
|
||||||
|
document.addEventListener('DOMContentLoaded', function () {
|
||||||
|
const urlParams = new URLSearchParams(window.location.search);
|
||||||
|
const theme = urlParams.get('theme');
|
||||||
|
if (theme === 'dark') {
|
||||||
|
document.documentElement.classList.add('dark');
|
||||||
|
document.body.classList.add('dark');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<div class="content-wrapper">
|
||||||
|
<h1>隐私协议</h1>
|
||||||
|
|
||||||
|
<p>
|
||||||
|
欢迎使用 Cherry Studio(以下简称"本软件"或"我们")。我们高度重视您的隐私保护,本隐私协议将说明我们如何处理与保护您的个人信息和数据。请在使用本软件前仔细阅读并理解本协议:
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<h2>一、我们收集的信息范围</h2>
|
||||||
|
<p>为了优化用户体验和提升软件质量,我们仅可能会匿名收集以下非个人化信息:</p>
|
||||||
|
<ul>
|
||||||
|
<li>软件版本信息;</li>
|
||||||
|
<li>软件功能的活跃度、使用频次;</li>
|
||||||
|
<li>匿名的崩溃、错误日志信息;</li>
|
||||||
|
</ul>
|
||||||
|
<p>上述信息完全匿名,不会涉及任何个人身份数据,也无法关联到您的个人信息。</p>
|
||||||
|
|
||||||
|
<h2>二、我们不会收集的任何信息</h2>
|
||||||
|
<p>为了最大限度保护您的隐私安全,我们明确承诺:</p>
|
||||||
|
<ul>
|
||||||
|
<li>不会收集、保存、传输或处理您输入到本软件中的模型服务 API Key 信息;</li>
|
||||||
|
<li>不会收集、保存、传输或处理您在使用本软件过程中产生的任何对话数据,包括但不限于聊天内容、指令信息、知识库信息、向量数据及其他自定义内容;</li>
|
||||||
|
<li>不会收集、保存、传输或处理任何可识别个人身份的敏感信息。</li>
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
<h2>三、数据交互说明</h2>
|
||||||
|
<p>
|
||||||
|
本软件采用您自行申请并配置的第三方模型服务提供商的 API Key,以完成相关模型的调用与对话功能。您使用的模型服务(例如大模型、API 接口等)由您选择的第三方提供商直接提供,我们不会介入、监控或干扰数据传输过程。
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
您与第三方模型服务之间的数据交互受第三方服务提供商的隐私政策和用户协议约束,我们建议您在使用前充分了解相关服务商的隐私条款。
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<h2>四、本地数据的安全保护</h2>
|
||||||
|
<p>本软件为本地化应用程序,所有数据默认存储在您的本地设备上。我们采取了以下措施保障数据安全:</p>
|
||||||
|
<ul>
|
||||||
|
<li>对话记录、配置信息等数据仅保存在您的本地设备中;</li>
|
||||||
|
<li>提供数据导入/导出功能,方便您自主管理和备份数据;</li>
|
||||||
|
<li>不会将您的本地数据上传至任何服务器或云端存储。</li>
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
<h2>五、第三方服务</h2>
|
||||||
|
<p>
|
||||||
|
在使用本软件过程中,您可能会接入第三方服务(如 AI 模型 API、翻译服务等)。这些第三方服务的使用受其各自的服务条款和隐私政策约束。我们强烈建议您在使用前仔细阅读并理解相关条款。
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<h2>六、用户权利</h2>
|
||||||
|
<p>您对自己的数据拥有完全的控制权:</p>
|
||||||
|
<ul>
|
||||||
|
<li>您可以随时查看、修改、删除本地存储的所有数据;</li>
|
||||||
|
<li>您可以选择是否启用特定功能或服务;</li>
|
||||||
|
<li>您可以随时停止使用本软件并删除所有相关数据。</li>
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
<h2>七、儿童隐私保护</h2>
|
||||||
|
<p>本软件不面向 18 岁以下的未成年人提供服务。如果您是未成年人,请在监护人的指导下使用本软件。</p>
|
||||||
|
|
||||||
|
<h2>八、隐私政策的更新</h2>
|
||||||
|
<p>
|
||||||
|
我们可能会根据法律法规要求或产品功能的变化更新本隐私协议。更新后的协议将在软件中发布,并在生效前通知您。如果您不同意更新后的条款,您可以选择停止使用本软件。
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<h2>九、联系我们</h2>
|
||||||
|
<p>如果您对本隐私协议有任何疑问、建议或投诉,请通过以下方式联系我们:</p>
|
||||||
|
<ul>
|
||||||
|
<li>
|
||||||
|
GitHub: <a href="https://github.com/CherryHQ/cherry-studio" target="_blank"
|
||||||
|
rel="noopener noreferrer">https://github.com/CherryHQ/cherry-studio</a>
|
||||||
|
</li>
|
||||||
|
<li>Email: support@cherry-ai.com</li>
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
<div class="footer">
|
||||||
|
最后更新日期:2024年12月
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
const https = require('https')
|
const https = require('https')
|
||||||
const fs = require('fs')
|
const fs = require('fs')
|
||||||
|
const path = require('path')
|
||||||
|
const { execSync } = require('child_process')
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Downloads a file from a URL with redirect handling
|
* Downloads a file from a URL with redirect handling
|
||||||
@@ -32,4 +34,39 @@ async function downloadWithRedirects(url, destinationPath) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = { downloadWithRedirects }
|
/**
|
||||||
|
* Downloads a file using PowerShell Invoke-WebRequest command
|
||||||
|
* @param {string} url The URL to download from
|
||||||
|
* @param {string} destinationPath The path to save the file to
|
||||||
|
* @returns {Promise<boolean>} Promise that resolves to true if download succeeds
|
||||||
|
*/
|
||||||
|
async function downloadWithPowerShell(url, destinationPath) {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
try {
|
||||||
|
// Only support windows platform for PowerShell download
|
||||||
|
if (process.platform !== 'win32') {
|
||||||
|
return reject(new Error('PowerShell download is only supported on Windows'))
|
||||||
|
}
|
||||||
|
|
||||||
|
const outputDir = path.dirname(destinationPath)
|
||||||
|
fs.mkdirSync(outputDir, { recursive: true })
|
||||||
|
|
||||||
|
// PowerShell command to download the file with progress disabled for faster download
|
||||||
|
const psCommand = `powershell -Command "$ProgressPreference = 'SilentlyContinue'; Invoke-WebRequest '${url}' -OutFile '${destinationPath}'"`
|
||||||
|
|
||||||
|
console.log(`Downloading with PowerShell: ${url}`)
|
||||||
|
execSync(psCommand, { stdio: 'inherit' })
|
||||||
|
|
||||||
|
if (fs.existsSync(destinationPath)) {
|
||||||
|
console.log(`Download completed: ${destinationPath}`)
|
||||||
|
resolve(true)
|
||||||
|
} else {
|
||||||
|
reject(new Error('Download failed: File not found after download'))
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
reject(new Error(`PowerShell download failed: ${error.message}`))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = { downloadWithRedirects, downloadWithPowerShell }
|
||||||
|
|||||||
@@ -0,0 +1,263 @@
|
|||||||
|
const fs = require('fs')
|
||||||
|
const path = require('path')
|
||||||
|
const os = require('os')
|
||||||
|
const { execSync } = require('child_process')
|
||||||
|
const { downloadWithPowerShell } = require('./download')
|
||||||
|
|
||||||
|
// Base URL for downloading OVMS binaries
|
||||||
|
const OVMS_RELEASE_BASE_URL =
|
||||||
|
'https://storage.openvinotoolkit.org/repositories/openvino_model_server/packages/2025.3.0/ovms_windows_python_on.zip'
|
||||||
|
const OVMS_EX_URL = 'https://gitcode.com/gcw_ggDjjkY3/kjfile/releases/download/download/ovms_25.3_ex.zip'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* error code:
|
||||||
|
* 101: Unsupported CPU (not Intel Ultra)
|
||||||
|
* 102: Unsupported platform (not Windows)
|
||||||
|
* 103: Download failed
|
||||||
|
* 104: Installation failed
|
||||||
|
* 105: Failed to create ovdnd.exe
|
||||||
|
* 106: Failed to create run.bat
|
||||||
|
* 110: Cleanup of old installation failed
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clean old OVMS installation if it exists
|
||||||
|
*/
|
||||||
|
function cleanOldOvmsInstallation() {
|
||||||
|
console.log('Cleaning the existing OVMS installation...')
|
||||||
|
const csDir = path.join(os.homedir(), '.cherrystudio')
|
||||||
|
const csOvmsDir = path.join(csDir, 'ovms')
|
||||||
|
if (fs.existsSync(csOvmsDir)) {
|
||||||
|
try {
|
||||||
|
fs.rmSync(csOvmsDir, { recursive: true })
|
||||||
|
} catch (error) {
|
||||||
|
console.warn(`Failed to clean up old OVMS installation: ${error.message}`)
|
||||||
|
return 110
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Install OVMS Base package
|
||||||
|
*/
|
||||||
|
async function installOvmsBase() {
|
||||||
|
// Download the base package
|
||||||
|
const tempdir = os.tmpdir()
|
||||||
|
const tempFilename = path.join(tempdir, 'ovms.zip')
|
||||||
|
|
||||||
|
try {
|
||||||
|
console.log(`Downloading OVMS Base Package from ${OVMS_RELEASE_BASE_URL} to ${tempFilename}...`)
|
||||||
|
|
||||||
|
// Try PowerShell download first, fallback to Node.js download if it fails
|
||||||
|
await downloadWithPowerShell(OVMS_RELEASE_BASE_URL, tempFilename)
|
||||||
|
console.log(`Successfully downloaded from: ${OVMS_RELEASE_BASE_URL}`)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Download OVMS Base failed: ${error.message}`)
|
||||||
|
fs.unlinkSync(tempFilename)
|
||||||
|
return 103
|
||||||
|
}
|
||||||
|
|
||||||
|
// unzip the base package to the target directory
|
||||||
|
const csDir = path.join(os.homedir(), '.cherrystudio')
|
||||||
|
const csOvmsDir = path.join(csDir, 'ovms')
|
||||||
|
fs.mkdirSync(csOvmsDir, { recursive: true })
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Use tar.exe to extract the ZIP file
|
||||||
|
console.log(`Extracting OVMS Base to ${csOvmsDir}...`)
|
||||||
|
execSync(`tar -xf ${tempFilename} -C ${csOvmsDir}`, { stdio: 'inherit' })
|
||||||
|
console.log(`OVMS extracted to ${csOvmsDir}`)
|
||||||
|
|
||||||
|
// Clean up temporary file
|
||||||
|
fs.unlinkSync(tempFilename)
|
||||||
|
console.log(`Installation directory: ${csDir}`)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Error installing OVMS: ${error.message}`)
|
||||||
|
fs.unlinkSync(tempFilename)
|
||||||
|
return 104
|
||||||
|
}
|
||||||
|
|
||||||
|
const csOvmsBinDir = path.join(csOvmsDir, 'ovms')
|
||||||
|
// copy ovms.exe to ovdnd.exe
|
||||||
|
try {
|
||||||
|
fs.copyFileSync(path.join(csOvmsBinDir, 'ovms.exe'), path.join(csOvmsBinDir, 'ovdnd.exe'))
|
||||||
|
console.log('Copied ovms.exe to ovdnd.exe')
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Error copying ovms.exe to ovdnd.exe: ${error.message}`)
|
||||||
|
return 105
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy {csOvmsBinDir}/setupvars.bat to {csOvmsBinDir}/run.bat, and append the following lines to run.bat:
|
||||||
|
// del %USERPROFILE%\.cherrystudio\ovms_log.log
|
||||||
|
// ovms.exe --config_path models/config.json --rest_port 8000 --log_level DEBUG --log_path %USERPROFILE%\.cherrystudio\ovms_log.log
|
||||||
|
const runBatPath = path.join(csOvmsBinDir, 'run.bat')
|
||||||
|
try {
|
||||||
|
fs.copyFileSync(path.join(csOvmsBinDir, 'setupvars.bat'), runBatPath)
|
||||||
|
fs.appendFileSync(runBatPath, '\r\n')
|
||||||
|
fs.appendFileSync(runBatPath, 'del %USERPROFILE%\\.cherrystudio\\ovms_log.log\r\n')
|
||||||
|
fs.appendFileSync(
|
||||||
|
runBatPath,
|
||||||
|
'ovms.exe --config_path models/config.json --rest_port 8000 --log_level DEBUG --log_path %USERPROFILE%\\.cherrystudio\\ovms_log.log\r\n'
|
||||||
|
)
|
||||||
|
console.log(`Created run.bat at: ${runBatPath}`)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Error creating run.bat: ${error.message}`)
|
||||||
|
return 106
|
||||||
|
}
|
||||||
|
|
||||||
|
// create {csOvmsBinDir}/models/config.json with content '{"model_config_list": []}'
|
||||||
|
const configJsonPath = path.join(csOvmsBinDir, 'models', 'config.json')
|
||||||
|
fs.mkdirSync(path.dirname(configJsonPath), { recursive: true })
|
||||||
|
fs.writeFileSync(configJsonPath, '{"mediapipe_config_list":[],"model_config_list":[]}')
|
||||||
|
console.log(`Created config file: ${configJsonPath}`)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Install OVMS Extra package
|
||||||
|
*/
|
||||||
|
async function installOvmsExtra() {
|
||||||
|
// Download the extra package
|
||||||
|
const tempdir = os.tmpdir()
|
||||||
|
const tempFilename = path.join(tempdir, 'ovms_ex.zip')
|
||||||
|
|
||||||
|
try {
|
||||||
|
console.log(`Downloading OVMS Extra Package from ${OVMS_EX_URL} to ${tempFilename}...`)
|
||||||
|
|
||||||
|
// Try PowerShell download first, fallback to Node.js download if it fails
|
||||||
|
await downloadWithPowerShell(OVMS_EX_URL, tempFilename)
|
||||||
|
console.log(`Successfully downloaded from: ${OVMS_EX_URL}`)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Download OVMS Extra failed: ${error.message}`)
|
||||||
|
fs.unlinkSync(tempFilename)
|
||||||
|
return 103
|
||||||
|
}
|
||||||
|
|
||||||
|
// unzip the extra package to the target directory
|
||||||
|
const csDir = path.join(os.homedir(), '.cherrystudio')
|
||||||
|
const csOvmsDir = path.join(csDir, 'ovms')
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Use tar.exe to extract the ZIP file
|
||||||
|
console.log(`Extracting OVMS Extra to ${csOvmsDir}...`)
|
||||||
|
execSync(`tar -xf ${tempFilename} -C ${csOvmsDir}`, { stdio: 'inherit' })
|
||||||
|
console.log(`OVMS extracted to ${csOvmsDir}`)
|
||||||
|
|
||||||
|
// Clean up temporary file
|
||||||
|
fs.unlinkSync(tempFilename)
|
||||||
|
console.log(`Installation directory: ${csDir}`)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Error installing OVMS Extra: ${error.message}`)
|
||||||
|
fs.unlinkSync(tempFilename)
|
||||||
|
return 104
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply ovms patch, copy all files in {csOvmsDir}/patch/ovms to {csOvmsDir}/ovms with overwrite mode
|
||||||
|
const patchDir = path.join(csOvmsDir, 'patch', 'ovms')
|
||||||
|
const csOvmsBinDir = path.join(csOvmsDir, 'ovms')
|
||||||
|
try {
|
||||||
|
const files = fs.readdirSync(patchDir)
|
||||||
|
files.forEach((file) => {
|
||||||
|
const srcPath = path.join(patchDir, file)
|
||||||
|
const destPath = path.join(csOvmsBinDir, file)
|
||||||
|
fs.copyFileSync(srcPath, destPath)
|
||||||
|
console.log(`Applied patch file: ${file}`)
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Error applying OVMS patch: ${error.message}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the CPU Name and ID
|
||||||
|
*/
|
||||||
|
function getCpuInfo() {
|
||||||
|
const cpuInfo = {
|
||||||
|
name: '',
|
||||||
|
id: ''
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use PowerShell to get CPU information
|
||||||
|
try {
|
||||||
|
const psCommand = `powershell -Command "Get-CimInstance -ClassName Win32_Processor | Select-Object Name, DeviceID | ConvertTo-Json"`
|
||||||
|
const psOutput = execSync(psCommand).toString()
|
||||||
|
const cpuData = JSON.parse(psOutput)
|
||||||
|
|
||||||
|
if (Array.isArray(cpuData)) {
|
||||||
|
cpuInfo.name = cpuData[0].Name || ''
|
||||||
|
cpuInfo.id = cpuData[0].DeviceID || ''
|
||||||
|
} else {
|
||||||
|
cpuInfo.name = cpuData.Name || ''
|
||||||
|
cpuInfo.id = cpuData.DeviceID || ''
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to get CPU info: ${error.message}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cpuInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Main function to install OVMS
|
||||||
|
*/
|
||||||
|
async function installOvms() {
|
||||||
|
const platform = os.platform()
|
||||||
|
console.log(`Detected platform: ${platform}`)
|
||||||
|
|
||||||
|
const cpuName = getCpuInfo().name
|
||||||
|
console.log(`CPU Name: ${cpuName}`)
|
||||||
|
|
||||||
|
// Check if CPU name contains "Ultra"
|
||||||
|
if (!cpuName.toLowerCase().includes('intel') || !cpuName.toLowerCase().includes('ultra')) {
|
||||||
|
console.error('OVMS installation requires an Intel(R) Core(TM) Ultra CPU.')
|
||||||
|
return 101
|
||||||
|
}
|
||||||
|
|
||||||
|
// only support windows
|
||||||
|
if (platform !== 'win32') {
|
||||||
|
console.error('OVMS installation is only supported on Windows.')
|
||||||
|
return 102
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean old installation if it exists
|
||||||
|
const cleanupCode = cleanOldOvmsInstallation()
|
||||||
|
if (cleanupCode !== 0) {
|
||||||
|
console.error(`OVMS cleanup failed with code: ${cleanupCode}`)
|
||||||
|
return cleanupCode
|
||||||
|
}
|
||||||
|
|
||||||
|
const installBaseCode = await installOvmsBase()
|
||||||
|
if (installBaseCode !== 0) {
|
||||||
|
console.error(`OVMS Base installation failed with code: ${installBaseCode}`)
|
||||||
|
cleanOldOvmsInstallation()
|
||||||
|
return installBaseCode
|
||||||
|
}
|
||||||
|
|
||||||
|
const installExtraCode = await installOvmsExtra()
|
||||||
|
if (installExtraCode !== 0) {
|
||||||
|
console.error(`OVMS Extra installation failed with code: ${installExtraCode}`)
|
||||||
|
return installExtraCode
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the installation
|
||||||
|
installOvms()
|
||||||
|
.then((retcode) => {
|
||||||
|
if (retcode === 0) {
|
||||||
|
console.log('OVMS installation successful')
|
||||||
|
} else {
|
||||||
|
console.error('OVMS installation failed')
|
||||||
|
}
|
||||||
|
process.exit(retcode)
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('OVMS installation failed:', error)
|
||||||
|
process.exit(100)
|
||||||
|
})
|
||||||
@@ -35,7 +35,7 @@ const allX64 = {
|
|||||||
'@napi-rs/system-ocr-win32-x64-msvc': '1.0.2'
|
'@napi-rs/system-ocr-win32-x64-msvc': '1.0.2'
|
||||||
}
|
}
|
||||||
|
|
||||||
const claudeCodeVenderPath = '@anthropic-ai/claude-code/vendor'
|
const claudeCodeVenderPath = '@anthropic-ai/claude-agent-sdk/vendor'
|
||||||
const claudeCodeVenders = ['arm64-darwin', 'arm64-linux', 'x64-darwin', 'x64-linux', 'x64-win32']
|
const claudeCodeVenders = ['arm64-darwin', 'arm64-linux', 'x64-darwin', 'x64-linux', 'x64-win32']
|
||||||
|
|
||||||
const platformToArch = {
|
const platformToArch = {
|
||||||
@@ -88,7 +88,7 @@ exports.default = async function (context) {
|
|||||||
const excludeClaudeCodeJBPlutins = ['!node_modules/' + claudeCodeVenderPath + '/' + 'claude-code-jetbrains-plugin']
|
const excludeClaudeCodeJBPlutins = ['!node_modules/' + claudeCodeVenderPath + '/' + 'claude-code-jetbrains-plugin']
|
||||||
|
|
||||||
const includeClaudeCodeFilters = [
|
const includeClaudeCodeFilters = [
|
||||||
'!node_modules/' + claudeCodeVenderPath + '/' + `${archType}-${platformToArch[platform]}/**`
|
'!node_modules/' + claudeCodeVenderPath + '/ripgrep/' + `${archType}-${platformToArch[platform]}/**`
|
||||||
]
|
]
|
||||||
|
|
||||||
if (arch === Arch.arm64) {
|
if (arch === Arch.arm64) {
|
||||||
|
|||||||
@@ -3,25 +3,42 @@ import cors from 'cors'
|
|||||||
import express from 'express'
|
import express from 'express'
|
||||||
import { v4 as uuidv4 } from 'uuid'
|
import { v4 as uuidv4 } from 'uuid'
|
||||||
|
|
||||||
|
import { LONG_POLL_TIMEOUT_MS } from './config/timeouts'
|
||||||
import { authMiddleware } from './middleware/auth'
|
import { authMiddleware } from './middleware/auth'
|
||||||
import { errorHandler } from './middleware/error'
|
import { errorHandler } from './middleware/error'
|
||||||
import { setupOpenAPIDocumentation } from './middleware/openapi'
|
import { setupOpenAPIDocumentation } from './middleware/openapi'
|
||||||
import { agentsRoutes } from './routes/agents'
|
import { agentsRoutes } from './routes/agents'
|
||||||
import { chatRoutes } from './routes/chat'
|
import { chatRoutes } from './routes/chat'
|
||||||
import { mcpRoutes } from './routes/mcp'
|
import { mcpRoutes } from './routes/mcp'
|
||||||
import { messagesRoutes } from './routes/messages'
|
import { messagesProviderRoutes, messagesRoutes } from './routes/messages'
|
||||||
import { modelsRoutes } from './routes/models'
|
import { modelsRoutes } from './routes/models'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServer')
|
const logger = loggerService.withContext('ApiServer')
|
||||||
|
|
||||||
|
const extendMessagesTimeout: express.RequestHandler = (req, res, next) => {
|
||||||
|
req.setTimeout(LONG_POLL_TIMEOUT_MS)
|
||||||
|
res.setTimeout(LONG_POLL_TIMEOUT_MS)
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
const app = express()
|
const app = express()
|
||||||
|
app.use(
|
||||||
|
express.json({
|
||||||
|
limit: '50mb'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
// Global middleware
|
// Global middleware
|
||||||
app.use((req, res, next) => {
|
app.use((req, res, next) => {
|
||||||
const start = Date.now()
|
const start = Date.now()
|
||||||
res.on('finish', () => {
|
res.on('finish', () => {
|
||||||
const duration = Date.now() - start
|
const duration = Date.now() - start
|
||||||
logger.info(`${req.method} ${req.path} - ${res.statusCode} - ${duration}ms`)
|
logger.info('API request completed', {
|
||||||
|
method: req.method,
|
||||||
|
path: req.path,
|
||||||
|
statusCode: res.statusCode,
|
||||||
|
durationMs: duration
|
||||||
|
})
|
||||||
})
|
})
|
||||||
next()
|
next()
|
||||||
})
|
})
|
||||||
@@ -108,21 +125,23 @@ app.get('/', (_req, res) => {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Setup OpenAPI documentation before protected routes so docs remain public
|
||||||
|
setupOpenAPIDocumentation(app)
|
||||||
|
|
||||||
|
// Provider-specific messages route requires authentication
|
||||||
|
app.use('/:provider/v1/messages', authMiddleware, extendMessagesTimeout, messagesProviderRoutes)
|
||||||
|
|
||||||
// API v1 routes with auth
|
// API v1 routes with auth
|
||||||
const apiRouter = express.Router()
|
const apiRouter = express.Router()
|
||||||
apiRouter.use(authMiddleware)
|
apiRouter.use(authMiddleware)
|
||||||
apiRouter.use(express.json())
|
|
||||||
// Mount routes
|
// Mount routes
|
||||||
apiRouter.use('/chat', chatRoutes)
|
apiRouter.use('/chat', chatRoutes)
|
||||||
apiRouter.use('/mcps', mcpRoutes)
|
apiRouter.use('/mcps', mcpRoutes)
|
||||||
apiRouter.use('/messages', messagesRoutes)
|
apiRouter.use('/messages', extendMessagesTimeout, messagesRoutes)
|
||||||
apiRouter.use('/models', modelsRoutes)
|
apiRouter.use('/models', modelsRoutes)
|
||||||
apiRouter.use('/agents', agentsRoutes)
|
apiRouter.use('/agents', agentsRoutes)
|
||||||
app.use('/v1', apiRouter)
|
app.use('/v1', apiRouter)
|
||||||
|
|
||||||
// Setup OpenAPI documentation
|
|
||||||
setupOpenAPIDocumentation(app)
|
|
||||||
|
|
||||||
// Error handling (must be last)
|
// Error handling (must be last)
|
||||||
app.use(errorHandler)
|
app.use(errorHandler)
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class ConfigManager {
|
|||||||
}
|
}
|
||||||
return this._config
|
return this._config
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.warn('Failed to load config from Redux, using defaults:', error)
|
logger.warn('Failed to load config from Redux, using defaults', { error })
|
||||||
this._config = {
|
this._config = {
|
||||||
enabled: false,
|
enabled: false,
|
||||||
port: defaultPort,
|
port: defaultPort,
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
export const LONG_POLL_TIMEOUT_MS = 120 * 60_000 // 120 minutes
|
||||||
|
|
||||||
|
export const MESSAGE_STREAM_TIMEOUT_MS = LONG_POLL_TIMEOUT_MS
|
||||||
@@ -0,0 +1,368 @@
|
|||||||
|
import type { NextFunction, Request, Response } from 'express'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import { config } from '../../config'
|
||||||
|
import { authMiddleware } from '../auth'
|
||||||
|
|
||||||
|
// Mock the config module
|
||||||
|
vi.mock('../../config', () => ({
|
||||||
|
config: {
|
||||||
|
get: vi.fn()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Mock the logger
|
||||||
|
vi.mock('@logger', () => ({
|
||||||
|
loggerService: {
|
||||||
|
withContext: vi.fn(() => ({
|
||||||
|
debug: vi.fn()
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
const mockConfig = config as any
|
||||||
|
|
||||||
|
describe('authMiddleware', () => {
|
||||||
|
let req: Partial<Request>
|
||||||
|
let res: Partial<Response>
|
||||||
|
let next: NextFunction
|
||||||
|
let jsonMock: ReturnType<typeof vi.fn>
|
||||||
|
let statusMock: ReturnType<typeof vi.fn>
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jsonMock = vi.fn()
|
||||||
|
statusMock = vi.fn(() => ({ json: jsonMock }))
|
||||||
|
|
||||||
|
req = {
|
||||||
|
header: vi.fn()
|
||||||
|
}
|
||||||
|
res = {
|
||||||
|
status: statusMock
|
||||||
|
}
|
||||||
|
next = vi.fn()
|
||||||
|
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Missing credentials', () => {
|
||||||
|
it('should return 401 when both auth headers are missing', async () => {
|
||||||
|
;(req.header as any).mockReturnValue('')
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: missing credentials' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return 401 when both auth headers are empty strings', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'authorization') return ''
|
||||||
|
if (header === 'x-api-key') return ''
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: missing credentials' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Server configuration', () => {
|
||||||
|
it('should return 403 when API key is not configured', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'x-api-key') return 'some-key'
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
mockConfig.get.mockResolvedValue({ apiKey: '' })
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return 403 when API key is null', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'x-api-key') return 'some-key'
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
mockConfig.get.mockResolvedValue({ apiKey: null })
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('API Key authentication (priority)', () => {
|
||||||
|
const validApiKey = 'valid-api-key-123'
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should authenticate successfully with valid API key', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'x-api-key') return validApiKey
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled()
|
||||||
|
expect(statusMock).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return 403 with invalid API key', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'x-api-key') return 'invalid-key'
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return 401 with empty API key', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'x-api-key') return ' '
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: empty x-api-key' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle API key with whitespace', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'x-api-key') return ` ${validApiKey} `
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled()
|
||||||
|
expect(statusMock).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should prioritize API key over Bearer token when both are present', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'x-api-key') return validApiKey
|
||||||
|
if (header === 'authorization') return 'Bearer invalid-token'
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled()
|
||||||
|
expect(statusMock).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return 403 when API key is invalid even if Bearer token is valid', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'x-api-key') return 'invalid-key'
|
||||||
|
if (header === 'authorization') return `Bearer ${validApiKey}`
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Bearer token authentication (fallback)', () => {
|
||||||
|
const validApiKey = 'valid-api-key-123'
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should authenticate successfully with valid Bearer token when no API key', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'authorization') return `Bearer ${validApiKey}`
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled()
|
||||||
|
expect(statusMock).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return 403 with invalid Bearer token', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'authorization') return 'Bearer invalid-token'
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return 401 with malformed authorization header', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'authorization') return 'Basic sometoken'
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return 401 with Bearer without space', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'authorization') return 'Bearer'
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle Bearer token with only trailing spaces (edge case)', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'authorization') return 'Bearer ' // This will be trimmed to "Bearer" and fail format check
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle Bearer token with case insensitive prefix', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'authorization') return `bearer ${validApiKey}`
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled()
|
||||||
|
expect(statusMock).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle Bearer token with whitespace', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'authorization') return ` Bearer ${validApiKey} `
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(next).toHaveBeenCalled()
|
||||||
|
expect(statusMock).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Edge cases', () => {
|
||||||
|
const validApiKey = 'valid-api-key-123'
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle config.get() rejection', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'x-api-key') return validApiKey
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
mockConfig.get.mockRejectedValue(new Error('Config error'))
|
||||||
|
|
||||||
|
await expect(authMiddleware(req as Request, res as Response, next)).rejects.toThrow('Config error')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should use timing-safe comparison for different length tokens', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'x-api-key') return 'short'
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return 401 when neither credential format is valid', async () => {
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'authorization') return 'Invalid format'
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Timing attack protection', () => {
|
||||||
|
const validApiKey = 'valid-api-key-123'
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle similar length but different API keys securely', async () => {
|
||||||
|
const similarKey = 'valid-api-key-124' // Same length, different last char
|
||||||
|
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'x-api-key') return similarKey
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle similar length but different Bearer tokens securely', async () => {
|
||||||
|
const similarKey = 'valid-api-key-124' // Same length, different last char
|
||||||
|
|
||||||
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
|
if (header === 'authorization') return `Bearer ${similarKey}`
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
|
expect(next).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -3,8 +3,17 @@ import { NextFunction, Request, Response } from 'express'
|
|||||||
|
|
||||||
import { config } from '../config'
|
import { config } from '../config'
|
||||||
|
|
||||||
|
const isValidToken = (token: string, apiKey: string): boolean => {
|
||||||
|
if (token.length !== apiKey.length) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
const tokenBuf = Buffer.from(token)
|
||||||
|
const keyBuf = Buffer.from(apiKey)
|
||||||
|
return crypto.timingSafeEqual(tokenBuf, keyBuf)
|
||||||
|
}
|
||||||
|
|
||||||
export const authMiddleware = async (req: Request, res: Response, next: NextFunction) => {
|
export const authMiddleware = async (req: Request, res: Response, next: NextFunction) => {
|
||||||
const auth = req.header('Authorization') || ''
|
const auth = req.header('authorization') || ''
|
||||||
const xApiKey = req.header('x-api-key') || ''
|
const xApiKey = req.header('x-api-key') || ''
|
||||||
|
|
||||||
// Fast rejection if neither credential header provided
|
// Fast rejection if neither credential header provided
|
||||||
@@ -12,51 +21,46 @@ export const authMiddleware = async (req: Request, res: Response, next: NextFunc
|
|||||||
return res.status(401).json({ error: 'Unauthorized: missing credentials' })
|
return res.status(401).json({ error: 'Unauthorized: missing credentials' })
|
||||||
}
|
}
|
||||||
|
|
||||||
let token: string | undefined
|
|
||||||
|
|
||||||
// Prefer Bearer if well‑formed
|
|
||||||
if (auth) {
|
|
||||||
const trimmed = auth.trim()
|
|
||||||
const bearerPrefix = /^Bearer\s+/i
|
|
||||||
if (bearerPrefix.test(trimmed)) {
|
|
||||||
const candidate = trimmed.replace(bearerPrefix, '').trim()
|
|
||||||
if (!candidate) {
|
|
||||||
return res.status(401).json({ error: 'Unauthorized: empty bearer token' })
|
|
||||||
}
|
|
||||||
token = candidate
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to x-api-key if token still not resolved
|
|
||||||
if (!token && xApiKey) {
|
|
||||||
if (!xApiKey.trim()) {
|
|
||||||
return res.status(401).json({ error: 'Unauthorized: empty x-api-key' })
|
|
||||||
}
|
|
||||||
token = xApiKey.trim()
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!token) {
|
|
||||||
// At this point we had at least one header, but none yielded a usable token
|
|
||||||
return res.status(401).json({ error: 'Unauthorized: invalid credentials format' })
|
|
||||||
}
|
|
||||||
|
|
||||||
const { apiKey } = await config.get()
|
const { apiKey } = await config.get()
|
||||||
|
|
||||||
if (!apiKey) {
|
if (!apiKey) {
|
||||||
// If server not configured, treat as forbidden (or could be 500). Choose 403 to avoid leaking config state.
|
|
||||||
return res.status(403).json({ error: 'Forbidden' })
|
return res.status(403).json({ error: 'Forbidden' })
|
||||||
}
|
}
|
||||||
|
|
||||||
// Timing-safe compare when lengths match, else immediate forbidden
|
// Check API key first (priority)
|
||||||
if (token.length !== apiKey.length) {
|
if (xApiKey) {
|
||||||
return res.status(403).json({ error: 'Forbidden' })
|
const trimmedApiKey = xApiKey.trim()
|
||||||
|
if (!trimmedApiKey) {
|
||||||
|
return res.status(401).json({ error: 'Unauthorized: empty x-api-key' })
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isValidToken(trimmedApiKey, apiKey)) {
|
||||||
|
return next()
|
||||||
|
} else {
|
||||||
|
return res.status(403).json({ error: 'Forbidden' })
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const tokenBuf = Buffer.from(token)
|
// Fallback to Bearer token
|
||||||
const keyBuf = Buffer.from(apiKey)
|
if (auth) {
|
||||||
if (!crypto.timingSafeEqual(tokenBuf, keyBuf)) {
|
const trimmed = auth.trim()
|
||||||
return res.status(403).json({ error: 'Forbidden' })
|
const bearerPrefix = /^Bearer\s+/i
|
||||||
|
|
||||||
|
if (!bearerPrefix.test(trimmed)) {
|
||||||
|
return res.status(401).json({ error: 'Unauthorized: invalid authorization format' })
|
||||||
|
}
|
||||||
|
|
||||||
|
const token = trimmed.replace(bearerPrefix, '').trim()
|
||||||
|
if (!token) {
|
||||||
|
return res.status(401).json({ error: 'Unauthorized: empty bearer token' })
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isValidToken(token, apiKey)) {
|
||||||
|
return next()
|
||||||
|
} else {
|
||||||
|
return res.status(403).json({ error: 'Forbidden' })
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return next()
|
return res.status(401).json({ error: 'Unauthorized: invalid credentials format' })
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ const logger = loggerService.withContext('ApiServerErrorHandler')
|
|||||||
|
|
||||||
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
|
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
|
||||||
export const errorHandler = (err: Error, _req: Request, res: Response, _next: NextFunction) => {
|
export const errorHandler = (err: Error, _req: Request, res: Response, _next: NextFunction) => {
|
||||||
logger.error('API Server Error:', err)
|
logger.error('API server error', { error: err })
|
||||||
|
|
||||||
// Don't expose internal errors in production
|
// Don't expose internal errors in production
|
||||||
const isDev = process.env.NODE_ENV === 'development'
|
const isDev = process.env.NODE_ENV === 'development'
|
||||||
|
|||||||
@@ -197,10 +197,11 @@ export function setupOpenAPIDocumentation(app: Express) {
|
|||||||
})
|
})
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info('OpenAPI documentation setup complete')
|
logger.info('OpenAPI documentation ready', {
|
||||||
logger.info('Documentation available at /api-docs')
|
docsPath: '/api-docs',
|
||||||
logger.info('OpenAPI spec available at /api-docs.json')
|
specPath: '/api-docs.json'
|
||||||
|
})
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Failed to setup OpenAPI documentation:', error as Error)
|
logger.error('Failed to setup OpenAPI documentation', { error })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { AgentModelValidationError, agentService } from '@main/services/agents'
|
import { AgentModelValidationError, agentService, sessionService } from '@main/services/agents'
|
||||||
import { ListAgentsResponse,type ReplaceAgentRequest, type UpdateAgentRequest } from '@types'
|
import { ListAgentsResponse, type ReplaceAgentRequest, type UpdateAgentRequest } from '@types'
|
||||||
import { Request, Response } from 'express'
|
import { Request, Response } from 'express'
|
||||||
|
|
||||||
import type { ValidationRequest } from '../validators/zodValidator'
|
import type { ValidationRequest } from '../validators/zodValidator'
|
||||||
@@ -20,7 +20,8 @@ const modelValidationErrorBody = (error: AgentModelValidationError) => ({
|
|||||||
* /v1/agents:
|
* /v1/agents:
|
||||||
* post:
|
* post:
|
||||||
* summary: Create a new agent
|
* summary: Create a new agent
|
||||||
* description: Creates a new autonomous agent with the specified configuration
|
* description: Creates a new autonomous agent with the specified configuration and automatically
|
||||||
|
* provisions an initial session that mirrors the agent's settings.
|
||||||
* tags: [Agents]
|
* tags: [Agents]
|
||||||
* requestBody:
|
* requestBody:
|
||||||
* required: true
|
* required: true
|
||||||
@@ -50,16 +51,45 @@ const modelValidationErrorBody = (error: AgentModelValidationError) => ({
|
|||||||
*/
|
*/
|
||||||
export const createAgent = async (req: Request, res: Response): Promise<Response> => {
|
export const createAgent = async (req: Request, res: Response): Promise<Response> => {
|
||||||
try {
|
try {
|
||||||
logger.info('Creating new agent')
|
logger.debug('Creating agent')
|
||||||
logger.debug('Agent data:', req.body)
|
logger.debug('Agent payload', { body: req.body })
|
||||||
|
|
||||||
const agent = await agentService.createAgent(req.body)
|
const agent = await agentService.createAgent(req.body)
|
||||||
|
|
||||||
logger.info(`Agent created successfully: ${agent.id}`)
|
try {
|
||||||
return res.status(201).json(agent)
|
logger.info('Agent created', { agentId: agent.id })
|
||||||
|
logger.debug('Creating default session for agent', { agentId: agent.id })
|
||||||
|
|
||||||
|
await sessionService.createSession(agent.id, {})
|
||||||
|
|
||||||
|
logger.info('Default session created for agent', { agentId: agent.id })
|
||||||
|
return res.status(201).json(agent)
|
||||||
|
} catch (sessionError: any) {
|
||||||
|
logger.error('Failed to create default session for new agent, rolling back agent creation', {
|
||||||
|
agentId: agent.id,
|
||||||
|
error: sessionError
|
||||||
|
})
|
||||||
|
|
||||||
|
try {
|
||||||
|
await agentService.deleteAgent(agent.id)
|
||||||
|
} catch (rollbackError: any) {
|
||||||
|
logger.error('Failed to roll back agent after session creation failure', {
|
||||||
|
agentId: agent.id,
|
||||||
|
error: rollbackError
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.status(500).json({
|
||||||
|
error: {
|
||||||
|
message: `Failed to create default session for agent: ${sessionError.message}`,
|
||||||
|
type: 'internal_error',
|
||||||
|
code: 'agent_session_creation_failed'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
if (error instanceof AgentModelValidationError) {
|
if (error instanceof AgentModelValidationError) {
|
||||||
logger.warn('Agent model validation error during create:', {
|
logger.warn('Agent model validation error during create', {
|
||||||
agentType: error.context.agentType,
|
agentType: error.context.agentType,
|
||||||
field: error.context.field,
|
field: error.context.field,
|
||||||
model: error.context.model,
|
model: error.context.model,
|
||||||
@@ -68,7 +98,7 @@ export const createAgent = async (req: Request, res: Response): Promise<Response
|
|||||||
return res.status(400).json(modelValidationErrorBody(error))
|
return res.status(400).json(modelValidationErrorBody(error))
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.error('Error creating agent:', error)
|
logger.error('Error creating agent', { error })
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: `Failed to create agent: ${error.message}`,
|
message: `Failed to create agent: ${error.message}`,
|
||||||
@@ -141,11 +171,16 @@ export const listAgents = async (req: Request, res: Response): Promise<Response>
|
|||||||
const limit = req.query.limit ? parseInt(req.query.limit as string) : 20
|
const limit = req.query.limit ? parseInt(req.query.limit as string) : 20
|
||||||
const offset = req.query.offset ? parseInt(req.query.offset as string) : 0
|
const offset = req.query.offset ? parseInt(req.query.offset as string) : 0
|
||||||
|
|
||||||
logger.info(`Listing agents with limit=${limit}, offset=${offset}`)
|
logger.debug('Listing agents', { limit, offset })
|
||||||
|
|
||||||
const result = await agentService.listAgents({ limit, offset })
|
const result = await agentService.listAgents({ limit, offset })
|
||||||
|
|
||||||
logger.info(`Retrieved ${result.agents.length} agents (total: ${result.total})`)
|
logger.info('Agents listed', {
|
||||||
|
returned: result.agents.length,
|
||||||
|
total: result.total,
|
||||||
|
limit,
|
||||||
|
offset
|
||||||
|
})
|
||||||
return res.json({
|
return res.json({
|
||||||
data: result.agents,
|
data: result.agents,
|
||||||
total: result.total,
|
total: result.total,
|
||||||
@@ -153,7 +188,7 @@ export const listAgents = async (req: Request, res: Response): Promise<Response>
|
|||||||
offset
|
offset
|
||||||
} satisfies ListAgentsResponse)
|
} satisfies ListAgentsResponse)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error listing agents:', error)
|
logger.error('Error listing agents', { error })
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Failed to list agents',
|
message: 'Failed to list agents',
|
||||||
@@ -201,12 +236,12 @@ export const listAgents = async (req: Request, res: Response): Promise<Response>
|
|||||||
export const getAgent = async (req: Request, res: Response): Promise<Response> => {
|
export const getAgent = async (req: Request, res: Response): Promise<Response> => {
|
||||||
try {
|
try {
|
||||||
const { agentId } = req.params
|
const { agentId } = req.params
|
||||||
logger.info(`Getting agent: ${agentId}`)
|
logger.debug('Getting agent', { agentId })
|
||||||
|
|
||||||
const agent = await agentService.getAgent(agentId)
|
const agent = await agentService.getAgent(agentId)
|
||||||
|
|
||||||
if (!agent) {
|
if (!agent) {
|
||||||
logger.warn(`Agent not found: ${agentId}`)
|
logger.warn('Agent not found', { agentId })
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Agent not found',
|
message: 'Agent not found',
|
||||||
@@ -216,10 +251,10 @@ export const getAgent = async (req: Request, res: Response): Promise<Response> =
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(`Agent retrieved successfully: ${agentId}`)
|
logger.info('Agent retrieved', { agentId })
|
||||||
return res.json(agent)
|
return res.json(agent)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error getting agent:', error)
|
logger.error('Error getting agent', { error, agentId: req.params.agentId })
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Failed to get agent',
|
message: 'Failed to get agent',
|
||||||
@@ -279,8 +314,8 @@ export const getAgent = async (req: Request, res: Response): Promise<Response> =
|
|||||||
export const updateAgent = async (req: Request, res: Response): Promise<Response> => {
|
export const updateAgent = async (req: Request, res: Response): Promise<Response> => {
|
||||||
const { agentId } = req.params
|
const { agentId } = req.params
|
||||||
try {
|
try {
|
||||||
logger.info(`Updating agent: ${agentId}`)
|
logger.debug('Updating agent', { agentId })
|
||||||
logger.debug('Update data:', req.body)
|
logger.debug('Replace payload', { body: req.body })
|
||||||
|
|
||||||
const { validatedBody } = req as ValidationRequest
|
const { validatedBody } = req as ValidationRequest
|
||||||
const replacePayload = (validatedBody ?? {}) as ReplaceAgentRequest
|
const replacePayload = (validatedBody ?? {}) as ReplaceAgentRequest
|
||||||
@@ -288,7 +323,7 @@ export const updateAgent = async (req: Request, res: Response): Promise<Response
|
|||||||
const agent = await agentService.updateAgent(agentId, replacePayload, { replace: true })
|
const agent = await agentService.updateAgent(agentId, replacePayload, { replace: true })
|
||||||
|
|
||||||
if (!agent) {
|
if (!agent) {
|
||||||
logger.warn(`Agent not found for update: ${agentId}`)
|
logger.warn('Agent not found for update', { agentId })
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Agent not found',
|
message: 'Agent not found',
|
||||||
@@ -298,11 +333,11 @@ export const updateAgent = async (req: Request, res: Response): Promise<Response
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(`Agent updated successfully: ${agentId}`)
|
logger.info('Agent updated', { agentId })
|
||||||
return res.json(agent)
|
return res.json(agent)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
if (error instanceof AgentModelValidationError) {
|
if (error instanceof AgentModelValidationError) {
|
||||||
logger.warn('Agent model validation error during update:', {
|
logger.warn('Agent model validation error during update', {
|
||||||
agentId,
|
agentId,
|
||||||
agentType: error.context.agentType,
|
agentType: error.context.agentType,
|
||||||
field: error.context.field,
|
field: error.context.field,
|
||||||
@@ -312,7 +347,7 @@ export const updateAgent = async (req: Request, res: Response): Promise<Response
|
|||||||
return res.status(400).json(modelValidationErrorBody(error))
|
return res.status(400).json(modelValidationErrorBody(error))
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.error('Error updating agent:', error)
|
logger.error('Error updating agent', { error, agentId })
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Failed to update agent: ' + error.message,
|
message: 'Failed to update agent: ' + error.message,
|
||||||
@@ -365,11 +400,11 @@ export const updateAgent = async (req: Request, res: Response): Promise<Response
|
|||||||
* small_model:
|
* small_model:
|
||||||
* type: string
|
* type: string
|
||||||
* description: Optional small/fast model ID
|
* description: Optional small/fast model ID
|
||||||
* built_in_tools:
|
* tools:
|
||||||
* type: array
|
* type: array
|
||||||
* items:
|
* items:
|
||||||
* type: string
|
* type: string
|
||||||
* description: Built-in tool IDs
|
* description: Tools
|
||||||
* mcps:
|
* mcps:
|
||||||
* type: array
|
* type: array
|
||||||
* items:
|
* items:
|
||||||
@@ -425,8 +460,8 @@ export const updateAgent = async (req: Request, res: Response): Promise<Response
|
|||||||
export const patchAgent = async (req: Request, res: Response): Promise<Response> => {
|
export const patchAgent = async (req: Request, res: Response): Promise<Response> => {
|
||||||
const { agentId } = req.params
|
const { agentId } = req.params
|
||||||
try {
|
try {
|
||||||
logger.info(`Partially updating agent: ${agentId}`)
|
logger.debug('Partially updating agent', { agentId })
|
||||||
logger.debug('Partial update data:', req.body)
|
logger.debug('Patch payload', { body: req.body })
|
||||||
|
|
||||||
const { validatedBody } = req as ValidationRequest
|
const { validatedBody } = req as ValidationRequest
|
||||||
const updatePayload = (validatedBody ?? {}) as UpdateAgentRequest
|
const updatePayload = (validatedBody ?? {}) as UpdateAgentRequest
|
||||||
@@ -434,7 +469,7 @@ export const patchAgent = async (req: Request, res: Response): Promise<Response>
|
|||||||
const agent = await agentService.updateAgent(agentId, updatePayload)
|
const agent = await agentService.updateAgent(agentId, updatePayload)
|
||||||
|
|
||||||
if (!agent) {
|
if (!agent) {
|
||||||
logger.warn(`Agent not found for partial update: ${agentId}`)
|
logger.warn('Agent not found for partial update', { agentId })
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Agent not found',
|
message: 'Agent not found',
|
||||||
@@ -444,11 +479,11 @@ export const patchAgent = async (req: Request, res: Response): Promise<Response>
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(`Agent partially updated successfully: ${agentId}`)
|
logger.info('Agent patched', { agentId })
|
||||||
return res.json(agent)
|
return res.json(agent)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
if (error instanceof AgentModelValidationError) {
|
if (error instanceof AgentModelValidationError) {
|
||||||
logger.warn('Agent model validation error during partial update:', {
|
logger.warn('Agent model validation error during partial update', {
|
||||||
agentId,
|
agentId,
|
||||||
agentType: error.context.agentType,
|
agentType: error.context.agentType,
|
||||||
field: error.context.field,
|
field: error.context.field,
|
||||||
@@ -458,7 +493,7 @@ export const patchAgent = async (req: Request, res: Response): Promise<Response>
|
|||||||
return res.status(400).json(modelValidationErrorBody(error))
|
return res.status(400).json(modelValidationErrorBody(error))
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.error('Error partially updating agent:', error)
|
logger.error('Error partially updating agent', { error, agentId })
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: `Failed to partially update agent: ${error.message}`,
|
message: `Failed to partially update agent: ${error.message}`,
|
||||||
@@ -502,12 +537,12 @@ export const patchAgent = async (req: Request, res: Response): Promise<Response>
|
|||||||
export const deleteAgent = async (req: Request, res: Response): Promise<Response> => {
|
export const deleteAgent = async (req: Request, res: Response): Promise<Response> => {
|
||||||
try {
|
try {
|
||||||
const { agentId } = req.params
|
const { agentId } = req.params
|
||||||
logger.info(`Deleting agent: ${agentId}`)
|
logger.debug('Deleting agent', { agentId })
|
||||||
|
|
||||||
const deleted = await agentService.deleteAgent(agentId)
|
const deleted = await agentService.deleteAgent(agentId)
|
||||||
|
|
||||||
if (!deleted) {
|
if (!deleted) {
|
||||||
logger.warn(`Agent not found for deletion: ${agentId}`)
|
logger.warn('Agent not found for deletion', { agentId })
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Agent not found',
|
message: 'Agent not found',
|
||||||
@@ -517,10 +552,10 @@ export const deleteAgent = async (req: Request, res: Response): Promise<Response
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(`Agent deleted successfully: ${agentId}`)
|
logger.info('Agent deleted', { agentId })
|
||||||
return res.status(204).send()
|
return res.status(204).send()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error deleting agent:', error)
|
logger.error('Error deleting agent', { error, agentId: req.params.agentId })
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Failed to delete agent',
|
message: 'Failed to delete agent',
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
|
import { MESSAGE_STREAM_TIMEOUT_MS } from '@main/apiServer/config/timeouts'
|
||||||
|
import { createStreamAbortController, STREAM_TIMEOUT_REASON } from '@main/apiServer/utils/createStreamAbortController'
|
||||||
|
import { agentService, sessionMessageService, sessionService } from '@main/services/agents'
|
||||||
import { Request, Response } from 'express'
|
import { Request, Response } from 'express'
|
||||||
|
|
||||||
import { agentService, sessionMessageService, sessionService } from '../../../../services/agents'
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServerMessagesHandlers')
|
const logger = loggerService.withContext('ApiServerMessagesHandlers')
|
||||||
|
|
||||||
// Helper function to verify agent and session exist and belong together
|
// Helper function to verify agent and session exist and belong together
|
||||||
@@ -25,6 +26,8 @@ const verifyAgentAndSession = async (agentId: string, sessionId: string) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const createMessage = async (req: Request, res: Response): Promise<void> => {
|
export const createMessage = async (req: Request, res: Response): Promise<void> => {
|
||||||
|
let clearAbortTimeout: (() => void) | undefined
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const { agentId, sessionId } = req.params
|
const { agentId, sessionId } = req.params
|
||||||
|
|
||||||
@@ -32,8 +35,8 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
|
|
||||||
const messageData = req.body
|
const messageData = req.body
|
||||||
|
|
||||||
logger.info(`Creating streaming message for session: ${sessionId}`)
|
logger.info('Creating streaming message', { agentId, sessionId })
|
||||||
logger.debug('Streaming message data:', messageData)
|
logger.debug('Streaming message payload', { messageData })
|
||||||
|
|
||||||
// Set SSE headers
|
// Set SSE headers
|
||||||
res.setHeader('Content-Type', 'text/event-stream')
|
res.setHeader('Content-Type', 'text/event-stream')
|
||||||
@@ -42,7 +45,14 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
res.setHeader('Access-Control-Allow-Origin', '*')
|
res.setHeader('Access-Control-Allow-Origin', '*')
|
||||||
res.setHeader('Access-Control-Allow-Headers', 'Cache-Control')
|
res.setHeader('Access-Control-Allow-Headers', 'Cache-Control')
|
||||||
|
|
||||||
const abortController = new AbortController()
|
const {
|
||||||
|
abortController,
|
||||||
|
registerAbortHandler,
|
||||||
|
clearAbortTimeout: helperClearAbortTimeout
|
||||||
|
} = createStreamAbortController({
|
||||||
|
timeoutMs: MESSAGE_STREAM_TIMEOUT_MS
|
||||||
|
})
|
||||||
|
clearAbortTimeout = helperClearAbortTimeout
|
||||||
const { stream, completion } = await sessionMessageService.createSessionMessage(
|
const { stream, completion } = await sessionMessageService.createSessionMessage(
|
||||||
session,
|
session,
|
||||||
messageData,
|
messageData,
|
||||||
@@ -54,6 +64,10 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
let responseEnded = false
|
let responseEnded = false
|
||||||
let streamFinished = false
|
let streamFinished = false
|
||||||
|
|
||||||
|
const cleanupAbortTimeout = () => {
|
||||||
|
clearAbortTimeout?.()
|
||||||
|
}
|
||||||
|
|
||||||
const finalizeResponse = () => {
|
const finalizeResponse = () => {
|
||||||
if (responseEnded) {
|
if (responseEnded) {
|
||||||
return
|
return
|
||||||
@@ -64,11 +78,12 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
}
|
}
|
||||||
|
|
||||||
responseEnded = true
|
responseEnded = true
|
||||||
|
cleanupAbortTimeout()
|
||||||
try {
|
try {
|
||||||
// res.write('data: {"type":"finish"}\n\n')
|
// res.write('data: {"type":"finish"}\n\n')
|
||||||
res.write('data: [DONE]\n\n')
|
res.write('data: [DONE]\n\n')
|
||||||
} catch (writeError) {
|
} catch (writeError) {
|
||||||
logger.error('Error writing final sentinel to SSE stream:', { error: writeError as Error })
|
logger.error('Error writing final sentinel to SSE stream', { error: writeError as Error })
|
||||||
}
|
}
|
||||||
res.end()
|
res.end()
|
||||||
}
|
}
|
||||||
@@ -92,12 +107,51 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
* - Clean up event listeners to prevent memory leaks
|
* - Clean up event listeners to prevent memory leaks
|
||||||
* - Mark the response as ended to prevent further writes
|
* - Mark the response as ended to prevent further writes
|
||||||
*/
|
*/
|
||||||
const handleDisconnect = () => {
|
registerAbortHandler((abortReason) => {
|
||||||
|
cleanupAbortTimeout()
|
||||||
|
|
||||||
if (responseEnded) return
|
if (responseEnded) return
|
||||||
logger.info(`Client disconnected from streaming message for session: ${sessionId}`)
|
|
||||||
responseEnded = true
|
responseEnded = true
|
||||||
|
|
||||||
|
if (abortReason === STREAM_TIMEOUT_REASON) {
|
||||||
|
logger.error('Streaming message timeout', { agentId, sessionId })
|
||||||
|
try {
|
||||||
|
res.write(
|
||||||
|
`data: ${JSON.stringify({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
message: 'Stream timeout',
|
||||||
|
type: 'timeout_error',
|
||||||
|
code: 'stream_timeout'
|
||||||
|
}
|
||||||
|
})}\n\n`
|
||||||
|
)
|
||||||
|
} catch (writeError) {
|
||||||
|
logger.error('Error writing timeout to SSE stream', { error: writeError })
|
||||||
|
}
|
||||||
|
} else if (abortReason === 'Client disconnected') {
|
||||||
|
logger.info('Streaming client disconnected', { agentId, sessionId })
|
||||||
|
} else {
|
||||||
|
logger.warn('Streaming aborted', { agentId, sessionId, reason: abortReason })
|
||||||
|
}
|
||||||
|
|
||||||
|
reader.cancel(abortReason ?? 'stream aborted').catch(() => {})
|
||||||
|
|
||||||
|
if (!res.headersSent) {
|
||||||
|
res.setHeader('Content-Type', 'text/event-stream')
|
||||||
|
res.setHeader('Cache-Control', 'no-cache')
|
||||||
|
res.setHeader('Connection', 'keep-alive')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!res.writableEnded) {
|
||||||
|
res.end()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const handleDisconnect = () => {
|
||||||
|
if (abortController.signal.aborted) return
|
||||||
abortController.abort('Client disconnected')
|
abortController.abort('Client disconnected')
|
||||||
reader.cancel('Client disconnected').catch(() => {})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
req.on('close', handleDisconnect)
|
req.on('close', handleDisconnect)
|
||||||
@@ -119,7 +173,7 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
finalizeResponse()
|
finalizeResponse()
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (responseEnded) return
|
if (responseEnded) return
|
||||||
logger.error('Error reading agent stream:', { error })
|
logger.error('Error reading agent stream', { error })
|
||||||
try {
|
try {
|
||||||
res.write(
|
res.write(
|
||||||
`data: ${JSON.stringify({
|
`data: ${JSON.stringify({
|
||||||
@@ -132,15 +186,16 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
})}\n\n`
|
})}\n\n`
|
||||||
)
|
)
|
||||||
} catch (writeError) {
|
} catch (writeError) {
|
||||||
logger.error('Error writing stream error to SSE:', { error: writeError })
|
logger.error('Error writing stream error to SSE', { error: writeError })
|
||||||
}
|
}
|
||||||
responseEnded = true
|
responseEnded = true
|
||||||
|
cleanupAbortTimeout()
|
||||||
res.end()
|
res.end()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pumpStream().catch((error) => {
|
pumpStream().catch((error) => {
|
||||||
logger.error('Pump stream failure:', { error })
|
logger.error('Pump stream failure', { error })
|
||||||
})
|
})
|
||||||
|
|
||||||
completion
|
completion
|
||||||
@@ -150,7 +205,7 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
})
|
})
|
||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
if (responseEnded) return
|
if (responseEnded) return
|
||||||
logger.error(`Streaming message error for session: ${sessionId}:`, error)
|
logger.error('Streaming message error', { agentId, sessionId, error })
|
||||||
try {
|
try {
|
||||||
res.write(
|
res.write(
|
||||||
`data: ${JSON.stringify({
|
`data: ${JSON.stringify({
|
||||||
@@ -163,45 +218,22 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
})}\n\n`
|
})}\n\n`
|
||||||
)
|
)
|
||||||
} catch (writeError) {
|
} catch (writeError) {
|
||||||
logger.error('Error writing completion error to SSE stream:', { error: writeError })
|
logger.error('Error writing completion error to SSE stream', { error: writeError })
|
||||||
}
|
}
|
||||||
responseEnded = true
|
responseEnded = true
|
||||||
|
cleanupAbortTimeout()
|
||||||
res.end()
|
res.end()
|
||||||
})
|
})
|
||||||
|
|
||||||
// Set a timeout to prevent hanging indefinitely
|
|
||||||
const timeout = setTimeout(
|
|
||||||
() => {
|
|
||||||
if (!responseEnded) {
|
|
||||||
logger.error(`Streaming message timeout for session: ${sessionId}`)
|
|
||||||
try {
|
|
||||||
res.write(
|
|
||||||
`data: ${JSON.stringify({
|
|
||||||
type: 'error',
|
|
||||||
error: {
|
|
||||||
message: 'Stream timeout',
|
|
||||||
type: 'timeout_error',
|
|
||||||
code: 'stream_timeout'
|
|
||||||
}
|
|
||||||
})}\n\n`
|
|
||||||
)
|
|
||||||
} catch (writeError) {
|
|
||||||
logger.error('Error writing timeout to SSE stream:', { error: writeError })
|
|
||||||
}
|
|
||||||
abortController.abort('stream timeout')
|
|
||||||
reader.cancel('stream timeout').catch(() => {})
|
|
||||||
responseEnded = true
|
|
||||||
res.end()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
10 * 60 * 1000
|
|
||||||
) // 10 minutes timeout
|
|
||||||
|
|
||||||
// Clear timeout when response ends
|
// Clear timeout when response ends
|
||||||
res.on('close', () => clearTimeout(timeout))
|
res.on('close', cleanupAbortTimeout)
|
||||||
res.on('finish', () => clearTimeout(timeout))
|
res.on('finish', cleanupAbortTimeout)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error in streaming message handler:', error)
|
clearAbortTimeout?.()
|
||||||
|
logger.error('Error in streaming message handler', {
|
||||||
|
error,
|
||||||
|
agentId: req.params.agentId,
|
||||||
|
sessionId: req.params.sessionId
|
||||||
|
})
|
||||||
|
|
||||||
// Send error as SSE if possible
|
// Send error as SSE if possible
|
||||||
if (!res.headersSent) {
|
if (!res.headersSent) {
|
||||||
@@ -222,9 +254,64 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
|
|
||||||
res.write(`data: ${JSON.stringify(errorResponse)}\n\n`)
|
res.write(`data: ${JSON.stringify(errorResponse)}\n\n`)
|
||||||
} catch (writeError) {
|
} catch (writeError) {
|
||||||
logger.error('Error writing initial error to SSE stream:', { error: writeError })
|
logger.error('Error writing initial error to SSE stream', { error: writeError })
|
||||||
}
|
}
|
||||||
|
|
||||||
res.end()
|
res.end()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const deleteMessage = async (req: Request, res: Response): Promise<Response> => {
|
||||||
|
try {
|
||||||
|
const { agentId, sessionId, messageId: messageIdParam } = req.params
|
||||||
|
const messageId = Number(messageIdParam)
|
||||||
|
|
||||||
|
await verifyAgentAndSession(agentId, sessionId)
|
||||||
|
|
||||||
|
const deleted = await sessionMessageService.deleteSessionMessage(sessionId, messageId)
|
||||||
|
|
||||||
|
if (!deleted) {
|
||||||
|
logger.warn('Session message not found', { agentId, sessionId, messageId })
|
||||||
|
return res.status(404).json({
|
||||||
|
error: {
|
||||||
|
message: 'Message not found for this session',
|
||||||
|
type: 'not_found',
|
||||||
|
code: 'session_message_not_found'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info('Session message deleted', { agentId, sessionId, messageId })
|
||||||
|
return res.status(204).send()
|
||||||
|
} catch (error: any) {
|
||||||
|
if (error?.status === 404) {
|
||||||
|
logger.warn('Delete message failed - missing resource', {
|
||||||
|
agentId: req.params.agentId,
|
||||||
|
sessionId: req.params.sessionId,
|
||||||
|
messageId: req.params.messageId,
|
||||||
|
error
|
||||||
|
})
|
||||||
|
return res.status(404).json({
|
||||||
|
error: {
|
||||||
|
message: error.message,
|
||||||
|
type: 'not_found',
|
||||||
|
code: error.code ?? 'session_message_not_found'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.error('Error deleting session message', {
|
||||||
|
error,
|
||||||
|
agentId: req.params.agentId,
|
||||||
|
sessionId: req.params.sessionId,
|
||||||
|
messageId: Number(req.params.messageId)
|
||||||
|
})
|
||||||
|
return res.status(500).json({
|
||||||
|
error: {
|
||||||
|
message: 'Failed to delete session message',
|
||||||
|
type: 'internal_error',
|
||||||
|
code: 'session_message_delete_failed'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,15 +1,6 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import {
|
import { AgentModelValidationError, sessionMessageService, sessionService } from '@main/services/agents'
|
||||||
AgentModelValidationError,
|
import { ListAgentSessionsResponse, type ReplaceSessionRequest, UpdateSessionResponse } from '@types'
|
||||||
sessionMessageService,
|
|
||||||
sessionService
|
|
||||||
} from '@main/services/agents'
|
|
||||||
import {
|
|
||||||
CreateSessionResponse,
|
|
||||||
ListAgentSessionsResponse,
|
|
||||||
type ReplaceSessionRequest,
|
|
||||||
UpdateSessionResponse
|
|
||||||
} from '@types'
|
|
||||||
import { Request, Response } from 'express'
|
import { Request, Response } from 'express'
|
||||||
|
|
||||||
import type { ValidationRequest } from '../validators/zodValidator'
|
import type { ValidationRequest } from '../validators/zodValidator'
|
||||||
@@ -29,16 +20,16 @@ export const createSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
try {
|
try {
|
||||||
const sessionData = req.body
|
const sessionData = req.body
|
||||||
|
|
||||||
logger.info(`Creating new session for agent: ${agentId}`)
|
logger.debug('Creating new session', { agentId })
|
||||||
logger.debug('Session data:', sessionData)
|
logger.debug('Session payload', { sessionData })
|
||||||
|
|
||||||
const session = (await sessionService.createSession(agentId, sessionData)) satisfies CreateSessionResponse
|
const session = await sessionService.createSession(agentId, sessionData)
|
||||||
|
|
||||||
logger.info(`Session created successfully: ${session.id}`)
|
logger.info('Session created', { agentId, sessionId: session?.id })
|
||||||
return res.status(201).json(session)
|
return res.status(201).json(session)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
if (error instanceof AgentModelValidationError) {
|
if (error instanceof AgentModelValidationError) {
|
||||||
logger.warn('Session model validation error during create:', {
|
logger.warn('Session model validation error during create', {
|
||||||
agentId,
|
agentId,
|
||||||
agentType: error.context.agentType,
|
agentType: error.context.agentType,
|
||||||
field: error.context.field,
|
field: error.context.field,
|
||||||
@@ -48,7 +39,7 @@ export const createSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
return res.status(400).json(modelValidationErrorBody(error))
|
return res.status(400).json(modelValidationErrorBody(error))
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.error('Error creating session:', error)
|
logger.error('Error creating session', { error, agentId })
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: `Failed to create session: ${error.message}`,
|
message: `Failed to create session: ${error.message}`,
|
||||||
@@ -60,17 +51,23 @@ export const createSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const listSessions = async (req: Request, res: Response): Promise<Response> => {
|
export const listSessions = async (req: Request, res: Response): Promise<Response> => {
|
||||||
|
const { agentId } = req.params
|
||||||
try {
|
try {
|
||||||
const { agentId } = req.params
|
|
||||||
const limit = req.query.limit ? parseInt(req.query.limit as string) : 20
|
const limit = req.query.limit ? parseInt(req.query.limit as string) : 20
|
||||||
const offset = req.query.offset ? parseInt(req.query.offset as string) : 0
|
const offset = req.query.offset ? parseInt(req.query.offset as string) : 0
|
||||||
const status = req.query.status as any
|
const status = req.query.status as any
|
||||||
|
|
||||||
logger.info(`Listing sessions for agent: ${agentId} with limit=${limit}, offset=${offset}, status=${status}`)
|
logger.debug('Listing agent sessions', { agentId, limit, offset, status })
|
||||||
|
|
||||||
const result = await sessionService.listSessions(agentId, { limit, offset })
|
const result = await sessionService.listSessions(agentId, { limit, offset })
|
||||||
|
|
||||||
logger.info(`Retrieved ${result.sessions.length} sessions (total: ${result.total}) for agent: ${agentId}`)
|
logger.info('Agent sessions listed', {
|
||||||
|
agentId,
|
||||||
|
returned: result.sessions.length,
|
||||||
|
total: result.total,
|
||||||
|
limit,
|
||||||
|
offset
|
||||||
|
})
|
||||||
return res.json({
|
return res.json({
|
||||||
data: result.sessions,
|
data: result.sessions,
|
||||||
total: result.total,
|
total: result.total,
|
||||||
@@ -78,7 +75,7 @@ export const listSessions = async (req: Request, res: Response): Promise<Respons
|
|||||||
offset
|
offset
|
||||||
})
|
})
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error listing sessions:', error)
|
logger.error('Error listing sessions', { error, agentId })
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Failed to list sessions',
|
message: 'Failed to list sessions',
|
||||||
@@ -92,12 +89,12 @@ export const listSessions = async (req: Request, res: Response): Promise<Respons
|
|||||||
export const getSession = async (req: Request, res: Response): Promise<Response> => {
|
export const getSession = async (req: Request, res: Response): Promise<Response> => {
|
||||||
try {
|
try {
|
||||||
const { agentId, sessionId } = req.params
|
const { agentId, sessionId } = req.params
|
||||||
logger.info(`Getting session: ${sessionId} for agent: ${agentId}`)
|
logger.debug('Getting session', { agentId, sessionId })
|
||||||
|
|
||||||
const session = await sessionService.getSession(agentId, sessionId)
|
const session = await sessionService.getSession(agentId, sessionId)
|
||||||
|
|
||||||
if (!session) {
|
if (!session) {
|
||||||
logger.warn(`Session not found: ${sessionId}`)
|
logger.warn('Session not found', { agentId, sessionId })
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Session not found',
|
message: 'Session not found',
|
||||||
@@ -119,7 +116,7 @@ export const getSession = async (req: Request, res: Response): Promise<Response>
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
// Fetch session messages
|
// Fetch session messages
|
||||||
logger.info(`Fetching messages for session: ${sessionId}`)
|
logger.debug('Fetching session messages', { sessionId })
|
||||||
const { messages } = await sessionMessageService.listSessionMessages(sessionId)
|
const { messages } = await sessionMessageService.listSessionMessages(sessionId)
|
||||||
|
|
||||||
// Add messages to session
|
// Add messages to session
|
||||||
@@ -128,10 +125,10 @@ export const getSession = async (req: Request, res: Response): Promise<Response>
|
|||||||
messages: messages
|
messages: messages
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(`Session retrieved successfully: ${sessionId} with ${messages.length} messages`)
|
logger.info('Session retrieved', { agentId, sessionId, messageCount: messages.length })
|
||||||
return res.json(sessionWithMessages)
|
return res.json(sessionWithMessages)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error getting session:', error)
|
logger.error('Error getting session', { error, agentId: req.params.agentId, sessionId: req.params.sessionId })
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Failed to get session',
|
message: 'Failed to get session',
|
||||||
@@ -145,13 +142,13 @@ export const getSession = async (req: Request, res: Response): Promise<Response>
|
|||||||
export const updateSession = async (req: Request, res: Response): Promise<Response> => {
|
export const updateSession = async (req: Request, res: Response): Promise<Response> => {
|
||||||
const { agentId, sessionId } = req.params
|
const { agentId, sessionId } = req.params
|
||||||
try {
|
try {
|
||||||
logger.info(`Updating session: ${sessionId} for agent: ${agentId}`)
|
logger.debug('Updating session', { agentId, sessionId })
|
||||||
logger.debug('Update data:', req.body)
|
logger.debug('Replace payload', { body: req.body })
|
||||||
|
|
||||||
// First check if session exists and belongs to agent
|
// First check if session exists and belongs to agent
|
||||||
const existingSession = await sessionService.getSession(agentId, sessionId)
|
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||||
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
logger.warn('Session not found for update', { agentId, sessionId })
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Session not found for this agent',
|
message: 'Session not found for this agent',
|
||||||
@@ -167,7 +164,7 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
const session = await sessionService.updateSession(agentId, sessionId, replacePayload)
|
const session = await sessionService.updateSession(agentId, sessionId, replacePayload)
|
||||||
|
|
||||||
if (!session) {
|
if (!session) {
|
||||||
logger.warn(`Session not found for update: ${sessionId}`)
|
logger.warn('Session missing during update', { agentId, sessionId })
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Session not found',
|
message: 'Session not found',
|
||||||
@@ -177,11 +174,11 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(`Session updated successfully: ${sessionId}`)
|
logger.info('Session updated', { agentId, sessionId })
|
||||||
return res.json(session satisfies UpdateSessionResponse)
|
return res.json(session satisfies UpdateSessionResponse)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
if (error instanceof AgentModelValidationError) {
|
if (error instanceof AgentModelValidationError) {
|
||||||
logger.warn('Session model validation error during update:', {
|
logger.warn('Session model validation error during update', {
|
||||||
agentId,
|
agentId,
|
||||||
sessionId,
|
sessionId,
|
||||||
agentType: error.context.agentType,
|
agentType: error.context.agentType,
|
||||||
@@ -192,7 +189,7 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
return res.status(400).json(modelValidationErrorBody(error))
|
return res.status(400).json(modelValidationErrorBody(error))
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.error('Error updating session:', error)
|
logger.error('Error updating session', { error, agentId, sessionId })
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: `Failed to update session: ${error.message}`,
|
message: `Failed to update session: ${error.message}`,
|
||||||
@@ -206,13 +203,13 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
export const patchSession = async (req: Request, res: Response): Promise<Response> => {
|
export const patchSession = async (req: Request, res: Response): Promise<Response> => {
|
||||||
const { agentId, sessionId } = req.params
|
const { agentId, sessionId } = req.params
|
||||||
try {
|
try {
|
||||||
logger.info(`Patching session: ${sessionId} for agent: ${agentId}`)
|
logger.debug('Patching session', { agentId, sessionId })
|
||||||
logger.debug('Patch data:', req.body)
|
logger.debug('Patch payload', { body: req.body })
|
||||||
|
|
||||||
// First check if session exists and belongs to agent
|
// First check if session exists and belongs to agent
|
||||||
const existingSession = await sessionService.getSession(agentId, sessionId)
|
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||||
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
logger.warn('Session not found for patch', { agentId, sessionId })
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Session not found for this agent',
|
message: 'Session not found for this agent',
|
||||||
@@ -226,7 +223,7 @@ export const patchSession = async (req: Request, res: Response): Promise<Respons
|
|||||||
const session = await sessionService.updateSession(agentId, sessionId, updateSession)
|
const session = await sessionService.updateSession(agentId, sessionId, updateSession)
|
||||||
|
|
||||||
if (!session) {
|
if (!session) {
|
||||||
logger.warn(`Session not found for patch: ${sessionId}`)
|
logger.warn('Session missing while patching', { agentId, sessionId })
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Session not found',
|
message: 'Session not found',
|
||||||
@@ -236,11 +233,11 @@ export const patchSession = async (req: Request, res: Response): Promise<Respons
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(`Session patched successfully: ${sessionId}`)
|
logger.info('Session patched', { agentId, sessionId })
|
||||||
return res.json(session)
|
return res.json(session)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
if (error instanceof AgentModelValidationError) {
|
if (error instanceof AgentModelValidationError) {
|
||||||
logger.warn('Session model validation error during patch:', {
|
logger.warn('Session model validation error during patch', {
|
||||||
agentId,
|
agentId,
|
||||||
sessionId,
|
sessionId,
|
||||||
agentType: error.context.agentType,
|
agentType: error.context.agentType,
|
||||||
@@ -251,7 +248,7 @@ export const patchSession = async (req: Request, res: Response): Promise<Respons
|
|||||||
return res.status(400).json(modelValidationErrorBody(error))
|
return res.status(400).json(modelValidationErrorBody(error))
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.error('Error patching session:', error)
|
logger.error('Error patching session', { error, agentId, sessionId })
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: `Failed to patch session, ${error.message}`,
|
message: `Failed to patch session, ${error.message}`,
|
||||||
@@ -265,12 +262,12 @@ export const patchSession = async (req: Request, res: Response): Promise<Respons
|
|||||||
export const deleteSession = async (req: Request, res: Response): Promise<Response> => {
|
export const deleteSession = async (req: Request, res: Response): Promise<Response> => {
|
||||||
try {
|
try {
|
||||||
const { agentId, sessionId } = req.params
|
const { agentId, sessionId } = req.params
|
||||||
logger.info(`Deleting session: ${sessionId} for agent: ${agentId}`)
|
logger.debug('Deleting session', { agentId, sessionId })
|
||||||
|
|
||||||
// First check if session exists and belongs to agent
|
// First check if session exists and belongs to agent
|
||||||
const existingSession = await sessionService.getSession(agentId, sessionId)
|
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||||
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
logger.warn('Session not found for deletion', { agentId, sessionId })
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Session not found for this agent',
|
message: 'Session not found for this agent',
|
||||||
@@ -283,7 +280,7 @@ export const deleteSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
const deleted = await sessionService.deleteSession(agentId, sessionId)
|
const deleted = await sessionService.deleteSession(agentId, sessionId)
|
||||||
|
|
||||||
if (!deleted) {
|
if (!deleted) {
|
||||||
logger.warn(`Session not found for deletion: ${sessionId}`)
|
logger.warn('Session missing during delete', { agentId, sessionId })
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Session not found',
|
message: 'Session not found',
|
||||||
@@ -293,10 +290,36 @@ export const deleteSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(`Session deleted successfully: ${sessionId}`)
|
logger.info('Session deleted', { agentId, sessionId })
|
||||||
|
|
||||||
|
const { total } = await sessionService.listSessions(agentId, { limit: 1 })
|
||||||
|
|
||||||
|
if (total === 0) {
|
||||||
|
logger.info('No remaining sessions, creating default', { agentId })
|
||||||
|
try {
|
||||||
|
const fallbackSession = await sessionService.createSession(agentId, {})
|
||||||
|
logger.info('Default session created after delete', {
|
||||||
|
agentId,
|
||||||
|
sessionId: fallbackSession?.id
|
||||||
|
})
|
||||||
|
} catch (recoveryError: any) {
|
||||||
|
logger.error('Failed to recreate session after deleting last session', {
|
||||||
|
agentId,
|
||||||
|
error: recoveryError
|
||||||
|
})
|
||||||
|
return res.status(500).json({
|
||||||
|
error: {
|
||||||
|
message: `Failed to recreate session after deletion: ${recoveryError.message}`,
|
||||||
|
type: 'internal_error',
|
||||||
|
code: 'session_recovery_failed'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return res.status(204).send()
|
return res.status(204).send()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error deleting session:', error)
|
logger.error('Error deleting session', { error, agentId: req.params.agentId, sessionId: req.params.sessionId })
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Failed to delete session',
|
message: 'Failed to delete session',
|
||||||
@@ -314,11 +337,16 @@ export const listAllSessions = async (req: Request, res: Response): Promise<Resp
|
|||||||
const offset = req.query.offset ? parseInt(req.query.offset as string) : 0
|
const offset = req.query.offset ? parseInt(req.query.offset as string) : 0
|
||||||
const status = req.query.status as any
|
const status = req.query.status as any
|
||||||
|
|
||||||
logger.info(`Listing all sessions with limit=${limit}, offset=${offset}, status=${status}`)
|
logger.debug('Listing all sessions', { limit, offset, status })
|
||||||
|
|
||||||
const result = await sessionService.listSessions(undefined, { limit, offset })
|
const result = await sessionService.listSessions(undefined, { limit, offset })
|
||||||
|
|
||||||
logger.info(`Retrieved ${result.sessions.length} sessions (total: ${result.total})`)
|
logger.info('Sessions listed', {
|
||||||
|
returned: result.sessions.length,
|
||||||
|
total: result.total,
|
||||||
|
limit,
|
||||||
|
offset
|
||||||
|
})
|
||||||
return res.json({
|
return res.json({
|
||||||
data: result.sessions,
|
data: result.sessions,
|
||||||
total: result.total,
|
total: result.total,
|
||||||
@@ -326,7 +354,7 @@ export const listAllSessions = async (req: Request, res: Response): Promise<Resp
|
|||||||
offset
|
offset
|
||||||
} satisfies ListAgentSessionsResponse)
|
} satisfies ListAgentSessionsResponse)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error listing all sessions:', error)
|
logger.error('Error listing all sessions', { error })
|
||||||
return res.status(500).json({
|
return res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Failed to list sessions',
|
message: 'Failed to list sessions',
|
||||||
@@ -336,35 +364,3 @@ export const listAllSessions = async (req: Request, res: Response): Promise<Resp
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export const getSessionById = async (req: Request, res: Response): Promise<Response> => {
|
|
||||||
try {
|
|
||||||
const { sessionId } = req.params
|
|
||||||
logger.info(`Getting session: ${sessionId}`)
|
|
||||||
|
|
||||||
const session = await sessionService.getSessionById(sessionId)
|
|
||||||
|
|
||||||
if (!session) {
|
|
||||||
logger.warn(`Session not found: ${sessionId}`)
|
|
||||||
return res.status(404).json({
|
|
||||||
error: {
|
|
||||||
message: 'Session not found',
|
|
||||||
type: 'not_found',
|
|
||||||
code: 'session_not_found'
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(`Session retrieved successfully: ${sessionId}`)
|
|
||||||
return res.json(session)
|
|
||||||
} catch (error: any) {
|
|
||||||
logger.error('Error getting session:', error)
|
|
||||||
return res.status(500).json({
|
|
||||||
error: {
|
|
||||||
message: 'Failed to get session',
|
|
||||||
type: 'internal_error',
|
|
||||||
code: 'session_get_failed'
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import {
|
|||||||
validateSession,
|
validateSession,
|
||||||
validateSessionId,
|
validateSessionId,
|
||||||
validateSessionMessage,
|
validateSessionMessage,
|
||||||
|
validateSessionMessageId,
|
||||||
validateSessionReplace,
|
validateSessionReplace,
|
||||||
validateSessionUpdate
|
validateSessionUpdate
|
||||||
} from './validators'
|
} from './validators'
|
||||||
@@ -362,7 +363,7 @@ const agentsRouter = express.Router()
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /api/agents:
|
* /agents:
|
||||||
* post:
|
* post:
|
||||||
* summary: Create a new agent
|
* summary: Create a new agent
|
||||||
* tags: [Agents]
|
* tags: [Agents]
|
||||||
@@ -391,7 +392,7 @@ agentsRouter.post('/', validateAgent, handleValidationErrors, agentHandlers.crea
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /api/agents:
|
* /agents:
|
||||||
* get:
|
* get:
|
||||||
* summary: List all agents with pagination
|
* summary: List all agents with pagination
|
||||||
* tags: [Agents]
|
* tags: [Agents]
|
||||||
@@ -429,7 +430,7 @@ agentsRouter.get('/', validatePagination, handleValidationErrors, agentHandlers.
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /api/agents/{agentId}:
|
* /agents/{agentId}:
|
||||||
* get:
|
* get:
|
||||||
* summary: Get agent by ID
|
* summary: Get agent by ID
|
||||||
* tags: [Agents]
|
* tags: [Agents]
|
||||||
@@ -457,7 +458,7 @@ agentsRouter.get('/', validatePagination, handleValidationErrors, agentHandlers.
|
|||||||
agentsRouter.get('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.getAgent)
|
agentsRouter.get('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.getAgent)
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /api/agents/{agentId}:
|
* /agents/{agentId}:
|
||||||
* put:
|
* put:
|
||||||
* summary: Replace agent (full update)
|
* summary: Replace agent (full update)
|
||||||
* tags: [Agents]
|
* tags: [Agents]
|
||||||
@@ -497,7 +498,7 @@ agentsRouter.get('/:agentId', validateAgentId, handleValidationErrors, agentHand
|
|||||||
agentsRouter.put('/:agentId', validateAgentId, validateAgentReplace, handleValidationErrors, agentHandlers.updateAgent)
|
agentsRouter.put('/:agentId', validateAgentId, validateAgentReplace, handleValidationErrors, agentHandlers.updateAgent)
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /api/agents/{agentId}:
|
* /agents/{agentId}:
|
||||||
* patch:
|
* patch:
|
||||||
* summary: Update agent (partial update)
|
* summary: Update agent (partial update)
|
||||||
* tags: [Agents]
|
* tags: [Agents]
|
||||||
@@ -537,7 +538,7 @@ agentsRouter.put('/:agentId', validateAgentId, validateAgentReplace, handleValid
|
|||||||
agentsRouter.patch('/:agentId', validateAgentId, validateAgentUpdate, handleValidationErrors, agentHandlers.patchAgent)
|
agentsRouter.patch('/:agentId', validateAgentId, validateAgentUpdate, handleValidationErrors, agentHandlers.patchAgent)
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /api/agents/{agentId}:
|
* /agents/{agentId}:
|
||||||
* delete:
|
* delete:
|
||||||
* summary: Delete agent
|
* summary: Delete agent
|
||||||
* tags: [Agents]
|
* tags: [Agents]
|
||||||
@@ -567,7 +568,7 @@ const createSessionsRouter = (): express.Router => {
|
|||||||
// Session CRUD routes (nested under agent)
|
// Session CRUD routes (nested under agent)
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /api/agents/{agentId}/sessions:
|
* /agents/{agentId}/sessions:
|
||||||
* post:
|
* post:
|
||||||
* summary: Create a new session for an agent
|
* summary: Create a new session for an agent
|
||||||
* tags: [Sessions]
|
* tags: [Sessions]
|
||||||
@@ -608,7 +609,7 @@ const createSessionsRouter = (): express.Router => {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /api/agents/{agentId}/sessions:
|
* /agents/{agentId}/sessions:
|
||||||
* get:
|
* get:
|
||||||
* summary: List sessions for an agent
|
* summary: List sessions for an agent
|
||||||
* tags: [Sessions]
|
* tags: [Sessions]
|
||||||
@@ -657,7 +658,7 @@ const createSessionsRouter = (): express.Router => {
|
|||||||
sessionsRouter.get('/', validatePagination, handleValidationErrors, sessionHandlers.listSessions)
|
sessionsRouter.get('/', validatePagination, handleValidationErrors, sessionHandlers.listSessions)
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /api/agents/{agentId}/sessions/{sessionId}:
|
* /agents/{agentId}/sessions/{sessionId}:
|
||||||
* get:
|
* get:
|
||||||
* summary: Get session by ID
|
* summary: Get session by ID
|
||||||
* tags: [Sessions]
|
* tags: [Sessions]
|
||||||
@@ -691,7 +692,7 @@ const createSessionsRouter = (): express.Router => {
|
|||||||
sessionsRouter.get('/:sessionId', validateSessionId, handleValidationErrors, sessionHandlers.getSession)
|
sessionsRouter.get('/:sessionId', validateSessionId, handleValidationErrors, sessionHandlers.getSession)
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /api/agents/{agentId}/sessions/{sessionId}:
|
* /agents/{agentId}/sessions/{sessionId}:
|
||||||
* put:
|
* put:
|
||||||
* summary: Replace session (full update)
|
* summary: Replace session (full update)
|
||||||
* tags: [Sessions]
|
* tags: [Sessions]
|
||||||
@@ -743,7 +744,7 @@ const createSessionsRouter = (): express.Router => {
|
|||||||
)
|
)
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /api/agents/{agentId}/sessions/{sessionId}:
|
* /agents/{agentId}/sessions/{sessionId}:
|
||||||
* patch:
|
* patch:
|
||||||
* summary: Update session (partial update)
|
* summary: Update session (partial update)
|
||||||
* tags: [Sessions]
|
* tags: [Sessions]
|
||||||
@@ -795,7 +796,7 @@ const createSessionsRouter = (): express.Router => {
|
|||||||
)
|
)
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /api/agents/{agentId}/sessions/{sessionId}:
|
* /agents/{agentId}/sessions/{sessionId}:
|
||||||
* delete:
|
* delete:
|
||||||
* summary: Delete session
|
* summary: Delete session
|
||||||
* tags: [Sessions]
|
* tags: [Sessions]
|
||||||
@@ -834,7 +835,7 @@ const createMessagesRouter = (): express.Router => {
|
|||||||
// Message CRUD routes (nested under agent/session)
|
// Message CRUD routes (nested under agent/session)
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /api/agents/{agentId}/sessions/{sessionId}/messages:
|
* /agents/{agentId}/sessions/{sessionId}/messages:
|
||||||
* post:
|
* post:
|
||||||
* summary: Create a new message in a session
|
* summary: Create a new message in a session
|
||||||
* tags: [Messages]
|
* tags: [Messages]
|
||||||
@@ -904,6 +905,43 @@ const createMessagesRouter = (): express.Router => {
|
|||||||
* $ref: '#/components/schemas/ErrorResponse'
|
* $ref: '#/components/schemas/ErrorResponse'
|
||||||
*/
|
*/
|
||||||
messagesRouter.post('/', validateSessionMessage, handleValidationErrors, messageHandlers.createMessage)
|
messagesRouter.post('/', validateSessionMessage, handleValidationErrors, messageHandlers.createMessage)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @swagger
|
||||||
|
* /agents/{agentId}/sessions/{sessionId}/messages/{messageId}:
|
||||||
|
* delete:
|
||||||
|
* summary: Delete a message from a session
|
||||||
|
* tags: [Messages]
|
||||||
|
* parameters:
|
||||||
|
* - in: path
|
||||||
|
* name: agentId
|
||||||
|
* required: true
|
||||||
|
* schema:
|
||||||
|
* type: string
|
||||||
|
* description: Agent ID
|
||||||
|
* - in: path
|
||||||
|
* name: sessionId
|
||||||
|
* required: true
|
||||||
|
* schema:
|
||||||
|
* type: string
|
||||||
|
* description: Session ID
|
||||||
|
* - in: path
|
||||||
|
* name: messageId
|
||||||
|
* required: true
|
||||||
|
* schema:
|
||||||
|
* type: integer
|
||||||
|
* description: Message ID
|
||||||
|
* responses:
|
||||||
|
* 204:
|
||||||
|
* description: Message deleted successfully
|
||||||
|
* 404:
|
||||||
|
* description: Agent, session, or message not found
|
||||||
|
* content:
|
||||||
|
* application/json:
|
||||||
|
* schema:
|
||||||
|
* $ref: '#/components/schemas/ErrorResponse'
|
||||||
|
*/
|
||||||
|
messagesRouter.delete('/:messageId', validateSessionMessageId, handleValidationErrors, messageHandlers.deleteMessage)
|
||||||
return messagesRouter
|
return messagesRouter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,10 @@ export const checkAgentExists = async (req: Request, res: Response, next: any):
|
|||||||
|
|
||||||
next()
|
next()
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Error checking agent existence:', error as Error)
|
logger.error('Error checking agent existence', {
|
||||||
|
error: error as Error,
|
||||||
|
agentId: req.params.agentId
|
||||||
|
})
|
||||||
res.status(500).json({
|
res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Failed to validate agent',
|
message: 'Failed to validate agent',
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
import { CreateSessionMessageRequestSchema } from '@types'
|
import { CreateSessionMessageRequestSchema, SessionMessageIdParamSchema } from '@types'
|
||||||
|
|
||||||
import { createZodValidator } from './zodValidator'
|
import { createZodValidator } from './zodValidator'
|
||||||
|
|
||||||
export const validateSessionMessage = createZodValidator({
|
export const validateSessionMessage = createZodValidator({
|
||||||
body: CreateSessionMessageRequestSchema
|
body: CreateSessionMessageRequestSchema
|
||||||
})
|
})
|
||||||
|
|
||||||
|
export const validateSessionMessageId = createZodValidator({
|
||||||
|
params: SessionMessageIdParamSchema
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { NextFunction,Request, Response } from 'express'
|
import { NextFunction, Request, Response } from 'express'
|
||||||
import { ZodError, ZodType } from 'zod'
|
import { ZodError, ZodType } from 'zod'
|
||||||
|
|
||||||
export interface ValidationRequest extends Request {
|
export interface ValidationRequest extends Request {
|
||||||
@@ -35,7 +35,7 @@ export const createZodValidator = (config: ZodValidationConfig) => {
|
|||||||
type: 'field',
|
type: 'field',
|
||||||
value: err.input,
|
value: err.input,
|
||||||
msg: err.message,
|
msg: err.message,
|
||||||
path: err.path.map(p => String(p)).join('.'),
|
path: err.path.map((p) => String(p)).join('.'),
|
||||||
location: getLocationFromPath(err.path, config)
|
location: getLocationFromPath(err.path, config)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ interface ErrorResponseBody {
|
|||||||
|
|
||||||
const mapChatCompletionError = (error: unknown): { status: number; body: ErrorResponseBody } => {
|
const mapChatCompletionError = (error: unknown): { status: number; body: ErrorResponseBody } => {
|
||||||
if (error instanceof ChatCompletionValidationError) {
|
if (error instanceof ChatCompletionValidationError) {
|
||||||
logger.warn('Chat completion validation error:', {
|
logger.warn('Chat completion validation error', {
|
||||||
errors: error.errors
|
errors: error.errors
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ const mapChatCompletionError = (error: unknown): { status: number; body: ErrorRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (error instanceof ChatCompletionModelError) {
|
if (error instanceof ChatCompletionModelError) {
|
||||||
logger.warn('Chat completion model error:', error.error)
|
logger.warn('Chat completion model error', error.error)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
status: 400,
|
status: 400,
|
||||||
@@ -72,7 +72,7 @@ const mapChatCompletionError = (error: unknown): { status: number; body: ErrorRe
|
|||||||
errorCode = 'upstream_error'
|
errorCode = 'upstream_error'
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.error('Chat completion error:', { error })
|
logger.error('Chat completion error', { error })
|
||||||
|
|
||||||
return {
|
return {
|
||||||
status: statusCode,
|
status: statusCode,
|
||||||
@@ -86,7 +86,7 @@ const mapChatCompletionError = (error: unknown): { status: number; body: ErrorRe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.error('Chat completion unknown error:', { error })
|
logger.error('Chat completion unknown error', { error })
|
||||||
|
|
||||||
return {
|
return {
|
||||||
status: 500,
|
status: 500,
|
||||||
@@ -193,7 +193,7 @@ router.post('/completions', async (req: Request, res: Response) => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info('Chat completion request:', {
|
logger.debug('Chat completion request', {
|
||||||
model: request.model,
|
model: request.model,
|
||||||
messageCount: request.messages?.length || 0,
|
messageCount: request.messages?.length || 0,
|
||||||
stream: request.stream,
|
stream: request.stream,
|
||||||
@@ -217,7 +217,7 @@ router.post('/completions', async (req: Request, res: Response) => {
|
|||||||
}
|
}
|
||||||
res.write('data: [DONE]\n\n')
|
res.write('data: [DONE]\n\n')
|
||||||
} catch (streamError: any) {
|
} catch (streamError: any) {
|
||||||
logger.error('Stream error:', streamError)
|
logger.error('Stream error', { error: streamError })
|
||||||
res.write(
|
res.write(
|
||||||
`data: ${JSON.stringify({
|
`data: ${JSON.stringify({
|
||||||
error: {
|
error: {
|
||||||
|
|||||||
@@ -43,14 +43,14 @@ const router = express.Router()
|
|||||||
*/
|
*/
|
||||||
router.get('/', async (req: Request, res: Response) => {
|
router.get('/', async (req: Request, res: Response) => {
|
||||||
try {
|
try {
|
||||||
logger.info('Get all MCP servers request received')
|
logger.debug('Listing MCP servers')
|
||||||
const servers = await mcpApiService.getAllServers(req)
|
const servers = await mcpApiService.getAllServers(req)
|
||||||
return res.json({
|
return res.json({
|
||||||
success: true,
|
success: true,
|
||||||
data: servers
|
data: servers
|
||||||
})
|
})
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error fetching MCP servers:', error)
|
logger.error('Error fetching MCP servers', { error })
|
||||||
return res.status(503).json({
|
return res.status(503).json({
|
||||||
success: false,
|
success: false,
|
||||||
error: {
|
error: {
|
||||||
@@ -103,10 +103,12 @@ router.get('/', async (req: Request, res: Response) => {
|
|||||||
*/
|
*/
|
||||||
router.get('/:server_id', async (req: Request, res: Response) => {
|
router.get('/:server_id', async (req: Request, res: Response) => {
|
||||||
try {
|
try {
|
||||||
logger.info('Get MCP server info request received')
|
logger.debug('Get MCP server info request received', {
|
||||||
|
serverId: req.params.server_id
|
||||||
|
})
|
||||||
const server = await mcpApiService.getServerInfo(req.params.server_id)
|
const server = await mcpApiService.getServerInfo(req.params.server_id)
|
||||||
if (!server) {
|
if (!server) {
|
||||||
logger.warn('MCP server not found')
|
logger.warn('MCP server not found', { serverId: req.params.server_id })
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
success: false,
|
success: false,
|
||||||
error: {
|
error: {
|
||||||
@@ -121,7 +123,7 @@ router.get('/:server_id', async (req: Request, res: Response) => {
|
|||||||
data: server
|
data: server
|
||||||
})
|
})
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error fetching MCP server info:', error)
|
logger.error('Error fetching MCP server info', { error, serverId: req.params.server_id })
|
||||||
return res.status(503).json({
|
return res.status(503).json({
|
||||||
success: false,
|
success: false,
|
||||||
error: {
|
error: {
|
||||||
@@ -137,7 +139,7 @@ router.get('/:server_id', async (req: Request, res: Response) => {
|
|||||||
router.all('/:server_id/mcp', async (req: Request, res: Response) => {
|
router.all('/:server_id/mcp', async (req: Request, res: Response) => {
|
||||||
const server = await mcpApiService.getServerById(req.params.server_id)
|
const server = await mcpApiService.getServerById(req.params.server_id)
|
||||||
if (!server) {
|
if (!server) {
|
||||||
logger.warn('MCP server not found')
|
logger.warn('MCP server not found', { serverId: req.params.server_id })
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
success: false,
|
success: false,
|
||||||
error: {
|
error: {
|
||||||
|
|||||||
@@ -1,13 +1,85 @@
|
|||||||
import { MessageCreateParams } from '@anthropic-ai/sdk/resources'
|
import { MessageCreateParams } from '@anthropic-ai/sdk/resources'
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
import { Provider } from '@types'
|
||||||
import express, { Request, Response } from 'express'
|
import express, { Request, Response } from 'express'
|
||||||
|
|
||||||
import { loggerService } from '../../services/LoggerService'
|
|
||||||
import { messagesService } from '../services/messages'
|
import { messagesService } from '../services/messages'
|
||||||
import { validateModelId } from '../utils'
|
import { getProviderById, validateModelId } from '../utils'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServerMessagesRoutes')
|
const logger = loggerService.withContext('ApiServerMessagesRoutes')
|
||||||
|
|
||||||
const router = express.Router()
|
const router = express.Router()
|
||||||
|
const providerRouter = express.Router({ mergeParams: true })
|
||||||
|
|
||||||
|
// Helper function for basic request validation
|
||||||
|
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
|
||||||
|
const request: MessageCreateParams = req.body
|
||||||
|
|
||||||
|
if (!request) {
|
||||||
|
return {
|
||||||
|
valid: false,
|
||||||
|
error: {
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'invalid_request_error',
|
||||||
|
message: 'Request body is required'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { valid: true }
|
||||||
|
}
|
||||||
|
|
||||||
|
interface HandleMessageProcessingOptions {
|
||||||
|
req: Request
|
||||||
|
res: Response
|
||||||
|
provider: Provider
|
||||||
|
request: MessageCreateParams
|
||||||
|
modelId?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleMessageProcessing({
|
||||||
|
req,
|
||||||
|
res,
|
||||||
|
provider,
|
||||||
|
request,
|
||||||
|
modelId
|
||||||
|
}: HandleMessageProcessingOptions): Promise<void> {
|
||||||
|
try {
|
||||||
|
const validation = messagesService.validateRequest(request)
|
||||||
|
if (!validation.isValid) {
|
||||||
|
res.status(400).json({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'invalid_request_error',
|
||||||
|
message: validation.errors.join('; ')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const extraHeaders = messagesService.prepareHeaders(req.headers)
|
||||||
|
const { client, anthropicRequest } = await messagesService.processMessage({
|
||||||
|
provider,
|
||||||
|
request,
|
||||||
|
extraHeaders,
|
||||||
|
modelId
|
||||||
|
})
|
||||||
|
|
||||||
|
if (request.stream) {
|
||||||
|
await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await client.messages.create(anthropicRequest)
|
||||||
|
res.json(response)
|
||||||
|
} catch (error: any) {
|
||||||
|
logger.error('Message processing error', { error })
|
||||||
|
const { statusCode, errorResponse } = messagesService.transformError(error)
|
||||||
|
res.status(statusCode).json(errorResponse)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
@@ -133,32 +205,23 @@ const router = express.Router()
|
|||||||
* description: Internal server error
|
* description: Internal server error
|
||||||
*/
|
*/
|
||||||
router.post('/', async (req: Request, res: Response) => {
|
router.post('/', async (req: Request, res: Response) => {
|
||||||
|
// Validate request body
|
||||||
|
const bodyValidation = await validateRequestBody(req)
|
||||||
|
if (!bodyValidation.valid) {
|
||||||
|
return res.status(400).json(bodyValidation.error)
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const request: MessageCreateParams = req.body
|
const request: MessageCreateParams = req.body
|
||||||
|
|
||||||
if (!request) {
|
|
||||||
return res.status(400).json({
|
|
||||||
type: 'error',
|
|
||||||
error: {
|
|
||||||
type: 'invalid_request_error',
|
|
||||||
message: 'Request body is required'
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info('Anthropic message request:', {
|
|
||||||
model: request.model,
|
|
||||||
messageCount: request.messages?.length || 0,
|
|
||||||
stream: request.stream,
|
|
||||||
max_tokens: request.max_tokens,
|
|
||||||
temperature: request.temperature
|
|
||||||
})
|
|
||||||
|
|
||||||
// Validate model ID and get provider
|
// Validate model ID and get provider
|
||||||
const modelValidation = await validateModelId(request.model)
|
const modelValidation = await validateModelId(request.model)
|
||||||
if (!modelValidation.valid) {
|
if (!modelValidation.valid) {
|
||||||
const error = modelValidation.error!
|
const error = modelValidation.error!
|
||||||
logger.warn(`Model validation failed for '${request.model}':`, error)
|
logger.warn('Model validation failed', {
|
||||||
|
model: request.model,
|
||||||
|
error
|
||||||
|
})
|
||||||
return res.status(400).json({
|
return res.status(400).json({
|
||||||
type: 'error',
|
type: 'error',
|
||||||
error: {
|
error: {
|
||||||
@@ -169,122 +232,172 @@ router.post('/', async (req: Request, res: Response) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const provider = modelValidation.provider!
|
const provider = modelValidation.provider!
|
||||||
|
|
||||||
// Ensure provider is Anthropic type
|
|
||||||
if (provider.type !== 'anthropic') {
|
|
||||||
return res.status(400).json({
|
|
||||||
type: 'error',
|
|
||||||
error: {
|
|
||||||
type: 'invalid_request_error',
|
|
||||||
message: `Invalid provider type '${provider.type}' for messages endpoint. Expected 'anthropic' provider.`
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const modelId = modelValidation.modelId!
|
const modelId = modelValidation.modelId!
|
||||||
request.model = modelId
|
|
||||||
|
|
||||||
logger.info('Model validation successful:', {
|
return handleMessageProcessing({ req, res, provider, request, modelId })
|
||||||
provider: provider.id,
|
|
||||||
providerType: provider.type,
|
|
||||||
modelId: modelId,
|
|
||||||
fullModelId: request.model
|
|
||||||
})
|
|
||||||
|
|
||||||
// Validate request
|
|
||||||
const validation = messagesService.validateRequest(request)
|
|
||||||
if (!validation.isValid) {
|
|
||||||
return res.status(400).json({
|
|
||||||
type: 'error',
|
|
||||||
error: {
|
|
||||||
type: 'invalid_request_error',
|
|
||||||
message: validation.errors.join('; ')
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle streaming
|
|
||||||
if (request.stream) {
|
|
||||||
res.setHeader('Content-Type', 'text/event-stream; charset=utf-8')
|
|
||||||
res.setHeader('Cache-Control', 'no-cache, no-transform')
|
|
||||||
res.setHeader('Connection', 'keep-alive')
|
|
||||||
res.setHeader('X-Accel-Buffering', 'no')
|
|
||||||
res.flushHeaders()
|
|
||||||
|
|
||||||
try {
|
|
||||||
for await (const chunk of messagesService.processStreamingMessage(request, provider)) {
|
|
||||||
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
|
|
||||||
}
|
|
||||||
res.write('data: [DONE]\n\n')
|
|
||||||
} catch (streamError: any) {
|
|
||||||
logger.error('Stream error:', streamError)
|
|
||||||
res.write(
|
|
||||||
`data: ${JSON.stringify({
|
|
||||||
type: 'error',
|
|
||||||
error: {
|
|
||||||
type: 'api_error',
|
|
||||||
message: 'Stream processing error'
|
|
||||||
}
|
|
||||||
})}\n\n`
|
|
||||||
)
|
|
||||||
} finally {
|
|
||||||
res.end()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle non-streaming
|
|
||||||
const response = await messagesService.processMessage(request, provider)
|
|
||||||
return res.json(response)
|
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Anthropic message error:', error)
|
logger.error('Message processing error', { error })
|
||||||
|
const { statusCode, errorResponse } = messagesService.transformError(error)
|
||||||
let statusCode = 500
|
return res.status(statusCode).json(errorResponse)
|
||||||
let errorType = 'api_error'
|
|
||||||
let errorMessage = 'Internal server error'
|
|
||||||
|
|
||||||
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
|
|
||||||
const anthropicError = error?.error
|
|
||||||
|
|
||||||
if (anthropicStatus) {
|
|
||||||
statusCode = anthropicStatus
|
|
||||||
}
|
|
||||||
|
|
||||||
if (anthropicError?.type) {
|
|
||||||
errorType = anthropicError.type
|
|
||||||
}
|
|
||||||
|
|
||||||
if (anthropicError?.message) {
|
|
||||||
errorMessage = anthropicError.message
|
|
||||||
} else if (error instanceof Error && error.message) {
|
|
||||||
errorMessage = error.message
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!anthropicStatus && error instanceof Error) {
|
|
||||||
if (error.message.includes('API key') || error.message.includes('authentication')) {
|
|
||||||
statusCode = 401
|
|
||||||
errorType = 'authentication_error'
|
|
||||||
} else if (error.message.includes('rate limit') || error.message.includes('quota')) {
|
|
||||||
statusCode = 429
|
|
||||||
errorType = 'rate_limit_error'
|
|
||||||
} else if (error.message.includes('timeout') || error.message.includes('connection')) {
|
|
||||||
statusCode = 502
|
|
||||||
errorType = 'api_error'
|
|
||||||
} else if (error.message.includes('validation') || error.message.includes('invalid')) {
|
|
||||||
statusCode = 400
|
|
||||||
errorType = 'invalid_request_error'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return res.status(statusCode).json({
|
|
||||||
type: 'error',
|
|
||||||
error: {
|
|
||||||
type: errorType,
|
|
||||||
message: errorMessage,
|
|
||||||
requestId: error?.request_id
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
export { router as messagesRoutes }
|
/**
|
||||||
|
* @swagger
|
||||||
|
* /{provider_id}/v1/messages:
|
||||||
|
* post:
|
||||||
|
* summary: Create message with provider in path
|
||||||
|
* description: Create a message response using provider ID from URL path
|
||||||
|
* tags: [Messages]
|
||||||
|
* parameters:
|
||||||
|
* - in: path
|
||||||
|
* name: provider_id
|
||||||
|
* required: true
|
||||||
|
* schema:
|
||||||
|
* type: string
|
||||||
|
* description: Provider ID (e.g., "my-anthropic")
|
||||||
|
* example: "my-anthropic"
|
||||||
|
* requestBody:
|
||||||
|
* required: true
|
||||||
|
* content:
|
||||||
|
* application/json:
|
||||||
|
* schema:
|
||||||
|
* type: object
|
||||||
|
* required:
|
||||||
|
* - model
|
||||||
|
* - max_tokens
|
||||||
|
* - messages
|
||||||
|
* properties:
|
||||||
|
* model:
|
||||||
|
* type: string
|
||||||
|
* description: Model ID without provider prefix
|
||||||
|
* example: "claude-3-5-sonnet-20241022"
|
||||||
|
* max_tokens:
|
||||||
|
* type: integer
|
||||||
|
* minimum: 1
|
||||||
|
* description: Maximum number of tokens to generate
|
||||||
|
* example: 1024
|
||||||
|
* messages:
|
||||||
|
* type: array
|
||||||
|
* items:
|
||||||
|
* type: object
|
||||||
|
* properties:
|
||||||
|
* role:
|
||||||
|
* type: string
|
||||||
|
* enum: [user, assistant]
|
||||||
|
* content:
|
||||||
|
* oneOf:
|
||||||
|
* - type: string
|
||||||
|
* - type: array
|
||||||
|
* system:
|
||||||
|
* type: string
|
||||||
|
* description: System message
|
||||||
|
* temperature:
|
||||||
|
* type: number
|
||||||
|
* minimum: 0
|
||||||
|
* maximum: 1
|
||||||
|
* description: Sampling temperature
|
||||||
|
* top_p:
|
||||||
|
* type: number
|
||||||
|
* minimum: 0
|
||||||
|
* maximum: 1
|
||||||
|
* description: Nucleus sampling
|
||||||
|
* top_k:
|
||||||
|
* type: integer
|
||||||
|
* minimum: 0
|
||||||
|
* description: Top-k sampling
|
||||||
|
* stream:
|
||||||
|
* type: boolean
|
||||||
|
* description: Whether to stream the response
|
||||||
|
* tools:
|
||||||
|
* type: array
|
||||||
|
* description: Available tools for the model
|
||||||
|
* responses:
|
||||||
|
* 200:
|
||||||
|
* description: Message response
|
||||||
|
* content:
|
||||||
|
* application/json:
|
||||||
|
* schema:
|
||||||
|
* type: object
|
||||||
|
* properties:
|
||||||
|
* id:
|
||||||
|
* type: string
|
||||||
|
* type:
|
||||||
|
* type: string
|
||||||
|
* example: message
|
||||||
|
* role:
|
||||||
|
* type: string
|
||||||
|
* example: assistant
|
||||||
|
* content:
|
||||||
|
* type: array
|
||||||
|
* items:
|
||||||
|
* type: object
|
||||||
|
* model:
|
||||||
|
* type: string
|
||||||
|
* stop_reason:
|
||||||
|
* type: string
|
||||||
|
* stop_sequence:
|
||||||
|
* type: string
|
||||||
|
* usage:
|
||||||
|
* type: object
|
||||||
|
* properties:
|
||||||
|
* input_tokens:
|
||||||
|
* type: integer
|
||||||
|
* output_tokens:
|
||||||
|
* type: integer
|
||||||
|
* text/event-stream:
|
||||||
|
* schema:
|
||||||
|
* type: string
|
||||||
|
* description: Server-sent events stream (when stream=true)
|
||||||
|
* 400:
|
||||||
|
* description: Bad request
|
||||||
|
* 401:
|
||||||
|
* description: Unauthorized
|
||||||
|
* 429:
|
||||||
|
* description: Rate limit exceeded
|
||||||
|
* 500:
|
||||||
|
* description: Internal server error
|
||||||
|
*/
|
||||||
|
providerRouter.post('/', async (req: Request, res: Response) => {
|
||||||
|
// Validate request body
|
||||||
|
const bodyValidation = await validateRequestBody(req)
|
||||||
|
if (!bodyValidation.valid) {
|
||||||
|
return res.status(400).json(bodyValidation.error)
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const providerId = req.params.provider
|
||||||
|
|
||||||
|
if (!providerId) {
|
||||||
|
return res.status(400).json({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'invalid_request_error',
|
||||||
|
message: 'Provider ID is required in URL path'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get provider directly by ID from URL path
|
||||||
|
const provider = await getProviderById(providerId)
|
||||||
|
if (!provider) {
|
||||||
|
return res.status(400).json({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'invalid_request_error',
|
||||||
|
message: `Provider '${providerId}' not found or not enabled`
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const request: MessageCreateParams = req.body
|
||||||
|
|
||||||
|
return handleMessageProcessing({ req, res, provider, request })
|
||||||
|
} catch (error: any) {
|
||||||
|
logger.error('Message processing error', { error })
|
||||||
|
const { statusCode, errorResponse } = messagesService.transformError(error)
|
||||||
|
return res.status(statusCode).json(errorResponse)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
export { providerRouter as messagesProviderRoutes, router as messagesRoutes }
|
||||||
|
|||||||
@@ -75,13 +75,13 @@ const router = express
|
|||||||
*/
|
*/
|
||||||
.get('/', async (req: Request, res: Response) => {
|
.get('/', async (req: Request, res: Response) => {
|
||||||
try {
|
try {
|
||||||
logger.info('Models list request received', { query: req.query })
|
logger.debug('Models list request received', { query: req.query })
|
||||||
|
|
||||||
// Validate query parameters using Zod schema
|
// Validate query parameters using Zod schema
|
||||||
const filterResult = ApiModelsFilterSchema.safeParse(req.query)
|
const filterResult = ApiModelsFilterSchema.safeParse(req.query)
|
||||||
|
|
||||||
if (!filterResult.success) {
|
if (!filterResult.success) {
|
||||||
logger.warn('Invalid query parameters:', filterResult.error.issues)
|
logger.warn('Invalid model query parameters', { issues: filterResult.error.issues })
|
||||||
return res.status(400).json({
|
return res.status(400).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Invalid query parameters',
|
message: 'Invalid query parameters',
|
||||||
@@ -99,24 +99,18 @@ const router = express
|
|||||||
const response = await modelsService.getModels(filter)
|
const response = await modelsService.getModels(filter)
|
||||||
|
|
||||||
if (response.data.length === 0) {
|
if (response.data.length === 0) {
|
||||||
logger.warn(
|
logger.warn('No models available from providers', { filter })
|
||||||
'No models available from providers. This may be because no OpenAI/Anthropic providers are configured or enabled.',
|
|
||||||
{ filter }
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(`Returning ${response.data.length} models`, {
|
logger.info('Models response ready', {
|
||||||
filter,
|
filter,
|
||||||
total: response.total
|
total: response.total,
|
||||||
|
modelIds: response.data.map((m) => m.id)
|
||||||
})
|
})
|
||||||
logger.debug(
|
|
||||||
'Model IDs:',
|
|
||||||
response.data.map((m) => m.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
return res.json(response satisfies ApiModelsResponse)
|
return res.json(response satisfies ApiModelsResponse)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error fetching models:', error)
|
logger.error('Error fetching models', { error })
|
||||||
return res.status(503).json({
|
return res.status(503).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Failed to retrieve models from available providers',
|
message: 'Failed to retrieve models from available providers',
|
||||||
|
|||||||
@@ -1,50 +1,72 @@
|
|||||||
import { createServer } from 'node:http'
|
import { createServer } from 'node:http'
|
||||||
|
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
|
||||||
import { agentService } from '../services/agents'
|
import { agentService } from '../services/agents'
|
||||||
import { loggerService } from '../services/LoggerService'
|
|
||||||
import { app } from './app'
|
import { app } from './app'
|
||||||
import { config } from './config'
|
import { config } from './config'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServer')
|
const logger = loggerService.withContext('ApiServer')
|
||||||
|
|
||||||
|
const GLOBAL_REQUEST_TIMEOUT_MS = 5 * 60_000
|
||||||
|
const GLOBAL_HEADERS_TIMEOUT_MS = GLOBAL_REQUEST_TIMEOUT_MS + 5_000
|
||||||
|
const GLOBAL_KEEPALIVE_TIMEOUT_MS = 60_000
|
||||||
|
|
||||||
export class ApiServer {
|
export class ApiServer {
|
||||||
private server: ReturnType<typeof createServer> | null = null
|
private server: ReturnType<typeof createServer> | null = null
|
||||||
|
|
||||||
async start(): Promise<void> {
|
async start(): Promise<void> {
|
||||||
if (this.server) {
|
if (this.server && this.server.listening) {
|
||||||
logger.warn('Server already running')
|
logger.warn('Server already running')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clean up any failed server instance
|
||||||
|
if (this.server && !this.server.listening) {
|
||||||
|
logger.warn('Cleaning up failed server instance')
|
||||||
|
this.server = null
|
||||||
|
}
|
||||||
|
|
||||||
// Load config
|
// Load config
|
||||||
const { port, host, apiKey } = await config.load()
|
const { port, host } = await config.load()
|
||||||
|
|
||||||
// Initialize AgentService
|
// Initialize AgentService
|
||||||
logger.info('Initializing AgentService...')
|
logger.info('Initializing AgentService')
|
||||||
await agentService.initialize()
|
await agentService.initialize()
|
||||||
logger.info('AgentService initialized successfully')
|
logger.info('AgentService initialized')
|
||||||
|
|
||||||
// Create server with Express app
|
// Create server with Express app
|
||||||
this.server = createServer(app)
|
this.server = createServer(app)
|
||||||
|
this.applyServerTimeouts(this.server)
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
this.server!.listen(port, host, () => {
|
this.server!.listen(port, host, () => {
|
||||||
logger.info(`API Server started at http://${host}:${port}`)
|
logger.info('API server started', { host, port })
|
||||||
logger.info(`API Key: ${apiKey}`)
|
|
||||||
resolve()
|
resolve()
|
||||||
})
|
})
|
||||||
|
|
||||||
this.server!.on('error', reject)
|
this.server!.on('error', (error) => {
|
||||||
|
// Clean up the server instance if listen fails
|
||||||
|
this.server = null
|
||||||
|
reject(error)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private applyServerTimeouts(server: ReturnType<typeof createServer>): void {
|
||||||
|
server.requestTimeout = GLOBAL_REQUEST_TIMEOUT_MS
|
||||||
|
server.headersTimeout = Math.max(GLOBAL_HEADERS_TIMEOUT_MS, server.requestTimeout + 1_000)
|
||||||
|
server.keepAliveTimeout = GLOBAL_KEEPALIVE_TIMEOUT_MS
|
||||||
|
server.setTimeout(0)
|
||||||
|
}
|
||||||
|
|
||||||
async stop(): Promise<void> {
|
async stop(): Promise<void> {
|
||||||
if (!this.server) return
|
if (!this.server) return
|
||||||
|
|
||||||
return new Promise((resolve) => {
|
return new Promise((resolve) => {
|
||||||
this.server!.close(() => {
|
this.server!.close(() => {
|
||||||
logger.info('API Server stopped')
|
logger.info('API server stopped')
|
||||||
this.server = null
|
this.server = null
|
||||||
resolve()
|
resolve()
|
||||||
})
|
})
|
||||||
@@ -62,7 +84,7 @@ export class ApiServer {
|
|||||||
const isListening = this.server?.listening || false
|
const isListening = this.server?.listening || false
|
||||||
const result = hasServer && isListening
|
const result = hasServer && isListening
|
||||||
|
|
||||||
logger.debug('isRunning check:', { hasServer, isListening, result })
|
logger.debug('isRunning check', { hasServer, isListening, result })
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,9 +38,10 @@ export type PrepareRequestResult =
|
|||||||
}
|
}
|
||||||
|
|
||||||
export class ChatCompletionService {
|
export class ChatCompletionService {
|
||||||
async resolveProviderContext(model: string): Promise<
|
async resolveProviderContext(
|
||||||
| { ok: false; error: ModelValidationError }
|
model: string
|
||||||
| { ok: true; provider: Provider; modelId: string; client: OpenAI }
|
): Promise<
|
||||||
|
{ ok: false; error: ModelValidationError } | { ok: true; provider: Provider; modelId: string; client: OpenAI }
|
||||||
> {
|
> {
|
||||||
const modelValidation = await validateModelId(model)
|
const modelValidation = await validateModelId(model)
|
||||||
if (!modelValidation.valid) {
|
if (!modelValidation.valid) {
|
||||||
@@ -97,7 +98,7 @@ export class ChatCompletionService {
|
|||||||
|
|
||||||
const { provider, modelId, client } = providerContext
|
const { provider, modelId, client } = providerContext
|
||||||
|
|
||||||
logger.info('Model validation successful:', {
|
logger.debug('Model validation successful', {
|
||||||
provider: provider.id,
|
provider: provider.id,
|
||||||
providerType: provider.type,
|
providerType: provider.type,
|
||||||
modelId,
|
modelId,
|
||||||
@@ -159,7 +160,7 @@ export class ChatCompletionService {
|
|||||||
response: OpenAI.Chat.Completions.ChatCompletion
|
response: OpenAI.Chat.Completions.ChatCompletion
|
||||||
}> {
|
}> {
|
||||||
try {
|
try {
|
||||||
logger.info('Processing chat completion request:', {
|
logger.debug('Processing chat completion request', {
|
||||||
model: request.model,
|
model: request.model,
|
||||||
messageCount: request.messages.length,
|
messageCount: request.messages.length,
|
||||||
stream: request.stream
|
stream: request.stream
|
||||||
@@ -176,7 +177,7 @@ export class ChatCompletionService {
|
|||||||
|
|
||||||
const { provider, modelId, client, providerRequest } = preparation
|
const { provider, modelId, client, providerRequest } = preparation
|
||||||
|
|
||||||
logger.debug('Sending request to provider:', {
|
logger.debug('Sending request to provider', {
|
||||||
provider: provider.id,
|
provider: provider.id,
|
||||||
model: modelId,
|
model: modelId,
|
||||||
apiHost: provider.apiHost
|
apiHost: provider.apiHost
|
||||||
@@ -184,27 +185,31 @@ export class ChatCompletionService {
|
|||||||
|
|
||||||
const response = (await client.chat.completions.create(providerRequest)) as OpenAI.Chat.Completions.ChatCompletion
|
const response = (await client.chat.completions.create(providerRequest)) as OpenAI.Chat.Completions.ChatCompletion
|
||||||
|
|
||||||
logger.info('Successfully processed chat completion')
|
logger.info('Chat completion processed', {
|
||||||
|
modelId,
|
||||||
|
provider: provider.id
|
||||||
|
})
|
||||||
return {
|
return {
|
||||||
provider,
|
provider,
|
||||||
modelId,
|
modelId,
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error processing chat completion:', error)
|
logger.error('Error processing chat completion', {
|
||||||
|
error,
|
||||||
|
model: request.model
|
||||||
|
})
|
||||||
throw error
|
throw error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async processStreamingCompletion(
|
async processStreamingCompletion(request: ChatCompletionCreateParams): Promise<{
|
||||||
request: ChatCompletionCreateParams
|
|
||||||
): Promise<{
|
|
||||||
provider: Provider
|
provider: Provider
|
||||||
modelId: string
|
modelId: string
|
||||||
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
|
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
|
||||||
}> {
|
}> {
|
||||||
try {
|
try {
|
||||||
logger.info('Processing streaming chat completion request:', {
|
logger.debug('Processing streaming chat completion request', {
|
||||||
model: request.model,
|
model: request.model,
|
||||||
messageCount: request.messages.length
|
messageCount: request.messages.length
|
||||||
})
|
})
|
||||||
@@ -220,25 +225,31 @@ export class ChatCompletionService {
|
|||||||
|
|
||||||
const { provider, modelId, client, providerRequest } = preparation
|
const { provider, modelId, client, providerRequest } = preparation
|
||||||
|
|
||||||
logger.debug('Sending streaming request to provider:', {
|
logger.debug('Sending streaming request to provider', {
|
||||||
provider: provider.id,
|
provider: provider.id,
|
||||||
model: modelId,
|
model: modelId,
|
||||||
apiHost: provider.apiHost
|
apiHost: provider.apiHost
|
||||||
})
|
})
|
||||||
|
|
||||||
const streamRequest = providerRequest as ChatCompletionCreateParamsStreaming
|
const streamRequest = providerRequest as ChatCompletionCreateParamsStreaming
|
||||||
const stream = (await client.chat.completions.create(streamRequest)) as AsyncIterable<
|
const stream = (await client.chat.completions.create(
|
||||||
OpenAI.Chat.Completions.ChatCompletionChunk
|
streamRequest
|
||||||
>
|
)) as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
|
||||||
|
|
||||||
logger.info('Successfully started streaming chat completion')
|
logger.info('Streaming chat completion started', {
|
||||||
|
modelId,
|
||||||
|
provider: provider.id
|
||||||
|
})
|
||||||
return {
|
return {
|
||||||
provider,
|
provider,
|
||||||
modelId,
|
modelId,
|
||||||
stream
|
stream
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error processing streaming chat completion:', error)
|
logger.error('Error processing streaming chat completion', {
|
||||||
|
error,
|
||||||
|
model: request.model
|
||||||
|
})
|
||||||
throw error
|
throw error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ class MCPApiService extends EventEmitter {
|
|||||||
constructor() {
|
constructor() {
|
||||||
super()
|
super()
|
||||||
this.initMcpServer()
|
this.initMcpServer()
|
||||||
logger.silly('MCPApiService initialized')
|
logger.debug('MCPApiService initialized')
|
||||||
}
|
}
|
||||||
|
|
||||||
private initMcpServer() {
|
private initMcpServer() {
|
||||||
@@ -60,7 +60,7 @@ class MCPApiService extends EventEmitter {
|
|||||||
async getAllServers(req: Request): Promise<McpServersResp> {
|
async getAllServers(req: Request): Promise<McpServersResp> {
|
||||||
try {
|
try {
|
||||||
const servers = await getMCPServersFromRedux()
|
const servers = await getMCPServersFromRedux()
|
||||||
logger.silly(`Returning ${servers.length} servers`)
|
logger.debug('Returning servers from Redux', { count: servers.length })
|
||||||
const resp: McpServersResp = {
|
const resp: McpServersResp = {
|
||||||
servers: {}
|
servers: {}
|
||||||
}
|
}
|
||||||
@@ -77,7 +77,7 @@ class MCPApiService extends EventEmitter {
|
|||||||
}
|
}
|
||||||
return resp
|
return resp
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Failed to get all servers:', error)
|
logger.error('Failed to get all servers', { error })
|
||||||
throw new Error('Failed to retrieve servers')
|
throw new Error('Failed to retrieve servers')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -85,87 +85,47 @@ class MCPApiService extends EventEmitter {
|
|||||||
// get server by id
|
// get server by id
|
||||||
async getServerById(id: string): Promise<MCPServer | null> {
|
async getServerById(id: string): Promise<MCPServer | null> {
|
||||||
try {
|
try {
|
||||||
logger.silly(`getServerById called with id: ${id}`)
|
logger.debug('getServerById called', { id })
|
||||||
const servers = await getMCPServersFromRedux()
|
const servers = await getMCPServersFromRedux()
|
||||||
const server = servers.find((s) => s.id === id)
|
const server = servers.find((s) => s.id === id)
|
||||||
if (!server) {
|
if (!server) {
|
||||||
logger.warn(`Server with id ${id} not found`)
|
logger.warn('Server not found', { id })
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
logger.silly(`Returning server with id ${id}`)
|
logger.debug('Returning server', { id })
|
||||||
return server
|
return server
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error(`Failed to get server with id ${id}:`, error)
|
logger.error('Failed to get server', { id, error })
|
||||||
throw new Error('Failed to retrieve server')
|
throw new Error('Failed to retrieve server')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async getServerInfo(id: string): Promise<any> {
|
async getServerInfo(id: string): Promise<any> {
|
||||||
try {
|
try {
|
||||||
logger.silly(`getServerInfo called with id: ${id}`)
|
|
||||||
const server = await this.getServerById(id)
|
const server = await this.getServerById(id)
|
||||||
if (!server) {
|
if (!server) {
|
||||||
logger.warn(`Server with id ${id} not found`)
|
logger.warn('Server not found while fetching info', { id })
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
logger.silly(`Returning server info for id ${id}`)
|
|
||||||
|
|
||||||
const client = await mcpService.initClient(server)
|
const client = await mcpService.initClient(server)
|
||||||
const tools = await client.listTools()
|
const tools = await client.listTools()
|
||||||
|
|
||||||
logger.info(`Server with id ${id} info:`, { tools: JSON.stringify(tools) })
|
|
||||||
|
|
||||||
// const [version, tools, prompts, resources] = await Promise.all([
|
|
||||||
// () => {
|
|
||||||
// try {
|
|
||||||
// return client.getServerVersion()
|
|
||||||
// } catch (error) {
|
|
||||||
// logger.error(`Failed to get server version for id ${id}:`, { error: error })
|
|
||||||
// return '1.0.0'
|
|
||||||
// }
|
|
||||||
// },
|
|
||||||
// (() => {
|
|
||||||
// try {
|
|
||||||
// return client.listTools()
|
|
||||||
// } catch (error) {
|
|
||||||
// logger.error(`Failed to list tools for id ${id}:`, { error: error })
|
|
||||||
// return []
|
|
||||||
// }
|
|
||||||
// })(),
|
|
||||||
// (() => {
|
|
||||||
// try {
|
|
||||||
// return client.listPrompts()
|
|
||||||
// } catch (error) {
|
|
||||||
// logger.error(`Failed to list prompts for id ${id}:`, { error: error })
|
|
||||||
// return []
|
|
||||||
// }
|
|
||||||
// })(),
|
|
||||||
// (() => {
|
|
||||||
// try {
|
|
||||||
// return client.listResources()
|
|
||||||
// } catch (error) {
|
|
||||||
// logger.error(`Failed to list resources for id ${id}:`, { error: error })
|
|
||||||
// return []
|
|
||||||
// }
|
|
||||||
// })()
|
|
||||||
// ])
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
id: server.id,
|
id: server.id,
|
||||||
name: server.name,
|
name: server.name,
|
||||||
type: server.type,
|
type: server.type,
|
||||||
description: server.description,
|
description: server.description,
|
||||||
tools
|
tools: tools.tools
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error(`Failed to get server info with id ${id}:`, error)
|
logger.error('Failed to get server info', { id, error })
|
||||||
throw new Error('Failed to retrieve server info')
|
throw new Error('Failed to retrieve server info')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async handleRequest(req: Request, res: Response, server: MCPServer) {
|
async handleRequest(req: Request, res: Response, server: MCPServer) {
|
||||||
const sessionId = req.headers['mcp-session-id'] as string | undefined
|
const sessionId = req.headers['mcp-session-id'] as string | undefined
|
||||||
logger.silly(`Handling request for server with sessionId ${sessionId}`)
|
logger.debug('Handling MCP request', { sessionId, serverId: server.id })
|
||||||
let transport: StreamableHTTPServerTransport
|
let transport: StreamableHTTPServerTransport
|
||||||
if (sessionId && transports[sessionId]) {
|
if (sessionId && transports[sessionId]) {
|
||||||
transport = transports[sessionId]
|
transport = transports[sessionId]
|
||||||
@@ -178,7 +138,7 @@ class MCPApiService extends EventEmitter {
|
|||||||
})
|
})
|
||||||
|
|
||||||
transport.onclose = () => {
|
transport.onclose = () => {
|
||||||
logger.info(`Transport for sessionId ${sessionId} closed`)
|
logger.info('Transport closed', { sessionId })
|
||||||
if (transport.sessionId) {
|
if (transport.sessionId) {
|
||||||
delete transports[transport.sessionId]
|
delete transports[transport.sessionId]
|
||||||
}
|
}
|
||||||
@@ -213,12 +173,15 @@ class MCPApiService extends EventEmitter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(`Request body`, { rawBody: req.body, messages: JSON.stringify(messages) })
|
logger.debug('Dispatching MCP request', {
|
||||||
|
sessionId: transport.sessionId ?? sessionId,
|
||||||
|
messageCount: messages.length
|
||||||
|
})
|
||||||
await transport.handleRequest(req as IncomingMessage, res as ServerResponse, messages)
|
await transport.handleRequest(req as IncomingMessage, res as ServerResponse, messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
private onMessage(message: JSONRPCMessage, extra?: MessageExtraInfo) {
|
private onMessage(message: JSONRPCMessage, extra?: MessageExtraInfo) {
|
||||||
logger.info(`Received message: ${JSON.stringify(message)}`, extra)
|
logger.debug('Received MCP message', { message, extra })
|
||||||
// Handle message here
|
// Handle message here
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,32 +1,93 @@
|
|||||||
import Anthropic from '@anthropic-ai/sdk'
|
import Anthropic from '@anthropic-ai/sdk'
|
||||||
import { Message, MessageCreateParams, RawMessageStreamEvent } from '@anthropic-ai/sdk/resources'
|
import { MessageCreateParams, MessageStreamEvent } from '@anthropic-ai/sdk/resources'
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
import anthropicService from '@main/services/AnthropicService'
|
||||||
|
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
|
||||||
import { Provider } from '@types'
|
import { Provider } from '@types'
|
||||||
|
import { Response } from 'express'
|
||||||
import { loggerService } from '../../services/LoggerService'
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('MessagesService')
|
const logger = loggerService.withContext('MessagesService')
|
||||||
|
const EXCLUDED_FORWARD_HEADERS: ReadonlySet<string> = new Set([
|
||||||
|
'host',
|
||||||
|
'x-api-key',
|
||||||
|
'authorization',
|
||||||
|
'sentry-trace',
|
||||||
|
'baggage',
|
||||||
|
'content-length',
|
||||||
|
'connection'
|
||||||
|
])
|
||||||
|
|
||||||
export interface ValidationResult {
|
export interface ValidationResult {
|
||||||
isValid: boolean
|
isValid: boolean
|
||||||
errors: string[]
|
errors: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ErrorResponse {
|
||||||
|
type: 'error'
|
||||||
|
error: {
|
||||||
|
type: string
|
||||||
|
message: string
|
||||||
|
requestId?: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface StreamConfig {
|
||||||
|
response: Response
|
||||||
|
onChunk?: (chunk: MessageStreamEvent) => void
|
||||||
|
onError?: (error: any) => void
|
||||||
|
onComplete?: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ProcessMessageOptions {
|
||||||
|
provider: Provider
|
||||||
|
request: MessageCreateParams
|
||||||
|
extraHeaders?: Record<string, string | string[]>
|
||||||
|
modelId?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ProcessMessageResult {
|
||||||
|
client: Anthropic
|
||||||
|
anthropicRequest: MessageCreateParams
|
||||||
|
}
|
||||||
|
|
||||||
export class MessagesService {
|
export class MessagesService {
|
||||||
// oxlint-disable-next-line no-unused-vars
|
|
||||||
validateRequest(request: MessageCreateParams): ValidationResult {
|
validateRequest(request: MessageCreateParams): ValidationResult {
|
||||||
// TODO: Implement comprehensive request validation
|
// TODO: Implement comprehensive request validation
|
||||||
const errors: string[] = []
|
const errors: string[] = []
|
||||||
|
|
||||||
if (!request.model) {
|
if (!request.model || typeof request.model !== 'string') {
|
||||||
errors.push('Model is required')
|
errors.push('Model is required')
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!request.max_tokens || request.max_tokens < 1) {
|
if (typeof request.max_tokens !== 'number' || !Number.isFinite(request.max_tokens) || request.max_tokens < 1) {
|
||||||
errors.push('max_tokens is required and must be at least 1')
|
errors.push('max_tokens is required and must be a positive number')
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!request.messages || !Array.isArray(request.messages) || request.messages.length === 0) {
|
if (!request.messages || !Array.isArray(request.messages) || request.messages.length === 0) {
|
||||||
errors.push('messages is required and must be a non-empty array')
|
errors.push('messages is required and must be a non-empty array')
|
||||||
|
} else {
|
||||||
|
request.messages.forEach((message, index) => {
|
||||||
|
if (!message || typeof message !== 'object') {
|
||||||
|
errors.push(`messages[${index}] must be an object`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!('role' in message) || typeof message.role !== 'string' || message.role.trim().length === 0) {
|
||||||
|
errors.push(`messages[${index}].role is required`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const content: unknown = message.content
|
||||||
|
if (content === undefined || content === null) {
|
||||||
|
errors.push(`messages[${index}].content is required`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof content === 'string' && content.trim().length === 0) {
|
||||||
|
errors.push(`messages[${index}].content cannot be empty`)
|
||||||
|
} else if (Array.isArray(content) && content.length === 0) {
|
||||||
|
errors.push(`messages[${index}].content must include at least one item when using an array`)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -35,70 +96,224 @@ export class MessagesService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async processMessage(request: MessageCreateParams, provider: Provider): Promise<Message> {
|
async getClient(provider: Provider, extraHeaders?: Record<string, string | string[]>): Promise<Anthropic> {
|
||||||
logger.info('Processing Anthropic message request:', {
|
|
||||||
model: request.model,
|
|
||||||
messageCount: request.messages.length,
|
|
||||||
stream: request.stream,
|
|
||||||
max_tokens: request.max_tokens
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create Anthropic client for the provider
|
// Create Anthropic client for the provider
|
||||||
const client = new Anthropic({
|
if (provider.authType === 'oauth') {
|
||||||
baseURL: provider.apiHost,
|
const oauthToken = await anthropicService.getValidAccessToken()
|
||||||
apiKey: provider.apiKey
|
return getSdkClient(provider, oauthToken, extraHeaders)
|
||||||
})
|
|
||||||
|
|
||||||
// Prepare request with the actual model ID
|
|
||||||
const anthropicRequest: MessageCreateParams = {
|
|
||||||
...request,
|
|
||||||
stream: false
|
|
||||||
}
|
}
|
||||||
|
return getSdkClient(provider, null, extraHeaders)
|
||||||
logger.debug('Sending request to Anthropic provider:', {
|
|
||||||
provider: provider.id,
|
|
||||||
apiHost: provider.apiHost
|
|
||||||
})
|
|
||||||
|
|
||||||
const response = await client.messages.create(anthropicRequest)
|
|
||||||
|
|
||||||
logger.info('Successfully processed Anthropic message')
|
|
||||||
return response
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async *processStreamingMessage(
|
prepareHeaders(headers: Record<string, string | string[] | undefined>): Record<string, string | string[]> {
|
||||||
request: MessageCreateParams,
|
const extraHeaders: Record<string, string | string[]> = {}
|
||||||
provider: Provider
|
|
||||||
): AsyncIterable<RawMessageStreamEvent> {
|
|
||||||
logger.info('Processing streaming Anthropic message request:', {
|
|
||||||
model: request.model,
|
|
||||||
messageCount: request.messages.length
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create Anthropic client for the provider
|
for (const [key, value] of Object.entries(headers)) {
|
||||||
const client = new Anthropic({
|
if (value === undefined) {
|
||||||
baseURL: provider.apiHost,
|
continue
|
||||||
apiKey: provider.apiKey
|
}
|
||||||
})
|
|
||||||
|
|
||||||
// Prepare streaming request
|
const normalizedKey = key.toLowerCase()
|
||||||
const streamingRequest: MessageCreateParams = {
|
if (EXCLUDED_FORWARD_HEADERS.has(normalizedKey)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
extraHeaders[normalizedKey] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
return extraHeaders
|
||||||
|
}
|
||||||
|
|
||||||
|
createAnthropicRequest(request: MessageCreateParams, provider: Provider, modelId?: string): MessageCreateParams {
|
||||||
|
const anthropicRequest: MessageCreateParams = {
|
||||||
...request,
|
...request,
|
||||||
stream: true
|
stream: !!request.stream
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug('Sending streaming request to Anthropic provider:', {
|
// Override model if provided
|
||||||
|
if (modelId) {
|
||||||
|
anthropicRequest.model = modelId
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add Claude Code system message for OAuth providers
|
||||||
|
if (provider.type === 'anthropic' && provider.authType === 'oauth') {
|
||||||
|
anthropicRequest.system = buildClaudeCodeSystemMessage(request.system)
|
||||||
|
}
|
||||||
|
|
||||||
|
return anthropicRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
async handleStreaming(
|
||||||
|
client: Anthropic,
|
||||||
|
request: MessageCreateParams,
|
||||||
|
config: StreamConfig,
|
||||||
|
provider: Provider
|
||||||
|
): Promise<void> {
|
||||||
|
const { response, onChunk, onError, onComplete } = config
|
||||||
|
|
||||||
|
// Set streaming headers
|
||||||
|
response.setHeader('Content-Type', 'text/event-stream; charset=utf-8')
|
||||||
|
response.setHeader('Cache-Control', 'no-cache, no-transform')
|
||||||
|
response.setHeader('Connection', 'keep-alive')
|
||||||
|
response.setHeader('X-Accel-Buffering', 'no')
|
||||||
|
response.flushHeaders()
|
||||||
|
|
||||||
|
const flushableResponse = response as Response & { flush?: () => void }
|
||||||
|
const flushStream = () => {
|
||||||
|
if (typeof flushableResponse.flush !== 'function') {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
flushableResponse.flush()
|
||||||
|
} catch (flushError: unknown) {
|
||||||
|
logger.warn('Failed to flush streaming response', { error: flushError })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const writeSse = (eventType: string | undefined, payload: unknown) => {
|
||||||
|
if (response.writableEnded || response.destroyed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (eventType) {
|
||||||
|
response.write(`event: ${eventType}\n`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = typeof payload === 'string' ? payload : JSON.stringify(payload)
|
||||||
|
response.write(`data: ${data}\n\n`)
|
||||||
|
flushStream()
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const stream = client.messages.stream(request)
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
if (response.writableEnded || response.destroyed) {
|
||||||
|
logger.warn('Streaming response ended before stream completion', {
|
||||||
|
provider: provider.id,
|
||||||
|
model: request.model
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
writeSse(chunk.type, chunk)
|
||||||
|
|
||||||
|
if (onChunk) {
|
||||||
|
onChunk(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeSse(undefined, '[DONE]')
|
||||||
|
|
||||||
|
if (onComplete) {
|
||||||
|
onComplete()
|
||||||
|
}
|
||||||
|
} catch (streamError: any) {
|
||||||
|
logger.error('Stream error', {
|
||||||
|
error: streamError,
|
||||||
|
provider: provider.id,
|
||||||
|
model: request.model,
|
||||||
|
apiHost: provider.apiHost,
|
||||||
|
anthropicApiHost: provider.anthropicApiHost
|
||||||
|
})
|
||||||
|
writeSse(undefined, {
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'api_error',
|
||||||
|
message: 'Stream processing error'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if (onError) {
|
||||||
|
onError(streamError)
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
if (!response.writableEnded) {
|
||||||
|
response.end()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } {
|
||||||
|
let statusCode = 500
|
||||||
|
let errorType = 'api_error'
|
||||||
|
let errorMessage = 'Internal server error'
|
||||||
|
|
||||||
|
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
|
||||||
|
const anthropicError = error?.error
|
||||||
|
|
||||||
|
if (anthropicStatus) {
|
||||||
|
statusCode = anthropicStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
if (anthropicError?.type) {
|
||||||
|
errorType = anthropicError.type
|
||||||
|
}
|
||||||
|
|
||||||
|
if (anthropicError?.message) {
|
||||||
|
errorMessage = anthropicError.message
|
||||||
|
} else if (error instanceof Error && error.message) {
|
||||||
|
errorMessage = error.message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Infer error type from message if not from Anthropic API
|
||||||
|
if (!anthropicStatus && error instanceof Error) {
|
||||||
|
const errorMessageText = error.message ?? ''
|
||||||
|
|
||||||
|
if (errorMessageText.includes('API key') || errorMessageText.includes('authentication')) {
|
||||||
|
statusCode = 401
|
||||||
|
errorType = 'authentication_error'
|
||||||
|
} else if (errorMessageText.includes('rate limit') || errorMessageText.includes('quota')) {
|
||||||
|
statusCode = 429
|
||||||
|
errorType = 'rate_limit_error'
|
||||||
|
} else if (errorMessageText.includes('timeout') || errorMessageText.includes('connection')) {
|
||||||
|
statusCode = 502
|
||||||
|
errorType = 'api_error'
|
||||||
|
} else if (errorMessageText.includes('validation') || errorMessageText.includes('invalid')) {
|
||||||
|
statusCode = 400
|
||||||
|
errorType = 'invalid_request_error'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const safeErrorMessage =
|
||||||
|
typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error'
|
||||||
|
|
||||||
|
return {
|
||||||
|
statusCode,
|
||||||
|
errorResponse: {
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: errorType,
|
||||||
|
message: safeErrorMessage,
|
||||||
|
requestId: error?.request_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async processMessage(options: ProcessMessageOptions): Promise<ProcessMessageResult> {
|
||||||
|
const { provider, request, extraHeaders, modelId } = options
|
||||||
|
|
||||||
|
const client = await this.getClient(provider, extraHeaders)
|
||||||
|
const anthropicRequest = this.createAnthropicRequest(request, provider, modelId)
|
||||||
|
|
||||||
|
const messageCount = Array.isArray(request.messages) ? request.messages.length : 0
|
||||||
|
|
||||||
|
logger.info('Processing anthropic messages request', {
|
||||||
provider: provider.id,
|
provider: provider.id,
|
||||||
apiHost: provider.apiHost
|
apiHost: provider.apiHost,
|
||||||
|
anthropicApiHost: provider.anthropicApiHost,
|
||||||
|
model: anthropicRequest.model,
|
||||||
|
stream: !!anthropicRequest.stream,
|
||||||
|
// systemPrompt: JSON.stringify(!!request.system),
|
||||||
|
// messages: JSON.stringify(request.messages),
|
||||||
|
messageCount,
|
||||||
|
toolCount: Array.isArray(request.tools) ? request.tools.length : 0
|
||||||
})
|
})
|
||||||
|
|
||||||
const stream = client.messages.stream(streamingRequest)
|
// Return client and request for route layer to handle streaming/non-streaming
|
||||||
|
return {
|
||||||
for await (const chunk of stream) {
|
client,
|
||||||
yield chunk
|
anthropicRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info('Successfully completed streaming Anthropic message')
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
|
import { isEmpty } from 'lodash'
|
||||||
|
|
||||||
import { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
|
import { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
|
||||||
import { loggerService } from '../../services/LoggerService'
|
import { loggerService } from '../../services/LoggerService'
|
||||||
import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils'
|
import {
|
||||||
|
getAvailableProviders,
|
||||||
|
getProviderAnthropicModelChecker,
|
||||||
|
listAllAvailableModels,
|
||||||
|
transformModelToOpenAI
|
||||||
|
} from '../utils'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ModelsService')
|
const logger = loggerService.withContext('ModelsService')
|
||||||
|
|
||||||
@@ -13,14 +20,33 @@ export class ModelsService {
|
|||||||
try {
|
try {
|
||||||
logger.debug('Getting available models from providers', { filter })
|
logger.debug('Getting available models from providers', { filter })
|
||||||
|
|
||||||
const models = await listAllAvailableModels()
|
let providers = await getAvailableProviders()
|
||||||
const providers = await getAvailableProviders()
|
|
||||||
|
|
||||||
|
if (filter.providerType === 'anthropic') {
|
||||||
|
providers = providers.filter((p) => p.type === 'anthropic' || !isEmpty(p.anthropicApiHost?.trim()))
|
||||||
|
}
|
||||||
|
|
||||||
|
const models = await listAllAvailableModels(providers)
|
||||||
// Use Map to deduplicate models by their full ID (provider:model_id)
|
// Use Map to deduplicate models by their full ID (provider:model_id)
|
||||||
const uniqueModels = new Map<string, ApiModel>()
|
const uniqueModels = new Map<string, ApiModel>()
|
||||||
|
|
||||||
for (const model of models) {
|
for (const model of models) {
|
||||||
const openAIModel = transformModelToOpenAI(model, providers)
|
const provider = providers.find((p) => p.id === model.provider)
|
||||||
|
logger.debug(`Processing model ${model.id}`)
|
||||||
|
if (!provider) {
|
||||||
|
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if (filter.providerType === 'anthropic') {
|
||||||
|
const checker = getProviderAnthropicModelChecker(provider.id)
|
||||||
|
if (!checker(model)) {
|
||||||
|
logger.debug(`Skipping model ${model.id} from ${model.provider}. Reason: Not an Anthropic model.`)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const openAIModel = transformModelToOpenAI(model, provider)
|
||||||
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
||||||
|
|
||||||
// Only add if not already present (first occurrence wins)
|
// Only add if not already present (first occurrence wins)
|
||||||
@@ -32,16 +58,6 @@ export class ModelsService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let modelData = Array.from(uniqueModels.values())
|
let modelData = Array.from(uniqueModels.values())
|
||||||
if (filter.providerType) {
|
|
||||||
// Apply filters
|
|
||||||
const providerType = filter.providerType
|
|
||||||
modelData = modelData.filter((model) => {
|
|
||||||
// Find the provider for this model and check its type
|
|
||||||
return model.provider_type === providerType
|
|
||||||
})
|
|
||||||
logger.debug(`Filtered by provider type '${providerType}': ${modelData.length} models`)
|
|
||||||
}
|
|
||||||
|
|
||||||
const total = modelData.length
|
const total = modelData.length
|
||||||
|
|
||||||
// Apply pagination
|
// Apply pagination
|
||||||
@@ -58,7 +74,11 @@ export class ModelsService {
|
|||||||
logger.debug(`Applied offset: offset=${offset}, showing ${modelData.length} of ${total} models`)
|
logger.debug(`Applied offset: offset=${offset}, showing ${modelData.length} of ${total} models`)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(`Successfully retrieved ${modelData.length} models from ${models.length} total models`)
|
logger.info('Models retrieved', {
|
||||||
|
returned: modelData.length,
|
||||||
|
discovered: models.length,
|
||||||
|
filter
|
||||||
|
})
|
||||||
|
|
||||||
if (models.length > total) {
|
if (models.length > total) {
|
||||||
logger.debug(`Filtered out ${models.length - total} models after deduplication and filtering`)
|
logger.debug(`Filtered out ${models.length - total} models after deduplication and filtering`)
|
||||||
@@ -80,7 +100,7 @@ export class ModelsService {
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error getting models:', error)
|
logger.error('Error getting models', { error, filter })
|
||||||
return {
|
return {
|
||||||
object: 'list',
|
object: 'list',
|
||||||
data: []
|
data: []
|
||||||
|
|||||||
@@ -0,0 +1,64 @@
|
|||||||
|
export type StreamAbortHandler = (reason: unknown) => void
|
||||||
|
|
||||||
|
export interface StreamAbortController {
|
||||||
|
abortController: AbortController
|
||||||
|
registerAbortHandler: (handler: StreamAbortHandler) => void
|
||||||
|
clearAbortTimeout: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export const STREAM_TIMEOUT_REASON = 'stream timeout'
|
||||||
|
|
||||||
|
interface CreateStreamAbortControllerOptions {
|
||||||
|
timeoutMs: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export const createStreamAbortController = (options: CreateStreamAbortControllerOptions): StreamAbortController => {
|
||||||
|
const { timeoutMs } = options
|
||||||
|
const abortController = new AbortController()
|
||||||
|
const signal = abortController.signal
|
||||||
|
|
||||||
|
let timeoutId: NodeJS.Timeout | undefined
|
||||||
|
let abortHandler: StreamAbortHandler | undefined
|
||||||
|
|
||||||
|
const clearAbortTimeout = () => {
|
||||||
|
if (!timeoutId) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clearTimeout(timeoutId)
|
||||||
|
timeoutId = undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleAbort = () => {
|
||||||
|
clearAbortTimeout()
|
||||||
|
|
||||||
|
if (!abortHandler) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
abortHandler(signal.reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
signal.addEventListener('abort', handleAbort, { once: true })
|
||||||
|
|
||||||
|
const registerAbortHandler = (handler: StreamAbortHandler) => {
|
||||||
|
abortHandler = handler
|
||||||
|
|
||||||
|
if (signal.aborted) {
|
||||||
|
abortHandler(signal.reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (timeoutMs > 0) {
|
||||||
|
timeoutId = setTimeout(() => {
|
||||||
|
if (!signal.aborted) {
|
||||||
|
abortController.abort(STREAM_TIMEOUT_REASON)
|
||||||
|
}
|
||||||
|
}, timeoutMs)
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
abortController,
|
||||||
|
registerAbortHandler,
|
||||||
|
clearAbortTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,21 +7,23 @@ const logger = loggerService.withContext('ApiServerUtils')
|
|||||||
|
|
||||||
// Cache configuration
|
// Cache configuration
|
||||||
const PROVIDERS_CACHE_KEY = 'api-server:providers'
|
const PROVIDERS_CACHE_KEY = 'api-server:providers'
|
||||||
const PROVIDERS_CACHE_TTL = 5 * 60 * 1000 // 5 minutes
|
const PROVIDERS_CACHE_TTL = 10 * 1000 // 10 seconds
|
||||||
|
|
||||||
export async function getAvailableProviders(): Promise<Provider[]> {
|
export async function getAvailableProviders(): Promise<Provider[]> {
|
||||||
try {
|
try {
|
||||||
// Try to get from cache first (faster)
|
// Try to get from cache first (faster)
|
||||||
const cachedSupportedProviders = CacheService.get<Provider[]>(PROVIDERS_CACHE_KEY)
|
const cachedSupportedProviders = CacheService.get<Provider[]>(PROVIDERS_CACHE_KEY)
|
||||||
if (cachedSupportedProviders) {
|
if (cachedSupportedProviders && cachedSupportedProviders.length > 0) {
|
||||||
logger.debug(`Found ${cachedSupportedProviders.length} supported providers (from cache)`)
|
logger.debug('Providers resolved from cache', {
|
||||||
|
count: cachedSupportedProviders.length
|
||||||
|
})
|
||||||
return cachedSupportedProviders
|
return cachedSupportedProviders
|
||||||
}
|
}
|
||||||
|
|
||||||
// If cache is not available, get fresh data from Redux
|
// If cache is not available, get fresh data from Redux
|
||||||
const providers = await reduxService.select('state.llm.providers')
|
const providers = await reduxService.select('state.llm.providers')
|
||||||
if (!providers || !Array.isArray(providers)) {
|
if (!providers || !Array.isArray(providers)) {
|
||||||
logger.warn('No providers found in Redux store, returning empty array')
|
logger.warn('No providers found in Redux store')
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,21 +35,26 @@ export async function getAvailableProviders(): Promise<Provider[]> {
|
|||||||
// Cache the filtered results
|
// Cache the filtered results
|
||||||
CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL)
|
CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL)
|
||||||
|
|
||||||
logger.info(`Filtered to ${supportedProviders.length} supported providers from ${providers.length} total providers`)
|
logger.info('Providers filtered', {
|
||||||
|
supported: supportedProviders.length,
|
||||||
|
total: providers.length
|
||||||
|
})
|
||||||
|
|
||||||
return supportedProviders
|
return supportedProviders
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Failed to get providers from Redux store:', error)
|
logger.error('Failed to get providers from Redux store', { error })
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function listAllAvailableModels(): Promise<Model[]> {
|
export async function listAllAvailableModels(providers?: Provider[]): Promise<Model[]> {
|
||||||
try {
|
try {
|
||||||
const providers = await getAvailableProviders()
|
if (!providers) {
|
||||||
|
providers = await getAvailableProviders()
|
||||||
|
}
|
||||||
return providers.map((p: Provider) => p.models || []).flat()
|
return providers.map((p: Provider) => p.models || []).flat()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Failed to list available models:', error)
|
logger.error('Failed to list available models', { error })
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -55,15 +62,13 @@ export async function listAllAvailableModels(): Promise<Model[]> {
|
|||||||
export async function getProviderByModel(model: string): Promise<Provider | undefined> {
|
export async function getProviderByModel(model: string): Promise<Provider | undefined> {
|
||||||
try {
|
try {
|
||||||
if (!model || typeof model !== 'string') {
|
if (!model || typeof model !== 'string') {
|
||||||
logger.warn(`Invalid model parameter: ${model}`)
|
logger.warn('Invalid model parameter', { model })
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate model format first
|
// Validate model format first
|
||||||
if (!model.includes(':')) {
|
if (!model.includes(':')) {
|
||||||
logger.warn(
|
logger.warn('Invalid model format missing separator', { model })
|
||||||
`Invalid model format, must contain ':' separator. Expected format "provider:model_id", got: ${model}`
|
|
||||||
)
|
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,7 +76,7 @@ export async function getProviderByModel(model: string): Promise<Provider | unde
|
|||||||
const modelInfo = model.split(':')
|
const modelInfo = model.split(':')
|
||||||
|
|
||||||
if (modelInfo.length < 2 || modelInfo[0].length === 0 || modelInfo[1].length === 0) {
|
if (modelInfo.length < 2 || modelInfo[0].length === 0 || modelInfo[1].length === 0) {
|
||||||
logger.warn(`Invalid model format, expected "provider:model_id" with non-empty parts, got: ${model}`)
|
logger.warn('Invalid model format with empty parts', { model })
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,16 +84,17 @@ export async function getProviderByModel(model: string): Promise<Provider | unde
|
|||||||
const provider = providers.find((p: Provider) => p.id === providerId)
|
const provider = providers.find((p: Provider) => p.id === providerId)
|
||||||
|
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
logger.warn(
|
logger.warn('Provider not found for model', {
|
||||||
`Provider '${providerId}' not found or not enabled. Available providers: ${providers.map((p) => p.id).join(', ')}`
|
providerId,
|
||||||
)
|
available: providers.map((p) => p.id)
|
||||||
|
})
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug(`Found provider '${providerId}' for model: ${model}`)
|
logger.debug('Provider resolved for model', { providerId, model })
|
||||||
return provider
|
return provider
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Failed to get provider by model:', error)
|
logger.error('Failed to get provider by model', { error, model })
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -103,9 +109,12 @@ export interface ModelValidationError {
|
|||||||
code: string
|
code: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function validateModelId(
|
export async function validateModelId(model: string): Promise<{
|
||||||
model: string
|
valid: boolean
|
||||||
): Promise<{ valid: boolean; error?: ModelValidationError; provider?: Provider; modelId?: string }> {
|
error?: ModelValidationError
|
||||||
|
provider?: Provider
|
||||||
|
modelId?: string
|
||||||
|
}> {
|
||||||
try {
|
try {
|
||||||
if (!model || typeof model !== 'string') {
|
if (!model || typeof model !== 'string') {
|
||||||
return {
|
return {
|
||||||
@@ -176,7 +185,7 @@ export async function validateModelId(
|
|||||||
modelId
|
modelId
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error validating model ID:', error)
|
logger.error('Error validating model ID', { error, model })
|
||||||
return {
|
return {
|
||||||
valid: false,
|
valid: false,
|
||||||
error: {
|
error: {
|
||||||
@@ -188,8 +197,7 @@ export async function validateModelId(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function transformModelToOpenAI(model: Model, providers: Provider[]): ApiModel {
|
export function transformModelToOpenAI(model: Model, provider?: Provider): ApiModel {
|
||||||
const provider = providers.find((p) => p.id === model.provider)
|
|
||||||
const providerDisplayName = provider?.name
|
const providerDisplayName = provider?.name
|
||||||
return {
|
return {
|
||||||
id: `${model.provider}:${model.id}`,
|
id: `${model.provider}:${model.id}`,
|
||||||
@@ -204,6 +212,32 @@ export function transformModelToOpenAI(model: Model, providers: Provider[]): Api
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function getProviderById(providerId: string): Promise<Provider | undefined> {
|
||||||
|
try {
|
||||||
|
if (!providerId || typeof providerId !== 'string') {
|
||||||
|
logger.warn('Invalid provider ID parameter', { providerId })
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
const providers = await getAvailableProviders()
|
||||||
|
const provider = providers.find((p: Provider) => p.id === providerId)
|
||||||
|
|
||||||
|
if (!provider) {
|
||||||
|
logger.warn('Provider not found by ID', {
|
||||||
|
providerId,
|
||||||
|
available: providers.map((p) => p.id)
|
||||||
|
})
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug('Provider found by ID', { providerId })
|
||||||
|
return provider
|
||||||
|
} catch (error: any) {
|
||||||
|
logger.error('Failed to get provider by ID', { error, providerId })
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export function validateProvider(provider: Provider): boolean {
|
export function validateProvider(provider: Provider): boolean {
|
||||||
try {
|
try {
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
@@ -212,7 +246,7 @@ export function validateProvider(provider: Provider): boolean {
|
|||||||
|
|
||||||
// Check required fields
|
// Check required fields
|
||||||
if (!provider.id || !provider.type || !provider.apiKey || !provider.apiHost) {
|
if (!provider.id || !provider.type || !provider.apiKey || !provider.apiHost) {
|
||||||
logger.warn('Provider missing required fields:', {
|
logger.warn('Provider missing required fields', {
|
||||||
id: !!provider.id,
|
id: !!provider.id,
|
||||||
type: !!provider.type,
|
type: !!provider.type,
|
||||||
apiKey: !!provider.apiKey,
|
apiKey: !!provider.apiKey,
|
||||||
@@ -223,21 +257,38 @@ export function validateProvider(provider: Provider): boolean {
|
|||||||
|
|
||||||
// Check if provider is enabled
|
// Check if provider is enabled
|
||||||
if (!provider.enabled) {
|
if (!provider.enabled) {
|
||||||
logger.debug(`Provider is disabled: ${provider.id}`)
|
logger.debug('Provider is disabled', { providerId: provider.id })
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Support OpenAI and Anthropic type providers
|
// Support OpenAI and Anthropic type providers
|
||||||
if (provider.type !== 'openai' && provider.type !== 'anthropic') {
|
if (provider.type !== 'openai' && provider.type !== 'anthropic') {
|
||||||
logger.debug(
|
logger.debug('Provider type not supported', {
|
||||||
`Provider type '${provider.type}' not supported, only 'openai' and 'anthropic' types are currently supported: ${provider.id}`
|
providerId: provider.id,
|
||||||
)
|
providerType: provider.type
|
||||||
|
})
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error validating provider:', error)
|
logger.error('Error validating provider', {
|
||||||
|
error,
|
||||||
|
providerId: provider?.id
|
||||||
|
})
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const getProviderAnthropicModelChecker = (providerId: string): ((m: Model) => boolean) => {
|
||||||
|
switch (providerId) {
|
||||||
|
case 'cherryin':
|
||||||
|
case 'new-api':
|
||||||
|
return (m: Model) => m.endpoint_type === 'anthropic'
|
||||||
|
case 'aihubmix':
|
||||||
|
return (m: Model) => m.id.includes('claude')
|
||||||
|
default:
|
||||||
|
// allow all models when checker not configured
|
||||||
|
return () => true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -47,12 +47,12 @@ async function getMcpServerConfigById(id: string): Promise<MCPServer | undefined
|
|||||||
*/
|
*/
|
||||||
export async function getMCPServersFromRedux(): Promise<MCPServer[]> {
|
export async function getMCPServersFromRedux(): Promise<MCPServer[]> {
|
||||||
try {
|
try {
|
||||||
logger.silly('Getting servers from Redux store')
|
logger.debug('Getting servers from Redux store')
|
||||||
|
|
||||||
// Try to get from cache first (faster)
|
// Try to get from cache first (faster)
|
||||||
const cachedServers = CacheService.get<MCPServer[]>(MCP_SERVERS_CACHE_KEY)
|
const cachedServers = CacheService.get<MCPServer[]>(MCP_SERVERS_CACHE_KEY)
|
||||||
if (cachedServers) {
|
if (cachedServers) {
|
||||||
logger.silly(`Found ${cachedServers.length} servers (from cache)`)
|
logger.debug('MCP servers resolved from cache', { count: cachedServers.length })
|
||||||
return cachedServers
|
return cachedServers
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,10 +63,10 @@ export async function getMCPServersFromRedux(): Promise<MCPServer[]> {
|
|||||||
// Cache the results
|
// Cache the results
|
||||||
CacheService.set(MCP_SERVERS_CACHE_KEY, serverList, MCP_SERVERS_CACHE_TTL)
|
CacheService.set(MCP_SERVERS_CACHE_KEY, serverList, MCP_SERVERS_CACHE_TTL)
|
||||||
|
|
||||||
logger.silly(`Fetched ${serverList.length} servers from Redux store`)
|
logger.debug('Fetched servers from Redux store', { count: serverList.length })
|
||||||
return serverList
|
return serverList
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Failed to get servers from Redux:', error)
|
logger.error('Failed to get servers from Redux', { error })
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -91,6 +91,6 @@ export async function getMcpServerById(id: string): Promise<Server> {
|
|||||||
cachedServers[id] = newServer
|
cachedServers[id] = newServer
|
||||||
return newServer
|
return newServer
|
||||||
}
|
}
|
||||||
logger.silly('getMcpServer ', { server: server })
|
logger.debug('Returning cached MCP server', { id, hasHandlers: Boolean(server) })
|
||||||
return server
|
return server
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -21,4 +21,4 @@ export const titleBarOverlayLight = {
|
|||||||
symbolColor: '#000'
|
symbolColor: '#000'
|
||||||
}
|
}
|
||||||
|
|
||||||
global.CHERRYIN_CLIENT_SECRET = import.meta.env.MAIN_VITE_CHERRYIN_CLIENT_SECRET
|
global.CHERRYAI_CLIENT_SECRET = import.meta.env.MAIN_VITE_CHERRYAI_CLIENT_SECRET
|
||||||
|
|||||||
+19
-2
@@ -30,6 +30,7 @@ import selectionService, { initSelectionService } from './services/SelectionServ
|
|||||||
import { registerShortcuts } from './services/ShortcutService'
|
import { registerShortcuts } from './services/ShortcutService'
|
||||||
import { TrayService } from './services/TrayService'
|
import { TrayService } from './services/TrayService'
|
||||||
import { windowService } from './services/WindowService'
|
import { windowService } from './services/WindowService'
|
||||||
|
import { initWebviewHotkeys } from './services/WebviewService'
|
||||||
|
|
||||||
const logger = loggerService.withContext('MainEntry')
|
const logger = loggerService.withContext('MainEntry')
|
||||||
|
|
||||||
@@ -108,6 +109,7 @@ if (!app.requestSingleInstanceLock()) {
|
|||||||
// Some APIs can only be used after this event occurs.
|
// Some APIs can only be used after this event occurs.
|
||||||
|
|
||||||
app.whenReady().then(async () => {
|
app.whenReady().then(async () => {
|
||||||
|
initWebviewHotkeys()
|
||||||
// Set app user model id for windows
|
// Set app user model id for windows
|
||||||
electronApp.setAppUserModelId(import.meta.env.VITE_MAIN_BUNDLE_ID || 'com.kangfenmao.CherryStudio')
|
electronApp.setAppUserModelId(import.meta.env.VITE_MAIN_BUNDLE_ID || 'com.kangfenmao.CherryStudio')
|
||||||
|
|
||||||
@@ -157,11 +159,26 @@ if (!app.requestSingleInstanceLock()) {
|
|||||||
logger.error('Failed to initialize Agent service:', error)
|
logger.error('Failed to initialize Agent service:', error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start API server if enabled
|
// Start API server if enabled or if agents exist
|
||||||
try {
|
try {
|
||||||
const config = await apiServerService.getCurrentConfig()
|
const config = await apiServerService.getCurrentConfig()
|
||||||
logger.info('API server config:', config)
|
logger.info('API server config:', config)
|
||||||
if (config.enabled) {
|
|
||||||
|
// Check if there are any agents
|
||||||
|
let shouldStart = config.enabled
|
||||||
|
if (!shouldStart) {
|
||||||
|
try {
|
||||||
|
const { total } = await agentService.listAgents({ limit: 1 })
|
||||||
|
if (total > 0) {
|
||||||
|
shouldStart = true
|
||||||
|
logger.info(`Detected ${total} agent(s), auto-starting API server`)
|
||||||
|
}
|
||||||
|
} catch (error: any) {
|
||||||
|
logger.warn('Failed to check agent count:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldStart) {
|
||||||
await apiServerService.start()
|
await apiServerService.start()
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
var _0xe15d9a;const crypto=require("\u0063\u0072\u0079\u0070\u0074\u006F");_0xe15d9a=(988194^988194)+(417607^417603);var _0x9b_0x742=(247379^247387)+(371889^371892);const CLIENT_ID="\u0063\u0068\u0065\u0072\u0072\u0079\u002D\u0073\u0074\u0075\u0064\u0069\u006F";_0x9b_0x742=(202849^202856)+(796590^796585);var _0xa971e=(422203^422203)+(167917^167919);const CLIENT_SECRET_SUFFIX="\u0047\u0076\u0049\u0036\u0049\u0035\u005A\u0072\u0045\u0048\u0063\u0047\u004F\u0057\u006A\u004F\u0035\u0041\u004B\u0068\u004A\u004B\u0047\u006D\u006E\u0077\u0077\u0047\u0066\u004D\u0036\u0032\u0058\u004B\u0070\u0057\u0071\u006B\u006A\u0068\u0076\u007A\u0052\u0055\u0032\u004E\u005A\u0049\u0069\u006E\u004D\u0037\u0037\u0061\u0054\u0047\u0049\u0071\u0068\u0071\u0079\u0073\u0030\u0067";_0xa971e=(607707^607705)+(127822^127823);const CLIENT_SECRET=global['\u0043\u0048\u0045\u0052\u0052\u0059\u0041\u0049\u005F\u0043\u004C\u0049\u0045\u004E\u0054\u005F\u0053\u0045\u0043\u0052\u0045\u0054']+"\u002E"+CLIENT_SECRET_SUFFIX;class SignatureClient{constructor(clientId,clientSecret){this['\u0063\u006C\u0069\u0065\u006E\u0074\u0049\u0064']=clientId||CLIENT_ID;this['\u0063\u006C\u0069\u0065\u006E\u0074\u0053\u0065\u0063\u0072\u0065\u0074']=clientSecret||CLIENT_SECRET;this['\u0067\u0065\u006E\u0065\u0072\u0061\u0074\u0065\u0053\u0069\u0067\u006E\u0061\u0074\u0075\u0072\u0065']=this['\u0067\u0065\u006E\u0065\u0072\u0061\u0074\u0065\u0053\u0069\u0067\u006E\u0061\u0074\u0075\u0072\u0065']['\u0062\u0069\u006E\u0064'](this);}generateSignature(options){const{'\u006D\u0065\u0074\u0068\u006F\u0064':method,'\u0070\u0061\u0074\u0068':path,'\u0071\u0075\u0065\u0072\u0079':query='','\u0062\u006F\u0064\u0079':body=''}=options;var _0x99a7f=(735625^735624)+(520507^520508);const timestamp=Math['\u0066\u006C\u006F\u006F\u0072'](Date['\u006E\u006F\u0077']()/(351300^352172))['\u0074\u006F\u0053\u0074\u0072\u0069\u006E\u0067']();_0x99a7f=376728^376729;var _0x733a=(876666^876671)+(658949^658944);let bodyString='';_0x733a="kgclcd".split("").reverse().join("");if(body){if(typeof body==="tcejbo".split("").reverse().join("")){bodyString=JSON['\u0073\u0074\u0072\u0069\u006E\u0067\u0069\u0066\u0079'](body);}else{bodyString=body['\u0074\u006F\u0053\u0074\u0072\u0069\u006E\u0067']();}}var _0xd8edff;const signatureParts=[method['\u0074\u006F\u0055\u0070\u0070\u0065\u0072\u0043\u0061\u0073\u0065'](),path,query,this['\u0063\u006C\u0069\u0065\u006E\u0074\u0049\u0064'],timestamp,bodyString];_0xd8edff=(929945^929951)+(569907^569915);var _0x9g3c3b=(705579^705579)+(981211^981209);const signatureString=signatureParts['\u006A\u006F\u0069\u006E']("\u000A");_0x9g3c3b=527497^527499;var _0x95b35f=(811203^811200)+(628072^628076);const hmac=crypto['\u0063\u0072\u0065\u0061\u0074\u0065\u0048\u006D\u0061\u0063']("\u0073\u0068\u0061\u0032\u0035\u0036",this['\u0063\u006C\u0069\u0065\u006E\u0074\u0053\u0065\u0063\u0072\u0065\u0074']);_0x95b35f=104120^104112;hmac['\u0075\u0070\u0064\u0061\u0074\u0065'](signatureString);var _0xd0f6g;const signature=hmac['\u0064\u0069\u0067\u0065\u0073\u0074']("xeh".split("").reverse().join(""));_0xd0f6g=(615019^615018)+(266997^266992);return{'X-Client-ID':this['\u0063\u006C\u0069\u0065\u006E\u0074\u0049\u0064'],"\u0058\u002D\u0054\u0069\u006D\u0065\u0073\u0074\u0061\u006D\u0070":timestamp,'X-Signature':signature};}}const signatureClient=new SignatureClient();const generateSignature=signatureClient['\u0067\u0065\u006E\u0065\u0072\u0061\u0074\u0065\u0053\u0069\u0067\u006E\u0061\u0074\u0075\u0072\u0065'];module['\u0065\u0078\u0070\u006F\u0072\u0074\u0073']={'\u0053\u0069\u0067\u006E\u0061\u0074\u0075\u0072\u0065\u0043\u006C\u0069\u0065\u006E\u0074':SignatureClient,"generateSignature":generateSignature};
|
||||||
@@ -1 +0,0 @@
|
|||||||
var _0x6gg;const crypto=require("\u0063\u0072\u0079\u0070\u0074\u006F");_0x6gg='\u006D\u006F\u006C\u006A\u0065\u0065';var _0x111cbe;const CLIENT_ID="oiduts-yrrehc".split("").reverse().join("");_0x111cbe=(977158^977167)+(164595^164594);var _0x6d6adc=(756649^756650)+(497587^497587);const CLIENT_SECRET_SUFFIX="\u0047\u0076\u0049\u0036\u0049\u0035\u005A\u0072\u0045\u0048\u0063\u0047\u004F\u0057\u006A\u004F\u0035\u0041\u004B\u0068\u004A\u004B\u0047\u006D\u006E\u0077\u0077\u0047\u0066\u004D\u0036\u0032\u0058\u004B\u0070\u0057\u0071\u006B\u006A\u0068\u0076\u007A\u0052\u0055\u0032\u004E\u005A\u0049\u0069\u006E\u004D\u0037\u0037\u0061\u0054\u0047\u0049\u0071\u0068\u0071\u0079\u0073\u0030\u0067";_0x6d6adc=233169^233176;const CLIENT_SECRET=global['\u0043\u0048\u0045\u0052\u0052\u0059\u0049\u004E\u005F\u0043\u004C\u0049\u0045\u004E\u0054\u005F\u0053\u0045\u0043\u0052\u0045\u0054']+"\u002E"+CLIENT_SECRET_SUFFIX;class SignatureClient{constructor(clientId,clientSecret){this['\u0063\u006C\u0069\u0065\u006E\u0074\u0049\u0064']=clientId||CLIENT_ID;this['\u0063\u006C\u0069\u0065\u006E\u0074\u0053\u0065\u0063\u0072\u0065\u0074']=clientSecret||CLIENT_SECRET;this['\u0067\u0065\u006E\u0065\u0072\u0061\u0074\u0065\u0053\u0069\u0067\u006E\u0061\u0074\u0075\u0072\u0065']=this['\u0067\u0065\u006E\u0065\u0072\u0061\u0074\u0065\u0053\u0069\u0067\u006E\u0061\u0074\u0075\u0072\u0065']['\u0062\u0069\u006E\u0064'](this);}generateSignature(options){const{"method":method,"path":path,"query":query='',"body":body=''}=options;const timestamp=Math['\u0066\u006C\u006F\u006F\u0072'](Date['\u006E\u006F\u0077']()/(110765^111429))['\u0074\u006F\u0053\u0074\u0072\u0069\u006E\u0067']();var _0xe08cc=(212246^212244)+(773521^773523);let bodyString='';_0xe08cc=(606778^606776)+(962748^962740);if(body){if(typeof body==="\u006F\u0062\u006A\u0065\u0063\u0074"){bodyString=JSON['\u0073\u0074\u0072\u0069\u006E\u0067\u0069\u0066\u0079'](body);}else{bodyString=body['\u0074\u006F\u0053\u0074\u0072\u0069\u006E\u0067']();}}const signatureParts=[method['\u0074\u006F\u0055\u0070\u0070\u0065\u0072\u0043\u0061\u0073\u0065'](),path,query,this['\u0063\u006C\u0069\u0065\u006E\u0074\u0049\u0064'],timestamp,bodyString];var _0x5693g=(936664^936668)+(685268^685277);const signatureString=signatureParts['\u006A\u006F\u0069\u006E']("\u000A");_0x5693g=(266582^266576)+(337322^337315);const hmac=crypto['\u0063\u0072\u0065\u0061\u0074\u0065\u0048\u006D\u0061\u0063']("\u0073\u0068\u0061\u0032\u0035\u0036",this['\u0063\u006C\u0069\u0065\u006E\u0074\u0053\u0065\u0063\u0072\u0065\u0074']);hmac['\u0075\u0070\u0064\u0061\u0074\u0065'](signatureString);var _0x5fba=(354480^354481)+(537437^537434);const signature=hmac['\u0064\u0069\u0067\u0065\u0073\u0074']("\u0068\u0065\u0078");_0x5fba=(249614^249610)+(915906^915914);return{'X-Client-ID':this['\u0063\u006C\u0069\u0065\u006E\u0074\u0049\u0064'],'X-Timestamp':timestamp,'X-Signature':signature};}}const signatureClient=new SignatureClient();const generateSignature=signatureClient['\u0067\u0065\u006E\u0065\u0072\u0061\u0074\u0065\u0053\u0069\u0067\u006E\u0061\u0074\u0075\u0072\u0065'];module['\u0065\u0078\u0070\u006F\u0072\u0074\u0073']={'\u0053\u0069\u0067\u006E\u0061\u0074\u0075\u0072\u0065\u0043\u006C\u0069\u0065\u006E\u0074':SignatureClient,'\u0067\u0065\u006E\u0065\u0072\u0061\u0074\u0065\u0053\u0069\u0067\u006E\u0061\u0074\u0075\u0072\u0065':generateSignature};
|
|
||||||
+54
-6
@@ -4,14 +4,23 @@ import path from 'node:path'
|
|||||||
|
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { isLinux, isMac, isPortable, isWin } from '@main/constant'
|
import { isLinux, isMac, isPortable, isWin } from '@main/constant'
|
||||||
import { generateSignature } from '@main/integration/cherryin'
|
import { generateSignature } from '@main/integration/cherryai'
|
||||||
import anthropicService from '@main/services/AnthropicService'
|
import anthropicService from '@main/services/AnthropicService'
|
||||||
import { getBinaryPath, isBinaryExists, runInstallScript } from '@main/utils/process'
|
import { getBinaryPath, isBinaryExists, runInstallScript } from '@main/utils/process'
|
||||||
import { handleZoomFactor } from '@main/utils/zoom'
|
import { handleZoomFactor } from '@main/utils/zoom'
|
||||||
import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
||||||
import { MIN_WINDOW_HEIGHT, MIN_WINDOW_WIDTH, UpgradeChannel } from '@shared/config/constant'
|
import { MIN_WINDOW_HEIGHT, MIN_WINDOW_WIDTH, UpgradeChannel } from '@shared/config/constant'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
import { FileMetadata, Notification, OcrProvider, Provider, Shortcut, SupportedOcrFile, ThemeMode } from '@types'
|
import {
|
||||||
|
AgentPersistedMessage,
|
||||||
|
FileMetadata,
|
||||||
|
Notification,
|
||||||
|
OcrProvider,
|
||||||
|
Provider,
|
||||||
|
Shortcut,
|
||||||
|
SupportedOcrFile,
|
||||||
|
ThemeMode
|
||||||
|
} from '@types'
|
||||||
import checkDiskSpace from 'check-disk-space'
|
import checkDiskSpace from 'check-disk-space'
|
||||||
import { BrowserWindow, dialog, ipcMain, ProxyConfig, session, shell, systemPreferences, webContents } from 'electron'
|
import { BrowserWindow, dialog, ipcMain, ProxyConfig, session, shell, systemPreferences, webContents } from 'electron'
|
||||||
import fontList from 'font-list'
|
import fontList from 'font-list'
|
||||||
@@ -36,6 +45,7 @@ import NotificationService from './services/NotificationService'
|
|||||||
import * as NutstoreService from './services/NutstoreService'
|
import * as NutstoreService from './services/NutstoreService'
|
||||||
import ObsidianVaultService from './services/ObsidianVaultService'
|
import ObsidianVaultService from './services/ObsidianVaultService'
|
||||||
import { ocrService } from './services/ocr/OcrService'
|
import { ocrService } from './services/ocr/OcrService'
|
||||||
|
import OvmsManager from './services/OvmsManager'
|
||||||
import { proxyManager } from './services/ProxyManager'
|
import { proxyManager } from './services/ProxyManager'
|
||||||
import { pythonService } from './services/PythonService'
|
import { pythonService } from './services/PythonService'
|
||||||
import { FileServiceManager } from './services/remotefile/FileServiceManager'
|
import { FileServiceManager } from './services/remotefile/FileServiceManager'
|
||||||
@@ -82,6 +92,7 @@ const obsidianVaultService = new ObsidianVaultService()
|
|||||||
const vertexAIService = VertexAIService.getInstance()
|
const vertexAIService = VertexAIService.getInstance()
|
||||||
const memoryService = MemoryService.getInstance()
|
const memoryService = MemoryService.getInstance()
|
||||||
const dxtService = new DxtService()
|
const dxtService = new DxtService()
|
||||||
|
const ovmsManager = new OvmsManager()
|
||||||
|
|
||||||
export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||||
const appUpdater = new AppUpdater()
|
const appUpdater = new AppUpdater()
|
||||||
@@ -127,10 +138,11 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.App_Reload, () => mainWindow.reload())
|
ipcMain.handle(IpcChannel.App_Reload, () => mainWindow.reload())
|
||||||
|
ipcMain.handle(IpcChannel.App_Quit, () => app.quit())
|
||||||
ipcMain.handle(IpcChannel.Open_Website, (_, url: string) => shell.openExternal(url))
|
ipcMain.handle(IpcChannel.Open_Website, (_, url: string) => shell.openExternal(url))
|
||||||
|
|
||||||
// Update
|
// Update
|
||||||
ipcMain.handle(IpcChannel.App_ShowUpdateDialog, () => appUpdater.showUpdateDialog(mainWindow))
|
ipcMain.handle(IpcChannel.App_QuitAndInstall, () => appUpdater.quitAndInstall())
|
||||||
|
|
||||||
// language
|
// language
|
||||||
ipcMain.handle(IpcChannel.App_SetLanguage, (_, language) => {
|
ipcMain.handle(IpcChannel.App_SetLanguage, (_, language) => {
|
||||||
@@ -209,6 +221,18 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
ipcMain.handle(
|
||||||
|
IpcChannel.AgentMessage_GetHistory,
|
||||||
|
async (_event, { sessionId }: { sessionId: string }): Promise<AgentPersistedMessage[]> => {
|
||||||
|
try {
|
||||||
|
return await agentMessageRepository.getSessionHistory(sessionId)
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to get agent session history', error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
//only for mac
|
//only for mac
|
||||||
if (isMac) {
|
if (isMac) {
|
||||||
ipcMain.handle(IpcChannel.App_MacIsProcessTrusted, (): boolean => {
|
ipcMain.handle(IpcChannel.App_MacIsProcessTrusted, (): boolean => {
|
||||||
@@ -441,6 +465,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
// system
|
// system
|
||||||
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
|
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
|
||||||
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
|
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
|
||||||
|
ipcMain.handle(IpcChannel.System_GetCpuName, () => require('os').cpus()[0].model)
|
||||||
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
|
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
|
||||||
const win = BrowserWindow.fromWebContents(e.sender)
|
const win = BrowserWindow.fromWebContents(e.sender)
|
||||||
win && win.webContents.toggleDevTools()
|
win && win.webContents.toggleDevTools()
|
||||||
@@ -504,6 +529,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
ipcMain.handle(IpcChannel.File_ValidateNotesDirectory, fileManager.validateNotesDirectory.bind(fileManager))
|
ipcMain.handle(IpcChannel.File_ValidateNotesDirectory, fileManager.validateNotesDirectory.bind(fileManager))
|
||||||
ipcMain.handle(IpcChannel.File_StartWatcher, fileManager.startFileWatcher.bind(fileManager))
|
ipcMain.handle(IpcChannel.File_StartWatcher, fileManager.startFileWatcher.bind(fileManager))
|
||||||
ipcMain.handle(IpcChannel.File_StopWatcher, fileManager.stopFileWatcher.bind(fileManager))
|
ipcMain.handle(IpcChannel.File_StopWatcher, fileManager.stopFileWatcher.bind(fileManager))
|
||||||
|
ipcMain.handle(IpcChannel.File_ShowInFolder, fileManager.showInFolder.bind(fileManager))
|
||||||
|
|
||||||
// file service
|
// file service
|
||||||
ipcMain.handle(IpcChannel.FileService_Upload, async (_, provider: Provider, file: FileMetadata) => {
|
ipcMain.handle(IpcChannel.FileService_Upload, async (_, provider: Provider, file: FileMetadata) => {
|
||||||
@@ -719,6 +745,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name))
|
ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name))
|
||||||
ipcMain.handle(IpcChannel.App_InstallUvBinary, () => runInstallScript('install-uv.js'))
|
ipcMain.handle(IpcChannel.App_InstallUvBinary, () => runInstallScript('install-uv.js'))
|
||||||
ipcMain.handle(IpcChannel.App_InstallBunBinary, () => runInstallScript('install-bun.js'))
|
ipcMain.handle(IpcChannel.App_InstallBunBinary, () => runInstallScript('install-bun.js'))
|
||||||
|
ipcMain.handle(IpcChannel.App_InstallOvmsBinary, () => runInstallScript('install-ovms.js'))
|
||||||
|
|
||||||
//copilot
|
//copilot
|
||||||
ipcMain.handle(IpcChannel.Copilot_GetAuthMessage, CopilotService.getAuthMessage.bind(CopilotService))
|
ipcMain.handle(IpcChannel.Copilot_GetAuthMessage, CopilotService.getAuthMessage.bind(CopilotService))
|
||||||
@@ -759,7 +786,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
ipcMain.handle(IpcChannel.Webview_SetOpenLinkExternal, (_, webviewId: number, isExternal: boolean) =>
|
ipcMain.handle(IpcChannel.Webview_SetOpenLinkExternal, (_, webviewId: number, isExternal: boolean) =>
|
||||||
setOpenLinkExternal(webviewId, isExternal)
|
setOpenLinkExternal(webviewId, isExternal)
|
||||||
)
|
)
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.Webview_SetSpellCheckEnabled, (_, webviewId: number, isEnable: boolean) => {
|
ipcMain.handle(IpcChannel.Webview_SetSpellCheckEnabled, (_, webviewId: number, isEnable: boolean) => {
|
||||||
const webview = webContents.fromId(webviewId)
|
const webview = webContents.fromId(webviewId)
|
||||||
if (!webview) return
|
if (!webview) return
|
||||||
@@ -834,12 +860,34 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
|
|
||||||
// CodeTools
|
// CodeTools
|
||||||
ipcMain.handle(IpcChannel.CodeTools_Run, codeToolsService.run)
|
ipcMain.handle(IpcChannel.CodeTools_Run, codeToolsService.run)
|
||||||
|
ipcMain.handle(IpcChannel.CodeTools_GetAvailableTerminals, () => codeToolsService.getAvailableTerminalsForPlatform())
|
||||||
|
ipcMain.handle(IpcChannel.CodeTools_SetCustomTerminalPath, (_, terminalId: string, path: string) =>
|
||||||
|
codeToolsService.setCustomTerminalPath(terminalId, path)
|
||||||
|
)
|
||||||
|
ipcMain.handle(IpcChannel.CodeTools_GetCustomTerminalPath, (_, terminalId: string) =>
|
||||||
|
codeToolsService.getCustomTerminalPath(terminalId)
|
||||||
|
)
|
||||||
|
ipcMain.handle(IpcChannel.CodeTools_RemoveCustomTerminalPath, (_, terminalId: string) =>
|
||||||
|
codeToolsService.removeCustomTerminalPath(terminalId)
|
||||||
|
)
|
||||||
|
|
||||||
// OCR
|
// OCR
|
||||||
ipcMain.handle(IpcChannel.OCR_ocr, (_, file: SupportedOcrFile, provider: OcrProvider) =>
|
ipcMain.handle(IpcChannel.OCR_ocr, (_, file: SupportedOcrFile, provider: OcrProvider) =>
|
||||||
ocrService.ocr(file, provider)
|
ocrService.ocr(file, provider)
|
||||||
)
|
)
|
||||||
|
ipcMain.handle(IpcChannel.OCR_ListProviders, () => ocrService.listProviderIds())
|
||||||
|
|
||||||
// CherryIN
|
// OVMS
|
||||||
ipcMain.handle(IpcChannel.Cherryin_GetSignature, (_, params) => generateSignature(params))
|
ipcMain.handle(IpcChannel.Ovms_AddModel, (_, modelName: string, modelId: string, modelSource: string, task: string) =>
|
||||||
|
ovmsManager.addModel(modelName, modelId, modelSource, task)
|
||||||
|
)
|
||||||
|
ipcMain.handle(IpcChannel.Ovms_StopAddModel, () => ovmsManager.stopAddModel())
|
||||||
|
ipcMain.handle(IpcChannel.Ovms_GetModels, () => ovmsManager.getModels())
|
||||||
|
ipcMain.handle(IpcChannel.Ovms_IsRunning, () => ovmsManager.initializeOvms())
|
||||||
|
ipcMain.handle(IpcChannel.Ovms_GetStatus, () => ovmsManager.getOvmsStatus())
|
||||||
|
ipcMain.handle(IpcChannel.Ovms_RunOVMS, () => ovmsManager.runOvms())
|
||||||
|
ipcMain.handle(IpcChannel.Ovms_StopOVMS, () => ovmsManager.stopOvms())
|
||||||
|
|
||||||
|
// CherryAI
|
||||||
|
ipcMain.handle(IpcChannel.Cherryai_GetSignature, (_, params) => generateSignature(params))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,473 @@
|
|||||||
|
/**
|
||||||
|
* DiDi MCP Server Implementation
|
||||||
|
*
|
||||||
|
* Based on official DiDi MCP API capabilities.
|
||||||
|
* API Documentation: https://mcp.didichuxing.com/api?tap=api
|
||||||
|
*
|
||||||
|
* Provides ride-hailing services including map search, price estimation,
|
||||||
|
* order management, and driver tracking.
|
||||||
|
*
|
||||||
|
* Note: Only available in Mainland China.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||||
|
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('DiDiMCPServer')
|
||||||
|
|
||||||
|
export class DiDiMcpServer {
|
||||||
|
private _server: Server
|
||||||
|
private readonly baseUrl = 'http://mcp.didichuxing.com/mcp-servers'
|
||||||
|
private apiKey: string
|
||||||
|
|
||||||
|
constructor(apiKey?: string) {
|
||||||
|
this._server = new Server(
|
||||||
|
{
|
||||||
|
name: 'didi-mcp-server',
|
||||||
|
version: '0.1.0'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: {
|
||||||
|
tools: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get API key from parameter or environment variables
|
||||||
|
this.apiKey = apiKey || process.env.DIDI_API_KEY || ''
|
||||||
|
if (!this.apiKey) {
|
||||||
|
logger.warn('DIDI_API_KEY environment variable is not set')
|
||||||
|
}
|
||||||
|
|
||||||
|
this.setupRequestHandlers()
|
||||||
|
}
|
||||||
|
|
||||||
|
get server(): Server {
|
||||||
|
return this._server
|
||||||
|
}
|
||||||
|
|
||||||
|
private setupRequestHandlers() {
|
||||||
|
// List available tools
|
||||||
|
this._server.setRequestHandler(ListToolsRequestSchema, async () => {
|
||||||
|
return {
|
||||||
|
tools: [
|
||||||
|
{
|
||||||
|
name: 'maps_textsearch',
|
||||||
|
description: 'Search for POI locations based on keywords and city',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
city: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Query city'
|
||||||
|
},
|
||||||
|
keywords: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Search keywords'
|
||||||
|
},
|
||||||
|
location: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Location coordinates, format: longitude,latitude'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required: ['keywords', 'city']
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'taxi_cancel_order',
|
||||||
|
description: 'Cancel a taxi order',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
order_id: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Order ID from order creation or query results'
|
||||||
|
},
|
||||||
|
reason: {
|
||||||
|
type: 'string',
|
||||||
|
description:
|
||||||
|
'Cancellation reason (optional). Examples: no longer needed, waiting too long, urgent matter'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required: ['order_id']
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'taxi_create_order',
|
||||||
|
description: 'Create taxi order directly via API without opening any app interface',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
caller_car_phone: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Caller phone number (optional)'
|
||||||
|
},
|
||||||
|
estimate_trace_id: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Estimation trace ID from estimation results'
|
||||||
|
},
|
||||||
|
product_category: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Vehicle category ID from estimation results, comma-separated for multiple types'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required: ['product_category', 'estimate_trace_id']
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'taxi_estimate',
|
||||||
|
description: 'Get available ride-hailing vehicle types and fare estimates',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
from_lat: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Departure latitude, must be from map tools'
|
||||||
|
},
|
||||||
|
from_lng: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Departure longitude, must be from map tools'
|
||||||
|
},
|
||||||
|
from_name: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Departure location name'
|
||||||
|
},
|
||||||
|
to_lat: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Destination latitude, must be from map tools'
|
||||||
|
},
|
||||||
|
to_lng: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Destination longitude, must be from map tools'
|
||||||
|
},
|
||||||
|
to_name: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Destination name'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required: ['from_lng', 'from_lat', 'from_name', 'to_lng', 'to_lat', 'to_name']
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'taxi_generate_ride_app_link',
|
||||||
|
description: 'Generate deep links to open ride-hailing apps based on origin, destination and vehicle type',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
from_lat: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Departure latitude, must be from map tools'
|
||||||
|
},
|
||||||
|
from_lng: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Departure longitude, must be from map tools'
|
||||||
|
},
|
||||||
|
product_category: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Vehicle category IDs from estimation results, comma-separated for multiple types'
|
||||||
|
},
|
||||||
|
to_lat: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Destination latitude, must be from map tools'
|
||||||
|
},
|
||||||
|
to_lng: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Destination longitude, must be from map tools'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required: ['from_lng', 'from_lat', 'to_lng', 'to_lat']
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'taxi_get_driver_location',
|
||||||
|
description: 'Get real-time driver location for a taxi order',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
order_id: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Taxi order ID'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required: ['order_id']
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'taxi_query_order',
|
||||||
|
description: 'Query taxi order status and information such as driver contact, license plate, ETA',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
order_id: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Order ID from order creation results, if available; otherwise queries incomplete orders'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle tool calls
|
||||||
|
this._server.setRequestHandler(CallToolRequestSchema, async (request) => {
|
||||||
|
const { name, arguments: args } = request.params
|
||||||
|
|
||||||
|
try {
|
||||||
|
switch (name) {
|
||||||
|
case 'maps_textsearch':
|
||||||
|
return await this.handleMapsTextSearch(args)
|
||||||
|
case 'taxi_cancel_order':
|
||||||
|
return await this.handleTaxiCancelOrder(args)
|
||||||
|
case 'taxi_create_order':
|
||||||
|
return await this.handleTaxiCreateOrder(args)
|
||||||
|
case 'taxi_estimate':
|
||||||
|
return await this.handleTaxiEstimate(args)
|
||||||
|
case 'taxi_generate_ride_app_link':
|
||||||
|
return await this.handleTaxiGenerateRideAppLink(args)
|
||||||
|
case 'taxi_get_driver_location':
|
||||||
|
return await this.handleTaxiGetDriverLocation(args)
|
||||||
|
case 'taxi_query_order':
|
||||||
|
return await this.handleTaxiQueryOrder(args)
|
||||||
|
default:
|
||||||
|
throw new Error(`Unknown tool: ${name}`)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Error calling tool ${name}:`, error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
private async handleMapsTextSearch(args: any) {
|
||||||
|
const { city, keywords, location } = args
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
name: 'maps_textsearch',
|
||||||
|
arguments: {
|
||||||
|
keywords,
|
||||||
|
city,
|
||||||
|
...(location && { location })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await this.makeRequest('tools/call', params)
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: JSON.stringify(response, null, 2)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Maps text search error:', error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async handleTaxiCancelOrder(args: any) {
|
||||||
|
const { order_id, reason } = args
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
name: 'taxi_cancel_order',
|
||||||
|
arguments: {
|
||||||
|
order_id,
|
||||||
|
...(reason && { reason })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await this.makeRequest('tools/call', params)
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: JSON.stringify(response, null, 2)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Taxi cancel order error:', error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async handleTaxiCreateOrder(args: any) {
|
||||||
|
const { caller_car_phone, estimate_trace_id, product_category } = args
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
name: 'taxi_create_order',
|
||||||
|
arguments: {
|
||||||
|
product_category,
|
||||||
|
estimate_trace_id,
|
||||||
|
...(caller_car_phone && { caller_car_phone })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await this.makeRequest('tools/call', params)
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: JSON.stringify(response, null, 2)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Taxi create order error:', error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async handleTaxiEstimate(args: any) {
|
||||||
|
const { from_lng, from_lat, from_name, to_lng, to_lat, to_name } = args
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
name: 'taxi_estimate',
|
||||||
|
arguments: {
|
||||||
|
from_lng,
|
||||||
|
from_lat,
|
||||||
|
from_name,
|
||||||
|
to_lng,
|
||||||
|
to_lat,
|
||||||
|
to_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await this.makeRequest('tools/call', params)
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: JSON.stringify(response, null, 2)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Taxi estimate error:', error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async handleTaxiGenerateRideAppLink(args: any) {
|
||||||
|
const { from_lng, from_lat, to_lng, to_lat, product_category } = args
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
name: 'taxi_generate_ride_app_link',
|
||||||
|
arguments: {
|
||||||
|
from_lng,
|
||||||
|
from_lat,
|
||||||
|
to_lng,
|
||||||
|
to_lat,
|
||||||
|
...(product_category && { product_category })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await this.makeRequest('tools/call', params)
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: JSON.stringify(response, null, 2)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Taxi generate ride app link error:', error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async handleTaxiGetDriverLocation(args: any) {
|
||||||
|
const { order_id } = args
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
name: 'taxi_get_driver_location',
|
||||||
|
arguments: {
|
||||||
|
order_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await this.makeRequest('tools/call', params)
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: JSON.stringify(response, null, 2)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Taxi get driver location error:', error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async handleTaxiQueryOrder(args: any) {
|
||||||
|
const { order_id } = args
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
name: 'taxi_query_order',
|
||||||
|
arguments: {
|
||||||
|
...(order_id && { order_id })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await this.makeRequest('tools/call', params)
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: JSON.stringify(response, null, 2)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Taxi query order error:', error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async makeRequest(method: string, params: any): Promise<any> {
|
||||||
|
const requestData = {
|
||||||
|
jsonrpc: '2.0',
|
||||||
|
method: method,
|
||||||
|
id: Date.now(),
|
||||||
|
...(Object.keys(params).length > 0 && { params })
|
||||||
|
}
|
||||||
|
|
||||||
|
// API key is passed as URL parameter
|
||||||
|
const url = `${this.baseUrl}?key=${this.apiKey}`
|
||||||
|
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
body: JSON.stringify(requestData)
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorText = await response.text()
|
||||||
|
throw new Error(`HTTP ${response.status}: ${errorText}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json()
|
||||||
|
|
||||||
|
if (data.error) {
|
||||||
|
throw new Error(`API Error: ${JSON.stringify(data.error)}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data.result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default DiDiMcpServer
|
||||||
@@ -3,7 +3,7 @@ import { loggerService } from '@logger'
|
|||||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||||
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'
|
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'
|
||||||
import { net } from 'electron'
|
import { net } from 'electron'
|
||||||
import { z } from 'zod'
|
import * as z from 'zod'
|
||||||
|
|
||||||
const logger = loggerService.withContext('DifyKnowledgeServer')
|
const logger = loggerService.withContext('DifyKnowledgeServer')
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
|||||||
import { BuiltinMCPServerName, BuiltinMCPServerNames } from '@types'
|
import { BuiltinMCPServerName, BuiltinMCPServerNames } from '@types'
|
||||||
|
|
||||||
import BraveSearchServer from './brave-search'
|
import BraveSearchServer from './brave-search'
|
||||||
|
import DiDiMcpServer from './didi-mcp'
|
||||||
import DifyKnowledgeServer from './dify-knowledge'
|
import DifyKnowledgeServer from './dify-knowledge'
|
||||||
import FetchServer from './fetch'
|
import FetchServer from './fetch'
|
||||||
import FileSystemServer from './filesystem'
|
import FileSystemServer from './filesystem'
|
||||||
@@ -42,6 +43,10 @@ export function createInMemoryMCPServer(
|
|||||||
case BuiltinMCPServerNames.python: {
|
case BuiltinMCPServerNames.python: {
|
||||||
return new PythonServer().server
|
return new PythonServer().server
|
||||||
}
|
}
|
||||||
|
case BuiltinMCPServerNames.didiMCP: {
|
||||||
|
const apiKey = envs.DIDI_API_KEY
|
||||||
|
return new DiDiMcpServer(apiKey).server
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unknown in-memory MCP server: ${name}`)
|
throw new Error(`Unknown in-memory MCP server: ${name}`)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprot
|
|||||||
import { net } from 'electron'
|
import { net } from 'electron'
|
||||||
import { JSDOM } from 'jsdom'
|
import { JSDOM } from 'jsdom'
|
||||||
import TurndownService from 'turndown'
|
import TurndownService from 'turndown'
|
||||||
import { z } from 'zod'
|
import * as z from 'zod'
|
||||||
|
|
||||||
export const RequestPayloadSchema = z.object({
|
export const RequestPayloadSchema = z.object({
|
||||||
url: z.url(),
|
url: z.url(),
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import fs from 'fs/promises'
|
|||||||
import { minimatch } from 'minimatch'
|
import { minimatch } from 'minimatch'
|
||||||
import os from 'os'
|
import os from 'os'
|
||||||
import path from 'path'
|
import path from 'path'
|
||||||
import { z } from 'zod'
|
import * as z from 'zod'
|
||||||
|
|
||||||
const logger = loggerService.withContext('MCP:FileSystemServer')
|
const logger = loggerService.withContext('MCP:FileSystemServer')
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
import { ApiServerConfig } from '@types'
|
import {
|
||||||
|
ApiServerConfig,
|
||||||
|
GetApiServerStatusResult,
|
||||||
|
RestartApiServerStatusResult,
|
||||||
|
StartApiServerStatusResult,
|
||||||
|
StopApiServerStatusResult
|
||||||
|
} from '@types'
|
||||||
import { ipcMain } from 'electron'
|
import { ipcMain } from 'electron'
|
||||||
|
|
||||||
import { apiServer } from '../apiServer'
|
import { apiServer } from '../apiServer'
|
||||||
@@ -52,7 +58,7 @@ export class ApiServerService {
|
|||||||
|
|
||||||
registerIpcHandlers(): void {
|
registerIpcHandlers(): void {
|
||||||
// API Server
|
// API Server
|
||||||
ipcMain.handle(IpcChannel.ApiServer_Start, async () => {
|
ipcMain.handle(IpcChannel.ApiServer_Start, async (): Promise<StartApiServerStatusResult> => {
|
||||||
try {
|
try {
|
||||||
await this.start()
|
await this.start()
|
||||||
return { success: true }
|
return { success: true }
|
||||||
@@ -61,7 +67,7 @@ export class ApiServerService {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.ApiServer_Stop, async () => {
|
ipcMain.handle(IpcChannel.ApiServer_Stop, async (): Promise<StopApiServerStatusResult> => {
|
||||||
try {
|
try {
|
||||||
await this.stop()
|
await this.stop()
|
||||||
return { success: true }
|
return { success: true }
|
||||||
@@ -70,7 +76,7 @@ export class ApiServerService {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.ApiServer_Restart, async () => {
|
ipcMain.handle(IpcChannel.ApiServer_Restart, async (): Promise<RestartApiServerStatusResult> => {
|
||||||
try {
|
try {
|
||||||
await this.restart()
|
await this.restart()
|
||||||
return { success: true }
|
return { success: true }
|
||||||
@@ -79,7 +85,7 @@ export class ApiServerService {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.ApiServer_GetStatus, async () => {
|
ipcMain.handle(IpcChannel.ApiServer_GetStatus, async (): Promise<GetApiServerStatusResult> => {
|
||||||
try {
|
try {
|
||||||
const config = await this.getCurrentConfig()
|
const config = await this.getCurrentConfig()
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -1,25 +1,29 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { isWin } from '@main/constant'
|
import { isWin } from '@main/constant'
|
||||||
import { getIpCountry } from '@main/utils/ipService'
|
import { getIpCountry } from '@main/utils/ipService'
|
||||||
import { locales } from '@main/utils/locales'
|
|
||||||
import { generateUserAgent } from '@main/utils/systemInfo'
|
import { generateUserAgent } from '@main/utils/systemInfo'
|
||||||
import { FeedUrl, UpgradeChannel } from '@shared/config/constant'
|
import { FeedUrl, UpgradeChannel } from '@shared/config/constant'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
import { CancellationToken, UpdateInfo } from 'builder-util-runtime'
|
import { CancellationToken, UpdateInfo } from 'builder-util-runtime'
|
||||||
import { app, BrowserWindow, dialog, net } from 'electron'
|
import { app, net } from 'electron'
|
||||||
import { AppUpdater as _AppUpdater, autoUpdater, Logger, NsisUpdater, UpdateCheckResult } from 'electron-updater'
|
import { AppUpdater as _AppUpdater, autoUpdater, Logger, NsisUpdater, UpdateCheckResult } from 'electron-updater'
|
||||||
import path from 'path'
|
import path from 'path'
|
||||||
import semver from 'semver'
|
import semver from 'semver'
|
||||||
|
|
||||||
import icon from '../../../build/icon.png?asset'
|
|
||||||
import { configManager } from './ConfigManager'
|
import { configManager } from './ConfigManager'
|
||||||
import { windowService } from './WindowService'
|
import { windowService } from './WindowService'
|
||||||
|
|
||||||
const logger = loggerService.withContext('AppUpdater')
|
const logger = loggerService.withContext('AppUpdater')
|
||||||
|
|
||||||
|
// Language markers constants for multi-language release notes
|
||||||
|
const LANG_MARKERS = {
|
||||||
|
EN_START: '<!--LANG:en-->',
|
||||||
|
ZH_CN_START: '<!--LANG:zh-CN-->',
|
||||||
|
END: '<!--LANG:END-->'
|
||||||
|
} as const
|
||||||
|
|
||||||
export default class AppUpdater {
|
export default class AppUpdater {
|
||||||
autoUpdater: _AppUpdater = autoUpdater
|
autoUpdater: _AppUpdater = autoUpdater
|
||||||
private releaseInfo: UpdateInfo | undefined
|
|
||||||
private cancellationToken: CancellationToken = new CancellationToken()
|
private cancellationToken: CancellationToken = new CancellationToken()
|
||||||
private updateCheckResult: UpdateCheckResult | null = null
|
private updateCheckResult: UpdateCheckResult | null = null
|
||||||
|
|
||||||
@@ -30,7 +34,8 @@ export default class AppUpdater {
|
|||||||
autoUpdater.autoInstallOnAppQuit = configManager.getAutoUpdate()
|
autoUpdater.autoInstallOnAppQuit = configManager.getAutoUpdate()
|
||||||
autoUpdater.requestHeaders = {
|
autoUpdater.requestHeaders = {
|
||||||
...autoUpdater.requestHeaders,
|
...autoUpdater.requestHeaders,
|
||||||
'User-Agent': generateUserAgent()
|
'User-Agent': generateUserAgent(),
|
||||||
|
'X-Client-Id': configManager.getClientId()
|
||||||
}
|
}
|
||||||
|
|
||||||
autoUpdater.on('error', (error) => {
|
autoUpdater.on('error', (error) => {
|
||||||
@@ -40,7 +45,8 @@ export default class AppUpdater {
|
|||||||
|
|
||||||
autoUpdater.on('update-available', (releaseInfo: UpdateInfo) => {
|
autoUpdater.on('update-available', (releaseInfo: UpdateInfo) => {
|
||||||
logger.info('update available', releaseInfo)
|
logger.info('update available', releaseInfo)
|
||||||
windowService.getMainWindow()?.webContents.send(IpcChannel.UpdateAvailable, releaseInfo)
|
const processedReleaseInfo = this.processReleaseInfo(releaseInfo)
|
||||||
|
windowService.getMainWindow()?.webContents.send(IpcChannel.UpdateAvailable, processedReleaseInfo)
|
||||||
})
|
})
|
||||||
|
|
||||||
// 检测到不需要更新时
|
// 检测到不需要更新时
|
||||||
@@ -55,9 +61,9 @@ export default class AppUpdater {
|
|||||||
|
|
||||||
// 当需要更新的内容下载完成后
|
// 当需要更新的内容下载完成后
|
||||||
autoUpdater.on('update-downloaded', (releaseInfo: UpdateInfo) => {
|
autoUpdater.on('update-downloaded', (releaseInfo: UpdateInfo) => {
|
||||||
windowService.getMainWindow()?.webContents.send(IpcChannel.UpdateDownloaded, releaseInfo)
|
const processedReleaseInfo = this.processReleaseInfo(releaseInfo)
|
||||||
this.releaseInfo = releaseInfo
|
windowService.getMainWindow()?.webContents.send(IpcChannel.UpdateDownloaded, processedReleaseInfo)
|
||||||
logger.info('update downloaded', releaseInfo)
|
logger.info('update downloaded', processedReleaseInfo)
|
||||||
})
|
})
|
||||||
|
|
||||||
if (isWin) {
|
if (isWin) {
|
||||||
@@ -237,49 +243,79 @@ export default class AppUpdater {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async showUpdateDialog(mainWindow: BrowserWindow) {
|
public quitAndInstall() {
|
||||||
if (!this.releaseInfo) {
|
app.isQuitting = true
|
||||||
return
|
setImmediate(() => autoUpdater.quitAndInstall())
|
||||||
}
|
|
||||||
const locale = locales[configManager.getLanguage()]
|
|
||||||
const { update: updateLocale } = locale.translation
|
|
||||||
|
|
||||||
let detail = this.formatReleaseNotes(this.releaseInfo.releaseNotes)
|
|
||||||
if (detail === '') {
|
|
||||||
detail = updateLocale.noReleaseNotes
|
|
||||||
}
|
|
||||||
|
|
||||||
dialog
|
|
||||||
.showMessageBox({
|
|
||||||
type: 'info',
|
|
||||||
title: updateLocale.title,
|
|
||||||
icon,
|
|
||||||
message: updateLocale.message.replace('{{version}}', this.releaseInfo.version),
|
|
||||||
detail,
|
|
||||||
buttons: [updateLocale.later, updateLocale.install],
|
|
||||||
defaultId: 1,
|
|
||||||
cancelId: 0
|
|
||||||
})
|
|
||||||
.then(({ response }) => {
|
|
||||||
if (response === 1) {
|
|
||||||
app.isQuitting = true
|
|
||||||
setImmediate(() => autoUpdater.quitAndInstall())
|
|
||||||
} else {
|
|
||||||
mainWindow.webContents.send(IpcChannel.UpdateDownloadedCancelled)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private formatReleaseNotes(releaseNotes: string | ReleaseNoteInfo[] | null | undefined): string {
|
/**
|
||||||
if (!releaseNotes) {
|
* Check if release notes contain multi-language markers
|
||||||
return ''
|
*/
|
||||||
}
|
private hasMultiLanguageMarkers(releaseNotes: string): boolean {
|
||||||
|
return releaseNotes.includes(LANG_MARKERS.EN_START)
|
||||||
|
}
|
||||||
|
|
||||||
if (typeof releaseNotes === 'string') {
|
/**
|
||||||
|
* Parse multi-language release notes and return the appropriate language version
|
||||||
|
* @param releaseNotes - Release notes string with language markers
|
||||||
|
* @returns Parsed release notes for the user's language
|
||||||
|
*
|
||||||
|
* Expected format:
|
||||||
|
* <!--LANG:en-->English content<!--LANG:zh-CN-->Chinese content<!--LANG:END-->
|
||||||
|
*/
|
||||||
|
private parseMultiLangReleaseNotes(releaseNotes: string): string {
|
||||||
|
try {
|
||||||
|
const language = configManager.getLanguage()
|
||||||
|
const isChineseUser = language === 'zh-CN' || language === 'zh-TW'
|
||||||
|
|
||||||
|
// Create regex patterns using constants
|
||||||
|
const enPattern = new RegExp(
|
||||||
|
`${LANG_MARKERS.EN_START.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')}([\\s\\S]*?)${LANG_MARKERS.ZH_CN_START.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')}`
|
||||||
|
)
|
||||||
|
const zhPattern = new RegExp(
|
||||||
|
`${LANG_MARKERS.ZH_CN_START.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')}([\\s\\S]*?)${LANG_MARKERS.END.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')}`
|
||||||
|
)
|
||||||
|
|
||||||
|
// Extract language sections
|
||||||
|
const enMatch = releaseNotes.match(enPattern)
|
||||||
|
const zhMatch = releaseNotes.match(zhPattern)
|
||||||
|
|
||||||
|
// Return appropriate language version with proper fallback
|
||||||
|
if (isChineseUser && zhMatch) {
|
||||||
|
return zhMatch[1].trim()
|
||||||
|
} else if (enMatch) {
|
||||||
|
return enMatch[1].trim()
|
||||||
|
} else {
|
||||||
|
// Clean fallback: remove all language markers
|
||||||
|
logger.warn('Failed to extract language-specific release notes, using cleaned fallback')
|
||||||
|
return releaseNotes
|
||||||
|
.replace(new RegExp(`${LANG_MARKERS.EN_START}|${LANG_MARKERS.ZH_CN_START}|${LANG_MARKERS.END}`, 'g'), '')
|
||||||
|
.trim()
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to parse multi-language release notes', error as Error)
|
||||||
|
// Return original notes as safe fallback
|
||||||
return releaseNotes
|
return releaseNotes
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return releaseNotes.map((note) => note.note).join('\n')
|
/**
|
||||||
|
* Process release info to handle multi-language release notes
|
||||||
|
* @param releaseInfo - Original release info from updater
|
||||||
|
* @returns Processed release info with localized release notes
|
||||||
|
*/
|
||||||
|
private processReleaseInfo(releaseInfo: UpdateInfo): UpdateInfo {
|
||||||
|
const processedInfo = { ...releaseInfo }
|
||||||
|
|
||||||
|
// Handle multi-language release notes in string format
|
||||||
|
if (releaseInfo.releaseNotes && typeof releaseInfo.releaseNotes === 'string') {
|
||||||
|
// Check if it contains multi-language markers
|
||||||
|
if (this.hasMultiLanguageMarkers(releaseInfo.releaseNotes)) {
|
||||||
|
processedInfo.releaseNotes = this.parseMultiLangReleaseNotes(releaseInfo.releaseNotes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return processedInfo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
interface GithubReleaseInfo {
|
interface GithubReleaseInfo {
|
||||||
@@ -287,7 +323,3 @@ interface GithubReleaseInfo {
|
|||||||
prerelease: boolean
|
prerelease: boolean
|
||||||
tag_name: string
|
tag_name: string
|
||||||
}
|
}
|
||||||
interface ReleaseNoteInfo {
|
|
||||||
readonly version: string
|
|
||||||
readonly note: string | null
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,11 +3,20 @@ import os from 'node:os'
|
|||||||
import path from 'node:path'
|
import path from 'node:path'
|
||||||
|
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { isWin } from '@main/constant'
|
import { isMac, isWin } from '@main/constant'
|
||||||
import { removeEnvProxy } from '@main/utils'
|
import { removeEnvProxy } from '@main/utils'
|
||||||
import { isUserInChina } from '@main/utils/ipService'
|
import { isUserInChina } from '@main/utils/ipService'
|
||||||
import { getBinaryName } from '@main/utils/process'
|
import { getBinaryName } from '@main/utils/process'
|
||||||
import { codeTools } from '@shared/config/constant'
|
import {
|
||||||
|
codeTools,
|
||||||
|
MACOS_TERMINALS,
|
||||||
|
MACOS_TERMINALS_WITH_COMMANDS,
|
||||||
|
terminalApps,
|
||||||
|
TerminalConfig,
|
||||||
|
TerminalConfigWithCommand,
|
||||||
|
WINDOWS_TERMINALS,
|
||||||
|
WINDOWS_TERMINALS_WITH_COMMANDS
|
||||||
|
} from '@shared/config/constant'
|
||||||
import { spawn } from 'child_process'
|
import { spawn } from 'child_process'
|
||||||
import { promisify } from 'util'
|
import { promisify } from 'util'
|
||||||
|
|
||||||
@@ -22,7 +31,13 @@ interface VersionInfo {
|
|||||||
|
|
||||||
class CodeToolsService {
|
class CodeToolsService {
|
||||||
private versionCache: Map<string, { version: string; timestamp: number }> = new Map()
|
private versionCache: Map<string, { version: string; timestamp: number }> = new Map()
|
||||||
|
private terminalsCache: {
|
||||||
|
terminals: TerminalConfig[]
|
||||||
|
timestamp: number
|
||||||
|
} | null = null
|
||||||
|
private customTerminalPaths: Map<string, string> = new Map() // Store user-configured terminal paths
|
||||||
private readonly CACHE_DURATION = 1000 * 60 * 30 // 30 minutes cache
|
private readonly CACHE_DURATION = 1000 * 60 * 30 // 30 minutes cache
|
||||||
|
private readonly TERMINALS_CACHE_DURATION = 1000 * 60 * 5 // 5 minutes cache for terminals
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.getBunPath = this.getBunPath.bind(this)
|
this.getBunPath = this.getBunPath.bind(this)
|
||||||
@@ -32,6 +47,23 @@ class CodeToolsService {
|
|||||||
this.getVersionInfo = this.getVersionInfo.bind(this)
|
this.getVersionInfo = this.getVersionInfo.bind(this)
|
||||||
this.updatePackage = this.updatePackage.bind(this)
|
this.updatePackage = this.updatePackage.bind(this)
|
||||||
this.run = this.run.bind(this)
|
this.run = this.run.bind(this)
|
||||||
|
|
||||||
|
if (isMac || isWin) {
|
||||||
|
this.preloadTerminals()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Preload available terminals in background
|
||||||
|
*/
|
||||||
|
private async preloadTerminals(): Promise<void> {
|
||||||
|
try {
|
||||||
|
logger.info('Preloading available terminals...')
|
||||||
|
await this.getAvailableTerminals()
|
||||||
|
logger.info('Terminal preloading completed')
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn('Terminal preloading failed:', error as Error)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getBunPath() {
|
public async getBunPath() {
|
||||||
@@ -53,6 +85,8 @@ class CodeToolsService {
|
|||||||
return '@qwen-code/qwen-code'
|
return '@qwen-code/qwen-code'
|
||||||
case codeTools.iFlowCli:
|
case codeTools.iFlowCli:
|
||||||
return '@iflow-ai/iflow-cli'
|
return '@iflow-ai/iflow-cli'
|
||||||
|
case codeTools.githubCopilotCli:
|
||||||
|
return '@github/copilot'
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unsupported CLI tool: ${cliTool}`)
|
throw new Error(`Unsupported CLI tool: ${cliTool}`)
|
||||||
}
|
}
|
||||||
@@ -70,15 +104,267 @@ class CodeToolsService {
|
|||||||
return 'qwen'
|
return 'qwen'
|
||||||
case codeTools.iFlowCli:
|
case codeTools.iFlowCli:
|
||||||
return 'iflow'
|
return 'iflow'
|
||||||
|
case codeTools.githubCopilotCli:
|
||||||
|
return 'copilot'
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unsupported CLI tool: ${cliTool}`)
|
throw new Error(`Unsupported CLI tool: ${cliTool}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a single terminal is available
|
||||||
|
*/
|
||||||
|
private async checkTerminalAvailability(terminal: TerminalConfig): Promise<TerminalConfig | null> {
|
||||||
|
try {
|
||||||
|
if (isMac && terminal.bundleId) {
|
||||||
|
// macOS: Check if application is installed via bundle ID with timeout
|
||||||
|
const { stdout } = await execAsync(`mdfind "kMDItemCFBundleIdentifier == '${terminal.bundleId}'"`, {
|
||||||
|
timeout: 3000
|
||||||
|
})
|
||||||
|
if (stdout.trim()) {
|
||||||
|
return terminal
|
||||||
|
}
|
||||||
|
} else if (isWin) {
|
||||||
|
// Windows: Check terminal availability
|
||||||
|
return await this.checkWindowsTerminalAvailability(terminal)
|
||||||
|
} else {
|
||||||
|
// TODO: Check if terminal is available in linux
|
||||||
|
await execAsync(`which ${terminal.id}`, { timeout: 2000 })
|
||||||
|
return terminal
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.debug(`Terminal ${terminal.id} not available:`, error as Error)
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check Windows terminal availability (simplified - user configured paths)
|
||||||
|
*/
|
||||||
|
private async checkWindowsTerminalAvailability(terminal: TerminalConfig): Promise<TerminalConfig | null> {
|
||||||
|
try {
|
||||||
|
switch (terminal.id) {
|
||||||
|
case terminalApps.cmd:
|
||||||
|
// CMD is always available on Windows
|
||||||
|
return terminal
|
||||||
|
|
||||||
|
case terminalApps.powershell:
|
||||||
|
// Check for PowerShell in PATH
|
||||||
|
try {
|
||||||
|
await execAsync('powershell -Command "Get-Host"', {
|
||||||
|
timeout: 3000
|
||||||
|
})
|
||||||
|
return terminal
|
||||||
|
} catch {
|
||||||
|
try {
|
||||||
|
await execAsync('pwsh -Command "Get-Host"', { timeout: 3000 })
|
||||||
|
return terminal
|
||||||
|
} catch {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case terminalApps.windowsTerminal:
|
||||||
|
// Check for Windows Terminal via where command (doesn't launch the terminal)
|
||||||
|
try {
|
||||||
|
await execAsync('where wt', { timeout: 3000 })
|
||||||
|
return terminal
|
||||||
|
} catch {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
case terminalApps.wsl:
|
||||||
|
// Check for WSL
|
||||||
|
try {
|
||||||
|
await execAsync('wsl --status', { timeout: 3000 })
|
||||||
|
return terminal
|
||||||
|
} catch {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
// For other terminals (Alacritty, WezTerm), check if user has configured custom path
|
||||||
|
return await this.checkCustomTerminalPath(terminal)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.debug(`Windows terminal ${terminal.id} not available:`, error as Error)
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if user has configured custom path for terminal
|
||||||
|
*/
|
||||||
|
private async checkCustomTerminalPath(terminal: TerminalConfig): Promise<TerminalConfig | null> {
|
||||||
|
// Check if user has configured custom path
|
||||||
|
const customPath = this.customTerminalPaths.get(terminal.id)
|
||||||
|
if (customPath && fs.existsSync(customPath)) {
|
||||||
|
try {
|
||||||
|
await execAsync(`"${customPath}" --version`, { timeout: 3000 })
|
||||||
|
return { ...terminal, customPath }
|
||||||
|
} catch {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to PATH check
|
||||||
|
try {
|
||||||
|
const command = terminal.id === terminalApps.alacritty ? 'alacritty' : 'wezterm'
|
||||||
|
await execAsync(`${command} --version`, { timeout: 3000 })
|
||||||
|
return terminal
|
||||||
|
} catch {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set custom path for a terminal (called from settings UI)
|
||||||
|
*/
|
||||||
|
public setCustomTerminalPath(terminalId: string, path: string): void {
|
||||||
|
logger.info(`Setting custom path for terminal ${terminalId}: ${path}`)
|
||||||
|
this.customTerminalPaths.set(terminalId, path)
|
||||||
|
// Clear terminals cache to force refresh
|
||||||
|
this.terminalsCache = null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get custom path for a terminal
|
||||||
|
*/
|
||||||
|
public getCustomTerminalPath(terminalId: string): string | undefined {
|
||||||
|
return this.customTerminalPaths.get(terminalId)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Remove custom path for a terminal
|
||||||
|
*/
|
||||||
|
public removeCustomTerminalPath(terminalId: string): void {
|
||||||
|
logger.info(`Removing custom path for terminal ${terminalId}`)
|
||||||
|
this.customTerminalPaths.delete(terminalId)
|
||||||
|
// Clear terminals cache to force refresh
|
||||||
|
this.terminalsCache = null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get available terminals (with caching and parallel checking)
|
||||||
|
*/
|
||||||
|
private async getAvailableTerminals(): Promise<TerminalConfig[]> {
|
||||||
|
const now = Date.now()
|
||||||
|
|
||||||
|
// Check cache first
|
||||||
|
if (this.terminalsCache && now - this.terminalsCache.timestamp < this.TERMINALS_CACHE_DURATION) {
|
||||||
|
logger.info(`Using cached terminals list (${this.terminalsCache.terminals.length} terminals)`)
|
||||||
|
return this.terminalsCache.terminals
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info('Checking available terminals in parallel...')
|
||||||
|
const startTime = Date.now()
|
||||||
|
|
||||||
|
// Get terminal list based on platform
|
||||||
|
const terminalList = isWin ? WINDOWS_TERMINALS : MACOS_TERMINALS
|
||||||
|
|
||||||
|
// Check all terminals in parallel
|
||||||
|
const terminalPromises = terminalList.map((terminal) => this.checkTerminalAvailability(terminal))
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Wait for all checks to complete with a global timeout
|
||||||
|
const results = await Promise.allSettled(
|
||||||
|
terminalPromises.map((p) =>
|
||||||
|
Promise.race([p, new Promise((_, reject) => setTimeout(() => reject(new Error('timeout')), 5000))])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
const availableTerminals: TerminalConfig[] = []
|
||||||
|
results.forEach((result, index) => {
|
||||||
|
if (result.status === 'fulfilled' && result.value) {
|
||||||
|
availableTerminals.push(result.value as TerminalConfig)
|
||||||
|
} else if (result.status === 'rejected') {
|
||||||
|
logger.debug(`Terminal check failed for ${MACOS_TERMINALS[index].id}:`, result.reason)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const endTime = Date.now()
|
||||||
|
logger.info(
|
||||||
|
`Terminal availability check completed in ${endTime - startTime}ms, found ${availableTerminals.length} terminals`
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cache the results
|
||||||
|
this.terminalsCache = {
|
||||||
|
terminals: availableTerminals,
|
||||||
|
timestamp: now
|
||||||
|
}
|
||||||
|
|
||||||
|
return availableTerminals
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error checking terminal availability:', error as Error)
|
||||||
|
// Return cached result if available, otherwise empty array
|
||||||
|
return this.terminalsCache?.terminals || []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get terminal config by ID, fallback to system default
|
||||||
|
*/
|
||||||
|
private async getTerminalConfig(terminalId?: string): Promise<TerminalConfigWithCommand> {
|
||||||
|
const availableTerminals = await this.getAvailableTerminals()
|
||||||
|
const terminalCommands = isWin ? WINDOWS_TERMINALS_WITH_COMMANDS : MACOS_TERMINALS_WITH_COMMANDS
|
||||||
|
const defaultTerminal = isWin ? terminalApps.cmd : terminalApps.systemDefault
|
||||||
|
|
||||||
|
if (terminalId) {
|
||||||
|
let requestedTerminal = terminalCommands.find(
|
||||||
|
(t) => t.id === terminalId && availableTerminals.some((at) => at.id === t.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (requestedTerminal) {
|
||||||
|
// Apply custom path if configured
|
||||||
|
const customPath = this.customTerminalPaths.get(terminalId)
|
||||||
|
if (customPath && isWin) {
|
||||||
|
requestedTerminal = this.applyCustomPath(requestedTerminal, customPath)
|
||||||
|
}
|
||||||
|
return requestedTerminal
|
||||||
|
} else {
|
||||||
|
logger.warn(`Requested terminal ${terminalId} not available, falling back to system default`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to system default Terminal
|
||||||
|
const systemTerminal = terminalCommands.find(
|
||||||
|
(t) => t.id === defaultTerminal && availableTerminals.some((at) => at.id === t.id)
|
||||||
|
)
|
||||||
|
if (systemTerminal) {
|
||||||
|
return systemTerminal
|
||||||
|
}
|
||||||
|
|
||||||
|
// If even system Terminal is not found, return the first available
|
||||||
|
const firstAvailable = terminalCommands.find((t) => availableTerminals.some((at) => at.id === t.id))
|
||||||
|
if (firstAvailable) {
|
||||||
|
return firstAvailable
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last resort fallback
|
||||||
|
return terminalCommands.find((t) => t.id === defaultTerminal)!
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply custom path to terminal configuration
|
||||||
|
*/
|
||||||
|
private applyCustomPath(terminal: TerminalConfigWithCommand, customPath: string): TerminalConfigWithCommand {
|
||||||
|
return {
|
||||||
|
...terminal,
|
||||||
|
customPath,
|
||||||
|
command: (directory: string, fullCommand: string) => {
|
||||||
|
const originalCommand = terminal.command(directory, fullCommand)
|
||||||
|
return {
|
||||||
|
...originalCommand,
|
||||||
|
command: customPath // Replace command with custom path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private async isPackageInstalled(cliTool: string): Promise<boolean> {
|
private async isPackageInstalled(cliTool: string): Promise<boolean> {
|
||||||
const executableName = await this.getCliExecutableName(cliTool)
|
const executableName = await this.getCliExecutableName(cliTool)
|
||||||
const binDir = path.join(os.homedir(), '.cherrystudio', 'bin')
|
const binDir = path.join(os.homedir(), '.cherrystudio', 'bin')
|
||||||
const executablePath = path.join(binDir, executableName + (process.platform === 'win32' ? '.exe' : ''))
|
const executablePath = path.join(binDir, executableName + (isWin ? '.exe' : ''))
|
||||||
|
|
||||||
// Ensure bin directory exists
|
// Ensure bin directory exists
|
||||||
if (!fs.existsSync(binDir)) {
|
if (!fs.existsSync(binDir)) {
|
||||||
@@ -105,9 +391,11 @@ class CodeToolsService {
|
|||||||
try {
|
try {
|
||||||
const executableName = await this.getCliExecutableName(cliTool)
|
const executableName = await this.getCliExecutableName(cliTool)
|
||||||
const binDir = path.join(os.homedir(), '.cherrystudio', 'bin')
|
const binDir = path.join(os.homedir(), '.cherrystudio', 'bin')
|
||||||
const executablePath = path.join(binDir, executableName + (process.platform === 'win32' ? '.exe' : ''))
|
const executablePath = path.join(binDir, executableName + (isWin ? '.exe' : ''))
|
||||||
|
|
||||||
const { stdout } = await execAsync(`"${executablePath}" --version`, { timeout: 10000 })
|
const { stdout } = await execAsync(`"${executablePath}" --version`, {
|
||||||
|
timeout: 10000
|
||||||
|
})
|
||||||
// Extract version number from output (format may vary by tool)
|
// Extract version number from output (format may vary by tool)
|
||||||
const versionMatch = stdout.trim().match(/\d+\.\d+\.\d+/)
|
const versionMatch = stdout.trim().match(/\d+\.\d+\.\d+/)
|
||||||
installedVersion = versionMatch ? versionMatch[0] : stdout.trim().split(' ')[0]
|
installedVersion = versionMatch ? versionMatch[0] : stdout.trim().split(' ')[0]
|
||||||
@@ -148,7 +436,10 @@ class CodeToolsService {
|
|||||||
logger.info(`${packageName} latest version: ${latestVersion}`)
|
logger.info(`${packageName} latest version: ${latestVersion}`)
|
||||||
|
|
||||||
// Cache the result
|
// Cache the result
|
||||||
this.versionCache.set(cacheKey, { version: latestVersion!, timestamp: now })
|
this.versionCache.set(cacheKey, {
|
||||||
|
version: latestVersion!,
|
||||||
|
timestamp: now
|
||||||
|
})
|
||||||
logger.debug(`Cached latest version for ${packageName}`)
|
logger.debug(`Cached latest version for ${packageName}`)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.warn(`Failed to get latest version for ${packageName}:`, error as Error)
|
logger.warn(`Failed to get latest version for ${packageName}:`, error as Error)
|
||||||
@@ -191,6 +482,17 @@ class CodeToolsService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get available terminals for the current platform
|
||||||
|
*/
|
||||||
|
public async getAvailableTerminalsForPlatform(): Promise<TerminalConfig[]> {
|
||||||
|
if (isMac || isWin) {
|
||||||
|
return this.getAvailableTerminals()
|
||||||
|
}
|
||||||
|
// For other platforms, return empty array for now
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update a CLI tool to the latest version
|
* Update a CLI tool to the latest version
|
||||||
*/
|
*/
|
||||||
@@ -202,10 +504,9 @@ class CodeToolsService {
|
|||||||
const bunInstallPath = path.join(os.homedir(), '.cherrystudio')
|
const bunInstallPath = path.join(os.homedir(), '.cherrystudio')
|
||||||
const registryUrl = await this.getNpmRegistryUrl()
|
const registryUrl = await this.getNpmRegistryUrl()
|
||||||
|
|
||||||
const installEnvPrefix =
|
const installEnvPrefix = isWin
|
||||||
process.platform === 'win32'
|
? `set "BUN_INSTALL=${bunInstallPath}" && set "NPM_CONFIG_REGISTRY=${registryUrl}" &&`
|
||||||
? `set "BUN_INSTALL=${bunInstallPath}" && set "NPM_CONFIG_REGISTRY=${registryUrl}" &&`
|
: `export BUN_INSTALL="${bunInstallPath}" && export NPM_CONFIG_REGISTRY="${registryUrl}" &&`
|
||||||
: `export BUN_INSTALL="${bunInstallPath}" && export NPM_CONFIG_REGISTRY="${registryUrl}" &&`
|
|
||||||
|
|
||||||
const updateCommand = `${installEnvPrefix} "${bunPath}" install -g ${packageName}`
|
const updateCommand = `${installEnvPrefix} "${bunPath}" install -g ${packageName}`
|
||||||
logger.info(`Executing update command: ${updateCommand}`)
|
logger.info(`Executing update command: ${updateCommand}`)
|
||||||
@@ -241,7 +542,7 @@ class CodeToolsService {
|
|||||||
_model: string,
|
_model: string,
|
||||||
directory: string,
|
directory: string,
|
||||||
env: Record<string, string>,
|
env: Record<string, string>,
|
||||||
options: { autoUpdateToLatest?: boolean } = {}
|
options: { autoUpdateToLatest?: boolean; terminal?: string } = {}
|
||||||
) {
|
) {
|
||||||
logger.info(`Starting CLI tool launch: ${cliTool} in directory: ${directory}`)
|
logger.info(`Starting CLI tool launch: ${cliTool} in directory: ${directory}`)
|
||||||
logger.debug(`Environment variables:`, Object.keys(env))
|
logger.debug(`Environment variables:`, Object.keys(env))
|
||||||
@@ -251,7 +552,7 @@ class CodeToolsService {
|
|||||||
const bunPath = await this.getBunPath()
|
const bunPath = await this.getBunPath()
|
||||||
const executableName = await this.getCliExecutableName(cliTool)
|
const executableName = await this.getCliExecutableName(cliTool)
|
||||||
const binDir = path.join(os.homedir(), '.cherrystudio', 'bin')
|
const binDir = path.join(os.homedir(), '.cherrystudio', 'bin')
|
||||||
const executablePath = path.join(binDir, executableName + (process.platform === 'win32' ? '.exe' : ''))
|
const executablePath = path.join(binDir, executableName + (isWin ? '.exe' : ''))
|
||||||
|
|
||||||
logger.debug(`Package name: ${packageName}`)
|
logger.debug(`Package name: ${packageName}`)
|
||||||
logger.debug(`Bun path: ${bunPath}`)
|
logger.debug(`Bun path: ${bunPath}`)
|
||||||
@@ -295,7 +596,13 @@ class CodeToolsService {
|
|||||||
|
|
||||||
// Build environment variable prefix (based on platform)
|
// Build environment variable prefix (based on platform)
|
||||||
const buildEnvPrefix = (isWindows: boolean) => {
|
const buildEnvPrefix = (isWindows: boolean) => {
|
||||||
if (Object.keys(env).length === 0) return ''
|
if (Object.keys(env).length === 0) {
|
||||||
|
logger.info('No environment variables to set')
|
||||||
|
return ''
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info('Setting environment variables:', Object.keys(env))
|
||||||
|
logger.info('Environment variable values:', env)
|
||||||
|
|
||||||
if (isWindows) {
|
if (isWindows) {
|
||||||
// Windows uses set command
|
// Windows uses set command
|
||||||
@@ -304,13 +611,29 @@ class CodeToolsService {
|
|||||||
.join(' && ')
|
.join(' && ')
|
||||||
} else {
|
} else {
|
||||||
// Unix-like systems use export command
|
// Unix-like systems use export command
|
||||||
return Object.entries(env)
|
const validEntries = Object.entries(env).filter(([key, value]) => {
|
||||||
.map(([key, value]) => `export ${key}="${value.replace(/"/g, '\\"')}"`)
|
if (!key || key.trim() === '') {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if (value === undefined || value === null) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
const envCommands = validEntries
|
||||||
|
.map(([key, value]) => {
|
||||||
|
const sanitizedValue = String(value).replace(/\\/g, '\\\\').replace(/"/g, '\\"')
|
||||||
|
const exportCmd = `export ${key}="${sanitizedValue}"`
|
||||||
|
logger.info(`Setting env var: ${key}="${sanitizedValue}"`)
|
||||||
|
logger.info(`Export command: ${exportCmd}`)
|
||||||
|
return exportCmd
|
||||||
|
})
|
||||||
.join(' && ')
|
.join(' && ')
|
||||||
|
return envCommands
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build command to execute
|
|
||||||
let baseCommand = isWin ? `"${executablePath}"` : `"${bunPath}" "${executablePath}"`
|
let baseCommand = isWin ? `"${executablePath}"` : `"${bunPath}" "${executablePath}"`
|
||||||
|
|
||||||
// Add configuration parameters for OpenAI Codex
|
// Add configuration parameters for OpenAI Codex
|
||||||
@@ -351,20 +674,20 @@ class CodeToolsService {
|
|||||||
|
|
||||||
switch (platform) {
|
switch (platform) {
|
||||||
case 'darwin': {
|
case 'darwin': {
|
||||||
// macOS - Use osascript to launch terminal and execute command directly, without showing startup command
|
// macOS - Support multiple terminals
|
||||||
const envPrefix = buildEnvPrefix(false)
|
const envPrefix = buildEnvPrefix(false)
|
||||||
const command = envPrefix ? `${envPrefix} && ${baseCommand}` : baseCommand
|
|
||||||
// Combine directory change with the main command to ensure they execute in the same shell session
|
|
||||||
const fullCommand = `cd '${directory.replace(/'/g, "\\'")}' && clear && ${command}`
|
|
||||||
|
|
||||||
terminalCommand = 'osascript'
|
const command = envPrefix ? `${envPrefix} && ${baseCommand}` : baseCommand
|
||||||
terminalArgs = [
|
|
||||||
'-e',
|
// Combine directory change with the main command to ensure they execute in the same shell session
|
||||||
`tell application "Terminal"
|
const fullCommand = `cd "${directory.replace(/"/g, '\\"')}" && clear && ${command}`
|
||||||
do script "${fullCommand.replace(/"/g, '\\"')}"
|
|
||||||
activate
|
const terminalConfig = await this.getTerminalConfig(options.terminal)
|
||||||
end tell`
|
logger.info(`Using terminal: ${terminalConfig.name} (${terminalConfig.id})`)
|
||||||
]
|
|
||||||
|
const { command: cmd, args } = terminalConfig.command(directory, fullCommand)
|
||||||
|
terminalCommand = cmd
|
||||||
|
terminalArgs = args
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
case 'win32': {
|
case 'win32': {
|
||||||
@@ -424,9 +747,23 @@ end tell`
|
|||||||
throw new Error(`Failed to create launch script: ${error}`)
|
throw new Error(`Failed to create launch script: ${error}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch bat file - Use safest start syntax, no title parameter
|
// Use selected terminal configuration
|
||||||
terminalCommand = 'cmd'
|
const terminalConfig = await this.getTerminalConfig(options.terminal)
|
||||||
terminalArgs = ['/c', 'start', batFilePath]
|
logger.info(`Using terminal: ${terminalConfig.name} (${terminalConfig.id})`)
|
||||||
|
|
||||||
|
// Get command and args from terminal configuration
|
||||||
|
// Pass the bat file path as the command to execute
|
||||||
|
const fullCommand = batFilePath
|
||||||
|
const { command: cmd, args } = terminalConfig.command(directory, fullCommand)
|
||||||
|
|
||||||
|
// Override if it's a custom terminal with a custom path
|
||||||
|
if (terminalConfig.customPath) {
|
||||||
|
terminalCommand = terminalConfig.customPath
|
||||||
|
terminalArgs = args
|
||||||
|
} else {
|
||||||
|
terminalCommand = cmd
|
||||||
|
terminalArgs = args
|
||||||
|
}
|
||||||
|
|
||||||
// Set cleanup task (delete temp file after 5 minutes)
|
// Set cleanup task (delete temp file after 5 minutes)
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import { defaultLanguage, UpgradeChannel, ZOOM_SHORTCUTS } from '@shared/config/
|
|||||||
import { LanguageVarious, Shortcut, ThemeMode } from '@types'
|
import { LanguageVarious, Shortcut, ThemeMode } from '@types'
|
||||||
import { app } from 'electron'
|
import { app } from 'electron'
|
||||||
import Store from 'electron-store'
|
import Store from 'electron-store'
|
||||||
|
import { v4 as uuidv4 } from 'uuid'
|
||||||
|
|
||||||
import { locales } from '../utils/locales'
|
import { locales } from '../utils/locales'
|
||||||
|
|
||||||
@@ -27,7 +28,8 @@ export enum ConfigKeys {
|
|||||||
SelectionAssistantFilterList = 'selectionAssistantFilterList',
|
SelectionAssistantFilterList = 'selectionAssistantFilterList',
|
||||||
DisableHardwareAcceleration = 'disableHardwareAcceleration',
|
DisableHardwareAcceleration = 'disableHardwareAcceleration',
|
||||||
Proxy = 'proxy',
|
Proxy = 'proxy',
|
||||||
EnableDeveloperMode = 'enableDeveloperMode'
|
EnableDeveloperMode = 'enableDeveloperMode',
|
||||||
|
ClientId = 'clientId'
|
||||||
}
|
}
|
||||||
|
|
||||||
export class ConfigManager {
|
export class ConfigManager {
|
||||||
@@ -241,6 +243,17 @@ export class ConfigManager {
|
|||||||
this.set(ConfigKeys.EnableDeveloperMode, value)
|
this.set(ConfigKeys.EnableDeveloperMode, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getClientId(): string {
|
||||||
|
let clientId = this.get<string>(ConfigKeys.ClientId)
|
||||||
|
|
||||||
|
if (!clientId) {
|
||||||
|
clientId = uuidv4()
|
||||||
|
this.set(ConfigKeys.ClientId, clientId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return clientId
|
||||||
|
}
|
||||||
|
|
||||||
set(key: string, value: unknown, isNotify: boolean = false) {
|
set(key: string, value: unknown, isNotify: boolean = false) {
|
||||||
this.store.set(key, value)
|
this.store.set(key, value)
|
||||||
isNotify && this.notifySubscribers(key, value)
|
isNotify && this.notifySubscribers(key, value)
|
||||||
|
|||||||
@@ -725,7 +725,10 @@ class FileStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public openPath = async (_: Electron.IpcMainInvokeEvent, path: string): Promise<void> => {
|
public openPath = async (_: Electron.IpcMainInvokeEvent, path: string): Promise<void> => {
|
||||||
shell.openPath(path).catch((err) => logger.error('[IPC - Error] Failed to open file:', err))
|
const resolved = await shell.openPath(path)
|
||||||
|
if (resolved !== '') {
|
||||||
|
throw new Error(resolved)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -1229,6 +1232,19 @@ class FileStorage {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public showInFolder = async (_: Electron.IpcMainInvokeEvent, path: string): Promise<void> => {
|
||||||
|
if (!fs.existsSync(path)) {
|
||||||
|
const msg = `File or folder does not exist: ${path}`
|
||||||
|
logger.error(msg)
|
||||||
|
throw new Error(msg)
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
shell.showItemInFolder(path)
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to show item in folder:', error as Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export const fileStorage = new FileStorage()
|
export const fileStorage = new FileStorage()
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ import Reranker from '@main/knowledge/reranker/Reranker'
|
|||||||
import { fileStorage } from '@main/services/FileStorage'
|
import { fileStorage } from '@main/services/FileStorage'
|
||||||
import { windowService } from '@main/services/WindowService'
|
import { windowService } from '@main/services/WindowService'
|
||||||
import { getDataPath } from '@main/utils'
|
import { getDataPath } from '@main/utils'
|
||||||
import { getAllFiles } from '@main/utils/file'
|
import { getAllFiles, sanitizeFilename } from '@main/utils/file'
|
||||||
import { TraceMethod } from '@mcp-trace/trace-core'
|
import { TraceMethod } from '@mcp-trace/trace-core'
|
||||||
import { MB } from '@shared/config/constant'
|
import { MB } from '@shared/config/constant'
|
||||||
import type { LoaderReturn } from '@shared/config/types'
|
import type { LoaderReturn } from '@shared/config/types'
|
||||||
@@ -147,11 +147,16 @@ class KnowledgeService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private getDbPath = (id: string): string => {
|
||||||
|
// 消除网络搜索requestI d中的特殊字符
|
||||||
|
return path.join(this.storageDir, sanitizeFilename(id, '_'))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Delete knowledge base file
|
* Delete knowledge base file
|
||||||
*/
|
*/
|
||||||
private deleteKnowledgeFile = (id: string): boolean => {
|
private deleteKnowledgeFile = (id: string): boolean => {
|
||||||
const dbPath = path.join(this.storageDir, id)
|
const dbPath = this.getDbPath(id)
|
||||||
if (fs.existsSync(dbPath)) {
|
if (fs.existsSync(dbPath)) {
|
||||||
try {
|
try {
|
||||||
fs.rmSync(dbPath, { recursive: true })
|
fs.rmSync(dbPath, { recursive: true })
|
||||||
@@ -244,7 +249,8 @@ class KnowledgeService {
|
|||||||
dimensions
|
dimensions
|
||||||
})
|
})
|
||||||
try {
|
try {
|
||||||
const libSqlDb = new LibSqlDb({ path: path.join(this.storageDir, id) })
|
const dbPath = this.getDbPath(id)
|
||||||
|
const libSqlDb = new LibSqlDb({ path: dbPath })
|
||||||
// Save database instance for later closing
|
// Save database instance for later closing
|
||||||
this.dbInstances.set(id, libSqlDb)
|
this.dbInstances.set(id, libSqlDb)
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import { createInMemoryMCPServer } from '@main/mcpServers/factory'
|
|||||||
import { makeSureDirExists, removeEnvProxy } from '@main/utils'
|
import { makeSureDirExists, removeEnvProxy } from '@main/utils'
|
||||||
import { buildFunctionCallToolName } from '@main/utils/mcp'
|
import { buildFunctionCallToolName } from '@main/utils/mcp'
|
||||||
import { getBinaryName, getBinaryPath } from '@main/utils/process'
|
import { getBinaryName, getBinaryPath } from '@main/utils/process'
|
||||||
|
import getLoginShellEnvironment from '@main/utils/shell-env'
|
||||||
import { TraceMethod, withSpanFunc } from '@mcp-trace/trace-core'
|
import { TraceMethod, withSpanFunc } from '@mcp-trace/trace-core'
|
||||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||||
import { SSEClientTransport, SSEClientTransportOptions } from '@modelcontextprotocol/sdk/client/sse.js'
|
import { SSEClientTransport, SSEClientTransportOptions } from '@modelcontextprotocol/sdk/client/sse.js'
|
||||||
@@ -43,14 +44,12 @@ import {
|
|||||||
} from '@types'
|
} from '@types'
|
||||||
import { app, net } from 'electron'
|
import { app, net } from 'electron'
|
||||||
import { EventEmitter } from 'events'
|
import { EventEmitter } from 'events'
|
||||||
import { memoize } from 'lodash'
|
|
||||||
import { v4 as uuidv4 } from 'uuid'
|
import { v4 as uuidv4 } from 'uuid'
|
||||||
|
|
||||||
import { CacheService } from './CacheService'
|
import { CacheService } from './CacheService'
|
||||||
import DxtService from './DxtService'
|
import DxtService from './DxtService'
|
||||||
import { CallBackServer } from './mcp/oauth/callback'
|
import { CallBackServer } from './mcp/oauth/callback'
|
||||||
import { McpOAuthClientProvider } from './mcp/oauth/provider'
|
import { McpOAuthClientProvider } from './mcp/oauth/provider'
|
||||||
import getLoginShellEnvironment from './mcp/shell-env'
|
|
||||||
import { windowService } from './WindowService'
|
import { windowService } from './WindowService'
|
||||||
|
|
||||||
// Generic type for caching wrapped functions
|
// Generic type for caching wrapped functions
|
||||||
@@ -335,7 +334,7 @@ class McpService {
|
|||||||
|
|
||||||
getServerLogger(server).debug(`Starting server`, { command: cmd, args })
|
getServerLogger(server).debug(`Starting server`, { command: cmd, args })
|
||||||
// Logger.info(`[MCP] Environment variables for server:`, server.env)
|
// Logger.info(`[MCP] Environment variables for server:`, server.env)
|
||||||
const loginShellEnv = await this.getLoginShellEnv()
|
const loginShellEnv = await getLoginShellEnvironment()
|
||||||
|
|
||||||
// Bun not support proxy https://github.com/oven-sh/bun/issues/16812
|
// Bun not support proxy https://github.com/oven-sh/bun/issues/16812
|
||||||
if (cmd.includes('bun')) {
|
if (cmd.includes('bun')) {
|
||||||
@@ -878,20 +877,6 @@ class McpService {
|
|||||||
return await cachedGetResource(server, uri)
|
return await cachedGetResource(server, uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
private getLoginShellEnv = memoize(async (): Promise<Record<string, string>> => {
|
|
||||||
try {
|
|
||||||
const loginEnv = await getLoginShellEnvironment()
|
|
||||||
const pathSeparator = process.platform === 'win32' ? ';' : ':'
|
|
||||||
const cherryBinPath = path.join(os.homedir(), '.cherrystudio', 'bin')
|
|
||||||
loginEnv.PATH = `${loginEnv.PATH}${pathSeparator}${cherryBinPath}`
|
|
||||||
logger.debug('Successfully fetched login shell environment variables:')
|
|
||||||
return loginEnv
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('Failed to fetch login shell environment variables:', error as Error)
|
|
||||||
return {}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// 实现 abortTool 方法
|
// 实现 abortTool 方法
|
||||||
public async abortTool(_: Electron.IpcMainInvokeEvent, callId: string) {
|
public async abortTool(_: Electron.IpcMainInvokeEvent, callId: string) {
|
||||||
const activeToolCall = this.activeToolCalls.get(callId)
|
const activeToolCall = this.activeToolCalls.get(callId)
|
||||||
|
|||||||
@@ -0,0 +1,586 @@
|
|||||||
|
import { exec } from 'node:child_process'
|
||||||
|
import { homedir } from 'node:os'
|
||||||
|
import { promisify } from 'node:util'
|
||||||
|
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
import * as fs from 'fs-extra'
|
||||||
|
import * as path from 'path'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('OvmsManager')
|
||||||
|
|
||||||
|
const execAsync = promisify(exec)
|
||||||
|
|
||||||
|
interface OvmsProcess {
|
||||||
|
pid: number
|
||||||
|
path: string
|
||||||
|
workingDirectory: string
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ModelConfig {
|
||||||
|
name: string
|
||||||
|
base_path: string
|
||||||
|
}
|
||||||
|
|
||||||
|
interface OvmsConfig {
|
||||||
|
mediapipe_config_list: ModelConfig[]
|
||||||
|
}
|
||||||
|
|
||||||
|
class OvmsManager {
|
||||||
|
private ovms: OvmsProcess | null = null
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Recursively terminate a process and all its child processes
|
||||||
|
* @param pid Process ID to terminate
|
||||||
|
* @returns Promise<{ success: boolean; message?: string }>
|
||||||
|
*/
|
||||||
|
private async terminalProcess(pid: number): Promise<{ success: boolean; message?: string }> {
|
||||||
|
try {
|
||||||
|
// Check if the process is running
|
||||||
|
const processCheckCommand = `Get-Process -Id ${pid} -ErrorAction SilentlyContinue | Select-Object Id | ConvertTo-Json`
|
||||||
|
const { stdout: processStdout } = await execAsync(`powershell -Command "${processCheckCommand}"`)
|
||||||
|
|
||||||
|
if (!processStdout.trim()) {
|
||||||
|
logger.info(`Process with PID ${pid} is not running`)
|
||||||
|
return { success: true, message: `Process with PID ${pid} is not running` }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find child processes
|
||||||
|
const childProcessCommand = `Get-WmiObject -Class Win32_Process | Where-Object { $_.ParentProcessId -eq ${pid} } | Select-Object ProcessId | ConvertTo-Json`
|
||||||
|
const { stdout: childStdout } = await execAsync(`powershell -Command "${childProcessCommand}"`)
|
||||||
|
|
||||||
|
// If there are child processes, terminate them first
|
||||||
|
if (childStdout.trim()) {
|
||||||
|
const childProcesses = JSON.parse(childStdout)
|
||||||
|
const childList = Array.isArray(childProcesses) ? childProcesses : [childProcesses]
|
||||||
|
|
||||||
|
logger.info(`Found ${childList.length} child processes for PID ${pid}`)
|
||||||
|
|
||||||
|
// Recursively terminate each child process
|
||||||
|
for (const childProcess of childList) {
|
||||||
|
const childPid = childProcess.ProcessId
|
||||||
|
logger.info(`Terminating child process PID: ${childPid}`)
|
||||||
|
await this.terminalProcess(childPid)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.info(`No child processes found for PID ${pid}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, terminate the parent process
|
||||||
|
const killCommand = `Stop-Process -Id ${pid} -Force -ErrorAction SilentlyContinue`
|
||||||
|
await execAsync(`powershell -Command "${killCommand}"`)
|
||||||
|
logger.info(`Terminated process with PID: ${pid}`)
|
||||||
|
|
||||||
|
// Wait for the process to disappear with 5-second timeout
|
||||||
|
const timeout = 5000 // 5 seconds
|
||||||
|
const startTime = Date.now()
|
||||||
|
|
||||||
|
while (Date.now() - startTime < timeout) {
|
||||||
|
const checkCommand = `Get-Process -Id ${pid} -ErrorAction SilentlyContinue | Select-Object Id | ConvertTo-Json`
|
||||||
|
const { stdout: checkStdout } = await execAsync(`powershell -Command "${checkCommand}"`)
|
||||||
|
|
||||||
|
if (!checkStdout.trim()) {
|
||||||
|
logger.info(`Process with PID ${pid} has disappeared`)
|
||||||
|
return { success: true, message: `Process ${pid} and all child processes terminated successfully` }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait 300ms before checking again
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 300))
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.warn(`Process with PID ${pid} did not disappear within timeout`)
|
||||||
|
return { success: false, message: `Process ${pid} did not disappear within 5 seconds` }
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Failed to terminate process ${pid}:`, error as Error)
|
||||||
|
return { success: false, message: `Failed to terminate process ${pid}` }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stop OVMS process if it's running
|
||||||
|
* @returns Promise<{ success: boolean; message?: string }>
|
||||||
|
*/
|
||||||
|
public async stopOvms(): Promise<{ success: boolean; message?: string }> {
|
||||||
|
try {
|
||||||
|
// Check if OVMS process is running
|
||||||
|
const psCommand = `Get-Process -Name "ovms" -ErrorAction SilentlyContinue | Select-Object Id, Path | ConvertTo-Json`
|
||||||
|
const { stdout } = await execAsync(`powershell -Command "${psCommand}"`)
|
||||||
|
|
||||||
|
if (!stdout.trim()) {
|
||||||
|
logger.info('OVMS process is not running')
|
||||||
|
return { success: true, message: 'OVMS process is not running' }
|
||||||
|
}
|
||||||
|
|
||||||
|
const processes = JSON.parse(stdout)
|
||||||
|
const processList = Array.isArray(processes) ? processes : [processes]
|
||||||
|
|
||||||
|
if (processList.length === 0) {
|
||||||
|
logger.info('OVMS process is not running')
|
||||||
|
return { success: true, message: 'OVMS process is not running' }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Terminate all OVMS processes using terminalProcess
|
||||||
|
for (const process of processList) {
|
||||||
|
const result = await this.terminalProcess(process.Id)
|
||||||
|
if (!result.success) {
|
||||||
|
logger.error(`Failed to terminate OVMS process with PID: ${process.Id}, ${result.message}`)
|
||||||
|
return { success: false, message: `Failed to terminate OVMS process: ${result.message}` }
|
||||||
|
}
|
||||||
|
logger.info(`Terminated OVMS process with PID: ${process.Id}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset the ovms instance
|
||||||
|
this.ovms = null
|
||||||
|
|
||||||
|
logger.info('OVMS process stopped successfully')
|
||||||
|
return { success: true, message: 'OVMS process stopped successfully' }
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Failed to stop OVMS process: ${error}`)
|
||||||
|
return { success: false, message: 'Failed to stop OVMS process' }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Run OVMS by ensuring config.json exists and executing run.bat
|
||||||
|
* @returns Promise<{ success: boolean; message?: string }>
|
||||||
|
*/
|
||||||
|
public async runOvms(): Promise<{ success: boolean; message?: string }> {
|
||||||
|
const homeDir = homedir()
|
||||||
|
const ovmsDir = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms')
|
||||||
|
const configPath = path.join(ovmsDir, 'models', 'config.json')
|
||||||
|
const runBatPath = path.join(ovmsDir, 'run.bat')
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Check if config.json exists, if not create it with default content
|
||||||
|
if (!(await fs.pathExists(configPath))) {
|
||||||
|
logger.info(`Config file does not exist, creating: ${configPath}`)
|
||||||
|
|
||||||
|
// Ensure the models directory exists
|
||||||
|
await fs.ensureDir(path.dirname(configPath))
|
||||||
|
|
||||||
|
// Create config.json with default content
|
||||||
|
const defaultConfig = {
|
||||||
|
mediapipe_config_list: [],
|
||||||
|
model_config_list: []
|
||||||
|
}
|
||||||
|
|
||||||
|
await fs.writeJson(configPath, defaultConfig, { spaces: 2 })
|
||||||
|
logger.info(`Config file created: ${configPath}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if run.bat exists
|
||||||
|
if (!(await fs.pathExists(runBatPath))) {
|
||||||
|
logger.error(`run.bat not found at: ${runBatPath}`)
|
||||||
|
return { success: false, message: 'run.bat not found' }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run run.bat without waiting for it to complete
|
||||||
|
logger.info(`Starting OVMS with run.bat: ${runBatPath}`)
|
||||||
|
exec(`"${runBatPath}"`, { cwd: ovmsDir }, (error) => {
|
||||||
|
if (error) {
|
||||||
|
logger.error(`Error running run.bat: ${error}`)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info('OVMS started successfully')
|
||||||
|
return { success: true }
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Failed to run OVMS: ${error}`)
|
||||||
|
return { success: false, message: 'Failed to run OVMS' }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get OVMS status - checks installation and running status
|
||||||
|
* @returns 'not-installed' | 'not-running' | 'running'
|
||||||
|
*/
|
||||||
|
public async getOvmsStatus(): Promise<'not-installed' | 'not-running' | 'running'> {
|
||||||
|
const homeDir = homedir()
|
||||||
|
const ovmsPath = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms', 'ovms.exe')
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Check if OVMS executable exists
|
||||||
|
if (!(await fs.pathExists(ovmsPath))) {
|
||||||
|
logger.info(`OVMS executable not found at: ${ovmsPath}`)
|
||||||
|
return 'not-installed'
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if OVMS process is running
|
||||||
|
//const psCommand = `Get-Process -Name "ovms" -ErrorAction SilentlyContinue | Where-Object { $_.Path -eq "${ovmsPath.replace(/\\/g, '\\\\')}" } | Select-Object Id | ConvertTo-Json`;
|
||||||
|
//const { stdout } = await execAsync(`powershell -Command "${psCommand}"`);
|
||||||
|
const psCommand = `Get-Process -Name "ovms" -ErrorAction SilentlyContinue | Select-Object Id, Path | ConvertTo-Json`
|
||||||
|
const { stdout } = await execAsync(`powershell -Command "${psCommand}"`)
|
||||||
|
|
||||||
|
if (!stdout.trim()) {
|
||||||
|
logger.info('OVMS process not running')
|
||||||
|
return 'not-running'
|
||||||
|
}
|
||||||
|
|
||||||
|
const processes = JSON.parse(stdout)
|
||||||
|
const processList = Array.isArray(processes) ? processes : [processes]
|
||||||
|
|
||||||
|
if (processList.length > 0) {
|
||||||
|
logger.info('OVMS process is running')
|
||||||
|
return 'running'
|
||||||
|
} else {
|
||||||
|
logger.info('OVMS process not running')
|
||||||
|
return 'not-running'
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.info(`Failed to check OVMS status: ${error}`)
|
||||||
|
return 'not-running'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize OVMS by finding the executable path and working directory
|
||||||
|
*/
|
||||||
|
public async initializeOvms(): Promise<boolean> {
|
||||||
|
// Use PowerShell to find ovms.exe processes with their paths
|
||||||
|
const psCommand = `Get-Process -Name "ovms" -ErrorAction SilentlyContinue | Select-Object Id, Path | ConvertTo-Json`
|
||||||
|
const { stdout } = await execAsync(`powershell -Command "${psCommand}"`)
|
||||||
|
|
||||||
|
if (!stdout.trim()) {
|
||||||
|
logger.error('Command to find OVMS process returned no output')
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
logger.debug(`OVMS process output: ${stdout}`)
|
||||||
|
|
||||||
|
const processes = JSON.parse(stdout)
|
||||||
|
const processList = Array.isArray(processes) ? processes : [processes]
|
||||||
|
|
||||||
|
// Find the first process with a valid path
|
||||||
|
for (const process of processList) {
|
||||||
|
this.ovms = {
|
||||||
|
pid: process.Id,
|
||||||
|
path: process.Path,
|
||||||
|
workingDirectory: path.dirname(process.Path)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.ovms !== null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if the Model Name and ID are valid, they are valid only if they are not used in the config.json
|
||||||
|
* @param modelName Name of the model to check
|
||||||
|
* @param modelId ID of the model to check
|
||||||
|
*/
|
||||||
|
public async isNameAndIDAvalid(modelName: string, modelId: string): Promise<boolean> {
|
||||||
|
if (!modelName || !modelId) {
|
||||||
|
logger.error('Model name and ID cannot be empty')
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
const homeDir = homedir()
|
||||||
|
const configPath = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms', 'models', 'config.json')
|
||||||
|
try {
|
||||||
|
if (!(await fs.pathExists(configPath))) {
|
||||||
|
logger.warn(`Config file does not exist: ${configPath}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
const config: OvmsConfig = await fs.readJson(configPath)
|
||||||
|
if (!config.mediapipe_config_list) {
|
||||||
|
logger.warn(`No mediapipe_config_list found in config: ${configPath}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the model name or ID already exists in the config
|
||||||
|
const exists = config.mediapipe_config_list.some(
|
||||||
|
(model) => model.name === modelName || model.base_path === modelId
|
||||||
|
)
|
||||||
|
if (exists) {
|
||||||
|
logger.warn(`Model with name "${modelName}" or ID "${modelId}" already exists in the config`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Failed to check model existence: ${error}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
private async applyModelPath(modelDirPath: string): Promise<boolean> {
|
||||||
|
const homeDir = homedir()
|
||||||
|
const patchDir = path.join(homeDir, '.cherrystudio', 'ovms', 'patch')
|
||||||
|
if (!(await fs.pathExists(patchDir))) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
const modelId = path.basename(modelDirPath)
|
||||||
|
|
||||||
|
// get all sub directories in patchDir
|
||||||
|
const patchs = await fs.readdir(patchDir)
|
||||||
|
for (const patch of patchs) {
|
||||||
|
const fullPatchPath = path.join(patchDir, patch)
|
||||||
|
|
||||||
|
if (fs.lstatSync(fullPatchPath).isDirectory()) {
|
||||||
|
if (modelId.toLowerCase().includes(patch.toLowerCase())) {
|
||||||
|
// copy all files from fullPath to modelDirPath
|
||||||
|
try {
|
||||||
|
const files = await fs.readdir(fullPatchPath)
|
||||||
|
for (const file of files) {
|
||||||
|
const srcFile = path.join(fullPatchPath, file)
|
||||||
|
const destFile = path.join(modelDirPath, file)
|
||||||
|
await fs.copyFile(srcFile, destFile)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Failed to copy files from ${fullPatchPath} to ${modelDirPath}: ${error}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
logger.info(`Applied patchs for model ${modelId}`)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a model to OVMS by downloading it
|
||||||
|
* @param modelName Name of the model to add
|
||||||
|
* @param modelId ID of the model to download
|
||||||
|
* @param modelSource Model Source: huggingface, hf-mirror and modelscope, default is huggingface
|
||||||
|
* @param task Task type: text_generation, embedding, rerank, image_generation
|
||||||
|
*/
|
||||||
|
public async addModel(
|
||||||
|
modelName: string,
|
||||||
|
modelId: string,
|
||||||
|
modelSource: string,
|
||||||
|
task: string = 'text_generation'
|
||||||
|
): Promise<{ success: boolean; message?: string }> {
|
||||||
|
logger.info(`Adding model: ${modelName} with ID: ${modelId}, Source: ${modelSource}, Task: ${task}`)
|
||||||
|
|
||||||
|
const homeDir = homedir()
|
||||||
|
const ovdndDir = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms')
|
||||||
|
const pathModel = path.join(ovdndDir, 'models', modelId)
|
||||||
|
|
||||||
|
try {
|
||||||
|
// check the ovdnDir+'models'+modelId exist or not
|
||||||
|
if (await fs.pathExists(pathModel)) {
|
||||||
|
logger.error(`Model with ID ${modelId} already exists`)
|
||||||
|
return { success: false, message: 'Model ID already exists!' }
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove the model directory if it exists
|
||||||
|
if (await fs.pathExists(pathModel)) {
|
||||||
|
logger.info(`Removing existing model directory: ${pathModel}`)
|
||||||
|
await fs.remove(pathModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use ovdnd.exe for downloading instead of ovms.exe
|
||||||
|
const ovdndPath = path.join(ovdndDir, 'ovdnd.exe')
|
||||||
|
const command =
|
||||||
|
`"${ovdndPath}" --pull ` +
|
||||||
|
`--model_repository_path "${ovdndDir}/models" ` +
|
||||||
|
`--source_model "${modelId}" ` +
|
||||||
|
`--model_name "${modelName}" ` +
|
||||||
|
`--target_device GPU ` +
|
||||||
|
`--task ${task} ` +
|
||||||
|
`--overwrite_models`
|
||||||
|
|
||||||
|
const env: Record<string, string | undefined> = {
|
||||||
|
...process.env,
|
||||||
|
OVMS_DIR: ovdndDir,
|
||||||
|
PYTHONHOME: path.join(ovdndDir, 'python'),
|
||||||
|
PATH: `${process.env.PATH};${ovdndDir};${path.join(ovdndDir, 'python')}`
|
||||||
|
}
|
||||||
|
|
||||||
|
if (modelSource) {
|
||||||
|
env.HF_ENDPOINT = modelSource
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(`Running command: ${command} from ${modelSource}`)
|
||||||
|
const { stdout } = await execAsync(command, { env: env, cwd: ovdndDir })
|
||||||
|
|
||||||
|
logger.info('Model download completed')
|
||||||
|
logger.debug(`Command output: ${stdout}`)
|
||||||
|
} catch (error) {
|
||||||
|
// remove ovdnDir+'models'+modelId if it exists
|
||||||
|
if (await fs.pathExists(pathModel)) {
|
||||||
|
logger.info(`Removing failed model directory: ${pathModel}`)
|
||||||
|
await fs.remove(pathModel)
|
||||||
|
}
|
||||||
|
logger.error(`Failed to add model: ${error}`)
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
message: `Download model ${modelId} failed, please check following items and try it again:<p>- the model id</p><p>- network connection and proxy</p>`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update config file
|
||||||
|
if (!(await this.updateModelConfig(modelName, modelId))) {
|
||||||
|
logger.error('Failed to update model config')
|
||||||
|
return { success: false, message: 'Failed to update model config' }
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!(await this.applyModelPath(pathModel))) {
|
||||||
|
logger.error('Failed to apply model patchs')
|
||||||
|
return { success: false, message: 'Failed to apply model patchs' }
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(`Model ${modelName} added successfully with ID ${modelId}`)
|
||||||
|
return { success: true }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stop the model download process if it's running
|
||||||
|
* @returns Promise<{ success: boolean; message?: string }>
|
||||||
|
*/
|
||||||
|
public async stopAddModel(): Promise<{ success: boolean; message?: string }> {
|
||||||
|
try {
|
||||||
|
// Check if ovdnd.exe process is running
|
||||||
|
const psCommand = `Get-Process -Name "ovdnd" -ErrorAction SilentlyContinue | Select-Object Id, Path | ConvertTo-Json`
|
||||||
|
const { stdout } = await execAsync(`powershell -Command "${psCommand}"`)
|
||||||
|
|
||||||
|
if (!stdout.trim()) {
|
||||||
|
logger.info('ovdnd process is not running')
|
||||||
|
return { success: true, message: 'Model download process is not running' }
|
||||||
|
}
|
||||||
|
|
||||||
|
const processes = JSON.parse(stdout)
|
||||||
|
const processList = Array.isArray(processes) ? processes : [processes]
|
||||||
|
|
||||||
|
if (processList.length === 0) {
|
||||||
|
logger.info('ovdnd process is not running')
|
||||||
|
return { success: true, message: 'Model download process is not running' }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Terminate all ovdnd processes
|
||||||
|
for (const process of processList) {
|
||||||
|
this.terminalProcess(process.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info('Model download process stopped successfully')
|
||||||
|
return { success: true, message: 'Model download process stopped successfully' }
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Failed to stop model download process: ${error}`)
|
||||||
|
return { success: false, message: 'Failed to stop model download process' }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* check if the model id exists in the OVMS configuration
|
||||||
|
* @param modelId ID of the model to check
|
||||||
|
*/
|
||||||
|
public async checkModelExists(modelId: string): Promise<boolean> {
|
||||||
|
const homeDir = homedir()
|
||||||
|
const ovmsDir = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms')
|
||||||
|
const configPath = path.join(ovmsDir, 'models', 'config.json')
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (!(await fs.pathExists(configPath))) {
|
||||||
|
logger.warn(`Config file does not exist: ${configPath}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
const config: OvmsConfig = await fs.readJson(configPath)
|
||||||
|
if (!config.mediapipe_config_list) {
|
||||||
|
logger.warn('No mediapipe_config_list found in config')
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return config.mediapipe_config_list.some((model) => model.base_path === modelId)
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Failed to check model existence: ${error}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update the model configuration file
|
||||||
|
*/
|
||||||
|
public async updateModelConfig(modelName: string, modelId: string): Promise<boolean> {
|
||||||
|
const homeDir = homedir()
|
||||||
|
const ovmsDir = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms')
|
||||||
|
const configPath = path.join(ovmsDir, 'models', 'config.json')
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Ensure the models directory exists
|
||||||
|
await fs.ensureDir(path.dirname(configPath))
|
||||||
|
let config: OvmsConfig
|
||||||
|
|
||||||
|
// Read existing config or create new one
|
||||||
|
if (await fs.pathExists(configPath)) {
|
||||||
|
config = await fs.readJson(configPath)
|
||||||
|
} else {
|
||||||
|
config = { mediapipe_config_list: [] }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure mediapipe_config_list exists
|
||||||
|
if (!config.mediapipe_config_list) {
|
||||||
|
config.mediapipe_config_list = []
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new model config
|
||||||
|
const newModelConfig: ModelConfig = {
|
||||||
|
name: modelName,
|
||||||
|
base_path: modelId
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if model already exists, if so, update it
|
||||||
|
const existingIndex = config.mediapipe_config_list.findIndex((model) => model.base_path === modelId)
|
||||||
|
|
||||||
|
if (existingIndex >= 0) {
|
||||||
|
config.mediapipe_config_list[existingIndex] = newModelConfig
|
||||||
|
logger.info(`Updated existing model config: ${modelName}`)
|
||||||
|
} else {
|
||||||
|
config.mediapipe_config_list.push(newModelConfig)
|
||||||
|
logger.info(`Added new model config: ${modelName}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write config back to file
|
||||||
|
await fs.writeJson(configPath, config, { spaces: 2 })
|
||||||
|
logger.info(`Config file updated: ${configPath}`)
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Failed to update model config: ${error}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all models from OVMS config, filtered for image generation models
|
||||||
|
* @returns Array of model configurations
|
||||||
|
*/
|
||||||
|
public async getModels(): Promise<ModelConfig[]> {
|
||||||
|
const homeDir = homedir()
|
||||||
|
const ovmsDir = path.join(homeDir, '.cherrystudio', 'ovms', 'ovms')
|
||||||
|
const configPath = path.join(ovmsDir, 'models', 'config.json')
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (!(await fs.pathExists(configPath))) {
|
||||||
|
logger.warn(`Config file does not exist: ${configPath}`)
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
const config: OvmsConfig = await fs.readJson(configPath)
|
||||||
|
if (!config.mediapipe_config_list) {
|
||||||
|
logger.warn('No mediapipe_config_list found in config')
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter models for image generation (SD, Stable-Diffusion, Stable Diffusion, FLUX)
|
||||||
|
const imageGenerationModels = config.mediapipe_config_list.filter((model) => {
|
||||||
|
const modelName = model.name.toLowerCase()
|
||||||
|
return (
|
||||||
|
modelName.startsWith('sd') ||
|
||||||
|
modelName.startsWith('stable-diffusion') ||
|
||||||
|
modelName.startsWith('stable diffusion') ||
|
||||||
|
modelName.startsWith('flux')
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(`Found ${imageGenerationModels.length} image generation models`)
|
||||||
|
return imageGenerationModels
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Failed to get models: ${error}`)
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default OvmsManager
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import { session, shell, webContents } from 'electron'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
|
import { app, session, shell, webContents } from 'electron'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* init the useragent of the webview session
|
* init the useragent of the webview session
|
||||||
@@ -36,3 +37,66 @@ export function setOpenLinkExternal(webviewId: number, isExternal: boolean) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const attachKeyboardHandler = (contents: Electron.WebContents) => {
|
||||||
|
if (contents.getType?.() !== 'webview') {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleBeforeInput = (event: Electron.Event, input: Electron.Input) => {
|
||||||
|
if (!input) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const key = input.key?.toLowerCase()
|
||||||
|
if (!key) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const isFindShortcut = (input.control || input.meta) && key === 'f'
|
||||||
|
const isEscape = key === 'escape'
|
||||||
|
const isEnter = key === 'enter'
|
||||||
|
|
||||||
|
if (!isFindShortcut && !isEscape && !isEnter) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const host = contents.hostWebContents
|
||||||
|
if (!host || host.isDestroyed()) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always prevent Cmd/Ctrl+F to override the guest page's native find dialog
|
||||||
|
if (isFindShortcut) {
|
||||||
|
event.preventDefault()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the hotkey event to the renderer
|
||||||
|
// The renderer will decide whether to preventDefault for Escape and Enter
|
||||||
|
// based on whether the search bar is visible
|
||||||
|
host.send(IpcChannel.Webview_SearchHotkey, {
|
||||||
|
webviewId: contents.id,
|
||||||
|
key,
|
||||||
|
control: Boolean(input.control),
|
||||||
|
meta: Boolean(input.meta),
|
||||||
|
shift: Boolean(input.shift),
|
||||||
|
alt: Boolean(input.alt)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
contents.on('before-input-event', handleBeforeInput)
|
||||||
|
contents.once('destroyed', () => {
|
||||||
|
contents.removeListener('before-input-event', handleBeforeInput)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export function initWebviewHotkeys() {
|
||||||
|
webContents.getAllWebContents().forEach((contents) => {
|
||||||
|
if (contents.isDestroyed()) return
|
||||||
|
attachKeyboardHandler(contents)
|
||||||
|
})
|
||||||
|
|
||||||
|
app.on('web-contents-created', (_, contents) => {
|
||||||
|
attachKeyboardHandler(contents)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -256,7 +256,7 @@ export class WindowService {
|
|||||||
|
|
||||||
private setupWebContentsHandlers(mainWindow: BrowserWindow) {
|
private setupWebContentsHandlers(mainWindow: BrowserWindow) {
|
||||||
mainWindow.webContents.on('will-navigate', (event, url) => {
|
mainWindow.webContents.on('will-navigate', (event, url) => {
|
||||||
if (url.includes('localhost:5173')) {
|
if (url.includes('localhost:517')) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -275,7 +275,8 @@ export class WindowService {
|
|||||||
'https://aihubmix.com/topup',
|
'https://aihubmix.com/topup',
|
||||||
'https://aihubmix.com/statistics',
|
'https://aihubmix.com/statistics',
|
||||||
'https://dash.302.ai/sso/login',
|
'https://dash.302.ai/sso/login',
|
||||||
'https://dash.302.ai/charge'
|
'https://dash.302.ai/charge',
|
||||||
|
'https://www.aiionly.com/login'
|
||||||
]
|
]
|
||||||
|
|
||||||
if (oauthProviderUrls.some((link) => url.startsWith(link))) {
|
if (oauthProviderUrls.some((link) => url.startsWith(link))) {
|
||||||
|
|||||||
@@ -0,0 +1,277 @@
|
|||||||
|
import { UpdateInfo } from 'builder-util-runtime'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
// Mock dependencies
|
||||||
|
vi.mock('@logger', () => ({
|
||||||
|
loggerService: {
|
||||||
|
withContext: () => ({
|
||||||
|
info: vi.fn(),
|
||||||
|
error: vi.fn(),
|
||||||
|
warn: vi.fn()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../ConfigManager', () => ({
|
||||||
|
configManager: {
|
||||||
|
getLanguage: vi.fn(),
|
||||||
|
getAutoUpdate: vi.fn(() => false),
|
||||||
|
getTestPlan: vi.fn(() => false),
|
||||||
|
getTestChannel: vi.fn(),
|
||||||
|
getClientId: vi.fn(() => 'test-client-id')
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../WindowService', () => ({
|
||||||
|
windowService: {
|
||||||
|
getMainWindow: vi.fn()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@main/constant', () => ({
|
||||||
|
isWin: false
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@main/utils/ipService', () => ({
|
||||||
|
getIpCountry: vi.fn(() => 'US')
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@main/utils/locales', () => ({
|
||||||
|
locales: {
|
||||||
|
en: { translation: { update: {} } },
|
||||||
|
'zh-CN': { translation: { update: {} } }
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@main/utils/systemInfo', () => ({
|
||||||
|
generateUserAgent: vi.fn(() => 'test-user-agent')
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('electron', () => ({
|
||||||
|
app: {
|
||||||
|
isPackaged: true,
|
||||||
|
getVersion: vi.fn(() => '1.0.0'),
|
||||||
|
getPath: vi.fn(() => '/test/path')
|
||||||
|
},
|
||||||
|
dialog: {
|
||||||
|
showMessageBox: vi.fn()
|
||||||
|
},
|
||||||
|
BrowserWindow: vi.fn(),
|
||||||
|
net: {
|
||||||
|
fetch: vi.fn()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('electron-updater', () => ({
|
||||||
|
autoUpdater: {
|
||||||
|
logger: null,
|
||||||
|
forceDevUpdateConfig: false,
|
||||||
|
autoDownload: false,
|
||||||
|
autoInstallOnAppQuit: false,
|
||||||
|
requestHeaders: {},
|
||||||
|
on: vi.fn(),
|
||||||
|
setFeedURL: vi.fn(),
|
||||||
|
checkForUpdates: vi.fn(),
|
||||||
|
downloadUpdate: vi.fn(),
|
||||||
|
quitAndInstall: vi.fn(),
|
||||||
|
channel: '',
|
||||||
|
allowDowngrade: false,
|
||||||
|
disableDifferentialDownload: false,
|
||||||
|
currentVersion: '1.0.0'
|
||||||
|
},
|
||||||
|
Logger: vi.fn(),
|
||||||
|
NsisUpdater: vi.fn(),
|
||||||
|
AppUpdater: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Import after mocks
|
||||||
|
import AppUpdater from '../AppUpdater'
|
||||||
|
import { configManager } from '../ConfigManager'
|
||||||
|
|
||||||
|
describe('AppUpdater', () => {
|
||||||
|
let appUpdater: AppUpdater
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
appUpdater = new AppUpdater()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('parseMultiLangReleaseNotes', () => {
|
||||||
|
const sampleReleaseNotes = `<!--LANG:en-->
|
||||||
|
🚀 New Features:
|
||||||
|
- Feature A
|
||||||
|
- Feature B
|
||||||
|
|
||||||
|
🎨 UI Improvements:
|
||||||
|
- Improvement A
|
||||||
|
<!--LANG:zh-CN-->
|
||||||
|
🚀 新功能:
|
||||||
|
- 功能 A
|
||||||
|
- 功能 B
|
||||||
|
|
||||||
|
🎨 界面改进:
|
||||||
|
- 改进 A
|
||||||
|
<!--LANG:END-->`
|
||||||
|
|
||||||
|
it('should return Chinese notes for zh-CN users', () => {
|
||||||
|
vi.mocked(configManager.getLanguage).mockReturnValue('zh-CN')
|
||||||
|
|
||||||
|
const result = (appUpdater as any).parseMultiLangReleaseNotes(sampleReleaseNotes)
|
||||||
|
|
||||||
|
expect(result).toContain('新功能')
|
||||||
|
expect(result).toContain('功能 A')
|
||||||
|
expect(result).not.toContain('New Features')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return Chinese notes for zh-TW users', () => {
|
||||||
|
vi.mocked(configManager.getLanguage).mockReturnValue('zh-TW')
|
||||||
|
|
||||||
|
const result = (appUpdater as any).parseMultiLangReleaseNotes(sampleReleaseNotes)
|
||||||
|
|
||||||
|
expect(result).toContain('新功能')
|
||||||
|
expect(result).toContain('功能 A')
|
||||||
|
expect(result).not.toContain('New Features')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return English notes for non-Chinese users', () => {
|
||||||
|
vi.mocked(configManager.getLanguage).mockReturnValue('en-US')
|
||||||
|
|
||||||
|
const result = (appUpdater as any).parseMultiLangReleaseNotes(sampleReleaseNotes)
|
||||||
|
|
||||||
|
expect(result).toContain('New Features')
|
||||||
|
expect(result).toContain('Feature A')
|
||||||
|
expect(result).not.toContain('新功能')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return English notes for other language users', () => {
|
||||||
|
vi.mocked(configManager.getLanguage).mockReturnValue('ru-RU')
|
||||||
|
|
||||||
|
const result = (appUpdater as any).parseMultiLangReleaseNotes(sampleReleaseNotes)
|
||||||
|
|
||||||
|
expect(result).toContain('New Features')
|
||||||
|
expect(result).not.toContain('新功能')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle missing language sections gracefully', () => {
|
||||||
|
const malformedNotes = 'Simple release notes without markers'
|
||||||
|
|
||||||
|
const result = (appUpdater as any).parseMultiLangReleaseNotes(malformedNotes)
|
||||||
|
|
||||||
|
expect(result).toBe('Simple release notes without markers')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle malformed markers', () => {
|
||||||
|
const malformedNotes = `<!--LANG:en-->English only`
|
||||||
|
vi.mocked(configManager.getLanguage).mockReturnValue('zh-CN')
|
||||||
|
|
||||||
|
const result = (appUpdater as any).parseMultiLangReleaseNotes(malformedNotes)
|
||||||
|
|
||||||
|
// Should clean up markers and return cleaned content
|
||||||
|
expect(result).toContain('English only')
|
||||||
|
expect(result).not.toContain('<!--LANG:')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty release notes', () => {
|
||||||
|
const result = (appUpdater as any).parseMultiLangReleaseNotes('')
|
||||||
|
|
||||||
|
expect(result).toBe('')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle errors gracefully', () => {
|
||||||
|
// Force an error by mocking configManager to throw
|
||||||
|
vi.mocked(configManager.getLanguage).mockImplementation(() => {
|
||||||
|
throw new Error('Test error')
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = (appUpdater as any).parseMultiLangReleaseNotes(sampleReleaseNotes)
|
||||||
|
|
||||||
|
// Should return original notes as fallback
|
||||||
|
expect(result).toBe(sampleReleaseNotes)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('hasMultiLanguageMarkers', () => {
|
||||||
|
it('should return true when markers are present', () => {
|
||||||
|
const notes = '<!--LANG:en-->Test'
|
||||||
|
|
||||||
|
const result = (appUpdater as any).hasMultiLanguageMarkers(notes)
|
||||||
|
|
||||||
|
expect(result).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return false when no markers are present', () => {
|
||||||
|
const notes = 'Simple text without markers'
|
||||||
|
|
||||||
|
const result = (appUpdater as any).hasMultiLanguageMarkers(notes)
|
||||||
|
|
||||||
|
expect(result).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('processReleaseInfo', () => {
|
||||||
|
it('should process multi-language release notes in string format', () => {
|
||||||
|
vi.mocked(configManager.getLanguage).mockReturnValue('zh-CN')
|
||||||
|
|
||||||
|
const releaseInfo = {
|
||||||
|
version: '1.0.0',
|
||||||
|
files: [],
|
||||||
|
path: '',
|
||||||
|
sha512: '',
|
||||||
|
releaseDate: new Date().toISOString(),
|
||||||
|
releaseNotes: `<!--LANG:en-->English notes<!--LANG:zh-CN-->中文说明<!--LANG:END-->`
|
||||||
|
} as UpdateInfo
|
||||||
|
|
||||||
|
const result = (appUpdater as any).processReleaseInfo(releaseInfo)
|
||||||
|
|
||||||
|
expect(result.releaseNotes).toBe('中文说明')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not process release notes without markers', () => {
|
||||||
|
const releaseInfo = {
|
||||||
|
version: '1.0.0',
|
||||||
|
files: [],
|
||||||
|
path: '',
|
||||||
|
sha512: '',
|
||||||
|
releaseDate: new Date().toISOString(),
|
||||||
|
releaseNotes: 'Simple release notes'
|
||||||
|
} as UpdateInfo
|
||||||
|
|
||||||
|
const result = (appUpdater as any).processReleaseInfo(releaseInfo)
|
||||||
|
|
||||||
|
expect(result.releaseNotes).toBe('Simple release notes')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle array format release notes', () => {
|
||||||
|
const releaseInfo = {
|
||||||
|
version: '1.0.0',
|
||||||
|
files: [],
|
||||||
|
path: '',
|
||||||
|
sha512: '',
|
||||||
|
releaseDate: new Date().toISOString(),
|
||||||
|
releaseNotes: [
|
||||||
|
{ version: '1.0.0', note: 'Note 1' },
|
||||||
|
{ version: '1.0.1', note: 'Note 2' }
|
||||||
|
]
|
||||||
|
} as UpdateInfo
|
||||||
|
|
||||||
|
const result = (appUpdater as any).processReleaseInfo(releaseInfo)
|
||||||
|
|
||||||
|
expect(result.releaseNotes).toEqual(releaseInfo.releaseNotes)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle null release notes', () => {
|
||||||
|
const releaseInfo = {
|
||||||
|
version: '1.0.0',
|
||||||
|
files: [],
|
||||||
|
path: '',
|
||||||
|
sha512: '',
|
||||||
|
releaseDate: new Date().toISOString(),
|
||||||
|
releaseNotes: null
|
||||||
|
} as UpdateInfo
|
||||||
|
|
||||||
|
const result = (appUpdater as any).processReleaseInfo(releaseInfo)
|
||||||
|
|
||||||
|
expect(result.releaseNotes).toBeNull()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -1,341 +0,0 @@
|
|||||||
# Agent Message Architecture Design Document
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This document describes the architecture for handling agent messages in Cherry Studio, including how agent-specific messages are generated, transformed to AI SDK format, stored, and sent to the UI. The system is designed to be agent-agnostic, allowing multiple agent types (Claude Code, OpenAI, etc.) to integrate seamlessly.
|
|
||||||
|
|
||||||
## Core Design Principles
|
|
||||||
|
|
||||||
1. **Agent Agnosticism**: The core message handling system should work with any agent type without modification
|
|
||||||
2. **Data Preservation**: All raw agent data must be preserved alongside transformed UI-friendly formats
|
|
||||||
3. **Streaming First**: Support real-time streaming of agent responses to the UI
|
|
||||||
4. **Type Safety**: Strong TypeScript interfaces ensure consistency across the pipeline
|
|
||||||
|
|
||||||
## Architecture Components
|
|
||||||
|
|
||||||
### 1. Agent Service Layer
|
|
||||||
|
|
||||||
Each agent (e.g., ClaudeCodeService) implements the `AgentServiceInterface`:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
interface AgentServiceInterface {
|
|
||||||
invoke(prompt: string, cwd: string, sessionId?: string, options?: any): AgentStream
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Responsibilities:
|
|
||||||
- Spawn and manage agent-specific processes (e.g., Claude Code CLI)
|
|
||||||
- Parse agent-specific output formats (e.g., SDKMessage for Claude Code)
|
|
||||||
- Transform agent messages to AI SDK format
|
|
||||||
- Emit standardized `AgentStreamEvent` objects
|
|
||||||
|
|
||||||
### 2. Agent Stream Events
|
|
||||||
|
|
||||||
The standardized event interface that all agents emit:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
interface AgentStreamEvent {
|
|
||||||
type: 'chunk' | 'error' | 'complete'
|
|
||||||
chunk?: UIMessageChunk // AI SDK format for UI
|
|
||||||
rawAgentMessage?: any // Agent-specific raw message
|
|
||||||
error?: Error
|
|
||||||
agentResult?: any // Complete agent-specific result
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Session Message Service
|
|
||||||
|
|
||||||
The `SessionMessageService` acts as the orchestration layer:
|
|
||||||
|
|
||||||
#### Responsibilities:
|
|
||||||
- Manages session lifecycle and persistence
|
|
||||||
- Collects streaming chunks and raw agent messages
|
|
||||||
- Stores structured data in the database
|
|
||||||
- Forwards events to the API layer
|
|
||||||
|
|
||||||
### 4. Database Storage
|
|
||||||
|
|
||||||
Session messages are stored with complete structured data:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
interface SessionMessageContent {
|
|
||||||
aiSDKChunks: UIMessageChunk[] // UI-friendly format
|
|
||||||
rawAgentMessages: any[] // Original agent messages
|
|
||||||
agentResult?: any // Complete agent result
|
|
||||||
agentType: string // Agent identifier
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Data Flow
|
|
||||||
|
|
||||||
```mermaid
|
|
||||||
graph TD
|
|
||||||
A[User Input] --> B[API Handler]
|
|
||||||
B --> C[SessionMessageService]
|
|
||||||
C --> D[Agent Service]
|
|
||||||
D --> E[Agent Process]
|
|
||||||
E --> F[Raw Agent Output]
|
|
||||||
F --> G[Transform to AI SDK]
|
|
||||||
G --> H[Emit AgentStreamEvent]
|
|
||||||
H --> I[SessionMessageService]
|
|
||||||
I --> J[Store in Database]
|
|
||||||
I --> K[Forward to Client]
|
|
||||||
K --> L[UI Rendering]
|
|
||||||
```
|
|
||||||
|
|
||||||
## Message Transformation Process
|
|
||||||
|
|
||||||
### Step 1: Raw Agent Message Generation
|
|
||||||
|
|
||||||
Each agent generates messages in its native format:
|
|
||||||
|
|
||||||
**Claude Code Example:**
|
|
||||||
```typescript
|
|
||||||
// SDKMessage from Claude Code CLI
|
|
||||||
{
|
|
||||||
type: 'assistant',
|
|
||||||
uuid: 'msg_123',
|
|
||||||
session_id: 'session_456',
|
|
||||||
message: {
|
|
||||||
role: 'assistant',
|
|
||||||
content: [
|
|
||||||
{ type: 'text', text: 'Hello, I can help...' },
|
|
||||||
{ type: 'tool_use', id: 'tool_1', name: 'read_file', input: {...} }
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 2: Transformation to AI SDK Format
|
|
||||||
|
|
||||||
The agent service transforms native messages to AI SDK `UIMessageChunk`:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
// In ClaudeCodeService
|
|
||||||
const emitChunks = (sdkMessage: SDKMessage) => {
|
|
||||||
// Transform to AI SDK format
|
|
||||||
const chunks = transformSDKMessageToUIChunk(sdkMessage)
|
|
||||||
|
|
||||||
for (const chunk of chunks) {
|
|
||||||
stream.emit('data', {
|
|
||||||
type: 'chunk',
|
|
||||||
chunk, // AI SDK format
|
|
||||||
rawAgentMessage: sdkMessage // Preserve original
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Transformed AI SDK Chunk:**
|
|
||||||
```typescript
|
|
||||||
{
|
|
||||||
type: 'text-delta',
|
|
||||||
id: 'msg_123',
|
|
||||||
delta: 'Hello, I can help...',
|
|
||||||
providerMetadata: {
|
|
||||||
claudeCode: {
|
|
||||||
originalSDKMessage: {...},
|
|
||||||
uuid: 'msg_123',
|
|
||||||
session_id: 'session_456'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 3: Session Message Processing
|
|
||||||
|
|
||||||
The SessionMessageService collects and processes events:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
// Collect streaming data
|
|
||||||
const streamedChunks: UIMessageChunk[] = []
|
|
||||||
const rawAgentMessages: any[] = []
|
|
||||||
|
|
||||||
claudeStream.on('data', async (event: AgentStreamEvent) => {
|
|
||||||
switch (event.type) {
|
|
||||||
case 'chunk':
|
|
||||||
streamedChunks.push(event.chunk)
|
|
||||||
if (event.rawAgentMessage) {
|
|
||||||
rawAgentMessages.push(event.rawAgentMessage)
|
|
||||||
}
|
|
||||||
// Forward to client
|
|
||||||
sessionStream.emit('data', { type: 'chunk', chunk: event.chunk })
|
|
||||||
break
|
|
||||||
|
|
||||||
case 'complete':
|
|
||||||
// Store complete structured data
|
|
||||||
const content = {
|
|
||||||
aiSDKChunks: streamedChunks,
|
|
||||||
rawAgentMessages: rawAgentMessages,
|
|
||||||
agentResult: event.agentResult,
|
|
||||||
agentType: event.agentResult?.agentType || 'unknown'
|
|
||||||
}
|
|
||||||
// Save to database...
|
|
||||||
break
|
|
||||||
}
|
|
||||||
})
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 4: Client Streaming
|
|
||||||
|
|
||||||
The API handler converts events to Server-Sent Events (SSE):
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
// In API handler
|
|
||||||
messageStream.on('data', (event: any) => {
|
|
||||||
switch (event.type) {
|
|
||||||
case 'chunk':
|
|
||||||
// Send AI SDK chunk as SSE
|
|
||||||
res.write(`data: ${JSON.stringify(event.chunk)}\n\n`)
|
|
||||||
break
|
|
||||||
case 'complete':
|
|
||||||
res.write('data: [DONE]\n\n')
|
|
||||||
res.end()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
})
|
|
||||||
```
|
|
||||||
|
|
||||||
## Adding New Agent Types
|
|
||||||
|
|
||||||
To add support for a new agent (e.g., OpenAI):
|
|
||||||
|
|
||||||
### 1. Create Agent Service
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
class OpenAIService implements AgentServiceInterface {
|
|
||||||
invokeStream(prompt: string, cwd: string, sessionId?: string, options?: any): AgentStream {
|
|
||||||
const stream = new OpenAIStream()
|
|
||||||
|
|
||||||
// Call OpenAI API
|
|
||||||
const openaiResponse = await openai.chat.completions.create({
|
|
||||||
messages: [{ role: 'user', content: prompt }],
|
|
||||||
stream: true
|
|
||||||
})
|
|
||||||
|
|
||||||
// Transform OpenAI format to AI SDK
|
|
||||||
for await (const chunk of openaiResponse) {
|
|
||||||
const aiSDKChunk = transformOpenAIToAISDK(chunk)
|
|
||||||
stream.emit('data', {
|
|
||||||
type: 'chunk',
|
|
||||||
chunk: aiSDKChunk,
|
|
||||||
rawAgentMessage: chunk // Preserve OpenAI format
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return stream
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Create Transform Function
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
function transformOpenAIToAISDK(openaiChunk: OpenAIChunk): UIMessageChunk {
|
|
||||||
return {
|
|
||||||
type: 'text-delta',
|
|
||||||
id: openaiChunk.id,
|
|
||||||
delta: openaiChunk.choices[0].delta.content,
|
|
||||||
providerMetadata: {
|
|
||||||
openai: {
|
|
||||||
original: openaiChunk,
|
|
||||||
model: openaiChunk.model
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Register Agent Type
|
|
||||||
|
|
||||||
Update the agent type enum and factory:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
export type AgentType = 'claude-code' | 'openai' | 'anthropic-api'
|
|
||||||
|
|
||||||
function createAgentService(type: AgentType): AgentServiceInterface {
|
|
||||||
switch (type) {
|
|
||||||
case 'claude-code':
|
|
||||||
return new ClaudeCodeService()
|
|
||||||
case 'openai':
|
|
||||||
return new OpenAIService()
|
|
||||||
// ...
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Benefits of This Architecture
|
|
||||||
|
|
||||||
1. **Extensibility**: Easy to add new agent types without modifying core logic
|
|
||||||
2. **Data Integrity**: Raw agent data is never lost during transformation
|
|
||||||
3. **Debugging**: Complete message history available for troubleshooting
|
|
||||||
4. **Performance**: Streaming support for real-time responses
|
|
||||||
5. **Type Safety**: Strong interfaces prevent runtime errors
|
|
||||||
6. **UI Consistency**: All agents provide data in standard AI SDK format
|
|
||||||
|
|
||||||
## Key Interfaces Reference
|
|
||||||
|
|
||||||
### AgentStreamEvent
|
|
||||||
```typescript
|
|
||||||
interface AgentStreamEvent {
|
|
||||||
type: 'chunk' | 'error' | 'complete'
|
|
||||||
chunk?: UIMessageChunk
|
|
||||||
rawAgentMessage?: any
|
|
||||||
error?: Error
|
|
||||||
agentResult?: any
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### SessionMessageEntity
|
|
||||||
```typescript
|
|
||||||
interface SessionMessageEntity {
|
|
||||||
id: number
|
|
||||||
session_id: string
|
|
||||||
parent_id?: number
|
|
||||||
role: 'user' | 'assistant' | 'system' | 'tool'
|
|
||||||
type: string
|
|
||||||
content: string | SessionMessageContent
|
|
||||||
metadata?: Record<string, any>
|
|
||||||
created_at: string
|
|
||||||
updated_at: string
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### SessionMessageContent
|
|
||||||
```typescript
|
|
||||||
interface SessionMessageContent {
|
|
||||||
aiSDKChunks: UIMessageChunk[]
|
|
||||||
rawAgentMessages: any[]
|
|
||||||
agentResult?: any
|
|
||||||
agentType: string
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Testing Strategy
|
|
||||||
|
|
||||||
### Unit Tests
|
|
||||||
- Test each transform function independently
|
|
||||||
- Verify event emission sequences
|
|
||||||
- Validate data structure preservation
|
|
||||||
|
|
||||||
### Integration Tests
|
|
||||||
- Test complete flow from input to database
|
|
||||||
- Verify streaming behavior
|
|
||||||
- Test error handling and recovery
|
|
||||||
|
|
||||||
### Agent-Specific Tests
|
|
||||||
- Validate agent-specific transformations
|
|
||||||
- Test edge cases for each agent type
|
|
||||||
- Verify metadata preservation
|
|
||||||
|
|
||||||
## Future Enhancements
|
|
||||||
|
|
||||||
1. **Message Replay**: Ability to replay sessions from stored raw messages
|
|
||||||
2. **Format Migration**: Tools to migrate between agent formats
|
|
||||||
3. **Analytics**: Aggregate metrics from raw agent data
|
|
||||||
4. **Caching**: Cache transformed chunks for performance
|
|
||||||
5. **Compression**: Compress raw messages for storage efficiency
|
|
||||||
|
|
||||||
## Conclusion
|
|
||||||
|
|
||||||
This architecture provides a robust, extensible foundation for handling messages from multiple AI agents while maintaining data integrity and providing a consistent interface for the UI. The separation of concerns between agent-specific logic and core message handling ensures the system can evolve to support new agents and features without breaking existing functionality.
|
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
import { type Client, createClient } from '@libsql/client'
|
import { type Client, createClient } from '@libsql/client'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
|
import { mcpApiService } from '@main/apiServer/services/mcp'
|
||||||
import { ModelValidationError, validateModelId } from '@main/apiServer/utils'
|
import { ModelValidationError, validateModelId } from '@main/apiServer/utils'
|
||||||
import { AgentType, objectKeys, Provider } from '@types'
|
import { AgentType, MCPTool, objectKeys, SlashCommand, Tool } from '@types'
|
||||||
import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
|
import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
|
||||||
import fs from 'fs'
|
import fs from 'fs'
|
||||||
import path from 'path'
|
import path from 'path'
|
||||||
@@ -10,6 +11,8 @@ import { MigrationService } from './database/MigrationService'
|
|||||||
import * as schema from './database/schema'
|
import * as schema from './database/schema'
|
||||||
import { dbPath } from './drizzle.config'
|
import { dbPath } from './drizzle.config'
|
||||||
import { AgentModelField, AgentModelValidationError } from './errors'
|
import { AgentModelField, AgentModelValidationError } from './errors'
|
||||||
|
import { builtinSlashCommands } from './services/claudecode/commands'
|
||||||
|
import { builtinTools } from './services/claudecode/tools'
|
||||||
|
|
||||||
const logger = loggerService.withContext('BaseService')
|
const logger = loggerService.withContext('BaseService')
|
||||||
|
|
||||||
@@ -30,7 +33,7 @@ export abstract class BaseService {
|
|||||||
protected static db: LibSQLDatabase<typeof schema> | null = null
|
protected static db: LibSQLDatabase<typeof schema> | null = null
|
||||||
protected static isInitialized = false
|
protected static isInitialized = false
|
||||||
protected static initializationPromise: Promise<void> | null = null
|
protected static initializationPromise: Promise<void> | null = null
|
||||||
protected jsonFields: string[] = ['built_in_tools', 'mcps', 'configuration', 'accessible_paths', 'allowed_tools']
|
protected jsonFields: string[] = ['tools', 'mcps', 'configuration', 'accessible_paths', 'allowed_tools']
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize database with retry logic and proper error handling
|
* Initialize database with retry logic and proper error handling
|
||||||
@@ -49,6 +52,45 @@ export abstract class BaseService {
|
|||||||
return BaseService.initializationPromise
|
return BaseService.initializationPromise
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async listMcpTools(agentType: AgentType, ids?: string[]): Promise<Tool[]> {
|
||||||
|
const tools: Tool[] = []
|
||||||
|
if (agentType === 'claude-code') {
|
||||||
|
tools.push(...builtinTools)
|
||||||
|
}
|
||||||
|
if (ids && ids.length > 0) {
|
||||||
|
for (const id of ids) {
|
||||||
|
try {
|
||||||
|
const server = await mcpApiService.getServerInfo(id)
|
||||||
|
if (server) {
|
||||||
|
server.tools.forEach((tool: MCPTool) => {
|
||||||
|
tools.push({
|
||||||
|
id: `mcp_${id}_${tool.name}`,
|
||||||
|
name: tool.name,
|
||||||
|
type: 'mcp',
|
||||||
|
description: tool.description || '',
|
||||||
|
requirePermissions: true
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn('Failed to list MCP tools', {
|
||||||
|
id,
|
||||||
|
error: error as Error
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
public async listSlashCommands(agentType: AgentType): Promise<SlashCommand[]> {
|
||||||
|
if (agentType === 'claude-code') {
|
||||||
|
return builtinSlashCommands
|
||||||
|
}
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
private static async performInitialization(): Promise<void> {
|
private static async performInitialization(): Promise<void> {
|
||||||
const maxRetries = 3
|
const maxRetries = 3
|
||||||
let lastError: Error
|
let lastError: Error
|
||||||
@@ -271,23 +313,6 @@ export abstract class BaseService {
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// different agent types may have different provider requirements
|
|
||||||
const agentTypeProviderRequirements: Record<AgentType, Provider['type']> = {
|
|
||||||
'claude-code': 'anthropic'
|
|
||||||
}
|
|
||||||
for (const [ak, pk] of Object.entries(agentTypeProviderRequirements)) {
|
|
||||||
if (agentType === ak && validation.provider.type !== pk) {
|
|
||||||
throw new AgentModelValidationError(
|
|
||||||
{ agentType, field, model: modelValue },
|
|
||||||
{
|
|
||||||
type: 'unsupported_provider_type',
|
|
||||||
message: `Provider type '${validation.provider.type}' is not supported for agent type '${agentType}'. Expected '${pk}'`,
|
|
||||||
code: 'unsupported_provider_type'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,35 +0,0 @@
|
|||||||
# Agents Service Refactor TODO (interface-level)
|
|
||||||
|
|
||||||
- [x] **SessionMessageService.createSessionMessage**
|
|
||||||
- Replace the current `EventEmitter` that emits `UIMessageChunk` with a readable stream of `TextStreamPart` objects (same shape produced by `/api/messages` in `messageThunk`).
|
|
||||||
- Update `startSessionMessageStream` to call a new adapter (`claudeToTextStreamPart(chunk)`) that maps Claude Code chunk payloads to `{ type: 'text-delta' | 'tool-call' | ... }` parts used by `AiSdkToChunkAdapter`.
|
|
||||||
- Add a secondary return value (promise) resolving to the persisted `ModelMessage[]` once streaming completes, so the renderer thunk can await save confirmation.
|
|
||||||
|
|
||||||
- [x] **main -> renderer transport**
|
|
||||||
- Update the existing SSE handler in `src/main/apiServer/routes/agents/handlers/messages.ts` (e.g., `createMessage`) to forward the new `TextStreamPart` stream over HTTP, preserving the current agent endpoint contract.
|
|
||||||
- Keep abort handling compatible with the current HTTP server (honor `AbortController` on the request to terminate the stream).
|
|
||||||
|
|
||||||
- [x] **renderer thunk integration**
|
|
||||||
- Introduce a thin IPC contract (e.g., `AgentMessagePersistence`) surfaced by `src/main/services/agents/database/index.ts` so the renderer thunk can request session-message writes without going through `SessionMessageService`.
|
|
||||||
- Define explicit entry points on the main side:
|
|
||||||
- `persistUserMessage({ sessionId, agentSessionId, payload, createdAt?, metadata? })`
|
|
||||||
- `persistAssistantMessage({ sessionId, agentSessionId, payload, createdAt?, metadata? })`
|
|
||||||
- `persistExchange({ sessionId, agentSessionId, user, assistant })` which runs the above in a single transaction and returns both records.
|
|
||||||
- Export these helpers via an `agentMessageRepository` object so both IPC handlers and legacy services share the same persistence path.
|
|
||||||
- Normalize persisted payloads to `{ message, blocks }` matching the renderer schema instead of AI-SDK `ModelMessage` chunks.
|
|
||||||
- Extend `messageThunk.sendMessage` to call the agent transport when the topic corresponds to a session, pipe chunks through `createStreamProcessor` + `AiSdkToChunkAdapter`, and invoke the new persistence interface once streaming resolves.
|
|
||||||
- Replace `useSession().createSessionMessage` optimistic insert with dispatching the thunk so Redux/Dexie persistence happens via the shared save helpers.
|
|
||||||
|
|
||||||
- [x] **persistence alignment**
|
|
||||||
- Remove `persistUserMessage` / `persistAssistantMessage` calls from `SessionMessageService`; instead expose a `SessionMessageRepository` in `main` that the thunk invokes via existing Dexie helpers.
|
|
||||||
- On renderer side, persist agent exchanges via IPC after streaming completes, storing `{ message, blocks }` payloads while skipping Dexie writes for agent sessions so the single source of truth remains `session_messages`.
|
|
||||||
|
|
||||||
- [x] **Blocks renderer**
|
|
||||||
- Replace `AgentSessionMessages` simple `<div>` render with the shared `Blocks` component (`src/renderer/src/pages/home/Messages/Blocks`) wired to the Redux store.
|
|
||||||
- Adjust `useSession` to only fetch metadata (e.g., session info) and rely on store selectors for message list.
|
|
||||||
|
|
||||||
- [x] **API client clean-up**
|
|
||||||
- Remove `AgentApiClient.createMessage` direct POST once thunk is in place; calls should go through renderer thunk -> stream -> final persistence.
|
|
||||||
|
|
||||||
- [ ] **Regression tests**
|
|
||||||
- Add integration test to assert agent sessions render incremental text the same way as standard assistant messages.
|
|
||||||
@@ -7,9 +7,10 @@ import type {
|
|||||||
AgentPersistedMessage,
|
AgentPersistedMessage,
|
||||||
AgentSessionMessageEntity
|
AgentSessionMessageEntity
|
||||||
} from '@types'
|
} from '@types'
|
||||||
|
import { and, asc, eq } from 'drizzle-orm'
|
||||||
|
|
||||||
import { BaseService } from '../BaseService'
|
import { BaseService } from '../BaseService'
|
||||||
import type { InsertSessionMessageRow } from './schema'
|
import type { InsertSessionMessageRow, SessionMessageRow } from './schema'
|
||||||
import { sessionMessagesTable } from './schema'
|
import { sessionMessagesTable } from './schema'
|
||||||
|
|
||||||
const logger = loggerService.withContext('AgentMessageRepository')
|
const logger = loggerService.withContext('AgentMessageRepository')
|
||||||
@@ -90,19 +91,86 @@ class AgentMessageRepository extends BaseService {
|
|||||||
return tx ?? this.database
|
return tx ?? this.database
|
||||||
}
|
}
|
||||||
|
|
||||||
async persistUserMessage(params: PersistUserMessageParams): Promise<AgentSessionMessageEntity> {
|
private async findExistingMessageRow(
|
||||||
|
writer: TxClient,
|
||||||
|
sessionId: string,
|
||||||
|
role: string,
|
||||||
|
messageId: string
|
||||||
|
): Promise<SessionMessageRow | null> {
|
||||||
|
const candidateRows: SessionMessageRow[] = await writer
|
||||||
|
.select()
|
||||||
|
.from(sessionMessagesTable)
|
||||||
|
.where(and(eq(sessionMessagesTable.session_id, sessionId), eq(sessionMessagesTable.role, role)))
|
||||||
|
.orderBy(asc(sessionMessagesTable.created_at))
|
||||||
|
|
||||||
|
for (const row of candidateRows) {
|
||||||
|
if (!row?.content) continue
|
||||||
|
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(row.content) as AgentPersistedMessage | undefined
|
||||||
|
if (parsed?.message?.id === messageId) {
|
||||||
|
return row
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn('Failed to parse session message content JSON during lookup', error as Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
private async upsertMessage(
|
||||||
|
params: PersistUserMessageParams | PersistAssistantMessageParams
|
||||||
|
): Promise<AgentSessionMessageEntity> {
|
||||||
await AgentMessageRepository.initialize()
|
await AgentMessageRepository.initialize()
|
||||||
this.ensureInitialized()
|
this.ensureInitialized()
|
||||||
|
|
||||||
const writer = this.getWriter(params.tx)
|
const { sessionId, agentSessionId = '', payload, metadata, createdAt, tx } = params
|
||||||
const now = params.createdAt ?? params.payload.message.createdAt ?? new Date().toISOString()
|
|
||||||
|
if (!payload?.message?.role) {
|
||||||
|
throw new Error('Message payload missing role')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!payload.message.id) {
|
||||||
|
throw new Error('Message payload missing id')
|
||||||
|
}
|
||||||
|
|
||||||
|
const writer = this.getWriter(tx)
|
||||||
|
const now = createdAt ?? payload.message.createdAt ?? new Date().toISOString()
|
||||||
|
const serializedPayload = this.serializeMessage(payload)
|
||||||
|
const serializedMetadata = this.serializeMetadata(metadata)
|
||||||
|
|
||||||
|
const existingRow = await this.findExistingMessageRow(writer, sessionId, payload.message.role, payload.message.id)
|
||||||
|
|
||||||
|
if (existingRow) {
|
||||||
|
const metadataToPersist = serializedMetadata ?? existingRow.metadata ?? undefined
|
||||||
|
const agentSessionToPersist = agentSessionId || existingRow.agent_session_id || ''
|
||||||
|
|
||||||
|
await writer
|
||||||
|
.update(sessionMessagesTable)
|
||||||
|
.set({
|
||||||
|
content: serializedPayload,
|
||||||
|
metadata: metadataToPersist,
|
||||||
|
agent_session_id: agentSessionToPersist,
|
||||||
|
updated_at: now
|
||||||
|
})
|
||||||
|
.where(eq(sessionMessagesTable.id, existingRow.id))
|
||||||
|
|
||||||
|
return this.deserialize({
|
||||||
|
...existingRow,
|
||||||
|
content: serializedPayload,
|
||||||
|
metadata: metadataToPersist,
|
||||||
|
agent_session_id: agentSessionToPersist,
|
||||||
|
updated_at: now
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
const insertData: InsertSessionMessageRow = {
|
const insertData: InsertSessionMessageRow = {
|
||||||
session_id: params.sessionId,
|
session_id: sessionId,
|
||||||
role: params.payload.message.role,
|
role: payload.message.role,
|
||||||
content: this.serializeMessage(params.payload),
|
content: serializedPayload,
|
||||||
agent_session_id: params.agentSessionId ?? '',
|
agent_session_id: agentSessionId,
|
||||||
metadata: this.serializeMetadata(params.metadata),
|
metadata: serializedMetadata,
|
||||||
created_at: now,
|
created_at: now,
|
||||||
updated_at: now
|
updated_at: now
|
||||||
}
|
}
|
||||||
@@ -112,26 +180,12 @@ class AgentMessageRepository extends BaseService {
|
|||||||
return this.deserialize(saved)
|
return this.deserialize(saved)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async persistUserMessage(params: PersistUserMessageParams): Promise<AgentSessionMessageEntity> {
|
||||||
|
return this.upsertMessage({ ...params, agentSessionId: params.agentSessionId ?? '' })
|
||||||
|
}
|
||||||
|
|
||||||
async persistAssistantMessage(params: PersistAssistantMessageParams): Promise<AgentSessionMessageEntity> {
|
async persistAssistantMessage(params: PersistAssistantMessageParams): Promise<AgentSessionMessageEntity> {
|
||||||
await AgentMessageRepository.initialize()
|
return this.upsertMessage(params)
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
const writer = this.getWriter(params.tx)
|
|
||||||
const now = params.createdAt ?? params.payload.message.createdAt ?? new Date().toISOString()
|
|
||||||
|
|
||||||
const insertData: InsertSessionMessageRow = {
|
|
||||||
session_id: params.sessionId,
|
|
||||||
role: params.payload.message.role,
|
|
||||||
content: this.serializeMessage(params.payload),
|
|
||||||
agent_session_id: params.agentSessionId,
|
|
||||||
metadata: this.serializeMetadata(params.metadata),
|
|
||||||
created_at: now,
|
|
||||||
updated_at: now
|
|
||||||
}
|
|
||||||
|
|
||||||
const [saved] = await writer.insert(sessionMessagesTable).values(insertData).returning()
|
|
||||||
|
|
||||||
return this.deserialize(saved)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async persistExchange(params: PersistExchangeParams): Promise<PersistExchangeResult> {
|
async persistExchange(params: PersistExchangeParams): Promise<PersistExchangeResult> {
|
||||||
@@ -144,9 +198,6 @@ class AgentMessageRepository extends BaseService {
|
|||||||
const exchangeResult: PersistExchangeResult = {}
|
const exchangeResult: PersistExchangeResult = {}
|
||||||
|
|
||||||
if (user?.payload) {
|
if (user?.payload) {
|
||||||
if (!user.payload.message?.role) {
|
|
||||||
throw new Error('User message payload missing role')
|
|
||||||
}
|
|
||||||
exchangeResult.userMessage = await this.persistUserMessage({
|
exchangeResult.userMessage = await this.persistUserMessage({
|
||||||
sessionId,
|
sessionId,
|
||||||
agentSessionId,
|
agentSessionId,
|
||||||
@@ -158,9 +209,6 @@ class AgentMessageRepository extends BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (assistant?.payload) {
|
if (assistant?.payload) {
|
||||||
if (!assistant.payload.message?.role) {
|
|
||||||
throw new Error('Assistant message payload missing role')
|
|
||||||
}
|
|
||||||
exchangeResult.assistantMessage = await this.persistAssistantMessage({
|
exchangeResult.assistantMessage = await this.persistAssistantMessage({
|
||||||
sessionId,
|
sessionId,
|
||||||
agentSessionId,
|
agentSessionId,
|
||||||
@@ -176,6 +224,34 @@ class AgentMessageRepository extends BaseService {
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async getSessionHistory(sessionId: string): Promise<AgentPersistedMessage[]> {
|
||||||
|
await AgentMessageRepository.initialize()
|
||||||
|
this.ensureInitialized()
|
||||||
|
|
||||||
|
try {
|
||||||
|
const rows = await this.database
|
||||||
|
.select()
|
||||||
|
.from(sessionMessagesTable)
|
||||||
|
.where(eq(sessionMessagesTable.session_id, sessionId))
|
||||||
|
.orderBy(asc(sessionMessagesTable.created_at))
|
||||||
|
|
||||||
|
const messages: AgentPersistedMessage[] = []
|
||||||
|
|
||||||
|
for (const row of rows) {
|
||||||
|
const deserialized = this.deserialize(row)
|
||||||
|
if (deserialized?.content) {
|
||||||
|
messages.push(deserialized.content as AgentPersistedMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(`Loaded ${messages.length} messages for session ${sessionId}`)
|
||||||
|
return messages
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to load session history', error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export const agentMessageRepository = AgentMessageRepository.getInstance()
|
export const agentMessageRepository = AgentMessageRepository.getInstance()
|
||||||
|
|||||||
@@ -20,4 +20,3 @@ export class AgentModelValidationError extends Error {
|
|||||||
this.detail = detail
|
this.detail = detail
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import { count, eq } from 'drizzle-orm'
|
|||||||
import { BaseService } from '../BaseService'
|
import { BaseService } from '../BaseService'
|
||||||
import { type AgentRow, agentsTable, type InsertAgentRow } from '../database/schema'
|
import { type AgentRow, agentsTable, type InsertAgentRow } from '../database/schema'
|
||||||
import { AgentModelField } from '../errors'
|
import { AgentModelField } from '../errors'
|
||||||
import { builtinTools } from './claudecode/tools'
|
|
||||||
|
|
||||||
export class AgentService extends BaseService {
|
export class AgentService extends BaseService {
|
||||||
private static instance: AgentService | null = null
|
private static instance: AgentService | null = null
|
||||||
@@ -92,10 +91,7 @@ export class AgentService extends BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const agent = this.deserializeJsonFields(result[0]) as GetAgentResponse
|
const agent = this.deserializeJsonFields(result[0]) as GetAgentResponse
|
||||||
if (agent.type === 'claude-code') {
|
agent.tools = await this.listMcpTools(agent.type, agent.mcps)
|
||||||
agent.built_in_tools = builtinTools
|
|
||||||
}
|
|
||||||
|
|
||||||
return agent
|
return agent
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,11 +111,9 @@ export class AgentService extends BaseService {
|
|||||||
|
|
||||||
const agents = result.map((row) => this.deserializeJsonFields(row)) as GetAgentResponse[]
|
const agents = result.map((row) => this.deserializeJsonFields(row)) as GetAgentResponse[]
|
||||||
|
|
||||||
agents.forEach((agent) => {
|
for (const agent of agents) {
|
||||||
if (agent.type === 'claude-code') {
|
agent.tools = await this.listMcpTools(agent.type, agent.mcps)
|
||||||
agent.built_in_tools = builtinTools
|
}
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return { agents, total: totalResult[0].count }
|
return { agents, total: totalResult[0].count }
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import type {
|
|||||||
GetAgentSessionResponse,
|
GetAgentSessionResponse,
|
||||||
ListOptions
|
ListOptions
|
||||||
} from '@types'
|
} from '@types'
|
||||||
import { ModelMessage, TextStreamPart } from 'ai'
|
import { TextStreamPart } from 'ai'
|
||||||
import { desc, eq } from 'drizzle-orm'
|
import { and, desc, eq, not } from 'drizzle-orm'
|
||||||
|
|
||||||
import { BaseService } from '../BaseService'
|
import { BaseService } from '../BaseService'
|
||||||
import { sessionMessagesTable } from '../database/schema'
|
import { sessionMessagesTable } from '../database/schema'
|
||||||
@@ -68,43 +68,29 @@ class TextStreamAccumulator {
|
|||||||
}
|
}
|
||||||
case 'tool-call':
|
case 'tool-call':
|
||||||
if (part.toolCallId) {
|
if (part.toolCallId) {
|
||||||
|
const legacyPart = part as typeof part & {
|
||||||
|
args?: unknown
|
||||||
|
providerMetadata?: { raw?: { input?: unknown } }
|
||||||
|
}
|
||||||
this.toolCalls.set(part.toolCallId, {
|
this.toolCalls.set(part.toolCallId, {
|
||||||
toolName: part.toolName,
|
toolName: part.toolName,
|
||||||
input: part.input ?? part.args ?? part.providerMetadata?.raw?.input
|
input: part.input ?? legacyPart.args ?? legacyPart.providerMetadata?.raw?.input
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
case 'tool-result':
|
case 'tool-result':
|
||||||
if (part.toolCallId) {
|
if (part.toolCallId) {
|
||||||
this.toolResults.set(part.toolCallId, part.output ?? part.result ?? part.providerMetadata?.raw)
|
const legacyPart = part as typeof part & {
|
||||||
|
result?: unknown
|
||||||
|
providerMetadata?: { raw?: unknown }
|
||||||
|
}
|
||||||
|
this.toolResults.set(part.toolCallId, part.output ?? legacyPart.result ?? legacyPart.providerMetadata?.raw)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
toModelMessage(role: ModelMessage['role'] = 'assistant'): ModelMessage {
|
|
||||||
const content = this.totalText || this.textBuffer || ''
|
|
||||||
|
|
||||||
const toolInvocations = Array.from(this.toolCalls.entries()).map(([toolCallId, info]) => ({
|
|
||||||
toolCallId,
|
|
||||||
toolName: info.toolName,
|
|
||||||
args: info.input,
|
|
||||||
result: this.toolResults.get(toolCallId)
|
|
||||||
}))
|
|
||||||
|
|
||||||
const message: Record<string, unknown> = {
|
|
||||||
role,
|
|
||||||
content
|
|
||||||
}
|
|
||||||
|
|
||||||
if (toolInvocations.length > 0) {
|
|
||||||
message.toolInvocations = toolInvocations
|
|
||||||
}
|
|
||||||
|
|
||||||
return message as ModelMessage
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export class SessionMessageService extends BaseService {
|
export class SessionMessageService extends BaseService {
|
||||||
@@ -159,6 +145,16 @@ export class SessionMessageService extends BaseService {
|
|||||||
return { messages }
|
return { messages }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async deleteSessionMessage(sessionId: string, messageId: number): Promise<boolean> {
|
||||||
|
this.ensureInitialized()
|
||||||
|
|
||||||
|
const result = await this.database
|
||||||
|
.delete(sessionMessagesTable)
|
||||||
|
.where(and(eq(sessionMessagesTable.id, messageId), eq(sessionMessagesTable.session_id, sessionId)))
|
||||||
|
|
||||||
|
return result.rowsAffected > 0
|
||||||
|
}
|
||||||
|
|
||||||
async createSessionMessage(
|
async createSessionMessage(
|
||||||
session: GetAgentSessionResponse,
|
session: GetAgentSessionResponse,
|
||||||
messageData: CreateSessionMessageRequest,
|
messageData: CreateSessionMessageRequest,
|
||||||
@@ -175,7 +171,6 @@ export class SessionMessageService extends BaseService {
|
|||||||
abortController: AbortController
|
abortController: AbortController
|
||||||
): Promise<SessionStreamResult> {
|
): Promise<SessionStreamResult> {
|
||||||
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
||||||
let newAgentSessionId = ''
|
|
||||||
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
|
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
|
||||||
|
|
||||||
if (session.agent_type !== 'claude-code') {
|
if (session.agent_type !== 'claude-code') {
|
||||||
@@ -222,10 +217,6 @@ export class SessionMessageService extends BaseService {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (chunk.type === 'start' && chunk.messageId) {
|
|
||||||
newAgentSessionId = chunk.messageId
|
|
||||||
}
|
|
||||||
|
|
||||||
accumulator.add(chunk)
|
accumulator.add(chunk)
|
||||||
controller.enqueue(chunk)
|
controller.enqueue(chunk)
|
||||||
break
|
break
|
||||||
@@ -285,10 +276,11 @@ export class SessionMessageService extends BaseService {
|
|||||||
const result = await this.database
|
const result = await this.database
|
||||||
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
|
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
|
||||||
.from(sessionMessagesTable)
|
.from(sessionMessagesTable)
|
||||||
.where(eq(sessionMessagesTable.session_id, sessionId))
|
.where(and(eq(sessionMessagesTable.session_id, sessionId), not(eq(sessionMessagesTable.agent_session_id, ''))))
|
||||||
.orderBy(desc(sessionMessagesTable.created_at))
|
.orderBy(desc(sessionMessagesTable.created_at))
|
||||||
.limit(1)
|
.limit(1)
|
||||||
|
|
||||||
|
logger.silly('Last agent session ID result:', { agentSessionId: result[0]?.agent_session_id, sessionId })
|
||||||
return result[0]?.agent_session_id || ''
|
return result[0]?.agent_session_id || ''
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Failed to get last agent session ID', {
|
logger.error('Failed to get last agent session ID', {
|
||||||
|
|||||||
@@ -3,13 +3,12 @@ import {
|
|||||||
type AgentEntity,
|
type AgentEntity,
|
||||||
type AgentSessionEntity,
|
type AgentSessionEntity,
|
||||||
type CreateSessionRequest,
|
type CreateSessionRequest,
|
||||||
type CreateSessionResponse,
|
|
||||||
type GetAgentSessionResponse,
|
type GetAgentSessionResponse,
|
||||||
type ListOptions,
|
type ListOptions,
|
||||||
type UpdateSessionRequest,
|
type UpdateSessionRequest,
|
||||||
UpdateSessionResponse
|
UpdateSessionResponse
|
||||||
} from '@types'
|
} from '@types'
|
||||||
import { and, count, eq, type SQL } from 'drizzle-orm'
|
import { and, count, desc, eq, type SQL } from 'drizzle-orm'
|
||||||
|
|
||||||
import { BaseService } from '../BaseService'
|
import { BaseService } from '../BaseService'
|
||||||
import { agentsTable, type InsertSessionRow, type SessionRow, sessionsTable } from '../database/schema'
|
import { agentsTable, type InsertSessionRow, type SessionRow, sessionsTable } from '../database/schema'
|
||||||
@@ -30,7 +29,10 @@ export class SessionService extends BaseService {
|
|||||||
await BaseService.initialize()
|
await BaseService.initialize()
|
||||||
}
|
}
|
||||||
|
|
||||||
async createSession(agentId: string, req: CreateSessionRequest): Promise<CreateSessionResponse> {
|
async createSession(
|
||||||
|
agentId: string,
|
||||||
|
req: Partial<CreateSessionRequest> = {}
|
||||||
|
): Promise<GetAgentSessionResponse | null> {
|
||||||
this.ensureInitialized()
|
this.ensureInitialized()
|
||||||
|
|
||||||
// Validate agent exists - we'll need to import AgentService for this check
|
// Validate agent exists - we'll need to import AgentService for this check
|
||||||
@@ -89,7 +91,8 @@ export class SessionService extends BaseService {
|
|||||||
throw new Error('Failed to create session')
|
throw new Error('Failed to create session')
|
||||||
}
|
}
|
||||||
|
|
||||||
return this.deserializeJsonFields(result[0]) as AgentSessionEntity
|
const session = this.deserializeJsonFields(result[0])
|
||||||
|
return await this.getSession(agentId, session.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
async getSession(agentId: string, id: string): Promise<GetAgentSessionResponse | null> {
|
async getSession(agentId: string, id: string): Promise<GetAgentSessionResponse | null> {
|
||||||
@@ -106,21 +109,8 @@ export class SessionService extends BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const session = this.deserializeJsonFields(result[0]) as GetAgentSessionResponse
|
const session = this.deserializeJsonFields(result[0]) as GetAgentSessionResponse
|
||||||
|
session.tools = await this.listMcpTools(session.agent_type, session.mcps)
|
||||||
return session
|
session.slash_commands = await this.listSlashCommands(session.agent_type)
|
||||||
}
|
|
||||||
|
|
||||||
async getSessionById(id: string): Promise<GetAgentSessionResponse | null> {
|
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
const result = await this.database.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
|
|
||||||
|
|
||||||
if (!result[0]) {
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
const session = this.deserializeJsonFields(result[0]) as GetAgentSessionResponse
|
|
||||||
|
|
||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,8 +138,12 @@ export class SessionService extends BaseService {
|
|||||||
|
|
||||||
const total = totalResult[0].count
|
const total = totalResult[0].count
|
||||||
|
|
||||||
// Build list query with pagination
|
// Build list query with pagination - sort by updated_at descending (latest first)
|
||||||
const baseQuery = this.database.select().from(sessionsTable).where(whereClause).orderBy(sessionsTable.created_at)
|
const baseQuery = this.database
|
||||||
|
.select()
|
||||||
|
.from(sessionsTable)
|
||||||
|
.where(whereClause)
|
||||||
|
.orderBy(desc(sessionsTable.updated_at))
|
||||||
|
|
||||||
const result =
|
const result =
|
||||||
options.limit !== undefined
|
options.limit !== undefined
|
||||||
|
|||||||
@@ -0,0 +1,290 @@
|
|||||||
|
import type { SDKMessage } from '@anthropic-ai/claude-agent-sdk'
|
||||||
|
import { describe, expect, it } from 'vitest'
|
||||||
|
|
||||||
|
import { ClaudeStreamState, transformSDKMessageToStreamParts } from '../transform'
|
||||||
|
|
||||||
|
const baseStreamMetadata = {
|
||||||
|
parent_tool_use_id: null,
|
||||||
|
session_id: 'session-123'
|
||||||
|
}
|
||||||
|
|
||||||
|
const uuid = (n: number) => `00000000-0000-0000-0000-${n.toString().padStart(12, '0')}`
|
||||||
|
|
||||||
|
describe('Claude → AiSDK transform', () => {
|
||||||
|
it('handles tool call streaming lifecycle', () => {
|
||||||
|
const state = new ClaudeStreamState()
|
||||||
|
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
|
||||||
|
|
||||||
|
const messages: SDKMessage[] = [
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(1),
|
||||||
|
event: {
|
||||||
|
type: 'message_start',
|
||||||
|
message: {
|
||||||
|
id: 'msg-start',
|
||||||
|
type: 'message',
|
||||||
|
role: 'assistant',
|
||||||
|
model: 'claude-test',
|
||||||
|
content: [],
|
||||||
|
stop_reason: null,
|
||||||
|
stop_sequence: null,
|
||||||
|
usage: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(2),
|
||||||
|
event: {
|
||||||
|
type: 'content_block_start',
|
||||||
|
index: 0,
|
||||||
|
content_block: {
|
||||||
|
type: 'tool_use',
|
||||||
|
id: 'tool-1',
|
||||||
|
name: 'Bash',
|
||||||
|
input: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(3),
|
||||||
|
event: {
|
||||||
|
type: 'content_block_delta',
|
||||||
|
index: 0,
|
||||||
|
delta: {
|
||||||
|
type: 'input_json_delta',
|
||||||
|
partial_json: '{"command":"ls"}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'assistant',
|
||||||
|
uuid: uuid(4),
|
||||||
|
message: {
|
||||||
|
id: 'msg-tool',
|
||||||
|
type: 'message',
|
||||||
|
role: 'assistant',
|
||||||
|
model: 'claude-test',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool_use',
|
||||||
|
id: 'tool-1',
|
||||||
|
name: 'Bash',
|
||||||
|
input: {
|
||||||
|
command: 'ls'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stop_reason: 'tool_use',
|
||||||
|
stop_sequence: null,
|
||||||
|
usage: {
|
||||||
|
input_tokens: 1,
|
||||||
|
output_tokens: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(5),
|
||||||
|
event: {
|
||||||
|
type: 'content_block_stop',
|
||||||
|
index: 0
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(6),
|
||||||
|
event: {
|
||||||
|
type: 'message_delta',
|
||||||
|
delta: {
|
||||||
|
stop_reason: 'tool_use',
|
||||||
|
stop_sequence: null
|
||||||
|
},
|
||||||
|
usage: {
|
||||||
|
input_tokens: 1,
|
||||||
|
output_tokens: 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(7),
|
||||||
|
event: {
|
||||||
|
type: 'message_stop'
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'user',
|
||||||
|
uuid: uuid(8),
|
||||||
|
message: {
|
||||||
|
role: 'user',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool_result',
|
||||||
|
tool_use_id: 'tool-1',
|
||||||
|
content: 'ok',
|
||||||
|
is_error: false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
} as SDKMessage
|
||||||
|
]
|
||||||
|
|
||||||
|
for (const message of messages) {
|
||||||
|
const transformed = transformSDKMessageToStreamParts(message, state)
|
||||||
|
for (const part of transformed) {
|
||||||
|
parts.push(part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const types = parts.map((part) => part.type)
|
||||||
|
expect(types).toEqual([
|
||||||
|
'start-step',
|
||||||
|
'tool-input-start',
|
||||||
|
'tool-input-delta',
|
||||||
|
'tool-call',
|
||||||
|
'tool-input-end',
|
||||||
|
'finish-step',
|
||||||
|
'tool-result'
|
||||||
|
])
|
||||||
|
|
||||||
|
const finishStep = parts.find((part) => part.type === 'finish-step') as Extract<
|
||||||
|
(typeof parts)[number],
|
||||||
|
{ type: 'finish-step' }
|
||||||
|
>
|
||||||
|
expect(finishStep.finishReason).toBe('tool-calls')
|
||||||
|
expect(finishStep.usage).toEqual({ inputTokens: 1, outputTokens: 5, totalTokens: 6 })
|
||||||
|
|
||||||
|
const toolResult = parts.find((part) => part.type === 'tool-result') as Extract<
|
||||||
|
(typeof parts)[number],
|
||||||
|
{ type: 'tool-result' }
|
||||||
|
>
|
||||||
|
expect(toolResult.toolCallId).toBe('tool-1')
|
||||||
|
expect(toolResult.toolName).toBe('Bash')
|
||||||
|
expect(toolResult.input).toEqual({ command: 'ls' })
|
||||||
|
expect(toolResult.output).toBe('ok')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('handles streaming text completion', () => {
|
||||||
|
const state = new ClaudeStreamState()
|
||||||
|
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
|
||||||
|
|
||||||
|
const messages: SDKMessage[] = [
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(9),
|
||||||
|
event: {
|
||||||
|
type: 'message_start',
|
||||||
|
message: {
|
||||||
|
id: 'msg-text',
|
||||||
|
type: 'message',
|
||||||
|
role: 'assistant',
|
||||||
|
model: 'claude-text',
|
||||||
|
content: [],
|
||||||
|
stop_reason: null,
|
||||||
|
stop_sequence: null,
|
||||||
|
usage: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(10),
|
||||||
|
event: {
|
||||||
|
type: 'content_block_start',
|
||||||
|
index: 0,
|
||||||
|
content_block: {
|
||||||
|
type: 'text',
|
||||||
|
text: ''
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(11),
|
||||||
|
event: {
|
||||||
|
type: 'content_block_delta',
|
||||||
|
index: 0,
|
||||||
|
delta: {
|
||||||
|
type: 'text_delta',
|
||||||
|
text: 'Hello'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(12),
|
||||||
|
event: {
|
||||||
|
type: 'content_block_delta',
|
||||||
|
index: 0,
|
||||||
|
delta: {
|
||||||
|
type: 'text_delta',
|
||||||
|
text: ' world'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(13),
|
||||||
|
event: {
|
||||||
|
type: 'content_block_stop',
|
||||||
|
index: 0
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(14),
|
||||||
|
event: {
|
||||||
|
type: 'message_delta',
|
||||||
|
delta: {
|
||||||
|
stop_reason: 'end_turn',
|
||||||
|
stop_sequence: null
|
||||||
|
},
|
||||||
|
usage: {
|
||||||
|
input_tokens: 2,
|
||||||
|
output_tokens: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(15),
|
||||||
|
event: {
|
||||||
|
type: 'message_stop'
|
||||||
|
}
|
||||||
|
} as SDKMessage
|
||||||
|
]
|
||||||
|
|
||||||
|
for (const message of messages) {
|
||||||
|
const transformed = transformSDKMessageToStreamParts(message, state)
|
||||||
|
parts.push(...transformed)
|
||||||
|
}
|
||||||
|
|
||||||
|
const types = parts.map((part) => part.type)
|
||||||
|
expect(types).toEqual(['start-step', 'text-start', 'text-delta', 'text-delta', 'text-end', 'finish-step'])
|
||||||
|
|
||||||
|
const finishStep = parts.find((part) => part.type === 'finish-step') as Extract<
|
||||||
|
(typeof parts)[number],
|
||||||
|
{ type: 'finish-step' }
|
||||||
|
>
|
||||||
|
expect(finishStep.finishReason).toBe('stop')
|
||||||
|
expect(finishStep.usage).toEqual({ inputTokens: 2, outputTokens: 4, totalTokens: 6 })
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -0,0 +1,241 @@
|
|||||||
|
/**
|
||||||
|
* Lightweight state container shared by the Claude → AiSDK transformer. Anthropic does not send
|
||||||
|
* deterministic identifiers for intermediate content blocks, so we stitch one together by tracking
|
||||||
|
* block indices and associated AiSDK ids. This class also keeps:
|
||||||
|
* • incremental text / reasoning buffers so we can emit only deltas while retaining the full
|
||||||
|
* aggregate for later tool-call emission;
|
||||||
|
* • a reverse lookup for tool calls so `tool_result` snapshots can recover their metadata;
|
||||||
|
* • pending usage + finish reason from `message_delta` events until the corresponding
|
||||||
|
* `message_stop` arrives.
|
||||||
|
* Every Claude turn gets its own instance. `resetStep` should be invoked once the finish event has
|
||||||
|
* been emitted to avoid leaking state into the next turn.
|
||||||
|
*/
|
||||||
|
import type { FinishReason, LanguageModelUsage, ProviderMetadata } from 'ai'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shared fields for every block that Claude can stream (text, reasoning, tool).
|
||||||
|
*/
|
||||||
|
type BaseBlockState = {
|
||||||
|
id: string
|
||||||
|
index: number
|
||||||
|
}
|
||||||
|
|
||||||
|
type TextBlockState = BaseBlockState & {
|
||||||
|
kind: 'text'
|
||||||
|
text: string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReasoningBlockState = BaseBlockState & {
|
||||||
|
kind: 'reasoning'
|
||||||
|
text: string
|
||||||
|
redacted: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolBlockState = BaseBlockState & {
|
||||||
|
kind: 'tool'
|
||||||
|
toolCallId: string
|
||||||
|
toolName: string
|
||||||
|
inputBuffer: string
|
||||||
|
providerMetadata?: ProviderMetadata
|
||||||
|
resolvedInput?: unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
export type BlockState = TextBlockState | ReasoningBlockState | ToolBlockState
|
||||||
|
|
||||||
|
type PendingUsageState = {
|
||||||
|
usage?: LanguageModelUsage
|
||||||
|
finishReason?: FinishReason
|
||||||
|
}
|
||||||
|
|
||||||
|
type PendingToolCall = {
|
||||||
|
toolCallId: string
|
||||||
|
toolName: string
|
||||||
|
input: unknown
|
||||||
|
providerMetadata?: ProviderMetadata
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tracks the lifecycle of Claude streaming blocks (text, thinking, tool calls)
|
||||||
|
* across individual websocket events. The transformer relies on this class to
|
||||||
|
* stitch together deltas, manage pending tool inputs/results, and propagate
|
||||||
|
* usage/finish metadata once Anthropic closes a message.
|
||||||
|
*/
|
||||||
|
export class ClaudeStreamState {
|
||||||
|
private blocksByIndex = new Map<number, BlockState>()
|
||||||
|
private toolIndexById = new Map<string, number>()
|
||||||
|
private pendingUsage: PendingUsageState = {}
|
||||||
|
private pendingToolCalls = new Map<string, PendingToolCall>()
|
||||||
|
private stepActive = false
|
||||||
|
|
||||||
|
/** Marks the beginning of a new AiSDK step. */
|
||||||
|
beginStep(): void {
|
||||||
|
this.stepActive = true
|
||||||
|
}
|
||||||
|
|
||||||
|
hasActiveStep(): boolean {
|
||||||
|
return this.stepActive
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Creates a text block placeholder so future deltas can accumulate into it. */
|
||||||
|
openTextBlock(index: number, id: string): TextBlockState {
|
||||||
|
const block: TextBlockState = {
|
||||||
|
kind: 'text',
|
||||||
|
id,
|
||||||
|
index,
|
||||||
|
text: ''
|
||||||
|
}
|
||||||
|
this.blocksByIndex.set(index, block)
|
||||||
|
return block
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Starts tracking an Anthropic "thinking" block, optionally flagged as redacted. */
|
||||||
|
openReasoningBlock(index: number, id: string, redacted: boolean): ReasoningBlockState {
|
||||||
|
const block: ReasoningBlockState = {
|
||||||
|
kind: 'reasoning',
|
||||||
|
id,
|
||||||
|
index,
|
||||||
|
redacted,
|
||||||
|
text: ''
|
||||||
|
}
|
||||||
|
this.blocksByIndex.set(index, block)
|
||||||
|
return block
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Caches tool metadata so subsequent input deltas and results can find it. */
|
||||||
|
openToolBlock(
|
||||||
|
index: number,
|
||||||
|
params: { toolCallId: string; toolName: string; providerMetadata?: ProviderMetadata }
|
||||||
|
): ToolBlockState {
|
||||||
|
const block: ToolBlockState = {
|
||||||
|
kind: 'tool',
|
||||||
|
id: params.toolCallId,
|
||||||
|
index,
|
||||||
|
toolCallId: params.toolCallId,
|
||||||
|
toolName: params.toolName,
|
||||||
|
inputBuffer: '',
|
||||||
|
providerMetadata: params.providerMetadata
|
||||||
|
}
|
||||||
|
this.blocksByIndex.set(index, block)
|
||||||
|
this.toolIndexById.set(params.toolCallId, index)
|
||||||
|
return block
|
||||||
|
}
|
||||||
|
|
||||||
|
getBlock(index: number): BlockState | undefined {
|
||||||
|
return this.blocksByIndex.get(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
getToolBlockById(toolCallId: string): ToolBlockState | undefined {
|
||||||
|
const index = this.toolIndexById.get(toolCallId)
|
||||||
|
if (index === undefined) return undefined
|
||||||
|
const block = this.blocksByIndex.get(index)
|
||||||
|
if (!block || block.kind !== 'tool') return undefined
|
||||||
|
return block
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Appends streamed text to a text block, returning the updated state when present. */
|
||||||
|
appendTextDelta(index: number, text: string): TextBlockState | undefined {
|
||||||
|
const block = this.blocksByIndex.get(index)
|
||||||
|
if (!block || block.kind !== 'text') return undefined
|
||||||
|
block.text += text
|
||||||
|
return block
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Appends streamed "thinking" content to the tracked reasoning block. */
|
||||||
|
appendReasoningDelta(index: number, text: string): ReasoningBlockState | undefined {
|
||||||
|
const block = this.blocksByIndex.get(index)
|
||||||
|
if (!block || block.kind !== 'reasoning') return undefined
|
||||||
|
block.text += text
|
||||||
|
return block
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Concatenates incremental JSON payloads for tool input blocks. */
|
||||||
|
appendToolInputDelta(index: number, jsonDelta: string): ToolBlockState | undefined {
|
||||||
|
const block = this.blocksByIndex.get(index)
|
||||||
|
if (!block || block.kind !== 'tool') return undefined
|
||||||
|
block.inputBuffer += jsonDelta
|
||||||
|
return block
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Records a tool call to be consumed once its result arrives from the user. */
|
||||||
|
registerToolCall(
|
||||||
|
toolCallId: string,
|
||||||
|
payload: { toolName: string; input: unknown; providerMetadata?: ProviderMetadata }
|
||||||
|
): void {
|
||||||
|
this.pendingToolCalls.set(toolCallId, {
|
||||||
|
toolCallId,
|
||||||
|
toolName: payload.toolName,
|
||||||
|
input: payload.input,
|
||||||
|
providerMetadata: payload.providerMetadata
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Retrieves and clears the buffered tool call metadata for the given id. */
|
||||||
|
consumePendingToolCall(toolCallId: string): PendingToolCall | undefined {
|
||||||
|
const entry = this.pendingToolCalls.get(toolCallId)
|
||||||
|
if (entry) {
|
||||||
|
this.pendingToolCalls.delete(toolCallId)
|
||||||
|
}
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Persists the final input payload for a tool block once the provider signals
|
||||||
|
* completion so that downstream tool results can reference the original call.
|
||||||
|
*/
|
||||||
|
completeToolBlock(toolCallId: string, input: unknown, providerMetadata?: ProviderMetadata): void {
|
||||||
|
this.registerToolCall(toolCallId, {
|
||||||
|
toolName: this.getToolBlockById(toolCallId)?.toolName ?? 'unknown',
|
||||||
|
input,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
const block = this.getToolBlockById(toolCallId)
|
||||||
|
if (block) {
|
||||||
|
block.resolvedInput = input
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Removes a block from the active index map when Claude signals it is done. */
|
||||||
|
closeBlock(index: number): BlockState | undefined {
|
||||||
|
const block = this.blocksByIndex.get(index)
|
||||||
|
if (!block) return undefined
|
||||||
|
this.blocksByIndex.delete(index)
|
||||||
|
if (block.kind === 'tool') {
|
||||||
|
this.toolIndexById.delete(block.toolCallId)
|
||||||
|
}
|
||||||
|
return block
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Stores interim usage metrics so they can be emitted with the `finish-step`. */
|
||||||
|
setPendingUsage(usage?: LanguageModelUsage, finishReason?: FinishReason): void {
|
||||||
|
if (usage) {
|
||||||
|
this.pendingUsage.usage = usage
|
||||||
|
}
|
||||||
|
if (finishReason) {
|
||||||
|
this.pendingUsage.finishReason = finishReason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
getPendingUsage(): PendingUsageState {
|
||||||
|
return { ...this.pendingUsage }
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Clears any accumulated usage values for the next streamed message. */
|
||||||
|
resetPendingUsage(): void {
|
||||||
|
this.pendingUsage = {}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Drops cached block metadata for the currently active message. */
|
||||||
|
resetBlocks(): void {
|
||||||
|
this.blocksByIndex.clear()
|
||||||
|
this.toolIndexById.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Resets the entire step lifecycle after emitting a terminal frame. */
|
||||||
|
resetStep(): void {
|
||||||
|
this.resetBlocks()
|
||||||
|
this.resetPendingUsage()
|
||||||
|
this.stepActive = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export type { PendingToolCall }
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
import { SlashCommand } from '@types'
|
||||||
|
|
||||||
|
export const builtinSlashCommands: SlashCommand[] = [
|
||||||
|
{ command: '/add-dir', description: 'Add additional working directories' },
|
||||||
|
{ command: '/agents', description: 'Manage custom AI subagents for specialized tasks' },
|
||||||
|
{ command: '/bug', description: 'Report bugs (sends conversation to Anthropic)' },
|
||||||
|
{ command: '/clear', description: 'Clear conversation history' },
|
||||||
|
{ command: '/compact', description: 'Compact conversation with optional focus instructions' },
|
||||||
|
{ command: '/config', description: 'View/modify configuration' },
|
||||||
|
{ command: '/cost', description: 'Show token usage statistics' },
|
||||||
|
{ command: '/doctor', description: 'Checks the health of your Claude Code installation' },
|
||||||
|
{ command: '/help', description: 'Get usage help' },
|
||||||
|
{ command: '/init', description: 'Initialize project with CLAUDE.md guide' },
|
||||||
|
{ command: '/login', description: 'Switch Anthropic accounts' },
|
||||||
|
{ command: '/logout', description: 'Sign out from your Anthropic account' },
|
||||||
|
{ command: '/mcp', description: 'Manage MCP server connections and OAuth authentication' },
|
||||||
|
{ command: '/memory', description: 'Edit CLAUDE.md memory files' },
|
||||||
|
{ command: '/model', description: 'Select or change the AI model' },
|
||||||
|
{ command: '/permissions', description: 'View or update permissions' },
|
||||||
|
{ command: '/pr_comments', description: 'View pull request comments' },
|
||||||
|
{ command: '/review', description: 'Request code review' },
|
||||||
|
{ command: '/status', description: 'View account and system statuses' },
|
||||||
|
{ command: '/terminal-setup', description: 'Install Shift+Enter key binding for newlines (iTerm2 and VSCode only)' },
|
||||||
|
{ command: '/vim', description: 'Enter vim mode for alternating insert and command modes' }
|
||||||
|
]
|
||||||
@@ -2,14 +2,16 @@
|
|||||||
import { EventEmitter } from 'node:events'
|
import { EventEmitter } from 'node:events'
|
||||||
import { createRequire } from 'node:module'
|
import { createRequire } from 'node:module'
|
||||||
|
|
||||||
import { McpHttpServerConfig, Options, query, SDKMessage } from '@anthropic-ai/claude-code'
|
import { McpHttpServerConfig, Options, query, SDKMessage } from '@anthropic-ai/claude-agent-sdk'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { config as apiConfigService } from '@main/apiServer/config'
|
import { config as apiConfigService } from '@main/apiServer/config'
|
||||||
import { validateModelId } from '@main/apiServer/utils'
|
import { validateModelId } from '@main/apiServer/utils'
|
||||||
|
import getLoginShellEnvironment from '@main/utils/shell-env'
|
||||||
|
import { app } from 'electron'
|
||||||
|
|
||||||
import { GetAgentSessionResponse } from '../..'
|
import { GetAgentSessionResponse } from '../..'
|
||||||
import { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface'
|
import { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface'
|
||||||
import { transformSDKMessageToStreamParts } from './transform'
|
import { ClaudeStreamState, transformSDKMessageToStreamParts } from './transform'
|
||||||
|
|
||||||
const require_ = createRequire(import.meta.url)
|
const require_ = createRequire(import.meta.url)
|
||||||
const logger = loggerService.withContext('ClaudeCodeService')
|
const logger = loggerService.withContext('ClaudeCodeService')
|
||||||
@@ -25,7 +27,10 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
// Resolve Claude Code CLI robustly (works in dev and in asar)
|
// Resolve Claude Code CLI robustly (works in dev and in asar)
|
||||||
this.claudeExecutablePath = require_.resolve('@anthropic-ai/claude-code/cli.js')
|
this.claudeExecutablePath = require_.resolve('@anthropic-ai/claude-agent-sdk/cli.js')
|
||||||
|
if (app.isPackaged) {
|
||||||
|
this.claudeExecutablePath = this.claudeExecutablePath.replace(/\.asar([\\/])/, '.asar.unpacked$1')
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async invoke(
|
async invoke(
|
||||||
@@ -55,7 +60,15 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
})
|
})
|
||||||
return aiStream
|
return aiStream
|
||||||
}
|
}
|
||||||
if (modelInfo.provider?.type !== 'anthropic' || modelInfo.provider.apiKey === '') {
|
if (
|
||||||
|
(modelInfo.provider?.type !== 'anthropic' &&
|
||||||
|
(modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) ||
|
||||||
|
modelInfo.provider.apiKey === ''
|
||||||
|
) {
|
||||||
|
logger.error('Anthropic provider configuration is missing', {
|
||||||
|
modelInfo
|
||||||
|
})
|
||||||
|
|
||||||
aiStream.emit('data', {
|
aiStream.emit('data', {
|
||||||
type: 'error',
|
type: 'error',
|
||||||
error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`)
|
error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`)
|
||||||
@@ -63,24 +76,52 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
return aiStream
|
return aiStream
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: use cherry studio api server config instead of direct provider config to provide more flexibility (e.g. custom headers, proxy, statistics, etc).
|
|
||||||
const apiConfig = await apiConfigService.get()
|
const apiConfig = await apiConfigService.get()
|
||||||
// process.env.ANTHROPIC_AUTH_TOKEN = apiConfig.apiKey
|
const loginShellEnv = await getLoginShellEnvironment()
|
||||||
// process.env.ANTHROPIC_BASE_URL = `http://${apiConfig.host}:${apiConfig.port}`
|
const loginShellEnvWithoutProxies = Object.fromEntries(
|
||||||
process.env.ANTHROPIC_AUTH_TOKEN = modelInfo.provider.apiKey
|
Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy'))
|
||||||
process.env.ANTHROPIC_BASE_URL = modelInfo.provider.apiHost
|
) as Record<string, string>
|
||||||
|
|
||||||
|
const env = {
|
||||||
|
...loginShellEnvWithoutProxies,
|
||||||
|
// TODO: fix the proxy api server
|
||||||
|
// ANTHROPIC_API_KEY: apiConfig.apiKey,
|
||||||
|
// ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey,
|
||||||
|
// ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
|
||||||
|
ANTHROPIC_API_KEY: modelInfo.provider.apiKey,
|
||||||
|
ANTHROPIC_AUTH_TOKEN: modelInfo.provider.apiKey,
|
||||||
|
ANTHROPIC_BASE_URL: modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost,
|
||||||
|
ANTHROPIC_MODEL: modelInfo.modelId,
|
||||||
|
ANTHROPIC_SMALL_FAST_MODEL: modelInfo.modelId,
|
||||||
|
ELECTRON_RUN_AS_NODE: '1',
|
||||||
|
ELECTRON_NO_ATTACH_CONSOLE: '1'
|
||||||
|
}
|
||||||
|
|
||||||
|
const errorChunks: string[] = []
|
||||||
|
|
||||||
// Build SDK options from parameters
|
// Build SDK options from parameters
|
||||||
const options: Options = {
|
const options: Options = {
|
||||||
abortController,
|
abortController,
|
||||||
cwd,
|
cwd,
|
||||||
|
env,
|
||||||
|
// model: modelInfo.modelId,
|
||||||
pathToClaudeCodeExecutable: this.claudeExecutablePath,
|
pathToClaudeCodeExecutable: this.claudeExecutablePath,
|
||||||
stderr: (chunk: string) => {
|
stderr: (chunk: string) => {
|
||||||
logger.info('claude stderr', { chunk })
|
logger.warn('claude stderr', { chunk })
|
||||||
|
errorChunks.push(chunk)
|
||||||
},
|
},
|
||||||
appendSystemPrompt: session.instructions,
|
systemPrompt: session.instructions
|
||||||
|
? {
|
||||||
|
type: 'preset',
|
||||||
|
preset: 'claude_code',
|
||||||
|
append: session.instructions
|
||||||
|
}
|
||||||
|
: { type: 'preset', preset: 'claude_code' },
|
||||||
|
settingSources: ['project'],
|
||||||
|
includePartialMessages: true,
|
||||||
permissionMode: session.configuration?.permission_mode,
|
permissionMode: session.configuration?.permission_mode,
|
||||||
maxTurns: session.configuration?.max_turns
|
maxTurns: session.configuration?.max_turns,
|
||||||
|
allowedTools: session.allowed_tools
|
||||||
}
|
}
|
||||||
|
|
||||||
if (session.accessible_paths.length > 1) {
|
if (session.accessible_paths.length > 1) {
|
||||||
@@ -105,15 +146,32 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
|
|
||||||
if (lastAgentSessionId) {
|
if (lastAgentSessionId) {
|
||||||
options.resume = lastAgentSessionId
|
options.resume = lastAgentSessionId
|
||||||
|
// TODO: use fork session when we support branching sessions
|
||||||
|
// options.forkSession = true
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info('Starting Claude Code SDK query', {
|
logger.info('Starting Claude Code SDK query', {
|
||||||
prompt,
|
prompt,
|
||||||
options
|
cwd: options.cwd,
|
||||||
|
model: options.model,
|
||||||
|
permissionMode: options.permissionMode,
|
||||||
|
maxTurns: options.maxTurns,
|
||||||
|
allowedTools: options.allowedTools,
|
||||||
|
resume: options.resume
|
||||||
})
|
})
|
||||||
|
|
||||||
// Start async processing
|
// Start async processing on the next tick so listeners can subscribe first
|
||||||
this.processSDKQuery(prompt, options, aiStream)
|
setImmediate(() => {
|
||||||
|
this.processSDKQuery(prompt, options, aiStream, errorChunks).catch((error) => {
|
||||||
|
logger.error('Unhandled Claude Code stream error', {
|
||||||
|
error: error instanceof Error ? { name: error.name, message: error.message } : String(error)
|
||||||
|
})
|
||||||
|
aiStream.emit('data', {
|
||||||
|
type: 'error',
|
||||||
|
error: error instanceof Error ? error : new Error(String(error))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
return aiStream
|
return aiStream
|
||||||
}
|
}
|
||||||
@@ -135,11 +193,17 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
/**
|
/**
|
||||||
* Process SDK query and emit stream events
|
* Process SDK query and emit stream events
|
||||||
*/
|
*/
|
||||||
private async processSDKQuery(prompt: string, options: Options, stream: ClaudeCodeStream): Promise<void> {
|
private async processSDKQuery(
|
||||||
|
prompt: string,
|
||||||
|
options: Options,
|
||||||
|
stream: ClaudeCodeStream,
|
||||||
|
errorChunks: string[]
|
||||||
|
): Promise<void> {
|
||||||
const jsonOutput: SDKMessage[] = []
|
const jsonOutput: SDKMessage[] = []
|
||||||
let hasCompleted = false
|
let hasCompleted = false
|
||||||
const startTime = Date.now()
|
const startTime = Date.now()
|
||||||
|
|
||||||
|
const streamState = new ClaudeStreamState()
|
||||||
try {
|
try {
|
||||||
// Process streaming responses using SDK query
|
// Process streaming responses using SDK query
|
||||||
for await (const message of query({
|
for await (const message of query({
|
||||||
@@ -149,15 +213,26 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
if (hasCompleted) break
|
if (hasCompleted) break
|
||||||
|
|
||||||
jsonOutput.push(message)
|
jsonOutput.push(message)
|
||||||
logger.silly('claude response', { message })
|
|
||||||
if (message.type === 'assistant' || message.type === 'user') {
|
if (message.type === 'assistant' || message.type === 'user') {
|
||||||
logger.silly('message content', {
|
logger.silly('claude response', {
|
||||||
message: JSON.stringify({ role: message.message.role, content: message.message.content })
|
message,
|
||||||
|
content: JSON.stringify(message.message.content)
|
||||||
|
})
|
||||||
|
} else if (message.type === 'stream_event') {
|
||||||
|
logger.silly('Claude stream event', {
|
||||||
|
message,
|
||||||
|
event: JSON.stringify(message.event)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
logger.silly('Claude response', {
|
||||||
|
message,
|
||||||
|
event: JSON.stringify(message)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transform SDKMessage to UIMessageChunks
|
// Transform SDKMessage to UIMessageChunks
|
||||||
const chunks = transformSDKMessageToStreamParts(message)
|
const chunks = transformSDKMessageToStreamParts(message, streamState)
|
||||||
for (const chunk of chunks) {
|
for (const chunk of chunks) {
|
||||||
stream.emit('data', {
|
stream.emit('data', {
|
||||||
type: 'chunk',
|
type: 'chunk',
|
||||||
@@ -202,17 +277,17 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Original error handling for non-abort errors
|
errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj))
|
||||||
logger.error('SDK query error:', {
|
const errorMessage = errorChunks.join('\n\n')
|
||||||
error: errorObj instanceof Error ? errorObj.message : String(errorObj),
|
logger.error('SDK query failed', {
|
||||||
duration,
|
duration,
|
||||||
messageCount: jsonOutput.length
|
error: errorObj instanceof Error ? { name: errorObj.name, message: errorObj.message } : String(errorObj),
|
||||||
|
stderr: errorChunks
|
||||||
})
|
})
|
||||||
|
|
||||||
// Emit error event
|
// Emit error event
|
||||||
stream.emit('data', {
|
stream.emit('data', {
|
||||||
type: 'error',
|
type: 'error',
|
||||||
error: errorObj instanceof Error ? errorObj : new Error(String(errorObj))
|
error: new Error(errorMessage)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,47 +2,83 @@ import { Tool } from '@types'
|
|||||||
|
|
||||||
// https://docs.anthropic.com/en/docs/claude-code/settings#tools-available-to-claude
|
// https://docs.anthropic.com/en/docs/claude-code/settings#tools-available-to-claude
|
||||||
export const builtinTools: Tool[] = [
|
export const builtinTools: Tool[] = [
|
||||||
{ id: 'Bash', name: 'Bash', description: 'Executes shell commands in your environment', requirePermissions: true },
|
{
|
||||||
{ id: 'Edit', name: 'Edit', description: 'Makes targeted edits to specific files', requirePermissions: true },
|
id: 'Bash',
|
||||||
{ id: 'Glob', name: 'Glob', description: 'Finds files based on pattern matching', requirePermissions: false },
|
name: 'Bash',
|
||||||
{ id: 'Grep', name: 'Grep', description: 'Searches for patterns in file contents', requirePermissions: false },
|
description: 'Executes shell commands in your environment',
|
||||||
|
requirePermissions: true,
|
||||||
|
type: 'builtin'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'Edit',
|
||||||
|
name: 'Edit',
|
||||||
|
description: 'Makes targeted edits to specific files',
|
||||||
|
requirePermissions: true,
|
||||||
|
type: 'builtin'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'Glob',
|
||||||
|
name: 'Glob',
|
||||||
|
description: 'Finds files based on pattern matching',
|
||||||
|
requirePermissions: false,
|
||||||
|
type: 'builtin'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'Grep',
|
||||||
|
name: 'Grep',
|
||||||
|
description: 'Searches for patterns in file contents',
|
||||||
|
requirePermissions: false,
|
||||||
|
type: 'builtin'
|
||||||
|
},
|
||||||
{
|
{
|
||||||
id: 'MultiEdit',
|
id: 'MultiEdit',
|
||||||
name: 'MultiEdit',
|
name: 'MultiEdit',
|
||||||
description: 'Performs multiple edits on a single file atomically',
|
description: 'Performs multiple edits on a single file atomically',
|
||||||
requirePermissions: true
|
requirePermissions: true,
|
||||||
|
type: 'builtin'
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'NotebookEdit',
|
id: 'NotebookEdit',
|
||||||
name: 'NotebookEdit',
|
name: 'NotebookEdit',
|
||||||
description: 'Modifies Jupyter notebook cells',
|
description: 'Modifies Jupyter notebook cells',
|
||||||
requirePermissions: true
|
requirePermissions: true,
|
||||||
|
type: 'builtin'
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'NotebookRead',
|
id: 'NotebookRead',
|
||||||
name: 'NotebookRead',
|
name: 'NotebookRead',
|
||||||
description: 'Reads and displays Jupyter notebook contents',
|
description: 'Reads and displays Jupyter notebook contents',
|
||||||
requirePermissions: false
|
requirePermissions: false,
|
||||||
|
type: 'builtin'
|
||||||
},
|
},
|
||||||
{ id: 'Read', name: 'Read', description: 'Reads the contents of files', requirePermissions: false },
|
{ id: 'Read', name: 'Read', description: 'Reads the contents of files', requirePermissions: false, type: 'builtin' },
|
||||||
{
|
{
|
||||||
id: 'Task',
|
id: 'Task',
|
||||||
name: 'Task',
|
name: 'Task',
|
||||||
description: 'Runs a sub-agent to handle complex, multi-step tasks',
|
description: 'Runs a sub-agent to handle complex, multi-step tasks',
|
||||||
requirePermissions: false
|
requirePermissions: false,
|
||||||
|
type: 'builtin'
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'TodoWrite',
|
id: 'TodoWrite',
|
||||||
name: 'TodoWrite',
|
name: 'TodoWrite',
|
||||||
description: 'Creates and manages structured task lists',
|
description: 'Creates and manages structured task lists',
|
||||||
requirePermissions: false
|
requirePermissions: false,
|
||||||
|
type: 'builtin'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'WebFetch',
|
||||||
|
name: 'WebFetch',
|
||||||
|
description: 'Fetches content from a specified URL',
|
||||||
|
requirePermissions: true,
|
||||||
|
type: 'builtin'
|
||||||
},
|
},
|
||||||
{ id: 'WebFetch', name: 'WebFetch', description: 'Fetches content from a specified URL', requirePermissions: true },
|
|
||||||
{
|
{
|
||||||
id: 'WebSearch',
|
id: 'WebSearch',
|
||||||
name: 'WebSearch',
|
name: 'WebSearch',
|
||||||
description: 'Performs web searches with domain filtering',
|
description: 'Performs web searches with domain filtering',
|
||||||
requirePermissions: true
|
requirePermissions: true,
|
||||||
|
type: 'builtin'
|
||||||
},
|
},
|
||||||
{ id: 'Write', name: 'Write', description: 'Creates or overwrites files', requirePermissions: true }
|
{ id: 'Write', name: 'Write', description: 'Creates or overwrites files', requirePermissions: true, type: 'builtin' }
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,60 +1,82 @@
|
|||||||
// This file is used to transform claude code json response to aisdk streaming format
|
/**
|
||||||
|
* Translates Anthropic Claude Code streaming messages into the generic AiSDK stream
|
||||||
|
* parts that the agent runtime understands. The transformer coordinates batched
|
||||||
|
* text/tool payloads, keeps per-message state using {@link ClaudeStreamState},
|
||||||
|
* and normalises usage metadata and finish reasons so downstream consumers do
|
||||||
|
* not need to reason about Anthropic-specific payload shapes.
|
||||||
|
*
|
||||||
|
* Stream lifecycle cheatsheet (per Claude turn):
|
||||||
|
* 1. `stream_event.message_start` → emit `start-step` and mark the state as active.
|
||||||
|
* 2. `content_block_start` (by index) → open a stateful block; emits one of
|
||||||
|
* `text-start` | `reasoning-start` | `tool-input-start`.
|
||||||
|
* 3. `content_block_delta` → append incremental text / reasoning / tool JSON,
|
||||||
|
* emitting only the delta to minimise UI churn.
|
||||||
|
* 4. `content_block_stop` → emit the matching `*-end` event and release the block.
|
||||||
|
* 5. `message_delta` → capture usage + stop reason but defer emission.
|
||||||
|
* 6. `message_stop` → emit `finish-step` with cached usage & reason, then reset.
|
||||||
|
* 7. Assistant snapshots with `tool_use` finalise the tool block (`tool-call`).
|
||||||
|
* 8. User snapshots with `tool_result` emit `tool-result`/`tool-error` using the cached payload.
|
||||||
|
* 9. Assistant snapshots with plain text (when no stream events were provided) fall back to
|
||||||
|
* emitting `text-*` parts and a synthetic `finish-step`.
|
||||||
|
*/
|
||||||
|
|
||||||
import type { LanguageModelV2Usage } from '@ai-sdk/provider'
|
import { SDKMessage } from '@anthropic-ai/claude-agent-sdk'
|
||||||
import { SDKMessage } from '@anthropic-ai/claude-code'
|
import type { BetaStopReason } from '@anthropic-ai/sdk/resources/beta/messages/messages.mjs'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import type { ProviderMetadata, TextStreamPart } from 'ai'
|
import type { FinishReason, LanguageModelUsage, ProviderMetadata, TextStreamPart } from 'ai'
|
||||||
import { v4 as uuidv4 } from 'uuid'
|
import { v4 as uuidv4 } from 'uuid'
|
||||||
|
|
||||||
|
import { ClaudeStreamState } from './claude-stream-state'
|
||||||
import { mapClaudeCodeFinishReason } from './map-claude-code-finish-reason'
|
import { mapClaudeCodeFinishReason } from './map-claude-code-finish-reason'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ClaudeCodeTransform')
|
const logger = loggerService.withContext('ClaudeCodeTransform')
|
||||||
|
|
||||||
type AgentStreamPart = TextStreamPart<Record<string, any>>
|
type AgentStreamPart = TextStreamPart<Record<string, any>>
|
||||||
|
|
||||||
const contentBlockState = new Map<
|
type ToolUseContent = {
|
||||||
string,
|
type: 'tool_use'
|
||||||
{
|
id: string
|
||||||
type: 'text' | 'tool-call'
|
name: string
|
||||||
toolCallId?: string
|
input: unknown
|
||||||
toolName?: string
|
|
||||||
input?: string
|
|
||||||
}
|
|
||||||
>()
|
|
||||||
|
|
||||||
// Helper function to generate unique IDs for text blocks
|
|
||||||
const generateMessageId = (): string => `msg_${uuidv4().replace(/-/g, '')}`
|
|
||||||
|
|
||||||
// Main transform function
|
|
||||||
export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage): AgentStreamPart[] {
|
|
||||||
const chunks: AgentStreamPart[] = []
|
|
||||||
logger.debug('Transforming SDKMessage to stream parts', sdkMessage)
|
|
||||||
switch (sdkMessage.type) {
|
|
||||||
case 'assistant':
|
|
||||||
case 'user':
|
|
||||||
chunks.push(...handleUserOrAssistantMessage(sdkMessage))
|
|
||||||
break
|
|
||||||
|
|
||||||
case 'stream_event':
|
|
||||||
chunks.push(...handleStreamEvent(sdkMessage))
|
|
||||||
break
|
|
||||||
|
|
||||||
case 'system':
|
|
||||||
chunks.push(...handleSystemMessage(sdkMessage))
|
|
||||||
break
|
|
||||||
|
|
||||||
case 'result':
|
|
||||||
chunks.push(...handleResultMessage(sdkMessage))
|
|
||||||
break
|
|
||||||
|
|
||||||
default:
|
|
||||||
logger.warn('Unknown SDKMessage type:', { type: (sdkMessage as any).type })
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ToolResultContent = {
|
||||||
|
type: 'tool_result'
|
||||||
|
tool_use_id: string
|
||||||
|
content: unknown
|
||||||
|
is_error?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps Anthropic stop reasons to the AiSDK equivalents so higher level
|
||||||
|
* consumers can treat completion states uniformly across providers.
|
||||||
|
*/
|
||||||
|
const finishReasonMapping: Record<BetaStopReason, FinishReason> = {
|
||||||
|
end_turn: 'stop',
|
||||||
|
max_tokens: 'length',
|
||||||
|
stop_sequence: 'stop',
|
||||||
|
tool_use: 'tool-calls',
|
||||||
|
pause_turn: 'unknown',
|
||||||
|
refusal: 'content-filter'
|
||||||
|
}
|
||||||
|
|
||||||
|
const emptyUsage: LanguageModelUsage = {
|
||||||
|
inputTokens: 0,
|
||||||
|
outputTokens: 0,
|
||||||
|
totalTokens: 0
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generates deterministic-ish message identifiers that are compatible with the
|
||||||
|
* AiSDK text stream contract. Anthropic deltas sometimes omit ids, so we create
|
||||||
|
* our own to ensure the downstream renderer can stitch chunks together.
|
||||||
|
*/
|
||||||
|
const generateMessageId = (): string => `msg_${uuidv4().replace(/-/g, '')}`
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts provider metadata from the raw Claude message so we can surface it
|
||||||
|
* on every emitted stream part for observability and debugging purposes.
|
||||||
|
*/
|
||||||
const sdkMessageToProviderMetadata = (message: SDKMessage): ProviderMetadata => {
|
const sdkMessageToProviderMetadata = (message: SDKMessage): ProviderMetadata => {
|
||||||
return {
|
return {
|
||||||
anthropic: {
|
anthropic: {
|
||||||
@@ -65,252 +87,544 @@ const sdkMessageToProviderMetadata = (message: SDKMessage): ProviderMetadata =>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function generateTextChunks(id: string, text: string, message: SDKMessage): AgentStreamPart[] {
|
/**
|
||||||
|
* Central entrypoint that receives Claude Code websocket events and converts
|
||||||
|
* them into AiSDK `TextStreamPart`s. The state machine tracks outstanding
|
||||||
|
* blocks across calls so that incremental deltas can be correlated correctly.
|
||||||
|
*/
|
||||||
|
export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state: ClaudeStreamState): AgentStreamPart[] {
|
||||||
|
switch (sdkMessage.type) {
|
||||||
|
case 'assistant':
|
||||||
|
return handleAssistantMessage(sdkMessage, state)
|
||||||
|
case 'user':
|
||||||
|
return handleUserMessage(sdkMessage, state)
|
||||||
|
case 'stream_event':
|
||||||
|
return handleStreamEvent(sdkMessage, state)
|
||||||
|
case 'system':
|
||||||
|
return handleSystemMessage(sdkMessage)
|
||||||
|
case 'result':
|
||||||
|
return handleResultMessage(sdkMessage)
|
||||||
|
default:
|
||||||
|
logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type })
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles aggregated assistant messages that arrive outside of the streaming
|
||||||
|
* protocol (e.g. after a tool call finishes). We emit the appropriate
|
||||||
|
* text/tool events and close the active step once the payload is fully
|
||||||
|
* processed.
|
||||||
|
*/
|
||||||
|
function handleAssistantMessage(
|
||||||
|
message: Extract<SDKMessage, { type: 'assistant' }>,
|
||||||
|
state: ClaudeStreamState
|
||||||
|
): AgentStreamPart[] {
|
||||||
|
const chunks: AgentStreamPart[] = []
|
||||||
const providerMetadata = sdkMessageToProviderMetadata(message)
|
const providerMetadata = sdkMessageToProviderMetadata(message)
|
||||||
return [
|
const content = message.message.content
|
||||||
{
|
const isStreamingActive = state.hasActiveStep()
|
||||||
|
|
||||||
|
if (typeof content === 'string') {
|
||||||
|
if (!content) {
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isStreamingActive) {
|
||||||
|
state.beginStep()
|
||||||
|
chunks.push({
|
||||||
|
type: 'start-step',
|
||||||
|
request: { body: '' },
|
||||||
|
warnings: []
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const textId = message.uuid?.toString() || generateMessageId()
|
||||||
|
chunks.push({
|
||||||
|
type: 'text-start',
|
||||||
|
id: textId,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
chunks.push({
|
||||||
|
type: 'text-delta',
|
||||||
|
id: textId,
|
||||||
|
text: content,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
chunks.push({
|
||||||
|
type: 'text-end',
|
||||||
|
id: textId,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
return finalizeNonStreamingStep(message, state, chunks)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!Array.isArray(content)) {
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
const textBlocks: string[] = []
|
||||||
|
|
||||||
|
for (const block of content) {
|
||||||
|
switch (block.type) {
|
||||||
|
case 'text':
|
||||||
|
if (!isStreamingActive) {
|
||||||
|
textBlocks.push(block.text)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
case 'tool_use':
|
||||||
|
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
logger.warn('Unhandled assistant content block', { type: (block as any).type })
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isStreamingActive && textBlocks.length > 0) {
|
||||||
|
const id = message.uuid?.toString() || generateMessageId()
|
||||||
|
state.beginStep()
|
||||||
|
chunks.push({
|
||||||
|
type: 'start-step',
|
||||||
|
request: { body: '' },
|
||||||
|
warnings: []
|
||||||
|
})
|
||||||
|
chunks.push({
|
||||||
type: 'text-start',
|
type: 'text-start',
|
||||||
id,
|
id,
|
||||||
providerMetadata
|
providerMetadata
|
||||||
},
|
})
|
||||||
{
|
chunks.push({
|
||||||
type: 'text-delta',
|
type: 'text-delta',
|
||||||
id,
|
id,
|
||||||
text,
|
text: textBlocks.join(''),
|
||||||
providerMetadata
|
providerMetadata
|
||||||
},
|
})
|
||||||
{
|
chunks.push({
|
||||||
type: 'text-end',
|
type: 'text-end',
|
||||||
id,
|
id,
|
||||||
providerMetadata: {
|
providerMetadata
|
||||||
...providerMetadata,
|
})
|
||||||
text: {
|
return finalizeNonStreamingStep(message, state, chunks)
|
||||||
value: text
|
}
|
||||||
}
|
|
||||||
}
|
return chunks
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assistant' | 'user' }>): AgentStreamPart[] {
|
/**
|
||||||
const chunks: AgentStreamPart[] = []
|
* Registers tool invocations with the stream state so that later tool results
|
||||||
const messageId = message.uuid?.toString() || generateMessageId()
|
* can be matched with the originating call.
|
||||||
|
*/
|
||||||
|
function handleAssistantToolUse(
|
||||||
|
block: ToolUseContent,
|
||||||
|
providerMetadata: ProviderMetadata,
|
||||||
|
state: ClaudeStreamState,
|
||||||
|
chunks: AgentStreamPart[]
|
||||||
|
): void {
|
||||||
|
chunks.push({
|
||||||
|
type: 'tool-call',
|
||||||
|
toolCallId: block.id,
|
||||||
|
toolName: block.name,
|
||||||
|
input: block.input,
|
||||||
|
providerExecuted: true,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
state.completeToolBlock(block.id, block.input, providerMetadata)
|
||||||
|
}
|
||||||
|
|
||||||
// handle normal text content
|
/**
|
||||||
if (typeof message.message.content === 'string') {
|
* Emits the terminating `finish-step` frame for non-streamed responses and
|
||||||
const textContent = message.message.content
|
* clears the currently active step in the state tracker.
|
||||||
if (textContent) {
|
*/
|
||||||
chunks.push(...generateTextChunks(messageId, textContent, message))
|
function finalizeNonStreamingStep(
|
||||||
|
message: Extract<SDKMessage, { type: 'assistant' }>,
|
||||||
|
state: ClaudeStreamState,
|
||||||
|
chunks: AgentStreamPart[]
|
||||||
|
): AgentStreamPart[] {
|
||||||
|
const usage = calculateUsageFromMessage(message)
|
||||||
|
const finishReason = inferFinishReason(message.message.stop_reason)
|
||||||
|
chunks.push({
|
||||||
|
type: 'finish-step',
|
||||||
|
response: {
|
||||||
|
id: message.uuid,
|
||||||
|
timestamp: new Date(),
|
||||||
|
modelId: message.message.model ?? ''
|
||||||
|
},
|
||||||
|
usage: usage ?? emptyUsage,
|
||||||
|
finishReason,
|
||||||
|
providerMetadata: sdkMessageToProviderMetadata(message)
|
||||||
|
})
|
||||||
|
state.resetStep()
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts user-originated websocket frames (text, tool results, etc.) into
|
||||||
|
* the AiSDK format. Tool results are matched back to pending tool calls via the
|
||||||
|
* shared `ClaudeStreamState` instance.
|
||||||
|
*/
|
||||||
|
function handleUserMessage(
|
||||||
|
message: Extract<SDKMessage, { type: 'user' }>,
|
||||||
|
state: ClaudeStreamState
|
||||||
|
): AgentStreamPart[] {
|
||||||
|
const chunks: AgentStreamPart[] = []
|
||||||
|
const providerMetadata = sdkMessageToProviderMetadata(message)
|
||||||
|
const content = message.message.content
|
||||||
|
|
||||||
|
if (typeof content === 'string') {
|
||||||
|
if (!content) {
|
||||||
|
return chunks
|
||||||
}
|
}
|
||||||
} else if (Array.isArray(message.message.content)) {
|
|
||||||
for (const block of message.message.content) {
|
const id = message.uuid?.toString() || generateMessageId()
|
||||||
switch (block.type) {
|
chunks.push({
|
||||||
case 'text':
|
type: 'text-start',
|
||||||
chunks.push(...generateTextChunks(messageId, block.text, message))
|
id,
|
||||||
break
|
providerMetadata
|
||||||
case 'tool_use':
|
})
|
||||||
chunks.push({
|
chunks.push({
|
||||||
type: 'tool-call',
|
type: 'text-delta',
|
||||||
toolCallId: block.id,
|
id,
|
||||||
toolName: block.name,
|
text: content,
|
||||||
input: block.input,
|
providerMetadata
|
||||||
providerExecuted: true,
|
})
|
||||||
providerMetadata: sdkMessageToProviderMetadata(message)
|
chunks.push({
|
||||||
})
|
type: 'text-end',
|
||||||
break
|
id,
|
||||||
case 'tool_result':
|
providerMetadata
|
||||||
// chunks.push({
|
})
|
||||||
// type: 'tool-result',
|
return chunks
|
||||||
// toolCallId: block.tool_use_id,
|
}
|
||||||
// output: block.content,
|
|
||||||
// providerMetadata: sdkMessageToProviderMetadata(message)
|
if (!Array.isArray(content)) {
|
||||||
// })
|
return chunks
|
||||||
break
|
}
|
||||||
default:
|
|
||||||
logger.warn('Unknown content block type in user/assistant message:', {
|
for (const block of content) {
|
||||||
type: block.type
|
if (block.type === 'tool_result') {
|
||||||
})
|
const toolResult = block as ToolResultContent
|
||||||
break
|
const pendingCall = state.consumePendingToolCall(toolResult.tool_use_id)
|
||||||
|
if (toolResult.is_error) {
|
||||||
|
chunks.push({
|
||||||
|
type: 'tool-error',
|
||||||
|
toolCallId: toolResult.tool_use_id,
|
||||||
|
toolName: pendingCall?.toolName ?? 'unknown',
|
||||||
|
input: pendingCall?.input,
|
||||||
|
error: toolResult.content,
|
||||||
|
providerExecuted: true
|
||||||
|
} as AgentStreamPart)
|
||||||
|
} else {
|
||||||
|
chunks.push({
|
||||||
|
type: 'tool-result',
|
||||||
|
toolCallId: toolResult.tool_use_id,
|
||||||
|
toolName: pendingCall?.toolName ?? 'unknown',
|
||||||
|
input: pendingCall?.input,
|
||||||
|
output: toolResult.content,
|
||||||
|
providerExecuted: true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
} else if (block.type === 'text') {
|
||||||
|
const id = message.uuid?.toString() || generateMessageId()
|
||||||
|
chunks.push({
|
||||||
|
type: 'text-start',
|
||||||
|
id,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
chunks.push({
|
||||||
|
type: 'text-delta',
|
||||||
|
id,
|
||||||
|
text: (block as { text: string }).text,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
chunks.push({
|
||||||
|
type: 'text-end',
|
||||||
|
id,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
logger.warn('Unhandled user content block', { type: (block as any).type })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle stream events (real-time streaming)
|
/**
|
||||||
function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }>): AgentStreamPart[] {
|
* Handles the fine-grained real-time streaming protocol where Anthropic emits
|
||||||
|
* discrete events for message lifecycle, content blocks, and usage deltas.
|
||||||
|
*/
|
||||||
|
function handleStreamEvent(
|
||||||
|
message: Extract<SDKMessage, { type: 'stream_event' }>,
|
||||||
|
state: ClaudeStreamState
|
||||||
|
): AgentStreamPart[] {
|
||||||
const chunks: AgentStreamPart[] = []
|
const chunks: AgentStreamPart[] = []
|
||||||
const event = message.event
|
const providerMetadata = sdkMessageToProviderMetadata(message)
|
||||||
const blockKey = `${message.uuid ?? message.session_id ?? 'session'}:${event.index}`
|
const { event } = message
|
||||||
|
|
||||||
switch (event.type) {
|
switch (event.type) {
|
||||||
case 'message_start':
|
case 'message_start':
|
||||||
// No specific UI chunk needed for message start in this protocol
|
state.beginStep()
|
||||||
|
chunks.push({
|
||||||
|
type: 'start-step',
|
||||||
|
request: { body: '' },
|
||||||
|
warnings: []
|
||||||
|
})
|
||||||
break
|
break
|
||||||
|
|
||||||
case 'content_block_start':
|
case 'content_block_start':
|
||||||
const contentBlockType = event.content_block.type
|
handleContentBlockStart(event.index, event.content_block, providerMetadata, state, chunks)
|
||||||
switch (contentBlockType) {
|
|
||||||
case 'text': {
|
|
||||||
contentBlockState.set(blockKey, { type: 'text' })
|
|
||||||
chunks.push({
|
|
||||||
type: 'text-start',
|
|
||||||
id: String(event.index),
|
|
||||||
providerMetadata: {
|
|
||||||
...sdkMessageToProviderMetadata(message),
|
|
||||||
anthropic: {
|
|
||||||
uuid: message.uuid,
|
|
||||||
session_id: message.session_id,
|
|
||||||
content_block_index: event.index
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
break
|
|
||||||
}
|
|
||||||
case 'tool_use': {
|
|
||||||
contentBlockState.set(blockKey, {
|
|
||||||
type: 'tool-call',
|
|
||||||
toolCallId: event.content_block.id,
|
|
||||||
toolName: event.content_block.name,
|
|
||||||
input: ''
|
|
||||||
})
|
|
||||||
chunks.push({
|
|
||||||
type: 'tool-call',
|
|
||||||
toolCallId: event.content_block.id,
|
|
||||||
toolName: event.content_block.name,
|
|
||||||
input: event.content_block.input,
|
|
||||||
providerExecuted: true,
|
|
||||||
providerMetadata: sdkMessageToProviderMetadata(message)
|
|
||||||
})
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
break
|
||||||
|
|
||||||
case 'content_block_delta':
|
case 'content_block_delta':
|
||||||
switch (event.delta.type) {
|
handleContentBlockDelta(event.index, event.delta, providerMetadata, state, chunks)
|
||||||
case 'text_delta': {
|
|
||||||
chunks.push({
|
|
||||||
type: 'text-delta',
|
|
||||||
id: String(event.index),
|
|
||||||
text: event.delta.text,
|
|
||||||
providerMetadata: {
|
|
||||||
...sdkMessageToProviderMetadata(message),
|
|
||||||
anthropic: {
|
|
||||||
uuid: message.uuid,
|
|
||||||
session_id: message.session_id,
|
|
||||||
content_block_index: event.index
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
break
|
|
||||||
}
|
|
||||||
// case 'thinking_delta': {
|
|
||||||
// chunks.push({
|
|
||||||
// type: 'reasoning-delta',
|
|
||||||
// id: String(event.index),
|
|
||||||
// text: event.delta.thinking,
|
|
||||||
// });
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
// case 'signature_delta': {
|
|
||||||
// if (blockType === 'thinking') {
|
|
||||||
// chunks.push({
|
|
||||||
// type: 'reasoning-delta',
|
|
||||||
// id: String(event.index),
|
|
||||||
// text: '',
|
|
||||||
// providerMetadata: {
|
|
||||||
// ...sdkMessageToProviderMetadata(message),
|
|
||||||
// anthropic: {
|
|
||||||
// uuid: message.uuid,
|
|
||||||
// session_id: message.session_id,
|
|
||||||
// content_block_index: event.index,
|
|
||||||
// signature: event.delta.signature
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
case 'input_json_delta': {
|
|
||||||
const contentBlock = contentBlockState.get(blockKey)
|
|
||||||
if (contentBlock && contentBlock.type === 'tool-call') {
|
|
||||||
contentBlockState.set(blockKey, {
|
|
||||||
...contentBlock,
|
|
||||||
input: `${contentBlock.input ?? ''}${event.delta.partial_json ?? ''}`
|
|
||||||
})
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
break
|
||||||
|
|
||||||
case 'content_block_stop': {
|
case 'content_block_stop': {
|
||||||
const contentBlock = contentBlockState.get(blockKey)
|
const block = state.closeBlock(event.index)
|
||||||
if (contentBlock?.type === 'text') {
|
if (!block) {
|
||||||
chunks.push({
|
logger.warn('Received content_block_stop for unknown index', { index: event.index })
|
||||||
type: 'text-end',
|
break
|
||||||
id: String(event.index)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
contentBlockState.delete(blockKey)
|
|
||||||
|
switch (block.kind) {
|
||||||
|
case 'text':
|
||||||
|
chunks.push({
|
||||||
|
type: 'text-end',
|
||||||
|
id: block.id,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'reasoning':
|
||||||
|
chunks.push({
|
||||||
|
type: 'reasoning-end',
|
||||||
|
id: block.id,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'tool':
|
||||||
|
chunks.push({
|
||||||
|
type: 'tool-input-end',
|
||||||
|
id: block.toolCallId,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
break
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
case 'message_delta':
|
case 'message_delta': {
|
||||||
// Handle usage updates or other message-level deltas
|
const finishReason = event.delta.stop_reason
|
||||||
|
? mapStopReason(event.delta.stop_reason as BetaStopReason)
|
||||||
|
: undefined
|
||||||
|
const usage = convertUsage(event.usage)
|
||||||
|
state.setPendingUsage(usage, finishReason)
|
||||||
break
|
break
|
||||||
|
}
|
||||||
|
|
||||||
case 'message_stop':
|
case 'message_stop': {
|
||||||
// This could signal the end of the message
|
const pending = state.getPendingUsage()
|
||||||
|
chunks.push({
|
||||||
|
type: 'finish-step',
|
||||||
|
response: {
|
||||||
|
id: message.uuid,
|
||||||
|
timestamp: new Date(),
|
||||||
|
modelId: ''
|
||||||
|
},
|
||||||
|
usage: pending.usage ?? emptyUsage,
|
||||||
|
finishReason: pending.finishReason ?? 'stop',
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
state.resetStep()
|
||||||
break
|
break
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logger.warn('Unknown stream event type:', { type: (event as any).type })
|
logger.warn('Unknown stream event type', { type: (event as any).type })
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle system messages
|
/**
|
||||||
function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>): AgentStreamPart[] {
|
* Opens the appropriate block type when Claude starts streaming a new content
|
||||||
const chunks: AgentStreamPart[] = []
|
* section so later deltas know which logical entity to append to.
|
||||||
logger.debug('Received system message', {
|
*/
|
||||||
subtype: message.subtype
|
function handleContentBlockStart(
|
||||||
})
|
index: number,
|
||||||
switch (message.subtype) {
|
contentBlock: any,
|
||||||
case 'init': {
|
providerMetadata: ProviderMetadata,
|
||||||
|
state: ClaudeStreamState,
|
||||||
|
chunks: AgentStreamPart[]
|
||||||
|
): void {
|
||||||
|
switch (contentBlock.type) {
|
||||||
|
case 'text': {
|
||||||
|
const block = state.openTextBlock(index, generateMessageId())
|
||||||
chunks.push({
|
chunks.push({
|
||||||
type: 'start'
|
type: 'text-start',
|
||||||
|
id: block.id,
|
||||||
|
providerMetadata
|
||||||
})
|
})
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
case 'thinking':
|
||||||
|
case 'redacted_thinking': {
|
||||||
|
const block = state.openReasoningBlock(index, generateMessageId(), contentBlock.type === 'redacted_thinking')
|
||||||
|
chunks.push({
|
||||||
|
type: 'reasoning-start',
|
||||||
|
id: block.id,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'tool_use': {
|
||||||
|
const block = state.openToolBlock(index, {
|
||||||
|
toolCallId: contentBlock.id,
|
||||||
|
toolName: contentBlock.name,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
chunks.push({
|
||||||
|
type: 'tool-input-start',
|
||||||
|
id: block.toolCallId,
|
||||||
|
toolName: block.toolName,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
logger.warn('Unhandled content_block_start type', { type: contentBlock.type })
|
||||||
|
break
|
||||||
}
|
}
|
||||||
return []
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle result messages (completion with usage stats)
|
/**
|
||||||
|
* Applies incremental deltas to the active block (text, thinking, tool input)
|
||||||
|
* and emits the translated AiSDK chunk immediately.
|
||||||
|
*/
|
||||||
|
function handleContentBlockDelta(
|
||||||
|
index: number,
|
||||||
|
delta: any,
|
||||||
|
providerMetadata: ProviderMetadata,
|
||||||
|
state: ClaudeStreamState,
|
||||||
|
chunks: AgentStreamPart[]
|
||||||
|
): void {
|
||||||
|
switch (delta.type) {
|
||||||
|
case 'text_delta': {
|
||||||
|
const block = state.appendTextDelta(index, delta.text)
|
||||||
|
if (!block) {
|
||||||
|
logger.warn('Received text_delta for unknown block', { index })
|
||||||
|
return
|
||||||
|
}
|
||||||
|
chunks.push({
|
||||||
|
type: 'text-delta',
|
||||||
|
id: block.id,
|
||||||
|
text: block.text,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'thinking_delta': {
|
||||||
|
const block = state.appendReasoningDelta(index, delta.thinking)
|
||||||
|
if (!block) {
|
||||||
|
logger.warn('Received thinking_delta for unknown block', { index })
|
||||||
|
return
|
||||||
|
}
|
||||||
|
chunks.push({
|
||||||
|
type: 'reasoning-delta',
|
||||||
|
id: block.id,
|
||||||
|
text: delta.thinking,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'signature_delta': {
|
||||||
|
const block = state.getBlock(index)
|
||||||
|
if (block && block.kind === 'reasoning') {
|
||||||
|
chunks.push({
|
||||||
|
type: 'reasoning-delta',
|
||||||
|
id: block.id,
|
||||||
|
text: '',
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'input_json_delta': {
|
||||||
|
const block = state.appendToolInputDelta(index, delta.partial_json)
|
||||||
|
if (!block) {
|
||||||
|
logger.warn('Received input_json_delta for unknown block', { index })
|
||||||
|
return
|
||||||
|
}
|
||||||
|
chunks.push({
|
||||||
|
type: 'tool-input-delta',
|
||||||
|
id: block.toolCallId,
|
||||||
|
delta: block.inputBuffer,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
logger.warn('Unhandled content_block_delta type', { type: delta.type })
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* System messages currently only deliver the session bootstrap payload. We
|
||||||
|
* forward it as both a `start` marker and a raw snapshot for diagnostics.
|
||||||
|
*/
|
||||||
|
function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>): AgentStreamPart[] {
|
||||||
|
const chunks: AgentStreamPart[] = []
|
||||||
|
if (message.subtype === 'init') {
|
||||||
|
chunks.push({
|
||||||
|
type: 'start'
|
||||||
|
})
|
||||||
|
chunks.push({
|
||||||
|
type: 'raw',
|
||||||
|
rawValue: {
|
||||||
|
type: 'init',
|
||||||
|
session_id: message.session_id,
|
||||||
|
slash_commands: message.slash_commands,
|
||||||
|
tools: message.tools,
|
||||||
|
raw: message
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else if (message.subtype === 'compact_boundary') {
|
||||||
|
chunks.push({
|
||||||
|
type: 'raw',
|
||||||
|
rawValue: {
|
||||||
|
type: 'compact',
|
||||||
|
session_id: message.session_id,
|
||||||
|
raw: message
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Terminal result messages arrive once the Claude Code session concludes.
|
||||||
|
* Successful runs yield a `finish` frame with aggregated usage metrics, while
|
||||||
|
* failures are surfaced as `error` frames.
|
||||||
|
*/
|
||||||
function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>): AgentStreamPart[] {
|
function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>): AgentStreamPart[] {
|
||||||
const chunks: AgentStreamPart[] = []
|
const chunks: AgentStreamPart[] = []
|
||||||
|
|
||||||
let usage: LanguageModelV2Usage | undefined
|
let usage: LanguageModelUsage | undefined
|
||||||
if ('usage' in message) {
|
if ('usage' in message) {
|
||||||
usage = {
|
usage = {
|
||||||
inputTokens:
|
inputTokens: message.usage.input_tokens ?? 0,
|
||||||
(message.usage.cache_creation_input_tokens ?? 0) +
|
|
||||||
(message.usage.cache_read_input_tokens ?? 0) +
|
|
||||||
(message.usage.input_tokens ?? 0),
|
|
||||||
outputTokens: message.usage.output_tokens ?? 0,
|
outputTokens: message.usage.output_tokens ?? 0,
|
||||||
totalTokens:
|
totalTokens: (message.usage.input_tokens ?? 0) + (message.usage.output_tokens ?? 0)
|
||||||
(message.usage.cache_creation_input_tokens ?? 0) +
|
|
||||||
(message.usage.cache_read_input_tokens ?? 0) +
|
|
||||||
(message.usage.input_tokens ?? 0) +
|
|
||||||
(message.usage.output_tokens ?? 0)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (message.subtype === 'success') {
|
if (message.subtype === 'success') {
|
||||||
chunks.push({
|
chunks.push({
|
||||||
type: 'finish',
|
type: 'finish',
|
||||||
totalUsage: usage,
|
totalUsage: usage ?? emptyUsage,
|
||||||
finishReason: mapClaudeCodeFinishReason(message.subtype),
|
finishReason: mapClaudeCodeFinishReason(message.subtype),
|
||||||
providerMetadata: {
|
providerMetadata: {
|
||||||
...sdkMessageToProviderMetadata(message),
|
...sdkMessageToProviderMetadata(message),
|
||||||
@@ -331,24 +645,59 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
|
|||||||
return chunks
|
return chunks
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convenience function to transform a stream of SDKMessages
|
/**
|
||||||
export function* transformSDKMessageStream(sdkMessages: SDKMessage[]): Generator<AgentStreamPart> {
|
* Normalises usage payloads so the caller always receives numeric values even
|
||||||
for (const sdkMessage of sdkMessages) {
|
* when the provider omits certain fields.
|
||||||
const chunks = transformSDKMessageToStreamParts(sdkMessage)
|
*/
|
||||||
for (const chunk of chunks) {
|
function convertUsage(
|
||||||
yield chunk
|
usage?: {
|
||||||
}
|
input_tokens?: number | null
|
||||||
|
output_tokens?: number | null
|
||||||
|
} | null
|
||||||
|
): LanguageModelUsage | undefined {
|
||||||
|
if (!usage) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
const inputTokens = usage.input_tokens ?? 0
|
||||||
|
const outputTokens = usage.output_tokens ?? 0
|
||||||
|
return {
|
||||||
|
inputTokens,
|
||||||
|
outputTokens,
|
||||||
|
totalTokens: inputTokens + outputTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Async version for async iterables
|
/**
|
||||||
export async function* transformSDKMessageStreamAsync(
|
* Anthropic-only wrapper around {@link finishReasonMapping} that defaults to
|
||||||
sdkMessages: AsyncIterable<SDKMessage>
|
* `unknown` to avoid surprising downstream consumers when new stop reasons are
|
||||||
): AsyncGenerator<AgentStreamPart> {
|
* introduced.
|
||||||
for await (const sdkMessage of sdkMessages) {
|
*/
|
||||||
const chunks = transformSDKMessageToStreamParts(sdkMessage)
|
function mapStopReason(reason: BetaStopReason): FinishReason {
|
||||||
for (const chunk of chunks) {
|
return finishReasonMapping[reason] ?? 'unknown'
|
||||||
yield chunk
|
}
|
||||||
}
|
|
||||||
|
/**
|
||||||
|
* Extracts token accounting details from an assistant message, if available.
|
||||||
|
*/
|
||||||
|
function calculateUsageFromMessage(
|
||||||
|
message: Extract<SDKMessage, { type: 'assistant' }>
|
||||||
|
): LanguageModelUsage | undefined {
|
||||||
|
const usage = message.message.usage
|
||||||
|
if (!usage) return undefined
|
||||||
|
return {
|
||||||
|
inputTokens: usage.input_tokens ?? 0,
|
||||||
|
outputTokens: usage.output_tokens ?? 0,
|
||||||
|
totalTokens: (usage.input_tokens ?? 0) + (usage.output_tokens ?? 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts Anthropic stop reasons into AiSDK finish reasons, falling back to a
|
||||||
|
* generic `stop` if the provider omits the detail entirely.
|
||||||
|
*/
|
||||||
|
function inferFinishReason(stopReason: BetaStopReason | null | undefined): FinishReason {
|
||||||
|
if (!stopReason) return 'stop'
|
||||||
|
return mapStopReason(stopReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
export { ClaudeStreamState }
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import {
|
|||||||
OAuthTokens
|
OAuthTokens
|
||||||
} from '@modelcontextprotocol/sdk/shared/auth.js'
|
} from '@modelcontextprotocol/sdk/shared/auth.js'
|
||||||
import EventEmitter from 'events'
|
import EventEmitter from 'events'
|
||||||
import { z } from 'zod'
|
import * as z from 'zod'
|
||||||
|
|
||||||
export interface OAuthStorageData {
|
export interface OAuthStorageData {
|
||||||
clientInfo?: OAuthClientInformation
|
clientInfo?: OAuthClientInformation
|
||||||
|
|||||||
@@ -1,122 +0,0 @@
|
|||||||
import { loggerService } from '@logger'
|
|
||||||
import { spawn } from 'child_process'
|
|
||||||
import os from 'os'
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('ShellEnv')
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Spawns a login shell in the user's home directory to capture its environment variables.
|
|
||||||
* @returns {Promise<Object>} A promise that resolves with an object containing
|
|
||||||
* the environment variables, or rejects with an error.
|
|
||||||
*/
|
|
||||||
function getLoginShellEnvironment(): Promise<Record<string, string>> {
|
|
||||||
return new Promise((resolve, reject) => {
|
|
||||||
const homeDirectory = os.homedir()
|
|
||||||
if (!homeDirectory) {
|
|
||||||
return reject(new Error("Could not determine user's home directory."))
|
|
||||||
}
|
|
||||||
|
|
||||||
let shellPath = process.env.SHELL
|
|
||||||
let commandArgs
|
|
||||||
let shellCommandToGetEnv
|
|
||||||
|
|
||||||
const platform = os.platform()
|
|
||||||
|
|
||||||
if (platform === 'win32') {
|
|
||||||
// On Windows, 'cmd.exe' is the common shell.
|
|
||||||
// The 'set' command lists environment variables.
|
|
||||||
// We don't typically talk about "login shells" in the same way,
|
|
||||||
// but cmd will load the user's environment.
|
|
||||||
shellPath = process.env.COMSPEC || 'cmd.exe'
|
|
||||||
shellCommandToGetEnv = 'set'
|
|
||||||
commandArgs = ['/c', shellCommandToGetEnv] // /c Carries out the command specified by string and then terminates
|
|
||||||
} else {
|
|
||||||
// For POSIX systems (Linux, macOS)
|
|
||||||
if (!shellPath) {
|
|
||||||
// Fallback if process.env.SHELL is not set (less common for interactive users)
|
|
||||||
// Defaulting to bash, but this might not be the user's actual login shell.
|
|
||||||
// A more robust solution might involve checking /etc/passwd or similar,
|
|
||||||
// but that's more complex and often requires higher privileges or native modules.
|
|
||||||
logger.warn("process.env.SHELL is not set. Defaulting to /bin/bash. This might not be the user's login shell.")
|
|
||||||
shellPath = '/bin/bash' // A common default
|
|
||||||
}
|
|
||||||
// -l: Make it a login shell. This sources profile files like .profile, .bash_profile, .zprofile etc.
|
|
||||||
// -i: Make it interactive. Some shells or profile scripts behave differently.
|
|
||||||
// 'env': The command to print environment variables.
|
|
||||||
// Using 'env -0' would be more robust for parsing if values contain newlines,
|
|
||||||
// but requires splitting by null character. For simplicity, we'll use 'env'.
|
|
||||||
shellCommandToGetEnv = 'env'
|
|
||||||
commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(`Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`)
|
|
||||||
|
|
||||||
const child = spawn(shellPath, commandArgs, {
|
|
||||||
cwd: homeDirectory, // Run the command in the user's home directory
|
|
||||||
detached: true, // Allows the parent to exit independently of the child
|
|
||||||
stdio: ['ignore', 'pipe', 'pipe'], // stdin, stdout, stderr
|
|
||||||
shell: false // We are specifying the shell command directly
|
|
||||||
})
|
|
||||||
|
|
||||||
let output = ''
|
|
||||||
let errorOutput = ''
|
|
||||||
|
|
||||||
child.stdout.on('data', (data) => {
|
|
||||||
output += data.toString()
|
|
||||||
})
|
|
||||||
|
|
||||||
child.stderr.on('data', (data) => {
|
|
||||||
errorOutput += data.toString()
|
|
||||||
})
|
|
||||||
|
|
||||||
child.on('error', (error) => {
|
|
||||||
logger.error(`Failed to start shell process: ${shellPath}`, error)
|
|
||||||
reject(new Error(`Failed to start shell: ${error.message}`))
|
|
||||||
})
|
|
||||||
|
|
||||||
child.on('close', (code) => {
|
|
||||||
if (code !== 0) {
|
|
||||||
const errorMessage = `Shell process exited with code ${code}. Shell: ${shellPath}. Args: ${commandArgs.join(' ')}. CWD: ${homeDirectory}. Stderr: ${errorOutput.trim()}`
|
|
||||||
logger.error(errorMessage)
|
|
||||||
return reject(new Error(errorMessage))
|
|
||||||
}
|
|
||||||
|
|
||||||
if (errorOutput.trim()) {
|
|
||||||
// Some shells might output warnings or non-fatal errors to stderr
|
|
||||||
// during profile loading. Log it, but proceed if exit code is 0.
|
|
||||||
logger.warn(`Shell process stderr output (even with exit code 0):\n${errorOutput.trim()}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
const env: Record<string, string> = {}
|
|
||||||
const lines = output.split('\n')
|
|
||||||
|
|
||||||
lines.forEach((line) => {
|
|
||||||
const trimmedLine = line.trim()
|
|
||||||
if (trimmedLine) {
|
|
||||||
const separatorIndex = trimmedLine.indexOf('=')
|
|
||||||
if (separatorIndex > 0) {
|
|
||||||
// Ensure '=' is present and it's not the first character
|
|
||||||
const key = trimmedLine.substring(0, separatorIndex)
|
|
||||||
const value = trimmedLine.substring(separatorIndex + 1)
|
|
||||||
env[key] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if (Object.keys(env).length === 0 && output.length < 100) {
|
|
||||||
// Arbitrary small length check
|
|
||||||
// This might indicate an issue if no env vars were parsed or output was minimal
|
|
||||||
logger.warn(
|
|
||||||
'Parsed environment is empty or output was very short. This might indicate an issue with shell execution or environment variable retrieval.'
|
|
||||||
)
|
|
||||||
logger.warn(`Raw output from shell:\n${output}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
env.PATH = env.Path || env.PATH || ''
|
|
||||||
|
|
||||||
resolve(env)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
export default getLoginShellEnvironment
|
|
||||||
@@ -2,6 +2,7 @@ import { loggerService } from '@logger'
|
|||||||
import { isLinux } from '@main/constant'
|
import { isLinux } from '@main/constant'
|
||||||
import { BuiltinOcrProviderIds, OcrHandler, OcrProvider, OcrResult, SupportedOcrFile } from '@types'
|
import { BuiltinOcrProviderIds, OcrHandler, OcrProvider, OcrResult, SupportedOcrFile } from '@types'
|
||||||
|
|
||||||
|
import { ovOcrService } from './builtin/OvOcrService'
|
||||||
import { ppocrService } from './builtin/PpocrService'
|
import { ppocrService } from './builtin/PpocrService'
|
||||||
import { systemOcrService } from './builtin/SystemOcrService'
|
import { systemOcrService } from './builtin/SystemOcrService'
|
||||||
import { tesseractService } from './builtin/TesseractService'
|
import { tesseractService } from './builtin/TesseractService'
|
||||||
@@ -22,6 +23,10 @@ export class OcrService {
|
|||||||
this.registry.delete(providerId)
|
this.registry.delete(providerId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public listProviderIds(): string[] {
|
||||||
|
return Array.from(this.registry.keys())
|
||||||
|
}
|
||||||
|
|
||||||
public async ocr(file: SupportedOcrFile, provider: OcrProvider): Promise<OcrResult> {
|
public async ocr(file: SupportedOcrFile, provider: OcrProvider): Promise<OcrResult> {
|
||||||
const handler = this.registry.get(provider.id)
|
const handler = this.registry.get(provider.id)
|
||||||
if (!handler) {
|
if (!handler) {
|
||||||
@@ -39,3 +44,5 @@ ocrService.register(BuiltinOcrProviderIds.tesseract, tesseractService.ocr.bind(t
|
|||||||
!isLinux && ocrService.register(BuiltinOcrProviderIds.system, systemOcrService.ocr.bind(systemOcrService))
|
!isLinux && ocrService.register(BuiltinOcrProviderIds.system, systemOcrService.ocr.bind(systemOcrService))
|
||||||
|
|
||||||
ocrService.register(BuiltinOcrProviderIds.paddleocr, ppocrService.ocr.bind(ppocrService))
|
ocrService.register(BuiltinOcrProviderIds.paddleocr, ppocrService.ocr.bind(ppocrService))
|
||||||
|
|
||||||
|
ovOcrService.isAvailable() && ocrService.register(BuiltinOcrProviderIds.ovocr, ovOcrService.ocr.bind(ovOcrService))
|
||||||
|
|||||||
@@ -0,0 +1,128 @@
|
|||||||
|
import { loggerService } from '@logger'
|
||||||
|
import { isWin } from '@main/constant'
|
||||||
|
import { isImageFileMetadata, OcrOvConfig, OcrResult, SupportedOcrFile } from '@types'
|
||||||
|
import { exec } from 'child_process'
|
||||||
|
import * as fs from 'fs'
|
||||||
|
import * as os from 'os'
|
||||||
|
import * as path from 'path'
|
||||||
|
import { promisify } from 'util'
|
||||||
|
|
||||||
|
import { OcrBaseService } from './OcrBaseService'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('OvOcrService')
|
||||||
|
const execAsync = promisify(exec)
|
||||||
|
|
||||||
|
const PATH_BAT_FILE = path.join(os.homedir(), '.cherrystudio', 'ovms', 'ovocr', 'run.npu.bat')
|
||||||
|
|
||||||
|
export class OvOcrService extends OcrBaseService {
|
||||||
|
constructor() {
|
||||||
|
super()
|
||||||
|
}
|
||||||
|
|
||||||
|
public isAvailable(): boolean {
|
||||||
|
return (
|
||||||
|
isWin &&
|
||||||
|
os.cpus()[0].model.toLowerCase().includes('intel') &&
|
||||||
|
os.cpus()[0].model.toLowerCase().includes('ultra') &&
|
||||||
|
fs.existsSync(PATH_BAT_FILE)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private getOvOcrPath(): string {
|
||||||
|
return path.join(os.homedir(), '.cherrystudio', 'ovms', 'ovocr')
|
||||||
|
}
|
||||||
|
|
||||||
|
private getImgDir(): string {
|
||||||
|
return path.join(this.getOvOcrPath(), 'img')
|
||||||
|
}
|
||||||
|
|
||||||
|
private getOutputDir(): string {
|
||||||
|
return path.join(this.getOvOcrPath(), 'output')
|
||||||
|
}
|
||||||
|
|
||||||
|
private async clearDirectory(dirPath: string): Promise<void> {
|
||||||
|
if (fs.existsSync(dirPath)) {
|
||||||
|
const files = await fs.promises.readdir(dirPath)
|
||||||
|
for (const file of files) {
|
||||||
|
const filePath = path.join(dirPath, file)
|
||||||
|
const stats = await fs.promises.stat(filePath)
|
||||||
|
if (stats.isDirectory()) {
|
||||||
|
await this.clearDirectory(filePath)
|
||||||
|
await fs.promises.rmdir(filePath)
|
||||||
|
} else {
|
||||||
|
await fs.promises.unlink(filePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If the directory does not exist, create it
|
||||||
|
await fs.promises.mkdir(dirPath, { recursive: true })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async copyFileToImgDir(sourceFilePath: string, targetFileName: string): Promise<void> {
|
||||||
|
const imgDir = this.getImgDir()
|
||||||
|
const targetFilePath = path.join(imgDir, targetFileName)
|
||||||
|
await fs.promises.copyFile(sourceFilePath, targetFilePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
private async runOcrBatch(): Promise<void> {
|
||||||
|
const ovOcrPath = this.getOvOcrPath()
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Execute run.bat in the ov-ocr directory
|
||||||
|
await execAsync(`"${PATH_BAT_FILE}"`, {
|
||||||
|
cwd: ovOcrPath,
|
||||||
|
timeout: 60000 // 60 second timeout
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Error running ovocr batch: ${error}`)
|
||||||
|
throw new Error(`Failed to run OCR batch: ${error}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async ocrImage(filePath: string, options?: OcrOvConfig): Promise<OcrResult> {
|
||||||
|
logger.info(`OV OCR called on ${filePath} with options ${JSON.stringify(options)}`)
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 1. Clear img directory and output directory
|
||||||
|
await this.clearDirectory(this.getImgDir())
|
||||||
|
await this.clearDirectory(this.getOutputDir())
|
||||||
|
|
||||||
|
// 2. Copy file to img directory
|
||||||
|
const fileName = path.basename(filePath)
|
||||||
|
await this.copyFileToImgDir(filePath, fileName)
|
||||||
|
logger.info(`File copied to img directory: ${fileName}`)
|
||||||
|
|
||||||
|
// 3. Run run.bat
|
||||||
|
logger.info('Running OV OCR batch process...')
|
||||||
|
await this.runOcrBatch()
|
||||||
|
|
||||||
|
// 4. Check that output/[basename].txt file exists
|
||||||
|
const baseNameWithoutExt = path.basename(fileName, path.extname(fileName))
|
||||||
|
const outputFilePath = path.join(this.getOutputDir(), `${baseNameWithoutExt}.txt`)
|
||||||
|
if (!fs.existsSync(outputFilePath)) {
|
||||||
|
throw new Error(`OV OCR output file not found at: ${outputFilePath}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Read output/[basename].txt file content
|
||||||
|
const ocrText = await fs.promises.readFile(outputFilePath, 'utf-8')
|
||||||
|
logger.info(`OV OCR text extracted: ${ocrText.substring(0, 100)}...`)
|
||||||
|
|
||||||
|
// 6. Return result
|
||||||
|
return { text: ocrText }
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Error during OV OCR process: ${error}`)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public ocr = async (file: SupportedOcrFile, options?: OcrOvConfig): Promise<OcrResult> => {
|
||||||
|
if (isImageFileMetadata(file)) {
|
||||||
|
return this.ocrImage(file.path, options)
|
||||||
|
} else {
|
||||||
|
throw new Error('Unsupported file type, currently only image files are supported')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const ovOcrService = new OvOcrService()
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import { loadOcrImage } from '@main/utils/ocr'
|
import { loadOcrImage } from '@main/utils/ocr'
|
||||||
import { ImageFileMetadata, isImageFileMetadata, OcrPpocrConfig, OcrResult, SupportedOcrFile } from '@types'
|
import { ImageFileMetadata, isImageFileMetadata, OcrPpocrConfig, OcrResult, SupportedOcrFile } from '@types'
|
||||||
import { net } from 'electron'
|
import { net } from 'electron'
|
||||||
import { z } from 'zod'
|
import * as z from 'zod'
|
||||||
|
|
||||||
import { OcrBaseService } from './OcrBaseService'
|
import { OcrBaseService } from './OcrBaseService'
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user