Skip to content
Snippets Groups Projects
Unverified Commit d8d952d9 authored by Thuc Pham's avatar Thuc Pham Committed by GitHub
Browse files

feat: init gemini llm (#769)

parent 216ba1f2
No related branches found
No related tags found
Loading
---
"llamaindex": patch
---
feat: add gemini llm and embedding
# Gemini
To use Gemini embeddings, you need to import `GeminiEmbedding` from `llamaindex`.
```ts
import { GeminiEmbedding, Settings } from "llamaindex";
// Update Embed Model
Settings.embedModel = new GeminiEmbedding();
const document = new Document({ text: essay, id_: "essay" });
const index = await VectorStoreIndex.fromDocuments([document]);
const queryEngine = index.asQueryEngine();
const query = "What is the meaning of life?";
const results = await queryEngine.query({
query,
});
```
Per default, `GeminiEmbedding` is using the `gemini-pro` model. You can change the model by passing the `model` parameter to the constructor.
For example:
```ts
import { GEMINI_MODEL, GeminiEmbedding } from "llamaindex";
Settings.embedModel = new GeminiEmbedding({
model: GEMINI_MODEL.GEMINI_PRO_LATEST,
});
```
# Gemini
## Usage
```ts
import { Gemini, Settings, GEMINI_MODEL } from "llamaindex";
Settings.llm = new Gemini({
model: GEMINI_MODEL.GEMINI_PRO,
});
```
## Load and index documents
For this example, we will use a single document. In a real-world scenario, you would have multiple documents to index.
```ts
const document = new Document({ text: essay, id_: "essay" });
const index = await VectorStoreIndex.fromDocuments([document]);
```
## Query
```ts
const queryEngine = index.asQueryEngine();
const query = "What is the meaning of life?";
const results = await queryEngine.query({
query,
});
```
## Full Example
```ts
import {
Gemini,
Document,
VectorStoreIndex,
Settings,
GEMINI_MODEL,
} from "llamaindex";
Settings.llm = new Gemini({
model: GEMINI_MODEL.GEMINI_PRO,
});
async function main() {
const document = new Document({ text: essay, id_: "essay" });
// Load and index documents
const index = await VectorStoreIndex.fromDocuments([document]);
// Create a query engine
const queryEngine = index.asQueryEngine({
retriever,
});
const query = "What is the meaning of life?";
// Query
const response = await queryEngine.query({
query,
});
// Log the response
console.log(response.response);
}
```
import { GEMINI_MODEL, GeminiEmbedding } from "llamaindex";
async function main() {
if (!process.env.GOOGLE_API_KEY) {
throw new Error("Please set the GOOGLE_API_KEY environment variable.");
}
const embedModel = new GeminiEmbedding({
model: GEMINI_MODEL.GEMINI_PRO,
});
const texts = ["hello", "world"];
const embeddings = await embedModel.getTextEmbeddingsBatch(texts);
console.log(`\nWe have ${embeddings.length} embeddings`);
}
main().catch(console.error);
import { Gemini, GEMINI_MODEL } from "llamaindex";
(async () => {
if (!process.env.GOOGLE_API_KEY) {
throw new Error("Please set the GOOGLE_API_KEY environment variable.");
}
const gemini = new Gemini({
model: GEMINI_MODEL.GEMINI_PRO,
});
const result = await gemini.chat({
messages: [
{ content: "You want to talk in rhymes.", role: "system" },
{
content:
"How much wood would a woodchuck chuck if a woodchuck could chuck wood?",
role: "user",
},
],
});
console.log(result);
})();
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
"@anthropic-ai/sdk": "^0.20.6", "@anthropic-ai/sdk": "^0.20.6",
"@aws-crypto/sha256-js": "^5.2.0", "@aws-crypto/sha256-js": "^5.2.0",
"@datastax/astra-db-ts": "^1.0.1", "@datastax/astra-db-ts": "^1.0.1",
"@google/generative-ai": "^0.8.0",
"@grpc/grpc-js": "^1.10.6", "@grpc/grpc-js": "^1.10.6",
"@llamaindex/cloud": "0.0.5", "@llamaindex/cloud": "0.0.5",
"@llamaindex/env": "workspace:*", "@llamaindex/env": "workspace:*",
......
import {
GEMINI_MODEL,
GeminiSessionStore,
type GeminiConfig,
type GeminiSession,
} from "../llm/gemini.js";
import { BaseEmbedding } from "./types.js";
/**
* GeminiEmbedding is an alias for Gemini that implements the BaseEmbedding interface.
*/
export class GeminiEmbedding extends BaseEmbedding {
model: GEMINI_MODEL;
temperature: number;
topP: number;
maxTokens?: number;
session: GeminiSession;
constructor(init?: GeminiConfig) {
super();
this.model = init?.model ?? GEMINI_MODEL.GEMINI_PRO;
this.temperature = init?.temperature ?? 0.1;
this.topP = init?.topP ?? 1;
this.maxTokens = init?.maxTokens ?? undefined;
this.session = init?.session ?? GeminiSessionStore.get();
}
private async getEmbedding(prompt: string): Promise<number[]> {
const client = this.session.gemini.getGenerativeModel({
model: this.model,
});
const result = await client.embedContent(prompt);
return result.embedding.values;
}
getTextEmbedding(text: string): Promise<number[]> {
return this.getEmbedding(text);
}
getQueryEmbedding(query: string): Promise<number[]> {
return this.getTextEmbedding(query);
}
}
export * from "./ClipEmbedding.js"; export * from "./ClipEmbedding.js";
export * from "./GeminiEmbedding.js";
export * from "./HuggingFaceEmbedding.js"; export * from "./HuggingFaceEmbedding.js";
export * from "./JinaAIEmbedding.js"; export * from "./JinaAIEmbedding.js";
export * from "./MistralAIEmbedding.js"; export * from "./MistralAIEmbedding.js";
......
import {
ChatSession,
GoogleGenerativeAI,
type Content as GeminiMessageContent,
type Part,
} from "@google/generative-ai";
import { getEnv } from "@llamaindex/env";
import { ToolCallLLM } from "./base.js";
import type {
ChatMessage,
ChatResponse,
ChatResponseChunk,
CompletionResponse,
LLMChatParamsNonStreaming,
LLMChatParamsStreaming,
LLMCompletionParamsNonStreaming,
LLMCompletionParamsStreaming,
LLMMetadata,
MessageContent,
MessageContentImageDetail,
MessageContentTextDetail,
MessageType,
ToolCallLLMMessageOptions,
} from "./types.js";
import { streamConverter, wrapLLMEvent } from "./utils.js";
// Session and Model Type Definitions
type GeminiSessionOptions = {
apiKey?: string;
};
export enum GEMINI_MODEL {
GEMINI_PRO = "gemini-pro",
GEMINI_PRO_VISION = "gemini-pro-vision",
EMBEDDING_001 = "embedding-001",
AQA = "aqa",
GEMINI_PRO_LATEST = "gemini-1.5-pro-latest",
}
export interface GeminiModelInfo {
contextWindow: number;
}
export const GEMINI_MODEL_INFO_MAP: Record<GEMINI_MODEL, GeminiModelInfo> = {
[GEMINI_MODEL.GEMINI_PRO]: { contextWindow: 30720 },
[GEMINI_MODEL.GEMINI_PRO_VISION]: { contextWindow: 12288 },
[GEMINI_MODEL.EMBEDDING_001]: { contextWindow: 2048 },
[GEMINI_MODEL.AQA]: { contextWindow: 7168 },
[GEMINI_MODEL.GEMINI_PRO_LATEST]: { contextWindow: 10 ** 6 },
};
const SUPPORT_TOOL_CALL_MODELS: GEMINI_MODEL[] = [
GEMINI_MODEL.GEMINI_PRO,
GEMINI_MODEL.GEMINI_PRO_VISION,
GEMINI_MODEL.EMBEDDING_001,
GEMINI_MODEL.AQA,
];
const DEFAULT_GEMINI_PARAMS = {
model: GEMINI_MODEL.GEMINI_PRO,
temperature: 0.1,
topP: 1,
maxTokens: undefined,
};
export type GeminiConfig = Partial<typeof DEFAULT_GEMINI_PARAMS> & {
session?: GeminiSession;
};
/// Chat Type Definitions
type GeminiMessageRole = "user" | "model";
export type GeminiAdditionalChatOptions = {};
export type GeminiChatParamsStreaming = LLMChatParamsStreaming<
GeminiAdditionalChatOptions,
ToolCallLLMMessageOptions
>;
export type GeminiChatStreamResponse = AsyncIterable<
ChatResponseChunk<ToolCallLLMMessageOptions>
>;
export type GeminiChatParamsNonStreaming = LLMChatParamsNonStreaming<
GeminiAdditionalChatOptions,
ToolCallLLMMessageOptions
>;
export type GeminiChatNonStreamResponse =
ChatResponse<ToolCallLLMMessageOptions>;
/**
* Gemini Session to manage the connection to the Gemini API
*/
export class GeminiSession {
gemini: GoogleGenerativeAI;
constructor(options: GeminiSessionOptions) {
if (!options.apiKey) {
options.apiKey = getEnv("GOOGLE_API_KEY");
}
if (!options.apiKey) {
throw new Error("Set Google API Key in GOOGLE_API_KEY env variable");
}
this.gemini = new GoogleGenerativeAI(options.apiKey);
}
}
/**
* Gemini Session Store to manage the current Gemini sessions
*/
export class GeminiSessionStore {
static sessions: Array<{
session: GeminiSession;
options: GeminiSessionOptions;
}> = [];
private static sessionMatched(
o1: GeminiSessionOptions,
o2: GeminiSessionOptions,
): boolean {
return o1.apiKey === o2.apiKey;
}
static get(options: GeminiSessionOptions = {}): GeminiSession {
let session = this.sessions.find((session) =>
this.sessionMatched(session.options, options),
)?.session;
if (!session) {
session = new GeminiSession(options);
this.sessions.push({ session, options });
}
return session;
}
}
/**
* Helper class providing utility functions for Gemini
*/
class GeminiHelper {
// Gemini only has user and model roles. Put the rest in user role.
public static readonly ROLES_TO_GEMINI: Record<
MessageType,
GeminiMessageRole
> = {
user: "user",
system: "user",
assistant: "user",
memory: "user",
};
public static readonly ROLES_FROM_GEMINI: Record<
GeminiMessageRole,
MessageType
> = {
user: "user",
model: "assistant",
};
public static mergeNeighboringSameRoleMessages(
messages: ChatMessage[],
): ChatMessage[] {
// Gemini does not support multiple messages of the same role in a row, so we merge them
const mergedMessages: ChatMessage[] = [];
let i: number = 0;
while (i < messages.length) {
const currentMessage: ChatMessage = messages[i];
// Initialize merged content with current message content
const mergedContent: MessageContent[] = [currentMessage.content];
// Check if the next message exists and has the same role
while (
i + 1 < messages.length &&
this.ROLES_TO_GEMINI[messages[i + 1].role] ===
this.ROLES_TO_GEMINI[currentMessage.role]
) {
i++;
const nextMessage: ChatMessage = messages[i];
mergedContent.push(nextMessage.content);
}
// Create a new ChatMessage object with merged content
const mergedMessage: ChatMessage = {
role: currentMessage.role,
content: mergedContent.join("\n"),
};
mergedMessages.push(mergedMessage);
i++;
}
return mergedMessages;
}
public static messageContentToGeminiParts(content: MessageContent): Part[] {
if (typeof content === "string") {
return [{ text: content }];
}
const parts: Part[] = [];
const imageContents = content.filter(
(i) => i.type === "image_url",
) as MessageContentImageDetail[];
parts.push(
...imageContents.map((i) => ({
fileData: {
mimeType: i.type,
fileUri: i.image_url.url,
},
})),
);
const textContents = content.filter(
(i) => i.type === "text",
) as MessageContentTextDetail[];
parts.push(...textContents.map((t) => ({ text: t.text })));
return parts;
}
public static chatMessageToGemini(
message: ChatMessage,
): GeminiMessageContent {
return {
role: this.ROLES_TO_GEMINI[message.role],
parts: this.messageContentToGeminiParts(message.content),
};
}
}
/**
* ToolCallLLM for Gemini
*/
export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
model: GEMINI_MODEL;
temperature: number;
topP: number;
maxTokens?: number;
session: GeminiSession;
constructor(init?: GeminiConfig) {
super();
this.model = init?.model ?? GEMINI_MODEL.GEMINI_PRO;
this.temperature = init?.temperature ?? 0.1;
this.topP = init?.topP ?? 1;
this.maxTokens = init?.maxTokens ?? undefined;
this.session = init?.session ?? GeminiSessionStore.get();
}
get supportToolCall(): boolean {
return SUPPORT_TOOL_CALL_MODELS.includes(this.model);
}
get metadata(): LLMMetadata {
return {
model: this.model,
temperature: this.temperature,
topP: this.topP,
maxTokens: this.maxTokens,
contextWindow: GEMINI_MODEL_INFO_MAP[this.model].contextWindow,
tokenizer: undefined,
};
}
private prepareChat(
params: GeminiChatParamsStreaming | GeminiChatParamsNonStreaming,
): {
chat: ChatSession;
messageContent: Part[];
} {
const { messages } = params;
const mergedMessages =
GeminiHelper.mergeNeighboringSameRoleMessages(messages);
const history = mergedMessages.slice(0, -1);
const nextMessage = mergedMessages[mergedMessages.length - 1];
const messageContent = GeminiHelper.chatMessageToGemini(nextMessage).parts;
const client = this.session.gemini.getGenerativeModel(this.metadata);
const chat = client.startChat({
history: history.map(GeminiHelper.chatMessageToGemini),
});
return {
chat,
messageContent,
};
}
protected async nonStreamChat(
params: GeminiChatParamsNonStreaming,
): Promise<GeminiChatNonStreamResponse> {
const { chat, messageContent } = this.prepareChat(params);
const result = await chat.sendMessage(messageContent);
const { response } = result;
const topCandidate = response.candidates![0];
return {
raw: response,
message: {
content: response.text(),
role: GeminiHelper.ROLES_FROM_GEMINI[
topCandidate.content.role as GeminiMessageRole
],
},
};
}
protected async *streamChat(
params: GeminiChatParamsStreaming,
): GeminiChatStreamResponse {
const { chat, messageContent } = this.prepareChat(params);
const result = await chat.sendMessageStream(messageContent);
return streamConverter(result.stream, (response) => {
return {
text: response.text(),
raw: response,
};
});
}
chat(params: GeminiChatParamsStreaming): Promise<GeminiChatStreamResponse>;
chat(
params: GeminiChatParamsNonStreaming,
): Promise<GeminiChatNonStreamResponse>;
@wrapLLMEvent
async chat(
params: GeminiChatParamsStreaming | GeminiChatParamsNonStreaming,
): Promise<GeminiChatStreamResponse | GeminiChatNonStreamResponse> {
if (params.stream) return this.streamChat(params);
return this.nonStreamChat(params);
}
complete(
params: LLMCompletionParamsStreaming,
): Promise<AsyncIterable<CompletionResponse>>;
complete(
params: LLMCompletionParamsNonStreaming,
): Promise<CompletionResponse>;
async complete(
params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming,
): Promise<CompletionResponse | AsyncIterable<CompletionResponse>> {
const { prompt, stream } = params;
const client = this.session.gemini.getGenerativeModel(this.metadata);
if (stream) {
const result = await client.generateContentStream(
GeminiHelper.messageContentToGeminiParts(prompt),
);
return streamConverter(result.stream, (response) => {
return {
text: response.text(),
raw: response,
};
});
}
const result = await client.generateContent(
GeminiHelper.messageContentToGeminiParts(prompt),
);
return {
text: result.response.text(),
raw: result.response,
};
}
}
...@@ -10,6 +10,7 @@ export * from "./openai.js"; ...@@ -10,6 +10,7 @@ export * from "./openai.js";
export { Portkey } from "./portkey.js"; export { Portkey } from "./portkey.js";
export * from "./replicate_ai.js"; export * from "./replicate_ai.js";
// Note: The type aliases for replicate are to simplify usage for Llama 2 (we're using replicate for Llama 2 support) // Note: The type aliases for replicate are to simplify usage for Llama 2 (we're using replicate for Llama 2 support)
export { GEMINI_MODEL, Gemini } from "./gemini.js";
export { export {
DeuceChatStrategy, DeuceChatStrategy,
LlamaDeuce, LlamaDeuce,
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
"@anthropic-ai/sdk": "^0.20.6", "@anthropic-ai/sdk": "^0.20.6",
"@aws-crypto/sha256-js": "^5.2.0", "@aws-crypto/sha256-js": "^5.2.0",
"@datastax/astra-db-ts": "^1.0.1", "@datastax/astra-db-ts": "^1.0.1",
"@google/generative-ai": "^0.8.0",
"@grpc/grpc-js": "^1.10.6", "@grpc/grpc-js": "^1.10.6",
"@llamaindex/cloud": "0.0.5", "@llamaindex/cloud": "0.0.5",
"@llamaindex/env": "workspace:*", "@llamaindex/env": "workspace:*",
......
Source diff could not be displayed: it is too large. Options to address this: view the blob.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment