diff --git a/.changeset/giant-gorillas-explain.md b/.changeset/giant-gorillas-explain.md new file mode 100644 index 0000000000000000000000000000000000000000..738275d5085f6e3f1f5dcd6abfb6ad5e5e9571cb --- /dev/null +++ b/.changeset/giant-gorillas-explain.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat: add support for managed identity for Azure OpenAI diff --git a/packages/core/src/embeddings/OpenAIEmbedding.ts b/packages/core/src/embeddings/OpenAIEmbedding.ts index a3460616ebe5abf32f9674a2feeafc5e3281708e..a2d173a3f1c6b54f7ccd81fd6cb1b8da46839d99 100644 --- a/packages/core/src/embeddings/OpenAIEmbedding.ts +++ b/packages/core/src/embeddings/OpenAIEmbedding.ts @@ -1,7 +1,6 @@ import type { ClientOptions as OpenAIClientOptions } from "openai"; import type { AzureOpenAIConfig } from "../llm/azure.js"; import { - getAzureBaseUrl, getAzureConfigFromEnv, getAzureModel, shouldUseAzure, @@ -67,28 +66,22 @@ export class OpenAIEmbedding extends BaseEmbedding { this.additionalSessionOptions = init?.additionalSessionOptions; if (init?.azure || shouldUseAzure()) { - const azureConfig = getAzureConfigFromEnv({ + const azureConfig = { + ...getAzureConfigFromEnv({ + model: getAzureModel(this.model), + }), ...init?.azure, - model: getAzureModel(this.model), - }); - - if (!azureConfig.apiKey) { - throw new Error( - "Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.", - ); - } + }; this.apiKey = azureConfig.apiKey; this.session = init?.session ?? getOpenAISession({ azure: true, - apiKey: this.apiKey, - baseURL: getAzureBaseUrl(azureConfig), maxRetries: this.maxRetries, timeout: this.timeout, - defaultQuery: { "api-version": azureConfig.apiVersion }, ...this.additionalSessionOptions, + ...azureConfig, }); } else { this.apiKey = init?.apiKey ?? undefined; diff --git a/packages/core/src/llm/azure.ts b/packages/core/src/llm/azure.ts index 964187463816e7eb8348d80307603a23bb4efcbf..bfbfb339fc45e4ec88690cdd4d298d78053bec22 100644 --- a/packages/core/src/llm/azure.ts +++ b/packages/core/src/llm/azure.ts @@ -1,9 +1,9 @@ import { getEnv } from "@llamaindex/env"; -export interface AzureOpenAIConfig { - apiKey?: string; - endpoint?: string; - apiVersion?: string; +import type { AzureClientOptions } from "openai"; + +export interface AzureOpenAIConfig extends AzureClientOptions { + /** @deprecated use "deployment" instead */ deploymentName?: string; } @@ -81,6 +81,12 @@ const DEFAULT_API_VERSION = "2023-05-15"; export function getAzureConfigFromEnv( init?: Partial<AzureOpenAIConfig> & { model?: string }, ): AzureOpenAIConfig { + const deployment = + init?.deploymentName ?? + init?.deployment ?? + getEnv("AZURE_OPENAI_DEPLOYMENT") ?? // From Azure docs + getEnv("AZURE_OPENAI_API_DEPLOYMENT_NAME") ?? // LCJS compatible + init?.model; // Fall back to model name, Python compatible return { apiKey: init?.apiKey ?? @@ -98,11 +104,8 @@ export function getAzureConfigFromEnv( getEnv("OPENAI_API_VERSION") ?? // Python compatible getEnv("AZURE_OPENAI_API_VERSION") ?? // LCJS compatible DEFAULT_API_VERSION, - deploymentName: - init?.deploymentName ?? - getEnv("AZURE_OPENAI_DEPLOYMENT") ?? // From Azure docs - getEnv("AZURE_OPENAI_API_DEPLOYMENT_NAME") ?? // LCJS compatible - init?.model, // Fall back to model name, Python compatible + deploymentName: deployment, // LCJS compatible + deployment, // For Azure OpenAI }; } diff --git a/packages/core/src/llm/openai.ts b/packages/core/src/llm/openai.ts index 1d1bd245b52dc026f12d9a294c2f94541217b512..2bc3604b4ae3dae26f5871545425604ac6bf107b 100644 --- a/packages/core/src/llm/openai.ts +++ b/packages/core/src/llm/openai.ts @@ -5,7 +5,7 @@ import type { ClientOptions, ClientOptions as OpenAIClientOptions, } from "openai"; -import { OpenAI as OrigOpenAI } from "openai"; +import { AzureOpenAI, OpenAI as OrigOpenAI } from "openai"; import type { ChatCompletionAssistantMessageParam, @@ -23,7 +23,6 @@ import { getCallbackManager } from "../internal/settings/CallbackManager.js"; import type { BaseTool } from "../types.js"; import type { AzureOpenAIConfig } from "./azure.js"; import { - getAzureBaseUrl, getAzureConfigFromEnv, getAzureModel, shouldUseAzure, @@ -43,30 +42,23 @@ import type { } from "./types.js"; import { extractText, wrapLLMEvent } from "./utils.js"; -export class AzureOpenAI extends OrigOpenAI { - protected override authHeaders() { - return { "api-key": this.apiKey }; - } -} - export class OpenAISession { openai: OrigOpenAI; constructor(options: ClientOptions & { azure?: boolean } = {}) { - if (!options.apiKey) { - options.apiKey = getEnv("OPENAI_API_KEY"); - } - - if (!options.apiKey) { - throw new Error("Set OpenAI Key in OPENAI_API_KEY env variable"); // Overriding OpenAI package's error message - } - if (options.azure) { - this.openai = new AzureOpenAI(options); + this.openai = new AzureOpenAI(options as AzureOpenAIConfig); } else { + if (!options.apiKey) { + options.apiKey = getEnv("OPENAI_API_KEY"); + } + + if (!options.apiKey) { + throw new Error("Set OpenAI Key in OPENAI_API_KEY env variable"); // Overriding OpenAI package's error message + } + this.openai = new OrigOpenAI({ ...options, - // defaultHeaders: { "OpenAI-Beta": "assistants=v1" }, }); } } @@ -195,28 +187,22 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { this.additionalSessionOptions = init?.additionalSessionOptions; if (init?.azure || shouldUseAzure()) { - const azureConfig = getAzureConfigFromEnv({ + const azureConfig = { + ...getAzureConfigFromEnv({ + model: getAzureModel(this.model), + }), ...init?.azure, - model: getAzureModel(this.model), - }); - - if (!azureConfig.apiKey) { - throw new Error( - "Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.", - ); - } + }; this.apiKey = azureConfig.apiKey; this.session = init?.session ?? getOpenAISession({ azure: true, - apiKey: this.apiKey, - baseURL: getAzureBaseUrl(azureConfig), maxRetries: this.maxRetries, timeout: this.timeout, - defaultQuery: { "api-version": azureConfig.apiVersion }, ...this.additionalSessionOptions, + ...azureConfig, }); } else { this.apiKey = init?.apiKey ?? undefined;