diff --git a/.changeset/calm-tables-camp.md b/.changeset/calm-tables-camp.md new file mode 100644 index 0000000000000000000000000000000000000000..b01e7f42087d121757c91728ba5d36ece8e23a58 --- /dev/null +++ b/.changeset/calm-tables-camp.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Support Astra VectorDB diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index 660b17a26f5f19a427b7fa0050abf1db900d9fc1..f7aab3f5666de01cc7789af227d50ac2725945c0 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -93,6 +93,21 @@ const getVectorDBEnvs = (vectorDb: TemplateVectorDB) => { description: "The password to access the Milvus server.", }, ]; + case "astra": + return [ + { + name: "ASTRA_DB_APPLICATION_TOKEN", + description: "The generated app token for your Astra database", + }, + { + name: "ASTRA_DB_ENDPOINT", + description: "The API endpoint for your Astra database", + }, + { + name: "ASTRA_DB_COLLECTION", + description: "The name of the collection in your Astra database", + }, + ]; default: return []; } diff --git a/helpers/python.ts b/helpers/python.ts index 56dc463d4568880951e82b69b5ce0c00e030e4ed..5b240ce13536963cf9c10377d75169da2993f254 100644 --- a/helpers/python.ts +++ b/helpers/python.ts @@ -60,6 +60,13 @@ const getAdditionalDependencies = ( }); break; } + case "astra": { + dependencies.push({ + name: "llama-index-vector-stores-astra-db", + version: "^0.1.5", + }); + break; + } } // Add data source dependencies diff --git a/helpers/types.ts b/helpers/types.ts index f7d0f4140c1ea8491e253afccb12ae0e4066bebf..52a300a757343e4fcf73d935bf4a9f5dffba8aa0 100644 --- a/helpers/types.ts +++ b/helpers/types.ts @@ -4,7 +4,13 @@ import { Tool } from "./tools"; export type TemplateType = "streaming" | "community" | "llamapack"; export type TemplateFramework = "nextjs" | "express" | "fastapi"; export type TemplateUI = "html" | "shadcn"; -export type TemplateVectorDB = "none" | "mongo" | "pg" | "pinecone" | "milvus"; +export type TemplateVectorDB = + | "none" + | "mongo" + | "pg" + | "pinecone" + | "milvus" + | "astra"; export type TemplatePostInstallAction = | "none" | "VSCode" diff --git a/questions.ts b/questions.ts index 1e5f326848caa567febe296b4f45117f2f872b2a..3ad02d9d3fa125d4ef7ca888b5490dc6dbca973f 100644 --- a/questions.ts +++ b/questions.ts @@ -101,6 +101,7 @@ const getVectorDbChoices = (framework: TemplateFramework) => { { title: "PostgreSQL", value: "pg" }, { title: "Pinecone", value: "pinecone" }, { title: "Milvus", value: "milvus" }, + { title: "Astra", value: "astra" }, ]; const vectordbLang = framework === "fastapi" ? "python" : "typescript"; diff --git a/templates/components/vectordbs/python/astra/__init__.py b/templates/components/vectordbs/python/astra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/templates/components/vectordbs/python/astra/generate.py b/templates/components/vectordbs/python/astra/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..2f5c65836fffa12fe9848976a1453211224a89e7 --- /dev/null +++ b/templates/components/vectordbs/python/astra/generate.py @@ -0,0 +1,37 @@ +from dotenv import load_dotenv + +load_dotenv() + +import os +import logging +from llama_index.core.storage import StorageContext +from llama_index.core.indices import VectorStoreIndex +from llama_index.vector_stores.astra_db import AstraDBVectorStore +from app.settings import init_settings +from app.engine.loaders import get_documents + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger() + + +def generate_datasource(): + logger.info("Creating new index") + documents = get_documents() + store = AstraDBVectorStore( + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_ENDPOINT"], + collection_name=os.environ["ASTRA_DB_COLLECTION"], + embedding_dimension=1536, + ) + storage_context = StorageContext.from_defaults(vector_store=store) + VectorStoreIndex.from_documents( + documents, + storage_context=storage_context, + show_progress=True, # this will show you a progress bar as the embeddings are created + ) + logger.info(f"Successfully created embeddings in the AstraDB") + + +if __name__ == "__main__": + init_settings() + generate_datasource() diff --git a/templates/components/vectordbs/python/astra/index.py b/templates/components/vectordbs/python/astra/index.py new file mode 100644 index 0000000000000000000000000000000000000000..d75508d7a289e889f2cc4c862df20d78b35be130 --- /dev/null +++ b/templates/components/vectordbs/python/astra/index.py @@ -0,0 +1,21 @@ +import logging +import os + +from llama_index.core.indices import VectorStoreIndex +from llama_index.vector_stores.astra_db import AstraDBVectorStore + + +logger = logging.getLogger("uvicorn") + + +def get_index(): + logger.info("Connecting to index from AstraDB...") + store = AstraDBVectorStore( + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_ENDPOINT"], + collection_name=os.environ["ASTRA_DB_COLLECTION"], + embedding_dimension=1536, + ) + index = VectorStoreIndex.from_vector_store(store) + logger.info("Finished connecting to index from AstraDB.") + return index diff --git a/templates/components/vectordbs/typescript/astra/generate.mjs b/templates/components/vectordbs/typescript/astra/generate.mjs new file mode 100644 index 0000000000000000000000000000000000000000..904ec009a7acf9cebff22c396dc08be95a29fcff --- /dev/null +++ b/templates/components/vectordbs/typescript/astra/generate.mjs @@ -0,0 +1,38 @@ +/* eslint-disable turbo/no-undeclared-env-vars */ +import * as dotenv from "dotenv"; +import { + AstraDBVectorStore, + VectorStoreIndex, + storageContextFromDefaults, +} from "llamaindex"; +import { getDocuments } from "./loader.mjs"; +import { checkRequiredEnvVars } from "./shared.mjs"; + +dotenv.config(); + +async function loadAndIndex() { + // load objects from storage and convert them into LlamaIndex Document objects + const documents = await getDocuments(); + + // create vector store and a collection + const collectionName = process.env.ASTRA_DB_COLLECTION; + const vectorStore = new AstraDBVectorStore(); + await vectorStore.create(collectionName, { + vector: { dimension: 1536, metric: "cosine" }, + }); + await vectorStore.connect(collectionName); + + // create index from documents and store them in Astra + console.log("Start creating embeddings..."); + const storageContext = await storageContextFromDefaults({ vectorStore }); + await VectorStoreIndex.fromDocuments(documents, { storageContext }); + console.log( + "Successfully created embeddings and save to your Astra database.", + ); +} + +(async () => { + checkRequiredEnvVars(); + await loadAndIndex(); + console.log("Finished generating storage."); +})(); diff --git a/templates/components/vectordbs/typescript/astra/index.ts b/templates/components/vectordbs/typescript/astra/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..4a44a11ac7d9b372191129ab5a1ceb053f1ab4bf --- /dev/null +++ b/templates/components/vectordbs/typescript/astra/index.ts @@ -0,0 +1,20 @@ +/* eslint-disable turbo/no-undeclared-env-vars */ +import { + AstraDBVectorStore, + LLM, + VectorStoreIndex, + serviceContextFromDefaults, +} from "llamaindex"; +import { CHUNK_OVERLAP, CHUNK_SIZE, checkRequiredEnvVars } from "./shared.mjs"; + +export async function getDataSource(llm: LLM) { + checkRequiredEnvVars(); + const serviceContext = serviceContextFromDefaults({ + llm, + chunkSize: CHUNK_SIZE, + chunkOverlap: CHUNK_OVERLAP, + }); + const store = new AstraDBVectorStore(); + await store.connect(process.env.ASTRA_DB_COLLECTION!); + return await VectorStoreIndex.fromVectorStore(store, serviceContext); +} diff --git a/templates/components/vectordbs/typescript/astra/shared.mjs b/templates/components/vectordbs/typescript/astra/shared.mjs new file mode 100644 index 0000000000000000000000000000000000000000..fb240187a3a0b58c89e79b0e3eb15f474093cf30 --- /dev/null +++ b/templates/components/vectordbs/typescript/astra/shared.mjs @@ -0,0 +1,25 @@ +export const CHUNK_SIZE = 512; +export const CHUNK_OVERLAP = 20; + +const REQUIRED_ENV_VARS = [ + "ASTRA_DB_APPLICATION_TOKEN", + "ASTRA_DB_ENDPOINT", + "ASTRA_DB_COLLECTION", +]; + +export function checkRequiredEnvVars() { + const missingEnvVars = REQUIRED_ENV_VARS.filter((envVar) => { + return !process.env[envVar]; + }); + + if (missingEnvVars.length > 0) { + console.log( + `The following environment variables are required but missing: ${missingEnvVars.join( + ", ", + )}`, + ); + throw new Error( + `Missing environment variables: ${missingEnvVars.join(", ")}`, + ); + } +}