Skip to content
Snippets Groups Projects
Unverified Commit c06d4af7 authored by Huu Le (Lee)'s avatar Huu Le (Lee) Committed by GitHub
Browse files

feat: Update FastAPI endpoint to support nodeSources (#30)

parent 27397143
No related branches found
No related tags found
No related merge requests found
---
"create-llama": patch
---
Add nodes to the response and support Vercel streaming format
...@@ -505,7 +505,7 @@ export const askQuestions = async ( ...@@ -505,7 +505,7 @@ export const askQuestions = async (
if (program.framework === "nextjs" || program.frontend) { if (program.framework === "nextjs" || program.frontend) {
if (!program.ui) { if (!program.ui) {
program.ui = getPrefOrDefault("ui"); program.ui = defaults.ui;
} }
} }
......
from typing import List
from pydantic import BaseModel from pydantic import BaseModel
from fastapi.responses import StreamingResponse from typing import List, Any, Optional, Dict, Tuple
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_index.core.chat_engine.types import BaseChatEngine from llama_index.core.chat_engine.types import (
BaseChatEngine,
StreamingAgentChatResponse,
)
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import ChatMessage, MessageRole from llama_index.core.llms import ChatMessage, MessageRole
from app.engine import get_chat_engine from app.engine import get_chat_engine
from typing import List, Tuple from app.api.routers.vercel_response import VercelStreamResponse
chat_router = r = APIRouter() chat_router = r = APIRouter()
...@@ -19,8 +22,27 @@ class _ChatData(BaseModel): ...@@ -19,8 +22,27 @@ class _ChatData(BaseModel):
messages: List[_Message] messages: List[_Message]
class _SourceNodes(BaseModel):
id: str
metadata: Dict[str, Any]
score: Optional[float]
@classmethod
def from_source_node(cls, source_node: NodeWithScore):
return cls(
id=source_node.node.node_id,
metadata=source_node.node.metadata,
score=source_node.score,
)
@classmethod
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
return [cls.from_source_node(node) for node in source_nodes]
class _Result(BaseModel): class _Result(BaseModel):
result: _Message result: _Message
nodes: List[_SourceNodes]
async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]: async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]:
...@@ -58,13 +80,25 @@ async def chat( ...@@ -58,13 +80,25 @@ async def chat(
response = await chat_engine.astream_chat(last_message_content, messages) response = await chat_engine.astream_chat(last_message_content, messages)
async def event_generator(): async def event_generator(request: Request, response: StreamingAgentChatResponse):
# Yield the text response
async for token in response.async_response_gen(): async for token in response.async_response_gen():
# If client closes connection, stop sending events
if await request.is_disconnected(): if await request.is_disconnected():
break break
yield token yield VercelStreamResponse.convert_text(token)
# Yield the source nodes
yield VercelStreamResponse.convert_data(
{
"nodes": [
_SourceNodes.from_source_node(node).dict()
for node in response.source_nodes
]
}
)
return StreamingResponse(event_generator(), media_type="text/plain") return VercelStreamResponse(content=event_generator(request, response))
# non-streaming endpoint - delete if not needed # non-streaming endpoint - delete if not needed
...@@ -77,5 +111,6 @@ async def chat_request( ...@@ -77,5 +111,6 @@ async def chat_request(
response = await chat_engine.achat(last_message_content, messages) response = await chat_engine.achat(last_message_content, messages)
return _Result( return _Result(
result=_Message(role=MessageRole.ASSISTANT, content=response.response) result=_Message(role=MessageRole.ASSISTANT, content=response.response),
nodes=_SourceNodes.from_source_nodes(response.source_nodes),
) )
import json
from typing import Any
from fastapi.responses import StreamingResponse
class VercelStreamResponse(StreamingResponse):
"""
Class to convert the response from the chat engine to the streaming format expected by Vercel/AI
"""
TEXT_PREFIX = "0:"
DATA_PREFIX = "2:"
VERCEL_HEADERS = {
"X-Experimental-Stream-Data": "true",
"Content-Type": "text/plain; charset=utf-8",
"Access-Control-Expose-Headers": "X-Experimental-Stream-Data",
}
@classmethod
def convert_text(cls, token: str):
return f'{cls.TEXT_PREFIX}"{token}"\n'
@classmethod
def convert_data(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}[{data_str}]\n"
def __init__(self, content: Any, **kwargs):
super().__init__(
content=content,
headers=self.VERCEL_HEADERS,
**kwargs,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment