From 9449fcd7379594fdedde51d97d3fc1831b046506 Mon Sep 17 00:00:00 2001
From: Timothy Carambat <rambat1010@gmail.com>
Date: Wed, 17 Apr 2024 11:54:58 -0700
Subject: [PATCH] Add Anthropic agent support with new API and tool_calling
 (#1116)

* Add Anthropic agent support with new API and tool_calling

* patch useProviderHook to unset default models on provider change
---
 frontend/src/hooks/useGetProvidersModels.js   |   6 +-
 .../AgentConfig/AgentLLMSelection/index.jsx   |   5 +-
 .../agents/aibitat/providers/anthropic.js     | 243 +++++++++++-------
 server/utils/agents/index.js                  |   1 +
 4 files changed, 156 insertions(+), 99 deletions(-)

diff --git a/frontend/src/hooks/useGetProvidersModels.js b/frontend/src/hooks/useGetProvidersModels.js
index 513bfdbe1..5dc5cd2ed 100644
--- a/frontend/src/hooks/useGetProvidersModels.js
+++ b/frontend/src/hooks/useGetProvidersModels.js
@@ -47,8 +47,12 @@ export default function useGetProviderModels(provider = null) {
       if (
         PROVIDER_DEFAULT_MODELS.hasOwnProperty(provider) &&
         !groupedProviders.includes(provider)
-      )
+      ) {
         setDefaultModels(PROVIDER_DEFAULT_MODELS[provider]);
+      } else {
+        setDefaultModels([]);
+      }
+
       groupedProviders.includes(provider)
         ? setCustomModels(groupModels(models))
         : setCustomModels(models);
diff --git a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx
index af6ae7549..f1b997470 100644
--- a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx
+++ b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx
@@ -5,10 +5,7 @@ import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference";
 import { CaretUpDown, MagnifyingGlass, X } from "@phosphor-icons/react";
 import AgentModelSelection from "../AgentModelSelection";
 
-const ENABLED_PROVIDERS = [
-  "openai",
-  // "anthropic"
-];
+const ENABLED_PROVIDERS = ["openai", "anthropic"];
 
 const LLM_DEFAULT = {
   name: "Please make a selection",
diff --git a/server/utils/agents/aibitat/providers/anthropic.js b/server/utils/agents/aibitat/providers/anthropic.js
index 5189dc2ef..8d7e40ed7 100644
--- a/server/utils/agents/aibitat/providers/anthropic.js
+++ b/server/utils/agents/aibitat/providers/anthropic.js
@@ -25,6 +25,90 @@ class AnthropicProvider extends Provider {
     this.model = model;
   }
 
+  // For Anthropic we will always need to ensure the message sequence is role,content
+  // as we can attach any data to message nodes and this keeps the message property
+  // sent to the API always in spec.
+  #sanitize(chats) {
+    const sanitized = [...chats];
+
+    // If the first message is not a USER, Anthropic will abort so keep shifting the
+    // message array until that is the case.
+    while (sanitized.length > 0 && sanitized[0].role !== "user")
+      sanitized.shift();
+
+    return sanitized.map((msg) => {
+      const { role, content } = msg;
+      return { role, content };
+    });
+  }
+
+  #normalizeChats(messages = []) {
+    if (!messages.length) return messages;
+    const normalized = [];
+
+    [...messages].forEach((msg, i) => {
+      if (msg.role !== "function") return normalized.push(msg);
+
+      // If the last message is a role "function" this is our special aibitat message node.
+      // and we need to remove it from the array of messages.
+      // Since Anthropic needs to have the tool call resolved, we look at the previous chat to "function"
+      // and go through its content "thought" from ~ln:143 and get the tool_call id so we can resolve
+      // this tool call properly.
+      const functionCompletion = msg;
+      const toolCallId = messages[i - 1]?.content?.find(
+        (msg) => msg.type === "tool_use"
+      )?.id;
+
+      // Append the Anthropic acceptable node to the message chain so function can resolve.
+      normalized.push({
+        role: "user",
+        content: [
+          {
+            type: "tool_result",
+            tool_use_id: toolCallId,
+            content: functionCompletion.content,
+          },
+        ],
+      });
+    });
+    return normalized;
+  }
+
+  // Anthropic handles system message as a property, so here we split the system message prompt
+  // from all the chats and then normalize them so they will be useable in case of tool_calls or general chat.
+  #parseSystemPrompt(messages = []) {
+    const chats = [];
+    let systemPrompt =
+      "You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions.";
+    for (const msg of messages) {
+      if (msg.role === "system") {
+        systemPrompt = msg.content;
+        continue;
+      }
+      chats.push(msg);
+    }
+
+    return [systemPrompt, this.#normalizeChats(chats)];
+  }
+
+  // Anthropic does not use the regular schema for functions so here we need to ensure it is in there specific format
+  // so that the call can run correctly.
+  #formatFunctions(functions = []) {
+    return functions.map((func) => {
+      const { name, description, parameters, required } = func;
+      const { type, properties } = parameters;
+      return {
+        name,
+        description,
+        input_schema: {
+          type,
+          properties,
+          required,
+        },
+      };
+    });
+  }
+
   /**
    * Create a completion based on the received messages.
    *
@@ -32,89 +116,78 @@ class AnthropicProvider extends Provider {
    * @param functions
    * @returns The completion.
    */
-  async complete(messages, functions) {
-    // clone messages to avoid mutating the original array
-    const promptMessages = [...messages];
-
-    if (functions) {
-      const functionPrompt = this.getFunctionPrompt(functions);
-
-      // add function prompt after the first message
-      promptMessages.splice(1, 0, {
-        content: functionPrompt,
-        role: "system",
-      });
-    }
-
-    const prompt = promptMessages
-      .map((message) => {
-        const { content, role } = message;
-
-        switch (role) {
-          case "system":
-            return content
-              ? `${Anthropic.HUMAN_PROMPT} <admin>${content}</admin>`
-              : "";
-
-          case "function":
-          case "user":
-            return `${Anthropic.HUMAN_PROMPT} ${content}`;
-
-          case "assistant":
-            return `${Anthropic.AI_PROMPT} ${content}`;
-
-          default:
-            return content;
-        }
-      })
-      .filter(Boolean)
-      .join("\n")
-      .concat(` ${Anthropic.AI_PROMPT}`);
-
+  async complete(messages, functions = null) {
     try {
-      const response = await this.client.completions.create({
-        model: this.model,
-        max_tokens_to_sample: 3000,
-        stream: false,
-        prompt,
-      });
-
-      const result = response.completion.trim();
-      // TODO: get cost from response
-      const cost = 0;
-
-      // Handle function calls if the model returns a function call
-      if (result.includes("function_name") && functions) {
-        let functionCall;
-        try {
-          functionCall = JSON.parse(result);
-        } catch (error) {
-          // call the complete function again in case it gets a json error
-          return await this.complete(
-            [
-              ...messages,
-              {
-                role: "function",
-                content: `You gave me this function call: ${result} but I couldn't parse it.
-                ${error?.message}
-                
-                Please try again.`,
-              },
-            ],
-            functions
-          );
-        }
-
+      const [systemPrompt, chats] = this.#parseSystemPrompt(messages);
+      const response = await this.client.messages.create(
+        {
+          model: this.model,
+          max_tokens: 4096,
+          system: systemPrompt,
+          messages: this.#sanitize(chats),
+          stream: false,
+          ...(Array.isArray(functions) && functions?.length > 0
+            ? { tools: this.#formatFunctions(functions) }
+            : {}),
+        },
+        { headers: { "anthropic-beta": "tools-2024-04-04" } } // Required to we can use tools.
+      );
+
+      // We know that we need to call a tool. So we are about to recurse through completions/handleExecution
+      // https://docs.anthropic.com/claude/docs/tool-use#how-tool-use-works
+      if (response.stop_reason === "tool_use") {
+        // Get the tool call explicitly.
+        const toolCall = response.content.find(
+          (res) => res.type === "tool_use"
+        );
+
+        // Here we need the chain of thought the model may or may not have generated alongside the call.
+        // this needs to be in a very specific format so we always ensure there is a 2-item content array
+        // so that we can ensure the tool_call content is correct. For anthropic all text items must not
+        // be empty, but the api will still return empty text so we need to make 100% sure text is not empty
+        // or the tool call will fail.
+        // wtf.
+        let thought = response.content.find((res) => res.type === "text");
+        thought =
+          thought?.content?.length > 0
+            ? {
+                role: thought.role,
+                content: [
+                  { type: "text", text: thought.content },
+                  { ...toolCall },
+                ],
+              }
+            : {
+                role: "assistant",
+                content: [
+                  {
+                    type: "text",
+                    text: `Okay, im going to use ${toolCall.name} to help me.`,
+                  },
+                  { ...toolCall },
+                ],
+              };
+
+        // Modify messages forcefully by adding system thought so that tool_use/tool_result
+        // messaging works with Anthropic's disastrous tool calling API.
+        messages.push(thought);
+
+        const functionArgs = toolCall.input;
         return {
           result: null,
-          functionCall,
-          cost,
+          functionCall: {
+            name: toolCall.name,
+            arguments: functionArgs,
+          },
+          cost: 0,
         };
       }
 
+      const completion = response.content.find((msg) => msg.type === "text");
       return {
-        result,
-        cost,
+        result:
+          completion?.text ?? "I could not generate a response from this.",
+        cost: 0,
       };
     } catch (error) {
       // If invalid Auth error we need to abort because no amount of waiting
@@ -132,24 +205,6 @@ class AnthropicProvider extends Provider {
       throw error;
     }
   }
-
-  getFunctionPrompt(functions = []) {
-    const functionPrompt = `<functions>You have been trained to directly call a Javascript function passing a JSON Schema parameter as a response to this chat. This function will return a string that you can use to keep chatting.
-  
-  Here is a list of functions available to you:
-  ${JSON.stringify(functions, null, 2)}
-  
-  When calling any of those function in order to complete your task, respond only this JSON format. Do not include any other information or any other stuff.
-  
-  Function call format:
-  {
-     function_name: "givenfunctionname",
-     parameters: {}
-  }
-  </functions>`;
-
-    return functionPrompt;
-  }
 }
 
 module.exports = AnthropicProvider;
diff --git a/server/utils/agents/index.js b/server/utils/agents/index.js
index dd42a6b99..ce80fff44 100644
--- a/server/utils/agents/index.js
+++ b/server/utils/agents/index.js
@@ -50,6 +50,7 @@ class AgentHandler {
             from: USER_AGENT.name,
             to: WORKSPACE_AGENT.name,
             content: chatLog.prompt,
+            state: "success",
           },
           {
             from: WORKSPACE_AGENT.name,
-- 
GitLab