diff --git a/.changeset/olive-bulldogs-boil.md b/.changeset/olive-bulldogs-boil.md new file mode 100644 index 0000000000000000000000000000000000000000..4fe2f473d94b47b78bafef1c12d553babfd08a56 --- /dev/null +++ b/.changeset/olive-bulldogs-boil.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Add in-text citation references diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index 1ddd269c101eb0d2d4662050e66d71e9f188e985..65dc853a3b7f575a28a8e65c1f3cdadc9821bf02 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -4,6 +4,7 @@ import { TOOL_SYSTEM_PROMPT_ENV_VAR, Tool } from "./tools"; import { InstallTemplateArgs, ModelConfig, + TemplateDataSource, TemplateFramework, TemplateObservability, TemplateType, @@ -423,7 +424,11 @@ const getToolEnvs = (tools?: Tool[]): EnvVar[] => { return toolEnvs; }; -const getSystemPromptEnv = (tools?: Tool[]): EnvVar => { +const getSystemPromptEnv = ( + tools?: Tool[], + dataSources?: TemplateDataSource[], + framework?: TemplateFramework, +): EnvVar[] => { const defaultSystemPrompt = "You are a helpful assistant who helps users with their questions."; @@ -442,11 +447,49 @@ const getSystemPromptEnv = (tools?: Tool[]): EnvVar => { ? `\"${toolSystemPrompt}\"` : defaultSystemPrompt; - return { - name: "SYSTEM_PROMPT", - description: "The system prompt for the AI model.", - value: systemPrompt, - }; + const systemPromptEnv = [ + { + name: "SYSTEM_PROMPT", + description: "The system prompt for the AI model.", + value: systemPrompt, + }, + ]; + + // Citation only works with FastAPI along with the chat engine and data source provided for now. + if ( + framework === "fastapi" && + tools?.length == 0 && + (dataSources?.length ?? 0 > 0) + ) { + const citationPrompt = `'You have provided information from a knowledge base that has been passed to you in nodes of information. +Each node has useful metadata such as node ID, file name, page, etc. +Please add the citation to the data node for each sentence or paragraph that you reference in the provided information. +The citation format is: . [citation:<node_id>]() +Where the <node_id> is the unique identifier of the data node. + +Example: +We have two nodes: + node_id: xyz + file_name: llama.pdf + + node_id: abc + file_name: animal.pdf + +User question: Tell me a fun fact about Llama. +Your answer: +A baby llama is called "Cria" [citation:xyz](). +It often live in desert [citation:abc](). +It\\'s cute animal. +'`; + systemPromptEnv.push({ + name: "SYSTEM_CITATION_PROMPT", + description: + "An additional system prompt to add citation when responding to user questions.", + value: citationPrompt, + }); + } + + return systemPromptEnv; }; const getTemplateEnvs = (template?: TemplateType): EnvVar[] => { @@ -525,7 +568,7 @@ export const createBackendEnvFile = async ( ...getToolEnvs(opts.tools), ...getTemplateEnvs(opts.template), ...getObservabilityEnvs(opts.observability), - getSystemPromptEnv(opts.tools), + ...getSystemPromptEnv(opts.tools, opts.dataSources, opts.framework), ]; // Render and write env file const content = renderEnvVar(envVars); diff --git a/templates/components/engines/python/chat/__init__.py b/templates/components/engines/python/chat/__init__.py index 7d8df55ae0074af3f695df0a1cd3ede7ea999053..b1fd361c2411a79c3d0c3481cd6e7540a097dacd 100644 --- a/templates/components/engines/python/chat/__init__.py +++ b/templates/components/engines/python/chat/__init__.py @@ -1,11 +1,20 @@ import os + from app.engine.index import get_index +from app.engine.node_postprocessors import NodeCitationProcessor from fastapi import HTTPException +from llama_index.core.chat_engine import CondensePlusContextChatEngine def get_chat_engine(filters=None, params=None): system_prompt = os.getenv("SYSTEM_PROMPT") - top_k = os.getenv("TOP_K", 3) + citation_prompt = os.getenv("SYSTEM_CITATION_PROMPT", None) + top_k = int(os.getenv("TOP_K", 3)) + + node_postprocessors = [] + if citation_prompt: + node_postprocessors = [NodeCitationProcessor()] + system_prompt = f"{system_prompt}\n{citation_prompt}" index = get_index(params) if index is None: @@ -16,9 +25,13 @@ def get_chat_engine(filters=None, params=None): ), ) - return index.as_chat_engine( - similarity_top_k=int(top_k), - system_prompt=system_prompt, - chat_mode="condense_plus_context", + retriever = index.as_retriever( + similarity_top_k=top_k, filters=filters, ) + + return CondensePlusContextChatEngine.from_defaults( + system_prompt=system_prompt, + retriever=retriever, + node_postprocessors=node_postprocessors, + ) diff --git a/templates/components/engines/python/chat/node_postprocessors.py b/templates/components/engines/python/chat/node_postprocessors.py new file mode 100644 index 0000000000000000000000000000000000000000..336cd0edcb7e6200cb818f6a8ecade039dbd592e --- /dev/null +++ b/templates/components/engines/python/chat/node_postprocessors.py @@ -0,0 +1,21 @@ +from typing import List, Optional + +from llama_index.core import QueryBundle +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.core.schema import NodeWithScore + + +class NodeCitationProcessor(BaseNodePostprocessor): + """ + Append node_id into metadata for citation purpose. + Config SYSTEM_CITATION_PROMPT in your runtime environment variable to enable this feature. + """ + + def _postprocess_nodes( + self, + nodes: List[NodeWithScore], + query_bundle: Optional[QueryBundle] = None, + ) -> List[NodeWithScore]: + for node_score in nodes: + node_score.node.metadata["node_id"] = node_score.node.node_id + return nodes diff --git a/templates/types/streaming/fastapi/app/api/services/file.py b/templates/types/streaming/fastapi/app/api/services/file.py index 36113f9d8d73bb430a924f87b31bfa943373c350..72107f8d2f7e1ce2b60546d6141a5046ad1fbbd5 100644 --- a/templates/types/streaming/fastapi/app/api/services/file.py +++ b/templates/types/streaming/fastapi/app/api/services/file.py @@ -4,7 +4,6 @@ import os from io import BytesIO from pathlib import Path from typing import Any, List, Tuple -from uuid import uuid4 from app.engine.index import get_index @@ -48,12 +47,9 @@ class PrivateFileService: return base64.b64decode(data), extension @staticmethod - def store_and_parse_file(file_data, extension) -> List[Document]: + def store_and_parse_file(file_name, file_data, extension) -> List[Document]: # Store file to the private directory os.makedirs(PrivateFileService.PRIVATE_STORE_PATH, exist_ok=True) - - # random file name - file_name = f"{uuid4().hex}{extension}" file_path = Path(os.path.join(PrivateFileService.PRIVATE_STORE_PATH, file_name)) # write file @@ -104,7 +100,9 @@ class PrivateFileService: ] else: # First process documents into nodes - documents = PrivateFileService.store_and_parse_file(file_data, extension) + documents = PrivateFileService.store_and_parse_file( + file_name, file_data, extension + ) pipeline = IngestionPipeline() nodes = pipeline.run(documents=documents) diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-sources.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-sources.tsx index d89ba1a07cf95275fcad081c472a56ca489c49f2..929e199cafac16c627a762aa93bb034c9bc0ada5 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-sources.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-sources.tsx @@ -13,8 +13,6 @@ import { useCopyToClipboard } from "../hooks/use-copy-to-clipboard"; import { DocumentFileType, SourceData, SourceNode } from "../index"; import PdfDialog from "../widgets/PdfDialog"; -const SCORE_THRESHOLD = 0.25; - type Document = { url: string; sources: SourceNode[]; @@ -24,15 +22,11 @@ export function ChatSources({ data }: { data: SourceData }) { const documents: Document[] = useMemo(() => { // group nodes by document (a document must have a URL) const nodesByUrl: Record<string, SourceNode[]> = {}; - data.nodes - .filter((node) => (node.score ?? 1) > SCORE_THRESHOLD) - .filter((node) => isValidUrl(node.url)) - .sort((a, b) => (b.score ?? 1) - (a.score ?? 1)) - .forEach((node) => { - const key = node.url!.replace(/\/$/, ""); // remove trailing slash - nodesByUrl[key] ??= []; - nodesByUrl[key].push(node); - }); + data.nodes.forEach((node) => { + const key = node.url; + nodesByUrl[key] ??= []; + nodesByUrl[key].push(node); + }); // convert to array of documents return Object.entries(nodesByUrl).map(([url, sources]) => ({ @@ -55,11 +49,51 @@ export function ChatSources({ data }: { data: SourceData }) { ); } -function SourceNumberButton({ index }: { index: number }) { +export function SourceInfo({ + node, + index, +}: { + node?: SourceNode; + index: number; +}) { + if (!node) return <SourceNumberButton index={index} />; + return ( + <HoverCard> + <HoverCardTrigger + className="cursor-default" + onClick={(e) => { + e.preventDefault(); + e.stopPropagation(); + }} + > + <SourceNumberButton + index={index} + className="hover:text-white hover:bg-primary" + /> + </HoverCardTrigger> + <HoverCardContent className="w-[400px]"> + <NodeInfo nodeInfo={node} /> + </HoverCardContent> + </HoverCard> + ); +} + +export function SourceNumberButton({ + index, + className, +}: { + index: number; + className?: string; +}) { return ( - <div className="text-xs w-5 h-5 rounded-full bg-gray-100 flex items-center justify-center hover:text-white hover:bg-primary "> + <span + className={cn( + "text-xs w-5 h-5 rounded-full bg-gray-100 inline-flex items-center justify-center", + className, + )} + > {index + 1} - </div> + </span> ); } @@ -89,20 +123,7 @@ function DocumentInfo({ document }: { document: Document }) { {sources.map((node: SourceNode, index: number) => { return ( <div key={node.id}> - <HoverCard> - <HoverCardTrigger - className="cursor-default" - onClick={(e) => { - e.preventDefault(); - e.stopPropagation(); - }} - > - <SourceNumberButton index={index} /> - </HoverCardTrigger> - <HoverCardContent className="w-[400px]"> - <NodeInfo nodeInfo={node} /> - </HoverCardContent> - </HoverCard> + <SourceInfo node={node} index={index} /> </div> ); })} @@ -125,13 +146,7 @@ function DocumentInfo({ document }: { document: Document }) { if (url.endsWith(".pdf")) { // open internal pdf dialog for pdf files when click document card - return ( - <PdfDialog - documentId={document.url} - url={document.url} - trigger={DocumentDetail} - /> - ); + return <PdfDialog documentId={url} url={url} trigger={DocumentDetail} />; } // open external link when click document card for other file types return <div onClick={() => window.open(url, "_blank")}>{DocumentDetail}</div>; @@ -179,13 +194,3 @@ function NodeInfo({ nodeInfo }: { nodeInfo: SourceNode }) { </div> ); } - -function isValidUrl(url?: string): boolean { - if (!url) return false; - try { - new URL(url); - return true; - } catch (_) { - return false; - } -} diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/index.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/index.tsx index e71903eaaa758a5a193e4af56e1a87c51d112f78..41195d446ce8f7f234fc566c0a49c244712158cb 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/index.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/index.tsx @@ -11,10 +11,10 @@ import { ImageData, MessageAnnotation, MessageAnnotationType, - SourceData, SuggestedQuestionsData, ToolData, getAnnotationData, + getSourceAnnotationData, } from "../index"; import ChatAvatar from "./chat-avatar"; import { ChatEvents } from "./chat-events"; @@ -54,10 +54,9 @@ function ChatMessageContent({ annotations, MessageAnnotationType.EVENTS, ); - const sourceData = getAnnotationData<SourceData>( - annotations, - MessageAnnotationType.SOURCES, - ); + + const sourceData = getSourceAnnotationData(annotations); + const toolData = getAnnotationData<ToolData>( annotations, MessageAnnotationType.TOOLS, @@ -91,7 +90,7 @@ function ChatMessageContent({ }, { order: 0, - component: <Markdown content={message.content} />, + component: <Markdown content={message.content} sources={sourceData[0]} />, }, { order: 3, diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/markdown.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/markdown.tsx index 79791b46cd7ee05f3f5a02306d493d7cd82c24ed..47b5de5af81e3d4f9622df69f1d62698246c5e11 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/markdown.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/markdown.tsx @@ -5,6 +5,8 @@ import rehypeKatex from "rehype-katex"; import remarkGfm from "remark-gfm"; import remarkMath from "remark-math"; +import { SourceData } from ".."; +import { SourceNumberButton } from "./chat-sources"; import { CodeBlock } from "./codeblock"; const MemoizedReactMarkdown: FC<Options> = memo( @@ -34,12 +36,48 @@ const preprocessMedia = (content: string) => { return content.replace(/(sandbox|attachment|snt):/g, ""); }; -const preprocessContent = (content: string) => { - return preprocessMedia(preprocessLaTeX(content)); +/** + * Update the citation flag [citation:id]() to the new format [citation:index](url) + */ +const preprocessCitations = (content: string, sources?: SourceData) => { + if (sources) { + const citationRegex = /\[citation:(.+?)\]\(\)/g; + let match; + // Find all the citation references in the content + while ((match = citationRegex.exec(content)) !== null) { + const citationId = match[1]; + // Find the source node with the id equal to the citation-id, also get the index of the source node + const sourceNode = sources.nodes.find((node) => node.id === citationId); + // If the source node is found, replace the citation reference with the new format + if (sourceNode !== undefined) { + content = content.replace( + match[0], + `[citation:${sources.nodes.indexOf(sourceNode)}]()`, + ); + } else { + // If the source node is not found, remove the citation reference + content = content.replace(match[0], ""); + } + } + } + return content; }; -export default function Markdown({ content }: { content: string }) { - const processedContent = preprocessContent(content); +const preprocessContent = (content: string, sources?: SourceData) => { + return preprocessCitations( + preprocessMedia(preprocessLaTeX(content)), + sources, + ); +}; + +export default function Markdown({ + content, + sources, +}: { + content: string; + sources?: SourceData; +}) { + const processedContent = preprocessContent(content, sources); return ( <MemoizedReactMarkdown @@ -80,6 +118,23 @@ export default function Markdown({ content }: { content: string }) { /> ); }, + a({ href, children }) { + // If a text link starts with 'citation:', then render it as a citation reference + if ( + Array.isArray(children) && + typeof children[0] === "string" && + children[0].startsWith("citation:") + ) { + const index = Number(children[0].replace("citation:", "")); + if (!isNaN(index)) { + return <SourceNumberButton index={index} />; + } else { + // citation is not looked up yet, don't render anything + return <></>; + } + } + return <a href={href}>{children}</a>; + }, }} > {processedContent} diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/index.ts b/templates/types/streaming/nextjs/app/components/ui/chat/index.ts index dcfc9cde288539f563eee19c6b51f7714e0fb7b9..669f404adac58a6e50c6a0a0a18f7cfc6e1e7cb5 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/index.ts +++ b/templates/types/streaming/nextjs/app/components/ui/chat/index.ts @@ -1,4 +1,5 @@ import { JSONValue } from "ai"; +import { isValidUrl } from "../lib/utils"; import ChatInput from "./chat-input"; import ChatMessages from "./chat-messages"; @@ -42,7 +43,7 @@ export type SourceNode = { metadata: Record<string, unknown>; score?: number; text: string; - url?: string; + url: string; }; export type SourceData = { @@ -83,9 +84,41 @@ export type MessageAnnotation = { data: AnnotationData; }; +const NODE_SCORE_THRESHOLD = 0.25; + export function getAnnotationData<T extends AnnotationData>( annotations: MessageAnnotation[], type: MessageAnnotationType, ): T[] { return annotations.filter((a) => a.type === type).map((a) => a.data as T); } + +export function getSourceAnnotationData( + annotations: MessageAnnotation[], +): SourceData[] { + const data = getAnnotationData<SourceData>( + annotations, + MessageAnnotationType.SOURCES, + ); + if (data.length > 0) { + const sourceData = data[0] as SourceData; + if (sourceData.nodes) { + sourceData.nodes = preprocessSourceNodes(sourceData.nodes); + } + } + return data; +} + +function preprocessSourceNodes(nodes: SourceNode[]): SourceNode[] { + // Filter source nodes has lower score + nodes = nodes + .filter((node) => (node.score ?? 1) > NODE_SCORE_THRESHOLD) + .filter((node) => isValidUrl(node.url)) + .sort((a, b) => (b.score ?? 1) - (a.score ?? 1)) + .map((node) => { + // remove trailing slash for node url if exists + node.url = node.url.replace(/\/$/, ""); + return node; + }); + return nodes; +} diff --git a/templates/types/streaming/nextjs/app/components/ui/lib/utils.ts b/templates/types/streaming/nextjs/app/components/ui/lib/utils.ts index a5ef193506d07d0459fec4f187af08283094d7c8..59c84c0c411340b22e4488533e94bea9cc79a337 100644 --- a/templates/types/streaming/nextjs/app/components/ui/lib/utils.ts +++ b/templates/types/streaming/nextjs/app/components/ui/lib/utils.ts @@ -4,3 +4,13 @@ import { twMerge } from "tailwind-merge"; export function cn(...inputs: ClassValue[]) { return twMerge(clsx(inputs)); } + +export function isValidUrl(url?: string): boolean { + if (!url) return false; + try { + new URL(url); + return true; + } catch (_) { + return false; + } +} \ No newline at end of file