Skip to content
Snippets Groups Projects
Unverified Commit 27397143 authored by Huu Le (Lee)'s avatar Huu Le (Lee) Committed by GitHub
Browse files

feat: Add database data source (MySQL and PostgreSQL) (#28)

parent 665c26cc
No related branches found
No related tags found
No related merge requests found
---
"create-llama": patch
---
Use databases as data source
...@@ -9,6 +9,7 @@ import { templatesDir } from "./dir"; ...@@ -9,6 +9,7 @@ import { templatesDir } from "./dir";
import { isPoetryAvailable, tryPoetryInstall } from "./poetry"; import { isPoetryAvailable, tryPoetryInstall } from "./poetry";
import { Tool } from "./tools"; import { Tool } from "./tools";
import { import {
DbSourceConfig,
InstallTemplateArgs, InstallTemplateArgs,
TemplateDataSource, TemplateDataSource,
TemplateVectorDB, TemplateVectorDB,
...@@ -65,17 +66,34 @@ const getAdditionalDependencies = ( ...@@ -65,17 +66,34 @@ const getAdditionalDependencies = (
// Add data source dependencies // Add data source dependencies
const dataSourceType = dataSource?.type; const dataSourceType = dataSource?.type;
if (dataSourceType === "file") { switch (dataSourceType) {
// llama-index-readers-file (pdf, excel, csv) is already included in llama_index package case "file":
dependencies.push({ dependencies.push({
name: "docx2txt", name: "docx2txt",
version: "^0.8", version: "^0.8",
}); });
} else if (dataSourceType === "web") { break;
dependencies.push({ case "web":
name: "llama-index-readers-web", dependencies.push({
version: "^0.1.6", name: "llama-index-readers-web",
}); version: "^0.1.6",
});
break;
case "db":
dependencies.push({
name: "llama-index-readers-database",
version: "^0.1.3",
});
dependencies.push({
name: "pymysql",
version: "^1.1.0",
extras: ["rsa"],
});
dependencies.push({
name: "psycopg2",
version: "^2.9.9",
});
break;
} }
// Add tools dependencies // Add tools dependencies
...@@ -307,6 +325,26 @@ export const installPythonTemplate = async ({ ...@@ -307,6 +325,26 @@ export const installPythonTemplate = async ({
node.commentBefore = ` use_llama_parse: Use LlamaParse if \`true\`. Needs a \`LLAMA_CLOUD_API_KEY\` from https://cloud.llamaindex.ai set as environment variable`; node.commentBefore = ` use_llama_parse: Use LlamaParse if \`true\`. Needs a \`LLAMA_CLOUD_API_KEY\` from https://cloud.llamaindex.ai set as environment variable`;
loaderConfig.set("file", node); loaderConfig.set("file", node);
} }
// DB loader config
const dbLoaders = dataSources.filter((ds) => ds.type === "db");
if (dbLoaders.length > 0) {
const dbLoaderConfig = new Document({});
const configEntries = dbLoaders.map((ds) => {
const dsConfig = ds.config as DbSourceConfig;
return {
uri: dsConfig.uri,
queries: [dsConfig.queries],
};
});
const node = dbLoaderConfig.createNode(configEntries);
node.commentBefore = ` The configuration for the database loader, only supports MySQL and PostgreSQL databases for now.
uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
query: The query to fetch data from the database. E.g.: SELECT * FROM table`;
loaderConfig.set("db", node);
}
// Write loaders config // Write loaders config
if (Object.keys(loaderConfig).length > 0) { if (Object.keys(loaderConfig).length > 0) {
const loaderConfigPath = path.join(root, "config/loaders.yaml"); const loaderConfigPath = path.join(root, "config/loaders.yaml");
......
...@@ -14,7 +14,7 @@ export type TemplateDataSource = { ...@@ -14,7 +14,7 @@ export type TemplateDataSource = {
type: TemplateDataSourceType; type: TemplateDataSourceType;
config: TemplateDataSourceConfig; config: TemplateDataSourceConfig;
}; };
export type TemplateDataSourceType = "file" | "web"; export type TemplateDataSourceType = "file" | "web" | "db";
export type TemplateObservability = "none" | "opentelemetry"; export type TemplateObservability = "none" | "opentelemetry";
// Config for both file and folder // Config for both file and folder
export type FileSourceConfig = { export type FileSourceConfig = {
...@@ -25,8 +25,15 @@ export type WebSourceConfig = { ...@@ -25,8 +25,15 @@ export type WebSourceConfig = {
prefix?: string; prefix?: string;
depth?: number; depth?: number;
}; };
export type DbSourceConfig = {
uri?: string;
queries?: string;
};
export type TemplateDataSourceConfig = FileSourceConfig | WebSourceConfig; export type TemplateDataSourceConfig =
| FileSourceConfig
| WebSourceConfig
| DbSourceConfig;
export type CommunityProjectConfig = { export type CommunityProjectConfig = {
owner: string; owner: string;
......
...@@ -159,6 +159,10 @@ export const getDataSourceChoices = ( ...@@ -159,6 +159,10 @@ export const getDataSourceChoices = (
title: "Use website content (requires Chrome)", title: "Use website content (requires Chrome)",
value: "web", value: "web",
}); });
choices.push({
title: "Use data from a database (Mysql, PostgreSQL)",
value: "db",
});
} }
return choices; return choices;
}; };
...@@ -629,52 +633,93 @@ export const askQuestions = async ( ...@@ -629,52 +633,93 @@ export const askQuestions = async (
// user doesn't want another data source or any data source // user doesn't want another data source or any data source
break; break;
} }
if (selectedSource === "exampleFile") { switch (selectedSource) {
program.dataSources.push(EXAMPLE_FILE); case "exampleFile": {
} else if (selectedSource === "file" || selectedSource === "folder") { program.dataSources.push(EXAMPLE_FILE);
// Select local data source break;
const selectedPaths = await selectLocalContextData(selectedSource); }
for (const p of selectedPaths) { case "file":
case "folder": {
const selectedPaths = await selectLocalContextData(selectedSource);
for (const p of selectedPaths) {
program.dataSources.push({
type: "file",
config: {
path: p,
},
});
}
break;
}
case "web": {
const { baseUrl } = await prompts(
{
type: "text",
name: "baseUrl",
message: "Please provide base URL of the website: ",
initial: "https://www.llamaindex.ai",
validate: (value: string) => {
if (!value.includes("://")) {
value = `https://${value}`;
}
const urlObj = new URL(value);
if (
urlObj.protocol !== "https:" &&
urlObj.protocol !== "http:"
) {
return `URL=${value} has invalid protocol, only allow http or https`;
}
return true;
},
},
handlers,
);
program.dataSources.push({ program.dataSources.push({
type: "file", type: "web",
config: { config: {
path: p, baseUrl,
prefix: baseUrl,
depth: 1,
}, },
}); });
break;
} }
} else if (selectedSource === "web") { case "db": {
// Selected web data source const dbPrompts: prompts.PromptObject<string>[] = [
const { baseUrl } = await prompts( {
{ type: "text",
type: "text", name: "uri",
name: "baseUrl", message:
message: "Please provide base URL of the website: ", "Please enter the connection string (URI) for the database.",
initial: "https://www.llamaindex.ai", initial: "mysql+pymysql://user:pass@localhost:3306/mydb",
validate: (value: string) => { validate: (value: string) => {
if (!value.includes("://")) { if (!value) {
value = `https://${value}`; return "Please provide a valid connection string";
} } else if (
const urlObj = new URL(value); !(
if ( value.startsWith("mysql+pymysql://") ||
urlObj.protocol !== "https:" && value.startsWith("postgresql+psycopg://")
urlObj.protocol !== "http:" )
) { ) {
return `URL=${value} has invalid protocol, only allow http or https`; return "The connection string must start with 'mysql+pymysql://' for MySQL or 'postgresql+psycopg://' for PostgreSQL";
} }
return true; return true;
},
}, },
}, // Only ask for a query, user can provide more complex queries in the config file later
handlers, {
); type: (prev) => (prev ? "text" : null),
name: "queries",
program.dataSources.push({ message: "Please enter the SQL query to fetch data:",
type: "web", initial: "SELECT * FROM mytable",
config: { },
baseUrl, ];
prefix: baseUrl, program.dataSources.push({
depth: 1, type: "db",
}, config: await prompts(dbPrompts, handlers),
}); });
}
} }
} }
} }
......
...@@ -5,6 +5,7 @@ import logging ...@@ -5,6 +5,7 @@ import logging
from typing import Dict from typing import Dict
from app.engine.loaders.file import FileLoaderConfig, get_file_documents from app.engine.loaders.file import FileLoaderConfig, get_file_documents
from app.engine.loaders.web import WebLoaderConfig, get_web_documents from app.engine.loaders.web import WebLoaderConfig, get_web_documents
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -22,11 +23,17 @@ def get_documents(): ...@@ -22,11 +23,17 @@ def get_documents():
logger.info( logger.info(
f"Loading documents from loader: {loader_type}, config: {loader_config}" f"Loading documents from loader: {loader_type}, config: {loader_config}"
) )
if loader_type == "file": match loader_type:
document = get_file_documents(FileLoaderConfig(**loader_config)) case "file":
documents.extend(document) document = get_file_documents(FileLoaderConfig(**loader_config))
elif loader_type == "web": case "web":
document = get_web_documents(WebLoaderConfig(**loader_config)) document = get_web_documents(WebLoaderConfig(**loader_config))
documents.extend(document) case "db":
document = get_db_documents(
configs=[DBLoaderConfig(**cfg) for cfg in loader_config]
)
case _:
raise ValueError(f"Invalid loader type: {loader_type}")
documents.extend(document)
return documents return documents
import os
import logging
from typing import List
from pydantic import BaseModel, validator
from llama_index.core.indices.vector_store import VectorStoreIndex
logger = logging.getLogger(__name__)
class DBLoaderConfig(BaseModel):
uri: str
queries: List[str]
def get_db_documents(configs: list[DBLoaderConfig]):
from llama_index.readers.database import DatabaseReader
docs = []
for entry in configs:
loader = DatabaseReader(uri=entry.uri)
for query in entry.queries:
logger.info(f"Loading data from database with query: {query}")
documents = loader.load_data(query=query)
docs.extend(documents)
return documents
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment