From f651ca862859daa99dd482c51c25233d92d5fd34 Mon Sep 17 00:00:00 2001
From: Sean Hatfield <seanhatfield5@gmail.com>
Date: Sat, 14 Dec 2024 06:18:02 +0800
Subject: [PATCH] APIPie LLM provider improvements (#2695)

* fix apipie streaming/sort by chat models

* lint

* linting

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
---
 server/utils/AiProviders/apipie/index.js | 61 ++++++++++++++----------
 server/utils/helpers/customModels.js     | 22 ++++++---
 2 files changed, 50 insertions(+), 33 deletions(-)

diff --git a/server/utils/AiProviders/apipie/index.js b/server/utils/AiProviders/apipie/index.js
index acfd2b1e6..47b3aabc8 100644
--- a/server/utils/AiProviders/apipie/index.js
+++ b/server/utils/AiProviders/apipie/index.js
@@ -1,8 +1,4 @@
 const { NativeEmbedder } = require("../../EmbeddingEngines/native");
-const {
-  handleDefaultStreamResponseV2,
-} = require("../../helpers/chat/responses");
-
 const { v4: uuidv4 } = require("uuid");
 const {
   writeResponseChunk,
@@ -98,6 +94,24 @@ class ApiPieLLM {
     );
   }
 
+  chatModels() {
+    const allModels = this.models();
+    return Object.entries(allModels).reduce(
+      (chatModels, [modelId, modelInfo]) => {
+        // Filter for chat models
+        if (
+          modelInfo.subtype &&
+          (modelInfo.subtype.includes("chat") ||
+            modelInfo.subtype.includes("chatx"))
+        ) {
+          chatModels[modelId] = modelInfo;
+        }
+        return chatModels;
+      },
+      {}
+    );
+  }
+
   streamingEnabled() {
     return "streamGetChatCompletion" in this;
   }
@@ -114,13 +128,13 @@ class ApiPieLLM {
   }
 
   promptWindowLimit() {
-    const availableModels = this.models();
+    const availableModels = this.chatModels();
     return availableModels[this.model]?.maxLength || 4096;
   }
 
   async isValidChatCompletionModel(model = "") {
     await this.#syncModels();
-    const availableModels = this.models();
+    const availableModels = this.chatModels();
     return availableModels.hasOwnProperty(model);
   }
 
@@ -189,22 +203,20 @@ class ApiPieLLM {
     return result.choices[0].message.content;
   }
 
-  // APIPie says it supports streaming, but it does not work across all models and providers.
-  // Notably, it is not working for OpenRouter models at all.
-  // async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
-  //   if (!(await this.isValidChatCompletionModel(this.model)))
-  //     throw new Error(
-  //       `ApiPie chat: ${this.model} is not valid for chat completion!`
-  //     );
-
-  //   const streamRequest = await this.openai.chat.completions.create({
-  //     model: this.model,
-  //     stream: true,
-  //     messages,
-  //     temperature,
-  //   });
-  //   return streamRequest;
-  // }
+  async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
+    if (!(await this.isValidChatCompletionModel(this.model)))
+      throw new Error(
+        `ApiPie chat: ${this.model} is not valid for chat completion!`
+      );
+
+    const streamRequest = await this.openai.chat.completions.create({
+      model: this.model,
+      stream: true,
+      messages,
+      temperature,
+    });
+    return streamRequest;
+  }
 
   handleStream(response, stream, responseProps) {
     const { uuid = uuidv4(), sources = [] } = responseProps;
@@ -264,10 +276,6 @@ class ApiPieLLM {
     });
   }
 
-  // handleStream(response, stream, responseProps) {
-  //   return handleDefaultStreamResponseV2(response, stream, responseProps);
-  // }
-
   // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
   async embedTextInput(textInput) {
     return await this.embedder.embedTextInput(textInput);
@@ -300,6 +308,7 @@ async function fetchApiPieModels(providedApiKey = null) {
           id: `${model.provider}/${model.model}`,
           name: `${model.provider}/${model.model}`,
           organization: model.provider,
+          subtype: model.subtype,
           maxLength: model.max_tokens,
         };
       });
diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js
index 1361c5271..a763635fb 100644
--- a/server/utils/helpers/customModels.js
+++ b/server/utils/helpers/customModels.js
@@ -401,13 +401,21 @@ async function getAPIPieModels(apiKey = null) {
   if (!Object.keys(knownModels).length === 0)
     return { models: [], error: null };
 
-  const models = Object.values(knownModels).map((model) => {
-    return {
-      id: model.id,
-      organization: model.organization,
-      name: model.name,
-    };
-  });
+  const models = Object.values(knownModels)
+    .filter((model) => {
+      // Filter for chat models
+      return (
+        model.subtype &&
+        (model.subtype.includes("chat") || model.subtype.includes("chatx"))
+      );
+    })
+    .map((model) => {
+      return {
+        id: model.id,
+        organization: model.organization,
+        name: model.name,
+      };
+    });
   return { models, error: null };
 }
 
-- 
GitLab