Skip to content
Snippets Groups Projects
Unverified Commit 047ae07e authored by Thuc Pham's avatar Thuc Pham Committed by GitHub
Browse files

feat: add local hugging face LLM (#854)

parent d8aa29a1
No related branches found
No related tags found
No related merge requests found
import { HuggingFaceLLM } from "llamaindex";
(async () => {
const hf = new HuggingFaceLLM();
const result = await hf.chat({
messages: [
{ content: "You want to talk in rhymes.", role: "system" },
{
content:
"How much wood would a woodchuck chuck if a woodchuck could chuck wood?",
role: "user",
},
],
});
console.log(result);
})();
...@@ -2,6 +2,12 @@ import { ...@@ -2,6 +2,12 @@ import {
HfInference, HfInference,
type Options as HfInferenceOptions, type Options as HfInferenceOptions,
} from "@huggingface/inference"; } from "@huggingface/inference";
import type {
PreTrainedModel,
PreTrainedTokenizer,
Tensor,
} from "@xenova/transformers";
import { lazyLoadTransformers } from "../internal/deps/transformers.js";
import { BaseLLM } from "./base.js"; import { BaseLLM } from "./base.js";
import type { import type {
ChatMessage, ChatMessage,
...@@ -139,3 +145,116 @@ export class HuggingFaceInferenceAPI extends BaseLLM { ...@@ -139,3 +145,116 @@ export class HuggingFaceInferenceAPI extends BaseLLM {
})); }));
} }
} }
const DEFAULT_HUGGINGFACE_MODEL = "stabilityai/stablelm-tuned-alpha-3b";
export interface HFLLMConfig {
modelName?: string;
tokenizerName?: string;
temperature?: number;
topP?: number;
maxTokens?: number;
contextWindow?: number;
}
export class HuggingFaceLLM extends BaseLLM {
modelName: string;
tokenizerName: string;
temperature: number;
topP: number;
maxTokens?: number;
contextWindow: number;
private tokenizer: PreTrainedTokenizer | null = null;
private model: PreTrainedModel | null = null;
constructor(init?: HFLLMConfig) {
super();
this.modelName = init?.modelName ?? DEFAULT_HUGGINGFACE_MODEL;
this.tokenizerName = init?.tokenizerName ?? DEFAULT_HUGGINGFACE_MODEL;
this.temperature = init?.temperature ?? DEFAULT_PARAMS.temperature;
this.topP = init?.topP ?? DEFAULT_PARAMS.topP;
this.maxTokens = init?.maxTokens ?? DEFAULT_PARAMS.maxTokens;
this.contextWindow = init?.contextWindow ?? DEFAULT_PARAMS.contextWindow;
}
get metadata(): LLMMetadata {
return {
model: this.modelName,
temperature: this.temperature,
topP: this.topP,
maxTokens: this.maxTokens,
contextWindow: this.contextWindow,
tokenizer: undefined,
};
}
async getTokenizer() {
const { AutoTokenizer } = await lazyLoadTransformers();
if (!this.tokenizer) {
this.tokenizer = await AutoTokenizer.from_pretrained(this.tokenizerName);
}
return this.tokenizer;
}
async getModel() {
const { AutoModelForCausalLM } = await lazyLoadTransformers();
if (!this.model) {
this.model = await AutoModelForCausalLM.from_pretrained(this.modelName);
}
return this.model;
}
chat(
params: LLMChatParamsStreaming,
): Promise<AsyncIterable<ChatResponseChunk>>;
chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
@wrapLLMEvent
async chat(
params: LLMChatParamsStreaming | LLMChatParamsNonStreaming,
): Promise<AsyncIterable<ChatResponseChunk> | ChatResponse<object>> {
if (params.stream) return this.streamChat(params);
return this.nonStreamChat(params);
}
protected async nonStreamChat(
params: LLMChatParamsNonStreaming,
): Promise<ChatResponse> {
const tokenizer = await this.getTokenizer();
const model = await this.getModel();
const messageInputs = params.messages.map((msg) => ({
role: msg.role,
content: msg.content as string,
}));
const inputs = tokenizer.apply_chat_template(messageInputs, {
add_generation_prompt: true,
...this.metadata,
}) as Tensor;
// TODO: the input for model.generate should be updated when using @xenova/transformers v3
// We should add `stopping_criteria` also when it's supported in v3
// See: https://github.com/xenova/transformers.js/blob/3260640b192b3e06a10a1f4dc004b1254fdf1b80/src/models.js#L1248C9-L1248C27
const outputs = await model.generate(inputs, this.metadata);
const outputText = tokenizer.batch_decode(outputs, {
skip_special_tokens: false,
});
return {
raw: outputs,
message: {
content: outputText.join(""),
role: "assistant",
},
};
}
protected async *streamChat(
params: LLMChatParamsStreaming,
): AsyncIterable<ChatResponseChunk> {
// @xenova/transformers v2 doesn't support streaming generation yet
// they are working on it in v3
// See: https://github.com/xenova/transformers.js/blob/3260640b192b3e06a10a1f4dc004b1254fdf1b80/src/models.js#L1249
throw new Error("Method not implemented.");
}
}
...@@ -7,7 +7,7 @@ export { ...@@ -7,7 +7,7 @@ export {
export { FireworksLLM } from "./fireworks.js"; export { FireworksLLM } from "./fireworks.js";
export { GEMINI_MODEL, Gemini, GeminiSession } from "./gemini.js"; export { GEMINI_MODEL, Gemini, GeminiSession } from "./gemini.js";
export { Groq } from "./groq.js"; export { Groq } from "./groq.js";
export { HuggingFaceInferenceAPI } from "./huggingface.js"; export { HuggingFaceInferenceAPI, HuggingFaceLLM } from "./huggingface.js";
export { export {
ALL_AVAILABLE_MISTRAL_MODELS, ALL_AVAILABLE_MISTRAL_MODELS,
MistralAI, MistralAI,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment