From 0139a1149364686bffe3413c8c162fb9c3144711 Mon Sep 17 00:00:00 2001 From: leehuwuj <leehuwuj@gmail.com> Date: Tue, 18 Feb 2025 18:09:42 +0700 Subject: [PATCH] support source nodes --- .../multiagent/python/app/api/routers/chat.py | 2 + .../fastapi/app/api/callbacks/source_nodes.py | 70 +++++++++++++++++++ .../streaming/fastapi/app/api/routers/chat.py | 2 + .../ui/chat/chat-message-content.tsx | 5 +- .../components/ui/chat/tools/query-index.tsx | 43 +++++++++++- templates/types/streaming/nextjs/package.json | 3 +- 6 files changed, 120 insertions(+), 5 deletions(-) create mode 100644 templates/types/streaming/fastapi/app/api/callbacks/source_nodes.py diff --git a/templates/components/multiagent/python/app/api/routers/chat.py b/templates/components/multiagent/python/app/api/routers/chat.py index d7a44e69..37c30530 100644 --- a/templates/components/multiagent/python/app/api/routers/chat.py +++ b/templates/components/multiagent/python/app/api/routers/chat.py @@ -5,6 +5,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status from app.api.callbacks.llamacloud import LlamaCloudFileDownload from app.api.callbacks.next_question import SuggestNextQuestions from app.api.callbacks.stream_handler import StreamHandler +from app.api.callbacks.add_node_url import AddNodeUrl from app.api.routers.models import ( ChatData, ) @@ -45,6 +46,7 @@ async def chat( callbacks=[ LlamaCloudFileDownload.from_default(background_tasks), SuggestNextQuestions.from_default(data), + AddNodeUrl.from_default(), ], ).vercel_stream() except Exception as e: diff --git a/templates/types/streaming/fastapi/app/api/callbacks/source_nodes.py b/templates/types/streaming/fastapi/app/api/callbacks/source_nodes.py new file mode 100644 index 00000000..01a35108 --- /dev/null +++ b/templates/types/streaming/fastapi/app/api/callbacks/source_nodes.py @@ -0,0 +1,70 @@ +import logging +import os +from typing import Any, Dict, Optional + +from app.api.callbacks.base import EventCallback +from app.config import DATA_DIR +from llama_index.core.agent.workflow.workflow_events import ToolCallResult + +logger = logging.getLogger("uvicorn") + + +class AddNodeUrl(EventCallback): + """ + Add URL to source nodes + """ + + async def run(self, event: Any) -> Any: + if self._is_retrieval_result_event(event): + for node_score in event.tool_output.raw_output.source_nodes: + node_score.node.metadata["url"] = self._get_url_from_metadata( + node_score.node.metadata + ) + return event + + def _is_retrieval_result_event(self, event: Any) -> bool: + if isinstance(event, ToolCallResult): + if event.tool_name == "query_index": + return True + return False + + def _get_url_from_metadata(self, metadata: Dict[str, Any]) -> Optional[str]: + url_prefix = os.getenv("FILESERVER_URL_PREFIX") + if not url_prefix: + logger.warning( + "Warning: FILESERVER_URL_PREFIX not set in environment variables. Can't use file server" + ) + file_name = metadata.get("file_name") + + if file_name and url_prefix: + # file_name exists and file server is configured + pipeline_id = metadata.get("pipeline_id") + if pipeline_id: + # file is from LlamaCloud + file_name = f"{pipeline_id}${file_name}" + return f"{url_prefix}/output/llamacloud/{file_name}" + is_private = metadata.get("private", "false") == "true" + if is_private: + # file is a private upload + return f"{url_prefix}/output/uploaded/{file_name}" + # file is from calling the 'generate' script + # Get the relative path of file_path to data_dir + file_path = metadata.get("file_path") + data_dir = os.path.abspath(DATA_DIR) + if file_path and data_dir: + relative_path = os.path.relpath(file_path, data_dir) + return f"{url_prefix}/data/{relative_path}" + # fallback to URL in metadata (e.g. for websites) + return metadata.get("URL") + + def convert_to_source_nodes(self, event: Any) -> Any: + if self._is_retrieval_result_event(event): + for node_score in event.tool_output.raw_output.source_nodes: + node_score.node.metadata["url"] = self._get_url_from_metadata( + node_score.node.metadata + ) + return event + + @classmethod + def from_default(cls) -> "AddNodeUrl": + return cls() diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index 8103a4b5..5094adec 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -7,6 +7,7 @@ from llama_index.core.llms import MessageRole from app.api.callbacks.llamacloud import LlamaCloudFileDownload from app.api.callbacks.next_question import SuggestNextQuestions +from app.api.callbacks.source_nodes import AddNodeUrl from app.api.callbacks.stream_handler import StreamHandler from app.api.routers.models import ( ChatData, @@ -49,6 +50,7 @@ async def chat( callbacks=[ LlamaCloudFileDownload.from_default(background_tasks), SuggestNextQuestions.from_default(data), + AddNodeUrl.from_default(), ], ).vercel_stream() except Exception as e: diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message-content.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message-content.tsx index 1eb12427..58507723 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message-content.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message-content.tsx @@ -1,21 +1,22 @@ import { ChatMessage } from "@llamaindex/chat-ui"; import { DeepResearchCard } from "./custom/deep-research-card"; import { ToolAnnotations } from "./tools/chat-tools"; -import { RetrieverComponent } from "./tools/query-index"; +import { ChatSourcesComponent, RetrieverComponent } from "./tools/query-index"; import { WeatherToolComponent } from "./tools/weather-card"; export function ChatMessageContent() { return ( <ChatMessage.Content> <ChatMessage.Content.Event /> + <ChatMessage.Content.AgentEvent /> <RetrieverComponent /> <WeatherToolComponent /> - <ChatMessage.Content.AgentEvent /> <DeepResearchCard /> <ToolAnnotations /> <ChatMessage.Content.Image /> <ChatMessage.Content.Markdown /> <ChatMessage.Content.DocumentFile /> + <ChatSourcesComponent /> <ChatMessage.Content.Source /> <ChatMessage.Content.SuggestedQuestions /> </ChatMessage.Content> diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/tools/query-index.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/tools/query-index.tsx index 203fffc6..2897a019 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/tools/query-index.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/tools/query-index.tsx @@ -1,7 +1,11 @@ "use client"; -import { getCustomAnnotation, useChatMessage } from "@llamaindex/chat-ui"; -import { ChatEvents } from "@llamaindex/chat-ui/widgets"; +import { + getCustomAnnotation, + SourceNode, + useChatMessage, +} from "@llamaindex/chat-ui"; +import { ChatEvents, ChatSources } from "@llamaindex/chat-ui/widgets"; import { useMemo } from "react"; import { z } from "zod"; @@ -79,3 +83,38 @@ export function RetrieverComponent() { </div> ); } + +/** + * Render the source nodes whenever we got query_index tool with output + */ +export function ChatSourcesComponent() { + const { message } = useChatMessage(); + + const queryIndexEvents = getCustomAnnotation<QueryIndex>( + message.annotations, + (annotation) => { + const result = QueryIndexSchema.safeParse(annotation); + return result.success && !!result.data.tool_output; + }, + ); + + const sources: SourceNode[] = useMemo(() => { + return ( + queryIndexEvents?.flatMap((event) => { + const sourceNodes = + (event.tool_output?.raw_output?.source_nodes as any[]) || []; + return sourceNodes.map((node) => { + return { + id: node.node.id_, + metadata: node.node.metadata, + score: node.score, + text: node.node.text, + url: node.node.metadata.url, + }; + }); + }) || [] + ); + }, [queryIndexEvents]); + + return <ChatSources data={{ nodes: sources }} />; +} diff --git a/templates/types/streaming/nextjs/package.json b/templates/types/streaming/nextjs/package.json index 355fd4df..5de6e3a4 100644 --- a/templates/types/streaming/nextjs/package.json +++ b/templates/types/streaming/nextjs/package.json @@ -37,7 +37,8 @@ "tiktoken": "^1.0.15", "uuid": "^9.0.1", "marked": "^14.1.2", - "wikipedia": "^2.1.2" + "wikipedia": "^2.1.2", + "zod": "^3.24.2" }, "devDependencies": { "@types/node": "^20.10.3", -- GitLab