From 8105c5cf06465b5eaa25abbe90bd20b2e41f0c41 Mon Sep 17 00:00:00 2001
From: Huu Le <39040748+leehuwuj@users.noreply.github.com>
Date: Mon, 9 Sep 2024 14:39:36 +0700
Subject: [PATCH] feat: Make suggest next questions configurable (#275)

---------
Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
---
 .changeset/cyan-buttons-clean.md              |   5 +
 helpers/env-variables.ts                      |  49 ++++----
 helpers/python.ts                             |   7 ++
 .../typescript/streaming/suggestion.ts        |  28 ++---
 .../services/python}/file.py                  |   0
 .../components/services/python/suggestion.py  |  78 ++++++++++++
 .../app/api/routers/vercel_response.py        |  55 +++++---
 .../fastapi/app/api/services/suggestion.py    |  60 ---------
 .../app/api/routers/vercel_response.py        |  39 +++---
 .../fastapi/app/api/services/file.py          | 119 ------------------
 .../fastapi/app/api/services/suggestion.py    |  60 ---------
 .../types/streaming/fastapi/pyproject.toml    |   2 +-
 12 files changed, 182 insertions(+), 320 deletions(-)
 create mode 100644 .changeset/cyan-buttons-clean.md
 rename templates/{types/multiagent/fastapi/app/api/services => components/services/python}/file.py (100%)
 create mode 100644 templates/components/services/python/suggestion.py
 delete mode 100644 templates/types/multiagent/fastapi/app/api/services/suggestion.py
 delete mode 100644 templates/types/streaming/fastapi/app/api/services/file.py
 delete mode 100644 templates/types/streaming/fastapi/app/api/services/suggestion.py

diff --git a/.changeset/cyan-buttons-clean.md b/.changeset/cyan-buttons-clean.md
new file mode 100644
index 00000000..0ad06218
--- /dev/null
+++ b/.changeset/cyan-buttons-clean.md
@@ -0,0 +1,5 @@
+---
+"create-llama": patch
+---
+
+Add env config for next questions feature
diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts
index 783eb43e..11beb0e8 100644
--- a/helpers/env-variables.ts
+++ b/helpers/env-variables.ts
@@ -487,33 +487,30 @@ It\\'s cute animal.
 };
 
 const getTemplateEnvs = (template?: TemplateType): EnvVar[] => {
-  if (template === "multiagent") {
-    return [
-      {
-        name: "MESSAGE_QUEUE_PORT",
-      },
-      {
-        name: "CONTROL_PLANE_PORT",
-      },
-      {
-        name: "HUMAN_CONSUMER_PORT",
-      },
-      {
-        name: "AGENT_QUERY_ENGINE_PORT",
-        value: "8003",
-      },
-      {
-        name: "AGENT_QUERY_ENGINE_DESCRIPTION",
-        value: "Query information from the provided data",
-      },
-      {
-        name: "AGENT_DUMMY_PORT",
-        value: "8004",
-      },
-    ];
-  } else {
-    return [];
+  const nextQuestionEnvs: EnvVar[] = [
+    {
+      name: "NEXT_QUESTION_PROMPT",
+      description: `Customize prompt to generate the next question suggestions based on the conversation history.
+Disable this prompt to disable the next question suggestions feature.`,
+      value: `"You're a helpful assistant! Your task is to suggest the next question that user might ask. 
+Here is the conversation history
+---------------------
+{conversation}
+---------------------
+Given the conversation history, please give me 3 questions that you might ask next!
+Your answer should be wrapped in three sticks which follows the following format:
+\`\`\`
+<question 1>
+<question 2>
+<question 3>
+\`\`\`"`,
+    },
+  ];
+
+  if (template === "multiagent" || template === "streaming") {
+    return nextQuestionEnvs;
   }
+  return [];
 };
 
 const getObservabilityEnvs = (
diff --git a/helpers/python.ts b/helpers/python.ts
index aee63dad..b9f8b200 100644
--- a/helpers/python.ts
+++ b/helpers/python.ts
@@ -395,6 +395,13 @@ export const installPythonTemplate = async ({
     cwd: path.join(compPath, "settings", "python"),
   });
 
+  // Copy services
+  if (template == "streaming" || template == "multiagent") {
+    await copy("**", path.join(root, "app", "api", "services"), {
+      cwd: path.join(compPath, "services", "python"),
+    });
+  }
+
   if (template === "streaming") {
     // For the streaming template only:
     // Select and copy engine code based on data sources and tools
diff --git a/templates/components/llamaindex/typescript/streaming/suggestion.ts b/templates/components/llamaindex/typescript/streaming/suggestion.ts
index 0dacaead..d8949cc3 100644
--- a/templates/components/llamaindex/typescript/streaming/suggestion.ts
+++ b/templates/components/llamaindex/typescript/streaming/suggestion.ts
@@ -1,32 +1,20 @@
 import { ChatMessage, Settings } from "llamaindex";
 
-const NEXT_QUESTION_PROMPT_TEMPLATE = `You're a helpful assistant! Your task is to suggest the next question that user might ask. 
-Here is the conversation history
----------------------
-$conversation
----------------------
-Given the conversation history, please give me $number_of_questions questions that you might ask next!
-Your answer should be wrapped in three sticks which follows the following format:
-\`\`\`
-<question 1>
-<question 2>\`\`\`
-`;
-const N_QUESTIONS_TO_GENERATE = 3;
-
-export async function generateNextQuestions(
-  conversation: ChatMessage[],
-  numberOfQuestions: number = N_QUESTIONS_TO_GENERATE,
-) {
+export async function generateNextQuestions(conversation: ChatMessage[]) {
   const llm = Settings.llm;
+  const NEXT_QUESTION_PROMPT = process.env.NEXT_QUESTION_PROMPT;
+  if (!NEXT_QUESTION_PROMPT) {
+    return [];
+  }
 
   // Format conversation
   const conversationText = conversation
     .map((message) => `${message.role}: ${message.content}`)
     .join("\n");
-  const message = NEXT_QUESTION_PROMPT_TEMPLATE.replace(
-    "$conversation",
+  const message = NEXT_QUESTION_PROMPT.replace(
+    "{conversation}",
     conversationText,
-  ).replace("$number_of_questions", numberOfQuestions.toString());
+  );
 
   try {
     const response = await llm.complete({ prompt: message });
diff --git a/templates/types/multiagent/fastapi/app/api/services/file.py b/templates/components/services/python/file.py
similarity index 100%
rename from templates/types/multiagent/fastapi/app/api/services/file.py
rename to templates/components/services/python/file.py
diff --git a/templates/components/services/python/suggestion.py b/templates/components/services/python/suggestion.py
new file mode 100644
index 00000000..7959088e
--- /dev/null
+++ b/templates/components/services/python/suggestion.py
@@ -0,0 +1,78 @@
+import logging
+import os
+import re
+from typing import List, Optional
+
+from app.api.routers.models import Message
+from llama_index.core.prompts import PromptTemplate
+from llama_index.core.settings import Settings
+
+logger = logging.getLogger("uvicorn")
+
+
+class NextQuestionSuggestion:
+    """
+    Suggest the next questions that user might ask based on the conversation history
+    Disable this feature by removing the NEXT_QUESTION_PROMPT environment variable
+    """
+
+    @classmethod
+    def get_configured_prompt(cls) -> Optional[str]:
+        prompt = os.getenv("NEXT_QUESTION_PROMPT", None)
+        if not prompt:
+            return None
+        return PromptTemplate(prompt)
+
+    @classmethod
+    async def suggest_next_questions_all_messages(
+        cls,
+        messages: List[Message],
+    ) -> Optional[List[str]]:
+        """
+        Suggest the next questions that user might ask based on the conversation history
+        Return None if suggestion is disabled or there is an error
+        """
+        prompt_template = cls.get_configured_prompt()
+        if not prompt_template:
+            return None
+
+        try:
+            # Reduce the cost by only using the last two messages
+            last_user_message = None
+            last_assistant_message = None
+            for message in reversed(messages):
+                if message.role == "user":
+                    last_user_message = f"User: {message.content}"
+                elif message.role == "assistant":
+                    last_assistant_message = f"Assistant: {message.content}"
+                if last_user_message and last_assistant_message:
+                    break
+            conversation: str = f"{last_user_message}\n{last_assistant_message}"
+
+            # Call the LLM and parse questions from the output
+            prompt = prompt_template.format(conversation=conversation)
+            output = await Settings.llm.acomplete(prompt)
+            questions = cls._extract_questions(output.text)
+
+            return questions
+        except Exception as e:
+            logger.error(f"Error when generating next question: {e}")
+            return None
+
+    @classmethod
+    def _extract_questions(cls, text: str) -> List[str]:
+        content_match = re.search(r"```(.*?)```", text, re.DOTALL)
+        content = content_match.group(1) if content_match else ""
+        return content.strip().split("\n")
+
+    @classmethod
+    async def suggest_next_questions(
+        cls,
+        chat_history: List[Message],
+        response: str,
+    ) -> List[str]:
+        """
+        Suggest the next questions that user might ask based on the chat history and the last response
+        """
+        messages = chat_history + [Message(role="assistant", content=response)]
+        return await cls.suggest_next_questions_all_messages(messages)
diff --git a/templates/types/multiagent/fastapi/app/api/routers/vercel_response.py b/templates/types/multiagent/fastapi/app/api/routers/vercel_response.py
index ec03fb6c..29bcf852 100644
--- a/templates/types/multiagent/fastapi/app/api/routers/vercel_response.py
+++ b/templates/types/multiagent/fastapi/app/api/routers/vercel_response.py
@@ -1,15 +1,15 @@
-from asyncio import Task
 import json
 import logging
-from typing import AsyncGenerator
+from asyncio import Task
+from typing import AsyncGenerator, List
 
 from aiostream import stream
+from app.agents.single import AgentRunEvent, AgentRunResult
+from app.api.routers.models import ChatData, Message
+from app.api.services.suggestion import NextQuestionSuggestion
 from fastapi import Request
 from fastapi.responses import StreamingResponse
 
-from app.api.routers.models import ChatData
-from app.agents.single import AgentRunEvent, AgentRunResult
-
 logger = logging.getLogger("uvicorn")
 
 
@@ -57,26 +57,35 @@ class VercelStreamResponse(StreamingResponse):
         # Yield the text response
         async def _chat_response_generator():
             result = await task
+            final_response = ""
 
             if isinstance(result, AgentRunResult):
                 for token in result.response.message.content:
-                    yield VercelStreamResponse.convert_text(token)
+                    final_response += token
+                    yield cls.convert_text(token)
 
             if isinstance(result, AsyncGenerator):
                 async for token in result:
-                    yield VercelStreamResponse.convert_text(token.delta)
+                    final_response += token.delta
+                    yield cls.convert_text(token.delta)
+
+            # Generate next questions if next question prompt is configured
+            question_data = await cls._generate_next_questions(
+                chat_data.messages, final_response
+            )
+            if question_data:
+                yield cls.convert_data(question_data)
 
-            # TODO: stream NextQuestionSuggestion
             # TODO: stream sources
 
         # Yield the events from the event handler
         async def _event_generator():
             async for event in events():
-                event_response = _event_to_response(event)
+                event_response = cls._event_to_response(event)
                 if verbose:
                     logger.debug(event_response)
                 if event_response is not None:
-                    yield VercelStreamResponse.convert_data(event_response)
+                    yield cls.convert_data(event_response)
 
         combine = stream.merge(_chat_response_generator(), _event_generator())
 
@@ -85,16 +94,28 @@ class VercelStreamResponse(StreamingResponse):
             if not is_stream_started:
                 is_stream_started = True
                 # Stream a blank message to start the stream
-                yield VercelStreamResponse.convert_text("")
+                yield cls.convert_text("")
 
             async for output in streamer:
                 yield output
                 if await request.is_disconnected():
                     break
 
-
-def _event_to_response(event: AgentRunEvent) -> dict:
-    return {
-        "type": "agent",
-        "data": {"agent": event.name, "text": event.msg},
-    }
+    @staticmethod
+    def _event_to_response(event: AgentRunEvent) -> dict:
+        return {
+            "type": "agent",
+            "data": {"agent": event.name, "text": event.msg},
+        }
+
+    @staticmethod
+    async def _generate_next_questions(chat_history: List[Message], response: str):
+        questions = await NextQuestionSuggestion.suggest_next_questions(
+            chat_history, response
+        )
+        if questions:
+            return {
+                "type": "suggested_questions",
+                "data": questions,
+            }
+        return None
diff --git a/templates/types/multiagent/fastapi/app/api/services/suggestion.py b/templates/types/multiagent/fastapi/app/api/services/suggestion.py
deleted file mode 100644
index f881962e..00000000
--- a/templates/types/multiagent/fastapi/app/api/services/suggestion.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import logging
-from typing import List
-
-from app.api.routers.models import Message
-from llama_index.core.prompts import PromptTemplate
-from llama_index.core.settings import Settings
-from pydantic import BaseModel
-
-NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate(
-    "You're a helpful assistant! Your task is to suggest the next question that user might ask. "
-    "\nHere is the conversation history"
-    "\n---------------------\n{conversation}\n---------------------"
-    "Given the conversation history, please give me {number_of_questions} questions that you might ask next!"
-)
-N_QUESTION_TO_GENERATE = 3
-
-
-logger = logging.getLogger("uvicorn")
-
-
-class NextQuestions(BaseModel):
-    """A list of questions that user might ask next"""
-
-    questions: List[str]
-
-
-class NextQuestionSuggestion:
-    @staticmethod
-    async def suggest_next_questions(
-        messages: List[Message],
-        number_of_questions: int = N_QUESTION_TO_GENERATE,
-    ) -> List[str]:
-        """
-        Suggest the next questions that user might ask based on the conversation history
-        Return as empty list if there is an error
-        """
-        try:
-            # Reduce the cost by only using the last two messages
-            last_user_message = None
-            last_assistant_message = None
-            for message in reversed(messages):
-                if message.role == "user":
-                    last_user_message = f"User: {message.content}"
-                elif message.role == "assistant":
-                    last_assistant_message = f"Assistant: {message.content}"
-                if last_user_message and last_assistant_message:
-                    break
-            conversation: str = f"{last_user_message}\n{last_assistant_message}"
-
-            output: NextQuestions = await Settings.llm.astructured_predict(
-                NextQuestions,
-                prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT,
-                conversation=conversation,
-                number_of_questions=number_of_questions,
-            )
-
-            return output.questions
-        except Exception as e:
-            logger.error(f"Error when generating next question: {e}")
-            return []
diff --git a/templates/types/streaming/fastapi/app/api/routers/vercel_response.py b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py
index 1e32c265..924c60ce 100644
--- a/templates/types/streaming/fastapi/app/api/routers/vercel_response.py
+++ b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py
@@ -1,4 +1,5 @@
 import json
+from typing import List
 
 from aiostream import stream
 from fastapi import Request
@@ -54,22 +55,14 @@ class VercelStreamResponse(StreamingResponse):
             final_response = ""
             async for token in response.async_response_gen():
                 final_response += token
-                yield VercelStreamResponse.convert_text(token)
-
-            # Generate questions that user might interested to
-            conversation = chat_data.messages + [
-                Message(role="assistant", content=final_response)
-            ]
-            questions = await NextQuestionSuggestion.suggest_next_questions(
-                conversation
+                yield cls.convert_text(token)
+
+            # Generate next questions if next question prompt is configured
+            question_data = await cls._generate_next_questions(
+                chat_data.messages, final_response
             )
-            if len(questions) > 0:
-                yield VercelStreamResponse.convert_data(
-                    {
-                        "type": "suggested_questions",
-                        "data": questions,
-                    }
-                )
+            if question_data:
+                yield cls.convert_data(question_data)
 
             # the text_generator is the leading stream, once it's finished, also finish the event stream
             event_handler.is_done = True
@@ -92,7 +85,7 @@ class VercelStreamResponse(StreamingResponse):
             async for event in event_handler.async_event_gen():
                 event_response = event.to_response()
                 if event_response is not None:
-                    yield VercelStreamResponse.convert_data(event_response)
+                    yield cls.convert_data(event_response)
 
         combine = stream.merge(_chat_response_generator(), _event_generator())
         is_stream_started = False
@@ -101,9 +94,21 @@ class VercelStreamResponse(StreamingResponse):
                 if not is_stream_started:
                     is_stream_started = True
                     # Stream a blank message to start the stream
-                    yield VercelStreamResponse.convert_text("")
+                    yield cls.convert_text("")
 
                 yield output
 
                 if await request.is_disconnected():
                     break
+
+    @staticmethod
+    async def _generate_next_questions(chat_history: List[Message], response: str):
+        questions = await NextQuestionSuggestion.suggest_next_questions(
+            chat_history, response
+        )
+        if questions:
+            return {
+                "type": "suggested_questions",
+                "data": questions,
+            }
+        return None
diff --git a/templates/types/streaming/fastapi/app/api/services/file.py b/templates/types/streaming/fastapi/app/api/services/file.py
deleted file mode 100644
index 9441db6e..00000000
--- a/templates/types/streaming/fastapi/app/api/services/file.py
+++ /dev/null
@@ -1,119 +0,0 @@
-import base64
-import mimetypes
-import os
-from io import BytesIO
-from pathlib import Path
-from typing import Any, List, Tuple
-
-from app.engine.index import IndexConfig, get_index
-from llama_index.core import VectorStoreIndex
-from llama_index.core.ingestion import IngestionPipeline
-from llama_index.core.readers.file.base import (
-    _try_loading_included_file_formats as get_file_loaders_map,
-)
-from llama_index.core.schema import Document
-from llama_index.indices.managed.llama_cloud.base import LlamaCloudIndex
-from llama_index.readers.file import FlatReader
-
-
-def get_llamaparse_parser():
-    from app.engine.loaders import load_configs
-    from app.engine.loaders.file import FileLoaderConfig, llama_parse_parser
-
-    config = load_configs()
-    file_loader_config = FileLoaderConfig(**config["file"])
-    if file_loader_config.use_llama_parse:
-        return llama_parse_parser()
-    else:
-        return None
-
-
-def default_file_loaders_map():
-    default_loaders = get_file_loaders_map()
-    default_loaders[".txt"] = FlatReader
-    return default_loaders
-
-
-class PrivateFileService:
-    PRIVATE_STORE_PATH = "output/uploaded"
-
-    @staticmethod
-    def preprocess_base64_file(base64_content: str) -> Tuple[bytes, str | None]:
-        header, data = base64_content.split(",", 1)
-        mime_type = header.split(";")[0].split(":", 1)[1]
-        extension = mimetypes.guess_extension(mime_type)
-        # File data as bytes
-        return base64.b64decode(data), extension
-
-    @staticmethod
-    def store_and_parse_file(file_name, file_data, extension) -> List[Document]:
-        # Store file to the private directory
-        os.makedirs(PrivateFileService.PRIVATE_STORE_PATH, exist_ok=True)
-        file_path = Path(os.path.join(PrivateFileService.PRIVATE_STORE_PATH, file_name))
-
-        # write file
-        with open(file_path, "wb") as f:
-            f.write(file_data)
-
-        # Load file to documents
-        # If LlamaParse is enabled, use it to parse the file
-        # Otherwise, use the default file loaders
-        reader = get_llamaparse_parser()
-        if reader is None:
-            reader_cls = default_file_loaders_map().get(extension)
-            if reader_cls is None:
-                raise ValueError(f"File extension {extension} is not supported")
-            reader = reader_cls()
-        documents = reader.load_data(file_path)
-        # Add custom metadata
-        for doc in documents:
-            doc.metadata["file_name"] = file_name
-            doc.metadata["private"] = "true"
-        return documents
-
-    @staticmethod
-    def process_file(file_name: str, base64_content: str, params: Any) -> List[str]:
-        file_data, extension = PrivateFileService.preprocess_base64_file(base64_content)
-
-        # Add the nodes to the index and persist it
-        index_config = IndexConfig(**params)
-        current_index = get_index(index_config)
-
-        # Insert the documents into the index
-        if isinstance(current_index, LlamaCloudIndex):
-            from app.engine.service import LLamaCloudFileService
-
-            project_id = current_index._get_project_id()
-            pipeline_id = current_index._get_pipeline_id()
-            # LlamaCloudIndex is a managed index so we can directly use the files
-            upload_file = (file_name, BytesIO(file_data))
-            return [
-                LLamaCloudFileService.add_file_to_pipeline(
-                    project_id,
-                    pipeline_id,
-                    upload_file,
-                    custom_metadata={
-                        # Set private=true to mark the document as private user docs (required for filtering)
-                        "private": "true",
-                    },
-                )
-            ]
-        else:
-            # First process documents into nodes
-            documents = PrivateFileService.store_and_parse_file(
-                file_name, file_data, extension
-            )
-            pipeline = IngestionPipeline()
-            nodes = pipeline.run(documents=documents)
-
-            # Add the nodes to the index and persist it
-            if current_index is None:
-                current_index = VectorStoreIndex(nodes=nodes)
-            else:
-                current_index.insert_nodes(nodes=nodes)
-            current_index.storage_context.persist(
-                persist_dir=os.environ.get("STORAGE_DIR", "storage")
-            )
-
-            # Return the document ids
-            return [doc.doc_id for doc in documents]
diff --git a/templates/types/streaming/fastapi/app/api/services/suggestion.py b/templates/types/streaming/fastapi/app/api/services/suggestion.py
deleted file mode 100644
index f881962e..00000000
--- a/templates/types/streaming/fastapi/app/api/services/suggestion.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import logging
-from typing import List
-
-from app.api.routers.models import Message
-from llama_index.core.prompts import PromptTemplate
-from llama_index.core.settings import Settings
-from pydantic import BaseModel
-
-NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate(
-    "You're a helpful assistant! Your task is to suggest the next question that user might ask. "
-    "\nHere is the conversation history"
-    "\n---------------------\n{conversation}\n---------------------"
-    "Given the conversation history, please give me {number_of_questions} questions that you might ask next!"
-)
-N_QUESTION_TO_GENERATE = 3
-
-
-logger = logging.getLogger("uvicorn")
-
-
-class NextQuestions(BaseModel):
-    """A list of questions that user might ask next"""
-
-    questions: List[str]
-
-
-class NextQuestionSuggestion:
-    @staticmethod
-    async def suggest_next_questions(
-        messages: List[Message],
-        number_of_questions: int = N_QUESTION_TO_GENERATE,
-    ) -> List[str]:
-        """
-        Suggest the next questions that user might ask based on the conversation history
-        Return as empty list if there is an error
-        """
-        try:
-            # Reduce the cost by only using the last two messages
-            last_user_message = None
-            last_assistant_message = None
-            for message in reversed(messages):
-                if message.role == "user":
-                    last_user_message = f"User: {message.content}"
-                elif message.role == "assistant":
-                    last_assistant_message = f"Assistant: {message.content}"
-                if last_user_message and last_assistant_message:
-                    break
-            conversation: str = f"{last_user_message}\n{last_assistant_message}"
-
-            output: NextQuestions = await Settings.llm.astructured_predict(
-                NextQuestions,
-                prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT,
-                conversation=conversation,
-                number_of_questions=number_of_questions,
-            )
-
-            return output.questions
-        except Exception as e:
-            logger.error(f"Error when generating next question: {e}")
-            return []
diff --git a/templates/types/streaming/fastapi/pyproject.toml b/templates/types/streaming/fastapi/pyproject.toml
index b66ca346..69d87ba6 100644
--- a/templates/types/streaming/fastapi/pyproject.toml
+++ b/templates/types/streaming/fastapi/pyproject.toml
@@ -14,8 +14,8 @@ fastapi = "^0.109.1"
 uvicorn = { extras = ["standard"], version = "^0.23.2" }
 python-dotenv = "^1.0.0"
 aiostream = "^0.5.2"
-llama-index = "0.11.6"
 cachetools = "^5.3.3"
+llama-index = "0.11.6"
 
 [build-system]
 requires = ["poetry-core"]
-- 
GitLab