From 376d29a78f2593900772603a1242932523d951ee Mon Sep 17 00:00:00 2001 From: Parham Saidi <parham@parha.me> Date: Sat, 27 Jul 2024 04:53:24 +0200 Subject: [PATCH] feat: meta3.1 405b tool calling support (#1080) --- .changeset/small-snails-stare.md | 6 + apps/docs/docs/modules/agent/index.md | 1 + .../modules/llms/available_llms/bedrock.md | 71 ++++- packages/community/README.md | 1 + packages/community/package.json | 7 +- packages/community/src/index.ts | 2 +- .../anthropic.ts => anthropic/provider.ts} | 9 +- .../src/llm/bedrock/anthropic/types.ts | 142 ++++++++++ .../src/llm/bedrock/anthropic/utils.ts | 186 ++++++++++++ .../src/llm/bedrock/{base.ts => index.ts} | 13 +- .../src/llm/bedrock/meta/constants.ts | 3 + .../src/llm/bedrock/meta/provider.ts | 136 +++++++++ .../community/src/llm/bedrock/meta/types.ts | 21 ++ .../community/src/llm/bedrock/meta/utils.ts | 198 +++++++++++++ .../community/src/llm/bedrock/provider.ts | 1 + .../src/llm/bedrock/providers/index.ts | 9 - .../src/llm/bedrock/providers/meta.ts | 69 ----- packages/community/src/llm/bedrock/types.ts | 156 +--------- packages/community/src/llm/bedrock/utils.ts | 268 ------------------ pnpm-lock.yaml | 3 + 20 files changed, 790 insertions(+), 512 deletions(-) create mode 100644 .changeset/small-snails-stare.md rename packages/community/src/llm/bedrock/{providers/anthropic.ts => anthropic/provider.ts} (98%) create mode 100644 packages/community/src/llm/bedrock/anthropic/types.ts create mode 100644 packages/community/src/llm/bedrock/anthropic/utils.ts rename packages/community/src/llm/bedrock/{base.ts => index.ts} (97%) create mode 100644 packages/community/src/llm/bedrock/meta/constants.ts create mode 100644 packages/community/src/llm/bedrock/meta/provider.ts create mode 100644 packages/community/src/llm/bedrock/meta/types.ts create mode 100644 packages/community/src/llm/bedrock/meta/utils.ts delete mode 100644 packages/community/src/llm/bedrock/providers/index.ts delete mode 100644 packages/community/src/llm/bedrock/providers/meta.ts diff --git a/.changeset/small-snails-stare.md b/.changeset/small-snails-stare.md new file mode 100644 index 000000000..d296b7b0c --- /dev/null +++ b/.changeset/small-snails-stare.md @@ -0,0 +1,6 @@ +--- +"@llamaindex/community": patch +"docs": patch +--- + +feat: added tool calling and agent support for llama3.1 504B diff --git a/apps/docs/docs/modules/agent/index.md b/apps/docs/docs/modules/agent/index.md index 39121cfb4..522afde32 100644 --- a/apps/docs/docs/modules/agent/index.md +++ b/apps/docs/docs/modules/agent/index.md @@ -15,6 +15,7 @@ LlamaIndex.TS comes with a few built-in agents, but you can also create your own - Anthropic Agent both via Anthropic and Bedrock (in `@llamaIndex/community`) - Gemini Agent - ReACT Agent +- Meta3.1 504B via Bedrock (in `@llamaIndex/community`) ## Examples diff --git a/apps/docs/docs/modules/llms/available_llms/bedrock.md b/apps/docs/docs/modules/llms/available_llms/bedrock.md index 6091b3ab2..24f5a159e 100644 --- a/apps/docs/docs/modules/llms/available_llms/bedrock.md +++ b/apps/docs/docs/modules/llms/available_llms/bedrock.md @@ -31,7 +31,7 @@ META_LLAMA3_8B_INSTRUCT = "meta.llama3-8b-instruct-v1:0"; META_LLAMA3_70B_INSTRUCT = "meta.llama3-70b-instruct-v1:0"; META_LLAMA3_1_8B_INSTRUCT = "meta.llama3-1-8b-instruct-v1:0"; // available on us-west-2 META_LLAMA3_1_70B_INSTRUCT = "meta.llama3-1-70b-instruct-v1:0"; // available on us-west-2 -META_LLAMA3_1_405B_INSTRUCT = "meta.llama3-1-405b-instruct-v1:0"; // preview only, available on us-west-2 +META_LLAMA3_1_405B_INSTRUCT = "meta.llama3-1-405b-instruct-v1:0"; // preview only, available on us-west-2, tool calling supported ``` Sonnet, Haiku and Opus are multimodal, image_url only supports base64 data url format, e.g. `data:image/jpeg;base64,SGVsbG8sIFdvcmxkIQ==` @@ -67,3 +67,72 @@ async function main() { console.log(response.response); } ``` + +## Agent Example + +```ts +import { BEDROCK_MODELS, Bedrock } from "@llamaindex/community"; +import { FunctionTool, LLMAgent } from "llamaindex"; + +const sumNumbers = FunctionTool.from( + ({ a, b }: { a: number; b: number }) => `${a + b}`, + { + name: "sumNumbers", + description: "Use this function to sum two numbers", + parameters: { + type: "object", + properties: { + a: { + type: "number", + description: "The first number", + }, + b: { + type: "number", + description: "The second number", + }, + }, + required: ["a", "b"], + }, + }, +); + +const divideNumbers = FunctionTool.from( + ({ a, b }: { a: number; b: number }) => `${a / b}`, + { + name: "divideNumbers", + description: "Use this function to divide two numbers", + parameters: { + type: "object", + properties: { + a: { + type: "number", + description: "The dividend a to divide", + }, + b: { + type: "number", + description: "The divisor b to divide by", + }, + }, + required: ["a", "b"], + }, + }, +); + +const bedrock = new Bedrock({ + model: BEDROCK_MODELS.META_LLAMA3_1_405B_INSTRUCT, + ... +}); + +async function main() { + const agent = new LLMAgent({ + llm: bedrock, + tools: [sumNumbers, divideNumbers], + }); + + const response = await agent.chat({ + message: "How much is 5 + 5? then divide by 2", + }); + + console.log(response.message); +} +``` diff --git a/packages/community/README.md b/packages/community/README.md index 47cbdc2be..e1311523d 100644 --- a/packages/community/README.md +++ b/packages/community/README.md @@ -6,6 +6,7 @@ - Bedrock support for the Anthropic Claude Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock) - Bedrock support for the Meta LLama 2, 3 and 3.1 Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock) +- Meta LLama3.1 405b tool call support ## LICENSE diff --git a/packages/community/package.json b/packages/community/package.json index 69f6ef7d7..8b601a39b 100644 --- a/packages/community/package.json +++ b/packages/community/package.json @@ -19,11 +19,11 @@ "./llm/bedrock": { "import": { "types": "./dist/type/llm/bedrock.d.ts", - "default": "./dist/llm/bedrock/base.js" + "default": "./dist/llm/bedrock/index.js" }, "require": { "types": "./dist/type/llm/bedrock.d.ts", - "default": "./dist/llm/bedrock/base.cjs" + "default": "./dist/llm/bedrock/index.cjs" } } }, @@ -47,6 +47,7 @@ }, "dependencies": { "@aws-sdk/client-bedrock-runtime": "^3.613.0", - "@llamaindex/core": "workspace:*" + "@llamaindex/core": "workspace:*", + "@llamaindex/env": "workspace:*" } } diff --git a/packages/community/src/index.ts b/packages/community/src/index.ts index 57a5cac79..6eecd6b95 100644 --- a/packages/community/src/index.ts +++ b/packages/community/src/index.ts @@ -2,4 +2,4 @@ export { BEDROCK_MODELS, BEDROCK_MODEL_MAX_TOKENS, Bedrock, -} from "./llm/bedrock/base.js"; +} from "./llm/bedrock/index.js"; diff --git a/packages/community/src/llm/bedrock/providers/anthropic.ts b/packages/community/src/llm/bedrock/anthropic/provider.ts similarity index 98% rename from packages/community/src/llm/bedrock/providers/anthropic.ts rename to packages/community/src/llm/bedrock/anthropic/provider.ts index b59f48d2a..bb71ed859 100644 --- a/packages/community/src/llm/bedrock/providers/anthropic.ts +++ b/packages/community/src/llm/bedrock/anthropic/provider.ts @@ -16,17 +16,18 @@ import { type BedrockChatStreamResponse, Provider, } from "../provider"; +import { toUtf8 } from "../utils"; import type { AnthropicNoneStreamingResponse, AnthropicStreamEvent, AnthropicTextContent, ToolBlock, -} from "../types"; +} from "./types"; + import { mapBaseToolsToAnthropicTools, mapChatMessagesToAnthropicMessages, - toUtf8, -} from "../utils"; +} from "./utils"; export class AnthropicProvider extends Provider<AnthropicStreamEvent> { getResultFromResponse( @@ -69,6 +70,7 @@ export class AnthropicProvider extends Provider<AnthropicStreamEvent> { let tool: ToolBlock | undefined = undefined; // #TODO this should be broken down into a separate consumer for await (const response of stream) { + const delta = this.getTextFromStreamResponse(response); const event = this.getStreamingEventResponse(response); if ( event?.type === "content_block_start" && @@ -114,7 +116,6 @@ export class AnthropicProvider extends Provider<AnthropicStreamEvent> { }; } } - const delta = this.getTextFromStreamResponse(response); if (!delta && !options) continue; yield { diff --git a/packages/community/src/llm/bedrock/anthropic/types.ts b/packages/community/src/llm/bedrock/anthropic/types.ts new file mode 100644 index 000000000..7a5db5665 --- /dev/null +++ b/packages/community/src/llm/bedrock/anthropic/types.ts @@ -0,0 +1,142 @@ +import type { ToolMetadata } from "@llamaindex/core/llms"; +import type { InvocationMetrics } from "../types"; + +type Usage = { + input_tokens: number; + output_tokens: number; +}; + +type Message = { + id: string; + type: string; + role: string; + content: string[]; + model: string; + stop_reason: string | null; + stop_sequence: string | null; + usage: Usage; +}; + +export type ToolBlock = { + id: string; + input: unknown; + name: string; + type: "tool_use"; +}; + +export type TextBlock = { + type: "text"; + text: string; +}; + +type ContentBlockStart = { + type: "content_block_start"; + index: number; + content_block: ToolBlock | TextBlock; +}; + +type Delta = + | { + type: "text_delta"; + text: string; + } + | { + type: "input_json_delta"; + partial_json: string; + }; + +type ContentBlockDelta = { + type: "content_block_delta"; + index: number; + delta: Delta; +}; + +type ContentBlockStop = { + type: "content_block_stop"; + index: number; +}; + +type MessageDelta = { + type: "message_delta"; + delta: { + stop_reason: string; + stop_sequence: string | null; + }; + usage: Usage; +}; + +export type MessageStop = { + type: "message_stop"; + "amazon-bedrock-invocationMetrics": InvocationMetrics; +}; + +export type AnthropicStreamEvent = + | { type: "message_start"; message: Message } + | ContentBlockStart + | ContentBlockDelta + | ContentBlockStop + | MessageDelta + | MessageStop; + +export type AnthropicContent = + | AnthropicTextContent + | AnthropicImageContent + | AnthropicToolContent + | AnthropicToolResultContent; + +export type AnthropicTextContent = { + type: "text"; + text: string; +}; + +export type AnthropicToolContent = { + type: "tool_use"; + id: string; + name: string; + input: Record<string, unknown>; +}; + +export type AnthropicToolResultContent = { + type: "tool_result"; + tool_use_id: string; + content: string; +}; + +export type AnthropicMediaTypes = + | "image/jpeg" + | "image/png" + | "image/webp" + | "image/gif"; + +export type AnthropicImageSource = { + type: "base64"; + media_type: AnthropicMediaTypes; + data: string; // base64 encoded image bytes +}; + +export type AnthropicImageContent = { + type: "image"; + source: AnthropicImageSource; +}; + +export type AnthropicMessage = { + role: "user" | "assistant"; + content: AnthropicContent[]; +}; + +export type AnthropicNoneStreamingResponse = { + id: string; + type: "message"; + role: "assistant"; + content: AnthropicContent[]; + model: string; + stop_reason: "end_turn" | "max_tokens" | "stop_sequence"; + stop_sequence?: string; + usage: { input_tokens: number; output_tokens: number }; +}; + +export type AnthropicTool = { + name: string; + description: string; + input_schema: ToolMetadata["parameters"]; +}; diff --git a/packages/community/src/llm/bedrock/anthropic/utils.ts b/packages/community/src/llm/bedrock/anthropic/utils.ts new file mode 100644 index 000000000..2d5926032 --- /dev/null +++ b/packages/community/src/llm/bedrock/anthropic/utils.ts @@ -0,0 +1,186 @@ +import type { JSONObject } from "@llamaindex/core/global"; +import type { + BaseTool, + ChatMessage, + MessageContent, + MessageContentDetail, + ToolCallLLMMessageOptions, +} from "@llamaindex/core/llms"; +import { mapMessageContentToMessageContentDetails } from "../utils"; +import type { + AnthropicContent, + AnthropicImageContent, + AnthropicMediaTypes, + AnthropicMessage, + AnthropicTextContent, + AnthropicTool, +} from "./types.js"; + +const ACCEPTED_IMAGE_MIME_TYPES = [ + "image/jpeg", + "image/png", + "image/webp", + "image/gif", +]; + +export const mergeNeighboringSameRoleMessages = ( + messages: AnthropicMessage[], +): AnthropicMessage[] => { + return messages.reduce( + (result: AnthropicMessage[], current: AnthropicMessage, index: number) => { + if (index > 0 && messages[index - 1].role === current.role) { + result[result.length - 1].content = [ + ...result[result.length - 1].content, + ...current.content, + ]; + } else { + result.push(current); + } + return result; + }, + [], + ); +}; + +export const mapMessageContentDetailToAnthropicContent = < + T extends MessageContentDetail, +>( + detail: T, +): AnthropicContent => { + let content: AnthropicContent; + + if (detail.type === "text") { + content = mapTextContent(detail.text); + } else if (detail.type === "image_url") { + content = mapImageContent(detail.image_url.url); + } else { + throw new Error("Unsupported content detail type"); + } + return content; +}; + +export const mapMessageContentToAnthropicContent = <T extends MessageContent>( + content: T, +): AnthropicContent[] => { + return mapMessageContentToMessageContentDetails(content).map( + mapMessageContentDetailToAnthropicContent, + ); +}; + +export const mapBaseToolsToAnthropicTools = ( + tools?: BaseTool[], +): AnthropicTool[] => { + if (!tools) return []; + return tools.map((tool: BaseTool) => { + const { + metadata: { parameters, ...options }, + } = tool; + return { + ...options, + input_schema: parameters, + }; + }); +}; + +export const mapChatMessagesToAnthropicMessages = < + T extends ChatMessage<ToolCallLLMMessageOptions>, +>( + messages: T[], +): AnthropicMessage[] => { + const mapped = messages + .flatMap((msg: T): AnthropicMessage[] => { + if (msg.options && "toolCall" in msg.options) { + return [ + { + role: "assistant", + content: msg.options.toolCall.map((call) => ({ + type: "tool_use", + id: call.id, + name: call.name, + input: call.input as JSONObject, + })), + }, + ]; + } + if (msg.options && "toolResult" in msg.options) { + return [ + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: msg.options.toolResult.id, + content: msg.options.toolResult.result, + }, + ], + }, + ]; + } + return mapMessageContentToMessageContentDetails(msg.content).map( + (detail: MessageContentDetail): AnthropicMessage => { + const content = mapMessageContentDetailToAnthropicContent(detail); + + return { + role: msg.role === "assistant" ? "assistant" : "user", + content: [content], + }; + }, + ); + }) + .filter((message: AnthropicMessage) => { + const content = message.content[0]; + if (content.type === "text" && !content.text) return false; + if (content.type === "image" && !content.source.data) return false; + if (content.type === "image" && message.role === "assistant") + return false; + return true; + }); + + return mergeNeighboringSameRoleMessages(mapped); +}; + +export const mapTextContent = (text: string): AnthropicTextContent => { + return { type: "text", text }; +}; + +export const extractDataUrlComponents = ( + dataUrl: string, +): { + mimeType: string; + base64: string; +} => { + const parts = dataUrl.split(";base64,"); + + if (parts.length !== 2 || !parts[0].startsWith("data:")) { + throw new Error("Invalid data URL"); + } + + const mimeType = parts[0].slice(5); + const base64 = parts[1]; + + return { + mimeType, + base64, + }; +}; + +export const mapImageContent = (imageUrl: string): AnthropicImageContent => { + if (!imageUrl.startsWith("data:")) + throw new Error( + "For Anthropic please only use base64 data url, e.g.: data:image/jpeg;base64,SGVsbG8sIFdvcmxkIQ==", + ); + const { mimeType, base64: data } = extractDataUrlComponents(imageUrl); + if (!ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType)) + throw new Error( + `Anthropic only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join("\n")}`, + ); + + return { + type: "image", + source: { + type: "base64", + media_type: mimeType as AnthropicMediaTypes, + data, + }, + }; +}; diff --git a/packages/community/src/llm/bedrock/base.ts b/packages/community/src/llm/bedrock/index.ts similarity index 97% rename from packages/community/src/llm/bedrock/base.ts rename to packages/community/src/llm/bedrock/index.ts index f90a54e71..c0a110dd5 100644 --- a/packages/community/src/llm/bedrock/base.ts +++ b/packages/community/src/llm/bedrock/index.ts @@ -22,8 +22,16 @@ import { type BedrockChatStreamResponse, Provider, } from "./provider"; -import { PROVIDERS } from "./providers"; -import { mapMessageContentToMessageContentDetails } from "./utils.js"; +import { mapMessageContentToMessageContentDetails } from "./utils"; + +import { AnthropicProvider } from "./anthropic/provider"; +import { MetaProvider } from "./meta/provider"; + +// Other providers should go here +export const PROVIDERS: { [key: string]: Provider } = { + anthropic: new AnthropicProvider(), + meta: new MetaProvider(), +}; export type BedrockChatParamsStreaming = LLMChatParamsStreaming< BedrockAdditionalChatOptions, @@ -140,6 +148,7 @@ export const TOOL_CALL_MODELS = [ BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_HAIKU, BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_OPUS, BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_5_SONNET, + BEDROCK_MODELS.META_LLAMA3_1_405B_INSTRUCT, ]; const getProvider = (model: string): Provider => { diff --git a/packages/community/src/llm/bedrock/meta/constants.ts b/packages/community/src/llm/bedrock/meta/constants.ts new file mode 100644 index 000000000..e93636f71 --- /dev/null +++ b/packages/community/src/llm/bedrock/meta/constants.ts @@ -0,0 +1,3 @@ +export const TOKENS = { + TOOL_CALL: "<|python_tag|>", +}; diff --git a/packages/community/src/llm/bedrock/meta/provider.ts b/packages/community/src/llm/bedrock/meta/provider.ts new file mode 100644 index 000000000..5b7f78ac8 --- /dev/null +++ b/packages/community/src/llm/bedrock/meta/provider.ts @@ -0,0 +1,136 @@ +import type { + InvokeModelCommandInput, + InvokeModelWithResponseStreamCommandInput, + ResponseStream, +} from "@aws-sdk/client-bedrock-runtime"; +import type { + BaseTool, + ChatMessage, + LLMMetadata, + ToolCall, + ToolCallLLMMessageOptions, +} from "@llamaindex/core/llms"; +import { toUtf8 } from "../utils"; +import type { MetaNoneStreamingResponse, MetaStreamEvent } from "./types"; + +import { randomUUID } from "@llamaindex/env"; +import { Provider, type BedrockChatStreamResponse } from "../provider"; +import { TOKENS } from "./constants"; +import { + mapChatMessagesToMetaLlama2Messages, + mapChatMessagesToMetaLlama3Messages, +} from "./utils"; + +export class MetaProvider extends Provider<MetaStreamEvent> { + getResultFromResponse( + response: Record<string, any>, + ): MetaNoneStreamingResponse { + return JSON.parse(toUtf8(response.body)); + } + + getToolsFromResponse<ToolContent>( + response: Record<string, any>, + ): ToolContent[] { + const result = this.getResultFromResponse(response); + if (!result.generation.trim().startsWith(TOKENS.TOOL_CALL)) return []; + const tool = JSON.parse( + result.generation.trim().split(TOKENS.TOOL_CALL)[1], + ); + return [ + { + id: randomUUID(), + name: tool.name, + input: tool.parameters, + } as ToolContent, + ]; + } + + getTextFromResponse(response: Record<string, any>): string { + const result = this.getResultFromResponse(response); + if (result.generation.trim().startsWith(TOKENS.TOOL_CALL)) return ""; + return result.generation; + } + + getTextFromStreamResponse(response: Record<string, any>): string { + const event = this.getStreamingEventResponse(response); + if (event?.generation) { + return event.generation; + } + return ""; + } + + async *reduceStream( + stream: AsyncIterable<ResponseStream>, + ): BedrockChatStreamResponse { + const collecting: string[] = []; + let toolId: string | undefined = undefined; + for await (const response of stream) { + const event = this.getStreamingEventResponse(response); + const delta = this.getTextFromStreamResponse(response); + // odd quirk of llama3.1, start token is \n\n + if ( + !event?.generation.trim() && + event?.generation_token_count === 1 && + event.prompt_token_count !== null + ) + continue; + + if (delta === TOKENS.TOOL_CALL) { + toolId = randomUUID(); + continue; + } + + let options: undefined | ToolCallLLMMessageOptions = undefined; + if (toolId && event?.stop_reason === "stop") { + const tool = JSON.parse(collecting.join("")); + options = { + toolCall: [ + { + id: toolId, + name: tool.name, + input: tool.parameters, + } as ToolCall, + ], + }; + } else if (toolId && !event?.stop_reason) { + collecting.push(delta); + continue; + } + + if (!delta && !options) continue; + + yield { + delta: options ? "" : delta, + options, + raw: response, + }; + } + } + + getRequestBody<T extends ChatMessage>( + metadata: LLMMetadata, + messages: T[], + tools?: BaseTool[], + ): InvokeModelCommandInput | InvokeModelWithResponseStreamCommandInput { + let prompt: string = ""; + if (metadata.model.startsWith("meta.llama3")) { + prompt = mapChatMessagesToMetaLlama3Messages(messages, tools); + } else if (metadata.model.startsWith("meta.llama2")) { + prompt = mapChatMessagesToMetaLlama2Messages(messages); + } else { + throw new Error(`Meta model ${metadata.model} is not supported`); + } + + return { + modelId: metadata.model, + contentType: "application/json", + accept: "application/json", + body: JSON.stringify({ + prompt, + max_gen_len: metadata.maxTokens, + temperature: metadata.temperature, + top_p: metadata.topP, + }), + }; + } +} diff --git a/packages/community/src/llm/bedrock/meta/types.ts b/packages/community/src/llm/bedrock/meta/types.ts new file mode 100644 index 000000000..4dcb9f705 --- /dev/null +++ b/packages/community/src/llm/bedrock/meta/types.ts @@ -0,0 +1,21 @@ +import type { InvocationMetrics } from "../types"; + +export type MetaTextContent = string; + +export type MetaMessage = { + role: "user" | "assistant" | "system" | "ipython"; + content: MetaTextContent; +}; + +type MetaResponse = { + generation: string; + prompt_token_count: number; + generation_token_count: number; + stop_reason: "stop" | "length"; +}; + +export type MetaStreamEvent = MetaResponse & { + "amazon-bedrock-invocationMetrics": InvocationMetrics; +}; + +export type MetaNoneStreamingResponse = MetaResponse; diff --git a/packages/community/src/llm/bedrock/meta/utils.ts b/packages/community/src/llm/bedrock/meta/utils.ts new file mode 100644 index 000000000..fd8cb1e69 --- /dev/null +++ b/packages/community/src/llm/bedrock/meta/utils.ts @@ -0,0 +1,198 @@ +import type { + BaseTool, + ChatMessage, + MessageContentTextDetail, + ToolCallLLMMessageOptions, +} from "@llamaindex/core/llms"; +import type { MetaMessage } from "./types"; + +const getToolCallInstructionString = (tool: BaseTool): string => { + return `Use the function '${tool.metadata.name}' to '${tool.metadata.description}'`; +}; + +const getToolCallParametersString = (tool: BaseTool): string => { + return JSON.stringify({ + name: tool.metadata.name, + description: tool.metadata.description, + parameters: tool.metadata.parameters + ? Object.entries(tool.metadata.parameters.properties).map( + ([name, definition]) => ({ [name]: definition }), + ) + : {}, + }); +}; + +// ported from https://github.com/meta-llama/llama-agentic-system/blob/main/llama_agentic_system/system_prompt.py +// NOTE: using json instead of the above xml style tool calling works more reliability +export const getToolsPrompt = (tools?: BaseTool[]) => { + if (!tools?.length) return ""; + + const customToolParams = tools.map((tool) => { + return [ + getToolCallInstructionString(tool), + getToolCallParametersString(tool), + ].join("\n\n"); + }); + + return ` +Environment: node + +# Tool Instructions +- Never use ipython, always use javascript in node + +Cutting Knowledge Date: December 2023 +Today Date: ${new Date().toLocaleString("en-US", { year: "numeric", month: "long" })} + +You have access to the following functions: + +${customToolParams} + +Think very carefully before calling functions. + +If a you choose to call a function ONLY reply in the following json format: +{ + "name": function_name, + "parameters": parameters, +} +where + +{ + "name": function_name, + "parameters": parameters, => a JSON dict with the function argument name as key and function argument value as value. +} + +Here is an example, + +{ + "name": "example_function_name", + "parameters": {"example_name": "example_value"} +} + +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line +- Always add your sources when using search results to answer the user query + `; +}; + +export const mapChatRoleToMetaRole = ( + role: ChatMessage["role"], +): MetaMessage["role"] => { + if (role === "assistant") return "assistant"; + if (role === "user") return "user"; + return "system"; +}; + +export const mapChatMessagesToMetaMessages = < + T extends ChatMessage<ToolCallLLMMessageOptions>, +>( + messages: T[], +): MetaMessage[] => { + return messages.flatMap((msg) => { + if (msg.options && "toolCall" in msg.options) { + return msg.options.toolCall.map((call) => ({ + role: "assistant", + content: JSON.stringify({ + id: call.id, + name: call.name, + parameters: call.input, + }), + })); + } + + if (msg.options && "toolResult" in msg.options) { + return { + role: "ipython", + content: JSON.stringify(msg.options.toolResult), + }; + } + + let content: string = ""; + if (typeof msg.content === "string") { + content = msg.content; + } else if (msg.content.length) { + content = (msg.content[0] as MessageContentTextDetail).text; + } + return { + role: mapChatRoleToMetaRole(msg.role), + content, + }; + }); +}; + +/** + * Documentation at https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3 + */ +export const mapChatMessagesToMetaLlama3Messages = <T extends ChatMessage>( + messages: T[], + tools?: BaseTool[], +): string => { + const parts: string[] = []; + if (tools?.length) { + parts.push( + "<|begin_of_text|>", + "<|start_header_id|>system<|end_header_id|>", + getToolsPrompt(tools), + "<|eot_id|>", + ); + } + + const mapped = mapChatMessagesToMetaMessages(messages).map((message) => { + return [ + "<|start_header_id|>", + message.role, + "<|end_header_id|>", + message.content, + "<|eot_id|>", + ].join("\n"); + }); + + parts.push( + "<|begin_of_text|>", + ...mapped, + "<|start_header_id|>assistant<|end_header_id|>", + ); + return parts.join("\n"); +}; + +/** + * Documentation at https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-2 + */ +export const mapChatMessagesToMetaLlama2Messages = <T extends ChatMessage>( + messages: T[], +): string => { + const mapped = mapChatMessagesToMetaMessages(messages); + let output = "<s>"; + let insideInst = false; + let needsStartAgain = false; + for (const message of mapped) { + if (needsStartAgain) { + output += "<s>"; + needsStartAgain = false; + } + const text = message.content; + if (message.role === "system") { + if (!insideInst) { + output += "[INST] "; + insideInst = true; + } + output += `<<SYS>>\n${text}\n<</SYS>>\n`; + } else if (message.role === "user") { + output += text; + if (insideInst) { + output += " [/INST]"; + insideInst = false; + } + } else if (message.role === "assistant") { + if (insideInst) { + output += " [/INST]"; + insideInst = false; + } + output += ` ${text} </s>\n`; + needsStartAgain = true; + } + } + return output; +}; diff --git a/packages/community/src/llm/bedrock/provider.ts b/packages/community/src/llm/bedrock/provider.ts index 43aaed8b9..1a4a6f973 100644 --- a/packages/community/src/llm/bedrock/provider.ts +++ b/packages/community/src/llm/bedrock/provider.ts @@ -23,6 +23,7 @@ export type BedrockChatStreamResponse = AsyncIterable< export abstract class Provider<ProviderStreamEvent extends {} = {}> { abstract getTextFromResponse(response: Record<string, any>): string; + // Return tool calls from none streaming calls abstract getToolsFromResponse<T extends {} = {}>( response: Record<string, any>, ): T[]; diff --git a/packages/community/src/llm/bedrock/providers/index.ts b/packages/community/src/llm/bedrock/providers/index.ts deleted file mode 100644 index 01ba640d5..000000000 --- a/packages/community/src/llm/bedrock/providers/index.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { Provider } from "../provider"; -import { AnthropicProvider } from "./anthropic"; -import { MetaProvider } from "./meta"; - -// Other providers should go here -export const PROVIDERS: { [key: string]: Provider } = { - anthropic: new AnthropicProvider(), - meta: new MetaProvider(), -}; diff --git a/packages/community/src/llm/bedrock/providers/meta.ts b/packages/community/src/llm/bedrock/providers/meta.ts deleted file mode 100644 index 2e19ec9a5..000000000 --- a/packages/community/src/llm/bedrock/providers/meta.ts +++ /dev/null @@ -1,69 +0,0 @@ -import type { - InvokeModelCommandInput, - InvokeModelWithResponseStreamCommandInput, -} from "@aws-sdk/client-bedrock-runtime"; -import type { ChatMessage, LLMMetadata } from "@llamaindex/core/llms"; -import type { MetaNoneStreamingResponse, MetaStreamEvent } from "../types"; -import { - mapChatMessagesToMetaLlama2Messages, - mapChatMessagesToMetaLlama3Messages, - toUtf8, -} from "../utils"; - -import { Provider } from "../provider"; - -export class MetaProvider extends Provider<MetaStreamEvent> { - constructor() { - super(); - } - - getResultFromResponse( - response: Record<string, any>, - ): MetaNoneStreamingResponse { - return JSON.parse(toUtf8(response.body)); - } - - getToolsFromResponse(_response: Record<string, any>): never { - throw new Error("Not supported by this provider."); - } - - getTextFromResponse(response: Record<string, any>): string { - const result = this.getResultFromResponse(response); - return result.generation; - } - - getTextFromStreamResponse(response: Record<string, any>): string { - const event = this.getStreamingEventResponse(response); - if (event?.generation) { - return event.generation; - } - return ""; - } - - getRequestBody<T extends ChatMessage>( - metadata: LLMMetadata, - messages: T[], - ): InvokeModelCommandInput | InvokeModelWithResponseStreamCommandInput { - let promptFunction: (messages: ChatMessage[]) => string; - - if (metadata.model.startsWith("meta.llama3")) { - promptFunction = mapChatMessagesToMetaLlama3Messages; - } else if (metadata.model.startsWith("meta.llama2")) { - promptFunction = mapChatMessagesToMetaLlama2Messages; - } else { - throw new Error(`Meta model ${metadata.model} is not supported`); - } - - return { - modelId: metadata.model, - contentType: "application/json", - accept: "application/json", - body: JSON.stringify({ - prompt: promptFunction(messages), - max_gen_len: metadata.maxTokens, - temperature: metadata.temperature, - top_p: metadata.topP, - }), - }; - } -} diff --git a/packages/community/src/llm/bedrock/types.ts b/packages/community/src/llm/bedrock/types.ts index a72554c73..86a18f3e2 100644 --- a/packages/community/src/llm/bedrock/types.ts +++ b/packages/community/src/llm/bedrock/types.ts @@ -1,165 +1,11 @@ -type Usage = { - input_tokens: number; - output_tokens: number; -}; - -type Message = { - id: string; - type: string; - role: string; - content: string[]; - model: string; - stop_reason: string | null; - stop_sequence: string | null; - usage: Usage; -}; - -export type ToolBlock = { - id: string; - input: unknown; - name: string; - type: "tool_use"; -}; - -export type TextBlock = { - type: "text"; - text: string; -}; - -type ContentBlockStart = { - type: "content_block_start"; - index: number; - content_block: ToolBlock | TextBlock; -}; - -type Delta = - | { - type: "text_delta"; - text: string; - } - | { - type: "input_json_delta"; - partial_json: string; - }; - -type ContentBlockDelta = { - type: "content_block_delta"; - index: number; - delta: Delta; -}; - -type ContentBlockStop = { - type: "content_block_stop"; - index: number; -}; - -type MessageDelta = { - type: "message_delta"; - delta: { - stop_reason: string; - stop_sequence: string | null; - }; - usage: Usage; -}; - -type InvocationMetrics = { +export type InvocationMetrics = { inputTokenCount: number; outputTokenCount: number; invocationLatency: number; firstByteLatency: number; }; -type MessageStop = { - type: "message_stop"; - "amazon-bedrock-invocationMetrics": InvocationMetrics; -}; - export type ToolChoice = | { type: "any" } | { type: "auto" } | { type: "tool"; name: string }; - -export type AnthropicStreamEvent = - | { type: "message_start"; message: Message } - | ContentBlockStart - | ContentBlockDelta - | ContentBlockStop - | MessageDelta - | MessageStop; - -export type AnthropicContent = - | AnthropicTextContent - | AnthropicImageContent - | AnthropicToolContent - | AnthropicToolResultContent; - -export type MetaTextContent = string; - -export type AnthropicTextContent = { - type: "text"; - text: string; -}; - -export type AnthropicToolContent = { - type: "tool_use"; - id: string; - name: string; - input: Record<string, unknown>; -}; - -export type AnthropicToolResultContent = { - type: "tool_result"; - tool_use_id: string; - content: string; -}; - -export type AnthropicMediaTypes = - | "image/jpeg" - | "image/png" - | "image/webp" - | "image/gif"; - -export type AnthropicImageSource = { - type: "base64"; - media_type: AnthropicMediaTypes; - data: string; // base64 encoded image bytes -}; - -export type AnthropicImageContent = { - type: "image"; - source: AnthropicImageSource; -}; - -export type AnthropicMessage = { - role: "user" | "assistant"; - content: AnthropicContent[]; -}; - -export type MetaMessage = { - role: "user" | "assistant" | "system"; - content: MetaTextContent; -}; - -export type AnthropicNoneStreamingResponse = { - id: string; - type: "message"; - role: "assistant"; - content: AnthropicContent[]; - model: string; - stop_reason: "end_turn" | "max_tokens" | "stop_sequence"; - stop_sequence?: string; - usage: { input_tokens: number; output_tokens: number }; -}; - -type MetaResponse = { - generation: string; - prompt_token_count: number; - generation_token_count: number; - stop_reason: "stop" | "length"; -}; - -export type MetaStreamEvent = MetaResponse & { - "amazon-bedrock-invocationMetrics": InvocationMetrics; -}; - -export type MetaNoneStreamingResponse = MetaResponse; diff --git a/packages/community/src/llm/bedrock/utils.ts b/packages/community/src/llm/bedrock/utils.ts index 4c79f994c..964651882 100644 --- a/packages/community/src/llm/bedrock/utils.ts +++ b/packages/community/src/llm/bedrock/utils.ts @@ -1,28 +1,7 @@ -import type { JSONObject } from "@llamaindex/core/global"; import type { - BaseTool, - ChatMessage, MessageContent, MessageContentDetail, - MessageContentTextDetail, - ToolCallLLMMessageOptions, - ToolMetadata, } from "@llamaindex/core/llms"; -import type { - AnthropicContent, - AnthropicImageContent, - AnthropicMediaTypes, - AnthropicMessage, - AnthropicTextContent, - MetaMessage, -} from "./types.js"; - -const ACCEPTED_IMAGE_MIME_TYPES = [ - "image/jpeg", - "image/png", - "image/webp", - "image/gif", -]; export const mapMessageContentToMessageContentDetails = ( content: MessageContent, @@ -30,252 +9,5 @@ export const mapMessageContentToMessageContentDetails = ( return Array.isArray(content) ? content : [{ type: "text", text: content }]; }; -export const mergeNeighboringSameRoleMessages = ( - messages: AnthropicMessage[], -): AnthropicMessage[] => { - return messages.reduce( - (result: AnthropicMessage[], current: AnthropicMessage, index: number) => { - if (index > 0 && messages[index - 1].role === current.role) { - result[result.length - 1].content = [ - ...result[result.length - 1].content, - ...current.content, - ]; - } else { - result.push(current); - } - return result; - }, - [], - ); -}; - -export const mapMessageContentDetailToAnthropicContent = < - T extends MessageContentDetail, ->( - detail: T, -): AnthropicContent => { - let content: AnthropicContent; - - if (detail.type === "text") { - content = mapTextContent(detail.text); - } else if (detail.type === "image_url") { - content = mapImageContent(detail.image_url.url); - } else { - throw new Error("Unsupported content detail type"); - } - return content; -}; - -export const mapMessageContentToAnthropicContent = <T extends MessageContent>( - content: T, -): AnthropicContent[] => { - return mapMessageContentToMessageContentDetails(content).map( - mapMessageContentDetailToAnthropicContent, - ); -}; - -type AnthropicTool = { - name: string; - description: string; - input_schema: ToolMetadata["parameters"]; -}; - -export const mapBaseToolsToAnthropicTools = ( - tools?: BaseTool[], -): AnthropicTool[] => { - if (!tools) return []; - return tools.map((tool: BaseTool) => { - const { - metadata: { parameters, ...options }, - } = tool; - return { - ...options, - input_schema: parameters, - }; - }); -}; - -export const mapChatMessagesToAnthropicMessages = < - T extends ChatMessage<ToolCallLLMMessageOptions>, ->( - messages: T[], -): AnthropicMessage[] => { - const mapped = messages - .flatMap((msg: T): AnthropicMessage[] => { - if (msg.options && "toolCall" in msg.options) { - return [ - { - role: "assistant", - content: msg.options.toolCall.map((call) => ({ - type: "tool_use", - id: call.id, - name: call.name, - input: call.input as JSONObject, - })), - }, - ]; - } - if (msg.options && "toolResult" in msg.options) { - return [ - { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: msg.options.toolResult.id, - content: msg.options.toolResult.result, - }, - ], - }, - ]; - } - return mapMessageContentToMessageContentDetails(msg.content).map( - (detail: MessageContentDetail): AnthropicMessage => { - const content = mapMessageContentDetailToAnthropicContent(detail); - - return { - role: msg.role === "assistant" ? "assistant" : "user", - content: [content], - }; - }, - ); - }) - .filter((message: AnthropicMessage) => { - const content = message.content[0]; - if (content.type === "text" && !content.text) return false; - if (content.type === "image" && !content.source.data) return false; - if (content.type === "image" && message.role === "assistant") - return false; - return true; - }); - - return mergeNeighboringSameRoleMessages(mapped); -}; - -export const mapChatMessagesToMetaMessages = <T extends ChatMessage>( - messages: T[], -): MetaMessage[] => { - return messages.map((msg) => { - let content: string = ""; - if (typeof msg.content === "string") { - content = msg.content; - } else if (msg.content.length) { - content = (msg.content[0] as MessageContentTextDetail).text; - } - return { - role: - msg.role === "assistant" - ? "assistant" - : msg.role === "user" - ? "user" - : "system", - content, - }; - }); -}; - -/** - * Documentation at https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3 - */ -export const mapChatMessagesToMetaLlama3Messages = <T extends ChatMessage>( - messages: T[], -): string => { - const mapped = mapChatMessagesToMetaMessages(messages).map((message) => { - const text = message.content; - return `<|start_header_id|>${message.role}<|end_header_id|>\n${text}\n<|eot_id|>\n`; - }); - return ( - "<|begin_of_text|>" + - mapped.join("\n") + - "\n<|start_header_id|>assistant<|end_header_id|>\n" - ); -}; - -/** - * Documentation at https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-2 - */ -export const mapChatMessagesToMetaLlama2Messages = <T extends ChatMessage>( - messages: T[], -): string => { - const mapped = mapChatMessagesToMetaMessages(messages); - let output = "<s>"; - let insideInst = false; - let needsStartAgain = false; - for (const message of mapped) { - if (needsStartAgain) { - output += "<s>"; - needsStartAgain = false; - } - const text = message.content; - if (message.role === "system") { - if (!insideInst) { - output += "[INST] "; - insideInst = true; - } - output += `<<SYS>>\n${text}\n<</SYS>>\n`; - } else if (message.role === "user") { - output += text; - if (insideInst) { - output += " [/INST]"; - insideInst = false; - } - } else if (message.role === "assistant") { - if (insideInst) { - output += " [/INST]"; - insideInst = false; - } - output += ` ${text} </s>\n`; - needsStartAgain = true; - } - } - return output; -}; - -export const mapTextContent = (text: string): AnthropicTextContent => { - return { type: "text", text }; -}; - -export const extractDataUrlComponents = ( - dataUrl: string, -): { - mimeType: string; - base64: string; -} => { - const parts = dataUrl.split(";base64,"); - - if (parts.length !== 2 || !parts[0].startsWith("data:")) { - throw new Error("Invalid data URL"); - } - - const mimeType = parts[0].slice(5); - const base64 = parts[1]; - - return { - mimeType, - base64, - }; -}; - -export const mapImageContent = (imageUrl: string): AnthropicImageContent => { - if (!imageUrl.startsWith("data:")) - throw new Error( - "For Anthropic please only use base64 data url, e.g.: data:image/jpeg;base64,SGVsbG8sIFdvcmxkIQ==", - ); - const { mimeType, base64: data } = extractDataUrlComponents(imageUrl); - if (!ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType)) - throw new Error( - `Anthropic only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join("\n")}`, - ); - - return { - type: "image", - source: { - type: "base64", - media_type: mimeType as AnthropicMediaTypes, - data, - }, - }; -}; - export const toUtf8 = (input: Uint8Array): string => new TextDecoder("utf-8").decode(input); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 03cf0f81b..7d3119ed8 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -352,6 +352,9 @@ importers: '@llamaindex/core': specifier: workspace:* version: link:../core + '@llamaindex/env': + specifier: workspace:* + version: link:../env devDependencies: '@types/node': specifier: ^20.14.2 -- GitLab