Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • mirrored_repos/machinelearning/run-llama/create-llama
1 result
Show changes
Commits on Source (4)
...@@ -24,7 +24,7 @@ class AddNodeUrl(EventCallback): ...@@ -24,7 +24,7 @@ class AddNodeUrl(EventCallback):
def _is_retrieval_result_event(self, event: Any) -> bool: def _is_retrieval_result_event(self, event: Any) -> bool:
if isinstance(event, ToolCallResult): if isinstance(event, ToolCallResult):
if event.tool_name == "query_index": if event.tool_name == "query_engine":
return True return True
return False return False
......
...@@ -64,7 +64,7 @@ def get_query_engine_tool( ...@@ -64,7 +64,7 @@ def get_query_engine_tool(
description (optional): The description of the tool. description (optional): The description of the tool.
""" """
if name is None: if name is None:
name = "query_index" name = "query_engine"
if description is None: if description is None:
description = ( description = (
"Use this tool to retrieve information about the text corpus from an index." "Use this tool to retrieve information about the text corpus from an index."
......
...@@ -9,36 +9,72 @@ import { ChatEvents, ChatSources } from "@llamaindex/chat-ui/widgets"; ...@@ -9,36 +9,72 @@ import { ChatEvents, ChatSources } from "@llamaindex/chat-ui/widgets";
import { useMemo } from "react"; import { useMemo } from "react";
import { z } from "zod"; import { z } from "zod";
const QueryIndexSchema = z.object({ type QueryIndex = {
tool_name: z.literal("query_index"), toolName: "query_engine";
tool_kwargs: z.object({ toolKwargs: {
input: z.string(), query: string;
};
toolId: string;
toolOutput?: {
id: string;
result: string;
isError: boolean;
};
returnDirect: boolean;
};
const TypeScriptSchema = z.object({
toolName: z.literal("query_engine"),
toolKwargs: z.object({
query: z.string(),
}), }),
tool_id: z.string(), toolId: z.string(),
tool_output: z.optional( toolOutput: z
z.object({ .object({
content: z.string(), id: z.string(),
tool_name: z.string(), result: z.string(),
raw_output: z.object({ isError: z.boolean(),
source_nodes: z.array( })
z.object({ .optional(),
node: z.object({ returnDirect: z.boolean(),
id_: z.string(),
metadata: z.object({
url: z.string(),
}),
text: z.string(),
}),
score: z.number(),
}),
),
}),
is_error: z.boolean().optional(),
}),
),
return_direct: z.boolean().optional(),
}); });
type QueryIndex = z.infer<typeof QueryIndexSchema>;
const PythonSchema = z
.object({
tool_name: z.literal("query_engine"),
tool_kwargs: z.object({
input: z.string(),
}),
tool_id: z.string(),
tool_output: z
.object({
content: z.string(),
tool_name: z.string(),
raw_output: z.object({
source_nodes: z.array(z.any()),
}),
is_error: z.boolean().optional(),
})
.optional(),
return_direct: z.boolean().optional(),
})
.transform((data): QueryIndex => {
return {
toolName: data.tool_name,
toolKwargs: {
query: data.tool_kwargs.input,
},
toolId: data.tool_id,
toolOutput: data.tool_output
? {
id: data.tool_id,
result: data.tool_output.content,
isError: data.tool_output.is_error || false,
}
: undefined,
returnDirect: data.return_direct || false,
};
});
type GroupedIndexQuery = { type GroupedIndexQuery = {
initial: QueryIndex; initial: QueryIndex;
...@@ -51,19 +87,22 @@ export function RetrieverComponent() { ...@@ -51,19 +87,22 @@ export function RetrieverComponent() {
const queryIndexEvents = getCustomAnnotation<QueryIndex>( const queryIndexEvents = getCustomAnnotation<QueryIndex>(
message.annotations, message.annotations,
(annotation) => { (annotation) => {
const result = QueryIndexSchema.safeParse(annotation); const schema = "toolName" in annotation ? TypeScriptSchema : PythonSchema;
return result.success; const result = schema.safeParse(annotation);
if (!result.success) return false;
// If the schema has transformed the annotation, replace the original
// annotation with the transformed data
Object.assign(annotation, result.data);
return true;
}, },
); );
// Group events by tool_id and render them in a single ChatEvents component
const groupedIndexQueries = useMemo(() => { const groupedIndexQueries = useMemo(() => {
const groups = new Map<string, GroupedIndexQuery>(); const groups = new Map<string, GroupedIndexQuery>();
queryIndexEvents?.forEach((event) => { queryIndexEvents?.forEach((event) => {
groups.set(event.tool_id, { initial: event }); groups.set(event.toolId, { initial: event });
}); });
return Array.from(groups.values()); return Array.from(groups.values());
}, [queryIndexEvents]); }, [queryIndexEvents]);
...@@ -73,21 +112,21 @@ export function RetrieverComponent() { ...@@ -73,21 +112,21 @@ export function RetrieverComponent() {
{groupedIndexQueries.map(({ initial }) => { {groupedIndexQueries.map(({ initial }) => {
const eventData = [ const eventData = [
{ {
title: `Searching index with query: ${initial.tool_kwargs.input}`, title: `Searching index with query: ${initial.toolKwargs.query}`,
}, },
]; ];
if (initial.tool_output) { if (initial.toolOutput) {
eventData.push({ eventData.push({
title: `Got ${JSON.stringify(initial.tool_output?.raw_output.source_nodes?.length ?? 0)} sources for query: ${initial.tool_kwargs.input}`, title: `Got result for query: ${initial.toolKwargs.query}`,
}); });
} }
return ( return (
<ChatEvents <ChatEvents
key={initial.tool_id} key={initial.toolId}
data={eventData} data={eventData}
showLoading={!initial.tool_output} showLoading={!initial.toolOutput}
/> />
); );
})} })}
...@@ -96,35 +135,24 @@ export function RetrieverComponent() { ...@@ -96,35 +135,24 @@ export function RetrieverComponent() {
); );
} }
/**
* Render the source nodes whenever we got query_index tool with output
*/
export function ChatSourcesComponent() { export function ChatSourcesComponent() {
const { message } = useChatMessage(); const { message } = useChatMessage();
const queryIndexEvents = getCustomAnnotation<QueryIndex>( const queryIndexEvents = getCustomAnnotation<QueryIndex>(
message.annotations, message.annotations,
(annotation) => { (annotation) => {
const result = QueryIndexSchema.safeParse(annotation); const schema = "toolName" in annotation ? TypeScriptSchema : PythonSchema;
return result.success && !!result.data.tool_output; const result = schema.safeParse(annotation);
if (!result.success) return false;
// If the schema has transformed the annotation, replace the original
Object.assign(annotation, result.data);
return !!result.data.toolOutput;
}, },
); );
const sources: SourceNode[] = useMemo(() => { const sources: SourceNode[] = useMemo(() => {
return ( return []; // TypeScript format doesn't use source nodes
queryIndexEvents?.flatMap((event) => {
const sourceNodes = event.tool_output?.raw_output?.source_nodes || [];
return sourceNodes.map((node) => {
return {
id: node.node.id_,
metadata: node.node.metadata,
score: node.score,
text: node.node.text,
url: node.node.metadata.url,
};
});
}) || []
);
}, [queryIndexEvents]); }, [queryIndexEvents]);
return <ChatSources data={{ nodes: sources }} />; return <ChatSources data={{ nodes: sources }} />;
......