From c06d4af7b6d4c6f639bcfe797f150ba93f6cb03d Mon Sep 17 00:00:00 2001
From: "Huu Le (Lee)" <39040748+leehuwuj@users.noreply.github.com>
Date: Mon, 1 Apr 2024 15:59:14 +0700
Subject: [PATCH] feat: Update FastAPI endpoint to support nodeSources (#30)

---
 .changeset/eleven-lemons-look.md              |  5 ++
 questions.ts                                  |  2 +-
 .../streaming/fastapi/app/api/routers/chat.py | 51 ++++++++++++++++---
 .../app/api/routers/vercel_response.py        | 33 ++++++++++++
 4 files changed, 82 insertions(+), 9 deletions(-)
 create mode 100644 .changeset/eleven-lemons-look.md
 create mode 100644 templates/types/streaming/fastapi/app/api/routers/vercel_response.py

diff --git a/.changeset/eleven-lemons-look.md b/.changeset/eleven-lemons-look.md
new file mode 100644
index 00000000..84d3879a
--- /dev/null
+++ b/.changeset/eleven-lemons-look.md
@@ -0,0 +1,5 @@
+---
+"create-llama": patch
+---
+
+Add nodes to the response and support Vercel streaming format
diff --git a/questions.ts b/questions.ts
index 231be855..d81c9726 100644
--- a/questions.ts
+++ b/questions.ts
@@ -505,7 +505,7 @@ export const askQuestions = async (
 
   if (program.framework === "nextjs" || program.frontend) {
     if (!program.ui) {
-      program.ui = getPrefOrDefault("ui");
+      program.ui = defaults.ui;
     }
   }
 
diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py
index 2ef7ff1f..80fa7070 100644
--- a/templates/types/streaming/fastapi/app/api/routers/chat.py
+++ b/templates/types/streaming/fastapi/app/api/routers/chat.py
@@ -1,11 +1,14 @@
-from typing import List
 from pydantic import BaseModel
-from fastapi.responses import StreamingResponse
+from typing import List, Any, Optional, Dict, Tuple
 from fastapi import APIRouter, Depends, HTTPException, Request, status
-from llama_index.core.chat_engine.types import BaseChatEngine
+from llama_index.core.chat_engine.types import (
+    BaseChatEngine,
+    StreamingAgentChatResponse,
+)
+from llama_index.core.schema import NodeWithScore
 from llama_index.core.llms import ChatMessage, MessageRole
 from app.engine import get_chat_engine
-from typing import List, Tuple
+from app.api.routers.vercel_response import VercelStreamResponse
 
 chat_router = r = APIRouter()
 
@@ -19,8 +22,27 @@ class _ChatData(BaseModel):
     messages: List[_Message]
 
 
+class _SourceNodes(BaseModel):
+    id: str
+    metadata: Dict[str, Any]
+    score: Optional[float]
+
+    @classmethod
+    def from_source_node(cls, source_node: NodeWithScore):
+        return cls(
+            id=source_node.node.node_id,
+            metadata=source_node.node.metadata,
+            score=source_node.score,
+        )
+
+    @classmethod
+    def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
+        return [cls.from_source_node(node) for node in source_nodes]
+
+
 class _Result(BaseModel):
     result: _Message
+    nodes: List[_SourceNodes]
 
 
 async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]:
@@ -58,13 +80,25 @@ async def chat(
 
     response = await chat_engine.astream_chat(last_message_content, messages)
 
-    async def event_generator():
+    async def event_generator(request: Request, response: StreamingAgentChatResponse):
+        # Yield the text response
         async for token in response.async_response_gen():
+            # If client closes connection, stop sending events
             if await request.is_disconnected():
                 break
-            yield token
+            yield VercelStreamResponse.convert_text(token)
+
+        # Yield the source nodes
+        yield VercelStreamResponse.convert_data(
+            {
+                "nodes": [
+                    _SourceNodes.from_source_node(node).dict()
+                    for node in response.source_nodes
+                ]
+            }
+        )
 
-    return StreamingResponse(event_generator(), media_type="text/plain")
+    return VercelStreamResponse(content=event_generator(request, response))
 
 
 # non-streaming endpoint - delete if not needed
@@ -77,5 +111,6 @@ async def chat_request(
 
     response = await chat_engine.achat(last_message_content, messages)
     return _Result(
-        result=_Message(role=MessageRole.ASSISTANT, content=response.response)
+        result=_Message(role=MessageRole.ASSISTANT, content=response.response),
+        nodes=_SourceNodes.from_source_nodes(response.source_nodes),
     )
diff --git a/templates/types/streaming/fastapi/app/api/routers/vercel_response.py b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py
new file mode 100644
index 00000000..37392cc9
--- /dev/null
+++ b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py
@@ -0,0 +1,33 @@
+import json
+from typing import Any
+from fastapi.responses import StreamingResponse
+
+
+class VercelStreamResponse(StreamingResponse):
+    """
+    Class to convert the response from the chat engine to the streaming format expected by Vercel/AI
+    """
+
+    TEXT_PREFIX = "0:"
+    DATA_PREFIX = "2:"
+    VERCEL_HEADERS = {
+        "X-Experimental-Stream-Data": "true",
+        "Content-Type": "text/plain; charset=utf-8",
+        "Access-Control-Expose-Headers": "X-Experimental-Stream-Data",
+    }
+
+    @classmethod
+    def convert_text(cls, token: str):
+        return f'{cls.TEXT_PREFIX}"{token}"\n'
+
+    @classmethod
+    def convert_data(cls, data: dict):
+        data_str = json.dumps(data)
+        return f"{cls.DATA_PREFIX}[{data_str}]\n"
+
+    def __init__(self, content: Any, **kwargs):
+        super().__init__(
+            content=content,
+            headers=self.VERCEL_HEADERS,
+            **kwargs,
+        )
-- 
GitLab