Skip to content
Draft
23 changes: 21 additions & 2 deletions packages/ai-proxy/src/ai-client.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { McpServerLoadFailure } from './mcp-client';
import type { AiConfiguration } from './provider';
import type RemoteTool from './remote-tool';
import type { ToolProvider } from './tool-provider';
Expand Down Expand Up @@ -39,13 +40,31 @@ export class AiClient {
}

async loadRemoteTools(configs: Record<string, ToolConfig>): Promise<RemoteTool[]> {
return (await this.loadRemoteToolsWithFailures(configs)).tools;
}

// Same load as loadRemoteTools, but also returns the classified per-server failures providers
// surface (only MCP providers do today). The default loadRemoteTools drops them, so existing
// consumers are unaffected.
async loadRemoteToolsWithFailures(
configs: Record<string, ToolConfig>,
): Promise<{ tools: RemoteTool[]; failures: McpServerLoadFailure[] }> {
await this.disposeToolProviders('Error closing previous remote tool connection');

const providers = createToolProviders(configs, this.logger);
const toolsByProvider = await Promise.all(providers.map(p => p.loadTools()));
const resultsByProvider = await Promise.all(
providers.map(async provider =>
provider.loadToolsWithFailures
? provider.loadToolsWithFailures()
: { tools: await provider.loadTools(), failures: [] },
),
);
this.toolProviders = providers;

return toolsByProvider.flat();
return {
tools: resultsByProvider.flatMap(result => result.tools),
failures: resultsByProvider.flatMap(result => result.failures),
};
}

async closeConnections(): Promise<void> {
Expand Down
1 change: 1 addition & 0 deletions packages/ai-proxy/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export * from './remote-tools';
export { default as RemoteTool } from './remote-tool';
export * from './router';
export * from './mcp-client';
export * from './mcp-auth-error';
export * from './oauth-token-injector';
export * from './errors';
export * from './tool-provider';
Expand Down
60 changes: 60 additions & 0 deletions packages/ai-proxy/src/mcp-auth-error.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Classifies errors surfaced while connecting to or calling an MCP server. Only 401 (the token was
// rejected) is a refreshable auth failure; 403 is a permission/scope problem a token refresh or
// re-consent cannot resolve, so it is left to surface as an ordinary failure. The MCP SDK / HTTP
// transport reports failures in several shapes (a numeric status field, or only a message string),
// so the checks walk the cause chain and inspect both structured status and the message text.
const AUTH_STATUSES = new Set([401]);
const AUTH_PATTERN = /\b401\b|unauthorized/i;
const CONNECTION_PATTERN =
/econnrefused|econnreset|etimedout|enotfound|eai_again|fetch failed|network|socket|timeout|connect/i;

export type McpLoadFailureKind = 'auth' | 'connection' | 'unknown';

function statusOf(value: unknown): number | undefined {
const candidate = value as { code?: unknown; status?: unknown; statusCode?: unknown };

for (const field of [candidate?.code, candidate?.status, candidate?.statusCode]) {
if (typeof field === 'number') return field;
}

return undefined;
}

function messageOf(value: unknown): string {
if (value instanceof Error) return value.message;
if (typeof value === 'string') return value;

return '';
}

function errorChain(error: unknown): unknown[] {
const links: unknown[] = [];
let current: unknown = error;

while (current && links.length < 10 && !links.includes(current)) {
links.push(current);
current = (current as { cause?: unknown }).cause;
}

return links;
}

export function isMcpAuthError(error: unknown): boolean {
return errorChain(error).some(link => {
const status = statusOf(link);

return (
(status !== undefined && AUTH_STATUSES.has(status)) || AUTH_PATTERN.test(messageOf(link))
);
});
}

export function classifyMcpLoadError(error: unknown): McpLoadFailureKind {
if (isMcpAuthError(error)) return 'auth';

const isConnectionFailure = errorChain(error).some(link =>
CONNECTION_PATTERN.test(messageOf(link)),
);

return isConnectionFailure ? 'connection' : 'unknown';
}
46 changes: 37 additions & 9 deletions packages/ai-proxy/src/mcp-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,28 @@ import type { Logger } from '@forestadmin/datasource-toolkit';
import { MultiServerMCPClient } from '@langchain/mcp-adapters';

import { McpConnectionError } from './errors';
import { type McpLoadFailureKind, classifyMcpLoadError } from './mcp-auth-error';
import McpServerRemoteTool from './mcp-server-remote-tool';

export type McpServers = MultiServerMCPClient['config']['mcpServers'];

export type McpServerConfig = MultiServerMCPClient['config']['mcpServers'][string] & {
id?: string;
// Executor-side routing hint served by the orchestrator; stripped before reaching the SDK.
authType?: string;
};

export type McpConfiguration = {
configs: Record<string, McpServerConfig>;
} & Omit<MultiServerMCPClient['config'], 'mcpServers'>;

export interface McpServerLoadFailure {
server: string;
mcpServerId?: string;
kind: McpLoadFailureKind;
error: Error;
}

export default class McpClient implements ToolProvider {
private readonly mcpClients: Record<string, MultiServerMCPClient> = {};
private readonly mcpServerIdsByName: Record<string, string | undefined> = {};
Expand All @@ -26,8 +36,11 @@ export default class McpClient implements ToolProvider {
// split the config into several clients to be more resilient
// if a mcp server is down, the others will still work
Object.entries(config.configs).forEach(([name, serverConfig]) => {
const { id: mcpServerId, ...rest } = serverConfig as McpServerConfig &
Record<string, unknown>;
const {
id: mcpServerId,
authType,
...rest
} = serverConfig as McpServerConfig & Record<string, unknown>;
this.mcpServerIdsByName[name] = mcpServerId;
this.mcpClients[name] = new MultiServerMCPClient({
mcpServers: { [name]: rest as McpServerConfig },
Expand All @@ -36,9 +49,15 @@ export default class McpClient implements ToolProvider {
});
}

async loadTools(): Promise<McpServerRemoteTool[]> {
// Exposes per-server failures classified by cause (auth vs connection) alongside the tools that
// did load, so a caller holding a per-user token can tell a revoked token (retry after refresh)
// from an unreachable server (genuine failure). loadTools() keeps its tools-only contract.
async loadToolsWithFailures(): Promise<{
tools: McpServerRemoteTool[];
failures: McpServerLoadFailure[];
}> {
const tools: McpServerRemoteTool[] = [];
const errors: Array<{ server: string; error: Error }> = [];
const failures: McpServerLoadFailure[] = [];

await Promise.all(
Object.entries(this.mcpClients).map(async ([name, client]) => {
Expand All @@ -55,22 +74,31 @@ export default class McpClient implements ToolProvider {
tools.push(...extendedTools);
} catch (error) {
this.logger?.('Error', `Error loading tools for ${name}`, error as Error);
errors.push({ server: name, error: error as Error });
failures.push({
server: name,
mcpServerId: this.mcpServerIdsByName[name],
kind: classifyMcpLoadError(error),
error: error as Error,
});
}
}),
);

// Surface partial failures to provide better feedback
if (errors.length > 0) {
const errorMessage = errors.map(e => `${e.server}: ${e.error.message}`).join('; ');
if (failures.length > 0) {
const errorMessage = failures.map(f => `${f.server}: ${f.error.message}`).join('; ');
this.logger?.(
'Error',
`Failed to load tools from ${errors.length}/${Object.keys(this.mcpClients).length} ` +
`Failed to load tools from ${failures.length}/${Object.keys(this.mcpClients).length} ` +
`MCP server(s): ${errorMessage}`,
);
}

return tools;
return { tools, failures };
}

async loadTools(): Promise<McpServerRemoteTool[]> {
return (await this.loadToolsWithFailures()).tools;
}

async checkConnection(): Promise<true> {
Expand Down
3 changes: 3 additions & 0 deletions packages/ai-proxy/src/tool-provider.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import type { McpServerLoadFailure } from './mcp-client';
import type RemoteTool from './remote-tool';

export interface ToolProvider {
loadTools(): Promise<RemoteTool[]>;
// Optional richer variant: providers that can classify per-server failures expose them here.
loadToolsWithFailures?(): Promise<{ tools: RemoteTool[]; failures: McpServerLoadFailure[] }>;
checkConnection(): Promise<true>;
dispose(): Promise<void>;
}
40 changes: 40 additions & 0 deletions packages/ai-proxy/test/ai-client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,46 @@ describe('loadRemoteTools', () => {
});
});

describe('loadRemoteToolsWithFailures', () => {
beforeEach(() => {
jest.clearAllMocks();
});

it('aggregates tools and classified failures from providers that expose them', async () => {
const mcpTool = { name: 'mcp-tool' };
const failure = {
server: 'slack',
mcpServerId: 'srv-a',
kind: 'auth' as const,
error: new Error('401'),
};
mockedCreateToolProviders.mockReturnValue([
mockProvider({
loadToolsWithFailures: jest
.fn()
.mockResolvedValue({ tools: [mcpTool], failures: [failure] }),
}),
]);

const result = await new AiClient({}).loadRemoteToolsWithFailures({} as never);

expect(result.tools).toEqual([mcpTool]);
expect(result.failures).toEqual([failure]);
});

it('falls back to loadTools with no failures for providers that do not classify', async () => {
const integrationTool = { name: 'zendesk-tool' };
mockedCreateToolProviders.mockReturnValue([
mockProvider({ loadTools: jest.fn().mockResolvedValue([integrationTool]) }),
]);

const result = await new AiClient({}).loadRemoteToolsWithFailures({} as never);

expect(result.tools).toEqual([integrationTool]);
expect(result.failures).toEqual([]);
});
});

describe('closeConnections', () => {
beforeEach(() => {
jest.clearAllMocks();
Expand Down
53 changes: 53 additions & 0 deletions packages/ai-proxy/test/mcp-auth-error.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { classifyMcpLoadError, isMcpAuthError } from '../src/mcp-auth-error';

function withCause(message: string, cause: unknown): Error {
const error = new Error(message);
(error as { cause?: unknown }).cause = cause;

return error;
}

describe('isMcpAuthError', () => {
it('detects a 401 numeric status field', () => {
expect(isMcpAuthError({ code: 401 })).toBe(true);
});

it('detects 401 / unauthorized in the message', () => {
expect(isMcpAuthError(new Error('Request failed with status code 401'))).toBe(true);
expect(isMcpAuthError(new Error('Unauthorized'))).toBe(true);
});

it('walks the cause chain', () => {
expect(
isMcpAuthError(withCause('wrapper', Object.assign(new Error('inner'), { status: 401 }))),
).toBe(true);
});

it('returns false for 403 (forbidden), non-auth errors, and nullish input', () => {
expect(isMcpAuthError(Object.assign(new Error('denied'), { status: 403 }))).toBe(false);
expect(isMcpAuthError(new Error('403 Forbidden'))).toBe(false);
expect(isMcpAuthError(new Error('ECONNREFUSED'))).toBe(false);
expect(isMcpAuthError(new Error('500 Internal Server Error'))).toBe(false);
expect(isMcpAuthError(undefined)).toBe(false);
});
});

describe('classifyMcpLoadError', () => {
it("classifies a 401 as 'auth'", () => {
expect(classifyMcpLoadError(new Error('HTTP 401 Unauthorized'))).toBe('auth');
expect(classifyMcpLoadError({ status: 401 })).toBe('auth');
});

it("classifies network failures as 'connection'", () => {
expect(classifyMcpLoadError(new Error('connect ECONNREFUSED 127.0.0.1:3000'))).toBe(
'connection',
);
expect(classifyMcpLoadError(new Error('fetch failed'))).toBe('connection');
expect(classifyMcpLoadError(new Error('socket hang up'))).toBe('connection');
});

it("classifies a 403 (forbidden) and anything else as 'unknown'", () => {
expect(classifyMcpLoadError(new Error('HTTP 403 Forbidden'))).toBe('unknown');
expect(classifyMcpLoadError(new Error('tool schema invalid'))).toBe('unknown');
});
});
Loading
Loading