Skip to content
Snippets Groups Projects
Unverified Commit d520b525 authored by Haotian Zhang's avatar Haotian Zhang Committed by GitHub
Browse files

Aysnc for Base nodes parser (#10418)

* Aysnc for Base nodes parser

* cr

* remove some unit tests

* cr
parent 2386cf21
Branches
Tags
No related merge requests found
import asyncio
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast from typing import Any, Dict, List, Optional, Sequence, Tuple, cast
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from llama_index.async_utils import DEFAULT_NUM_WORKERS, run_jobs
from llama_index.bridge.pydantic import BaseModel, Field, ValidationError from llama_index.bridge.pydantic import BaseModel, Field, ValidationError
from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.base import CallbackManager
from llama_index.core.response.schema import PydanticResponse from llama_index.core.response.schema import PydanticResponse
...@@ -75,6 +77,12 @@ class BaseElementNodeParser(NodeParser): ...@@ -75,6 +77,12 @@ class BaseElementNodeParser(NodeParser):
default=DEFAULT_SUMMARY_QUERY_STR, default=DEFAULT_SUMMARY_QUERY_STR,
description="Query string to use for summarization.", description="Query string to use for summarization.",
) )
num_workers: int = Field(
default=DEFAULT_NUM_WORKERS,
description="Num of works for async jobs.",
)
show_progress: bool = Field(default=True, description="Whether to show progress.")
@classmethod @classmethod
def class_name(cls) -> str: def class_name(cls) -> str:
...@@ -135,6 +143,8 @@ class BaseElementNodeParser(NodeParser): ...@@ -135,6 +143,8 @@ class BaseElementNodeParser(NodeParser):
llm = cast(LLM, llm) llm = cast(LLM, llm)
service_context = ServiceContext.from_defaults(llm=llm, embed_model=None) service_context = ServiceContext.from_defaults(llm=llm, embed_model=None)
table_context_list = []
for idx, element in tqdm(enumerate(elements)): for idx, element in tqdm(enumerate(elements)):
if element.type != "table": if element.type != "table":
continue continue
...@@ -147,19 +157,35 @@ class BaseElementNodeParser(NodeParser): ...@@ -147,19 +157,35 @@ class BaseElementNodeParser(NodeParser):
elements[idx - 1].element elements[idx - 1].element
).lower().strip().startswith("table"): ).lower().strip().startswith("table"):
table_context += "\n" + str(elements[idx + 1].element) table_context += "\n" + str(elements[idx + 1].element)
table_context_list.append(table_context)
async def _get_table_output(table_context: str, summary_query_str: str) -> Any:
index = SummaryIndex.from_documents( index = SummaryIndex.from_documents(
[Document(text=table_context)], service_context=service_context [Document(text=table_context)], service_context=service_context
) )
query_engine = index.as_query_engine(output_cls=TableOutput) query_engine = index.as_query_engine(output_cls=TableOutput)
try: try:
response = query_engine.query(self.summary_query_str) response = await query_engine.aquery(summary_query_str)
element.table_output = cast(PydanticResponse, response).response return cast(PydanticResponse, response).response
except ValidationError: except ValidationError:
# There was a pydantic validation error, so we will run with text completion # There was a pydantic validation error, so we will run with text completion
# fill in the summary and leave other fields blank # fill in the summary and leave other fields blank
query_engine = index.as_query_engine() query_engine = index.as_query_engine()
response_txt = str(query_engine.query(self.summary_query_str)) response_txt = await query_engine.aquery(summary_query_str)
element.table_output = TableOutput(summary=response_txt, columns=[]) return TableOutput(summary=str(response_txt), columns=[])
summary_jobs = [
_get_table_output(table_context, self.summary_query_str)
for table_context in table_context_list
]
summary_outputs = asyncio.run(
run_jobs(
summary_jobs, show_progress=self.show_progress, workers=self.num_workers
)
)
for element, summary_output in zip(elements, summary_outputs):
element.table_output = summary_output
def get_base_nodes_and_mappings( def get_base_nodes_and_mappings(
self, nodes: List[BaseNode] self, nodes: List[BaseNode]
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from typing import Dict from typing import Dict
from llama_index.param_tuner.base import AsyncParamTuner, ParamTuner, RunResult from llama_index.param_tuner.base import ParamTuner, RunResult
def _mock_obj_function(param_dict: Dict) -> RunResult: def _mock_obj_function(param_dict: Dict) -> RunResult:
...@@ -40,14 +40,14 @@ def test_param_tuner() -> None: ...@@ -40,14 +40,14 @@ def test_param_tuner() -> None:
assert result.best_run_result.params["a"] == 3 assert result.best_run_result.params["a"] == 3
assert result.best_run_result.params["b"] == 6 assert result.best_run_result.params["b"] == 6
# try async version # # try async version
atuner = AsyncParamTuner( # atuner = AsyncParamTuner(
param_dict=param_dict, # param_dict=param_dict,
fixed_param_dict=fixed_param_dict, # fixed_param_dict=fixed_param_dict,
aparam_fn=_amock_obj_function, # aparam_fn=_amock_obj_function,
) # )
# should run synchronous fn # # should run synchronous fn
result = atuner.tune() # result = atuner.tune()
assert result.best_run_result.score == 4 # assert result.best_run_result.score == 4
assert result.best_run_result.params["a"] == 3 # assert result.best_run_result.params["a"] == 3
assert result.best_run_result.params["b"] == 4 # assert result.best_run_result.params["b"] == 4
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment