Skip to content
Snippets Groups Projects
Unverified Commit ae7b3010 authored by Thuc Pham's avatar Thuc Pham Committed by GitHub
Browse files

feat: display sources in chat messages (#45)

parent 5fb64b74
No related branches found
No related tags found
No related merge requests found
Showing
with 2787 additions and 2045 deletions
......@@ -220,18 +220,16 @@ Given this information, please answer the question: {query_str}
],
];
} else {
const nextJsEnvs = [
{
name: "NEXT_PUBLIC_MODEL",
description: "The LLM model to use (hardcode to front-end artifact).",
value: opts.model,
},
];
envVars = [
...defaultEnvs,
...[
opts.framework === "nextjs"
? {
name: "NEXT_PUBLIC_MODEL",
description:
"The LLM model to use (hardcode to front-end artifact).",
value: opts.model,
}
: {},
],
...(opts.framework === "nextjs" ? nextJsEnvs : []),
];
}
// Render and write env file
......
......@@ -77,5 +77,5 @@
"engines": {
"node": ">=16.14.0"
},
"packageManager": "pnpm@8.15.1"
"packageManager": "pnpm@9.0.1"
}
Source diff could not be displayed: it is too large. Options to address this: view the blob.
......@@ -10,11 +10,11 @@
"dev": "concurrently \"tsup index.ts --format esm --dts --watch\" \"nodemon -q dist/index.mjs\""
},
"dependencies": {
"ai": "^2.2.25",
"ai": "^3.0.21",
"cors": "^2.8.5",
"dotenv": "^16.3.1",
"express": "^4.18.2",
"llamaindex": "latest"
"llamaindex": "0.2.9"
},
"devDependencies": {
"@types/cors": "^2.8.16",
......@@ -22,12 +22,12 @@
"@types/node": "^20.9.5",
"concurrently": "^8.2.2",
"eslint": "^8.54.0",
"eslint-config-prettier": "^8.10.0",
"nodemon": "^3.0.1",
"tsup": "^8.0.1",
"typescript": "^5.3.2",
"prettier": "^3.2.5",
"prettier-plugin-organize-imports": "^3.2.4",
"eslint-config-prettier": "^8.10.0",
"ts-node": "^10.9.2"
"ts-node": "^10.9.2",
"tsup": "^8.0.1",
"typescript": "^5.3.2"
}
}
import {
JSONValue,
StreamData,
createCallbacksTransformer,
createStreamDataTransformer,
experimental_StreamData,
trimStartOfStreamHelper,
type AIStreamCallbacksAndOptions,
} from "ai";
import { Response, StreamingAgentChatResponse } from "llamaindex";
import {
Metadata,
NodeWithScore,
Response,
StreamingAgentChatResponse,
} from "llamaindex";
type ParserOptions = {
image_url?: string;
};
function appendImageData(data: StreamData, imageUrl?: string) {
if (!imageUrl) return;
data.appendMessageAnnotation({
type: "image",
data: {
url: imageUrl,
},
});
}
function appendSourceData(
data: StreamData,
sourceNodes?: NodeWithScore<Metadata>[],
) {
if (!sourceNodes?.length) return;
data.appendMessageAnnotation({
type: "sources",
data: {
nodes: sourceNodes.map((node) => ({
...node.node.toMutableJSON(),
id: node.node.id_,
score: node.score ?? null,
})),
},
});
}
function createParser(
res: AsyncIterable<Response>,
data: experimental_StreamData,
data: StreamData,
opts?: ParserOptions,
) {
const it = res[Symbol.asyncIterator]();
const trimStartOfStream = trimStartOfStreamHelper();
let sourceNodes: NodeWithScore<Metadata>[] | undefined;
return new ReadableStream<string>({
start() {
// if image_url is provided, send it via the data stream
if (opts?.image_url) {
const message: JSONValue = {
type: "image_url",
image_url: {
url: opts.image_url,
},
};
data.append(message);
} else {
data.append({}); // send an empty image response for the user's message
}
appendImageData(data, opts?.image_url);
},
async pull(controller): Promise<void> {
const { value, done } = await it.next();
if (done) {
appendSourceData(data, sourceNodes);
controller.close();
data.append({}); // send an empty image response for the assistant's message
data.close();
return;
}
if (!sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
}
const text = trimStartOfStream(value.response ?? "");
if (text) {
controller.enqueue(text);
......@@ -57,8 +83,8 @@ export function LlamaIndexStream(
callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions;
},
): { stream: ReadableStream; data: experimental_StreamData } {
const data = new experimental_StreamData();
): { stream: ReadableStream; data: StreamData } {
const data = new StreamData();
const res =
response instanceof StreamingAgentChatResponse
? response.response
......@@ -66,7 +92,7 @@ export function LlamaIndexStream(
return {
stream: createParser(res, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer(true)),
.pipeThrough(createStreamDataTransformer()),
data,
};
}
from pydantic import BaseModel
from typing import List, Any, Optional, Dict, Tuple
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import StreamingResponse
from llama_index.core.chat_engine.types import (
BaseChatEngine,
StreamingAgentChatResponse,
)
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import ChatMessage, MessageRole
from app.engine import get_chat_engine
from app.api.routers.vercel_response import VercelStreamResponse
chat_router = r = APIRouter()
......@@ -37,6 +38,7 @@ class _SourceNodes(BaseModel):
id: str
metadata: Dict[str, Any]
score: Optional[float]
text: str
@classmethod
def from_source_node(cls, source_node: NodeWithScore):
......@@ -44,6 +46,7 @@ class _SourceNodes(BaseModel):
id=source_node.node.node_id,
metadata=source_node.node.metadata,
score=source_node.score,
text=source_node.node.text,
)
@classmethod
......@@ -91,13 +94,28 @@ async def chat(
response = await chat_engine.astream_chat(last_message_content, messages)
async def event_generator():
async def event_generator(request: Request, response: StreamingAgentChatResponse):
# Yield the text response
async for token in response.async_response_gen():
# If client closes connection, stop sending events
if await request.is_disconnected():
break
yield token
yield VercelStreamResponse.convert_text(token)
# Yield the source nodes
yield VercelStreamResponse.convert_data(
{
"type": "sources",
"data": {
"nodes": [
_SourceNodes.from_source_node(node).dict()
for node in response.source_nodes
]
},
}
)
return StreamingResponse(event_generator(), media_type="text/plain")
return VercelStreamResponse(content=event_generator(request, response))
# non-streaming endpoint - delete if not needed
......
import json
from typing import Any
from fastapi.responses import StreamingResponse
class VercelStreamResponse(StreamingResponse):
"""
Class to convert the response from the chat engine to the streaming format expected by Vercel
"""
TEXT_PREFIX = "0:"
DATA_PREFIX = "8:"
@classmethod
def convert_text(cls, token: str):
# Escape newlines to avoid breaking the stream
token = token.replace("\n", "\\n")
return f'{cls.TEXT_PREFIX}"{token}"\n'
@classmethod
def convert_data(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}[{data_str}]\n"
def __init__(self, content: Any, **kwargs):
super().__init__(
content=content,
**kwargs,
)
import {
JSONValue,
StreamData,
createCallbacksTransformer,
createStreamDataTransformer,
experimental_StreamData,
trimStartOfStreamHelper,
type AIStreamCallbacksAndOptions,
} from "ai";
import { Response, StreamingAgentChatResponse } from "llamaindex";
import {
Metadata,
NodeWithScore,
Response,
StreamingAgentChatResponse,
} from "llamaindex";
type ParserOptions = {
image_url?: string;
};
function appendImageData(data: StreamData, imageUrl?: string) {
if (!imageUrl) return;
data.appendMessageAnnotation({
type: "image",
data: {
url: imageUrl,
},
});
}
function appendSourceData(
data: StreamData,
sourceNodes?: NodeWithScore<Metadata>[],
) {
if (!sourceNodes?.length) return;
data.appendMessageAnnotation({
type: "sources",
data: {
nodes: sourceNodes.map((node) => ({
...node.node.toMutableJSON(),
id: node.node.id_,
score: node.score ?? null,
})),
},
});
}
function createParser(
res: AsyncIterable<Response>,
data: experimental_StreamData,
data: StreamData,
opts?: ParserOptions,
) {
const it = res[Symbol.asyncIterator]();
const trimStartOfStream = trimStartOfStreamHelper();
let sourceNodes: NodeWithScore<Metadata>[] | undefined;
return new ReadableStream<string>({
start() {
// if image_url is provided, send it via the data stream
if (opts?.image_url) {
const message: JSONValue = {
type: "image_url",
image_url: {
url: opts.image_url,
},
};
data.append(message);
} else {
data.append({}); // send an empty image response for the user's message
}
appendImageData(data, opts?.image_url);
},
async pull(controller): Promise<void> {
const { value, done } = await it.next();
if (done) {
appendSourceData(data, sourceNodes);
controller.close();
data.append({}); // send an empty image response for the assistant's message
data.close();
return;
}
if (!sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
}
const text = trimStartOfStream(value.response ?? "");
if (text) {
controller.enqueue(text);
......@@ -57,8 +83,8 @@ export function LlamaIndexStream(
callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions;
},
): { stream: ReadableStream; data: experimental_StreamData } {
const data = new experimental_StreamData();
): { stream: ReadableStream; data: StreamData } {
const data = new StreamData();
const res =
response instanceof StreamingAgentChatResponse
? response.response
......@@ -66,7 +92,7 @@ export function LlamaIndexStream(
return {
stream: createParser(res, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer(true)),
.pipeThrough(createStreamDataTransformer()),
data,
};
}
"use client";
import { useChat } from "ai/react";
import { useMemo } from "react";
import { insertDataIntoMessages } from "./transform";
import { ChatInput, ChatMessages } from "./ui/chat";
export default function ChatSection() {
......@@ -14,7 +12,6 @@ export default function ChatSection() {
handleInputChange,
reload,
stop,
data,
} = useChat({
api: process.env.NEXT_PUBLIC_CHAT_API,
headers: {
......@@ -22,14 +19,10 @@ export default function ChatSection() {
},
});
const transformedMessages = useMemo(() => {
return insertDataIntoMessages(messages, data);
}, [messages, data]);
return (
<div className="space-y-4 max-w-5xl w-full">
<ChatMessages
messages={transformedMessages}
messages={messages}
isLoading={isLoading}
reload={reload}
stop={stop}
......
import { JSONValue, Message } from "ai";
export const isValidMessageData = (rawData: JSONValue | undefined) => {
if (!rawData || typeof rawData !== "object") return false;
if (Object.keys(rawData).length === 0) return false;
return true;
};
export const insertDataIntoMessages = (
messages: Message[],
data: JSONValue[] | undefined,
) => {
if (!data) return messages;
messages.forEach((message, i) => {
const rawData = data[i];
if (isValidMessageData(rawData)) message.data = rawData;
});
return messages;
};
import Image from "next/image";
import { type ImageData } from "./index";
export function ChatImage({ data }: { data: ImageData }) {
return (
<div className="rounded-md max-w-[200px] shadow-md">
<Image
src={data.url}
width={0}
height={0}
sizes="100vw"
style={{ width: "100%", height: "auto" }}
alt=""
/>
</div>
);
}
import { Check, Copy } from "lucide-react";
import { JSONValue, Message } from "ai";
import Image from "next/image";
import { Message } from "ai";
import { Fragment } from "react";
import { Button } from "../button";
import ChatAvatar from "./chat-avatar";
import { ChatImage } from "./chat-image";
import { ChatSources } from "./chat-sources";
import {
AnnotationData,
ImageData,
MessageAnnotation,
MessageAnnotationType,
SourceData,
} from "./index";
import Markdown from "./markdown";
import { useCopyToClipboard } from "./use-copy-to-clipboard";
interface ChatMessageImageData {
type: "image_url";
image_url: {
url: string;
};
type ContentDiplayConfig = {
order: number;
component: JSX.Element | null;
};
function getAnnotationData<T extends AnnotationData>(
annotations: MessageAnnotation[],
type: MessageAnnotationType,
): T | undefined {
return annotations.find((a) => a.type === type)?.data as T | undefined;
}
// This component will parse message data and render the appropriate UI.
function ChatMessageData({ messageData }: { messageData: JSONValue }) {
const { image_url, type } = messageData as unknown as ChatMessageImageData;
if (type === "image_url") {
return (
<div className="rounded-md max-w-[200px] shadow-md">
<Image
src={image_url.url}
width={0}
height={0}
sizes="100vw"
style={{ width: "100%", height: "auto" }}
alt=""
/>
</div>
);
}
return null;
function ChatMessageContent({ message }: { message: Message }) {
const annotations = message.annotations as MessageAnnotation[] | undefined;
if (!annotations?.length) return <Markdown content={message.content} />;
const imageData = getAnnotationData<ImageData>(
annotations,
MessageAnnotationType.IMAGE,
);
const sourceData = getAnnotationData<SourceData>(
annotations,
MessageAnnotationType.SOURCES,
);
const contents: ContentDiplayConfig[] = [
{
order: -1,
component: imageData ? <ChatImage data={imageData} /> : null,
},
{
order: 0,
component: <Markdown content={message.content} />,
},
{
order: 1,
component: sourceData ? <ChatSources data={sourceData} /> : null,
},
];
return (
<div className="flex-1 gap-4 flex flex-col">
{contents
.sort((a, b) => a.order - b.order)
.map((content, index) => (
<Fragment key={index}>{content.component}</Fragment>
))}
</div>
);
}
export default function ChatMessage(chatMessage: Message) {
......@@ -40,12 +73,7 @@ export default function ChatMessage(chatMessage: Message) {
<div className="flex items-start gap-4 pr-5 pt-5">
<ChatAvatar role={chatMessage.role} />
<div className="group flex flex-1 justify-between gap-2">
<div className="flex-1 space-y-4">
{chatMessage.data && (
<ChatMessageData messageData={chatMessage.data} />
)}
<Markdown content={chatMessage.content} />
</div>
<ChatMessageContent message={chatMessage} />
<Button
onClick={() => copyToClipboard(chatMessage.content)}
size="icon"
......
import { ArrowUpRightSquare, Check, Copy } from "lucide-react";
import { useMemo } from "react";
import { Button } from "../button";
import { HoverCard, HoverCardContent, HoverCardTrigger } from "../hover-card";
import { SourceData, SourceNode } from "./index";
import { useCopyToClipboard } from "./use-copy-to-clipboard";
const SCORE_THRESHOLD = 0.5;
export function ChatSources({ data }: { data: SourceData }) {
const sources = useMemo(() => {
return (
data.nodes
?.filter((node) => (node.score ?? 1) > SCORE_THRESHOLD)
.sort((a, b) => (b.score ?? 1) - (a.score ?? 1)) || []
);
}, [data.nodes]);
if (sources.length === 0) return null;
return (
<div className="space-x-2 text-sm">
<span className="font-semibold">Sources:</span>
<div className="inline-flex gap-1 items-center">
{sources.map((node: SourceNode, index: number) => (
<div key={node.id}>
<HoverCard>
<HoverCardTrigger>
<div className="text-xs w-5 h-5 rounded-full bg-gray-100 mb-2 flex items-center justify-center hover:text-white hover:bg-primary hover:cursor-pointer">
{index + 1}
</div>
</HoverCardTrigger>
<HoverCardContent>
<NodeInfo node={node} />
</HoverCardContent>
</HoverCard>
</div>
))}
</div>
</div>
);
}
function NodeInfo({ node }: { node: SourceNode }) {
const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 1000 });
if (typeof node.metadata["URL"] === "string") {
// this is a node generated by the web loader, it contains an external URL
// add a link to view this URL
return (
<a
className="space-x-2 flex items-center my-2 hover:text-blue-900"
href={node.metadata["URL"]}
target="_blank"
>
<span>{node.metadata["URL"]}</span>
<ArrowUpRightSquare className="w-4 h-4" />
</a>
);
}
if (typeof node.metadata["file_path"] === "string") {
// this is a node generated by the file loader, it contains file path
// add a button to copy the path to the clipboard
const filePath = node.metadata["file_path"];
return (
<div className="flex items-center px-2 py-1 justify-between my-2">
<span>{filePath}</span>
<Button
onClick={() => copyToClipboard(filePath)}
size="icon"
variant="ghost"
className="h-12 w-12"
>
{isCopied ? (
<Check className="h-4 w-4" />
) : (
<Copy className="h-4 w-4" />
)}
</Button>
</div>
);
}
// node generated by unknown loader, implement renderer by analyzing logged out metadata
console.log("Node metadata", node.metadata);
return <p>Sorry, unknown node type. Please add a new renderer.</p>;
}
......@@ -3,3 +3,30 @@ import ChatMessages from "./chat-messages";
export { type ChatHandler } from "./chat.interface";
export { ChatInput, ChatMessages };
export enum MessageAnnotationType {
IMAGE = "image",
SOURCES = "sources",
}
export type ImageData = {
url: string;
};
export type SourceNode = {
id: string;
metadata: Record<string, unknown>;
score?: number;
text: string;
};
export type SourceData = {
nodes: SourceNode[];
};
export type AnnotationData = ImageData | SourceData;
export type MessageAnnotation = {
type: MessageAnnotationType;
data: AnnotationData;
};
"use client";
import * as HoverCardPrimitive from "@radix-ui/react-hover-card";
import * as React from "react";
import { cn } from "./lib/utils";
const HoverCard = HoverCardPrimitive.Root;
const HoverCardTrigger = HoverCardPrimitive.Trigger;
const HoverCardContent = React.forwardRef<
React.ElementRef<typeof HoverCardPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof HoverCardPrimitive.Content>
>(({ className, align = "center", sideOffset = 4, ...props }, ref) => (
<HoverCardPrimitive.Content
ref={ref}
align={align}
sideOffset={sideOffset}
className={cn(
"z-50 w-64 rounded-md border bg-popover p-4 text-popover-foreground shadow-md outline-none data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
className,
)}
{...props}
/>
));
HoverCardContent.displayName = HoverCardPrimitive.Content.displayName;
export { HoverCard, HoverCardContent, HoverCardTrigger };
......@@ -10,12 +10,13 @@
"lint": "next lint"
},
"dependencies": {
"@radix-ui/react-hover-card": "^1.0.7",
"@radix-ui/react-slot": "^1.0.2",
"ai": "^2.2.27",
"ai": "^3.0.21",
"class-variance-authority": "^0.7.0",
"clsx": "^1.2.1",
"dotenv": "^16.3.1",
"llamaindex": "latest",
"llamaindex": "0.2.9",
"lucide-react": "^0.294.0",
"next": "^14.0.3",
"react": "^18.2.0",
......@@ -33,17 +34,17 @@
"@types/node": "^20.10.3",
"@types/react": "^18.2.42",
"@types/react-dom": "^18.2.17",
"@types/react-syntax-highlighter": "^15.5.11",
"autoprefixer": "^10.4.16",
"cross-env": "^7.0.3",
"eslint": "^8.55.0",
"eslint-config-next": "^14.0.3",
"eslint-config-prettier": "^8.10.0",
"postcss": "^8.4.32",
"tailwindcss": "^3.3.6",
"typescript": "^5.3.2",
"@types/react-syntax-highlighter": "^15.5.11",
"cross-env": "^7.0.3",
"prettier": "^3.2.5",
"prettier-plugin-organize-imports": "^3.2.4",
"eslint-config-prettier": "^8.10.0",
"ts-node": "^10.9.2"
"tailwindcss": "^3.3.6",
"ts-node": "^10.9.2",
"typescript": "^5.3.2"
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment