Skip to content
Snippets Groups Projects
Unverified Commit 4e06714c authored by Huu Le's avatar Huu Le Committed by GitHub
Browse files

Fix: deep research use case (#493)

parent 18c8d254
No related branches found
No related tags found
No related merge requests found
---
"create-llama": patch
---
Fix the error: Unable to view file sources due to CORS.
......@@ -16,7 +16,10 @@ class AnalysisDecision(BaseModel):
description="Whether to continue research, write a report, or cancel the research after several retries"
)
research_questions: Optional[List[str]] = Field(
description="Questions to research if continuing research. Maximum 3 questions. Set to null or empty if writing a report.",
description="""
If the decision is to research, provide a list of questions to research that related to the user request.
Maximum 3 questions. Set to null or empty if writing a report or cancel the research.
""",
default_factory=list,
)
cancel_reason: Optional[str] = Field(
......@@ -29,23 +32,23 @@ async def plan_research(
memory: SimpleComposableMemory,
context_nodes: List[Node],
user_request: str,
total_questions: int,
) -> AnalysisDecision:
analyze_prompt = PromptTemplate(
"""
analyze_prompt = """
You are a professor who is guiding a researcher to research a specific request/problem.
Your task is to decide on a research plan for the researcher.
The possible actions are:
+ Provide a list of questions for the researcher to investigate, with the purpose of clarifying the request.
+ Write a report if the researcher has already gathered enough research on the topic and can resolve the initial request.
+ Cancel the research if most of the answers from researchers indicate there is insufficient information to research the request. Do not attempt more than 3 research iterations or too many questions.
The workflow should be:
+ Always begin by providing some initial questions for the researcher to investigate.
+ Analyze the provided answers against the initial topic/request. If the answers are insufficient to resolve the initial request, provide additional questions for the researcher to investigate.
+ If the answers are sufficient to resolve the initial request, instruct the researcher to write a report.
<User request>
{user_request}
</User request>
Here are the context:
<Collected information>
{context_str}
</Collected information>
......@@ -53,8 +56,29 @@ async def plan_research(
<Conversation context>
{conversation_context}
</Conversation context>
{enhanced_prompt}
Now, provide your decision in the required format for this user request:
<User request>
{user_request}
</User request>
"""
)
# Manually craft the prompt to avoid LLM hallucination
enhanced_prompt = ""
if total_questions == 0:
# Avoid writing a report without any research context
enhanced_prompt = """
The student has no questions to research. Let start by asking some questions.
"""
elif total_questions > 6:
# Avoid asking too many questions (when the data is not ready for writing a report)
enhanced_prompt = f"""
The student has researched {total_questions} questions. Should cancel the research if the context is not enough to write a report.
"""
conversation_context = "\n".join(
[f"{message.role}: {message.content}" for message in memory.get_all()]
)
......@@ -63,10 +87,11 @@ async def plan_research(
)
res = await Settings.llm.astructured_predict(
output_cls=AnalysisDecision,
prompt=analyze_prompt,
prompt=PromptTemplate(template=analyze_prompt),
user_request=user_request,
context_str=context_str,
conversation_context=conversation_context,
enhanced_prompt=enhanced_prompt,
)
return res
......
......@@ -89,10 +89,11 @@ class DeepResearchWorkflow(Workflow):
)
@step
def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
async def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
"""
Initiate the workflow: memory, tools, agent
"""
await ctx.set("total_questions", 0)
self.user_request = ev.get("input")
self.memory.put_messages(
messages=[
......@@ -132,9 +133,7 @@ class DeepResearchWorkflow(Workflow):
nodes=nodes,
)
)
return PlanResearchEvent(
context_nodes=self.context_nodes,
)
return PlanResearchEvent()
@step
async def analyze(
......@@ -153,10 +152,12 @@ class DeepResearchWorkflow(Workflow):
},
)
)
total_questions = await ctx.get("total_questions")
res = await plan_research(
memory=self.memory,
context_nodes=self.context_nodes,
user_request=self.user_request,
total_questions=total_questions,
)
if res.decision == "cancel":
ctx.write_event_to_stream(
......@@ -172,6 +173,22 @@ class DeepResearchWorkflow(Workflow):
result=res.cancel_reason,
)
elif res.decision == "write":
# Writing a report without any research context is not allowed.
# It's a LLM hallucination.
if total_questions == 0:
ctx.write_event_to_stream(
DataEvent(
type="deep_research_event",
data={
"event": "analyze",
"state": "done",
},
)
)
return StopEvent(
result="Sorry, I have a problem when analyzing the retrieved information. Please try again.",
)
self.memory.put(
message=ChatMessage(
role=MessageRole.ASSISTANT,
......@@ -180,7 +197,11 @@ class DeepResearchWorkflow(Workflow):
)
ctx.send_event(ReportEvent())
else:
await ctx.set("n_questions", len(res.research_questions))
total_questions += len(res.research_questions)
await ctx.set("total_questions", total_questions) # For tracking
await ctx.set(
"waiting_questions", len(res.research_questions)
) # For waiting questions to be answered
self.memory.put(
message=ChatMessage(
role=MessageRole.ASSISTANT,
......@@ -270,7 +291,7 @@ class DeepResearchWorkflow(Workflow):
"""
Collect answers to all questions
"""
num_questions = await ctx.get("n_questions")
num_questions = await ctx.get("waiting_questions")
results = ctx.collect_events(
ev,
expected=[CollectAnswersEvent] * num_questions,
......@@ -284,7 +305,7 @@ class DeepResearchWorkflow(Workflow):
content=f"<Question>{result.question}</Question>\n<Answer>{result.answer}</Answer>",
)
)
await ctx.set("n_questions", 0)
await ctx.set("waiting_questions", 0)
self.memory.put(
message=ChatMessage(
role=MessageRole.ASSISTANT,
......
# flake8: noqa: E402
from app.config import DATA_DIR, STATIC_DIR
from dotenv import load_dotenv
from app.config import DATA_DIR, STATIC_DIR
load_dotenv()
import logging
import os
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
from app.api.routers import api_router
from app.middlewares.frontend import FrontendProxyMiddleware
from app.observability import init_observability
from app.settings import init_settings
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
servers = []
app_name = os.getenv("FLY_APP_NAME")
......@@ -28,6 +31,16 @@ init_observability()
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
logger = logging.getLogger("uvicorn")
# Add CORS middleware for development
if environment == "dev":
app.add_middleware(
CORSMiddleware,
allow_origin_regex="http://localhost:\d+|http://0\.0\.0\.0:\d+",
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def mount_static_files(directory, path, html=False):
if os.path.exists(directory):
......
"use client";
import * as AccordionPrimitive from "@radix-ui/react-accordion";
import { ChevronDown } from "lucide-react";
import * as React from "react";
import { cn } from "./lib/utils";
const Accordion = AccordionPrimitive.Root;
const AccordionItem = React.forwardRef<
React.ElementRef<typeof AccordionPrimitive.Item>,
React.ComponentPropsWithoutRef<typeof AccordionPrimitive.Item>
>(({ className, ...props }, ref) => (
<AccordionPrimitive.Item
ref={ref}
className={cn("border-b", className)}
{...props}
/>
));
AccordionItem.displayName = "AccordionItem";
const AccordionTrigger = React.forwardRef<
React.ElementRef<typeof AccordionPrimitive.Trigger>,
React.ComponentPropsWithoutRef<typeof AccordionPrimitive.Trigger>
>(({ className, children, ...props }, ref) => (
<AccordionPrimitive.Header className="flex">
<AccordionPrimitive.Trigger
ref={ref}
className={cn(
"flex flex-1 items-center justify-between py-4 text-sm font-medium transition-all hover:underline text-left [&[data-state=open]>svg]:rotate-180",
className,
)}
{...props}
>
{children}
<ChevronDown className="h-4 w-4 shrink-0 text-neutral-500 transition-transform duration-200 dark:text-neutral-400" />
</AccordionPrimitive.Trigger>
</AccordionPrimitive.Header>
));
AccordionTrigger.displayName = AccordionPrimitive.Trigger.displayName;
const AccordionContent = React.forwardRef<
React.ElementRef<typeof AccordionPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof AccordionPrimitive.Content>
>(({ className, children, ...props }, ref) => (
<AccordionPrimitive.Content
ref={ref}
className="overflow-hidden text-sm data-[state=closed]:animate-accordion-up data-[state=open]:animate-accordion-down"
{...props}
>
<div className={cn("pb-4 pt-0", className)}>{children}</div>
</AccordionPrimitive.Content>
));
AccordionContent.displayName = AccordionPrimitive.Content.displayName;
export { Accordion, AccordionContent, AccordionItem, AccordionTrigger };
import * as React from "react";
import { cn } from "./lib/utils";
const Card = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn(
"rounded-xl border border-neutral-200 bg-white text-neutral-950 shadow dark:border-neutral-800 dark:bg-neutral-950 dark:text-neutral-50",
className,
)}
{...props}
/>
));
Card.displayName = "Card";
const CardHeader = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn("flex flex-col space-y-1.5 p-6", className)}
{...props}
/>
));
CardHeader.displayName = "CardHeader";
const CardTitle = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn("font-semibold leading-none tracking-tight", className)}
{...props}
/>
));
CardTitle.displayName = "CardTitle";
const CardDescription = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn("text-sm text-neutral-500 dark:text-neutral-400", className)}
{...props}
/>
));
CardDescription.displayName = "CardDescription";
const CardContent = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div ref={ref} className={cn("p-6 pt-0", className)} {...props} />
));
CardContent.displayName = "CardContent";
const CardFooter = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn("flex items-center p-6 pt-0", className)}
{...props}
/>
));
CardFooter.displayName = "CardFooter";
export {
Card,
CardContent,
CardDescription,
CardFooter,
CardHeader,
CardTitle,
};
......@@ -4,7 +4,6 @@ import { Message } from "@llamaindex/chat-ui";
import {
AlertCircle,
CheckCircle2,
ChevronDown,
CircleDashed,
Clock,
NotebookPen,
......@@ -12,10 +11,12 @@ import {
} from "lucide-react";
import { useMemo } from "react";
import {
Collapsible,
CollapsibleContent,
CollapsibleTrigger,
} from "../../collapsible";
Accordion,
AccordionContent,
AccordionItem,
AccordionTrigger,
} from "../../accordion";
import { Card, CardContent, CardHeader, CardTitle } from "../../card";
import { cn } from "../../lib/utils";
import { Markdown } from "./markdown";
......@@ -163,63 +164,53 @@ export function DeepResearchCard({
}
return (
<div
className={cn(
"rounded-lg border bg-card text-card-foreground shadow-sm p-5 space-y-6 w-full",
className,
)}
>
{state.retrieve.state !== null && (
<div className="border-t pt-4">
<h3 className="text-lg font-semibold flex items-center gap-2">
<Card className={cn("w-full", className)}>
<CardHeader className="space-y-4">
{state.retrieve.state !== null && (
<CardTitle className="flex items-center gap-2">
<Search className="h-5 w-5" />
<span>
{state.retrieve.state === "inprogress"
? "Searching..."
: "Search completed"}
</span>
</h3>
</div>
)}
{state.analyze.state !== null && (
<div className="border-t pt-4">
<h3 className="text-lg font-semibold flex items-center gap-2">
{state.retrieve.state === "inprogress"
? "Searching..."
: "Search completed"}
</CardTitle>
)}
{state.analyze.state !== null && (
<CardTitle className="flex items-center gap-2 border-t pt-4">
<NotebookPen className="h-5 w-5" />
<span>
{state.analyze.state === "inprogress"
? "Analyzing..."
: "Analysis"}
</span>
</h3>
{state.analyze.questions.length > 0 && (
<div className="space-y-2">
{state.analyze.questions.map((question: QuestionState) => (
<Collapsible key={question.id}>
<CollapsibleTrigger className="w-full">
<div className="flex items-center gap-2 p-3 hover:bg-accent transition-colors rounded-lg border">
<div className="flex-shrink-0">
{stateIcon[question.state]}
</div>
<span className="font-medium text-left flex-1">
{question.question}
</span>
<ChevronDown className="h-5 w-5 transition-transform ui-expanded:rotate-180" />
{state.analyze.state === "inprogress" ? "Analyzing..." : "Analysis"}
</CardTitle>
)}
</CardHeader>
<CardContent>
{state.analyze.questions.length > 0 && (
<Accordion type="single" collapsible className="space-y-2">
{state.analyze.questions.map((question: QuestionState) => (
<AccordionItem
key={question.id}
value={question.id}
className="border rounded-lg [&[data-state=open]>div]:rounded-b-none"
>
<AccordionTrigger className="hover:bg-accent hover:no-underline py-3 px-3 gap-2">
<div className="flex items-center gap-2 w-full">
<div className="flex-shrink-0">
{stateIcon[question.state]}
</div>
</CollapsibleTrigger>
{question.answer && (
<CollapsibleContent>
<div className="p-3 border border-t-0 rounded-b-lg">
<Markdown content={question.answer} />
</div>
</CollapsibleContent>
)}
</Collapsible>
))}
</div>
)}
</div>
)}
</div>
<span className="font-medium text-left flex-1">
{question.question}
</span>
</div>
</AccordionTrigger>
{question.answer && (
<AccordionContent className="border-t px-3 py-3">
<Markdown content={question.answer} />
</AccordionContent>
)}
</AccordionItem>
))}
</Accordion>
)}
</CardContent>
</Card>
);
}
......@@ -12,6 +12,7 @@
"dependencies": {
"@apidevtools/swagger-parser": "^10.1.0",
"@e2b/code-interpreter": "^1.0.4",
"@radix-ui/react-accordion": "^1.2.2",
"@radix-ui/react-collapsible": "^1.0.3",
"@radix-ui/react-select": "^2.1.1",
"@radix-ui/react-slot": "^1.0.2",
......@@ -19,7 +20,7 @@
"@llamaindex/chat-ui": "0.0.14",
"ai": "^4.0.3",
"ajv": "^8.12.0",
"class-variance-authority": "^0.7.0",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"dotenv": "^16.3.1",
"duck-duck-scrape": "^2.2.5",
......@@ -32,7 +33,7 @@
"react-dom": "^19.0.0",
"papaparse": "^5.4.1",
"supports-color": "^8.1.1",
"tailwind-merge": "^2.1.0",
"tailwind-merge": "^2.6.0",
"tiktoken": "^1.0.15",
"uuid": "^9.0.1",
"marked": "^14.1.2"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment