immersive2/server/services/ChatCloudflare.js
Michael Mainguy 4ca98cf980 Add LangChain model wrappers and enhance diagram AI tools
- Migrate to LangChain for model abstraction (@langchain/anthropic, @langchain/ollama)
- Add custom ChatCloudflare class for Cloudflare Workers AI
- Simplify API routes using unified LangChain interface
- Add session preferences API for storing user settings
- Add connection label preference (ask user once, remember for session)
- Add shape modification support (change entity shapes via AI)
- Add template setter to DiagramObject for shape changes
- Improve entity inference with fuzzy matching
- Map colors to 16 toolbox palette colors
- Limit conversation history to last 6 messages
- Fix model switching to accept display names

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-14 10:17:15 -06:00

368 lines
12 KiB
JavaScript

/**
* Custom LangChain ChatModel for Cloudflare Workers AI
*
* Cloudflare Workers AI has limitations:
* - Does NOT support multi-turn tool conversations (error 3043)
* - Some models (Mistral) output tools as text: [TOOL_CALLS][{...}]
*
* This class handles these quirks to provide a LangChain-compatible interface.
*/
import { BaseChatModel } from "@langchain/core/language_models/chat_models";
import { AIMessage } from "@langchain/core/messages";
import { getCloudflareAccountId, getCloudflareApiToken } from "./providerConfig.js";
import { toolSchemas } from "./langchainTools.js";
import { zodToJsonSchema } from "zod-to-json-schema";
/**
* Try to repair and parse a potentially truncated JSON object
*/
function tryRepairAndParse(jsonStr) {
try {
return JSON.parse(jsonStr);
} catch (e) {
const repairs = [
jsonStr + '}',
jsonStr + '"}',
jsonStr + '}}',
jsonStr + '"}}',
jsonStr + ': null}}',
jsonStr + '": null}}'
];
for (const attempt of repairs) {
try {
const parsed = JSON.parse(attempt);
if (parsed.name) {
return parsed;
}
} catch (e2) {
// Continue trying
}
}
return null;
}
}
/**
* Parse tool calls from text response
* Handles: [TOOL_CALLS][{...}] and [Called tool: name({args})]
*/
function parseTextToolCalls(text) {
if (!text) return { cleanText: '', toolCalls: [] };
const toolCalls = [];
let cleanText = text;
// Format 1: [TOOL_CALLS][...] (Mistral native format)
const toolCallMatch = text.match(/\[TOOL_CALLS\]\s*(\[[\s\S]*)/);
if (toolCallMatch) {
const toolCallsJson = toolCallMatch[1];
// Try normal JSON.parse first
try {
const parsedCalls = JSON.parse(toolCallsJson);
if (Array.isArray(parsedCalls)) {
const validCalls = parsedCalls
.filter(call => call && call.name)
.map(call => ({
id: `toolu_cf_${Date.now()}_${toolCalls.length}`,
name: call.name,
args: call.arguments || {}
}));
cleanText = text.replace(/\[TOOL_CALLS\]\s*\[[\s\S]*/, '').trim();
return { cleanText, toolCalls: validCalls };
}
} catch (e) {
// Try multi-line format
}
// Try parsing as multiple single-element arrays on separate lines
const lines = toolCallsJson.split('\n');
const lineMatches = [];
for (const line of lines) {
const trimmed = line.trim();
if (trimmed.startsWith('[') && trimmed.endsWith(']')) {
try {
const parsed = JSON.parse(trimmed);
if (Array.isArray(parsed) && parsed.length > 0 && parsed[0].name) {
lineMatches.push({
id: `toolu_cf_${Date.now()}_${lineMatches.length}`,
name: parsed[0].name,
args: parsed[0].arguments || {}
});
}
} catch (e) {
const objMatch = trimmed.match(/\[\s*(\{[\s\S]*)/);
if (objMatch) {
const repaired = tryRepairAndParse(objMatch[1].replace(/\]\s*$/, ''));
if (repaired && repaired.name) {
lineMatches.push({
id: `toolu_cf_${Date.now()}_${lineMatches.length}`,
name: repaired.name,
args: repaired.arguments || {}
});
}
}
}
}
}
if (lineMatches.length > 0) {
cleanText = text.replace(/\[TOOL_CALLS\][\s\S]*/, '').trim();
return { cleanText, toolCalls: lineMatches };
}
// Extract individual tool calls using regex for truncated JSON
const startPattern = /\{"name"\s*:\s*"/g;
let match;
const toolCallStarts = [];
while ((match = startPattern.exec(toolCallsJson)) !== null) {
toolCallStarts.push(match.index);
}
for (let i = 0; i < toolCallStarts.length; i++) {
const start = toolCallStarts[i];
const end = toolCallStarts[i + 1] || toolCallsJson.length;
let segment = toolCallsJson.substring(start, end).replace(/,\s*$/, '');
const parsed = tryRepairAndParse(segment);
if (parsed && parsed.name) {
toolCalls.push({
id: `toolu_cf_${Date.now()}_${i}`,
name: parsed.name,
args: parsed.arguments || {}
});
}
}
cleanText = text.replace(/\[TOOL_CALLS\]\s*\[[\s\S]*/, '').trim();
if (toolCalls.length > 0) {
return { cleanText, toolCalls };
}
}
// Format 2: [Called tool: name({args})]
const calledToolPattern = /\[Called tool:\s*(\w+)\((\{[\s\S]*?\})\)\]/g;
let calledMatch;
while ((calledMatch = calledToolPattern.exec(text)) !== null) {
try {
const args = JSON.parse(calledMatch[2]);
toolCalls.push({
id: `toolu_cf_${Date.now()}_${toolCalls.length}`,
name: calledMatch[1],
args: args
});
cleanText = cleanText.replace(calledMatch[0], '');
} catch (e) {
const repaired = tryRepairAndParse(calledMatch[2]);
if (repaired) {
toolCalls.push({
id: `toolu_cf_${Date.now()}_${toolCalls.length}`,
name: calledMatch[1],
args: repaired
});
cleanText = cleanText.replace(calledMatch[0], '');
}
}
}
return { cleanText: cleanText.trim(), toolCalls };
}
/**
* Convert LangChain messages to Cloudflare format
* IMPORTANT: Converts tool history to text (Cloudflare limitation)
*/
function messagesToCloudflare(messages) {
const cfMessages = [];
for (const msg of messages) {
const msgType = msg.constructor.name;
if (msgType === 'SystemMessage') {
cfMessages.push({
role: "system",
content: msg.content
});
} else if (msgType === 'HumanMessage') {
cfMessages.push({
role: "user",
content: msg.content
});
} else if (msgType === 'AIMessage') {
// Convert tool calls to text for Cloudflare
let content = msg.content || '';
if (msg.tool_calls && msg.tool_calls.length > 0) {
const toolText = msg.tool_calls.map(tc =>
`[Called tool: ${tc.name}(${JSON.stringify(tc.args)})]`
).join('\n');
content = content ? `${content}\n${toolText}` : toolText;
}
if (content) {
cfMessages.push({
role: "assistant",
content: content
});
}
} else if (msgType === 'ToolMessage') {
// Convert tool results to user messages
cfMessages.push({
role: "user",
content: `[Tool Result (${msg.name}): ${msg.content}]`
});
}
}
return cfMessages;
}
/**
* ChatCloudflare - LangChain ChatModel for Cloudflare Workers AI
*/
export class ChatCloudflare extends BaseChatModel {
static lc_name() {
return "ChatCloudflare";
}
constructor(fields = {}) {
super(fields);
this.model = fields.model || "@cf/mistral/mistral-7b-instruct-v0.2";
this.accountId = fields.accountId || getCloudflareAccountId();
this.apiToken = fields.apiToken || getCloudflareApiToken();
this.maxTokens = fields.maxTokens || 1024;
this._tools = [];
}
_llmType() {
return "cloudflare";
}
/**
* Bind tools to this model instance
*/
bindTools(tools) {
const bound = new ChatCloudflare({
model: this.model,
accountId: this.accountId,
apiToken: this.apiToken,
maxTokens: this.maxTokens
});
// Convert tools to OpenAI format
bound._tools = tools.map(tool => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: zodToJsonSchema(tool.schema, { target: "openApi3" })
}
}));
return bound;
}
async _generate(messages, options, runManager) {
const cfMessages = messagesToCloudflare(messages);
const requestBody = {
messages: cfMessages,
max_tokens: this.maxTokens
};
// Add tools if bound
if (this._tools.length > 0) {
requestBody.tools = this._tools;
}
const endpoint = `https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/run/${this.model}`;
console.log(`[ChatCloudflare] Sending request to: ${endpoint}`);
console.log(`[ChatCloudflare] Messages: ${cfMessages.length}, Tools: ${this._tools.length}`);
const response = await fetch(endpoint, {
method: "POST",
headers: {
"Content-Type": "application/json",
"Authorization": `Bearer ${this.apiToken}`,
},
body: JSON.stringify(requestBody),
});
const cfData = await response.json();
if (!cfData.success) {
throw new Error(cfData.errors?.[0]?.message || "Cloudflare API error");
}
const result = cfData.result || cfData;
// Get tool calls from proper field or parse from text
let toolCalls = result.tool_calls || [];
let textResponse = result.response || '';
// Parse tool calls from text if not present natively
if (toolCalls.length === 0 && textResponse) {
const parsed = parseTextToolCalls(textResponse);
if (parsed.toolCalls.length > 0) {
toolCalls = parsed.toolCalls;
textResponse = parsed.cleanText;
}
} else {
// Convert native tool calls to LangChain format
toolCalls = toolCalls.map((tc, i) => ({
id: `toolu_cf_${Date.now()}_${i}`,
name: tc.name,
args: typeof tc.arguments === 'string' ? JSON.parse(tc.arguments) : (tc.arguments || {})
}));
}
// Create AIMessage with tool calls
const aiMessage = new AIMessage({
content: textResponse,
tool_calls: toolCalls.length > 0 ? toolCalls : undefined,
usage_metadata: {
input_tokens: result.usage?.prompt_tokens || 0,
output_tokens: result.usage?.completion_tokens || 0
}
});
return {
generations: [{
text: textResponse,
message: aiMessage,
}],
llmOutput: {
tokenUsage: {
promptTokens: result.usage?.prompt_tokens || 0,
completionTokens: result.usage?.completion_tokens || 0,
}
}
};
}
}
/**
* Get a Cloudflare model instance with tools bound
* @param {string} modelId - The Cloudflare model name
* @returns {ChatCloudflare} ChatCloudflare instance with tools
*/
export function getCloudflareModel(modelId) {
const model = new ChatCloudflare({ model: modelId });
const langchainTools = Object.values(toolSchemas).map(tool => ({
name: tool.name,
description: tool.description,
schema: tool.schema
}));
return model.bindTools(langchainTools);
}
export default ChatCloudflare;