Skip to content
Snippets Groups Projects
Commit df9c524e authored by thucpn's avatar thucpn
Browse files

Merge branch 'main' into tp/llamaindex-server-package

parents 95b6dcfa 04f8c96c
No related branches found
No related tags found
No related merge requests found
Showing
with 925 additions and 375 deletions
---
"@llamaindex/core": patch
---
fix: include additional options for context chat engine
---
"@llamaindex/google": patch
---
fix: don't ignore parts that only have inline data for google studio
---
"@llamaindex/mongodb": patch
---
Added mongo db document and key value store
---
"@llamaindex/workflow": patch
---
Fix: multi-agent handover
......@@ -56,10 +56,10 @@ const vectorStore = new QdrantVectorStore({
```ts
const document = new Document({ text: essay, id_: path });
const index = await VectorStoreIndex.fromDocuments([document], {
vectorStore,
});
const storageContext = await storageContextFromDefaults({ vectorStore });
const index = await VectorStoreIndex.fromDocuments([document], {
storageContext,
});
```
## Query the index
......@@ -91,11 +91,11 @@ async function main() {
});
const document = new Document({ text: essay, id_: path });
const storageContext = await storageContextFromDefaults({ vectorStore });
const index = await VectorStoreIndex.fromDocuments([document], {
vectorStore,
storageContext,
});
const queryEngine = index.asQueryEngine();
const response = await queryEngine.query({
......
......@@ -102,6 +102,7 @@ export class ContextChatEngine extends PromptMixin implements BaseChatEngine {
const stream = await this.chatModel.chat({
messages: requestMessages.messages,
stream: true,
additionalChatOptions: params.chatOptions as object,
});
return streamConverter(
streamReducer({
......@@ -117,6 +118,7 @@ export class ContextChatEngine extends PromptMixin implements BaseChatEngine {
}
const response = await this.chatModel.chat({
messages: requestMessages.messages,
additionalChatOptions: params.chatOptions as object,
});
chatHistory.put(response.message);
return EngineResponse.fromChatResponse(response, requestMessages.nodes);
......
......@@ -124,7 +124,7 @@ export const mapChatMessagesToGoogleMessages = <
return mapMessageContentToMessageContentDetails(msg.content)
.map((detail: MessageContentDetail): ContentUnion | null => {
const part = mapMessageContentDetailToGooglePart(detail);
if (!part.text) return null;
if (!part.text && !part.inlineData) return null;
return {
role: msg.role === "assistant" ? "model" : "user",
......
......@@ -35,10 +35,13 @@
},
"scripts": {
"build": "bunchee",
"dev": "bunchee --watch"
"dev": "bunchee --watch",
"test": "vitest"
},
"devDependencies": {
"bunchee": "6.4.0"
"bunchee": "6.4.0",
"vitest": "2.1.0",
"mongodb-memory-server": "^10.1.4"
},
"dependencies": {
"@llamaindex/core": "workspace:*",
......
import { KVDocumentStore } from "@llamaindex/core/storage/doc-store";
import { MongoClient } from "mongodb";
import { MongoKVStore } from "../kvStore/MongoKVStore";
const DEFAULT_DATABASE = "DocumentStoreDB";
const DEFAULT_COLLECTION = "DocumentStoreCollection";
interface MongoDBDocumentStoreConfig {
mongoKVStore: MongoKVStore;
namespace?: string;
}
export class MongoDocumentStore extends KVDocumentStore {
constructor({ mongoKVStore, namespace }: MongoDBDocumentStoreConfig) {
super(mongoKVStore, namespace);
}
/**
* Static method for creating an instance using a MongoClient.
* @returns Instance of MongoDBDocumentStore
* @param mongoClient - MongoClient instance
* @param dbName - Database name
* @param collectionName - Collection name
* @example
* ```ts
* const mongoClient = new MongoClient("mongodb://localhost:27017");
* const documentStore = MongoDBDocumentStore.fromMongoClient(mongoClient, "my_db", "my_collection");
* ```
*/
static fromMongoClient(
mongoClient: MongoClient,
dbName: string = DEFAULT_DATABASE,
collectionName: string = DEFAULT_COLLECTION,
): MongoDocumentStore {
const mongoKVStore = new MongoKVStore({
mongoClient,
dbName,
});
return new MongoDocumentStore({
mongoKVStore,
namespace: `${dbName}.${collectionName}`,
});
}
/**
* Static method for creating an instance using a connection string.
* @returns Instance of MongoDBDocumentStore
* @param connectionString - MongoDB connection string
* @param dbName - Database name
* @param collectionName - Collection name
* @example
* ```ts
* const documentStore = MongoDBDocumentStore.fromConnectionString("mongodb://localhost:27017", "my_db", "my_collection");
* ```
*/
static fromConnectionString(
connectionString: string,
dbName: string = DEFAULT_DATABASE,
collectionName: string = DEFAULT_COLLECTION,
): MongoDocumentStore {
const mongoClient = new MongoClient(connectionString, {
appName: "LLAMAINDEX_JS",
});
const mongoKVStore = new MongoKVStore({
mongoClient,
dbName,
});
return new MongoDocumentStore({
mongoKVStore,
namespace: `${dbName}.${collectionName}`,
});
}
}
export * from "./docStore/MongoDBDocumentStore";
export * from "./kvStore/MongoKVStore";
export * from "./MongoDBAtlasVectorStore";
import type { StoredValue } from "@llamaindex/core/schema";
import { BaseKVStore } from "@llamaindex/core/storage/kv-store";
import type { Collection, MongoClient } from "mongodb";
const DEFAULT_DB_NAME = "KVStoreDB";
const DEFAULT_COLLECTION_NAME = "KVStoreCollection";
interface MongoKVStoreConfig {
mongoClient: MongoClient;
dbName?: string;
}
export class MongoKVStore extends BaseKVStore {
private mongoClient: MongoClient;
private dbName: string;
constructor({ mongoClient, dbName = DEFAULT_DB_NAME }: MongoKVStoreConfig) {
super();
if (!mongoClient) {
throw new Error(
"MongoClient is required for MongoKVStore initialization",
);
}
this.mongoClient = mongoClient;
this.dbName = dbName;
}
get client(): MongoClient {
return this.mongoClient;
}
private async ensureCollection(collectionName: string): Promise<Collection> {
const collection = this.mongoClient
.db(this.dbName)
.collection(collectionName);
return collection;
}
async put(
key: string,
val: Exclude<StoredValue, null>,
collectionName: string = DEFAULT_COLLECTION_NAME,
): Promise<void> {
const collection = await this.ensureCollection(collectionName);
await collection.updateOne({ id: key }, { $set: val }, { upsert: true });
}
async get(
key: string,
collectionName: string = DEFAULT_COLLECTION_NAME,
): Promise<StoredValue> {
const collection = await this.ensureCollection(collectionName);
const result = await collection.findOne(
{ id: key },
{ projection: { _id: 0 } },
);
if (!result) {
return null;
}
return result;
}
async getAll(
collectionName: string = DEFAULT_COLLECTION_NAME,
): Promise<Record<string, StoredValue>> {
const collection = await this.ensureCollection(collectionName);
const cursor = collection.find({}, { projection: { _id: 0 } });
const output: Record<string, StoredValue> = {};
await cursor.forEach((item) => {
output[item.id] = item;
});
return output;
}
async delete(
key: string,
collectionName: string = DEFAULT_COLLECTION_NAME,
): Promise<boolean> {
const collection = await this.ensureCollection(collectionName);
await collection.deleteOne({ id: key });
return true;
}
}
import { Document, MetadataMode } from "@llamaindex/core/schema";
import { MongoClient } from "mongodb";
import { afterAll, beforeAll, beforeEach, describe, expect, it } from "vitest";
import { MongoDocumentStore } from "../docStore/MongoDBDocumentStore";
import { setupTestDb } from "./setuptTestDb";
describe("MongoDocumentStore", () => {
let cleanup: () => Promise<void>;
let mongoClient: MongoClient;
let documentStore: MongoDocumentStore;
let mongoUri: string;
beforeAll(async () => {
const testDb = await setupTestDb();
cleanup = testDb.cleanup;
mongoClient = testDb.mongoClient;
mongoUri = testDb.mongoUri;
documentStore = MongoDocumentStore.fromMongoClient(mongoClient);
}, 120000);
afterAll(async () => {
await cleanup();
});
describe("constructor", () => {
it("should create instance with mongoClient", () => {
const store = MongoDocumentStore.fromMongoClient(mongoClient);
expect(store).toBeInstanceOf(MongoDocumentStore);
});
it("should create instance with custom namespace", () => {
const store = MongoDocumentStore.fromMongoClient(
mongoClient,
"custom",
"namespace",
);
expect(store).toBeInstanceOf(MongoDocumentStore);
});
});
describe("static constructors", () => {
it("should create instance from connection string", () => {
const store = MongoDocumentStore.fromConnectionString(mongoUri);
expect(store).toBeInstanceOf(MongoDocumentStore);
});
it("should create instance from MongoClient", () => {
const store = MongoDocumentStore.fromMongoClient(mongoClient);
expect(store).toBeInstanceOf(MongoDocumentStore);
});
});
describe("document operations", () => {
beforeEach(async () => {
await mongoClient.db("test").collection("test").deleteMany({});
});
it("should store and retrieve a document", async () => {
const doc = new Document({ text: "test document", id_: "test_id" });
await documentStore.addDocuments([doc]);
const retrievedDoc = await documentStore.getDocument("test_id");
const text = retrievedDoc?.getContent(MetadataMode.ALL);
expect(text).toBe(doc.text);
});
it("should store and retrieve multiple documents", async () => {
const docs = [
new Document({ text: "doc1", id_: "id1" }),
new Document({ text: "doc2", id_: "id2" }),
];
await documentStore.addDocuments(docs);
const retrievedDocs = await documentStore.getNodes(["id1", "id2"]);
expect(retrievedDocs.map((d) => d?.getContent(MetadataMode.ALL))).toEqual(
docs.map((d) => d.text),
);
});
it("should handle missing documents", async () => {
const doc = await documentStore.getDocument("non_existent_id", false);
expect(doc).toBeUndefined();
});
});
describe("document updates", () => {
it("should update existing document when allowUpdate is true", async () => {
const doc1 = new Document({ text: "original", id_: "test_id" });
const doc2 = new Document({ text: "updated", id_: "test_id" });
await documentStore.addDocuments([doc1]);
await documentStore.addDocuments([doc2], true);
const retrieved = await documentStore.getDocument("test_id");
const text = retrieved?.getContent(MetadataMode.ALL);
expect(text).toBe("updated");
});
it("should throw error when updating with allowUpdate false", async () => {
const doc1 = new Document({ text: "original", id_: "test_id" });
await documentStore.addDocuments([doc1]);
const doc2 = new Document({ text: "updated", id_: "test_id" });
await expect(documentStore.addDocuments([doc2], false)).rejects.toThrow(
/doc_id.*already exists/,
);
});
});
describe("document deletion", () => {
it("should delete a document", async () => {
const doc = new Document({ text: "test document", id_: "test_id" });
await documentStore.addDocuments([doc]);
await documentStore.deleteDocument("test_id");
const retrieved = await documentStore.getDocument("test_id", false);
expect(retrieved).toBeUndefined();
});
it("should handle deleting non-existent document", async () => {
await expect(
documentStore.deleteDocument("non_existent_id"),
).resolves.not.toThrow();
});
});
describe("document existence", () => {
it("should check if document exists", async () => {
const doc = new Document({ text: "test document", id_: "test_id" });
await documentStore.addDocuments([doc]);
const exists = await documentStore.documentExists("test_id");
expect(exists).toBe(true);
});
it("should return false for non-existent document", async () => {
const exists = await documentStore.documentExists("non_existent_id");
expect(exists).toBe(false);
});
});
describe("document hash", () => {
it("should get document hash", async () => {
const doc = new Document({ text: "test document", id_: "test_id" });
await documentStore.addDocuments([doc]);
const hash = await documentStore.getDocumentHash("test_id");
expect(hash).toBe(doc.hash);
});
it("should return null for non-existent document hash", async () => {
const hash = await documentStore.getDocumentHash("non_existent_id");
expect(hash).toBeUndefined();
});
});
});
import type { MongoClient } from "mongodb";
import { afterAll, beforeAll, beforeEach, describe, expect, it } from "vitest";
import { MongoKVStore } from "../kvStore/MongoKVStore";
import { setupTestDb } from "./setuptTestDb";
describe("MongoKVStore", () => {
let cleanUp: () => Promise<void>;
let mongoClient: MongoClient;
let mongoUri: string;
let kvStore: MongoKVStore;
beforeAll(async () => {
const testDb = await setupTestDb();
cleanUp = testDb.cleanup;
mongoClient = testDb.mongoClient;
mongoUri = testDb.mongoUri;
kvStore = new MongoKVStore({
mongoClient,
dbName: "test",
});
}, 120000);
afterAll(async () => {
await cleanUp();
});
describe("Mongod KV store constructor", () => {
it("should create instance with mongoClient", () => {
const kvStore = new MongoKVStore({
mongoClient,
dbName: "test",
});
expect(kvStore).toBeInstanceOf(MongoKVStore);
});
it("should create db with custom db and collection name", () => {
const kvStore = new MongoKVStore({
mongoClient,
dbName: "test",
});
expect(kvStore).toBeInstanceOf(MongoKVStore);
});
});
describe("mongo kv store put ", () => {
it("should store a value", async () => {
const key = "test_key";
const value = { data: "test_value" };
await kvStore.put(key, value);
const result = await kvStore.get(key);
expect(result).toMatchObject(value);
});
it("should update an existing value", async () => {
const key = "test_key2";
const value = { data: "test_value" };
const value2 = { data: "test_value2" };
await kvStore.put(key, value);
await kvStore.put(key, value2);
const result = await kvStore.get(key);
expect(result).toMatchObject(value2);
});
});
describe("mongo kv store get", () => {
it("should return null for non-existent key", async () => {
const result = await kvStore.get("non_existent_key");
expect(result).toBeNull();
});
it("should return a value for stored key", async () => {
const key = "test_key";
const value = { data: "test_value" };
const result = await kvStore.get(key);
expect(result).toMatchObject(value);
});
});
describe("mongo kv store getAll", () => {
//reset the db before each test
beforeEach(async () => {
await mongoClient.db("test").collection("test").deleteMany({});
});
it("should return all values", async () => {
const items = {
test_key1: { data: "test_value1" },
test_key2: { data: "test_value2" },
};
await Promise.all([
kvStore.put("test_key1", items["test_key1"]),
kvStore.put("test_key2", items["test_key2"]),
]);
const result = await kvStore.getAll();
expect(result).toMatchObject(items);
});
});
describe("mongo kv store delete", () => {
it("should delete a value", async () => {
const key = "test_key";
const value = { data: "test_value" };
await kvStore.put(key, value);
const deleted = await kvStore.delete(key);
const result = await kvStore.get(key);
expect(result).toBeNull();
expect(deleted).toBe(true);
});
});
});
import { MongoClient } from "mongodb";
import { MongoMemoryServer } from "mongodb-memory-server";
//setup a in memory test db for testing
export async function setupTestDb() {
const mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
const mongoClient = new MongoClient(mongoUri);
await mongoClient.connect();
return {
mongoServer,
mongoClient,
mongoUri,
// return a cleanup function to close the db and stop the server
cleanup: async () => {
await mongoClient.close();
await mongoServer.stop();
},
};
}
......@@ -163,6 +163,24 @@ export class AgentWorkflow {
this.addAgents(processedAgents);
}
private addAgents(agents: BaseWorkflowAgent[]): void {
const agentNames = new Set(agents.map((a) => a.name));
if (agentNames.size !== agents.length) {
throw new Error("The agent names must be unique!");
}
agents.forEach((agent) => {
this.agents.set(agent.name, agent);
});
if (agents.length > 1) {
agents.forEach((agent) => {
this.validateAgent(agent);
this.addHandoffTool(agent);
});
}
}
private validateAgent(agent: BaseWorkflowAgent) {
// Validate that all canHandoffTo agents exist
const invalidAgents = agent.canHandoffTo.filter(
......@@ -176,7 +194,14 @@ export class AgentWorkflow {
}
private addHandoffTool(agent: BaseWorkflowAgent) {
const handoffTool = createHandoffTool(this.agents);
if (agent.tools.some((t) => t.metadata.name === "handOff")) {
return;
}
const toHandoffAgents: Map<string, BaseWorkflowAgent> = new Map();
agent.canHandoffTo.forEach((name) => {
toHandoffAgents.set(name, this.agents.get(name)!);
});
const handoffTool = createHandoffTool(toHandoffAgents);
if (
agent.canHandoffTo.length > 0 &&
!agent.tools.some((t) => t.metadata.name === handoffTool.metadata.name)
......@@ -185,24 +210,6 @@ export class AgentWorkflow {
}
}
private addAgents(agents: BaseWorkflowAgent[]): void {
const agentNames = new Set(agents.map((a) => a.name));
if (agentNames.size !== agents.length) {
throw new Error("The agent names must be unique!");
}
// First pass: add all agents to the map
agents.forEach((agent) => {
this.agents.set(agent.name, agent);
});
// Second pass: validate and setup handoff tools
agents.forEach((agent) => {
this.validateAgent(agent);
this.addHandoffTool(agent);
});
}
/**
* Adds a new agent to the workflow
*/
......@@ -226,7 +233,6 @@ export class AgentWorkflow {
* @param params - Parameters for the single agent workflow
* @returns A new AgentWorkflow instance
*/
static fromTools(params: SingleAgentParams): AgentWorkflow {
const agent = new FunctionAgent({
name: params.name,
......@@ -234,6 +240,7 @@ export class AgentWorkflow {
tools: params.tools,
llm: params.llm,
systemPrompt: params.systemPrompt,
canHandoffTo: params.canHandoffTo,
});
const workflow = new AgentWorkflow({
......
......@@ -3,7 +3,7 @@ import { FunctionTool } from "@llamaindex/core/tools";
import { MockLLM } from "@llamaindex/core/utils";
import { describe, expect, test, vi } from "vitest";
import { z } from "zod";
import { AgentWorkflow, FunctionAgent } from "../src/agent";
import { AgentWorkflow, FunctionAgent, agent, multiAgent } from "../src/agent";
import { setupToolCallingMockLLM } from "./mock";
describe("AgentWorkflow", () => {
......@@ -157,3 +157,125 @@ describe("AgentWorkflow", () => {
//
});
});
describe("Multiple agents", () => {
test("multiple agents are set up correctly with handoff capabilities", () => {
// Create mock LLM
const mockLLM = new MockLLM();
mockLLM.supportToolCall = true;
// Create tools for agents
const addTool = FunctionTool.from(
(params: { x: number; y: number }) => params.x + params.y,
{
name: "add",
description: "Adds two numbers",
parameters: z.object({
x: z.number(),
y: z.number(),
}),
},
);
const multiplyTool = FunctionTool.from(
(params: { x: number; y: number }) => params.x * params.y,
{
name: "multiply",
description: "Multiplies two numbers",
parameters: z.object({
x: z.number(),
y: z.number(),
}),
},
);
const subtractTool = FunctionTool.from(
(params: { x: number; y: number }) => params.x - params.y,
{
name: "subtract",
description: "Subtracts two numbers",
parameters: z.object({
x: z.number(),
y: z.number(),
}),
},
);
// Create agents using the agent() function
const addAgent = agent({
name: "AddAgent",
description: "Agent that can add numbers",
tools: [addTool],
llm: mockLLM,
});
const multiplyAgent = agent({
name: "MultiplyAgent",
description: "Agent that can multiply numbers",
tools: [multiplyTool],
llm: mockLLM,
});
const mathAgent = agent({
name: "MathAgent",
description: "Agent that can do various math operations",
tools: [addTool, multiplyTool, subtractTool],
llm: mockLLM,
canHandoffTo: ["AddAgent", "MultiplyAgent"],
});
// Create workflow with multiple agents using multiAgent
const workflow = multiAgent({
agents: [mathAgent, addAgent, multiplyAgent],
rootAgent: mathAgent,
verbose: false,
});
// Verify agents are set up correctly
expect(workflow).toBeDefined();
expect(workflow.getAgents().length).toBe(3);
// Verify that the mathAgent has a handoff tool
const mathAgentInstance = workflow
.getAgents()
.find((agent) => agent.name === "MathAgent");
expect(mathAgentInstance).toBeDefined();
expect(
mathAgentInstance?.tools.some((tool) => tool.metadata.name === "handOff"),
).toBe(true);
// Verify that addAgent and multiplyAgent don't have handoff tools since they don't handoff to other agents
const addAgentInstance = workflow
.getAgents()
.find((agent) => agent.name === "AddAgent");
expect(addAgentInstance).toBeDefined();
expect(
addAgentInstance?.tools.some((tool) => tool.metadata.name === "handOff"),
).toBe(false);
const multiplyAgentInstance = workflow
.getAgents()
.find((agent) => agent.name === "MultiplyAgent");
expect(multiplyAgentInstance).toBeDefined();
expect(
multiplyAgentInstance?.tools.some(
(tool) => tool.metadata.name === "handOff",
),
).toBe(false);
// Verify agent specific tools are preserved
expect(
mathAgentInstance?.tools.some(
(tool) => tool.metadata.name === "subtract",
),
).toBe(true);
expect(
addAgentInstance?.tools.some((tool) => tool.metadata.name === "add"),
).toBe(true);
expect(
multiplyAgentInstance?.tools.some(
(tool) => tool.metadata.name === "multiply",
),
).toBe(true);
});
});
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment