Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Llama Index
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
mirrored_repos
MachineLearning
run-llama
Llama Index
Commits
d520b525
Unverified
Commit
d520b525
authored
1 year ago
by
Haotian Zhang
Committed by
GitHub
1 year ago
Browse files
Options
Downloads
Patches
Plain Diff
Aysnc for Base nodes parser (#10418)
* Aysnc for Base nodes parser * cr * remove some unit tests * cr
parent
2386cf21
Branches
Branches containing commit
Tags
Tags containing commit
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
llama_index/node_parser/relational/base_element.py
+30
-4
30 additions, 4 deletions
llama_index/node_parser/relational/base_element.py
tests/param_tuner/test_base.py
+12
-12
12 additions, 12 deletions
tests/param_tuner/test_base.py
with
42 additions
and
16 deletions
llama_index/node_parser/relational/base_element.py
+
30
−
4
View file @
d520b525
import
asyncio
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
cast
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
cast
import
pandas
as
pd
import
pandas
as
pd
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
llama_index.async_utils
import
DEFAULT_NUM_WORKERS
,
run_jobs
from
llama_index.bridge.pydantic
import
BaseModel
,
Field
,
ValidationError
from
llama_index.bridge.pydantic
import
BaseModel
,
Field
,
ValidationError
from
llama_index.callbacks.base
import
CallbackManager
from
llama_index.callbacks.base
import
CallbackManager
from
llama_index.core.response.schema
import
PydanticResponse
from
llama_index.core.response.schema
import
PydanticResponse
...
@@ -75,6 +77,12 @@ class BaseElementNodeParser(NodeParser):
...
@@ -75,6 +77,12 @@ class BaseElementNodeParser(NodeParser):
default
=
DEFAULT_SUMMARY_QUERY_STR
,
default
=
DEFAULT_SUMMARY_QUERY_STR
,
description
=
"
Query string to use for summarization.
"
,
description
=
"
Query string to use for summarization.
"
,
)
)
num_workers
:
int
=
Field
(
default
=
DEFAULT_NUM_WORKERS
,
description
=
"
Num of works for async jobs.
"
,
)
show_progress
:
bool
=
Field
(
default
=
True
,
description
=
"
Whether to show progress.
"
)
@classmethod
@classmethod
def
class_name
(
cls
)
->
str
:
def
class_name
(
cls
)
->
str
:
...
@@ -135,6 +143,8 @@ class BaseElementNodeParser(NodeParser):
...
@@ -135,6 +143,8 @@ class BaseElementNodeParser(NodeParser):
llm
=
cast
(
LLM
,
llm
)
llm
=
cast
(
LLM
,
llm
)
service_context
=
ServiceContext
.
from_defaults
(
llm
=
llm
,
embed_model
=
None
)
service_context
=
ServiceContext
.
from_defaults
(
llm
=
llm
,
embed_model
=
None
)
table_context_list
=
[]
for
idx
,
element
in
tqdm
(
enumerate
(
elements
)):
for
idx
,
element
in
tqdm
(
enumerate
(
elements
)):
if
element
.
type
!=
"
table
"
:
if
element
.
type
!=
"
table
"
:
continue
continue
...
@@ -147,19 +157,35 @@ class BaseElementNodeParser(NodeParser):
...
@@ -147,19 +157,35 @@ class BaseElementNodeParser(NodeParser):
elements
[
idx
-
1
].
element
elements
[
idx
-
1
].
element
).
lower
().
strip
().
startswith
(
"
table
"
):
).
lower
().
strip
().
startswith
(
"
table
"
):
table_context
+=
"
\n
"
+
str
(
elements
[
idx
+
1
].
element
)
table_context
+=
"
\n
"
+
str
(
elements
[
idx
+
1
].
element
)
table_context_list
.
append
(
table_context
)
async
def
_get_table_output
(
table_context
:
str
,
summary_query_str
:
str
)
->
Any
:
index
=
SummaryIndex
.
from_documents
(
index
=
SummaryIndex
.
from_documents
(
[
Document
(
text
=
table_context
)],
service_context
=
service_context
[
Document
(
text
=
table_context
)],
service_context
=
service_context
)
)
query_engine
=
index
.
as_query_engine
(
output_cls
=
TableOutput
)
query_engine
=
index
.
as_query_engine
(
output_cls
=
TableOutput
)
try
:
try
:
response
=
query_engine
.
query
(
self
.
summary_query_str
)
response
=
await
query_engine
.
a
query
(
summary_query_str
)
element
.
table_output
=
cast
(
PydanticResponse
,
response
).
response
return
cast
(
PydanticResponse
,
response
).
response
except
ValidationError
:
except
ValidationError
:
# There was a pydantic validation error, so we will run with text completion
# There was a pydantic validation error, so we will run with text completion
# fill in the summary and leave other fields blank
# fill in the summary and leave other fields blank
query_engine
=
index
.
as_query_engine
()
query_engine
=
index
.
as_query_engine
()
response_txt
=
str
(
query_engine
.
query
(
self
.
summary_query_str
))
response_txt
=
await
query_engine
.
aquery
(
summary_query_str
)
element
.
table_output
=
TableOutput
(
summary
=
response_txt
,
columns
=
[])
return
TableOutput
(
summary
=
str
(
response_txt
),
columns
=
[])
summary_jobs
=
[
_get_table_output
(
table_context
,
self
.
summary_query_str
)
for
table_context
in
table_context_list
]
summary_outputs
=
asyncio
.
run
(
run_jobs
(
summary_jobs
,
show_progress
=
self
.
show_progress
,
workers
=
self
.
num_workers
)
)
for
element
,
summary_output
in
zip
(
elements
,
summary_outputs
):
element
.
table_output
=
summary_output
def
get_base_nodes_and_mappings
(
def
get_base_nodes_and_mappings
(
self
,
nodes
:
List
[
BaseNode
]
self
,
nodes
:
List
[
BaseNode
]
...
...
This diff is collapsed.
Click to expand it.
tests/param_tuner/test_base.py
+
12
−
12
View file @
d520b525
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
from
typing
import
Dict
from
typing
import
Dict
from
llama_index.param_tuner.base
import
AsyncParamTuner
,
ParamTuner
,
RunResult
from
llama_index.param_tuner.base
import
ParamTuner
,
RunResult
def
_mock_obj_function
(
param_dict
:
Dict
)
->
RunResult
:
def
_mock_obj_function
(
param_dict
:
Dict
)
->
RunResult
:
...
@@ -40,14 +40,14 @@ def test_param_tuner() -> None:
...
@@ -40,14 +40,14 @@ def test_param_tuner() -> None:
assert
result
.
best_run_result
.
params
[
"
a
"
]
==
3
assert
result
.
best_run_result
.
params
[
"
a
"
]
==
3
assert
result
.
best_run_result
.
params
[
"
b
"
]
==
6
assert
result
.
best_run_result
.
params
[
"
b
"
]
==
6
# try async version
#
#
try async version
atuner
=
AsyncParamTuner
(
#
atuner = AsyncParamTuner(
param_dict
=
param_dict
,
#
param_dict=param_dict,
fixed_param_dict
=
fixed_param_dict
,
#
fixed_param_dict=fixed_param_dict,
aparam_fn
=
_amock_obj_function
,
#
aparam_fn=_amock_obj_function,
)
#
)
# should run synchronous fn
#
#
should run synchronous fn
result
=
atuner
.
tune
()
#
result = atuner.tune()
assert
result
.
best_run_result
.
score
==
4
#
assert result.best_run_result.score == 4
assert
result
.
best_run_result
.
params
[
"
a
"
]
==
3
#
assert result.best_run_result.params["a"] == 3
assert
result
.
best_run_result
.
params
[
"
b
"
]
==
4
#
assert result.best_run_result.params["b"] == 4
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment