Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • mirrored_repos/machinelearning/run-llama/create-llama
1 result
Show changes
Commits on Source (4)
......@@ -24,7 +24,7 @@ class AddNodeUrl(EventCallback):
def _is_retrieval_result_event(self, event: Any) -> bool:
if isinstance(event, ToolCallResult):
if event.tool_name == "query_index":
if event.tool_name == "query_engine":
return True
return False
......
......@@ -64,7 +64,7 @@ def get_query_engine_tool(
description (optional): The description of the tool.
"""
if name is None:
name = "query_index"
name = "query_engine"
if description is None:
description = (
"Use this tool to retrieve information about the text corpus from an index."
......
......@@ -9,36 +9,72 @@ import { ChatEvents, ChatSources } from "@llamaindex/chat-ui/widgets";
import { useMemo } from "react";
import { z } from "zod";
const QueryIndexSchema = z.object({
tool_name: z.literal("query_index"),
tool_kwargs: z.object({
input: z.string(),
type QueryIndex = {
toolName: "query_engine";
toolKwargs: {
query: string;
};
toolId: string;
toolOutput?: {
id: string;
result: string;
isError: boolean;
};
returnDirect: boolean;
};
const TypeScriptSchema = z.object({
toolName: z.literal("query_engine"),
toolKwargs: z.object({
query: z.string(),
}),
tool_id: z.string(),
tool_output: z.optional(
z.object({
content: z.string(),
tool_name: z.string(),
raw_output: z.object({
source_nodes: z.array(
z.object({
node: z.object({
id_: z.string(),
metadata: z.object({
url: z.string(),
}),
text: z.string(),
}),
score: z.number(),
}),
),
}),
is_error: z.boolean().optional(),
}),
),
return_direct: z.boolean().optional(),
toolId: z.string(),
toolOutput: z
.object({
id: z.string(),
result: z.string(),
isError: z.boolean(),
})
.optional(),
returnDirect: z.boolean(),
});
type QueryIndex = z.infer<typeof QueryIndexSchema>;
const PythonSchema = z
.object({
tool_name: z.literal("query_engine"),
tool_kwargs: z.object({
input: z.string(),
}),
tool_id: z.string(),
tool_output: z
.object({
content: z.string(),
tool_name: z.string(),
raw_output: z.object({
source_nodes: z.array(z.any()),
}),
is_error: z.boolean().optional(),
})
.optional(),
return_direct: z.boolean().optional(),
})
.transform((data): QueryIndex => {
return {
toolName: data.tool_name,
toolKwargs: {
query: data.tool_kwargs.input,
},
toolId: data.tool_id,
toolOutput: data.tool_output
? {
id: data.tool_id,
result: data.tool_output.content,
isError: data.tool_output.is_error || false,
}
: undefined,
returnDirect: data.return_direct || false,
};
});
type GroupedIndexQuery = {
initial: QueryIndex;
......@@ -51,19 +87,22 @@ export function RetrieverComponent() {
const queryIndexEvents = getCustomAnnotation<QueryIndex>(
message.annotations,
(annotation) => {
const result = QueryIndexSchema.safeParse(annotation);
return result.success;
const schema = "toolName" in annotation ? TypeScriptSchema : PythonSchema;
const result = schema.safeParse(annotation);
if (!result.success) return false;
// If the schema has transformed the annotation, replace the original
// annotation with the transformed data
Object.assign(annotation, result.data);
return true;
},
);
// Group events by tool_id and render them in a single ChatEvents component
const groupedIndexQueries = useMemo(() => {
const groups = new Map<string, GroupedIndexQuery>();
queryIndexEvents?.forEach((event) => {
groups.set(event.tool_id, { initial: event });
groups.set(event.toolId, { initial: event });
});
return Array.from(groups.values());
}, [queryIndexEvents]);
......@@ -73,21 +112,21 @@ export function RetrieverComponent() {
{groupedIndexQueries.map(({ initial }) => {
const eventData = [
{
title: `Searching index with query: ${initial.tool_kwargs.input}`,
title: `Searching index with query: ${initial.toolKwargs.query}`,
},
];
if (initial.tool_output) {
if (initial.toolOutput) {
eventData.push({
title: `Got ${JSON.stringify(initial.tool_output?.raw_output.source_nodes?.length ?? 0)} sources for query: ${initial.tool_kwargs.input}`,
title: `Got result for query: ${initial.toolKwargs.query}`,
});
}
return (
<ChatEvents
key={initial.tool_id}
key={initial.toolId}
data={eventData}
showLoading={!initial.tool_output}
showLoading={!initial.toolOutput}
/>
);
})}
......@@ -96,35 +135,24 @@ export function RetrieverComponent() {
);
}
/**
* Render the source nodes whenever we got query_index tool with output
*/
export function ChatSourcesComponent() {
const { message } = useChatMessage();
const queryIndexEvents = getCustomAnnotation<QueryIndex>(
message.annotations,
(annotation) => {
const result = QueryIndexSchema.safeParse(annotation);
return result.success && !!result.data.tool_output;
const schema = "toolName" in annotation ? TypeScriptSchema : PythonSchema;
const result = schema.safeParse(annotation);
if (!result.success) return false;
// If the schema has transformed the annotation, replace the original
Object.assign(annotation, result.data);
return !!result.data.toolOutput;
},
);
const sources: SourceNode[] = useMemo(() => {
return (
queryIndexEvents?.flatMap((event) => {
const sourceNodes = event.tool_output?.raw_output?.source_nodes || [];
return sourceNodes.map((node) => {
return {
id: node.node.id_,
metadata: node.node.metadata,
score: node.score,
text: node.node.text,
url: node.node.metadata.url,
};
});
}) || []
);
return []; // TypeScript format doesn't use source nodes
}, [queryIndexEvents]);
return <ChatSources data={{ nodes: sources }} />;
......