diff --git a/index.ts b/index.ts index 5131e9227a48d7d420236281919267d1a421f23a..ab80066569c681989a8cf3959a0ca9fbac7c9680 100644 --- a/index.ts +++ b/index.ts @@ -239,30 +239,32 @@ async function run(): Promise<void> { } } - if (!program.engine) { - if (ciInfo.isCI) { - program.engine = getPrefOrDefault("engine"); - } else { - const { engine } = await prompts( - { - type: "select", - name: "engine", - message: "Which chat engine would you like to use?", - choices: [ - { title: "SimpleChatEngine", value: "simple" }, - { title: "ContextChatEngine", value: "context" }, - ], - initial: 0, - }, - { - onCancel: () => { - console.error("Exiting."); - process.exit(1); + if (program.framework === "express" || program.framework === "nextjs") { + if (!program.engine) { + if (ciInfo.isCI) { + program.engine = getPrefOrDefault("engine"); + } else { + const { engine } = await prompts( + { + type: "select", + name: "engine", + message: "Which chat engine would you like to use?", + choices: [ + { title: "SimpleChatEngine", value: "simple" }, + { title: "ContextChatEngine", value: "context" }, + ], + initial: 0, }, - }, - ); - program.engine = engine; - preferences.engine = engine; + { + onCancel: () => { + console.error("Exiting."); + process.exit(1); + }, + }, + ); + program.engine = engine; + preferences.engine = engine; + } } } diff --git a/templates/index.ts b/templates/index.ts index 28eb7ef42edd5c312c19fecc0dd0421c47da2d51..92641318f17d2460bb7852556e8b244c0ae95cce 100644 --- a/templates/index.ts +++ b/templates/index.ts @@ -56,20 +56,30 @@ export const installTemplate = async ({ /** * Copy the selected chat engine files to the target directory and reference it. */ - console.log("\nUsing chat engine:", engine, "\n"); - const enginePath = path.join(__dirname, "engines", engine); - const engineDestPath = path.join(root, "app", "api", "chat", "engine"); - await copy("**", engineDestPath, { - parents: true, - cwd: enginePath, - }); - const routeFile = path.join(engineDestPath, "..", "route.ts"); - const routeFileContent = await fs.readFile(routeFile, "utf8"); - const newContent = routeFileContent.replace( - /^import { createChatEngine }.*$/m, - 'import { createChatEngine } from "./engine"\n', - ); - await fs.writeFile(routeFile, newContent); + let relativeEngineDestPath; + if (framework === "express" || framework === "nextjs") { + console.log("\nUsing chat engine:", engine, "\n"); + const enginePath = path.join(__dirname, "engines", engine); + relativeEngineDestPath = + framework === "nextjs" + ? path.join("app", "api", "chat") + : path.join("src", "controllers"); + await copy("**", path.join(root, relativeEngineDestPath, "engine"), { + parents: true, + cwd: enginePath, + }); + const routeFile = path.join( + root, + relativeEngineDestPath, + framework === "nextjs" ? "route.ts" : "llm.controller.ts", + ); + const routeFileContent = await fs.readFile(routeFile, "utf8"); + const newContent = routeFileContent.replace( + /^import { createChatEngine }.*$/m, + 'import { createChatEngine } from "./engine"\n', + ); + await fs.writeFile(routeFile, newContent); + } /** * Update the package.json scripts. @@ -86,11 +96,15 @@ export const installTemplate = async ({ llamaindex: version, }; - if (engine === "context") { + if (engine === "context" && relativeEngineDestPath) { // add generate script if using context engine packageJson.scripts = { ...packageJson.scripts, - generate: "node ./app/api/chat/engine/generate.mjs", + generate: `node ${path.join( + relativeEngineDestPath, + "engine", + "generate.mjs", + )}`, }; } diff --git a/templates/simple/express/src/controllers/llm.controller.ts b/templates/simple/express/src/controllers/llm.controller.ts index 54fa51b130c4cad11df9981d9461efc40364f252..7647b649baa7c24a72f80823779b9f7f6a9f8542 100644 --- a/templates/simple/express/src/controllers/llm.controller.ts +++ b/templates/simple/express/src/controllers/llm.controller.ts @@ -1,42 +1,41 @@ -import { ChatMessage, OpenAI, SimpleChatEngine } from "llamaindex"; import { NextFunction, Request, Response } from "express"; +import { ChatMessage, OpenAI } from "llamaindex"; +import { createChatEngine } from "../../../../engines/context"; export const chat = async (req: Request, res: Response, next: NextFunction) => { - try { - const { - message, - chatHistory, - }: { - message: string; - chatHistory: ChatMessage[]; - } = req.body; - if (!message || !chatHistory) { - return res.status(400).json({ - error: "message, chatHistory are required in the request body", - }); - } + try { + const { + message, + chatHistory, + }: { + message: string; + chatHistory: ChatMessage[]; + } = req.body; + if (!message || !chatHistory) { + return res.status(400).json({ + error: "message, chatHistory are required in the request body", + }); + } - const llm = new OpenAI({ - model: "gpt-3.5-turbo", - }); + const llm = new OpenAI({ + model: "gpt-3.5-turbo", + }); - const chatEngine = new SimpleChatEngine({ - llm, - }); + const chatEngine = await createChatEngine(llm); - const response = await chatEngine.chat(message, chatHistory); - const result: ChatMessage = { - role: "assistant", - content: response.response, - }; + const response = await chatEngine.chat(message, chatHistory); + const result: ChatMessage = { + role: "assistant", + content: response.response, + }; - return res.status(200).json({ - result, - }); - } catch (error) { - console.error("[LlamaIndex]", error); - return res.status(500).json({ - error: (error as Error).message, - }); - } + return res.status(200).json({ + result, + }); + } catch (error) { + console.error("[LlamaIndex]", error); + return res.status(500).json({ + error: (error as Error).message, + }); + } };