From ef4f63d9f4fc7047fc48e8688d8272c92c3527f4 Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Tue, 19 Nov 2024 02:39:46 +0700 Subject: [PATCH] refactor: move mockLLM to core (#1493) Co-authored-by: Alex Yang <himself65@outlook.com> --- .changeset/perfect-turtles-mate.md | 5 ++ apps/next/src/app/api/chat/route.ts | 4 +- .../components/demo/chat/rsc/ai-action.tsx | 4 +- apps/next/src/lib/utils.ts | 32 +------ packages/core/src/utils/index.ts | 1 + packages/core/src/utils/llms.ts | 88 +++++++++++++++++++ 6 files changed, 100 insertions(+), 34 deletions(-) create mode 100644 .changeset/perfect-turtles-mate.md diff --git a/.changeset/perfect-turtles-mate.md b/.changeset/perfect-turtles-mate.md new file mode 100644 index 000000000..efba7bb25 --- /dev/null +++ b/.changeset/perfect-turtles-mate.md @@ -0,0 +1,5 @@ +--- +"@llamaindex/core": patch +--- + +refactor: move mockLLM to core diff --git a/apps/next/src/app/api/chat/route.ts b/apps/next/src/app/api/chat/route.ts index 49cbd9c01..0eb9e3133 100644 --- a/apps/next/src/app/api/chat/route.ts +++ b/apps/next/src/app/api/chat/route.ts @@ -1,9 +1,9 @@ -import { llm } from "@/lib/utils"; +import { MockLLM } from "@llamaindex/core/utils"; import { LlamaIndexAdapter, type Message } from "ai"; import { Settings, SimpleChatEngine, type ChatMessage } from "llamaindex"; import { NextResponse, type NextRequest } from "next/server"; -Settings.llm = llm; +Settings.llm = new MockLLM(); // config your LLM here export async function POST(request: NextRequest) { try { diff --git a/apps/next/src/components/demo/chat/rsc/ai-action.tsx b/apps/next/src/components/demo/chat/rsc/ai-action.tsx index dd74f5e0d..169ba827a 100644 --- a/apps/next/src/components/demo/chat/rsc/ai-action.tsx +++ b/apps/next/src/components/demo/chat/rsc/ai-action.tsx @@ -1,5 +1,5 @@ -import { llm } from "@/lib/utils"; import { Markdown } from "@llamaindex/chat-ui/widgets"; +import { MockLLM } from "@llamaindex/core/utils"; import { generateId, Message } from "ai"; import { createAI, createStreamableUI, getMutableAIState } from "ai/rsc"; import { type ChatMessage, Settings, SimpleChatEngine } from "llamaindex"; @@ -11,7 +11,7 @@ type Actions = { chat: (message: Message) => Promise<Message & { display: ReactNode }>; }; -Settings.llm = llm; +Settings.llm = new MockLLM(); // config your LLM here export const AI = createAI<ServerState, FrontendState, Actions>({ initialAIState: [], diff --git a/apps/next/src/lib/utils.ts b/apps/next/src/lib/utils.ts index e073bc945..bd0c391dd 100644 --- a/apps/next/src/lib/utils.ts +++ b/apps/next/src/lib/utils.ts @@ -1,34 +1,6 @@ -import { clsx, type ClassValue } from "clsx"; -import { LLM, LLMMetadata } from "llamaindex"; -import { twMerge } from "tailwind-merge"; +import { clsx, type ClassValue } from "clsx" +import { twMerge } from "tailwind-merge" export function cn(...inputs: ClassValue[]) { return twMerge(clsx(inputs)) } - -class MockLLM { - metadata: LLMMetadata = { - model: "MockLLM", - temperature: 0.5, - topP: 0.5, - contextWindow: 1024, - tokenizer: undefined, - }; - - chat() { - const mockResponse = "Hello! This is a mock response"; - return Promise.resolve( - new ReadableStream({ - async start(controller) { - for (const char of mockResponse) { - controller.enqueue({ delta: char }); - await new Promise((resolve) => setTimeout(resolve, 20)); - } - controller.close(); - }, - }), - ); - } -} - -export const llm = new MockLLM() as unknown as LLM; \ No newline at end of file diff --git a/packages/core/src/utils/index.ts b/packages/core/src/utils/index.ts index d040f010a..a68273919 100644 --- a/packages/core/src/utils/index.ts +++ b/packages/core/src/utils/index.ts @@ -76,6 +76,7 @@ export { extractText, imageToDataUrl, messagesToHistory, + MockLLM, toToolDescriptions, } from "./llms"; diff --git a/packages/core/src/utils/llms.ts b/packages/core/src/utils/llms.ts index 255b82b91..c08933667 100644 --- a/packages/core/src/utils/llms.ts +++ b/packages/core/src/utils/llms.ts @@ -2,6 +2,15 @@ import { fs } from "@llamaindex/env"; import { filetypemime } from "magic-bytes.js"; import type { ChatMessage, + ChatResponse, + ChatResponseChunk, + CompletionResponse, + LLM, + LLMChatParamsNonStreaming, + LLMChatParamsStreaming, + LLMCompletionParamsNonStreaming, + LLMCompletionParamsStreaming, + LLMMetadata, MessageContent, MessageContentDetail, MessageContentTextDetail, @@ -143,3 +152,82 @@ export async function imageToDataUrl( } return await blobToDataUrl(input); } + +export class MockLLM implements LLM { + metadata: LLMMetadata; + options: { + timeBetweenToken: number; + responseMessage: string; + }; + + constructor(options?: { + timeBetweenToken?: number; + responseMessage?: string; + metadata?: LLMMetadata; + }) { + this.options = { + timeBetweenToken: options?.timeBetweenToken ?? 20, + responseMessage: options?.responseMessage ?? "This is a mock response", + }; + this.metadata = options?.metadata ?? { + model: "MockLLM", + temperature: 0.5, + topP: 0.5, + contextWindow: 1024, + tokenizer: undefined, + }; + } + + chat( + params: LLMChatParamsStreaming<object, object>, + ): Promise<AsyncIterable<ChatResponseChunk>>; + chat( + params: LLMChatParamsNonStreaming<object, object>, + ): Promise<ChatResponse<object>>; + async chat( + params: + | LLMChatParamsStreaming<object, object> + | LLMChatParamsNonStreaming<object, object>, + ): Promise<AsyncIterable<ChatResponseChunk> | ChatResponse<object>> { + const responseMessage = this.options.responseMessage; + const timeBetweenToken = this.options.timeBetweenToken; + + if (params.stream) { + return (async function* () { + for (const char of responseMessage) { + yield { delta: char, raw: {} }; + await new Promise((resolve) => setTimeout(resolve, timeBetweenToken)); + } + })(); + } + + return { + message: { content: responseMessage, role: "assistant" }, + raw: {}, + }; + } + + async complete( + params: LLMCompletionParamsStreaming, + ): Promise<AsyncIterable<CompletionResponse>>; + async complete( + params: LLMCompletionParamsNonStreaming, + ): Promise<CompletionResponse>; + async complete( + params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming, + ): Promise<AsyncIterable<CompletionResponse> | CompletionResponse> { + const responseMessage = this.options.responseMessage; + const timeBetweenToken = this.options.timeBetweenToken; + + if (params.stream) { + return (async function* () { + for (const char of responseMessage) { + yield { delta: char, text: char, raw: {} }; + await new Promise((resolve) => setTimeout(resolve, timeBetweenToken)); + } + })(); + } + + return { text: responseMessage, raw: {} }; + } +} -- GitLab