From 09004136892279421406a49031c5fdfcfba7a833 Mon Sep 17 00:00:00 2001 From: Huu Le <39040748+leehuwuj@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:04:53 +0700 Subject: [PATCH] Add next questions suggestion to the user (#170) --------- Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de> --- .changeset/tall-pans-bake.md | 5 ++ .../llamaindex/typescript/streaming/stream.ts | 23 ++++++-- .../typescript/streaming/suggestion.ts | 54 +++++++++++++++++++ .../src/controllers/chat.controller.ts | 6 ++- .../streaming/fastapi/app/api/routers/chat.py | 2 +- .../fastapi/app/api/routers/models.py | 4 +- .../app/api/routers/vercel_response.py | 45 ++++++++++++---- .../fastapi/app/api/services/suggestion.py | 48 +++++++++++++++++ .../streaming/nextjs/app/api/chat/route.ts | 6 ++- .../chat-message/chat-suggestedQuestions.tsx | 32 +++++++++++ .../components/ui/chat/chat-message/index.tsx | 26 ++++++++- .../app/components/ui/chat/chat-messages.tsx | 1 + .../nextjs/app/components/ui/chat/index.ts | 6 ++- 13 files changed, 237 insertions(+), 21 deletions(-) create mode 100644 .changeset/tall-pans-bake.md create mode 100644 templates/components/llamaindex/typescript/streaming/suggestion.ts create mode 100644 templates/types/streaming/fastapi/app/api/services/suggestion.py create mode 100644 templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-suggestedQuestions.tsx diff --git a/.changeset/tall-pans-bake.md b/.changeset/tall-pans-bake.md new file mode 100644 index 00000000..3ff4c6a8 --- /dev/null +++ b/.changeset/tall-pans-bake.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Add suggestions for next questions. diff --git a/templates/components/llamaindex/typescript/streaming/stream.ts b/templates/components/llamaindex/typescript/streaming/stream.ts index 7844c1cf..99cbf7f3 100644 --- a/templates/components/llamaindex/typescript/streaming/stream.ts +++ b/templates/components/llamaindex/typescript/streaming/stream.ts @@ -5,34 +5,51 @@ import { trimStartOfStreamHelper, type AIStreamCallbacksAndOptions, } from "ai"; -import { EngineResponse } from "llamaindex"; +import { ChatMessage, EngineResponse } from "llamaindex"; +import { generateNextQuestions } from "./suggestion"; export function LlamaIndexStream( response: AsyncIterable<EngineResponse>, data: StreamData, + chatHistory: ChatMessage[], opts?: { callbacks?: AIStreamCallbacksAndOptions; }, ): ReadableStream<Uint8Array> { - return createParser(response, data) + return createParser(response, data, chatHistory) .pipeThrough(createCallbacksTransformer(opts?.callbacks)) .pipeThrough(createStreamDataTransformer()); } -function createParser(res: AsyncIterable<EngineResponse>, data: StreamData) { +function createParser( + res: AsyncIterable<EngineResponse>, + data: StreamData, + chatHistory: ChatMessage[], +) { const it = res[Symbol.asyncIterator](); const trimStartOfStream = trimStartOfStreamHelper(); + let llmTextResponse = ""; return new ReadableStream<string>({ async pull(controller): Promise<void> { const { value, done } = await it.next(); if (done) { controller.close(); + // LLM stream is done, generate the next questions with a new LLM call + chatHistory.push({ role: "assistant", content: llmTextResponse }); + const questions: string[] = await generateNextQuestions(chatHistory); + if (questions.length > 0) { + data.appendMessageAnnotation({ + type: "suggested_questions", + data: questions, + }); + } data.close(); return; } const text = trimStartOfStream(value.delta ?? ""); if (text) { + llmTextResponse += text; controller.enqueue(text); } }, diff --git a/templates/components/llamaindex/typescript/streaming/suggestion.ts b/templates/components/llamaindex/typescript/streaming/suggestion.ts new file mode 100644 index 00000000..3cf60faa --- /dev/null +++ b/templates/components/llamaindex/typescript/streaming/suggestion.ts @@ -0,0 +1,54 @@ +import { ChatMessage, Settings } from "llamaindex"; + +const NEXT_QUESTION_PROMPT_TEMPLATE = `You're a helpful assistant! Your task is to suggest the next question that user might ask. +Here is the conversation history +--------------------- +$conversation +--------------------- +Given the conversation history, please give me $number_of_questions questions that you might ask next! +Your answer should be wrapped in three sticks which follows the following format: +\`\`\` +<question 1> +<question 2>\`\`\` +`; +const N_QUESTIONS_TO_GENERATE = 3; + +export async function generateNextQuestions( + conversation: ChatMessage[], + numberOfQuestions: number = N_QUESTIONS_TO_GENERATE, +) { + const llm = Settings.llm; + + // Format conversation + const conversationText = conversation + .map((message) => `${message.role}: ${message.content}`) + .join("\n"); + const message = NEXT_QUESTION_PROMPT_TEMPLATE.replace( + "$conversation", + conversationText, + ).replace("$number_of_questions", numberOfQuestions.toString()); + + try { + const response = await llm.complete({ prompt: message }); + const questions = extractQuestions(response.text); + return questions; + } catch (error) { + console.error("Error: ", error); + throw error; + } +} + +// TODO: instead of parsing the LLM's result we can use structured predict, once LITS supports it +function extractQuestions(text: string): string[] { + // Extract the text inside the triple backticks + const contentMatch = text.match(/```(.*?)```/s); + const content = contentMatch ? contentMatch[1] : ""; + + // Split the content by newlines to get each question + const questions = content + .split("\n") + .map((question) => question.trim()) + .filter((question) => question !== ""); + + return questions; +} diff --git a/templates/types/streaming/express/src/controllers/chat.controller.ts b/templates/types/streaming/express/src/controllers/chat.controller.ts index 30d10aad..95228e8d 100644 --- a/templates/types/streaming/express/src/controllers/chat.controller.ts +++ b/templates/types/streaming/express/src/controllers/chat.controller.ts @@ -67,7 +67,11 @@ export const chat = async (req: Request, res: Response) => { }); // Return a stream, which can be consumed by the Vercel/AI client - const stream = LlamaIndexStream(response, vercelStreamData); + const stream = LlamaIndexStream( + response, + vercelStreamData, + messages as ChatMessage[], + ); return streamToResponse(stream, res, {}, vercelStreamData); } catch (error) { diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index b1e6b6b5..e2cffadf 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -61,7 +61,7 @@ async def chat( response = await chat_engine.astream_chat(last_message_content, messages) process_response_nodes(response.source_nodes, background_tasks) - return VercelStreamResponse(request, event_handler, response) + return VercelStreamResponse(request, event_handler, response, data) except Exception as e: logger.exception("Error in chat engine", exc_info=True) raise HTTPException( diff --git a/templates/types/streaming/fastapi/app/api/routers/models.py b/templates/types/streaming/fastapi/app/api/routers/models.py index d7247b24..7a4fdf31 100644 --- a/templates/types/streaming/fastapi/app/api/routers/models.py +++ b/templates/types/streaming/fastapi/app/api/routers/models.py @@ -25,7 +25,7 @@ class File(BaseModel): filetype: str -class AnnotationData(BaseModel): +class AnnotationFileData(BaseModel): files: List[File] = Field( default=[], description="List of files", @@ -50,7 +50,7 @@ class AnnotationData(BaseModel): class Annotation(BaseModel): type: str - data: AnnotationData + data: AnnotationFileData | List[str] def to_content(self) -> str | None: if self.type == "document_file": diff --git a/templates/types/streaming/fastapi/app/api/routers/vercel_response.py b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py index 7bffa7da..0222a149 100644 --- a/templates/types/streaming/fastapi/app/api/routers/vercel_response.py +++ b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py @@ -6,7 +6,8 @@ from fastapi.responses import StreamingResponse from llama_index.core.chat_engine.types import StreamingAgentChatResponse from app.api.routers.events import EventCallbackHandler -from app.api.routers.models import SourceNodes +from app.api.routers.models import ChatData, Message, SourceNodes +from app.api.services.suggestion import NextQuestionSuggestion class VercelStreamResponse(StreamingResponse): @@ -17,15 +18,6 @@ class VercelStreamResponse(StreamingResponse): TEXT_PREFIX = "0:" DATA_PREFIX = "8:" - def __init__( - self, - request: Request, - event_handler: EventCallbackHandler, - response: StreamingAgentChatResponse, - ): - content = self.content_generator(request, event_handler, response) - super().__init__(content=content) - @classmethod def convert_text(cls, token: str): # Escape newlines and double quotes to avoid breaking the stream @@ -37,17 +29,48 @@ class VercelStreamResponse(StreamingResponse): data_str = json.dumps(data) return f"{cls.DATA_PREFIX}[{data_str}]\n" + def __init__( + self, + request: Request, + event_handler: EventCallbackHandler, + response: StreamingAgentChatResponse, + chat_data: ChatData, + ): + content = VercelStreamResponse.content_generator( + request, event_handler, response, chat_data + ) + super().__init__(content=content) + @classmethod async def content_generator( cls, request: Request, event_handler: EventCallbackHandler, response: StreamingAgentChatResponse, + chat_data: ChatData, ): # Yield the text response async def _chat_response_generator(): + final_response = "" async for token in response.async_response_gen(): - yield cls.convert_text(token) + final_response += token + yield VercelStreamResponse.convert_text(token) + + # Generate questions that user might interested to + conversation = chat_data.messages + [ + Message(role="assistant", content=final_response) + ] + questions = await NextQuestionSuggestion.suggest_next_questions( + conversation + ) + if len(questions) > 0: + yield VercelStreamResponse.convert_data( + { + "type": "suggested_questions", + "data": questions, + } + ) + # the text_generator is the leading stream, once it's finished, also finish the event stream event_handler.is_done = True diff --git a/templates/types/streaming/fastapi/app/api/services/suggestion.py b/templates/types/streaming/fastapi/app/api/services/suggestion.py new file mode 100644 index 00000000..406b0aec --- /dev/null +++ b/templates/types/streaming/fastapi/app/api/services/suggestion.py @@ -0,0 +1,48 @@ +from typing import List + +from app.api.routers.models import Message +from llama_index.core.prompts import PromptTemplate +from llama_index.core.settings import Settings +from pydantic import BaseModel + +NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate( + "You're a helpful assistant! Your task is to suggest the next question that user might ask. " + "\nHere is the conversation history" + "\n---------------------\n{conversation}\n---------------------" + "Given the conversation history, please give me $number_of_questions questions that you might ask next!" +) +N_QUESTION_TO_GENERATE = 3 + + +class NextQuestions(BaseModel): + """A list of questions that user might ask next""" + + questions: List[str] + + +class NextQuestionSuggestion: + @staticmethod + async def suggest_next_questions( + messages: List[Message], + number_of_questions: int = N_QUESTION_TO_GENERATE, + ) -> List[str]: + # Reduce the cost by only using the last two messages + last_user_message = None + last_assistant_message = None + for message in reversed(messages): + if message.role == "user": + last_user_message = f"User: {message.content}" + elif message.role == "assistant": + last_assistant_message = f"Assistant: {message.content}" + if last_user_message and last_assistant_message: + break + conversation: str = f"{last_user_message}\n{last_assistant_message}" + + output: NextQuestions = await Settings.llm.astructured_predict( + NextQuestions, + prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT, + conversation=conversation, + nun_questions=number_of_questions, + ) + + return output.questions diff --git a/templates/types/streaming/nextjs/app/api/chat/route.ts b/templates/types/streaming/nextjs/app/api/chat/route.ts index c222a4fd..792ecb7b 100644 --- a/templates/types/streaming/nextjs/app/api/chat/route.ts +++ b/templates/types/streaming/nextjs/app/api/chat/route.ts @@ -80,7 +80,11 @@ export async function POST(request: NextRequest) { }); // Transform LlamaIndex stream to Vercel/AI format - const stream = LlamaIndexStream(response, vercelStreamData); + const stream = LlamaIndexStream( + response, + vercelStreamData, + messages as ChatMessage[], + ); // Return a StreamingTextResponse, which can be consumed by the Vercel/AI client return new StreamingTextResponse(stream, {}, vercelStreamData); diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-suggestedQuestions.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-suggestedQuestions.tsx new file mode 100644 index 00000000..ea662e4e --- /dev/null +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-suggestedQuestions.tsx @@ -0,0 +1,32 @@ +import { useState } from "react"; +import { ChatHandler, SuggestedQuestionsData } from ".."; + +export function SuggestedQuestions({ + questions, + append, +}: { + questions: SuggestedQuestionsData; + append: Pick<ChatHandler, "append">["append"]; +}) { + const [showQuestions, setShowQuestions] = useState(questions.length > 0); + + return ( + showQuestions && + append !== undefined && ( + <div className="flex flex-col space-y-2"> + {questions.map((question, index) => ( + <a + key={index} + onClick={() => { + append({ role: "user", content: question }); + setShowQuestions(false); + }} + className="text-sm italic hover:underline cursor-pointer" + > + {"->"} {question} + </a> + ))} + </div> + ) + ); +} diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/index.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/index.tsx index 4f6f5572..e71903ea 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/index.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/index.tsx @@ -5,12 +5,14 @@ import { Fragment } from "react"; import { Button } from "../../button"; import { useCopyToClipboard } from "../hooks/use-copy-to-clipboard"; import { + ChatHandler, DocumentFileData, EventData, ImageData, MessageAnnotation, MessageAnnotationType, SourceData, + SuggestedQuestionsData, ToolData, getAnnotationData, } from "../index"; @@ -19,6 +21,7 @@ import { ChatEvents } from "./chat-events"; import { ChatFiles } from "./chat-files"; import { ChatImage } from "./chat-image"; import { ChatSources } from "./chat-sources"; +import { SuggestedQuestions } from "./chat-suggestedQuestions"; import ChatTools from "./chat-tools"; import Markdown from "./markdown"; @@ -30,9 +33,11 @@ type ContentDisplayConfig = { function ChatMessageContent({ message, isLoading, + append, }: { message: Message; isLoading: boolean; + append: Pick<ChatHandler, "append">["append"]; }) { const annotations = message.annotations as MessageAnnotation[] | undefined; if (!annotations?.length) return <Markdown content={message.content} />; @@ -57,6 +62,10 @@ function ChatMessageContent({ annotations, MessageAnnotationType.TOOLS, ); + const suggestedQuestionsData = getAnnotationData<SuggestedQuestionsData>( + annotations, + MessageAnnotationType.SUGGESTED_QUESTIONS, + ); const contents: ContentDisplayConfig[] = [ { @@ -88,6 +97,15 @@ function ChatMessageContent({ order: 3, component: sourceData[0] ? <ChatSources data={sourceData[0]} /> : null, }, + { + order: 4, + component: suggestedQuestionsData[0] ? ( + <SuggestedQuestions + questions={suggestedQuestionsData[0]} + append={append} + /> + ) : null, + }, ]; return ( @@ -104,16 +122,22 @@ function ChatMessageContent({ export default function ChatMessage({ chatMessage, isLoading, + append, }: { chatMessage: Message; isLoading: boolean; + append: Pick<ChatHandler, "append">["append"]; }) { const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 2000 }); return ( <div className="flex items-start gap-4 pr-5 pt-5"> <ChatAvatar role={chatMessage.role} /> <div className="group flex flex-1 justify-between gap-2"> - <ChatMessageContent message={chatMessage} isLoading={isLoading} /> + <ChatMessageContent + message={chatMessage} + isLoading={isLoading} + append={append} + /> <Button onClick={() => copyToClipboard(chatMessage.content)} size="icon" diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-messages.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-messages.tsx index f12b56b2..e0afd8b5 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-messages.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-messages.tsx @@ -53,6 +53,7 @@ export default function ChatMessages( key={m.id} chatMessage={m} isLoading={isLoadingMessage} + append={props.append!} /> ); })} diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/index.ts b/templates/types/streaming/nextjs/app/components/ui/chat/index.ts index d3892333..dcfc9cde 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/index.ts +++ b/templates/types/streaming/nextjs/app/components/ui/chat/index.ts @@ -11,6 +11,7 @@ export enum MessageAnnotationType { SOURCES = "sources", EVENTS = "events", TOOLS = "tools", + SUGGESTED_QUESTIONS = "suggested_questions", } export type ImageData = { @@ -67,12 +68,15 @@ export type ToolData = { }; }; +export type SuggestedQuestionsData = string[]; + export type AnnotationData = | ImageData | DocumentFileData | SourceData | EventData - | ToolData; + | ToolData + | SuggestedQuestionsData; export type MessageAnnotation = { type: MessageAnnotationType; -- GitLab