immersive2/server/services/langchainModels.js
Michael Mainguy 4ca98cf980 Add LangChain model wrappers and enhance diagram AI tools
- Migrate to LangChain for model abstraction (@langchain/anthropic, @langchain/ollama)
- Add custom ChatCloudflare class for Cloudflare Workers AI
- Simplify API routes using unified LangChain interface
- Add session preferences API for storing user settings
- Add connection label preference (ask user once, remember for session)
- Add shape modification support (change entity shapes via AI)
- Add template setter to DiagramObject for shape changes
- Improve entity inference with fuzzy matching
- Map colors to 16 toolbox palette colors
- Limit conversation history to last 6 messages
- Fix model switching to accept display names

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-14 10:17:15 -06:00

289 lines
10 KiB
JavaScript

/**
* LangChain Model Wrappers
*
* Provides unified model interfaces using LangChain.
* Handles tool binding and response format conversion.
*/
import { ChatAnthropic } from "@langchain/anthropic";
import { ChatOllama } from "@langchain/ollama";
import { HumanMessage, SystemMessage, AIMessage, ToolMessage } from "@langchain/core/messages";
import { toolSchemas } from "./langchainTools.js";
import { getSession, getConversationForAPI } from "./sessionStore.js";
import { getOllamaUrl } from "./providerConfig.js";
/**
* Convert tool schemas to LangChain tool format for bindTools()
* LangChain expects: { name, description, schema (Zod) }
*/
const langchainTools = Object.values(toolSchemas).map(tool => ({
name: tool.name,
description: tool.description,
schema: tool.schema
}));
/**
* Get a Claude model instance with tools bound
* @param {string} modelId - The Claude model ID
* @returns {object} ChatAnthropic instance with tools
*/
export function getClaudeModel(modelId) {
const model = new ChatAnthropic({
modelName: modelId,
anthropicApiKey: process.env.ANTHROPIC_API_KEY,
});
return model.bindTools(langchainTools);
}
/**
* Get an Ollama model instance with tools bound
* @param {string} modelId - The Ollama model name
* @returns {object} ChatOllama instance with tools
*/
export function getOllamaModel(modelId) {
const model = new ChatOllama({
model: modelId,
baseUrl: getOllamaUrl(),
});
return model.bindTools(langchainTools);
}
/**
* Build camera context string
* @param {object} cameraPosition - Camera position and orientation data
* @returns {string} Camera context string
*/
export function buildCameraContext(cameraPosition) {
if (!cameraPosition) {
return "";
}
const { position, forward, groundForward, groundRight } = cameraPosition;
if (!position) return "";
const p = position;
const gf = groundForward || { x: 0, z: 1 };
const gr = groundRight || { x: 1, z: 0 };
return `\n\n## User's Current View (World Coordinates)
Position: (${p.x?.toFixed(2)}, ${p.y?.toFixed(2)}, ${p.z?.toFixed(2)})
Looking: (${forward?.x?.toFixed(2) || 0}, ${forward?.y?.toFixed(2) || 0}, ${forward?.z?.toFixed(2) || 0})
Ground Forward: (${gf.x?.toFixed(2)}, 0, ${gf.z?.toFixed(2)})
Ground Right: (${gr.x?.toFixed(2)}, 0, ${gr.z?.toFixed(2)})
To place entities relative to user:
- FORWARD: add groundForward * distance to position
- RIGHT: add groundRight * distance to position
- LEFT: subtract groundRight * distance from position
- BACK: subtract groundForward * distance from position`;
}
/**
* Build entity context string for the system prompt
* @param {Array} entities - Array of diagram entities
* @param {object} cameraPosition - Optional camera position data
* @returns {string} Entity context string
*/
export function buildEntityContext(entities, cameraPosition = null) {
let context = "";
// Add camera context if available
context += buildCameraContext(cameraPosition);
// Add entity context
if (!entities || entities.length === 0) {
context += "\n\nThe diagram is currently empty.";
return context;
}
const entityList = entities.map(e => {
const shape = e.template?.replace('#', '').replace('-template', '') || 'unknown';
const pos = e.position || { x: 0, y: 0, z: 0 };
return `- ${e.text || '(no label)'} (${shape}, ${e.color || 'unknown'}) at (${pos.x?.toFixed(1)}, ${pos.y?.toFixed(1)}, ${pos.z?.toFixed(1)})`;
}).join('\n');
context += `\n\n## Current Diagram State\nThe diagram currently contains ${entities.length} entities:\n${entityList}`;
return context;
}
/**
* Convert Claude-format messages to LangChain message objects
* @param {Array} messages - Messages in Claude format
* @returns {Array} Array of LangChain message objects
*/
export function claudeMessagesToLangChain(messages) {
const result = [];
// Track tool use IDs for tool results
const toolCallMap = new Map();
for (const msg of messages) {
if (msg.role === 'user') {
if (Array.isArray(msg.content)) {
// Handle tool results
for (const block of msg.content) {
if (block.type === 'text') {
result.push(new HumanMessage(block.text));
} else if (block.type === 'tool_result') {
// Get tool name from previous tool_use
const toolName = toolCallMap.get(block.tool_use_id) || 'unknown';
result.push(new ToolMessage({
content: typeof block.content === 'string' ? block.content : JSON.stringify(block.content),
tool_call_id: block.tool_use_id,
name: toolName
}));
}
}
} else {
result.push(new HumanMessage(msg.content));
}
} else if (msg.role === 'assistant') {
if (Array.isArray(msg.content)) {
let textContent = '';
const toolCalls = [];
for (const block of msg.content) {
if (block.type === 'text') {
textContent += block.text;
} else if (block.type === 'tool_use') {
toolCallMap.set(block.id, block.name);
toolCalls.push({
id: block.id,
name: block.name,
args: block.input
});
}
}
const aiMessage = new AIMessage({
content: textContent,
tool_calls: toolCalls.length > 0 ? toolCalls : undefined
});
result.push(aiMessage);
} else {
result.push(new AIMessage(msg.content));
}
}
}
return result;
}
/**
* Build LangChain messages from session and request
* @param {string} sessionId - Session ID
* @param {Array} requestMessages - Messages from the request
* @param {string} systemPrompt - Base system prompt
* @returns {Array} Array of LangChain messages
*/
// Maximum number of history messages to include (to limit token usage)
const MAX_HISTORY_MESSAGES = 6; // 3 exchanges (user + assistant pairs)
export function buildLangChainMessages(sessionId, requestMessages, systemPrompt) {
const messages = [];
let entityContext = '';
if (sessionId) {
const session = getSession(sessionId);
if (session) {
entityContext = buildEntityContext(session.entities, session.cameraPosition);
// Get conversation history (limited to last few messages)
const historyMessages = getConversationForAPI(sessionId);
if (historyMessages.length > 0) {
// Filter out duplicates
const currentContent = requestMessages?.[requestMessages.length - 1]?.content;
let filteredHistory = historyMessages.filter(msg => msg.content !== currentContent);
// Limit to last N messages to reduce token usage
if (filteredHistory.length > MAX_HISTORY_MESSAGES) {
console.log(`[LangChain] Trimming history from ${filteredHistory.length} to ${MAX_HISTORY_MESSAGES} messages`);
filteredHistory = filteredHistory.slice(-MAX_HISTORY_MESSAGES);
}
// Convert history to LangChain format
const langChainHistory = claudeMessagesToLangChain(filteredHistory);
messages.push(...langChainHistory);
}
}
}
// Add system message at the beginning
if (systemPrompt || entityContext) {
messages.unshift(new SystemMessage((systemPrompt || '') + entityContext));
}
// Add current request messages
if (requestMessages && requestMessages.length > 0) {
const currentMessages = claudeMessagesToLangChain(requestMessages);
messages.push(...currentMessages);
}
return messages;
}
/**
* Convert LangChain AIMessage to Claude API response format
* @param {AIMessage} aiMessage - LangChain AIMessage
* @param {string} model - Model name
* @returns {object} Response in Claude API format
*/
export function aiMessageToClaudeResponse(aiMessage, model) {
const content = [];
// Add text content if present
if (aiMessage.content) {
content.push({
type: "text",
text: typeof aiMessage.content === 'string' ? aiMessage.content : aiMessage.content.toString()
});
}
// Add tool calls if present
if (aiMessage.tool_calls && aiMessage.tool_calls.length > 0) {
console.log('[LangChain] Tool calls in AIMessage:', JSON.stringify(aiMessage.tool_calls, null, 2));
for (let i = 0; i < aiMessage.tool_calls.length; i++) {
const tc = aiMessage.tool_calls[i];
console.log(`[LangChain] Tool call ${i}: name=${tc.name}, args=${JSON.stringify(tc.args)}`);
content.push({
type: "tool_use",
id: tc.id || `toolu_${Date.now()}_${i}`,
name: tc.name,
input: tc.args || {}
});
}
}
// Extract usage from response metadata
const usage = aiMessage.usage_metadata || aiMessage.response_metadata?.usage || {
input_tokens: 0,
output_tokens: 0
};
return {
id: `msg_${Date.now()}`,
type: "message",
role: "assistant",
content: content,
model: model,
stop_reason: aiMessage.tool_calls?.length > 0 ? "tool_use" : "end_turn",
usage: {
input_tokens: usage.input_tokens || usage.prompt_tokens || 0,
output_tokens: usage.output_tokens || usage.completion_tokens || 0,
cache_creation_input_tokens: usage.cache_creation_input_tokens || 0,
cache_read_input_tokens: usage.cache_read_input_tokens || 0
}
};
}
export default {
getClaudeModel,
getOllamaModel,
buildEntityContext,
claudeMessagesToLangChain,
buildLangChainMessages,
aiMessageToClaudeResponse,
langchainTools
};