diff --git a/examples/huggingface/local.ts b/examples/huggingface/local.ts new file mode 100644 index 0000000000000000000000000000000000000000..b2d3cd0d1ca48004f24fdaab28a96921052d38d6 --- /dev/null +++ b/examples/huggingface/local.ts @@ -0,0 +1,16 @@ +import { HuggingFaceLLM } from "llamaindex"; + +(async () => { + const hf = new HuggingFaceLLM(); + const result = await hf.chat({ + messages: [ + { content: "You want to talk in rhymes.", role: "system" }, + { + content: + "How much wood would a woodchuck chuck if a woodchuck could chuck wood?", + role: "user", + }, + ], + }); + console.log(result); +})(); diff --git a/packages/core/src/llm/huggingface.ts b/packages/core/src/llm/huggingface.ts index 74e44449240315814ce2aa7478d5d34e7c738ebb..6428c0fe3d71d77c1644778dac08c91188ec05cf 100644 --- a/packages/core/src/llm/huggingface.ts +++ b/packages/core/src/llm/huggingface.ts @@ -2,6 +2,12 @@ import { HfInference, type Options as HfInferenceOptions, } from "@huggingface/inference"; +import type { + PreTrainedModel, + PreTrainedTokenizer, + Tensor, +} from "@xenova/transformers"; +import { lazyLoadTransformers } from "../internal/deps/transformers.js"; import { BaseLLM } from "./base.js"; import type { ChatMessage, @@ -139,3 +145,116 @@ export class HuggingFaceInferenceAPI extends BaseLLM { })); } } + +const DEFAULT_HUGGINGFACE_MODEL = "stabilityai/stablelm-tuned-alpha-3b"; + +export interface HFLLMConfig { + modelName?: string; + tokenizerName?: string; + temperature?: number; + topP?: number; + maxTokens?: number; + contextWindow?: number; +} + +export class HuggingFaceLLM extends BaseLLM { + modelName: string; + tokenizerName: string; + temperature: number; + topP: number; + maxTokens?: number; + contextWindow: number; + + private tokenizer: PreTrainedTokenizer | null = null; + private model: PreTrainedModel | null = null; + + constructor(init?: HFLLMConfig) { + super(); + this.modelName = init?.modelName ?? DEFAULT_HUGGINGFACE_MODEL; + this.tokenizerName = init?.tokenizerName ?? DEFAULT_HUGGINGFACE_MODEL; + this.temperature = init?.temperature ?? DEFAULT_PARAMS.temperature; + this.topP = init?.topP ?? DEFAULT_PARAMS.topP; + this.maxTokens = init?.maxTokens ?? DEFAULT_PARAMS.maxTokens; + this.contextWindow = init?.contextWindow ?? DEFAULT_PARAMS.contextWindow; + } + + get metadata(): LLMMetadata { + return { + model: this.modelName, + temperature: this.temperature, + topP: this.topP, + maxTokens: this.maxTokens, + contextWindow: this.contextWindow, + tokenizer: undefined, + }; + } + + async getTokenizer() { + const { AutoTokenizer } = await lazyLoadTransformers(); + if (!this.tokenizer) { + this.tokenizer = await AutoTokenizer.from_pretrained(this.tokenizerName); + } + return this.tokenizer; + } + + async getModel() { + const { AutoModelForCausalLM } = await lazyLoadTransformers(); + if (!this.model) { + this.model = await AutoModelForCausalLM.from_pretrained(this.modelName); + } + return this.model; + } + + chat( + params: LLMChatParamsStreaming, + ): Promise<AsyncIterable<ChatResponseChunk>>; + chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; + @wrapLLMEvent + async chat( + params: LLMChatParamsStreaming | LLMChatParamsNonStreaming, + ): Promise<AsyncIterable<ChatResponseChunk> | ChatResponse<object>> { + if (params.stream) return this.streamChat(params); + return this.nonStreamChat(params); + } + + protected async nonStreamChat( + params: LLMChatParamsNonStreaming, + ): Promise<ChatResponse> { + const tokenizer = await this.getTokenizer(); + const model = await this.getModel(); + + const messageInputs = params.messages.map((msg) => ({ + role: msg.role, + content: msg.content as string, + })); + const inputs = tokenizer.apply_chat_template(messageInputs, { + add_generation_prompt: true, + ...this.metadata, + }) as Tensor; + + // TODO: the input for model.generate should be updated when using @xenova/transformers v3 + // We should add `stopping_criteria` also when it's supported in v3 + // See: https://github.com/xenova/transformers.js/blob/3260640b192b3e06a10a1f4dc004b1254fdf1b80/src/models.js#L1248C9-L1248C27 + const outputs = await model.generate(inputs, this.metadata); + const outputText = tokenizer.batch_decode(outputs, { + skip_special_tokens: false, + }); + + return { + raw: outputs, + message: { + content: outputText.join(""), + role: "assistant", + }, + }; + } + + protected async *streamChat( + params: LLMChatParamsStreaming, + ): AsyncIterable<ChatResponseChunk> { + // @xenova/transformers v2 doesn't support streaming generation yet + // they are working on it in v3 + // See: https://github.com/xenova/transformers.js/blob/3260640b192b3e06a10a1f4dc004b1254fdf1b80/src/models.js#L1249 + throw new Error("Method not implemented."); + } +} diff --git a/packages/core/src/llm/index.ts b/packages/core/src/llm/index.ts index 3430bd7860a12c5212f0ca93aad142fb44a334a1..7574dc8ecd6c9308ddf0935b4bc66113c4db4fd1 100644 --- a/packages/core/src/llm/index.ts +++ b/packages/core/src/llm/index.ts @@ -7,7 +7,7 @@ export { export { FireworksLLM } from "./fireworks.js"; export { GEMINI_MODEL, Gemini, GeminiSession } from "./gemini.js"; export { Groq } from "./groq.js"; -export { HuggingFaceInferenceAPI } from "./huggingface.js"; +export { HuggingFaceInferenceAPI, HuggingFaceLLM } from "./huggingface.js"; export { ALL_AVAILABLE_MISTRAL_MODELS, MistralAI,