From a5ee6121e22bb41d61104c03d96e063633b0dbf4 Mon Sep 17 00:00:00 2001
From: timothycarambat <rambat1010@gmail.com>
Date: Wed, 5 Feb 2025 11:34:03 -0800
Subject: [PATCH] Add patch for `o#` models on Azure connect #3023 Note:
 depends on user naming the deployment correctly.

---
 server/utils/AiProviders/azureOpenAi/index.js | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/server/utils/AiProviders/azureOpenAi/index.js b/server/utils/AiProviders/azureOpenAi/index.js
index bfebbd4ad..6b726c88d 100644
--- a/server/utils/AiProviders/azureOpenAi/index.js
+++ b/server/utils/AiProviders/azureOpenAi/index.js
@@ -38,6 +38,16 @@ class AzureOpenAiLLM {
     );
   }
 
+  /**
+   * Check if the model is an o# type model.
+   * NOTE: This is HIGHLY dependent on if the user named their deployment "o1" or "o3-mini" or something else to match the model name.
+   * It cannot be determined by the model name alone since model deployments can be named arbitrarily.
+   * @returns {boolean}
+   */
+  get isOTypeModel() {
+    return this.model.startsWith("o");
+  }
+
   #log(text, ...args) {
     console.log(`\x1b[32m[AzureOpenAi]\x1b[0m ${text}`, ...args);
   }
@@ -55,6 +65,7 @@ class AzureOpenAiLLM {
   }
 
   streamingEnabled() {
+    if (this.isOTypeModel && this.model !== "o3-mini") return false;
     return "streamGetChatCompletion" in this;
   }
 
@@ -110,7 +121,7 @@ class AzureOpenAiLLM {
     attachments = [], // This is the specific attachment for only this prompt
   }) {
     const prompt = {
-      role: "system",
+      role: this.isOTypeModel ? "user" : "system",
       content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
     };
     return [
@@ -131,7 +142,7 @@ class AzureOpenAiLLM {
 
     const result = await LLMPerformanceMonitor.measureAsyncFunction(
       this.openai.getChatCompletions(this.model, messages, {
-        temperature,
+        ...(this.isOTypeModel ? {} : { temperature }),
       })
     );
 
@@ -161,7 +172,7 @@ class AzureOpenAiLLM {
 
     const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
       await this.openai.streamChatCompletions(this.model, messages, {
-        temperature,
+        ...(this.isOTypeModel ? {} : { temperature }),
         n: 1,
       }),
       messages
-- 
GitLab