From d011e3c041b07648aab5b46e0bd411c89c5e0ee6 Mon Sep 17 00:00:00 2001 From: Mateusz Charytoniuk <mateusz.charytoniuk@protonmail.com> Date: Tue, 26 Mar 2024 21:05:42 +0100 Subject: [PATCH] feat: observable tasks factory --- src/ObservableTask.php | 5 ++- src/ObservableTaskCategory.php | 10 +++++ src/ObservableTaskFactory.php | 43 +++++++++++++++++++ src/ObservableTaskTimeoutIterator.php | 6 ++- .../LlamaCppSubjectActionPromptResponder.php | 36 +++++++--------- 5 files changed, 76 insertions(+), 24 deletions(-) create mode 100644 src/ObservableTaskCategory.php create mode 100644 src/ObservableTaskFactory.php diff --git a/src/ObservableTask.php b/src/ObservableTask.php index 03256c9b..3cd40439 100644 --- a/src/ObservableTask.php +++ b/src/ObservableTask.php @@ -8,6 +8,9 @@ use Closure; use Generator; use Throwable; +/** + * @psalm-type TIterableTaskCallback = callable():iterable<ObservableTaskStatusUpdate> + */ readonly class ObservableTask implements ObservableTaskInterface { /** @@ -16,7 +19,7 @@ readonly class ObservableTask implements ObservableTaskInterface private Closure $iterableTask; /** - * @param callable():iterable<ObservableTaskStatusUpdate> $iterableTask + * @param TIterableTaskCallback $iterableTask */ public function __construct( callable $iterableTask, diff --git a/src/ObservableTaskCategory.php b/src/ObservableTaskCategory.php new file mode 100644 index 00000000..a9901c76 --- /dev/null +++ b/src/ObservableTaskCategory.php @@ -0,0 +1,10 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance; + +enum ObservableTaskCategory: string +{ + case LlamaCpp = 'llama_cpp'; +} diff --git a/src/ObservableTaskFactory.php b/src/ObservableTaskFactory.php new file mode 100644 index 00000000..7c88bbd4 --- /dev/null +++ b/src/ObservableTaskFactory.php @@ -0,0 +1,43 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance; + +use Generator; + +/** + * @psalm-import-type TIterableTaskCallback from ObservableTask + */ +final readonly class ObservableTaskFactory +{ + public static function withTimeout( + callable $iterableTask, + float $inactivityTimeout = 5.0, + string $name = '', + string $category = '', + ): ObservableTask { + return new ObservableTask( + iterableTask: new ObservableTaskTimeoutIterator( + iterableTask: static function () use ($iterableTask): Generator { + yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Running, null); + + try { + yield from $iterableTask(); + } finally { + yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Finished, null); + } + }, + inactivityTimeout: $inactivityTimeout, + ), + name: $name, + category: $category, + ); + } + + /** + * @psalm-suppress UnusedConstructor this class is just a wrapper around + * functions + */ + private function __construct() {} +} diff --git a/src/ObservableTaskTimeoutIterator.php b/src/ObservableTaskTimeoutIterator.php index 38aac69e..516b3865 100644 --- a/src/ObservableTaskTimeoutIterator.php +++ b/src/ObservableTaskTimeoutIterator.php @@ -11,17 +11,19 @@ use Swoole\Coroutine; use Swoole\Coroutine\Channel; /** + * @psalm-import-type TIterableTaskCallback from ObservableTask + * * @template-implements IteratorAggregate<ObservableTaskStatusUpdate> */ readonly class ObservableTaskTimeoutIterator implements IteratorAggregate { /** - * @var Closure():Generator<ObservableTaskStatusUpdate> + * @var Closure():iterable<ObservableTaskStatusUpdate> */ private Closure $iterableTask; /** - * @param callable():Generator<ObservableTaskStatusUpdate> $iterableTask + * @param TIterableTaskCallback $iterableTask */ public function __construct( callable $iterableTask, diff --git a/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php b/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php index 9a696af8..8c3603de 100644 --- a/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php +++ b/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php @@ -12,11 +12,11 @@ use Distantmagic\Resonance\LlamaCppCompletionRequest; use Distantmagic\Resonance\LlmPrompt\SubjectActionPrompt; use Distantmagic\Resonance\LlmPromptTemplate; use Distantmagic\Resonance\LlmPromptTemplate\ChainPrompt; -use Distantmagic\Resonance\ObservableTask; +use Distantmagic\Resonance\ObservableTaskCategory; +use Distantmagic\Resonance\ObservableTaskFactory; use Distantmagic\Resonance\ObservableTaskStatus; use Distantmagic\Resonance\ObservableTaskStatusUpdate; use Distantmagic\Resonance\ObservableTaskTable; -use Distantmagic\Resonance\ObservableTaskTimeoutIterator; use Distantmagic\Resonance\PromptSubjectResponderAggregate; use Distantmagic\Resonance\WebSocketAuthResolution; use Distantmagic\Resonance\WebSocketConnection; @@ -86,29 +86,21 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs WebSocketConnection $webSocketConnection, JsonRPCRequest $rpcRequest, ): void { - $this->observableTaskTable->observe(new ObservableTask( - iterableTask: new ObservableTaskTimeoutIterator( - iterableTask: function () use ( + $this->observableTaskTable->observe(ObservableTaskFactory::withTimeout( + iterableTask: function () use ( + $webSocketAuthResolution, + $webSocketConnection, + $rpcRequest, + ): Generator { + yield from $this->onObservableRequest( $webSocketAuthResolution, $webSocketConnection, $rpcRequest, - ): Generator { - yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Running, null); - - try { - yield from $this->onObservableRequest( - $webSocketAuthResolution, - $webSocketConnection, - $rpcRequest, - ); - } finally { - yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Finished, null); - } - }, - inactivityTimeout: 5.0, - ), + ); + }, + inactivityTimeout: 5.0, name: 'websocket_jsonrpc_response', - category: 'llama_cpp', + category: ObservableTaskCategory::LlamaCpp->value, )); } @@ -160,6 +152,8 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs break; } + yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Running, null); + $this->onResponseChunk( webSocketAuthResolution: $webSocketAuthResolution, webSocketConnection: $webSocketConnection, -- GitLab