Skip to content
Snippets Groups Projects
Commit 0139a114 authored by leehuwuj's avatar leehuwuj
Browse files

support source nodes

parent 7e23d779
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
from app.api.callbacks.llamacloud import LlamaCloudFileDownload
from app.api.callbacks.next_question import SuggestNextQuestions
from app.api.callbacks.stream_handler import StreamHandler
from app.api.callbacks.add_node_url import AddNodeUrl
from app.api.routers.models import (
ChatData,
)
......@@ -45,6 +46,7 @@ async def chat(
callbacks=[
LlamaCloudFileDownload.from_default(background_tasks),
SuggestNextQuestions.from_default(data),
AddNodeUrl.from_default(),
],
).vercel_stream()
except Exception as e:
......
import logging
import os
from typing import Any, Dict, Optional
from app.api.callbacks.base import EventCallback
from app.config import DATA_DIR
from llama_index.core.agent.workflow.workflow_events import ToolCallResult
logger = logging.getLogger("uvicorn")
class AddNodeUrl(EventCallback):
"""
Add URL to source nodes
"""
async def run(self, event: Any) -> Any:
if self._is_retrieval_result_event(event):
for node_score in event.tool_output.raw_output.source_nodes:
node_score.node.metadata["url"] = self._get_url_from_metadata(
node_score.node.metadata
)
return event
def _is_retrieval_result_event(self, event: Any) -> bool:
if isinstance(event, ToolCallResult):
if event.tool_name == "query_index":
return True
return False
def _get_url_from_metadata(self, metadata: Dict[str, Any]) -> Optional[str]:
url_prefix = os.getenv("FILESERVER_URL_PREFIX")
if not url_prefix:
logger.warning(
"Warning: FILESERVER_URL_PREFIX not set in environment variables. Can't use file server"
)
file_name = metadata.get("file_name")
if file_name and url_prefix:
# file_name exists and file server is configured
pipeline_id = metadata.get("pipeline_id")
if pipeline_id:
# file is from LlamaCloud
file_name = f"{pipeline_id}${file_name}"
return f"{url_prefix}/output/llamacloud/{file_name}"
is_private = metadata.get("private", "false") == "true"
if is_private:
# file is a private upload
return f"{url_prefix}/output/uploaded/{file_name}"
# file is from calling the 'generate' script
# Get the relative path of file_path to data_dir
file_path = metadata.get("file_path")
data_dir = os.path.abspath(DATA_DIR)
if file_path and data_dir:
relative_path = os.path.relpath(file_path, data_dir)
return f"{url_prefix}/data/{relative_path}"
# fallback to URL in metadata (e.g. for websites)
return metadata.get("URL")
def convert_to_source_nodes(self, event: Any) -> Any:
if self._is_retrieval_result_event(event):
for node_score in event.tool_output.raw_output.source_nodes:
node_score.node.metadata["url"] = self._get_url_from_metadata(
node_score.node.metadata
)
return event
@classmethod
def from_default(cls) -> "AddNodeUrl":
return cls()
......@@ -7,6 +7,7 @@ from llama_index.core.llms import MessageRole
from app.api.callbacks.llamacloud import LlamaCloudFileDownload
from app.api.callbacks.next_question import SuggestNextQuestions
from app.api.callbacks.source_nodes import AddNodeUrl
from app.api.callbacks.stream_handler import StreamHandler
from app.api.routers.models import (
ChatData,
......@@ -49,6 +50,7 @@ async def chat(
callbacks=[
LlamaCloudFileDownload.from_default(background_tasks),
SuggestNextQuestions.from_default(data),
AddNodeUrl.from_default(),
],
).vercel_stream()
except Exception as e:
......
import { ChatMessage } from "@llamaindex/chat-ui";
import { DeepResearchCard } from "./custom/deep-research-card";
import { ToolAnnotations } from "./tools/chat-tools";
import { RetrieverComponent } from "./tools/query-index";
import { ChatSourcesComponent, RetrieverComponent } from "./tools/query-index";
import { WeatherToolComponent } from "./tools/weather-card";
export function ChatMessageContent() {
return (
<ChatMessage.Content>
<ChatMessage.Content.Event />
<ChatMessage.Content.AgentEvent />
<RetrieverComponent />
<WeatherToolComponent />
<ChatMessage.Content.AgentEvent />
<DeepResearchCard />
<ToolAnnotations />
<ChatMessage.Content.Image />
<ChatMessage.Content.Markdown />
<ChatMessage.Content.DocumentFile />
<ChatSourcesComponent />
<ChatMessage.Content.Source />
<ChatMessage.Content.SuggestedQuestions />
</ChatMessage.Content>
......
"use client";
import { getCustomAnnotation, useChatMessage } from "@llamaindex/chat-ui";
import { ChatEvents } from "@llamaindex/chat-ui/widgets";
import {
getCustomAnnotation,
SourceNode,
useChatMessage,
} from "@llamaindex/chat-ui";
import { ChatEvents, ChatSources } from "@llamaindex/chat-ui/widgets";
import { useMemo } from "react";
import { z } from "zod";
......@@ -79,3 +83,38 @@ export function RetrieverComponent() {
</div>
);
}
/**
* Render the source nodes whenever we got query_index tool with output
*/
export function ChatSourcesComponent() {
const { message } = useChatMessage();
const queryIndexEvents = getCustomAnnotation<QueryIndex>(
message.annotations,
(annotation) => {
const result = QueryIndexSchema.safeParse(annotation);
return result.success && !!result.data.tool_output;
},
);
const sources: SourceNode[] = useMemo(() => {
return (
queryIndexEvents?.flatMap((event) => {
const sourceNodes =
(event.tool_output?.raw_output?.source_nodes as any[]) || [];
return sourceNodes.map((node) => {
return {
id: node.node.id_,
metadata: node.node.metadata,
score: node.score,
text: node.node.text,
url: node.node.metadata.url,
};
});
}) || []
);
}, [queryIndexEvents]);
return <ChatSources data={{ nodes: sources }} />;
}
......@@ -37,7 +37,8 @@
"tiktoken": "^1.0.15",
"uuid": "^9.0.1",
"marked": "^14.1.2",
"wikipedia": "^2.1.2"
"wikipedia": "^2.1.2",
"zod": "^3.24.2"
},
"devDependencies": {
"@types/node": "^20.10.3",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment