fix: enforce message context isolation

This commit is contained in:
Peter Steinberger
2026-01-13 01:03:23 +00:00
parent 0edbdb1948
commit ffc465394e
6 changed files with 164 additions and 5 deletions

View File

@@ -9,6 +9,7 @@
- Update: run `clawdbot doctor --non-interactive` during updates to avoid TTY hangs. (#781 — thanks @ronyrus)
- Tools: allow Claude/Gemini tool param aliases (`file_path`, `old_string`, `new_string`) while enforcing required params at runtime. (#793 — thanks @hsrvc)
- Gemini: downgrade tool-call history missing `thought_signature` to avoid INVALID_ARGUMENT errors. (#793 — thanks @hsrvc)
- Messaging: enforce context isolation for message tool sends across providers (normalized targets + tests). (#793 — thanks @hsrvc)
## 2026.1.12-3

View File

@@ -179,6 +179,7 @@ Core actions:
Notes:
- `send` routes WhatsApp via the Gateway; other providers go direct.
- `poll` uses the Gateway for WhatsApp and MS Teams; Discord polls go direct.
- When a message tool call is bound to an active chat session, sends are constrained to that sessions target to avoid cross-context leaks.
### `cron`
Manage Gateway cron jobs and wakeups.

View File

@@ -2,6 +2,10 @@ import { Type } from "@sinclair/typebox";
import type { ClawdbotConfig } from "../../config/config.js";
import { loadConfig } from "../../config/config.js";
import {
GATEWAY_CLIENT_IDS,
GATEWAY_CLIENT_MODES,
} from "../../gateway/protocol/client-info.js";
import { runMessageAction } from "../../infra/outbound/message-action-runner.js";
import {
listProviderMessageActions,
@@ -12,10 +16,6 @@ import {
type ProviderMessageActionName,
} from "../../providers/plugins/types.js";
import { normalizeAccountId } from "../../routing/session-key.js";
import {
GATEWAY_CLIENT_MODES,
GATEWAY_CLIENT_NAMES,
} from "../../utils/message-provider.js";
import type { AnyAgentTool } from "./common.js";
import { jsonResult, readNumberParam, readStringParam } from "./common.js";
@@ -184,7 +184,7 @@ export function createMessageTool(options?: MessageToolOptions): AnyAgentTool {
url: readStringParam(params, "gatewayUrl", { trim: false }),
token: readStringParam(params, "gatewayToken", { trim: false }),
timeoutMs: readNumberParam(params, "timeoutMs"),
clientName: GATEWAY_CLIENT_NAMES.GATEWAY_CLIENT,
clientName: GATEWAY_CLIENT_IDS.GATEWAY_CLIENT,
clientDisplayName: "agent",
mode: GATEWAY_CLIENT_MODES.BACKEND,
};

View File

@@ -0,0 +1,61 @@
import { describe, expect, it } from "vitest";
import type { ClawdbotConfig } from "../../config/config.js";
import { runMessageAction } from "./message-action-runner.js";
const slackConfig = {
slack: {
botToken: "xoxb-test",
appToken: "xapp-test",
},
} as ClawdbotConfig;
describe("runMessageAction context isolation", () => {
it("allows send when target matches current channel", async () => {
const result = await runMessageAction({
cfg: slackConfig,
action: "send",
params: {
provider: "slack",
to: "#C123",
message: "hi",
},
toolContext: { currentChannelId: "C123" },
dryRun: true,
});
expect(result.kind).toBe("send");
});
it("blocks send when target differs from current channel", async () => {
await expect(
runMessageAction({
cfg: slackConfig,
action: "send",
params: {
provider: "slack",
to: "channel:C999",
message: "hi",
},
toolContext: { currentChannelId: "C123" },
dryRun: true,
}),
).rejects.toThrow(/Cross-context messaging denied/);
});
it("blocks thread-reply when channelId differs from current channel", async () => {
await expect(
runMessageAction({
cfg: slackConfig,
action: "thread-reply",
params: {
provider: "slack",
channelId: "C999",
message: "hi",
},
toolContext: { currentChannelId: "C123" },
dryRun: true,
}),
).rejects.toThrow(/Cross-context messaging denied/);
});
});

View File

@@ -1,4 +1,5 @@
import type { AgentToolResult } from "@mariozechner/pi-agent-core";
import { normalizeTargetForProvider } from "../../agents/pi-embedded-messaging.js";
import {
readNumberParam,
readStringArrayParam,
@@ -125,6 +126,56 @@ function parseButtonsParam(params: Record<string, unknown>): void {
}
}
const CONTEXT_GUARDED_ACTIONS = new Set<ProviderMessageActionName>([
"send",
"poll",
"thread-create",
"thread-reply",
"sticker",
]);
function resolveContextGuardTarget(
action: ProviderMessageActionName,
params: Record<string, unknown>,
): string | undefined {
if (!CONTEXT_GUARDED_ACTIONS.has(action)) return undefined;
if (action === "thread-reply" || action === "thread-create") {
return (
readStringParam(params, "channelId") ?? readStringParam(params, "to")
);
}
return readStringParam(params, "to") ?? readStringParam(params, "channelId");
}
function enforceContextIsolation(params: {
provider: ProviderId;
action: ProviderMessageActionName;
params: Record<string, unknown>;
toolContext?: ProviderThreadingToolContext;
}): void {
const currentTarget = params.toolContext?.currentChannelId?.trim();
if (!currentTarget) return;
if (!CONTEXT_GUARDED_ACTIONS.has(params.action)) return;
const target = resolveContextGuardTarget(params.action, params.params);
if (!target) return;
const normalizedTarget =
normalizeTargetForProvider(params.provider, target) ?? target.toLowerCase();
const normalizedCurrent =
normalizeTargetForProvider(params.provider, currentTarget) ??
currentTarget.toLowerCase();
if (!normalizedTarget || !normalizedCurrent) return;
if (normalizedTarget === normalizedCurrent) return;
throw new Error(
`Cross-context messaging denied: action=${params.action} target="${target}" while bound to "${currentTarget}" (provider=${params.provider}).`,
);
}
async function resolveProvider(
cfg: ClawdbotConfig,
params: Record<string, unknown>,
@@ -150,6 +201,13 @@ export async function runMessageAction(
readStringParam(params, "accountId") ?? input.defaultAccountId;
const dryRun = Boolean(input.dryRun ?? readBooleanParam(params, "dryRun"));
enforceContextIsolation({
provider,
action,
params,
toolContext: input.toolContext,
});
const gateway = input.gateway
? {
url: input.gateway.url,

View File

@@ -102,6 +102,11 @@ const DOCKS: Record<ProviderId, ProviderDock> = {
},
threading: {
resolveReplyToMode: ({ cfg }) => cfg.telegram?.replyToMode ?? "first",
buildToolContext: ({ context, hasRepliedRef }) => ({
currentChannelId: context.To?.trim() || undefined,
currentThreadTs: context.ReplyToId,
hasRepliedRef,
}),
},
},
whatsapp: {
@@ -142,6 +147,13 @@ const DOCKS: Record<ProviderId, ProviderDock> = {
return [escaped, `@${escaped}`];
},
},
threading: {
buildToolContext: ({ context, hasRepliedRef }) => ({
currentChannelId: context.To?.trim() || undefined,
currentThreadTs: context.ReplyToId,
hasRepliedRef,
}),
},
},
discord: {
id: "discord",
@@ -175,6 +187,11 @@ const DOCKS: Record<ProviderId, ProviderDock> = {
},
threading: {
resolveReplyToMode: ({ cfg }) => cfg.discord?.replyToMode ?? "off",
buildToolContext: ({ context, hasRepliedRef }) => ({
currentChannelId: context.To?.trim() || undefined,
currentThreadTs: context.ReplyToId,
hasRepliedRef,
}),
},
},
slack: {
@@ -246,6 +263,13 @@ const DOCKS: Record<ProviderId, ProviderDock> = {
)
.filter(Boolean),
},
threading: {
buildToolContext: ({ context, hasRepliedRef }) => ({
currentChannelId: context.To?.trim() || undefined,
currentThreadTs: context.ReplyToId,
hasRepliedRef,
}),
},
},
imessage: {
id: "imessage",
@@ -266,6 +290,13 @@ const DOCKS: Record<ProviderId, ProviderDock> = {
groups: {
resolveRequireMention: resolveIMessageGroupRequireMention,
},
threading: {
buildToolContext: ({ context, hasRepliedRef }) => ({
currentChannelId: context.To?.trim() || undefined,
currentThreadTs: context.ReplyToId,
hasRepliedRef,
}),
},
},
msteams: {
id: "msteams",
@@ -280,6 +311,13 @@ const DOCKS: Record<ProviderId, ProviderDock> = {
resolveAllowFrom: ({ cfg }) => cfg.msteams?.allowFrom ?? [],
formatAllowFrom: ({ allowFrom }) => formatLower(allowFrom),
},
threading: {
buildToolContext: ({ context, hasRepliedRef }) => ({
currentChannelId: context.To?.trim() || undefined,
currentThreadTs: context.ReplyToId,
hasRepliedRef,
}),
},
},
};