Skip to content
Snippets Groups Projects
Unverified Commit 2774e802 authored by Parham Saidi's avatar Parham Saidi Committed by GitHub
Browse files

feat: Meta Llama 3.2 via bedrock (#1285)

parent 449274ca
No related branches found
No related tags found
No related merge requests found
---
"@llamaindex/community": patch
---
feat: added meta3.2 support via Bedrock including vision, tool call and inference region support
...@@ -5,9 +5,11 @@ ...@@ -5,9 +5,11 @@
## Current Features: ## Current Features:
- Bedrock support for the Anthropic Claude Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock) - Bedrock support for the Anthropic Claude Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock)
- Bedrock support for the Meta LLama 2, 3 and 3.1 Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock) - Bedrock support for the Meta LLama 2, 3, 3.1 and 3.2 Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock)
- Meta LLama3.1 405b tool call support - Meta LLama3.1 405b and Llama3.2 tool call support
- Meta 3.2 11B and 90B vision support
- Bedrock support for querying Knowledge Base - Bedrock support for querying Knowledge Base
- Bedrock: [Supported Regions and models for cross-region inference](https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-support.html)
## LICENSE ## LICENSE
......
...@@ -2,5 +2,7 @@ export { ...@@ -2,5 +2,7 @@ export {
BEDROCK_MODELS, BEDROCK_MODELS,
BEDROCK_MODEL_MAX_TOKENS, BEDROCK_MODEL_MAX_TOKENS,
Bedrock, Bedrock,
INFERENCE_BEDROCK_MODELS,
INFERENCE_TO_BEDROCK_MAP,
} from "./llm/bedrock/index.js"; } from "./llm/bedrock/index.js";
export { AmazonKnowledgeBaseRetriever } from "./retrievers/bedrock.js"; export { AmazonKnowledgeBaseRetriever } from "./retrievers/bedrock.js";
...@@ -6,7 +6,10 @@ import type { ...@@ -6,7 +6,10 @@ import type {
MessageContentDetail, MessageContentDetail,
ToolCallLLMMessageOptions, ToolCallLLMMessageOptions,
} from "@llamaindex/core/llms"; } from "@llamaindex/core/llms";
import { mapMessageContentToMessageContentDetails } from "../utils"; import {
extractDataUrlComponents,
mapMessageContentToMessageContentDetails,
} from "../utils";
import type { import type {
AnthropicContent, AnthropicContent,
AnthropicImageContent, AnthropicImageContent,
...@@ -143,27 +146,6 @@ export const mapTextContent = (text: string): AnthropicTextContent => { ...@@ -143,27 +146,6 @@ export const mapTextContent = (text: string): AnthropicTextContent => {
return { type: "text", text }; return { type: "text", text };
}; };
export const extractDataUrlComponents = (
dataUrl: string,
): {
mimeType: string;
base64: string;
} => {
const parts = dataUrl.split(";base64,");
if (parts.length !== 2 || !parts[0]!.startsWith("data:")) {
throw new Error("Invalid data URL");
}
const mimeType = parts[0]!.slice(5);
const base64 = parts[1]!;
return {
mimeType,
base64,
};
};
export const mapImageContent = (imageUrl: string): AnthropicImageContent => { export const mapImageContent = (imageUrl: string): AnthropicImageContent => {
if (!imageUrl.startsWith("data:")) if (!imageUrl.startsWith("data:"))
throw new Error( throw new Error(
......
...@@ -47,35 +47,96 @@ export type BedrockChatParamsNonStreaming = LLMChatParamsNonStreaming< ...@@ -47,35 +47,96 @@ export type BedrockChatParamsNonStreaming = LLMChatParamsNonStreaming<
export type BedrockChatNonStreamResponse = export type BedrockChatNonStreamResponse =
ChatResponse<ToolCallLLMMessageOptions>; ChatResponse<ToolCallLLMMessageOptions>;
export enum BEDROCK_MODELS { export const BEDROCK_MODELS = {
AMAZON_TITAN_TG1_LARGE = "amazon.titan-tg1-large", AMAZON_TITAN_TG1_LARGE: "amazon.titan-tg1-large",
AMAZON_TITAN_TEXT_EXPRESS_V1 = "amazon.titan-text-express-v1", AMAZON_TITAN_TEXT_EXPRESS_V1: "amazon.titan-text-express-v1",
AI21_J2_GRANDE_INSTRUCT = "ai21.j2-grande-instruct", AI21_J2_GRANDE_INSTRUCT: "ai21.j2-grande-instruct",
AI21_J2_JUMBO_INSTRUCT = "ai21.j2-jumbo-instruct", AI21_J2_JUMBO_INSTRUCT: "ai21.j2-jumbo-instruct",
AI21_J2_MID = "ai21.j2-mid", AI21_J2_MID: "ai21.j2-mid",
AI21_J2_MID_V1 = "ai21.j2-mid-v1", AI21_J2_MID_V1: "ai21.j2-mid-v1",
AI21_J2_ULTRA = "ai21.j2-ultra", AI21_J2_ULTRA: "ai21.j2-ultra",
AI21_J2_ULTRA_V1 = "ai21.j2-ultra-v1", AI21_J2_ULTRA_V1: "ai21.j2-ultra-v1",
COHERE_COMMAND_TEXT_V14 = "cohere.command-text-v14", COHERE_COMMAND_TEXT_V14: "cohere.command-text-v14",
ANTHROPIC_CLAUDE_INSTANT_1 = "anthropic.claude-instant-v1", ANTHROPIC_CLAUDE_INSTANT_1: "anthropic.claude-instant-v1",
ANTHROPIC_CLAUDE_1 = "anthropic.claude-v1", // EOF: No longer supported ANTHROPIC_CLAUDE_1: "anthropic.claude-v1", // EOF: No longer supported
ANTHROPIC_CLAUDE_2 = "anthropic.claude-v2", ANTHROPIC_CLAUDE_2: "anthropic.claude-v2",
ANTHROPIC_CLAUDE_2_1 = "anthropic.claude-v2:1", ANTHROPIC_CLAUDE_2_1: "anthropic.claude-v2:1",
ANTHROPIC_CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0", ANTHROPIC_CLAUDE_3_SONNET: "anthropic.claude-3-sonnet-20240229-v1:0",
ANTHROPIC_CLAUDE_3_HAIKU = "anthropic.claude-3-haiku-20240307-v1:0", ANTHROPIC_CLAUDE_3_HAIKU: "anthropic.claude-3-haiku-20240307-v1:0",
ANTHROPIC_CLAUDE_3_OPUS = "anthropic.claude-3-opus-20240229-v1:0", ANTHROPIC_CLAUDE_3_OPUS: "anthropic.claude-3-opus-20240229-v1:0",
ANTHROPIC_CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0", ANTHROPIC_CLAUDE_3_5_SONNET: "anthropic.claude-3-5-sonnet-20240620-v1:0",
META_LLAMA2_13B_CHAT = "meta.llama2-13b-chat-v1", META_LLAMA2_13B_CHAT: "meta.llama2-13b-chat-v1",
META_LLAMA2_70B_CHAT = "meta.llama2-70b-chat-v1", META_LLAMA2_70B_CHAT: "meta.llama2-70b-chat-v1",
META_LLAMA3_8B_INSTRUCT = "meta.llama3-8b-instruct-v1:0", META_LLAMA3_8B_INSTRUCT: "meta.llama3-8b-instruct-v1:0",
META_LLAMA3_70B_INSTRUCT = "meta.llama3-70b-instruct-v1:0", META_LLAMA3_70B_INSTRUCT: "meta.llama3-70b-instruct-v1:0",
META_LLAMA3_1_8B_INSTRUCT = "meta.llama3-1-8b-instruct-v1:0", META_LLAMA3_1_8B_INSTRUCT: "meta.llama3-1-8b-instruct-v1:0",
META_LLAMA3_1_70B_INSTRUCT = "meta.llama3-1-70b-instruct-v1:0", META_LLAMA3_1_70B_INSTRUCT: "meta.llama3-1-70b-instruct-v1:0",
META_LLAMA3_1_405B_INSTRUCT = "meta.llama3-1-405b-instruct-v1:0", META_LLAMA3_1_405B_INSTRUCT: "meta.llama3-1-405b-instruct-v1:0",
MISTRAL_7B_INSTRUCT = "mistral.mistral-7b-instruct-v0:2", META_LLAMA3_2_1B_INSTRUCT: "meta.llama3-2-1b-instruct-v1:0",
MISTRAL_MIXTRAL_7B_INSTRUCT = "mistral.mixtral-8x7b-instruct-v0:1", META_LLAMA3_2_3B_INSTRUCT: "meta.llama3-2-3b-instruct-v1:0",
MISTRAL_MIXTRAL_LARGE_2402 = "mistral.mistral-large-2402-v1:0", META_LLAMA3_2_11B_INSTRUCT: "meta.llama3-2-11b-instruct-v1:0",
} META_LLAMA3_2_90B_INSTRUCT: "meta.llama3-2-90b-instruct-v1:0",
MISTRAL_7B_INSTRUCT: "mistral.mistral-7b-instruct-v0:2",
MISTRAL_MIXTRAL_7B_INSTRUCT: "mistral.mixtral-8x7b-instruct-v0:1",
MISTRAL_MIXTRAL_LARGE_2402: "mistral.mistral-large-2402-v1:0",
};
export type BEDROCK_MODELS =
(typeof BEDROCK_MODELS)[keyof typeof BEDROCK_MODELS];
export const INFERENCE_BEDROCK_MODELS = {
US_ANTHROPIC_CLAUDE_3_HAIKU: "us.anthropic.claude-3-haiku-20240307-v1:0",
US_ANTHROPIC_CLAUDE_3_OPUS: "us.anthropic.claude-3-opus-20240229-v1:0",
US_ANTHROPIC_CLAUDE_3_SONNET: "us.anthropic.claude-3-sonnet-20240229-v1:0",
US_ANTHROPIC_CLAUDE_3_5_SONNET:
"us.anthropic.claude-3-5-sonnet-20240620-v1:0",
US_META_LLAMA_3_2_1B_INSTRUCT: "us.meta.llama3-2-1b-instruct-v1:0",
US_META_LLAMA_3_2_3B_INSTRUCT: "us.meta.llama3-2-3b-instruct-v1:0",
US_META_LLAMA_3_2_11B_INSTRUCT: "us.meta.llama3-2-11b-instruct-v1:0",
US_META_LLAMA_3_2_90B_INSTRUCT: "us.meta.llama3-2-90b-instruct-v1:0",
EU_ANTHROPIC_CLAUDE_3_HAIKU: "eu.anthropic.claude-3-haiku-20240307-v1:0",
EU_ANTHROPIC_CLAUDE_3_SONNET: "eu.anthropic.claude-3-sonnet-20240229-v1:0",
EU_ANTHROPIC_CLAUDE_3_5_SONNET:
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
EU_META_LLAMA_3_2_1B_INSTRUCT: "eu.meta.llama3-2-1b-instruct-v1:0",
EU_META_LLAMA_3_2_3B_INSTRUCT: "eu.meta.llama3-2-3b-instruct-v1:0",
};
export type INFERENCE_BEDROCK_MODELS =
(typeof INFERENCE_BEDROCK_MODELS)[keyof typeof INFERENCE_BEDROCK_MODELS];
export const INFERENCE_TO_BEDROCK_MAP: Record<
INFERENCE_BEDROCK_MODELS,
BEDROCK_MODELS
> = {
[INFERENCE_BEDROCK_MODELS.US_ANTHROPIC_CLAUDE_3_HAIKU]:
BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_HAIKU,
[INFERENCE_BEDROCK_MODELS.US_ANTHROPIC_CLAUDE_3_OPUS]:
BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_OPUS,
[INFERENCE_BEDROCK_MODELS.US_ANTHROPIC_CLAUDE_3_SONNET]:
BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_SONNET,
[INFERENCE_BEDROCK_MODELS.US_ANTHROPIC_CLAUDE_3_5_SONNET]:
BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_5_SONNET,
[INFERENCE_BEDROCK_MODELS.US_META_LLAMA_3_2_1B_INSTRUCT]:
BEDROCK_MODELS.META_LLAMA3_2_1B_INSTRUCT,
[INFERENCE_BEDROCK_MODELS.US_META_LLAMA_3_2_3B_INSTRUCT]:
BEDROCK_MODELS.META_LLAMA3_2_3B_INSTRUCT,
[INFERENCE_BEDROCK_MODELS.US_META_LLAMA_3_2_11B_INSTRUCT]:
BEDROCK_MODELS.META_LLAMA3_2_11B_INSTRUCT,
[INFERENCE_BEDROCK_MODELS.US_META_LLAMA_3_2_90B_INSTRUCT]:
BEDROCK_MODELS.META_LLAMA3_2_90B_INSTRUCT,
[INFERENCE_BEDROCK_MODELS.EU_ANTHROPIC_CLAUDE_3_HAIKU]:
BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_HAIKU,
[INFERENCE_BEDROCK_MODELS.EU_ANTHROPIC_CLAUDE_3_SONNET]:
BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_SONNET,
[INFERENCE_BEDROCK_MODELS.EU_ANTHROPIC_CLAUDE_3_5_SONNET]:
BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_5_SONNET,
[INFERENCE_BEDROCK_MODELS.EU_META_LLAMA_3_2_1B_INSTRUCT]:
BEDROCK_MODELS.META_LLAMA3_2_1B_INSTRUCT,
[INFERENCE_BEDROCK_MODELS.EU_META_LLAMA_3_2_3B_INSTRUCT]:
BEDROCK_MODELS.META_LLAMA3_2_3B_INSTRUCT,
};
/* /*
* Values taken from https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html#model-parameters-claude * Values taken from https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html#model-parameters-claude
...@@ -109,6 +170,10 @@ const CHAT_ONLY_MODELS = { ...@@ -109,6 +170,10 @@ const CHAT_ONLY_MODELS = {
[BEDROCK_MODELS.META_LLAMA3_1_8B_INSTRUCT]: 128000, [BEDROCK_MODELS.META_LLAMA3_1_8B_INSTRUCT]: 128000,
[BEDROCK_MODELS.META_LLAMA3_1_70B_INSTRUCT]: 128000, [BEDROCK_MODELS.META_LLAMA3_1_70B_INSTRUCT]: 128000,
[BEDROCK_MODELS.META_LLAMA3_1_405B_INSTRUCT]: 128000, [BEDROCK_MODELS.META_LLAMA3_1_405B_INSTRUCT]: 128000,
[BEDROCK_MODELS.META_LLAMA3_2_1B_INSTRUCT]: 131000,
[BEDROCK_MODELS.META_LLAMA3_2_3B_INSTRUCT]: 131000,
[BEDROCK_MODELS.META_LLAMA3_2_11B_INSTRUCT]: 128000,
[BEDROCK_MODELS.META_LLAMA3_2_90B_INSTRUCT]: 128000,
[BEDROCK_MODELS.MISTRAL_7B_INSTRUCT]: 32000, [BEDROCK_MODELS.MISTRAL_7B_INSTRUCT]: 32000,
[BEDROCK_MODELS.MISTRAL_MIXTRAL_7B_INSTRUCT]: 32000, [BEDROCK_MODELS.MISTRAL_MIXTRAL_7B_INSTRUCT]: 32000,
[BEDROCK_MODELS.MISTRAL_MIXTRAL_LARGE_2402]: 32000, [BEDROCK_MODELS.MISTRAL_MIXTRAL_LARGE_2402]: 32000,
...@@ -139,17 +204,25 @@ export const STREAMING_MODELS = new Set([ ...@@ -139,17 +204,25 @@ export const STREAMING_MODELS = new Set([
BEDROCK_MODELS.META_LLAMA3_1_8B_INSTRUCT, BEDROCK_MODELS.META_LLAMA3_1_8B_INSTRUCT,
BEDROCK_MODELS.META_LLAMA3_1_70B_INSTRUCT, BEDROCK_MODELS.META_LLAMA3_1_70B_INSTRUCT,
BEDROCK_MODELS.META_LLAMA3_1_405B_INSTRUCT, BEDROCK_MODELS.META_LLAMA3_1_405B_INSTRUCT,
BEDROCK_MODELS.META_LLAMA3_2_1B_INSTRUCT,
BEDROCK_MODELS.META_LLAMA3_2_3B_INSTRUCT,
BEDROCK_MODELS.META_LLAMA3_2_11B_INSTRUCT,
BEDROCK_MODELS.META_LLAMA3_2_90B_INSTRUCT,
BEDROCK_MODELS.MISTRAL_7B_INSTRUCT, BEDROCK_MODELS.MISTRAL_7B_INSTRUCT,
BEDROCK_MODELS.MISTRAL_MIXTRAL_7B_INSTRUCT, BEDROCK_MODELS.MISTRAL_MIXTRAL_7B_INSTRUCT,
BEDROCK_MODELS.MISTRAL_MIXTRAL_LARGE_2402, BEDROCK_MODELS.MISTRAL_MIXTRAL_LARGE_2402,
]); ]);
export const TOOL_CALL_MODELS = [ export const TOOL_CALL_MODELS: BEDROCK_MODELS[] = [
BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_SONNET, BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_SONNET,
BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_HAIKU, BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_HAIKU,
BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_OPUS, BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_OPUS,
BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_5_SONNET, BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_5_SONNET,
BEDROCK_MODELS.META_LLAMA3_1_405B_INSTRUCT, BEDROCK_MODELS.META_LLAMA3_1_405B_INSTRUCT,
BEDROCK_MODELS.META_LLAMA3_2_1B_INSTRUCT,
BEDROCK_MODELS.META_LLAMA3_2_3B_INSTRUCT,
BEDROCK_MODELS.META_LLAMA3_2_11B_INSTRUCT,
BEDROCK_MODELS.META_LLAMA3_2_90B_INSTRUCT,
]; ];
const getProvider = (model: string): Provider => { const getProvider = (model: string): Provider => {
...@@ -166,7 +239,7 @@ const getProvider = (model: string): Provider => { ...@@ -166,7 +239,7 @@ const getProvider = (model: string): Provider => {
}; };
export type BedrockModelParams = { export type BedrockModelParams = {
model: keyof typeof BEDROCK_FOUNDATION_LLMS; model: BEDROCK_MODELS | INFERENCE_BEDROCK_MODELS;
temperature?: number; temperature?: number;
topP?: number; topP?: number;
maxTokens?: number; maxTokens?: number;
...@@ -185,6 +258,10 @@ export const BEDROCK_MODEL_MAX_TOKENS: Partial<Record<BEDROCK_MODELS, number>> = ...@@ -185,6 +258,10 @@ export const BEDROCK_MODEL_MAX_TOKENS: Partial<Record<BEDROCK_MODELS, number>> =
[BEDROCK_MODELS.META_LLAMA3_1_8B_INSTRUCT]: 2048, [BEDROCK_MODELS.META_LLAMA3_1_8B_INSTRUCT]: 2048,
[BEDROCK_MODELS.META_LLAMA3_1_70B_INSTRUCT]: 2048, [BEDROCK_MODELS.META_LLAMA3_1_70B_INSTRUCT]: 2048,
[BEDROCK_MODELS.META_LLAMA3_1_405B_INSTRUCT]: 2048, [BEDROCK_MODELS.META_LLAMA3_1_405B_INSTRUCT]: 2048,
[BEDROCK_MODELS.META_LLAMA3_2_1B_INSTRUCT]: 2048,
[BEDROCK_MODELS.META_LLAMA3_2_3B_INSTRUCT]: 2048,
[BEDROCK_MODELS.META_LLAMA3_2_11B_INSTRUCT]: 2048,
[BEDROCK_MODELS.META_LLAMA3_2_90B_INSTRUCT]: 2048,
}; };
const DEFAULT_BEDROCK_PARAMS = { const DEFAULT_BEDROCK_PARAMS = {
...@@ -193,14 +270,15 @@ const DEFAULT_BEDROCK_PARAMS = { ...@@ -193,14 +270,15 @@ const DEFAULT_BEDROCK_PARAMS = {
maxTokens: 1024, // required by anthropic maxTokens: 1024, // required by anthropic
}; };
export type BedrockParams = BedrockModelParams & BedrockRuntimeClientConfig; export type BedrockParams = BedrockRuntimeClientConfig & BedrockModelParams;
/** /**
* ToolCallLLM for Bedrock * ToolCallLLM for Bedrock
*/ */
export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> { export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> {
private client: BedrockRuntimeClient; private client: BedrockRuntimeClient;
model: keyof typeof BEDROCK_FOUNDATION_LLMS; protected actualModel: BEDROCK_MODELS | INFERENCE_BEDROCK_MODELS;
model: BEDROCK_MODELS;
temperature: number; temperature: number;
topP: number; topP: number;
maxTokens?: number; maxTokens?: number;
...@@ -217,8 +295,8 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> { ...@@ -217,8 +295,8 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> {
...params ...params
}: BedrockParams) { }: BedrockParams) {
super(); super();
this.actualModel = model;
this.model = model; this.model = INFERENCE_TO_BEDROCK_MAP[model] ?? model;
this.provider = getProvider(this.model); this.provider = getProvider(this.model);
this.maxTokens = maxTokens ?? DEFAULT_BEDROCK_PARAMS.maxTokens; this.maxTokens = maxTokens ?? DEFAULT_BEDROCK_PARAMS.maxTokens;
this.temperature = temperature ?? DEFAULT_BEDROCK_PARAMS.temperature; this.temperature = temperature ?? DEFAULT_BEDROCK_PARAMS.temperature;
...@@ -241,7 +319,7 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> { ...@@ -241,7 +319,7 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> {
temperature: this.temperature, temperature: this.temperature,
topP: this.topP, topP: this.topP,
maxTokens: this.maxTokens, maxTokens: this.maxTokens,
contextWindow: BEDROCK_FOUNDATION_LLMS[this.model], contextWindow: BEDROCK_FOUNDATION_LLMS[this.model] ?? 128000,
tokenizer: undefined, tokenizer: undefined,
}; };
} }
...@@ -256,6 +334,8 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> { ...@@ -256,6 +334,8 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> {
params.additionalChatOptions, params.additionalChatOptions,
); );
const command = new InvokeModelCommand(input); const command = new InvokeModelCommand(input);
command.input.modelId = this.actualModel;
const response = await this.client.send(command); const response = await this.client.send(command);
let options: ToolCallLLMMessageOptions = {}; let options: ToolCallLLMMessageOptions = {};
if (this.supportToolCall) { if (this.supportToolCall) {
...@@ -287,6 +367,8 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> { ...@@ -287,6 +367,8 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> {
params.additionalChatOptions, params.additionalChatOptions,
); );
const command = new InvokeModelWithResponseStreamCommand(input); const command = new InvokeModelWithResponseStreamCommand(input);
command.input.modelId = this.actualModel;
const response = await this.client.send(command); const response = await this.client.send(command);
if (response.body) yield* this.provider.reduceStream(response.body); if (response.body) yield* this.provider.reduceStream(response.body);
......
...@@ -67,21 +67,26 @@ export class MetaProvider extends Provider<MetaStreamEvent> { ...@@ -67,21 +67,26 @@ export class MetaProvider extends Provider<MetaStreamEvent> {
for await (const response of stream) { for await (const response of stream) {
const event = this.getStreamingEventResponse(response); const event = this.getStreamingEventResponse(response);
const delta = this.getTextFromStreamResponse(response); const delta = this.getTextFromStreamResponse(response);
// odd quirk of llama3.1, start token is \n\n // odd quirk of llama3.1, start token is \n\n
if ( if (
!toolId &&
!event?.generation.trim() && !event?.generation.trim() &&
event?.generation_token_count === 1 && event?.generation_token_count === 1 &&
event.prompt_token_count !== null event?.prompt_token_count !== null
) )
continue; continue;
if (delta === TOKENS.TOOL_CALL) { if (delta.startsWith(TOKENS.TOOL_CALL)) {
toolId = randomUUID(); toolId = randomUUID();
const parts = delta.split(TOKENS.TOOL_CALL).filter((part) => part);
collecting.push(...parts);
continue; continue;
} }
let options: undefined | ToolCallLLMMessageOptions = undefined; let options: undefined | ToolCallLLMMessageOptions = undefined;
if (toolId && event?.stop_reason === "stop") { if (toolId && event?.stop_reason === "stop") {
if (delta) collecting.push(delta);
const tool = JSON.parse(collecting.join("")); const tool = JSON.parse(collecting.join(""));
options = { options = {
toolCall: [ toolCall: [
...@@ -110,11 +115,18 @@ export class MetaProvider extends Provider<MetaStreamEvent> { ...@@ -110,11 +115,18 @@ export class MetaProvider extends Provider<MetaStreamEvent> {
getRequestBody<T extends ChatMessage>( getRequestBody<T extends ChatMessage>(
metadata: LLMMetadata, metadata: LLMMetadata,
messages: T[], messages: T[],
tools?: BaseTool[], tools: BaseTool[] = [],
): InvokeModelCommandInput | InvokeModelWithResponseStreamCommandInput { ): InvokeModelCommandInput | InvokeModelWithResponseStreamCommandInput {
let prompt: string = ""; let prompt: string = "";
let images: string[] = [];
if (metadata.model.startsWith("meta.llama3")) { if (metadata.model.startsWith("meta.llama3")) {
prompt = mapChatMessagesToMetaLlama3Messages(messages, tools); const mapped = mapChatMessagesToMetaLlama3Messages({
messages,
tools,
model: metadata.model,
});
prompt = mapped.prompt;
images = mapped.images;
} else if (metadata.model.startsWith("meta.llama2")) { } else if (metadata.model.startsWith("meta.llama2")) {
prompt = mapChatMessagesToMetaLlama2Messages(messages); prompt = mapChatMessagesToMetaLlama2Messages(messages);
} else { } else {
...@@ -127,6 +139,7 @@ export class MetaProvider extends Provider<MetaStreamEvent> { ...@@ -127,6 +139,7 @@ export class MetaProvider extends Provider<MetaStreamEvent> {
accept: "application/json", accept: "application/json",
body: JSON.stringify({ body: JSON.stringify({
prompt, prompt,
images: images.length ? images : undefined,
max_gen_len: metadata.maxTokens, max_gen_len: metadata.maxTokens,
temperature: metadata.temperature, temperature: metadata.temperature,
top_p: metadata.topP, top_p: metadata.topP,
......
import type { import type {
BaseTool, BaseTool,
ChatMessage, ChatMessage,
LLMMetadata,
MessageContentTextDetail, MessageContentTextDetail,
ToolCallLLMMessageOptions, ToolCallLLMMessageOptions,
} from "@llamaindex/core/llms"; } from "@llamaindex/core/llms";
import { extractDataUrlComponents } from "../utils";
import { TOKENS } from "./constants";
import type { MetaMessage } from "./types"; import type { MetaMessage } from "./types";
const getToolCallInstructionString = (tool: BaseTool): string => { const getToolCallInstructionString = (tool: BaseTool): string => {
...@@ -24,7 +27,7 @@ const getToolCallParametersString = (tool: BaseTool): string => { ...@@ -24,7 +27,7 @@ const getToolCallParametersString = (tool: BaseTool): string => {
// ported from https://github.com/meta-llama/llama-agentic-system/blob/main/llama_agentic_system/system_prompt.py // ported from https://github.com/meta-llama/llama-agentic-system/blob/main/llama_agentic_system/system_prompt.py
// NOTE: using json instead of the above xml style tool calling works more reliability // NOTE: using json instead of the above xml style tool calling works more reliability
export const getToolsPrompt = (tools?: BaseTool[]) => { export const getToolsPrompt_3_1 = (tools?: BaseTool[]) => {
if (!tools?.length) return ""; if (!tools?.length) return "";
const customToolParams = tools.map((tool) => { const customToolParams = tools.map((tool) => {
...@@ -77,6 +80,46 @@ Reminder: ...@@ -77,6 +80,46 @@ Reminder:
`; `;
}; };
export const getToolsPrompt_3_2 = (tools?: BaseTool[]) => {
if (!tools?.length) return "";
return `
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of and start with the token: ${TOKENS.TOOL_CALL}:
{
"name": function_name,
"parameters": parameters,
}
where
{
"name": function_name,
"parameters": parameters, => a JSON dict with the function argument name as key and function argument value as value.
}
Here is an example,
{
"name": "example_function_name",
"parameters": {"example_name": "example_value"}
}
Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- You SHOULD NOT include any other text in the response
- Put the entire function call reply on one line
Here is a list of functions in JSON format that you can invoke.
${JSON.stringify(tools)}
`;
};
export const mapChatRoleToMetaRole = ( export const mapChatRoleToMetaRole = (
role: ChatMessage["role"], role: ChatMessage["role"],
): MetaMessage["role"] => { ): MetaMessage["role"] => {
...@@ -125,16 +168,46 @@ export const mapChatMessagesToMetaMessages = < ...@@ -125,16 +168,46 @@ export const mapChatMessagesToMetaMessages = <
/** /**
* Documentation at https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3 * Documentation at https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3
*/ */
export const mapChatMessagesToMetaLlama3Messages = <T extends ChatMessage>( export const mapChatMessagesToMetaLlama3Messages = <T extends ChatMessage>({
messages: T[], messages,
tools?: BaseTool[], model,
): string => { tools,
}: {
messages: T[];
model: LLMMetadata["model"];
tools?: BaseTool[];
}): { prompt: string; images: string[] } => {
const images: string[] = [];
const textMessages: T[] = [];
messages.forEach((message) => {
if (Array.isArray(message.content)) {
message.content.forEach((content) => {
if (content.type === "image_url") {
const { base64 } = extractDataUrlComponents(content.image_url.url);
images.push(base64);
} else {
textMessages.push(message);
}
});
} else {
textMessages.push(message);
}
});
const parts: string[] = []; const parts: string[] = [];
if (tools?.length) {
let toolsPrompt = "";
if (model.startsWith("meta.llama3-2")) {
toolsPrompt = getToolsPrompt_3_2(tools);
} else if (model.startsWith("meta.llama3-1")) {
toolsPrompt = getToolsPrompt_3_1(tools);
}
if (toolsPrompt) {
parts.push( parts.push(
"<|begin_of_text|>", "<|begin_of_text|>",
"<|start_header_id|>system<|end_header_id|>", "<|start_header_id|>system<|end_header_id|>",
getToolsPrompt(tools), toolsPrompt,
"<|eot_id|>", "<|eot_id|>",
); );
} }
...@@ -154,7 +227,9 @@ export const mapChatMessagesToMetaLlama3Messages = <T extends ChatMessage>( ...@@ -154,7 +227,9 @@ export const mapChatMessagesToMetaLlama3Messages = <T extends ChatMessage>(
...mapped, ...mapped,
"<|start_header_id|>assistant<|end_header_id|>", "<|start_header_id|>assistant<|end_header_id|>",
); );
return parts.join("\n");
const prompt = parts.join("\n");
return { prompt, images };
}; };
/** /**
......
...@@ -11,3 +11,24 @@ export const mapMessageContentToMessageContentDetails = ( ...@@ -11,3 +11,24 @@ export const mapMessageContentToMessageContentDetails = (
export const toUtf8 = (input: Uint8Array): string => export const toUtf8 = (input: Uint8Array): string =>
new TextDecoder("utf-8").decode(input); new TextDecoder("utf-8").decode(input);
export const extractDataUrlComponents = (
dataUrl: string,
): {
mimeType: string;
base64: string;
} => {
const parts = dataUrl.split(";base64,");
if (parts.length !== 2 || !parts[0]!.startsWith("data:")) {
throw new Error("Invalid data URL");
}
const mimeType = parts[0]!.slice(5);
const base64 = parts[1]!;
return {
mimeType,
base64,
};
};
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment