Skip to content
Snippets Groups Projects
Unverified Commit 0e78ba46 authored by Marcus Schiesser's avatar Marcus Schiesser Committed by GitHub
Browse files

fix: .env not loaded on poetry run generate (#348)


--------
Co-authored-by: default avatarleehuwuj <leehuwuj@gmail.com>
parent 7652b2b3
No related branches found
No related tags found
No related merge requests found
---
"create-llama": patch
---
Fix: programmatically ensure index for LlamaCloud
---
"create-llama": patch
---
Fix .env not loaded on poetry run generate
...@@ -5,7 +5,7 @@ from app.api.routers.models import ( ...@@ -5,7 +5,7 @@ from app.api.routers.models import (
ChatData, ChatData,
) )
from app.api.routers.vercel_response import VercelStreamResponse from app.api.routers.vercel_response import VercelStreamResponse
from app.engine import get_chat_engine from app.engine.engine import get_chat_engine
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
chat_router = r = APIRouter() chat_router = r = APIRouter()
......
# flake8: noqa: E402 # flake8: noqa: E402
import os
from dotenv import load_dotenv from dotenv import load_dotenv
from app.engine.index import get_index
load_dotenv() load_dotenv()
from llama_cloud import PipelineType
from app.settings import init_settings
from llama_index.core.settings import Settings
from app.engine.index import get_client, get_index
import logging import logging
from llama_index.core.readers import SimpleDirectoryReader from llama_index.core.readers import SimpleDirectoryReader
from app.engine.service import LLamaCloudFileService from app.engine.service import LLamaCloudFileService
...@@ -13,10 +20,49 @@ logging.basicConfig(level=logging.INFO) ...@@ -13,10 +20,49 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger() logger = logging.getLogger()
def ensure_index(index):
project_id = index._get_project_id()
client = get_client()
pipelines = client.pipelines.search_pipelines(
project_id=project_id,
pipeline_name=index.name,
pipeline_type=PipelineType.MANAGED.value,
)
if len(pipelines) == 0:
from llama_index.embeddings.openai import OpenAIEmbedding
if not isinstance(Settings.embed_model, OpenAIEmbedding):
raise ValueError(
"Creating a new pipeline with a non-OpenAI embedding model is not supported."
)
client.pipelines.upsert_pipeline(
project_id=project_id,
request={
"name": index.name,
"embedding_config": {
"type": "OPENAI_EMBEDDING",
"component": {
"api_key": os.getenv("OPENAI_API_KEY"), # editable
"model_name": os.getenv("EMBEDDING_MODEL"),
},
},
"transform_config": {
"mode": "auto",
"config": {
"chunk_size": Settings.chunk_size, # editable
"chunk_overlap": Settings.chunk_overlap, # editable
},
},
},
)
def generate_datasource(): def generate_datasource():
init_settings()
logger.info("Generate index for the provided data") logger.info("Generate index for the provided data")
index = get_index() index = get_index()
ensure_index(index)
project_id = index._get_project_id() project_id = index._get_project_id()
pipeline_id = index._get_pipeline_id() pipeline_id = index._get_pipeline_id()
......
...@@ -7,7 +7,7 @@ from llama_index.core.ingestion.api_utils import ( ...@@ -7,7 +7,7 @@ from llama_index.core.ingestion.api_utils import (
get_client as llama_cloud_get_client, get_client as llama_cloud_get_client,
) )
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, field_validator
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
...@@ -15,31 +15,26 @@ logger = logging.getLogger("uvicorn") ...@@ -15,31 +15,26 @@ logger = logging.getLogger("uvicorn")
class LlamaCloudConfig(BaseModel): class LlamaCloudConfig(BaseModel):
# Private attributes # Private attributes
api_key: str = Field( api_key: str = Field(
default=os.getenv("LLAMA_CLOUD_API_KEY"),
exclude=True, # Exclude from the model representation exclude=True, # Exclude from the model representation
) )
base_url: Optional[str] = Field( base_url: Optional[str] = Field(
default=os.getenv("LLAMA_CLOUD_BASE_URL"),
exclude=True, exclude=True,
) )
organization_id: Optional[str] = Field( organization_id: Optional[str] = Field(
default=os.getenv("LLAMA_CLOUD_ORGANIZATION_ID"),
exclude=True, exclude=True,
) )
# Configuration attributes, can be set by the user # Configuration attributes, can be set by the user
pipeline: str = Field( pipeline: str = Field(
description="The name of the pipeline to use", description="The name of the pipeline to use",
default=os.getenv("LLAMA_CLOUD_INDEX_NAME"),
) )
project: str = Field( project: str = Field(
description="The name of the LlamaCloud project", description="The name of the LlamaCloud project",
default=os.getenv("LLAMA_CLOUD_PROJECT_NAME"),
) )
# Validate and throw error if the env variables are not set before starting the app # Validate and throw error if the env variables are not set before starting the app
@validator("pipeline", "project", "api_key", pre=True, always=True) @field_validator("pipeline", "project", "api_key", mode="before")
@classmethod @classmethod
def validate_env_vars(cls, value): def validate_fields(cls, value):
if value is None: if value is None:
raise ValueError( raise ValueError(
"Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY" "Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY"
...@@ -53,10 +48,20 @@ class LlamaCloudConfig(BaseModel): ...@@ -53,10 +48,20 @@ class LlamaCloudConfig(BaseModel):
"base_url": self.base_url, "base_url": self.base_url,
} }
@classmethod
def from_env(cls):
return LlamaCloudConfig(
api_key=os.getenv("LLAMA_CLOUD_API_KEY"),
base_url=os.getenv("LLAMA_CLOUD_BASE_URL"),
organization_id=os.getenv("LLAMA_CLOUD_ORGANIZATION_ID"),
pipeline=os.getenv("LLAMA_CLOUD_INDEX_NAME"),
project=os.getenv("LLAMA_CLOUD_PROJECT_NAME"),
)
class IndexConfig(BaseModel): class IndexConfig(BaseModel):
llama_cloud_pipeline_config: LlamaCloudConfig = Field( llama_cloud_pipeline_config: LlamaCloudConfig = Field(
default=LlamaCloudConfig(), default_factory=LlamaCloudConfig.from_env,
alias="llamaCloudPipeline", alias="llamaCloudPipeline",
) )
callback_manager: Optional[CallbackManager] = Field( callback_manager: Optional[CallbackManager] = Field(
...@@ -83,5 +88,5 @@ def get_index(config: IndexConfig = None): ...@@ -83,5 +88,5 @@ def get_index(config: IndexConfig = None):
def get_client(): def get_client():
config = LlamaCloudConfig() config = LlamaCloudConfig.from_env()
return llama_cloud_get_client(**config.to_client_kwargs()) return llama_cloud_get_client(**config.to_client_kwargs())
...@@ -13,7 +13,7 @@ from app.api.routers.models import ( ...@@ -13,7 +13,7 @@ from app.api.routers.models import (
SourceNodes, SourceNodes,
) )
from app.api.routers.vercel_response import VercelStreamResponse from app.api.routers.vercel_response import VercelStreamResponse
from app.engine import get_chat_engine from app.engine.engine import get_chat_engine
from app.engine.query_filter import generate_filters from app.engine.query_filter import generate_filters
chat_router = r = APIRouter() chat_router = r = APIRouter()
......
from .engine import get_chat_engine as get_chat_engine
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment