diff --git a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection/index.jsx index 60a6e940ae861b17b17a324dc9dfb6d7f643a7a6..085f4ef93a60ba3337cbdc619a9d341cc38adf9e 100644 --- a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection/index.jsx @@ -97,7 +97,7 @@ export default function AgentModelSelection({ <option key={model.id} value={model.id} - selected={workspace?.chatModel === model.id} + selected={workspace?.agentModel === model.id} > {model.name} </option> diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js index 852baa65b2f1bcf782a4e593aea53bf35fcd8fbd..d1e9ae9c8a97ce6badba357cd6bc8653cdd10610 100644 --- a/server/utils/agents/aibitat/index.js +++ b/server/utils/agents/aibitat/index.js @@ -12,6 +12,7 @@ const { Telemetry } = require("../../../models/telemetry.js"); class AIbitat { emitter = new EventEmitter(); + provider = null; defaultProvider = null; defaultInterrupt; maxRounds; @@ -39,6 +40,7 @@ class AIbitat { provider, ...rest, }; + this.provider = this.defaultProvider.provider; } /** diff --git a/server/utils/agents/aibitat/plugins/summarize.js b/server/utils/agents/aibitat/plugins/summarize.js index 0c1a9f4e9cd6b9b5d419766012c4cc03643bcce6..ac1ad82c50ec090381c6771562f6a2318e5925d1 100644 --- a/server/utils/agents/aibitat/plugins/summarize.js +++ b/server/utils/agents/aibitat/plugins/summarize.js @@ -2,6 +2,7 @@ const { Document } = require("../../../../models/documents"); const { safeJsonParse } = require("../../../http"); const { validate } = require("uuid"); const { summarizeContent } = require("../utils/summarize"); +const Provider = require("../providers/ai-provider"); const docSummarizer = { name: "document-summarizer", @@ -95,7 +96,19 @@ const docSummarizer = { document?.title ?? "a discovered file." }` ); - if (document?.content?.length < 8000) return content; + + if (!document.content || document.content.length === 0) { + throw new Error( + "This document has no readable content that could be found." + ); + } + + if ( + document.content?.length < + Provider.contextLimit(this.super.provider) + ) { + return document.content; + } this.super.introspect( `${this.caller}: Summarizing ${document?.title ?? ""}...` @@ -109,6 +122,7 @@ const docSummarizer = { }); return await summarizeContent( + this.super.provider, this.controller.signal, document.content ); diff --git a/server/utils/agents/aibitat/plugins/web-scraping.js b/server/utils/agents/aibitat/plugins/web-scraping.js index 90e226c0f26e87d43776b129458d15b3552c0bc3..1e614b6bac2a32d4dcd5aa836e017afe5f3ee954 100644 --- a/server/utils/agents/aibitat/plugins/web-scraping.js +++ b/server/utils/agents/aibitat/plugins/web-scraping.js @@ -1,4 +1,5 @@ const { CollectorApi } = require("../../../collectorApi"); +const Provider = require("../providers/ai-provider"); const { summarizeContent } = require("../utils/summarize"); const webScraping = { @@ -61,7 +62,11 @@ const webScraping = { ); } - if (content?.length <= 8000) { + if (!content || content?.length === 0) { + throw new Error("There was no content to be collected or read."); + } + + if (content.length < Provider.contextLimit(this.super.provider)) { return content; } @@ -74,7 +79,11 @@ const webScraping = { ); this.controller.abort(); }); - return summarizeContent(this.controller.signal, content); + return summarizeContent( + this.super.provider, + this.controller.signal, + content + ); }, }); }, diff --git a/server/utils/agents/aibitat/providers/ai-provider.js b/server/utils/agents/aibitat/providers/ai-provider.js index 3f9181bc4b85ece53f20e46a8f318b2764538d9f..ed7bd31c5ee9c778f58f2cb155f474d2403c4a36 100644 --- a/server/utils/agents/aibitat/providers/ai-provider.js +++ b/server/utils/agents/aibitat/providers/ai-provider.js @@ -2,6 +2,9 @@ * A service that provides an AI client to create a completion. */ +const { ChatOpenAI } = require("langchain/chat_models/openai"); +const { ChatAnthropic } = require("langchain/chat_models/anthropic"); + class Provider { _client; constructor(client) { @@ -14,6 +17,37 @@ class Provider { get client() { return this._client; } + + static LangChainChatModel(provider = "openai", config = {}) { + switch (provider) { + case "openai": + return new ChatOpenAI({ + openAIApiKey: process.env.OPEN_AI_KEY, + ...config, + }); + case "anthropic": + return new ChatAnthropic({ + anthropicApiKey: process.env.ANTHROPIC_API_KEY, + ...config, + }); + default: + return new ChatOpenAI({ + openAIApiKey: process.env.OPEN_AI_KEY, + ...config, + }); + } + } + + static contextLimit(provider = "openai") { + switch (provider) { + case "openai": + return 8_000; + case "anthropic": + return 100_000; + default: + return 8_000; + } + } } module.exports = Provider; diff --git a/server/utils/agents/aibitat/providers/anthropic.js b/server/utils/agents/aibitat/providers/anthropic.js index 8d7e40ed75dd18bdd1ca5484f1f4bcb138b59c02..307731ba7a4e8446e05772a0910ec141f985d372 100644 --- a/server/utils/agents/aibitat/providers/anthropic.js +++ b/server/utils/agents/aibitat/providers/anthropic.js @@ -186,7 +186,8 @@ class AnthropicProvider extends Provider { const completion = response.content.find((msg) => msg.type === "text"); return { result: - completion?.text ?? "I could not generate a response from this.", + completion?.text ?? + "The model failed to complete the task and return back a valid response.", cost: 0, }; } catch (error) { diff --git a/server/utils/agents/aibitat/utils/summarize.js b/server/utils/agents/aibitat/utils/summarize.js index 26eae988a295b173e98c800faebd247a155ce164..2a61263b96210986851aa33668e8e2e32583ea8e 100644 --- a/server/utils/agents/aibitat/utils/summarize.js +++ b/server/utils/agents/aibitat/utils/summarize.js @@ -1,7 +1,7 @@ const { loadSummarizationChain } = require("langchain/chains"); -const { ChatOpenAI } = require("langchain/chat_models/openai"); const { PromptTemplate } = require("langchain/prompts"); const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter"); +const Provider = require("../providers/ai-provider"); /** * Summarize content using OpenAI's GPT-3.5 model. * @@ -9,11 +9,20 @@ const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter"); * @param content The content to summarize. * @returns The summarized content. */ -async function summarizeContent(controllerSignal, content) { - const llm = new ChatOpenAI({ - openAIApiKey: process.env.OPEN_AI_KEY, + +const SUMMARY_MODEL = { + anthropic: "claude-3-opus-20240229", // 200,000 tokens + openai: "gpt-3.5-turbo-1106", // 16,385 tokens +}; + +async function summarizeContent( + provider = "openai", + controllerSignal, + content +) { + const llm = Provider.LangChainChatModel(provider, { temperature: 0, - modelName: "gpt-3.5-turbo-16k-0613", + modelName: SUMMARY_MODEL[provider], }); const textSplitter = new RecursiveCharacterTextSplitter({ @@ -41,6 +50,7 @@ async function summarizeContent(controllerSignal, content) { combineMapPrompt: mapPromptTemplate, verbose: process.env.NODE_ENV === "development", }); + const res = await chain.call({ ...(controllerSignal ? { signal: controllerSignal } : {}), input_documents: docs,