Newer
Older
import type { ClientOptions } from "@anthropic-ai/sdk";
import { Anthropic as SDKAnthropic } from "@anthropic-ai/sdk";
import type {
BetaCacheControlEphemeral,
BetaTextBlockParam,
} from "@anthropic-ai/sdk/resources/beta/index";
import type { TextBlock } from "@anthropic-ai/sdk/resources/index";
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
Tool,
ToolUseBlock,
} from "@anthropic-ai/sdk/resources/messages";
import { wrapLLMEvent } from "@llamaindex/core/decorator";
import type {
BaseTool,
ChatMessage,
ChatResponse,
ChatResponseChunk,
LLMChatParamsNonStreaming,
LLMChatParamsStreaming,
ToolCallLLMMessageOptions,
} from "@llamaindex/core/llms";
import { ToolCallLLM } from "@llamaindex/core/llms";
import { extractText } from "@llamaindex/core/utils";
import { getEnv } from "@llamaindex/env";
import { isDeepEqual } from "remeda";
export class AnthropicSession {
anthropic: SDKAnthropic;
constructor(options: ClientOptions = {}) {
if (!options.apiKey) {
options.apiKey = getEnv("ANTHROPIC_API_KEY");
}
if (!options.apiKey) {
throw new Error("Set Anthropic Key in ANTHROPIC_API_KEY env variable");
}
this.anthropic = new SDKAnthropic(options);
}
}
// I'm not 100% sure this is necessary vs. just starting a new session
// every time we make a call. They say they try to reuse connections
// so in theory this is more efficient, but we should test it in the future.
const defaultAnthropicSession: {
session: AnthropicSession;
options: ClientOptions;
}[] = [];
/**
* Get a session for the Anthropic API. If one already exists with the same options,
* it will be returned. Otherwise, a new session will be created.
* @param options
* @returns
*/
export function getAnthropicSession(options: ClientOptions = {}) {
let session = defaultAnthropicSession.find((session) => {
return isDeepEqual(session.options, options);
})?.session;
if (!session) {
session = new AnthropicSession(options);
defaultAnthropicSession.push({ session, options });
}
return session;
}
export const ALL_AVAILABLE_ANTHROPIC_LEGACY_MODELS = {
"claude-2.1": {
contextWindow: 200000,
},
"claude-2.0": {
contextWindow: 100000,
},
"claude-instant-1.2": {
contextWindow: 100000,
},
};
export const ALL_AVAILABLE_V3_MODELS = {
"claude-3-opus": { contextWindow: 200000 },
"claude-3-opus-latest": { contextWindow: 200000 },
"claude-3-opus-20240229": { contextWindow: 200000 },
"claude-3-sonnet": { contextWindow: 200000 },
"claude-3-sonnet-20240229": { contextWindow: 200000 },
"claude-3-haiku": { contextWindow: 200000 },
"claude-3-haiku-20240307": { contextWindow: 200000 },
};
export const ALL_AVAILABLE_V3_5_MODELS = {
"claude-3-5-sonnet": { contextWindow: 200000 },
"claude-3-5-sonnet-20241022": { contextWindow: 200000 },
"claude-3-5-sonnet-20240620": { contextWindow: 200000 },
"claude-3-5-sonnet-latest": { contextWindow: 200000 },
"claude-3-5-haiku": { contextWindow: 200000 },
"claude-3-5-haiku-latest": { contextWindow: 200000 },
"claude-3-5-haiku-20241022": { contextWindow: 200000 },
};
export const ALL_AVAILABLE_ANTHROPIC_MODELS = {
...ALL_AVAILABLE_ANTHROPIC_LEGACY_MODELS,
...ALL_AVAILABLE_V3_MODELS,
...ALL_AVAILABLE_V3_5_MODELS,
} satisfies {
[key in Model]: { contextWindow: number };
};
const AVAILABLE_ANTHROPIC_MODELS_WITHOUT_DATE: { [key: string]: string } = {
"claude-3-opus": "claude-3-opus-20240229",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-haiku": "claude-3-haiku-20240307",
"claude-3-5-sonnet": "claude-3-5-sonnet-20240620",
} as { [key in keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS]: string };
export type AnthropicAdditionalChatOptions = object;
export type AnthropicToolCallLLMMessageOptions = ToolCallLLMMessageOptions & {
cache_control?: BetaCacheControlEphemeral | null;
};
export class Anthropic extends ToolCallLLM<
AnthropicAdditionalChatOptions,
AnthropicToolCallLLMMessageOptions
> {
// Per completion Anthropic params
model: keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS | ({} & string);
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
temperature: number;
topP: number;
maxTokens?: number | undefined;
// Anthropic session params
apiKey?: string | undefined;
maxRetries: number;
timeout?: number;
session: AnthropicSession;
constructor(init?: Partial<Anthropic>) {
super();
this.model = init?.model ?? "claude-3-opus";
this.temperature = init?.temperature ?? 0.1;
this.topP = init?.topP ?? 0.999; // Per Ben Mann
this.maxTokens = init?.maxTokens ?? undefined;
this.apiKey = init?.apiKey ?? undefined;
this.maxRetries = init?.maxRetries ?? 10;
this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
this.session =
init?.session ??
getAnthropicSession({
apiKey: this.apiKey,
maxRetries: this.maxRetries,
timeout: this.timeout,
});
}
get supportToolCall() {
return this.model.startsWith("claude-3");
}
get metadata() {
return {
model: this.model,
temperature: this.temperature,
topP: this.topP,
maxTokens: this.maxTokens,
contextWindow:
this.model in ALL_AVAILABLE_ANTHROPIC_MODELS
? ALL_AVAILABLE_ANTHROPIC_MODELS[
this.model as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS
].contextWindow
: 200000,
tokenizer: undefined,
};
}
getModelName = (model: string): string => {
if (Object.keys(AVAILABLE_ANTHROPIC_MODELS_WITHOUT_DATE).includes(model)) {
return AVAILABLE_ANTHROPIC_MODELS_WITHOUT_DATE[model]!;
}
return model;
};
formatMessages(
messages: ChatMessage<ToolCallLLMMessageOptions>[],
): MessageParam[] {
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
const formattedMessages = messages.flatMap((message) => {
const options = message.options ?? {};
if (message.role === "system") {
// Skip system messages
return [];
}
if ("toolCall" in options) {
const formattedMessage: MessageParam = {
role: "assistant",
content: [
{
type: "text" as const,
text: extractText(message.content),
},
...options.toolCall.map((tool) => ({
type: "tool_use" as const,
id: tool.id,
name: tool.name,
input:
typeof tool.input === "string"
? JSON.parse(tool.input)
: tool.input,
})),
],
};
return formattedMessage;
}
// Handle tool results
if ("toolResult" in options) {
const formattedMessage: MessageParam = {
role: "user",
content: [
{
type: "tool_result" as const,
tool_use_id: options.toolResult.id,
content: extractText(message.content),
},
],
};
return formattedMessage;
}
// Handle regular messages
if (typeof message.content === "string") {
const role: "user" | "assistant" =
message.role === "assistant" ? "assistant" : "user";
role,
content: message.content,
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
}
// Handle multi-modal content
const role: "user" | "assistant" =
message.role === "assistant" ? "assistant" : "user";
return {
role,
content: message.content.map((content) => {
if (content.type === "text") {
return {
type: "text" as const,
text: content.text,
};
}
return {
type: "image" as const,
source: {
type: "base64" as const,
media_type: `image/${content.image_url.url.substring(
"data:image/".length,
content.image_url.url.indexOf(";base64"),
)}` as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
data: content.image_url.url.substring(
content.image_url.url.indexOf(",") + 1,
),
},
};
}),
} satisfies MessageParam;
});
return this.mergeConsecutiveMessages(formattedMessages);
}
// Add helper method to prepare tools for API call
private prepareToolsForAPI(tools: BaseTool[]): Tool[] {
return tools.map((tool) => {
if (tool.metadata.parameters?.type !== "object") {
throw new TypeError("Tool parameters must be an object");
}
return {
input_schema: {
type: "object",
properties: tool.metadata.parameters.properties,
required: tool.metadata.parameters.required,
},
name: tool.metadata.name,
description: tool.metadata.description,
};
});
}
private mergeConsecutiveMessages(messages: MessageParam[]): MessageParam[] {
const result: MessageParam[] = [];
for (let i = 0; i < messages.length; i++) {
result.push(messages[i]!);
const current = messages[i]!;
const previous = result[result.length - 1]!;
if (current.role === previous.role) {
// Merge content based on type
if (Array.isArray(previous.content)) {
if (Array.isArray(current.content)) {
previous.content.push(...current.content);
} else {
previous.content.push({
type: "text",
text: current.content,
});
}
} else {
if (Array.isArray(current.content)) {
previous.content = [
{ type: "text", text: previous.content },
...current.content,
];
} else {
previous.content = `${previous.content}\n${current.content}`;
} else {
result.push(current);
}
chat(
params: LLMChatParamsStreaming<
AnthropicAdditionalChatOptions,
AnthropicToolCallLLMMessageOptions
): Promise<
AsyncIterable<ChatResponseChunk<AnthropicToolCallLLMMessageOptions>>
>;
chat(
params: LLMChatParamsNonStreaming<
AnthropicAdditionalChatOptions,
AnthropicToolCallLLMMessageOptions
): Promise<ChatResponse<AnthropicToolCallLLMMessageOptions>>;
@wrapLLMEvent
async chat(
params:
| LLMChatParamsNonStreaming<AnthropicToolCallLLMMessageOptions>
| LLMChatParamsStreaming<AnthropicToolCallLLMMessageOptions>,
| ChatResponse<AnthropicToolCallLLMMessageOptions>
| AsyncIterable<ChatResponseChunk<AnthropicToolCallLLMMessageOptions>>
const { messages, stream, tools } = params;
// Handle system messages
let systemPrompt: string | BetaTextBlockParam[] | null = null;
const systemMessages = messages.filter(
(message) => message.role === "system",
);
if (systemMessages.length > 0) {
systemPrompt = systemMessages.map((message): BetaTextBlockParam => {
const textContent = extractText(message.content);
if (message.options && "cache_control" in message.options) {
return {
type: "text" as const,
text: textContent,
cache_control: message.options
.cache_control as BetaCacheControlEphemeral,
};
}
return {
type: "text" as const,
text: textContent,
};
});
Array.isArray(systemPrompt) &&
systemPrompt.some((message) => "cache_control" in message);
let anthropic = this.session.anthropic;
if (beta) {
// @ts-expect-error type casting
anthropic = anthropic.beta.promptCaching;
}
if (stream) {
if (tools) {
console.error("Tools are not supported in streaming mode");
}
return this.streamChat(
messages.filter((m) => m.role !== "system"),
systemPrompt,
anthropic,
);
const apiParams = {
model: this.getModelName(this.model),
messages: this.mergeConsecutiveMessages(
this.formatMessages(messages.filter((m) => m.role !== "system")),
),
max_tokens: this.maxTokens ?? 4096,
temperature: this.temperature,
top_p: this.topP,
...(systemPrompt && { system: systemPrompt }),
};
if (tools?.length) {
Object.assign(apiParams, {
tools: this.prepareToolsForAPI(tools),
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
const response = await anthropic.messages.create(apiParams);
const toolUseBlock = response.content.filter(
(content): content is ToolUseBlock => content.type === "tool_use",
);
return {
raw: response,
message: {
content: response.content
.filter((content): content is TextBlock => content.type === "text")
.map((content) => ({
type: "text" as const,
text: content.text,
})),
role: "assistant",
options:
toolUseBlock.length > 0
? {
toolCall: toolUseBlock.map((block) => ({
id: block.id,
name: block.name,
input: JSON.stringify(block.input),
})),
}
: {},
},
};
}
protected async *streamChat(
messages: ChatMessage<AnthropicToolCallLLMMessageOptions>[],
systemPrompt: string | Array<BetaTextBlockParam> | null,
anthropic: SDKAnthropic,
): AsyncIterable<ChatResponseChunk<AnthropicToolCallLLMMessageOptions>> {
const stream = await anthropic.messages.create({
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
model: this.getModelName(this.model),
messages: this.formatMessages(messages),
max_tokens: this.maxTokens ?? 4096,
temperature: this.temperature,
top_p: this.topP,
stream: true,
...(systemPrompt && { system: systemPrompt }),
});
let idx_counter: number = 0;
for await (const part of stream) {
const content =
part.type === "content_block_delta"
? part.delta.type === "text_delta"
? part.delta.text
: part.delta
: undefined;
if (typeof content !== "string") continue;
idx_counter++;
yield {
raw: part,
delta: content,
options: {},
};
}
return;
}
static toTool(tool: BaseTool): Tool {
if (tool.metadata.parameters?.type !== "object") {
throw new TypeError("Tool parameters must be an object");
}
return {
input_schema: {
type: "object",
properties: tool.metadata.parameters.properties,
required: tool.metadata.parameters.required,
},
name: tool.metadata.name,
description: tool.metadata.description,
};
}
}