diff --git a/.changeset/poor-eels-beam.md b/.changeset/poor-eels-beam.md new file mode 100644 index 0000000000000000000000000000000000000000..ca8a49b5b1fff6c86e32cec4f66bbe439e50b2f6 --- /dev/null +++ b/.changeset/poor-eels-beam.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Add support for displaying tool outputs (including weather widget as example) diff --git a/helpers/python.ts b/helpers/python.ts index 9141cd6292e2d9014d01f7381f6066cd7c74687f..51d4802f4f1937878930ff7882a6435814472ad9 100644 --- a/helpers/python.ts +++ b/helpers/python.ts @@ -24,7 +24,7 @@ interface Dependency { const getAdditionalDependencies = ( modelConfig: ModelConfig, vectorDb?: TemplateVectorDB, - dataSource?: TemplateDataSource, + dataSources?: TemplateDataSource[], tools?: Tool[], ) => { const dependencies: Dependency[] = []; @@ -73,38 +73,43 @@ const getAdditionalDependencies = ( } // Add data source dependencies - const dataSourceType = dataSource?.type; - switch (dataSourceType) { - case "file": - dependencies.push({ - name: "docx2txt", - version: "^0.8", - }); - break; - case "web": - dependencies.push({ - name: "llama-index-readers-web", - version: "^0.1.6", - }); - break; - case "db": - dependencies.push({ - name: "llama-index-readers-database", - version: "^0.1.3", - }); - dependencies.push({ - name: "pymysql", - version: "^1.1.0", - extras: ["rsa"], - }); - dependencies.push({ - name: "psycopg2", - version: "^2.9.9", - }); - break; + if (dataSources) { + for (const ds of dataSources) { + const dsType = ds?.type; + switch (dsType) { + case "file": + dependencies.push({ + name: "docx2txt", + version: "^0.8", + }); + break; + case "web": + dependencies.push({ + name: "llama-index-readers-web", + version: "^0.1.6", + }); + break; + case "db": + dependencies.push({ + name: "llama-index-readers-database", + version: "^0.1.3", + }); + dependencies.push({ + name: "pymysql", + version: "^1.1.0", + extras: ["rsa"], + }); + dependencies.push({ + name: "psycopg2", + version: "^2.9.9", + }); + break; + } + } } // Add tools dependencies + console.log("Adding tools dependencies"); tools?.forEach((tool) => { tool.dependencies?.forEach((dep) => { dependencies.push(dep); @@ -299,9 +304,14 @@ export const installPythonTemplate = async ({ cwd: path.join(compPath, "engines", "python", engine), }); - const addOnDependencies = dataSources - .map((ds) => getAdditionalDependencies(modelConfig, vectorDb, ds, tools)) - .flat(); + console.log("Adding additional dependencies"); + + const addOnDependencies = getAdditionalDependencies( + modelConfig, + vectorDb, + dataSources, + tools, + ); if (observability === "opentelemetry") { addOnDependencies.push({ diff --git a/helpers/tools.ts b/helpers/tools.ts index d8b3967db97336ca95d9801dc3b5559aa7a8b19a..95890d3982f0f7af1784bcf032b21b159dd59c02 100644 --- a/helpers/tools.ts +++ b/helpers/tools.ts @@ -5,12 +5,18 @@ import yaml from "yaml"; import { makeDir } from "./make-dir"; import { TemplateFramework } from "./types"; +export enum ToolType { + LLAMAHUB = "llamahub", + LOCAL = "local", +} + export type Tool = { display: string; name: string; config?: Record<string, any>; dependencies?: ToolDependencies[]; supportedFrameworks?: Array<TemplateFramework>; + type: ToolType; }; export type ToolDependencies = { @@ -35,6 +41,7 @@ export const supportedTools: Tool[] = [ }, ], supportedFrameworks: ["fastapi"], + type: ToolType.LLAMAHUB, }, { display: "Wikipedia", @@ -46,6 +53,14 @@ export const supportedTools: Tool[] = [ }, ], supportedFrameworks: ["fastapi", "express", "nextjs"], + type: ToolType.LLAMAHUB, + }, + { + display: "Weather", + name: "weather", + dependencies: [], + supportedFrameworks: ["fastapi", "express", "nextjs"], + type: ToolType.LOCAL, }, ]; @@ -90,9 +105,19 @@ export const writeToolsConfig = async ( type: ConfigFileType = ConfigFileType.YAML, ) => { if (tools.length === 0) return; // no tools selected, no config need - const configContent: Record<string, any> = {}; + const configContent: { + [key in ToolType]: Record<string, any>; + } = { + local: {}, + llamahub: {}, + }; tools.forEach((tool) => { - configContent[tool.name] = tool.config ?? {}; + if (tool.type === ToolType.LLAMAHUB) { + configContent.llamahub[tool.name] = tool.config ?? {}; + } + if (tool.type === ToolType.LOCAL) { + configContent.local[tool.name] = tool.config ?? {}; + } }); const configPath = path.join(root, "config"); await makeDir(configPath); diff --git a/helpers/typescript.ts b/helpers/typescript.ts index 244403520b96b53e1874267a3af21fabd7c50e1e..910bf36b066ebf4d20ff19caaf8dee141eab4110 100644 --- a/helpers/typescript.ts +++ b/helpers/typescript.ts @@ -105,7 +105,7 @@ export const installTSTemplate = async ({ const enginePath = path.join(root, relativeEngineDestPath, "engine"); // copy vector db component - console.log("\nUsing vector DB:", vectorDb, "\n"); + console.log("\nUsing vector DB:", vectorDb ?? "none", "\n"); await copy("**", enginePath, { parents: true, cwd: path.join(compPath, "vectordbs", "typescript", vectorDb ?? "none"), diff --git a/templates/components/engines/python/agent/tools.py b/templates/components/engines/python/agent/tools.py deleted file mode 100644 index 584947e873c70f3ebed76ff7b93528c0fd3a60eb..0000000000000000000000000000000000000000 --- a/templates/components/engines/python/agent/tools.py +++ /dev/null @@ -1,35 +0,0 @@ -import os -import yaml -import importlib - -from llama_index.core.tools.tool_spec.base import BaseToolSpec -from llama_index.core.tools.function_tool import FunctionTool - - -class ToolFactory: - - @staticmethod - def create_tool(tool_name: str, **kwargs) -> list[FunctionTool]: - try: - tool_package, tool_cls_name = tool_name.split(".") - module_name = f"llama_index.tools.{tool_package}" - module = importlib.import_module(module_name) - tool_class = getattr(module, tool_cls_name) - tool_spec: BaseToolSpec = tool_class(**kwargs) - return tool_spec.to_tool_list() - except (ImportError, AttributeError) as e: - raise ValueError(f"Unsupported tool: {tool_name}") from e - except TypeError as e: - raise ValueError( - f"Could not create tool: {tool_name}. With config: {kwargs}" - ) from e - - @staticmethod - def from_env() -> list[FunctionTool]: - tools = [] - if os.path.exists("config/tools.yaml"): - with open("config/tools.yaml", "r") as f: - tool_configs = yaml.safe_load(f) - for name, config in tool_configs.items(): - tools += ToolFactory.create_tool(name, **config) - return tools diff --git a/templates/components/engines/python/agent/tools/__init__.py b/templates/components/engines/python/agent/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2e5da6942af4f01780cd3367366f4b0701363bb --- /dev/null +++ b/templates/components/engines/python/agent/tools/__init__.py @@ -0,0 +1,56 @@ +import os +import yaml +import importlib + +from llama_index.core.tools.tool_spec.base import BaseToolSpec +from llama_index.core.tools.function_tool import FunctionTool + + +class ToolType: + LLAMAHUB = "llamahub" + LOCAL = "local" + + +class ToolFactory: + + TOOL_SOURCE_PACKAGE_MAP = { + ToolType.LLAMAHUB: "llama_index.tools", + ToolType.LOCAL: "app.engine.tools", + } + + @staticmethod + def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]: + source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type] + try: + if "ToolSpec" in tool_name: + tool_package, tool_cls_name = tool_name.split(".") + module_name = f"{source_package}.{tool_package}" + module = importlib.import_module(module_name) + tool_class = getattr(module, tool_cls_name) + tool_spec: BaseToolSpec = tool_class(**config) + return tool_spec.to_tool_list() + else: + module = importlib.import_module(f"{source_package}.{tool_name}") + tools = getattr(module, "tools") + if not all(isinstance(tool, FunctionTool) for tool in tools): + raise ValueError( + f"The module {module} does not contain valid tools" + ) + return tools + except ImportError as e: + raise ValueError(f"Failed to import tool {tool_name}: {e}") + except AttributeError as e: + raise ValueError(f"Failed to load tool {tool_name}: {e}") + + @staticmethod + def from_env() -> list[FunctionTool]: + tools = [] + if os.path.exists("config/tools.yaml"): + with open("config/tools.yaml", "r") as f: + tool_configs = yaml.safe_load(f) + for tool_type, config_entries in tool_configs.items(): + for tool_name, config in config_entries.items(): + tools.extend( + ToolFactory.load_tools(tool_type, tool_name, config) + ) + return tools diff --git a/templates/components/engines/python/agent/tools/weather.py b/templates/components/engines/python/agent/tools/weather.py new file mode 100644 index 0000000000000000000000000000000000000000..3ea0fc03fcdfd48443366833dd506d5025c74a19 --- /dev/null +++ b/templates/components/engines/python/agent/tools/weather.py @@ -0,0 +1,72 @@ +"""Open Meteo weather map tool spec.""" + +import logging +import requests +import pytz +from llama_index.core.tools import FunctionTool + +logger = logging.getLogger(__name__) + + +class OpenMeteoWeather: + geo_api = "https://geocoding-api.open-meteo.com/v1" + weather_api = "https://api.open-meteo.com/v1" + + @classmethod + def _get_geo_location(cls, location: str) -> dict: + """Get geo location from location name.""" + params = {"name": location, "count": 10, "language": "en", "format": "json"} + response = requests.get(f"{cls.geo_api}/search", params=params) + if response.status_code != 200: + raise Exception(f"Failed to fetch geo location: {response.status_code}") + else: + data = response.json() + result = data["results"][0] + geo_location = { + "id": result["id"], + "name": result["name"], + "latitude": result["latitude"], + "longitude": result["longitude"], + } + return geo_location + + @classmethod + def get_weather_information(cls, location: str) -> dict: + """Use this function to get the weather of any given location. + Note that the weather code should follow WMO Weather interpretation codes (WW): + 0: Clear sky + 1, 2, 3: Mainly clear, partly cloudy, and overcast + 45, 48: Fog and depositing rime fog + 51, 53, 55: Drizzle: Light, moderate, and dense intensity + 56, 57: Freezing Drizzle: Light and dense intensity + 61, 63, 65: Rain: Slight, moderate and heavy intensity + 66, 67: Freezing Rain: Light and heavy intensity + 71, 73, 75: Snow fall: Slight, moderate, and heavy intensity + 77: Snow grains + 80, 81, 82: Rain showers: Slight, moderate, and violent + 85, 86: Snow showers slight and heavy + 95: Thunderstorm: Slight or moderate + 96, 99: Thunderstorm with slight and heavy hail + """ + logger.info( + f"Calling open-meteo api to get weather information of location: {location}" + ) + geo_location = cls._get_geo_location(location) + timezone = pytz.timezone("UTC").zone + params = { + "latitude": geo_location["latitude"], + "longitude": geo_location["longitude"], + "current": "temperature_2m,weather_code", + "hourly": "temperature_2m,weather_code", + "daily": "weather_code", + "timezone": timezone, + } + response = requests.get(f"{cls.weather_api}/forecast", params=params) + if response.status_code != 200: + raise Exception( + f"Failed to fetch weather information: {response.status_code}" + ) + return response.json() + + +tools = [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)] diff --git a/templates/components/engines/typescript/agent/chat.ts b/templates/components/engines/typescript/agent/chat.ts index 41d7118d915ee331138ea663ecd2f98d6a590931..9d16c8bda40e5fc4a9ec71c18054c599c17f145b 100644 --- a/templates/components/engines/typescript/agent/chat.ts +++ b/templates/components/engines/typescript/agent/chat.ts @@ -4,9 +4,10 @@ import fs from "node:fs/promises"; import path from "node:path"; import { getDataSource } from "./index"; import { STORAGE_CACHE_DIR } from "./shared"; +import { createLocalTools } from "./tools"; export async function createChatEngine() { - let tools: BaseToolWithCall[] = []; + const tools: BaseToolWithCall[] = []; // Add a query engine tool if we have a data source // Delete this code if you don't have a data source @@ -28,7 +29,14 @@ export async function createChatEngine() { const config = JSON.parse( await fs.readFile(path.join("config", "tools.json"), "utf8"), ); - tools = tools.concat(await ToolsFactory.createTools(config)); + + // add local tools from the 'tools' folder (if configured) + const localTools = createLocalTools(config.local); + tools.push(...localTools); + + // add tools from LlamaIndexTS (if configured) + const llamaTools = await ToolsFactory.createTools(config.llamahub); + tools.push(...llamaTools); } catch {} return new OpenAIAgent({ diff --git a/templates/components/engines/typescript/agent/tools/index.ts b/templates/components/engines/typescript/agent/tools/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..8d041a766f204317d84cffa12b0d5aa596722169 --- /dev/null +++ b/templates/components/engines/typescript/agent/tools/index.ts @@ -0,0 +1,26 @@ +import { BaseToolWithCall } from "llamaindex"; +import { WeatherTool, WeatherToolParams } from "./weather"; + +type ToolCreator = (config: unknown) => BaseToolWithCall; + +const toolFactory: Record<string, ToolCreator> = { + weather: (config: unknown) => { + return new WeatherTool(config as WeatherToolParams); + }, +}; + +export function createLocalTools( + localConfig: Record<string, unknown>, +): BaseToolWithCall[] { + const tools: BaseToolWithCall[] = []; + + Object.keys(localConfig).forEach((key) => { + if (key in toolFactory) { + const toolConfig = localConfig[key]; + const tool = toolFactory[key](toolConfig); + tools.push(tool); + } + }); + + return tools; +} diff --git a/templates/components/engines/typescript/agent/tools/weather.ts b/templates/components/engines/typescript/agent/tools/weather.ts new file mode 100644 index 0000000000000000000000000000000000000000..c1f601494ccde7566b662c5017da7863392253d9 --- /dev/null +++ b/templates/components/engines/typescript/agent/tools/weather.ts @@ -0,0 +1,81 @@ +import type { JSONSchemaType } from "ajv"; +import { BaseTool, ToolMetadata } from "llamaindex"; + +interface GeoLocation { + id: string; + name: string; + latitude: number; + longitude: number; +} + +export type WeatherParameter = { + location: string; +}; + +export type WeatherToolParams = { + metadata?: ToolMetadata<JSONSchemaType<WeatherParameter>>; +}; + +const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<WeatherParameter>> = { + name: "get_weather_information", + description: ` + Use this function to get the weather of any given location. + Note that the weather code should follow WMO Weather interpretation codes (WW): + 0: Clear sky + 1, 2, 3: Mainly clear, partly cloudy, and overcast + 45, 48: Fog and depositing rime fog + 51, 53, 55: Drizzle: Light, moderate, and dense intensity + 56, 57: Freezing Drizzle: Light and dense intensity + 61, 63, 65: Rain: Slight, moderate and heavy intensity + 66, 67: Freezing Rain: Light and heavy intensity + 71, 73, 75: Snow fall: Slight, moderate, and heavy intensity + 77: Snow grains + 80, 81, 82: Rain showers: Slight, moderate, and violent + 85, 86: Snow showers slight and heavy + 95: Thunderstorm: Slight or moderate + 96, 99: Thunderstorm with slight and heavy hail + `, + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The location to get the weather information", + }, + }, + required: ["location"], + }, +}; + +export class WeatherTool implements BaseTool<WeatherParameter> { + metadata: ToolMetadata<JSONSchemaType<WeatherParameter>>; + + private getGeoLocation = async (location: string): Promise<GeoLocation> => { + const apiUrl = `https://geocoding-api.open-meteo.com/v1/search?name=${location}&count=10&language=en&format=json`; + const response = await fetch(apiUrl); + const data = await response.json(); + const { id, name, latitude, longitude } = data.results[0]; + return { id, name, latitude, longitude }; + }; + + private getWeatherByLocation = async (location: string) => { + console.log( + "Calling open-meteo api to get weather information of location:", + location, + ); + const { latitude, longitude } = await this.getGeoLocation(location); + const timezone = Intl.DateTimeFormat().resolvedOptions().timeZone; + const apiUrl = `https://api.open-meteo.com/v1/forecast?latitude=${latitude}&longitude=${longitude}¤t=temperature_2m,weather_code&hourly=temperature_2m,weather_code&daily=weather_code&timezone=${timezone}`; + const response = await fetch(apiUrl); + const data = await response.json(); + return data; + }; + + constructor(params?: WeatherToolParams) { + this.metadata = params?.metadata || DEFAULT_META_DATA; + } + + async call(input: WeatherParameter) { + return await this.getWeatherByLocation(input.location); + } +} diff --git a/templates/types/streaming/express/package.json b/templates/types/streaming/express/package.json index d59f18721f8a7e5cef5ff8fdaa1dfc1bab43f77e..f4473a6ffa9697f1a1e20594afdb44101bdb8c6a 100644 --- a/templates/types/streaming/express/package.json +++ b/templates/types/streaming/express/package.json @@ -14,8 +14,9 @@ "cors": "^2.8.5", "dotenv": "^16.3.1", "express": "^4.18.2", - "llamaindex": "0.3.3", - "pdf2json": "3.0.5" + "llamaindex": "0.3.7", + "pdf2json": "3.0.5", + "ajv": "^8.12.0" }, "devDependencies": { "@types/cors": "^2.8.16", diff --git a/templates/types/streaming/express/src/controllers/chat.controller.ts b/templates/types/streaming/express/src/controllers/chat.controller.ts index d9ba57dcab0ffcee1ed80a33ca9ecaa95e2c50c9..3de1ac6b83075702c20f4e1049c36f227b61455e 100644 --- a/templates/types/streaming/express/src/controllers/chat.controller.ts +++ b/templates/types/streaming/express/src/controllers/chat.controller.ts @@ -1,14 +1,9 @@ import { Message, StreamData, streamToResponse } from "ai"; import { Request, Response } from "express"; -import { - CallbackManager, - ChatMessage, - MessageContent, - Settings, -} from "llamaindex"; +import { ChatMessage, MessageContent, Settings } from "llamaindex"; import { createChatEngine } from "./engine/chat"; import { LlamaIndexStream } from "./llamaindex-stream"; -import { appendEventData } from "./stream-helper"; +import { createCallbackManager } from "./stream-helper"; const convertMessageContent = ( textMessage: string, @@ -52,18 +47,7 @@ export const chat = async (req: Request, res: Response) => { const vercelStreamData = new StreamData(); // Setup callbacks - const callbackManager = new CallbackManager(); - callbackManager.on("retrieve", (data) => { - const { nodes } = data.detail; - appendEventData( - vercelStreamData, - `Retrieving context for query: '${userMessage.content}'`, - ); - appendEventData( - vercelStreamData, - `Retrieved ${nodes.length} sources to use as context for the query`, - ); - }); + const callbackManager = createCallbackManager(vercelStreamData); // Calling LlamaIndex's ChatEngine to get a streamed response const response = await Settings.withCallbackManager(callbackManager, () => { diff --git a/templates/types/streaming/express/src/controllers/stream-helper.ts b/templates/types/streaming/express/src/controllers/stream-helper.ts index e74597b9a2a30e18b6812196b9e848484d025d24..9f1a8864bfa63e06e8e622a641075cd34ac604d8 100644 --- a/templates/types/streaming/express/src/controllers/stream-helper.ts +++ b/templates/types/streaming/express/src/controllers/stream-helper.ts @@ -1,5 +1,11 @@ import { StreamData } from "ai"; -import { Metadata, NodeWithScore } from "llamaindex"; +import { + CallbackManager, + Metadata, + NodeWithScore, + ToolCall, + ToolOutput, +} from "llamaindex"; export function appendImageData(data: StreamData, imageUrl?: string) { if (!imageUrl) return; @@ -37,3 +43,55 @@ export function appendEventData(data: StreamData, title?: string) { }, }); } + +export function appendToolData( + data: StreamData, + toolCall: ToolCall, + toolOutput: ToolOutput, +) { + data.appendMessageAnnotation({ + type: "tools", + data: { + toolCall: { + id: toolCall.id, + name: toolCall.name, + input: toolCall.input, + }, + toolOutput: { + output: toolOutput.output, + isError: toolOutput.isError, + }, + }, + }); +} + +export function createCallbackManager(stream: StreamData) { + const callbackManager = new CallbackManager(); + + callbackManager.on("retrieve", (data) => { + const { nodes, query } = data.detail; + appendEventData(stream, `Retrieving context for query: '${query}'`); + appendEventData( + stream, + `Retrieved ${nodes.length} sources to use as context for the query`, + ); + }); + + callbackManager.on("llm-tool-call", (event) => { + const { name, input } = event.detail.payload.toolCall; + const inputString = Object.entries(input) + .map(([key, value]) => `${key}: ${value}`) + .join(", "); + appendEventData( + stream, + `Using tool: '${name}' with inputs: '${inputString}'`, + ); + }); + + callbackManager.on("llm-tool-result", (event) => { + const { toolCall, toolResult } = event.detail.payload; + appendToolData(stream, toolCall, toolResult); + }); + + return callbackManager; +} diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index cd287f5fb5b9105abe264cf7eeead0aa6c91003c..c92ca3d425d3634b477122cb11847d09b3af4131 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -1,10 +1,7 @@ from pydantic import BaseModel from typing import List, Any, Optional, Dict, Tuple from fastapi import APIRouter, Depends, HTTPException, Request, status -from llama_index.core.chat_engine.types import ( - BaseChatEngine, - StreamingAgentChatResponse, -) +from llama_index.core.chat_engine.types import BaseChatEngine from llama_index.core.schema import NodeWithScore from llama_index.core.llms import ChatMessage, MessageRole from app.engine import get_chat_engine @@ -109,12 +106,9 @@ async def chat( # Yield the events from the event handler async def _event_generator(): async for event in event_handler.async_event_gen(): - yield VercelStreamResponse.convert_data( - { - "type": "events", - "data": {"title": event.get_title()}, - } - ) + event_response = event.to_response() + if event_response is not None: + yield VercelStreamResponse.convert_data(event_response) combine = stream.merge(_text_generator(), _event_generator()) async with combine.stream() as streamer: diff --git a/templates/types/streaming/fastapi/app/api/routers/messaging.py b/templates/types/streaming/fastapi/app/api/routers/messaging.py index 9c2a49057bfc614c23ae7d053aa3e13d34ebf5ae..a239657a8e0227287bf6496600d218d4b546c717 100644 --- a/templates/types/streaming/fastapi/app/api/routers/messaging.py +++ b/templates/types/streaming/fastapi/app/api/routers/messaging.py @@ -1,8 +1,9 @@ +import json import asyncio from typing import AsyncGenerator, Dict, Any, List, Optional - from llama_index.core.callbacks.base import BaseCallbackHandler from llama_index.core.callbacks.schema import CBEventType +from llama_index.core.tools.types import ToolOutput from pydantic import BaseModel @@ -11,19 +12,73 @@ class CallbackEvent(BaseModel): payload: Optional[Dict[str, Any]] = None event_id: str = "" - def get_title(self) -> str | None: - # Return as None for the unhandled event types - # to avoid showing them in the UI + def get_retrieval_message(self) -> dict | None: + if self.payload: + nodes = self.payload.get("nodes") + if nodes: + msg = f"Retrieved {len(nodes)} sources to use as context for the query" + else: + msg = f"Retrieving context for query: '{self.payload.get('query_str')}'" + return { + "type": "events", + "data": {"title": msg}, + } + else: + return None + + def get_tool_message(self) -> dict | None: + func_call_args = self.payload.get("function_call") + if func_call_args is not None and "tool" in self.payload: + tool = self.payload.get("tool") + return { + "type": "events", + "data": { + "title": f"Calling tool: {tool.name} with inputs: {func_call_args}", + }, + } + + def _is_output_serializable(self, output: Any) -> bool: + try: + json.dumps(output) + return True + except TypeError: + return False + + def get_agent_tool_response(self) -> dict | None: + response = self.payload.get("response") + if response is not None: + sources = response.sources + for source in sources: + # Return the tool response here to include the toolCall information + if isinstance(source, ToolOutput): + if self._is_output_serializable(source.raw_output): + output = source.raw_output + else: + output = source.content + + return { + "type": "tools", + "data": { + "toolOutput": { + "output": output, + "isError": source.is_error, + }, + "toolCall": { + "id": None, # There is no tool id in the ToolOutput + "name": source.tool_name, + "input": source.raw_input, + }, + }, + } + + def to_response(self): match self.event_type: case "retrieve": - if self.payload: - nodes = self.payload.get("nodes") - if nodes: - return f"Retrieved {len(nodes)} sources to use as context for the query" - else: - return f"Retrieving context for query: '{self.payload.get('query_str')}'" - else: - return None + return self.get_retrieval_message() + case "function_call": + return self.get_tool_message() + case "agent_step": + return self.get_agent_tool_response() case _: return None @@ -54,7 +109,7 @@ class EventCallbackHandler(BaseCallbackHandler): **kwargs: Any, ) -> str: event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload) - if event.get_title() is not None: + if event.to_response() is not None: self._aqueue.put_nowait(event) def on_event_end( @@ -65,7 +120,7 @@ class EventCallbackHandler(BaseCallbackHandler): **kwargs: Any, ) -> None: event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload) - if event.get_title() is not None: + if event.to_response() is not None: self._aqueue.put_nowait(event) def start_trace(self, trace_id: Optional[str] = None) -> None: diff --git a/templates/types/streaming/nextjs/app/api/chat/route.ts b/templates/types/streaming/nextjs/app/api/chat/route.ts index 117848d37c291a3b8dece5236c29c1c926d7ba00..86ab4945efbb9df92e4ff44386d868eb53d8c9ae 100644 --- a/templates/types/streaming/nextjs/app/api/chat/route.ts +++ b/templates/types/streaming/nextjs/app/api/chat/route.ts @@ -1,16 +1,11 @@ import { initObservability } from "@/app/observability"; import { Message, StreamData, StreamingTextResponse } from "ai"; -import { - CallbackManager, - ChatMessage, - MessageContent, - Settings, -} from "llamaindex"; +import { ChatMessage, MessageContent, Settings } from "llamaindex"; import { NextRequest, NextResponse } from "next/server"; import { createChatEngine } from "./engine/chat"; import { initSettings } from "./engine/settings"; import { LlamaIndexStream } from "./llamaindex-stream"; -import { appendEventData } from "./stream-helper"; +import { createCallbackManager } from "./stream-helper"; initObservability(); initSettings(); @@ -64,18 +59,7 @@ export async function POST(request: NextRequest) { const vercelStreamData = new StreamData(); // Setup callbacks - const callbackManager = new CallbackManager(); - callbackManager.on("retrieve", (data) => { - const { nodes } = data.detail; - appendEventData( - vercelStreamData, - `Retrieving context for query: '${userMessage.content}'`, - ); - appendEventData( - vercelStreamData, - `Retrieved ${nodes.length} sources to use as context for the query`, - ); - }); + const callbackManager = createCallbackManager(vercelStreamData); // Calling LlamaIndex's ChatEngine to get a streamed response const response = await Settings.withCallbackManager(callbackManager, () => { diff --git a/templates/types/streaming/nextjs/app/api/chat/stream-helper.ts b/templates/types/streaming/nextjs/app/api/chat/stream-helper.ts index e74597b9a2a30e18b6812196b9e848484d025d24..9f1a8864bfa63e06e8e622a641075cd34ac604d8 100644 --- a/templates/types/streaming/nextjs/app/api/chat/stream-helper.ts +++ b/templates/types/streaming/nextjs/app/api/chat/stream-helper.ts @@ -1,5 +1,11 @@ import { StreamData } from "ai"; -import { Metadata, NodeWithScore } from "llamaindex"; +import { + CallbackManager, + Metadata, + NodeWithScore, + ToolCall, + ToolOutput, +} from "llamaindex"; export function appendImageData(data: StreamData, imageUrl?: string) { if (!imageUrl) return; @@ -37,3 +43,55 @@ export function appendEventData(data: StreamData, title?: string) { }, }); } + +export function appendToolData( + data: StreamData, + toolCall: ToolCall, + toolOutput: ToolOutput, +) { + data.appendMessageAnnotation({ + type: "tools", + data: { + toolCall: { + id: toolCall.id, + name: toolCall.name, + input: toolCall.input, + }, + toolOutput: { + output: toolOutput.output, + isError: toolOutput.isError, + }, + }, + }); +} + +export function createCallbackManager(stream: StreamData) { + const callbackManager = new CallbackManager(); + + callbackManager.on("retrieve", (data) => { + const { nodes, query } = data.detail; + appendEventData(stream, `Retrieving context for query: '${query}'`); + appendEventData( + stream, + `Retrieved ${nodes.length} sources to use as context for the query`, + ); + }); + + callbackManager.on("llm-tool-call", (event) => { + const { name, input } = event.detail.payload.toolCall; + const inputString = Object.entries(input) + .map(([key, value]) => `${key}: ${value}`) + .join(", "); + appendEventData( + stream, + `Using tool: '${name}' with inputs: '${inputString}'`, + ); + }); + + callbackManager.on("llm-tool-result", (event) => { + const { toolCall, toolResult } = event.detail.payload; + appendToolData(stream, toolCall, toolResult); + }); + + return callbackManager; +} diff --git a/templates/types/streaming/nextjs/app/components/chat-section.tsx b/templates/types/streaming/nextjs/app/components/chat-section.tsx index afb59960f98486e1d5121879231d01148935b197..4f8832200a8c3ae7f49b6cb455b8f6b4c3762a9c 100644 --- a/templates/types/streaming/nextjs/app/components/chat-section.tsx +++ b/templates/types/streaming/nextjs/app/components/chat-section.tsx @@ -17,7 +17,8 @@ export default function ChatSection() { headers: { "Content-Type": "application/json", // using JSON because of vercel/ai 2.2.26 }, - onError: (error) => { + onError: (error: unknown) => { + if (!(error instanceof Error)) throw error; const message = JSON.parse(error.message); alert(message.detail); }, diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message.tsx index c8dcc13a440db727dc6a39b0af6f6174d236fd2d..da1d92e95a97e566f334ba0c9cac1f5fffa90bc8 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message.tsx @@ -7,6 +7,7 @@ import ChatAvatar from "./chat-avatar"; import { ChatEvents } from "./chat-events"; import { ChatImage } from "./chat-image"; import { ChatSources } from "./chat-sources"; +import ChatTools from "./chat-tools"; import { AnnotationData, EventData, @@ -14,6 +15,7 @@ import { MessageAnnotation, MessageAnnotationType, SourceData, + ToolData, } from "./index"; import Markdown from "./markdown"; import { useCopyToClipboard } from "./use-copy-to-clipboard"; @@ -52,19 +54,27 @@ function ChatMessageContent({ annotations, MessageAnnotationType.SOURCES, ); + const toolData = getAnnotationData<ToolData>( + annotations, + MessageAnnotationType.TOOLS, + ); const contents: ContentDisplayConfig[] = [ { - order: -2, + order: -3, component: imageData[0] ? <ChatImage data={imageData[0]} /> : null, }, { - order: -1, + order: -2, component: eventData.length > 0 ? ( <ChatEvents isLoading={isLoading} data={eventData} /> ) : null, }, + { + order: -1, + component: toolData[0] ? <ChatTools data={toolData[0]} /> : null, + }, { order: 0, component: <Markdown content={message.content} />, diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-messages.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-messages.tsx index 5f0eab6b2692d0bbcd54654af9ce328cae277d50..55cfd9eb0664507b6e70bbb0cf2a18d893ed5470 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-messages.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-messages.tsx @@ -40,9 +40,16 @@ export default function ChatMessages( className="flex h-[50vh] flex-col gap-5 divide-y overflow-y-auto pb-4" ref={scrollableChatContainerRef} > - {props.messages.map((m) => ( - <ChatMessage key={m.id} chatMessage={m} isLoading={props.isLoading} /> - ))} + {props.messages.map((m, i) => { + const isLoadingMessage = i === messageLength - 1 && props.isLoading; + return ( + <ChatMessage + key={m.id} + chatMessage={m} + isLoading={isLoadingMessage} + /> + ); + })} {isPending && ( <div className="flex justify-center items-center pt-10"> <Loader2 className="h-4 w-4 animate-spin" /> diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-tools.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-tools.tsx new file mode 100644 index 0000000000000000000000000000000000000000..268b4360d0c812a948581b489ee9f4f4267df642 --- /dev/null +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-tools.tsx @@ -0,0 +1,26 @@ +import { ToolData } from "./index"; +import { WeatherCard, WeatherData } from "./widgets/WeatherCard"; + +// TODO: If needed, add displaying more tool outputs here +export default function ChatTools({ data }: { data: ToolData }) { + if (!data) return null; + const { toolCall, toolOutput } = data; + + if (toolOutput.isError) { + return ( + <div className="border-l-2 border-red-400 pl-2"> + There was an error when calling the tool {toolCall.name} with input:{" "} + <br /> + {JSON.stringify(toolCall.input)} + </div> + ); + } + + switch (toolCall.name) { + case "get_weather_information": + const weatherData = toolOutput.output as unknown as WeatherData; + return <WeatherCard data={weatherData} />; + default: + return null; + } +} diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/index.ts b/templates/types/streaming/nextjs/app/components/ui/chat/index.ts index 464f195a7ae7fbcf9dc138c03209a1710a060863..106f6294bd5138b4636da6582fc4370f027eb9d8 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/index.ts +++ b/templates/types/streaming/nextjs/app/components/ui/chat/index.ts @@ -1,3 +1,4 @@ +import { JSONValue } from "ai"; import ChatInput from "./chat-input"; import ChatMessages from "./chat-messages"; @@ -8,6 +9,7 @@ export enum MessageAnnotationType { IMAGE = "image", SOURCES = "sources", EVENTS = "events", + TOOLS = "tools", } export type ImageData = { @@ -30,7 +32,21 @@ export type EventData = { isCollapsed: boolean; }; -export type AnnotationData = ImageData | SourceData | EventData; +export type ToolData = { + toolCall: { + id: string; + name: string; + input: { + [key: string]: JSONValue; + }; + }; + toolOutput: { + output: JSONValue; + isError: boolean; + }; +}; + +export type AnnotationData = ImageData | SourceData | EventData | ToolData; export type MessageAnnotation = { type: MessageAnnotationType; diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/widgets/WeatherCard.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/widgets/WeatherCard.tsx new file mode 100644 index 0000000000000000000000000000000000000000..f2115ae0e5d89e735c7264d6ad7a6ce7fb3cb860 --- /dev/null +++ b/templates/types/streaming/nextjs/app/components/ui/chat/widgets/WeatherCard.tsx @@ -0,0 +1,213 @@ +export interface WeatherData { + latitude: number; + longitude: number; + generationtime_ms: number; + utc_offset_seconds: number; + timezone: string; + timezone_abbreviation: string; + elevation: number; + current_units: { + time: string; + interval: string; + temperature_2m: string; + weather_code: string; + }; + current: { + time: string; + interval: number; + temperature_2m: number; + weather_code: number; + }; + hourly_units: { + time: string; + temperature_2m: string; + weather_code: string; + }; + hourly: { + time: string[]; + temperature_2m: number[]; + weather_code: number[]; + }; + daily_units: { + time: string; + weather_code: string; + }; + daily: { + time: string[]; + weather_code: number[]; + }; +} + +// Follow WMO Weather interpretation codes (WW) +const weatherCodeDisplayMap: Record< + string, + { + icon: JSX.Element; + status: string; + } +> = { + "0": { + icon: <span>☀ï¸</span>, + status: "Clear sky", + }, + "1": { + icon: <span>🌤ï¸</span>, + status: "Mainly clear", + }, + "2": { + icon: <span>â˜ï¸</span>, + status: "Partly cloudy", + }, + "3": { + icon: <span>â˜ï¸</span>, + status: "Overcast", + }, + "45": { + icon: <span>🌫ï¸</span>, + status: "Fog", + }, + "48": { + icon: <span>🌫ï¸</span>, + status: "Depositing rime fog", + }, + "51": { + icon: <span>🌧ï¸</span>, + status: "Drizzle", + }, + "53": { + icon: <span>🌧ï¸</span>, + status: "Drizzle", + }, + "55": { + icon: <span>🌧ï¸</span>, + status: "Drizzle", + }, + "56": { + icon: <span>🌧ï¸</span>, + status: "Freezing Drizzle", + }, + "57": { + icon: <span>🌧ï¸</span>, + status: "Freezing Drizzle", + }, + "61": { + icon: <span>🌧ï¸</span>, + status: "Rain", + }, + "63": { + icon: <span>🌧ï¸</span>, + status: "Rain", + }, + "65": { + icon: <span>🌧ï¸</span>, + status: "Rain", + }, + "66": { + icon: <span>🌧ï¸</span>, + status: "Freezing Rain", + }, + "67": { + icon: <span>🌧ï¸</span>, + status: "Freezing Rain", + }, + "71": { + icon: <span>â„ï¸</span>, + status: "Snow fall", + }, + "73": { + icon: <span>â„ï¸</span>, + status: "Snow fall", + }, + "75": { + icon: <span>â„ï¸</span>, + status: "Snow fall", + }, + "77": { + icon: <span>â„ï¸</span>, + status: "Snow grains", + }, + "80": { + icon: <span>🌧ï¸</span>, + status: "Rain showers", + }, + "81": { + icon: <span>🌧ï¸</span>, + status: "Rain showers", + }, + "82": { + icon: <span>🌧ï¸</span>, + status: "Rain showers", + }, + "85": { + icon: <span>â„ï¸</span>, + status: "Snow showers", + }, + "86": { + icon: <span>â„ï¸</span>, + status: "Snow showers", + }, + "95": { + icon: <span>⛈ï¸</span>, + status: "Thunderstorm", + }, + "96": { + icon: <span>⛈ï¸</span>, + status: "Thunderstorm", + }, + "99": { + icon: <span>⛈ï¸</span>, + status: "Thunderstorm", + }, +}; + +const displayDay = (time: string) => { + return new Date(time).toLocaleDateString("en-US", { + weekday: "long", + }); +}; + +export function WeatherCard({ data }: { data: WeatherData }) { + const currentDayString = new Date(data.current.time).toLocaleDateString( + "en-US", + { + weekday: "long", + month: "long", + day: "numeric", + }, + ); + + return ( + <div className="bg-[#61B9F2] rounded-2xl shadow-xl p-5 space-y-4 text-white w-fit"> + <div className="flex justify-between"> + <div className="space-y-2"> + <div className="text-xl">{currentDayString}</div> + <div className="text-5xl font-semibold flex gap-4"> + <span> + {data.current.temperature_2m} {data.current_units.temperature_2m} + </span> + {weatherCodeDisplayMap[data.current.weather_code].icon} + </div> + </div> + <span className="text-xl"> + {weatherCodeDisplayMap[data.current.weather_code].status} + </span> + </div> + <div className="gap-2 grid grid-cols-6"> + {data.daily.time.map((time, index) => { + if (index === 0) return null; // skip the current day + return ( + <div key={time} className="flex flex-col items-center gap-4"> + <span>{displayDay(time)}</span> + <div className="text-4xl"> + {weatherCodeDisplayMap[data.daily.weather_code[index]].icon} + </div> + <span className="text-sm"> + {weatherCodeDisplayMap[data.daily.weather_code[index]].status} + </span> + </div> + ); + })} + </div> + </div> + ); +} diff --git a/templates/types/streaming/nextjs/package.json b/templates/types/streaming/nextjs/package.json index 3d043a14d2fb748c7cf1e12613a791c70225e180..3093201b0c054625e411e4d73ec4020f8634207b 100644 --- a/templates/types/streaming/nextjs/package.json +++ b/templates/types/streaming/nextjs/package.json @@ -14,10 +14,11 @@ "@radix-ui/react-hover-card": "^1.0.7", "@radix-ui/react-slot": "^1.0.2", "ai": "^3.0.21", + "ajv": "^8.12.0", "class-variance-authority": "^0.7.0", - "clsx": "^1.2.1", + "clsx": "^2.1.1", "dotenv": "^16.3.1", - "llamaindex": "0.3.3", + "llamaindex": "0.3.9", "lucide-react": "^0.294.0", "next": "^14.0.3", "pdf2json": "3.0.5",