/** * Custom LangChain ChatModel for Cloudflare Workers AI * * Cloudflare Workers AI has limitations: * - Does NOT support multi-turn tool conversations (error 3043) * - Some models (Mistral) output tools as text: [TOOL_CALLS][{...}] * * This class handles these quirks to provide a LangChain-compatible interface. */ import { BaseChatModel } from "@langchain/core/language_models/chat_models"; import { AIMessage } from "@langchain/core/messages"; import { getCloudflareAccountId, getCloudflareApiToken } from "./providerConfig.js"; import { toolSchemas } from "./langchainTools.js"; import { zodToJsonSchema } from "zod-to-json-schema"; /** * Try to repair and parse a potentially truncated JSON object */ function tryRepairAndParse(jsonStr) { try { return JSON.parse(jsonStr); } catch (e) { const repairs = [ jsonStr + '}', jsonStr + '"}', jsonStr + '}}', jsonStr + '"}}', jsonStr + ': null}}', jsonStr + '": null}}' ]; for (const attempt of repairs) { try { const parsed = JSON.parse(attempt); if (parsed.name) { return parsed; } } catch (e2) { // Continue trying } } return null; } } /** * Parse tool calls from text response * Handles: [TOOL_CALLS][{...}] and [Called tool: name({args})] */ function parseTextToolCalls(text) { if (!text) return { cleanText: '', toolCalls: [] }; const toolCalls = []; let cleanText = text; // Format 1: [TOOL_CALLS][...] (Mistral native format) const toolCallMatch = text.match(/\[TOOL_CALLS\]\s*(\[[\s\S]*)/); if (toolCallMatch) { const toolCallsJson = toolCallMatch[1]; // Try normal JSON.parse first try { const parsedCalls = JSON.parse(toolCallsJson); if (Array.isArray(parsedCalls)) { const validCalls = parsedCalls .filter(call => call && call.name) .map(call => ({ id: `toolu_cf_${Date.now()}_${toolCalls.length}`, name: call.name, args: call.arguments || {} })); cleanText = text.replace(/\[TOOL_CALLS\]\s*\[[\s\S]*/, '').trim(); return { cleanText, toolCalls: validCalls }; } } catch (e) { // Try multi-line format } // Try parsing as multiple single-element arrays on separate lines const lines = toolCallsJson.split('\n'); const lineMatches = []; for (const line of lines) { const trimmed = line.trim(); if (trimmed.startsWith('[') && trimmed.endsWith(']')) { try { const parsed = JSON.parse(trimmed); if (Array.isArray(parsed) && parsed.length > 0 && parsed[0].name) { lineMatches.push({ id: `toolu_cf_${Date.now()}_${lineMatches.length}`, name: parsed[0].name, args: parsed[0].arguments || {} }); } } catch (e) { const objMatch = trimmed.match(/\[\s*(\{[\s\S]*)/); if (objMatch) { const repaired = tryRepairAndParse(objMatch[1].replace(/\]\s*$/, '')); if (repaired && repaired.name) { lineMatches.push({ id: `toolu_cf_${Date.now()}_${lineMatches.length}`, name: repaired.name, args: repaired.arguments || {} }); } } } } } if (lineMatches.length > 0) { cleanText = text.replace(/\[TOOL_CALLS\][\s\S]*/, '').trim(); return { cleanText, toolCalls: lineMatches }; } // Extract individual tool calls using regex for truncated JSON const startPattern = /\{"name"\s*:\s*"/g; let match; const toolCallStarts = []; while ((match = startPattern.exec(toolCallsJson)) !== null) { toolCallStarts.push(match.index); } for (let i = 0; i < toolCallStarts.length; i++) { const start = toolCallStarts[i]; const end = toolCallStarts[i + 1] || toolCallsJson.length; let segment = toolCallsJson.substring(start, end).replace(/,\s*$/, ''); const parsed = tryRepairAndParse(segment); if (parsed && parsed.name) { toolCalls.push({ id: `toolu_cf_${Date.now()}_${i}`, name: parsed.name, args: parsed.arguments || {} }); } } cleanText = text.replace(/\[TOOL_CALLS\]\s*\[[\s\S]*/, '').trim(); if (toolCalls.length > 0) { return { cleanText, toolCalls }; } } // Format 2: [Called tool: name({args})] const calledToolPattern = /\[Called tool:\s*(\w+)\((\{[\s\S]*?\})\)\]/g; let calledMatch; while ((calledMatch = calledToolPattern.exec(text)) !== null) { try { const args = JSON.parse(calledMatch[2]); toolCalls.push({ id: `toolu_cf_${Date.now()}_${toolCalls.length}`, name: calledMatch[1], args: args }); cleanText = cleanText.replace(calledMatch[0], ''); } catch (e) { const repaired = tryRepairAndParse(calledMatch[2]); if (repaired) { toolCalls.push({ id: `toolu_cf_${Date.now()}_${toolCalls.length}`, name: calledMatch[1], args: repaired }); cleanText = cleanText.replace(calledMatch[0], ''); } } } return { cleanText: cleanText.trim(), toolCalls }; } /** * Convert LangChain messages to Cloudflare format * IMPORTANT: Converts tool history to text (Cloudflare limitation) */ function messagesToCloudflare(messages) { const cfMessages = []; for (const msg of messages) { const msgType = msg.constructor.name; if (msgType === 'SystemMessage') { cfMessages.push({ role: "system", content: msg.content }); } else if (msgType === 'HumanMessage') { cfMessages.push({ role: "user", content: msg.content }); } else if (msgType === 'AIMessage') { // Convert tool calls to text for Cloudflare let content = msg.content || ''; if (msg.tool_calls && msg.tool_calls.length > 0) { const toolText = msg.tool_calls.map(tc => `[Called tool: ${tc.name}(${JSON.stringify(tc.args)})]` ).join('\n'); content = content ? `${content}\n${toolText}` : toolText; } if (content) { cfMessages.push({ role: "assistant", content: content }); } } else if (msgType === 'ToolMessage') { // Convert tool results to user messages cfMessages.push({ role: "user", content: `[Tool Result (${msg.name}): ${msg.content}]` }); } } return cfMessages; } /** * ChatCloudflare - LangChain ChatModel for Cloudflare Workers AI */ export class ChatCloudflare extends BaseChatModel { static lc_name() { return "ChatCloudflare"; } constructor(fields = {}) { super(fields); this.model = fields.model || "@cf/mistral/mistral-7b-instruct-v0.2"; this.accountId = fields.accountId || getCloudflareAccountId(); this.apiToken = fields.apiToken || getCloudflareApiToken(); this.maxTokens = fields.maxTokens || 1024; this._tools = []; } _llmType() { return "cloudflare"; } /** * Bind tools to this model instance */ bindTools(tools) { const bound = new ChatCloudflare({ model: this.model, accountId: this.accountId, apiToken: this.apiToken, maxTokens: this.maxTokens }); // Convert tools to OpenAI format bound._tools = tools.map(tool => ({ type: "function", function: { name: tool.name, description: tool.description, parameters: zodToJsonSchema(tool.schema, { target: "openApi3" }) } })); return bound; } async _generate(messages, options, runManager) { const cfMessages = messagesToCloudflare(messages); const requestBody = { messages: cfMessages, max_tokens: this.maxTokens }; // Add tools if bound if (this._tools.length > 0) { requestBody.tools = this._tools; } const endpoint = `https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/run/${this.model}`; console.log(`[ChatCloudflare] Sending request to: ${endpoint}`); console.log(`[ChatCloudflare] Messages: ${cfMessages.length}, Tools: ${this._tools.length}`); const response = await fetch(endpoint, { method: "POST", headers: { "Content-Type": "application/json", "Authorization": `Bearer ${this.apiToken}`, }, body: JSON.stringify(requestBody), }); const cfData = await response.json(); if (!cfData.success) { throw new Error(cfData.errors?.[0]?.message || "Cloudflare API error"); } const result = cfData.result || cfData; // Get tool calls from proper field or parse from text let toolCalls = result.tool_calls || []; let textResponse = result.response || ''; // Parse tool calls from text if not present natively if (toolCalls.length === 0 && textResponse) { const parsed = parseTextToolCalls(textResponse); if (parsed.toolCalls.length > 0) { toolCalls = parsed.toolCalls; textResponse = parsed.cleanText; } } else { // Convert native tool calls to LangChain format toolCalls = toolCalls.map((tc, i) => ({ id: `toolu_cf_${Date.now()}_${i}`, name: tc.name, args: typeof tc.arguments === 'string' ? JSON.parse(tc.arguments) : (tc.arguments || {}) })); } // Create AIMessage with tool calls const aiMessage = new AIMessage({ content: textResponse, tool_calls: toolCalls.length > 0 ? toolCalls : undefined, usage_metadata: { input_tokens: result.usage?.prompt_tokens || 0, output_tokens: result.usage?.completion_tokens || 0 } }); return { generations: [{ text: textResponse, message: aiMessage, }], llmOutput: { tokenUsage: { promptTokens: result.usage?.prompt_tokens || 0, completionTokens: result.usage?.completion_tokens || 0, } } }; } } /** * Get a Cloudflare model instance with tools bound * @param {string} modelId - The Cloudflare model name * @returns {ChatCloudflare} ChatCloudflare instance with tools */ export function getCloudflareModel(modelId) { const model = new ChatCloudflare({ model: modelId }); const langchainTools = Object.values(toolSchemas).map(tool => ({ name: tool.name, description: tool.description, schema: tool.schema })); return model.bindTools(langchainTools); } export default ChatCloudflare;