diff --git a/.changeset/chilled-mangos-cheat.md b/.changeset/chilled-mangos-cheat.md new file mode 100644 index 0000000000000000000000000000000000000000..a10a3b5073df263cdba53d3d3554da93e2f6b226 --- /dev/null +++ b/.changeset/chilled-mangos-cheat.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Ensure the generation script always works diff --git a/helpers/datasources.ts b/helpers/datasources.ts index e56e2a7243c4d367a265818455d6d6d8c90f6f40..dc80207e43bf6061f97f4ac40d27734cba6d6b1b 100644 --- a/helpers/datasources.ts +++ b/helpers/datasources.ts @@ -36,74 +36,66 @@ export async function writeLoadersConfig( dataSources: TemplateDataSource[], useLlamaParse?: boolean, ) { - if (dataSources.length === 0) return; // no datasources, no config needed - const loaderConfig = new Document({}); - // Web loader config - if (dataSources.some((ds) => ds.type === "web")) { - const webLoaderConfig = new Document({}); - - // Create config for browser driver arguments - const driverArgNodeValue = webLoaderConfig.createNode([ - "--no-sandbox", - "--disable-dev-shm-usage", - ]); - driverArgNodeValue.commentBefore = - " The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode"; - webLoaderConfig.set("driver_arguments", driverArgNodeValue); - - // Create config for urls - const urlConfigs = dataSources - .filter((ds) => ds.type === "web") - .map((ds) => { - const dsConfig = ds.config as WebSourceConfig; - return { - base_url: dsConfig.baseUrl, - prefix: dsConfig.prefix, - depth: dsConfig.depth, - }; - }); - const urlConfigNode = webLoaderConfig.createNode(urlConfigs); - urlConfigNode.commentBefore = ` base_url: The URL to start crawling with - prefix: Only crawl URLs matching the specified prefix - depth: The maximum depth for BFS traversal - You can add more websites by adding more entries (don't forget the - prefix from YAML)`; - webLoaderConfig.set("urls", urlConfigNode); + const loaderConfig: Record<string, any> = {}; - // Add web config to the loaders config - loaderConfig.set("web", webLoaderConfig); - } + // Always set file loader config + loaderConfig.file = createFileLoaderConfig(useLlamaParse); - // File loader config - if (dataSources.some((ds) => ds.type === "file")) { - // Add documentation to web loader config - const node = loaderConfig.createNode({ - use_llama_parse: useLlamaParse, - }); - node.commentBefore = ` use_llama_parse: Use LlamaParse if \`true\`. Needs a \`LLAMA_CLOUD_API_KEY\` from https://cloud.llamaindex.ai set as environment variable`; - loaderConfig.set("file", node); + if (dataSources.some((ds) => ds.type === "web")) { + loaderConfig.web = createWebLoaderConfig(dataSources); } - // DB loader config const dbLoaders = dataSources.filter((ds) => ds.type === "db"); if (dbLoaders.length > 0) { - const dbLoaderConfig = new Document({}); - const configEntries = dbLoaders.map((ds) => { - const dsConfig = ds.config as DbSourceConfig; - return { - uri: dsConfig.uri, - queries: [dsConfig.queries], - }; - }); - - const node = dbLoaderConfig.createNode(configEntries); - node.commentBefore = ` The configuration for the database loader, only supports MySQL and PostgreSQL databases for now. - uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db - query: The query to fetch data from the database. E.g.: SELECT * FROM table`; - loaderConfig.set("db", node); + loaderConfig.db = createDbLoaderConfig(dbLoaders); } + // Create a new Document with the loaderConfig + const yamlDoc = new Document(loaderConfig); + // Write loaders config const loaderConfigPath = path.join(root, "config", "loaders.yaml"); await fs.mkdir(path.join(root, "config"), { recursive: true }); - await fs.writeFile(loaderConfigPath, yaml.stringify(loaderConfig)); + await fs.writeFile(loaderConfigPath, yaml.stringify(yamlDoc)); +} + +function createWebLoaderConfig(dataSources: TemplateDataSource[]): any { + const webLoaderConfig: Record<string, any> = {}; + + // Create config for browser driver arguments + webLoaderConfig.driver_arguments = [ + "--no-sandbox", + "--disable-dev-shm-usage", + ]; + + // Create config for urls + const urlConfigs = dataSources + .filter((ds) => ds.type === "web") + .map((ds) => { + const dsConfig = ds.config as WebSourceConfig; + return { + base_url: dsConfig.baseUrl, + prefix: dsConfig.prefix, + depth: dsConfig.depth, + }; + }); + webLoaderConfig.urls = urlConfigs; + + return webLoaderConfig; +} + +function createFileLoaderConfig(useLlamaParse?: boolean): any { + return { + use_llama_parse: useLlamaParse, + }; +} + +function createDbLoaderConfig(dbLoaders: TemplateDataSource[]): any { + return dbLoaders.map((ds) => { + const dsConfig = ds.config as DbSourceConfig; + return { + uri: dsConfig.uri, + queries: [dsConfig.queries], + }; + }); } diff --git a/helpers/index.ts b/helpers/index.ts index 6ff16e93d3d03d8483fe4f4ff85ce1e7c8dc09cd..56d07808f92480d2fe771429af73521eac52d05f 100644 --- a/helpers/index.ts +++ b/helpers/index.ts @@ -96,10 +96,11 @@ async function generateContextData( } } -const copyContextData = async ( +const prepareContextData = async ( root: string, dataSources: TemplateDataSource[], ) => { + await makeDir(path.join(root, "data")); for (const dataSource of dataSources) { const dataSourceConfig = dataSource?.config as FileSourceConfig; // Copy local data @@ -174,25 +175,25 @@ export const installTemplate = async ( await createBackendEnvFile(props.root, props); } - if (props.dataSources.length > 0) { + await prepareContextData( + props.root, + props.dataSources.filter((ds) => ds.type === "file"), + ); + + if ( + props.dataSources.length > 0 && + (props.postInstallAction === "runApp" || + props.postInstallAction === "dependencies") + ) { console.log("\nGenerating context data...\n"); - await copyContextData( - props.root, - props.dataSources.filter((ds) => ds.type === "file"), + await generateContextData( + props.framework, + props.modelConfig, + props.packageManager, + props.vectorDb, + props.llamaCloudKey, + props.useLlamaParse, ); - if ( - props.postInstallAction === "runApp" || - props.postInstallAction === "dependencies" - ) { - await generateContextData( - props.framework, - props.modelConfig, - props.packageManager, - props.vectorDb, - props.llamaCloudKey, - props.useLlamaParse, - ); - } } // Create outputs directory diff --git a/questions.ts b/questions.ts index 5e855cf9f86e997e1ecbcd32cdfe6c6d48a04c33..973aeba5a44930b6433206899662289fa76d8cf1 100644 --- a/questions.ts +++ b/questions.ts @@ -637,6 +637,7 @@ export const askQuestions = async ( type: "db", config: await prompts(dbPrompts, questionHandlers), }); + break; } case "llamacloud": { program.dataSources.push({