diff --git a/packages/core/src/global/settings.ts b/packages/core/src/global/settings.ts index 57ea265a04f889cd94e5e3c2ed130265699d43cd..e01901d458f9265bd5de28ecee4fb33092027e68 100644 --- a/packages/core/src/global/settings.ts +++ b/packages/core/src/global/settings.ts @@ -1,4 +1,5 @@ import type { Tokenizer } from "@llamaindex/env"; +import type { LLM } from "../llms"; import { type CallbackManager, getCallbackManager, @@ -10,6 +11,7 @@ import { setChunkSize, withChunkSize, } from "./settings/chunk-size"; +import { getLLM, setLLM, withLLM } from "./settings/llm"; import { getTokenizer, setTokenizer, @@ -17,6 +19,15 @@ import { } from "./settings/tokenizer"; export const Settings = { + get llm() { + return getLLM(); + }, + set llm(llm) { + setLLM(llm); + }, + withLLM<Result>(llm: LLM, fn: () => Result): Result { + return withLLM(llm, fn); + }, get tokenizer() { return getTokenizer(); }, diff --git a/packages/core/src/global/settings/llm.ts b/packages/core/src/global/settings/llm.ts new file mode 100644 index 0000000000000000000000000000000000000000..9309b3255c6b2f4e1d35abc93dc2ec4f58768fb9 --- /dev/null +++ b/packages/core/src/global/settings/llm.ts @@ -0,0 +1,23 @@ +import { AsyncLocalStorage } from "@llamaindex/env"; +import type { LLM } from "../../llms"; + +const llmAsyncLocalStorage = new AsyncLocalStorage<LLM>(); +let globalLLM: LLM | undefined; + +export function getLLM(): LLM { + const currentLLM = globalLLM ?? llmAsyncLocalStorage.getStore(); + if (!currentLLM) { + throw new Error( + "Cannot find LLM, please set `Settings.llm = ...` on the top of your code", + ); + } + return currentLLM; +} + +export function setLLM(llm: LLM): void { + globalLLM = llm; +} + +export function withLLM<Result>(llm: LLM, fn: () => Result): Result { + return llmAsyncLocalStorage.run(llm, fn); +} diff --git a/packages/llamaindex/src/Settings.ts b/packages/llamaindex/src/Settings.ts index 778a6fee42c4a1a9c1f41fa60801c12c69401bb6..628271c409c065d8d4dab601ba71397f50465d46 100644 --- a/packages/llamaindex/src/Settings.ts +++ b/packages/llamaindex/src/Settings.ts @@ -27,7 +27,6 @@ export type PromptConfig = { export interface Config { prompt: PromptConfig; - llm: LLM | null; promptHelper: PromptHelper | null; embedModel: BaseEmbedding | null; nodeParser: NodeParser | null; @@ -41,12 +40,10 @@ export interface Config { */ class GlobalSettings implements Config { #prompt: PromptConfig = {}; - #llm: LLM | null = null; #promptHelper: PromptHelper | null = null; #nodeParser: NodeParser | null = null; #chunkOverlap?: number; - #llmAsyncLocalStorage = new AsyncLocalStorage<LLM>(); #promptHelperAsyncLocalStorage = new AsyncLocalStorage<PromptHelper>(); #nodeParserAsyncLocalStorage = new AsyncLocalStorage<NodeParser>(); #chunkOverlapAsyncLocalStorage = new AsyncLocalStorage<number>(); @@ -62,19 +59,19 @@ class GlobalSettings implements Config { } get llm(): LLM { - if (this.#llm === null) { - this.#llm = new OpenAI(); + if (CoreSettings.llm === null) { + CoreSettings.llm = new OpenAI(); } - return this.#llmAsyncLocalStorage.getStore() ?? this.#llm; + return CoreSettings.llm; } set llm(llm: LLM) { - this.#llm = llm; + CoreSettings.llm = llm; } withLLM<Result>(llm: LLM, fn: () => Result): Result { - return this.#llmAsyncLocalStorage.run(llm, fn); + return CoreSettings.withLLM(llm, fn); } get promptHelper(): PromptHelper {