From ef4f63d9f4fc7047fc48e8688d8272c92c3527f4 Mon Sep 17 00:00:00 2001
From: Thuc Pham <51660321+thucpn@users.noreply.github.com>
Date: Tue, 19 Nov 2024 02:39:46 +0700
Subject: [PATCH] refactor: move mockLLM to core (#1493)

Co-authored-by: Alex Yang <himself65@outlook.com>
---
 .changeset/perfect-turtles-mate.md            |  5 ++
 apps/next/src/app/api/chat/route.ts           |  4 +-
 .../components/demo/chat/rsc/ai-action.tsx    |  4 +-
 apps/next/src/lib/utils.ts                    | 32 +------
 packages/core/src/utils/index.ts              |  1 +
 packages/core/src/utils/llms.ts               | 88 +++++++++++++++++++
 6 files changed, 100 insertions(+), 34 deletions(-)
 create mode 100644 .changeset/perfect-turtles-mate.md

diff --git a/.changeset/perfect-turtles-mate.md b/.changeset/perfect-turtles-mate.md
new file mode 100644
index 000000000..efba7bb25
--- /dev/null
+++ b/.changeset/perfect-turtles-mate.md
@@ -0,0 +1,5 @@
+---
+"@llamaindex/core": patch
+---
+
+refactor: move mockLLM to core
diff --git a/apps/next/src/app/api/chat/route.ts b/apps/next/src/app/api/chat/route.ts
index 49cbd9c01..0eb9e3133 100644
--- a/apps/next/src/app/api/chat/route.ts
+++ b/apps/next/src/app/api/chat/route.ts
@@ -1,9 +1,9 @@
-import { llm } from "@/lib/utils";
+import { MockLLM } from "@llamaindex/core/utils";
 import { LlamaIndexAdapter, type Message } from "ai";
 import { Settings, SimpleChatEngine, type ChatMessage } from "llamaindex";
 import { NextResponse, type NextRequest } from "next/server";
 
-Settings.llm = llm;
+Settings.llm = new MockLLM(); // config your LLM here
 
 export async function POST(request: NextRequest) {
   try {
diff --git a/apps/next/src/components/demo/chat/rsc/ai-action.tsx b/apps/next/src/components/demo/chat/rsc/ai-action.tsx
index dd74f5e0d..169ba827a 100644
--- a/apps/next/src/components/demo/chat/rsc/ai-action.tsx
+++ b/apps/next/src/components/demo/chat/rsc/ai-action.tsx
@@ -1,5 +1,5 @@
-import { llm } from "@/lib/utils";
 import { Markdown } from "@llamaindex/chat-ui/widgets";
+import { MockLLM } from "@llamaindex/core/utils";
 import { generateId, Message } from "ai";
 import { createAI, createStreamableUI, getMutableAIState } from "ai/rsc";
 import { type ChatMessage, Settings, SimpleChatEngine } from "llamaindex";
@@ -11,7 +11,7 @@ type Actions = {
   chat: (message: Message) => Promise<Message & { display: ReactNode }>;
 };
 
-Settings.llm = llm;
+Settings.llm = new MockLLM(); // config your LLM here
 
 export const AI = createAI<ServerState, FrontendState, Actions>({
   initialAIState: [],
diff --git a/apps/next/src/lib/utils.ts b/apps/next/src/lib/utils.ts
index e073bc945..bd0c391dd 100644
--- a/apps/next/src/lib/utils.ts
+++ b/apps/next/src/lib/utils.ts
@@ -1,34 +1,6 @@
-import { clsx, type ClassValue } from "clsx";
-import { LLM, LLMMetadata } from "llamaindex";
-import { twMerge } from "tailwind-merge";
+import { clsx, type ClassValue } from "clsx"
+import { twMerge } from "tailwind-merge"
 
 export function cn(...inputs: ClassValue[]) {
   return twMerge(clsx(inputs))
 }
-
-class MockLLM  {
-  metadata: LLMMetadata = {
-    model: "MockLLM",
-    temperature: 0.5,
-    topP: 0.5,
-    contextWindow: 1024,
-    tokenizer: undefined,
-  };
-
-  chat() {
-    const mockResponse = "Hello! This is a mock response";
-    return Promise.resolve(
-      new ReadableStream({
-        async start(controller) {
-          for (const char of mockResponse) {
-            controller.enqueue({ delta: char });
-            await new Promise((resolve) => setTimeout(resolve, 20));
-          }
-          controller.close();
-        },
-      }),
-    );
-  }
-}
-
-export const llm = new MockLLM() as unknown as LLM;
\ No newline at end of file
diff --git a/packages/core/src/utils/index.ts b/packages/core/src/utils/index.ts
index d040f010a..a68273919 100644
--- a/packages/core/src/utils/index.ts
+++ b/packages/core/src/utils/index.ts
@@ -76,6 +76,7 @@ export {
   extractText,
   imageToDataUrl,
   messagesToHistory,
+  MockLLM,
   toToolDescriptions,
 } from "./llms";
 
diff --git a/packages/core/src/utils/llms.ts b/packages/core/src/utils/llms.ts
index 255b82b91..c08933667 100644
--- a/packages/core/src/utils/llms.ts
+++ b/packages/core/src/utils/llms.ts
@@ -2,6 +2,15 @@ import { fs } from "@llamaindex/env";
 import { filetypemime } from "magic-bytes.js";
 import type {
   ChatMessage,
+  ChatResponse,
+  ChatResponseChunk,
+  CompletionResponse,
+  LLM,
+  LLMChatParamsNonStreaming,
+  LLMChatParamsStreaming,
+  LLMCompletionParamsNonStreaming,
+  LLMCompletionParamsStreaming,
+  LLMMetadata,
   MessageContent,
   MessageContentDetail,
   MessageContentTextDetail,
@@ -143,3 +152,82 @@ export async function imageToDataUrl(
   }
   return await blobToDataUrl(input);
 }
+
+export class MockLLM implements LLM {
+  metadata: LLMMetadata;
+  options: {
+    timeBetweenToken: number;
+    responseMessage: string;
+  };
+
+  constructor(options?: {
+    timeBetweenToken?: number;
+    responseMessage?: string;
+    metadata?: LLMMetadata;
+  }) {
+    this.options = {
+      timeBetweenToken: options?.timeBetweenToken ?? 20,
+      responseMessage: options?.responseMessage ?? "This is a mock response",
+    };
+    this.metadata = options?.metadata ?? {
+      model: "MockLLM",
+      temperature: 0.5,
+      topP: 0.5,
+      contextWindow: 1024,
+      tokenizer: undefined,
+    };
+  }
+
+  chat(
+    params: LLMChatParamsStreaming<object, object>,
+  ): Promise<AsyncIterable<ChatResponseChunk>>;
+  chat(
+    params: LLMChatParamsNonStreaming<object, object>,
+  ): Promise<ChatResponse<object>>;
+  async chat(
+    params:
+      | LLMChatParamsStreaming<object, object>
+      | LLMChatParamsNonStreaming<object, object>,
+  ): Promise<AsyncIterable<ChatResponseChunk> | ChatResponse<object>> {
+    const responseMessage = this.options.responseMessage;
+    const timeBetweenToken = this.options.timeBetweenToken;
+
+    if (params.stream) {
+      return (async function* () {
+        for (const char of responseMessage) {
+          yield { delta: char, raw: {} };
+          await new Promise((resolve) => setTimeout(resolve, timeBetweenToken));
+        }
+      })();
+    }
+
+    return {
+      message: { content: responseMessage, role: "assistant" },
+      raw: {},
+    };
+  }
+
+  async complete(
+    params: LLMCompletionParamsStreaming,
+  ): Promise<AsyncIterable<CompletionResponse>>;
+  async complete(
+    params: LLMCompletionParamsNonStreaming,
+  ): Promise<CompletionResponse>;
+  async complete(
+    params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming,
+  ): Promise<AsyncIterable<CompletionResponse> | CompletionResponse> {
+    const responseMessage = this.options.responseMessage;
+    const timeBetweenToken = this.options.timeBetweenToken;
+
+    if (params.stream) {
+      return (async function* () {
+        for (const char of responseMessage) {
+          yield { delta: char, text: char, raw: {} };
+          await new Promise((resolve) => setTimeout(resolve, timeBetweenToken));
+        }
+      })();
+    }
+
+    return { text: responseMessage, raw: {} };
+  }
+}
-- 
GitLab