diff --git a/llama-index-core/tests/workflow/test_checkpointer.py b/llama-index-core/tests/workflow/test_checkpointer.py index c21292dc4128d00426318aef727b5a7f9c5f767e..14ff48897f575bc172301801ce75864b1f1b89dd 100644 --- a/llama-index-core/tests/workflow/test_checkpointer.py +++ b/llama-index-core/tests/workflow/test_checkpointer.py @@ -1,20 +1,19 @@ import asyncio -import pytest import random -from unittest.mock import patch, MagicMock -from typing import Coroutine +from unittest.mock import MagicMock, patch -from llama_index.core.workflow.events import StartEvent, StopEvent -from llama_index.core.workflow.handler import WorkflowHandler +import pytest +from llama_index.core.workflow.checkpointer import WorkflowCheckpointer from llama_index.core.workflow.events import ( StartEvent, StopEvent, ) +from llama_index.core.workflow.handler import WorkflowHandler from llama_index.core.workflow.workflow import ( Context, ) -from llama_index.core.workflow.checkpointer import WorkflowCheckpointer -from .conftest import OneTestEvent, DummyWorkflow, LastEvent + +from .conftest import DummyWorkflow, LastEvent, OneTestEvent @pytest.fixture() @@ -79,7 +78,7 @@ async def test_filter_checkpoints(workflow_checkpointer: WorkflowCheckpointer): steps = ["start_step", "middle_step", "end_step"] # sequential workflow for step in steps: checkpoints = workflow_checkpointer.filter_checkpoints(last_completed_step=step) - assert len(checkpoints) == num_runs, f"fails on step: {step.__name__}" + assert len(checkpoints) == num_runs, f"fails on step: {step}" # filter by input and output event event_types = [StartEvent, OneTestEvent, LastEvent, StopEvent] @@ -112,7 +111,7 @@ async def test_checkpoints_works_with_new_instances_concurrently( num_instances = 3 tasks = [] - async def add_random_startup(coro: Coroutine): + async def add_random_startup(coro: WorkflowHandler): """To randomly mix up the processing of the 3 runs.""" startup = random.random() await asyncio.sleep(startup) @@ -172,16 +171,20 @@ async def test_checkpointer_with_stepwise( stepwise_run_id = "stepwise_run" mock_uuid.uuid4.return_value = stepwise_run_id handler = workflow_checkpointer.run(stepwise=True) + assert handler.ctx event = await handler.run_step() + assert event assert len(workflow_checkpointer.checkpoints[stepwise_run_id]) == 1 handler.ctx.send_event(event) event = await handler.run_step() + assert event assert len(workflow_checkpointer.checkpoints[stepwise_run_id]) == 2 handler.ctx.send_event(event) event = await handler.run_step() + assert event assert len(workflow_checkpointer.checkpoints[stepwise_run_id]) == 3 handler.ctx.send_event(event) diff --git a/llama-index-core/tests/workflow/test_context.py b/llama-index-core/tests/workflow/test_context.py index 41afa9476a1a11bc3138f70daf7ba8e34c632a80..aa7aef5749632a369c8b3dda54d402553b85184e 100644 --- a/llama-index-core/tests/workflow/test_context.py +++ b/llama-index-core/tests/workflow/test_context.py @@ -65,7 +65,7 @@ async def test_get_not_found(ctx): async def test_legacy_data(workflow): c1 = Context(workflow) await c1.set(key="test_key", value=42) - assert c1.data["test_key"] == 42 + assert await c1.get("test_key") == 42 def test_send_event_step_is_none(ctx): @@ -168,6 +168,7 @@ async def test_wait_for_event_in_workflow(): workflow = TestWorkflow() handler = workflow.run() + assert handler.ctx async for ev in handler.stream_events(): if isinstance(ev, Event) and ev.msg == "foo": handler.ctx.send_event(Event(msg="bar")) diff --git a/llama-index-core/tests/workflow/test_decorator.py b/llama-index-core/tests/workflow/test_decorator.py index 3a072ef52bf666bb4bef40b86692e400ba664d2c..b4854592665287bbcd66aeb2797e330cc65bd13a 100644 --- a/llama-index-core/tests/workflow/test_decorator.py +++ b/llama-index-core/tests/workflow/test_decorator.py @@ -81,6 +81,6 @@ def test_decorate_free_function_wrong_num_workers(): WorkflowValidationError, match="num_workers must be an integer greater than 0" ): - @step(workflow=TestWorkflow, num_workers=0.5) + @step(workflow=TestWorkflow, num_workers=0.5) # type: ignore def f2(ev: Event) -> Event: return Event() diff --git a/llama-index-core/tests/workflow/test_event.py b/llama-index-core/tests/workflow/test_event.py index 490eeb76f1907e95d3d38c06f0dfd02542332ece..ff77e36ea5947069da9d6e1eb0a07ca1e82e8815 100644 --- a/llama-index-core/tests/workflow/test_event.py +++ b/llama-index-core/tests/workflow/test_event.py @@ -1,9 +1,9 @@ -import pytest +from typing import Any, cast -from llama_index.core.workflow.events import Event +import pytest from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.workflow.context_serializers import JsonSerializer -from typing import Any, cast +from llama_index.core.workflow.events import Event class _TestEvent(Event): @@ -91,7 +91,7 @@ def test_event_dict_api(): def test_event_serialization(): - ev = _TestEvent(param="foo", not_a_field="bar") + ev = _TestEvent(param="foo", not_a_field="bar") # type: ignore serializer = JsonSerializer() serialized_ev = serializer.serialize(ev) deseriazlied_ev = serializer.deserialize(serialized_ev) diff --git a/llama-index-core/tests/workflow/test_streaming.py b/llama-index-core/tests/workflow/test_streaming.py index d66de3abc69cbd9c5b40eb3b84aaaa21dbed822a..143e40b8224f0a84bcf0b7378896e4885f6449de 100644 --- a/llama-index-core/tests/workflow/test_streaming.py +++ b/llama-index-core/tests/workflow/test_streaming.py @@ -19,7 +19,7 @@ class StreamingWorkflow(Workflow): yield word async for w in stream_messages(): - ctx.session.write_event_to_stream(Event(msg=w)) + ctx.write_event_to_stream(Event(msg=w)) return StopEvent(result=None) @@ -151,4 +151,5 @@ async def test_resume_streams(): pass await handler_2 + assert handler_2.ctx assert await handler_2.ctx.get("cur_count") == 2 diff --git a/llama-index-core/tests/workflow/test_workflow.py b/llama-index-core/tests/workflow/test_workflow.py index da5c4ac3c2e1bceb8cf2c6f9498b656c3965c1a5..40dec9f96dbd44b45ae4dd3d10a756f0e5276953 100644 --- a/llama-index-core/tests/workflow/test_workflow.py +++ b/llama-index-core/tests/workflow/test_workflow.py @@ -257,6 +257,7 @@ async def test_workflow_num_workers(): assert set(result) == {"test1", "test2", "test4"} # ctx should have 1 extra event + assert handler.ctx assert ( len(handler.ctx._events_buffer["tests.workflow.conftest.AnotherTestEvent"]) == 1 ) @@ -473,18 +474,21 @@ async def test_workflow_continue_context(): # first run r = wf.run() result = await r + assert r.ctx assert result == "Done" assert await r.ctx.get("number") == 1 # second run -- independent from the first r = wf.run() result = await r + assert r.ctx assert result == "Done" assert await r.ctx.get("number") == 1 # third run -- continue from the second run r = wf.run(ctx=r.ctx) result = await r + assert r.ctx assert result == "Done" assert await r.ctx.get("number") == 2 @@ -503,6 +507,7 @@ async def test_workflow_pickle(): wf = DummyWorkflow() handler = wf.run() + assert handler.ctx _ = await handler # by default, we can't pickle the LLM/embedding object @@ -514,6 +519,7 @@ async def test_workflow_pickle(): new_handler = WorkflowHandler( ctx=Context.from_dict(wf, state_dict, serializer=JsonPickleSerializer()) ) + assert new_handler.ctx # check that the step count is the same cur_step = await handler.ctx.get("step") @@ -531,6 +537,7 @@ async def test_workflow_pickle(): assert new_llm.max_tokens == llm.max_tokens handler = wf.run(ctx=new_handler.ctx) + assert handler.ctx _ = await handler # check that the step count is incremented @@ -636,6 +643,7 @@ async def test_human_in_the_loop(): workflow = HumanInTheLoopWorkflow() handler: WorkflowHandler = workflow.run() + assert handler.ctx async for event in handler.stream_events(): if isinstance(event, InputRequiredEvent): assert event.prefix == "Enter a number: "