Skip to content
Snippets Groups Projects
Commit d60b3c5a authored by leehuwuj's avatar leehuwuj
Browse files

refactor code and add changeset

parent c3e9ed3d
No related branches found
No related tags found
No related merge requests found
---
"create-llama": patch
---
Add support E2B code interpreter tool for FastAPI
...@@ -171,6 +171,11 @@ export const installTemplate = async ( ...@@ -171,6 +171,11 @@ export const installTemplate = async (
); );
} }
} }
// Create tool-output directory
if (props.tools && props.tools.length > 0) {
await fsExtra.mkdir(path.join(props.root, "tool-output"));
}
} else { } else {
// this is a frontend for a full-stack app, create .env file with model information // this is a frontend for a full-stack app, create .env file with model information
await createFrontendEnvFile(props.root, { await createFrontendEnvFile(props.root, {
......
...@@ -32,11 +32,6 @@ class E2BCodeInterpreter: ...@@ -32,11 +32,6 @@ class E2BCodeInterpreter:
self.api_key = api_key self.api_key = api_key
self.filesever_url_prefix = filesever_url_prefix self.filesever_url_prefix = filesever_url_prefix
def code_interpret(
self, code_interpreter: CodeInterpreter, code: str
) -> Tuple[List, List]:
pass
def get_output_path(self, filename: str) -> str: def get_output_path(self, filename: str) -> str:
# if output directory doesn't exist, create it # if output directory doesn't exist, create it
if not os.path.exists(self.output_dir): if not os.path.exists(self.output_dir):
...@@ -48,8 +43,12 @@ class E2BCodeInterpreter: ...@@ -48,8 +43,12 @@ class E2BCodeInterpreter:
buffer = base64.b64decode(base64_data) buffer = base64.b64decode(base64_data)
output_path = self.get_output_path(filename) output_path = self.get_output_path(filename)
with open(output_path, "wb") as file: try:
file.write(buffer) with open(output_path, "wb") as file:
file.write(buffer)
except IOError as e:
logger.error(f"Failed to write to file {output_path}: {str(e)}")
raise e
logger.info(f"Saved file to {output_path}") logger.info(f"Saved file to {output_path}")
...@@ -89,7 +88,7 @@ class E2BCodeInterpreter: ...@@ -89,7 +88,7 @@ class E2BCodeInterpreter:
return output return output
def interpret(self, code: str) -> Dict: def interpret(self, code: str) -> E2BToolOutput:
with CodeInterpreter(api_key=self.api_key) as interpreter: with CodeInterpreter(api_key=self.api_key) as interpreter:
logger.info( logger.info(
f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}" f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}"
...@@ -106,7 +105,7 @@ class E2BCodeInterpreter: ...@@ -106,7 +105,7 @@ class E2BCodeInterpreter:
output = E2BToolOutput( output = E2BToolOutput(
is_error=False, logs=exec.logs, results=results is_error=False, logs=exec.logs, results=results
) )
return output.dict() return output
def code_interpret(code: str) -> Dict: def code_interpret(code: str) -> Dict:
...@@ -127,7 +126,8 @@ def code_interpret(code: str) -> Dict: ...@@ -127,7 +126,8 @@ def code_interpret(code: str) -> Dict:
interpreter = E2BCodeInterpreter( interpreter = E2BCodeInterpreter(
api_key=api_key, filesever_url_prefix=filesever_url_prefix api_key=api_key, filesever_url_prefix=filesever_url_prefix
) )
return interpreter.interpret(code) output = interpreter.interpret(code)
return output.dict()
# Specify as functions tools to be loaded by the ToolFactory # Specify as functions tools to be loaded by the ToolFactory
......
...@@ -93,7 +93,13 @@ async def chat( ...@@ -93,7 +93,13 @@ async def chat(
event_handler = EventCallbackHandler() event_handler = EventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
response = await chat_engine.astream_chat(last_message_content, messages) try:
response = await chat_engine.astream_chat(last_message_content, messages)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error in chat engine: {e}",
)
async def content_generator(): async def content_generator():
# Yield the text response # Yield the text response
......
import json import json
import asyncio import asyncio
import logging
from typing import AsyncGenerator, Dict, Any, List, Optional from typing import AsyncGenerator, Dict, Any, List, Optional
from llama_index.core.callbacks.base import BaseCallbackHandler from llama_index.core.callbacks.base import BaseCallbackHandler
from llama_index.core.callbacks.schema import CBEventType from llama_index.core.callbacks.schema import CBEventType
...@@ -7,6 +8,9 @@ from llama_index.core.tools.types import ToolOutput ...@@ -7,6 +8,9 @@ from llama_index.core.tools.types import ToolOutput
from pydantic import BaseModel from pydantic import BaseModel
logger = logging.getLogger(__name__)
class CallbackEvent(BaseModel): class CallbackEvent(BaseModel):
event_type: CBEventType event_type: CBEventType
payload: Optional[Dict[str, Any]] = None payload: Optional[Dict[str, Any]] = None
...@@ -72,15 +76,19 @@ class CallbackEvent(BaseModel): ...@@ -72,15 +76,19 @@ class CallbackEvent(BaseModel):
} }
def to_response(self): def to_response(self):
match self.event_type: try:
case "retrieve": match self.event_type:
return self.get_retrieval_message() case "retrieve":
case "function_call": return self.get_retrieval_message()
return self.get_tool_message() case "function_call":
case "agent_step": return self.get_tool_message()
return self.get_agent_tool_response() case "agent_step":
case _: return self.get_agent_tool_response()
return None case _:
return None
except Exception as e:
logger.error(f"Error in converting event to response: {e}")
return None
class EventCallbackHandler(BaseCallbackHandler): class EventCallbackHandler(BaseCallbackHandler):
......
...@@ -41,8 +41,8 @@ if environment == "dev": ...@@ -41,8 +41,8 @@ if environment == "dev":
# Mount the data files to serve the file viewer # Mount the data files to serve the file viewer
if os.path.exists("data"): if os.path.exists("data"):
app.mount("/api/files/data", StaticFiles(directory="data"), name="data-static") app.mount("/api/files/data", StaticFiles(directory="data"), name="data-static")
# Mount the tool output files # Mount the output files from tools
if os.path.exists("config/tools.yaml"): if os.path.exists("tool-output"):
app.mount( app.mount(
"/api/files/tool-output", "/api/files/tool-output",
StaticFiles(directory="tool-output"), StaticFiles(directory="tool-output"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment