From d18f0399e5db60bf0bffbf6d7447684f47750864 Mon Sep 17 00:00:00 2001 From: Huu Le <39040748+leehuwuj@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:47:44 +0700 Subject: [PATCH] feat: Add e2b code artifact tool support for the FastAPI template (#339) --- .changeset/shiny-lamps-retire.md | 5 + e2e/python/resolve_dependencies.spec.ts | 1 + helpers/python.ts | 14 ++ helpers/tools.ts | 13 +- .../engines/python/agent/tools/artifact.py | 100 +++++++++++ .../components/routers/python/sandbox.py | 158 ++++++++++++++++++ .../fastapi/app/api/routers/__init__.py | 18 ++ .../fastapi/app/api/routers/models.py | 39 ++++- .../fastapi/app/engine/utils/file_helper.py | 64 +++++++ templates/types/streaming/fastapi/main.py | 11 +- .../components/ui/chat/widgets/Artifact.tsx | 63 +++---- 11 files changed, 444 insertions(+), 42 deletions(-) create mode 100644 .changeset/shiny-lamps-retire.md create mode 100644 templates/components/engines/python/agent/tools/artifact.py create mode 100644 templates/components/routers/python/sandbox.py create mode 100644 templates/types/streaming/fastapi/app/engine/utils/file_helper.py diff --git a/.changeset/shiny-lamps-retire.md b/.changeset/shiny-lamps-retire.md new file mode 100644 index 00000000..9b8149fb --- /dev/null +++ b/.changeset/shiny-lamps-retire.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Add e2b code artifact tool for the FastAPI template diff --git a/e2e/python/resolve_dependencies.spec.ts b/e2e/python/resolve_dependencies.spec.ts index 159f2a78..a1b02802 100644 --- a/e2e/python/resolve_dependencies.spec.ts +++ b/e2e/python/resolve_dependencies.spec.ts @@ -34,6 +34,7 @@ if ( "wikipedia.WikipediaToolSpec", "google.GoogleSearchToolSpec", "document_generator", + "artifact", ]; const dataSources = [ diff --git a/helpers/python.ts b/helpers/python.ts index 575aa7e9..20d93394 100644 --- a/helpers/python.ts +++ b/helpers/python.ts @@ -280,6 +280,17 @@ const mergePoetryDependencies = ( } }; +const copyRouterCode = async (root: string, tools: Tool[]) => { + // Copy sandbox router if the artifact tool is selected + if (tools?.some((t) => t.name === "artifact")) { + await copy("sandbox.py", path.join(root, "app", "api", "routers"), { + parents: true, + cwd: path.join(templatesDir, "components", "routers", "python"), + rename: assetRelocator, + }); + } +}; + export const addDependencies = async ( projectDir: string, dependencies: Dependency[], @@ -431,6 +442,9 @@ export const installPythonTemplate = async ({ parents: true, cwd: path.join(compPath, "engines", "python", engine), }); + + // Copy router code + await copyRouterCode(root, tools ?? []); } if (template === "multiagent") { diff --git a/helpers/tools.ts b/helpers/tools.ts index b65957e7..0684a780 100644 --- a/helpers/tools.ts +++ b/helpers/tools.ts @@ -139,7 +139,7 @@ For better results, you can specify the region parameter to get results from a s dependencies: [ { name: "e2b_code_interpreter", - version: "0.0.7", + version: "0.0.10", }, ], supportedFrameworks: ["fastapi", "express", "nextjs"], @@ -165,8 +165,15 @@ For better results, you can specify the region parameter to get results from a s { display: "Artifact Code Generator", name: "artifact", - dependencies: [], - supportedFrameworks: ["express", "nextjs"], + // Using pre-release version of e2b_code_interpreter + // TODO: Update to stable version when 0.0.11 is released + dependencies: [ + { + name: "e2b_code_interpreter", + version: "^0.0.11b38", + }, + ], + supportedFrameworks: ["fastapi", "express", "nextjs"], type: ToolType.LOCAL, envVars: [ { diff --git a/templates/components/engines/python/agent/tools/artifact.py b/templates/components/engines/python/agent/tools/artifact.py new file mode 100644 index 00000000..9b132c64 --- /dev/null +++ b/templates/components/engines/python/agent/tools/artifact.py @@ -0,0 +1,100 @@ +import logging +from typing import Dict, List, Optional + +from llama_index.core.base.llms.types import ChatMessage +from llama_index.core.settings import Settings +from llama_index.core.tools import FunctionTool +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +# Prompt based on https://github.com/e2b-dev/ai-artifacts +CODE_GENERATION_PROMPT = """You are a skilled software engineer. You do not make mistakes. Generate an artifact. You can install additional dependencies. You can use one of the following templates: + +1. code-interpreter-multilang: "Runs code as a Jupyter notebook cell. Strong data analysis angle. Can use complex visualisation to explain results.". File: script.py. Dependencies installed: python, jupyter, numpy, pandas, matplotlib, seaborn, plotly. Port: none. + +2. nextjs-developer: "A Next.js 13+ app that reloads automatically. Using the pages router.". File: pages/index.tsx. Dependencies installed: nextjs@14.2.5, typescript, @types/node, @types/react, @types/react-dom, postcss, tailwindcss, shadcn. Port: 3000. + +3. vue-developer: "A Vue.js 3+ app that reloads automatically. Only when asked specifically for a Vue app.". File: app.vue. Dependencies installed: vue@latest, nuxt@3.13.0, tailwindcss. Port: 3000. + +4. streamlit-developer: "A streamlit app that reloads automatically.". File: app.py. Dependencies installed: streamlit, pandas, numpy, matplotlib, request, seaborn, plotly. Port: 8501. + +5. gradio-developer: "A gradio app. Gradio Blocks/Interface should be called demo.". File: app.py. Dependencies installed: gradio, pandas, numpy, matplotlib, request, seaborn, plotly. Port: 7860. + +Make sure to use the correct syntax for the programming language you're using. +""" + + +class CodeArtifact(BaseModel): + commentary: str = Field( + ..., + description="Describe what you're about to do and the steps you want to take for generating the artifact in great detail.", + ) + template: str = Field( + ..., description="Name of the template used to generate the artifact." + ) + title: str = Field(..., description="Short title of the artifact. Max 3 words.") + description: str = Field( + ..., description="Short description of the artifact. Max 1 sentence." + ) + additional_dependencies: List[str] = Field( + ..., + description="Additional dependencies required by the artifact. Do not include dependencies that are already included in the template.", + ) + has_additional_dependencies: bool = Field( + ..., + description="Detect if additional dependencies that are not included in the template are required by the artifact.", + ) + install_dependencies_command: str = Field( + ..., + description="Command to install additional dependencies required by the artifact.", + ) + port: Optional[int] = Field( + ..., + description="Port number used by the resulted artifact. Null when no ports are exposed.", + ) + file_path: str = Field( + ..., description="Relative path to the file, including the file name." + ) + code: str = Field( + ..., + description="Code generated by the artifact. Only runnable code is allowed.", + ) + + +class CodeGeneratorTool: + def __init__(self): + pass + + def artifact(self, query: str, old_code: Optional[str] = None) -> Dict: + """Generate a code artifact based on the input. + + Args: + query (str): The description of the application you want to build. + old_code (Optional[str], optional): The existing code to be modified. Defaults to None. + + Returns: + Dict: A dictionary containing the generated artifact information. + """ + + if old_code: + user_message = f"{query}\n\nThe existing code is: \n```\n{old_code}\n```" + else: + user_message = query + + messages: List[ChatMessage] = [ + ChatMessage(role="system", content=CODE_GENERATION_PROMPT), + ChatMessage(role="user", content=user_message), + ] + try: + sllm = Settings.llm.as_structured_llm(output_cls=CodeArtifact) + response = sllm.chat(messages) + data: CodeArtifact = response.raw + return data.model_dump() + except Exception as e: + logger.error(f"Failed to generate artifact: {str(e)}") + raise e + + +def get_tools(**kwargs): + return [FunctionTool.from_defaults(fn=CodeGeneratorTool().artifact)] diff --git a/templates/components/routers/python/sandbox.py b/templates/components/routers/python/sandbox.py new file mode 100644 index 00000000..c5a2a367 --- /dev/null +++ b/templates/components/routers/python/sandbox.py @@ -0,0 +1,158 @@ +# Copyright 2024 FoundryLabs, Inc. and LlamaIndex, Inc. +# Portions of this file are copied from the e2b project (https://github.com/e2b-dev/ai-artifacts) and then converted to Python +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import logging +import os +import uuid +from typing import Dict, List, Optional, Union + +from app.engine.tools.artifact import CodeArtifact +from app.engine.utils.file_helper import save_file +from e2b_code_interpreter import CodeInterpreter, Sandbox +from fastapi import APIRouter, HTTPException, Request +from pydantic import BaseModel + +logger = logging.getLogger("uvicorn") + +sandbox_router = APIRouter() + +SANDBOX_TIMEOUT = 10 * 60 # timeout in seconds +MAX_DURATION = 60 # max duration in seconds + + +class ExecutionResult(BaseModel): + template: str + stdout: List[str] + stderr: List[str] + runtime_error: Optional[Dict[str, Union[str, List[str]]]] = None + output_urls: List[Dict[str, str]] + url: Optional[str] + + def to_response(self): + """ + Convert the execution result to a response object (camelCase) + """ + return { + "template": self.template, + "stdout": self.stdout, + "stderr": self.stderr, + "runtimeError": self.runtime_error, + "outputUrls": self.output_urls, + "url": self.url, + } + + +@sandbox_router.post("") +async def create_sandbox(request: Request): + request_data = await request.json() + + try: + artifact = CodeArtifact(**request_data["artifact"]) + except Exception: + logger.error(f"Could not create artifact from request data: {request_data}") + return HTTPException( + status_code=400, detail="Could not create artifact from the request data" + ) + + sbx = None + + # Create an interpreter or a sandbox + if artifact.template == "code-interpreter-multilang": + sbx = CodeInterpreter(api_key=os.getenv("E2B_API_KEY"), timeout=SANDBOX_TIMEOUT) + logger.debug(f"Created code interpreter {sbx}") + else: + sbx = Sandbox( + api_key=os.getenv("E2B_API_KEY"), + template=artifact.template, + metadata={"template": artifact.template, "user_id": "default"}, + timeout=SANDBOX_TIMEOUT, + ) + logger.debug(f"Created sandbox {sbx}") + + # Install packages + if artifact.has_additional_dependencies: + if isinstance(sbx, CodeInterpreter): + sbx.notebook.exec_cell(artifact.install_dependencies_command) + logger.debug( + f"Installed dependencies: {', '.join(artifact.additional_dependencies)} in code interpreter {sbx}" + ) + elif isinstance(sbx, Sandbox): + sbx.commands.run(artifact.install_dependencies_command) + logger.debug( + f"Installed dependencies: {', '.join(artifact.additional_dependencies)} in sandbox {sbx}" + ) + + # Copy code to disk + if isinstance(artifact.code, list): + for file in artifact.code: + sbx.files.write(file.file_path, file.file_content) + logger.debug(f"Copied file to {file.file_path}") + else: + sbx.files.write(artifact.file_path, artifact.code) + logger.debug(f"Copied file to {artifact.file_path}") + + # Execute code or return a URL to the running sandbox + if artifact.template == "code-interpreter-multilang": + result = sbx.notebook.exec_cell(artifact.code or "") + output_urls = _download_cell_results(result.results) + return ExecutionResult( + template=artifact.template, + stdout=result.logs.stdout, + stderr=result.logs.stderr, + runtime_error=result.error, + output_urls=output_urls, + url=None, + ).to_response() + else: + return ExecutionResult( + template=artifact.template, + stdout=[], + stderr=[], + runtime_error=None, + output_urls=[], + url=f"https://{sbx.get_host(artifact.port or 80)}", + ).to_response() + + +def _download_cell_results(cell_results: Optional[List]) -> List[Dict[str, str]]: + """ + To pull results from code interpreter cell and save them to disk for serving + """ + if not cell_results: + return [] + + output = [] + for result in cell_results: + try: + formats = result.formats() + for ext in formats: + data = result[ext] + + if ext in ["png", "svg", "jpeg", "pdf"]: + file_path = f"output/tools/{uuid.uuid4()}.{ext}" + base64_data = data + buffer = base64.b64decode(base64_data) + file_meta = save_file(content=buffer, file_path=file_path) + output.append( + { + "type": ext, + "filename": file_meta.filename, + "url": file_meta.url, + } + ) + except Exception as e: + logger.error(f"Error processing result: {str(e)}") + + return output diff --git a/templates/types/streaming/fastapi/app/api/routers/__init__.py b/templates/types/streaming/fastapi/app/api/routers/__init__.py index e69de29b..8c897aa5 100644 --- a/templates/types/streaming/fastapi/app/api/routers/__init__.py +++ b/templates/types/streaming/fastapi/app/api/routers/__init__.py @@ -0,0 +1,18 @@ +from fastapi import APIRouter + +from .chat import chat_router # noqa: F401 +from .chat_config import config_router # noqa: F401 +from .upload import file_upload_router # noqa: F401 + +api_router = APIRouter() +api_router.include_router(chat_router, prefix="/chat") +api_router.include_router(config_router, prefix="/chat/config") +api_router.include_router(file_upload_router, prefix="/chat/upload") + +# Dynamically adding additional routers if they exist +try: + from .sandbox import sandbox_router # noqa: F401 + + api_router.include_router(sandbox_router, prefix="/sandbox") +except ImportError: + pass diff --git a/templates/types/streaming/fastapi/app/api/routers/models.py b/templates/types/streaming/fastapi/app/api/routers/models.py index 123f97ba..c5c8d4a3 100644 --- a/templates/types/streaming/fastapi/app/api/routers/models.py +++ b/templates/types/streaming/fastapi/app/api/routers/models.py @@ -55,9 +55,14 @@ class AgentAnnotation(BaseModel): text: str +class ArtifactAnnotation(BaseModel): + toolCall: Dict[str, Any] + toolOutput: Dict[str, Any] + + class Annotation(BaseModel): type: str - data: AnnotationFileData | List[str] | AgentAnnotation + data: AnnotationFileData | List[str] | AgentAnnotation | ArtifactAnnotation def to_content(self) -> str | None: if self.type == "document_file": @@ -146,8 +151,29 @@ class ChatData(BaseModel): break return agent_messages + def _get_latest_code_artifact(self) -> Optional[str]: + """ + Get latest code artifact from annotations to append to the user message + """ + for message in reversed(self.messages): + if ( + message.role == MessageRole.ASSISTANT + and message.annotations is not None + ): + for annotation in message.annotations: + # type is tools and has `toolOutput` attribute + if annotation.type == "tools" and isinstance( + annotation.data, ArtifactAnnotation + ): + tool_output = annotation.data.toolOutput + if tool_output and not tool_output.get("isError", False): + return tool_output.get("output", {}).get("code", None) + return None + def get_history_messages( - self, include_agent_messages: bool = False + self, + include_agent_messages: bool = False, + include_code_artifact: bool = True, ) -> List[ChatMessage]: """ Get the history messages @@ -164,7 +190,14 @@ class ChatData(BaseModel): content="Previous agent events: \n" + "\n".join(agent_messages), ) chat_messages.append(message) - + if include_code_artifact: + latest_code_artifact = self._get_latest_code_artifact() + if latest_code_artifact: + message = ChatMessage( + role=MessageRole.ASSISTANT, + content=f"The existing code is:\n```\n{latest_code_artifact}\n```", + ) + chat_messages.append(message) return chat_messages def is_last_message_from_user(self) -> bool: diff --git a/templates/types/streaming/fastapi/app/engine/utils/file_helper.py b/templates/types/streaming/fastapi/app/engine/utils/file_helper.py new file mode 100644 index 00000000..c794a3b4 --- /dev/null +++ b/templates/types/streaming/fastapi/app/engine/utils/file_helper.py @@ -0,0 +1,64 @@ +import logging +import os +import uuid +from typing import Optional + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class FileMetadata(BaseModel): + outputPath: str + filename: str + url: str + + +def save_file( + content: bytes | str, + file_name: Optional[str] = None, + file_path: Optional[str] = None, +) -> FileMetadata: + """ + Save the content to a file in the local file server (accessible via URL) + Args: + content (bytes | str): The content to save, either bytes or string. + file_name (Optional[str]): The name of the file. If not provided, a random name will be generated with .txt extension. + file_path (Optional[str]): The path to save the file to. If not provided, a random name will be generated. + Returns: + The metadata of the saved file. + """ + if file_name is not None and file_path is not None: + raise ValueError("Either file_name or file_path should be provided") + + if file_path is None: + if file_name is None: + file_name = f"{uuid.uuid4()}.txt" + file_path = os.path.join(os.getcwd(), file_name) + else: + file_name = os.path.basename(file_path) + + if isinstance(content, str): + content = content.encode() + + try: + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "wb") as file: + file.write(content) + except PermissionError as e: + logger.error(f"Permission denied when writing to file {file_path}: {str(e)}") + raise + except IOError as e: + logger.error(f"IO error occurred when writing to file {file_path}: {str(e)}") + raise + except Exception as e: + logger.error(f"Unexpected error when writing to file {file_path}: {str(e)}") + raise + + logger.info(f"Saved file to {file_path}") + + return FileMetadata( + outputPath=file_path, + filename=file_name, + url=f"{os.getenv('FILESERVER_URL_PREFIX')}/{file_path}", + ) diff --git a/templates/types/streaming/fastapi/main.py b/templates/types/streaming/fastapi/main.py index 12a54872..cf1a4e8c 100644 --- a/templates/types/streaming/fastapi/main.py +++ b/templates/types/streaming/fastapi/main.py @@ -1,7 +1,6 @@ # flake8: noqa: E402 -from dotenv import load_dotenv - from app.config import DATA_DIR +from dotenv import load_dotenv load_dotenv() @@ -9,9 +8,7 @@ import logging import os import uvicorn -from app.api.routers.chat import chat_router -from app.api.routers.chat_config import config_router -from app.api.routers.upload import file_upload_router +from app.api.routers import api_router from app.observability import init_observability from app.settings import init_settings from fastapi import FastAPI @@ -58,9 +55,7 @@ mount_static_files(DATA_DIR, "/api/files/data") # Mount the output files from tools mount_static_files("output", "/api/files/output") -app.include_router(chat_router, prefix="/api/chat") -app.include_router(config_router, prefix="/api/chat/config") -app.include_router(file_upload_router, prefix="/api/chat/upload") +app.include_router(api_router, prefix="/api") if __name__ == "__main__": app_host = os.getenv("APP_HOST", "0.0.0.0") diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/widgets/Artifact.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/widgets/Artifact.tsx index ff8b8077..6f808efd 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/widgets/Artifact.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/widgets/Artifact.tsx @@ -119,6 +119,22 @@ export function Artifact({ className="w-[45vw] fixed top-0 right-0 h-screen z-50 artifact-panel animate-slideIn" ref={panelRef} > + <div className="flex justify-between items-center pl-5 pr-10 py-6 border-b"> + <div className="space-y-2"> + <h2 className="text-2xl font-bold m-0">{artifact?.title}</h2> + <span className="text-sm text-gray-500">Version: v{version}</span> + </div> + <Button + onClick={() => { + closePanel(); + setOpenOutputPanel(false); + }} + variant="outline" + > + Close + </Button> + </div> + {sandboxCreating && ( <div className="flex justify-center items-center h-full"> <Loader2 className="h-6 w-6 animate-spin" /> @@ -159,35 +175,26 @@ function ArtifactOutput({ const { url: sandboxUrl, outputUrls, runtimeError, stderr, stdout } = result; return ( - <> - <div className="flex justify-between items-center pl-5 pr-10 py-6"> - <div className="space-y-2"> - <h2 className="text-2xl font-bold m-0">{artifact.title}</h2> - <span className="text-sm text-gray-500">Version: v{version}</span> + <Tabs defaultValue="code" className="h-full p-4 overflow-auto"> + <TabsList className="grid grid-cols-2 max-w-[400px] mx-auto"> + <TabsTrigger value="code">Code</TabsTrigger> + <TabsTrigger value="preview">Preview</TabsTrigger> + </TabsList> + <TabsContent value="code" className="h-[80%] mb-4 overflow-auto"> + <div className="m-4 overflow-auto"> + <Markdown content={markdownCode} /> </div> - <Button onClick={closePanel}>Close</Button> - </div> - <Tabs defaultValue="code" className="h-full p-4 overflow-auto"> - <TabsList className="grid grid-cols-2 max-w-[400px] mx-auto"> - <TabsTrigger value="code">Code</TabsTrigger> - <TabsTrigger value="preview">Preview</TabsTrigger> - </TabsList> - <TabsContent value="code" className="h-[80%] mb-4 overflow-auto"> - <div className="m-4 overflow-auto"> - <Markdown content={markdownCode} /> - </div> - </TabsContent> - <TabsContent - value="preview" - className="h-[80%] mb-4 overflow-auto mt-4 space-y-4" - > - {runtimeError && <RunTimeError runtimeError={runtimeError} />} - <ArtifactLogs stderr={stderr} stdout={stdout} /> - {sandboxUrl && <CodeSandboxPreview url={sandboxUrl} />} - {outputUrls && <InterpreterOutput outputUrls={outputUrls} />} - </TabsContent> - </Tabs> - </> + </TabsContent> + <TabsContent + value="preview" + className="h-[80%] mb-4 overflow-auto mt-4 space-y-4" + > + {runtimeError && <RunTimeError runtimeError={runtimeError} />} + <ArtifactLogs stderr={stderr} stdout={stdout} /> + {sandboxUrl && <CodeSandboxPreview url={sandboxUrl} />} + {outputUrls && <InterpreterOutput outputUrls={outputUrls} />} + </TabsContent> + </Tabs> ); } -- GitLab