From 81fd82e133c2de25b2f6558afdfc5c72b4e0aab0 Mon Sep 17 00:00:00 2001
From: Timothy Carambat <rambat1010@gmail.com>
Date: Wed, 17 Apr 2024 14:04:51 -0700
Subject: [PATCH] model specific summarization (#1119)

* model specific summarization

* update guard functions

* patch model picker and key inputs
---
 .../AgentConfig/AgentModelSelection/index.jsx |  2 +-
 server/utils/agents/aibitat/index.js          |  2 ++
 .../utils/agents/aibitat/plugins/summarize.js | 16 ++++++++-
 .../agents/aibitat/plugins/web-scraping.js    | 13 +++++--
 .../agents/aibitat/providers/ai-provider.js   | 34 +++++++++++++++++++
 .../agents/aibitat/providers/anthropic.js     |  3 +-
 .../utils/agents/aibitat/utils/summarize.js   | 20 ++++++++---
 7 files changed, 80 insertions(+), 10 deletions(-)

diff --git a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection/index.jsx
index 60a6e940a..085f4ef93 100644
--- a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection/index.jsx
+++ b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection/index.jsx
@@ -97,7 +97,7 @@ export default function AgentModelSelection({
                     <option
                       key={model.id}
                       value={model.id}
-                      selected={workspace?.chatModel === model.id}
+                      selected={workspace?.agentModel === model.id}
                     >
                       {model.name}
                     </option>
diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js
index 852baa65b..d1e9ae9c8 100644
--- a/server/utils/agents/aibitat/index.js
+++ b/server/utils/agents/aibitat/index.js
@@ -12,6 +12,7 @@ const { Telemetry } = require("../../../models/telemetry.js");
 class AIbitat {
   emitter = new EventEmitter();
 
+  provider = null;
   defaultProvider = null;
   defaultInterrupt;
   maxRounds;
@@ -39,6 +40,7 @@ class AIbitat {
       provider,
       ...rest,
     };
+    this.provider = this.defaultProvider.provider;
   }
 
   /**
diff --git a/server/utils/agents/aibitat/plugins/summarize.js b/server/utils/agents/aibitat/plugins/summarize.js
index 0c1a9f4e9..ac1ad82c5 100644
--- a/server/utils/agents/aibitat/plugins/summarize.js
+++ b/server/utils/agents/aibitat/plugins/summarize.js
@@ -2,6 +2,7 @@ const { Document } = require("../../../../models/documents");
 const { safeJsonParse } = require("../../../http");
 const { validate } = require("uuid");
 const { summarizeContent } = require("../utils/summarize");
+const Provider = require("../providers/ai-provider");
 
 const docSummarizer = {
   name: "document-summarizer",
@@ -95,7 +96,19 @@ const docSummarizer = {
                   document?.title ?? "a discovered file."
                 }`
               );
-              if (document?.content?.length < 8000) return content;
+
+              if (!document.content || document.content.length === 0) {
+                throw new Error(
+                  "This document has no readable content that could be found."
+                );
+              }
+
+              if (
+                document.content?.length <
+                Provider.contextLimit(this.super.provider)
+              ) {
+                return document.content;
+              }
 
               this.super.introspect(
                 `${this.caller}: Summarizing ${document?.title ?? ""}...`
@@ -109,6 +122,7 @@ const docSummarizer = {
               });
 
               return await summarizeContent(
+                this.super.provider,
                 this.controller.signal,
                 document.content
               );
diff --git a/server/utils/agents/aibitat/plugins/web-scraping.js b/server/utils/agents/aibitat/plugins/web-scraping.js
index 90e226c0f..1e614b6ba 100644
--- a/server/utils/agents/aibitat/plugins/web-scraping.js
+++ b/server/utils/agents/aibitat/plugins/web-scraping.js
@@ -1,4 +1,5 @@
 const { CollectorApi } = require("../../../collectorApi");
+const Provider = require("../providers/ai-provider");
 const { summarizeContent } = require("../utils/summarize");
 
 const webScraping = {
@@ -61,7 +62,11 @@ const webScraping = {
               );
             }
 
-            if (content?.length <= 8000) {
+            if (!content || content?.length === 0) {
+              throw new Error("There was no content to be collected or read.");
+            }
+
+            if (content.length < Provider.contextLimit(this.super.provider)) {
               return content;
             }
 
@@ -74,7 +79,11 @@ const webScraping = {
               );
               this.controller.abort();
             });
-            return summarizeContent(this.controller.signal, content);
+            return summarizeContent(
+              this.super.provider,
+              this.controller.signal,
+              content
+            );
           },
         });
       },
diff --git a/server/utils/agents/aibitat/providers/ai-provider.js b/server/utils/agents/aibitat/providers/ai-provider.js
index 3f9181bc4..ed7bd31c5 100644
--- a/server/utils/agents/aibitat/providers/ai-provider.js
+++ b/server/utils/agents/aibitat/providers/ai-provider.js
@@ -2,6 +2,9 @@
  * A service that provides an AI client to create a completion.
  */
 
+const { ChatOpenAI } = require("langchain/chat_models/openai");
+const { ChatAnthropic } = require("langchain/chat_models/anthropic");
+
 class Provider {
   _client;
   constructor(client) {
@@ -14,6 +17,37 @@ class Provider {
   get client() {
     return this._client;
   }
+
+  static LangChainChatModel(provider = "openai", config = {}) {
+    switch (provider) {
+      case "openai":
+        return new ChatOpenAI({
+          openAIApiKey: process.env.OPEN_AI_KEY,
+          ...config,
+        });
+      case "anthropic":
+        return new ChatAnthropic({
+          anthropicApiKey: process.env.ANTHROPIC_API_KEY,
+          ...config,
+        });
+      default:
+        return new ChatOpenAI({
+          openAIApiKey: process.env.OPEN_AI_KEY,
+          ...config,
+        });
+    }
+  }
+
+  static contextLimit(provider = "openai") {
+    switch (provider) {
+      case "openai":
+        return 8_000;
+      case "anthropic":
+        return 100_000;
+      default:
+        return 8_000;
+    }
+  }
 }
 
 module.exports = Provider;
diff --git a/server/utils/agents/aibitat/providers/anthropic.js b/server/utils/agents/aibitat/providers/anthropic.js
index 8d7e40ed7..307731ba7 100644
--- a/server/utils/agents/aibitat/providers/anthropic.js
+++ b/server/utils/agents/aibitat/providers/anthropic.js
@@ -186,7 +186,8 @@ class AnthropicProvider extends Provider {
       const completion = response.content.find((msg) => msg.type === "text");
       return {
         result:
-          completion?.text ?? "I could not generate a response from this.",
+          completion?.text ??
+          "The model failed to complete the task and return back a valid response.",
         cost: 0,
       };
     } catch (error) {
diff --git a/server/utils/agents/aibitat/utils/summarize.js b/server/utils/agents/aibitat/utils/summarize.js
index 26eae988a..2a61263b9 100644
--- a/server/utils/agents/aibitat/utils/summarize.js
+++ b/server/utils/agents/aibitat/utils/summarize.js
@@ -1,7 +1,7 @@
 const { loadSummarizationChain } = require("langchain/chains");
-const { ChatOpenAI } = require("langchain/chat_models/openai");
 const { PromptTemplate } = require("langchain/prompts");
 const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter");
+const Provider = require("../providers/ai-provider");
 /**
  * Summarize content using OpenAI's GPT-3.5 model.
  *
@@ -9,11 +9,20 @@ const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter");
  * @param content The content to summarize.
  * @returns The summarized content.
  */
-async function summarizeContent(controllerSignal, content) {
-  const llm = new ChatOpenAI({
-    openAIApiKey: process.env.OPEN_AI_KEY,
+
+const SUMMARY_MODEL = {
+  anthropic: "claude-3-opus-20240229", // 200,000 tokens
+  openai: "gpt-3.5-turbo-1106", // 16,385 tokens
+};
+
+async function summarizeContent(
+  provider = "openai",
+  controllerSignal,
+  content
+) {
+  const llm = Provider.LangChainChatModel(provider, {
     temperature: 0,
-    modelName: "gpt-3.5-turbo-16k-0613",
+    modelName: SUMMARY_MODEL[provider],
   });
 
   const textSplitter = new RecursiveCharacterTextSplitter({
@@ -41,6 +50,7 @@ async function summarizeContent(controllerSignal, content) {
     combineMapPrompt: mapPromptTemplate,
     verbose: process.env.NODE_ENV === "development",
   });
+
   const res = await chain.call({
     ...(controllerSignal ? { signal: controllerSignal } : {}),
     input_documents: docs,
-- 
GitLab