From 81fd82e133c2de25b2f6558afdfc5c72b4e0aab0 Mon Sep 17 00:00:00 2001 From: Timothy Carambat <rambat1010@gmail.com> Date: Wed, 17 Apr 2024 14:04:51 -0700 Subject: [PATCH] model specific summarization (#1119) * model specific summarization * update guard functions * patch model picker and key inputs --- .../AgentConfig/AgentModelSelection/index.jsx | 2 +- server/utils/agents/aibitat/index.js | 2 ++ .../utils/agents/aibitat/plugins/summarize.js | 16 ++++++++- .../agents/aibitat/plugins/web-scraping.js | 13 +++++-- .../agents/aibitat/providers/ai-provider.js | 34 +++++++++++++++++++ .../agents/aibitat/providers/anthropic.js | 3 +- .../utils/agents/aibitat/utils/summarize.js | 20 ++++++++--- 7 files changed, 80 insertions(+), 10 deletions(-) diff --git a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection/index.jsx index 60a6e940a..085f4ef93 100644 --- a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection/index.jsx @@ -97,7 +97,7 @@ export default function AgentModelSelection({ <option key={model.id} value={model.id} - selected={workspace?.chatModel === model.id} + selected={workspace?.agentModel === model.id} > {model.name} </option> diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js index 852baa65b..d1e9ae9c8 100644 --- a/server/utils/agents/aibitat/index.js +++ b/server/utils/agents/aibitat/index.js @@ -12,6 +12,7 @@ const { Telemetry } = require("../../../models/telemetry.js"); class AIbitat { emitter = new EventEmitter(); + provider = null; defaultProvider = null; defaultInterrupt; maxRounds; @@ -39,6 +40,7 @@ class AIbitat { provider, ...rest, }; + this.provider = this.defaultProvider.provider; } /** diff --git a/server/utils/agents/aibitat/plugins/summarize.js b/server/utils/agents/aibitat/plugins/summarize.js index 0c1a9f4e9..ac1ad82c5 100644 --- a/server/utils/agents/aibitat/plugins/summarize.js +++ b/server/utils/agents/aibitat/plugins/summarize.js @@ -2,6 +2,7 @@ const { Document } = require("../../../../models/documents"); const { safeJsonParse } = require("../../../http"); const { validate } = require("uuid"); const { summarizeContent } = require("../utils/summarize"); +const Provider = require("../providers/ai-provider"); const docSummarizer = { name: "document-summarizer", @@ -95,7 +96,19 @@ const docSummarizer = { document?.title ?? "a discovered file." }` ); - if (document?.content?.length < 8000) return content; + + if (!document.content || document.content.length === 0) { + throw new Error( + "This document has no readable content that could be found." + ); + } + + if ( + document.content?.length < + Provider.contextLimit(this.super.provider) + ) { + return document.content; + } this.super.introspect( `${this.caller}: Summarizing ${document?.title ?? ""}...` @@ -109,6 +122,7 @@ const docSummarizer = { }); return await summarizeContent( + this.super.provider, this.controller.signal, document.content ); diff --git a/server/utils/agents/aibitat/plugins/web-scraping.js b/server/utils/agents/aibitat/plugins/web-scraping.js index 90e226c0f..1e614b6ba 100644 --- a/server/utils/agents/aibitat/plugins/web-scraping.js +++ b/server/utils/agents/aibitat/plugins/web-scraping.js @@ -1,4 +1,5 @@ const { CollectorApi } = require("../../../collectorApi"); +const Provider = require("../providers/ai-provider"); const { summarizeContent } = require("../utils/summarize"); const webScraping = { @@ -61,7 +62,11 @@ const webScraping = { ); } - if (content?.length <= 8000) { + if (!content || content?.length === 0) { + throw new Error("There was no content to be collected or read."); + } + + if (content.length < Provider.contextLimit(this.super.provider)) { return content; } @@ -74,7 +79,11 @@ const webScraping = { ); this.controller.abort(); }); - return summarizeContent(this.controller.signal, content); + return summarizeContent( + this.super.provider, + this.controller.signal, + content + ); }, }); }, diff --git a/server/utils/agents/aibitat/providers/ai-provider.js b/server/utils/agents/aibitat/providers/ai-provider.js index 3f9181bc4..ed7bd31c5 100644 --- a/server/utils/agents/aibitat/providers/ai-provider.js +++ b/server/utils/agents/aibitat/providers/ai-provider.js @@ -2,6 +2,9 @@ * A service that provides an AI client to create a completion. */ +const { ChatOpenAI } = require("langchain/chat_models/openai"); +const { ChatAnthropic } = require("langchain/chat_models/anthropic"); + class Provider { _client; constructor(client) { @@ -14,6 +17,37 @@ class Provider { get client() { return this._client; } + + static LangChainChatModel(provider = "openai", config = {}) { + switch (provider) { + case "openai": + return new ChatOpenAI({ + openAIApiKey: process.env.OPEN_AI_KEY, + ...config, + }); + case "anthropic": + return new ChatAnthropic({ + anthropicApiKey: process.env.ANTHROPIC_API_KEY, + ...config, + }); + default: + return new ChatOpenAI({ + openAIApiKey: process.env.OPEN_AI_KEY, + ...config, + }); + } + } + + static contextLimit(provider = "openai") { + switch (provider) { + case "openai": + return 8_000; + case "anthropic": + return 100_000; + default: + return 8_000; + } + } } module.exports = Provider; diff --git a/server/utils/agents/aibitat/providers/anthropic.js b/server/utils/agents/aibitat/providers/anthropic.js index 8d7e40ed7..307731ba7 100644 --- a/server/utils/agents/aibitat/providers/anthropic.js +++ b/server/utils/agents/aibitat/providers/anthropic.js @@ -186,7 +186,8 @@ class AnthropicProvider extends Provider { const completion = response.content.find((msg) => msg.type === "text"); return { result: - completion?.text ?? "I could not generate a response from this.", + completion?.text ?? + "The model failed to complete the task and return back a valid response.", cost: 0, }; } catch (error) { diff --git a/server/utils/agents/aibitat/utils/summarize.js b/server/utils/agents/aibitat/utils/summarize.js index 26eae988a..2a61263b9 100644 --- a/server/utils/agents/aibitat/utils/summarize.js +++ b/server/utils/agents/aibitat/utils/summarize.js @@ -1,7 +1,7 @@ const { loadSummarizationChain } = require("langchain/chains"); -const { ChatOpenAI } = require("langchain/chat_models/openai"); const { PromptTemplate } = require("langchain/prompts"); const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter"); +const Provider = require("../providers/ai-provider"); /** * Summarize content using OpenAI's GPT-3.5 model. * @@ -9,11 +9,20 @@ const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter"); * @param content The content to summarize. * @returns The summarized content. */ -async function summarizeContent(controllerSignal, content) { - const llm = new ChatOpenAI({ - openAIApiKey: process.env.OPEN_AI_KEY, + +const SUMMARY_MODEL = { + anthropic: "claude-3-opus-20240229", // 200,000 tokens + openai: "gpt-3.5-turbo-1106", // 16,385 tokens +}; + +async function summarizeContent( + provider = "openai", + controllerSignal, + content +) { + const llm = Provider.LangChainChatModel(provider, { temperature: 0, - modelName: "gpt-3.5-turbo-16k-0613", + modelName: SUMMARY_MODEL[provider], }); const textSplitter = new RecursiveCharacterTextSplitter({ @@ -41,6 +50,7 @@ async function summarizeContent(controllerSignal, content) { combineMapPrompt: mapPromptTemplate, verbose: process.env.NODE_ENV === "development", }); + const res = await chain.call({ ...(controllerSignal ? { signal: controllerSignal } : {}), input_documents: docs, -- GitLab