diff --git a/src/ObservableTask.php b/src/ObservableTask.php index 03256c9b3f26a409c3fcfd0dc043dcfe87020e2a..3cd404394397b03595ac3ad8fde0c88cb8cce496 100644 --- a/src/ObservableTask.php +++ b/src/ObservableTask.php @@ -8,6 +8,9 @@ use Closure; use Generator; use Throwable; +/** + * @psalm-type TIterableTaskCallback = callable():iterable<ObservableTaskStatusUpdate> + */ readonly class ObservableTask implements ObservableTaskInterface { /** @@ -16,7 +19,7 @@ readonly class ObservableTask implements ObservableTaskInterface private Closure $iterableTask; /** - * @param callable():iterable<ObservableTaskStatusUpdate> $iterableTask + * @param TIterableTaskCallback $iterableTask */ public function __construct( callable $iterableTask, diff --git a/src/ObservableTaskCategory.php b/src/ObservableTaskCategory.php new file mode 100644 index 0000000000000000000000000000000000000000..a9901c7640d8328522bf4d50483f61231ce695f3 --- /dev/null +++ b/src/ObservableTaskCategory.php @@ -0,0 +1,10 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance; + +enum ObservableTaskCategory: string +{ + case LlamaCpp = 'llama_cpp'; +} diff --git a/src/ObservableTaskFactory.php b/src/ObservableTaskFactory.php new file mode 100644 index 0000000000000000000000000000000000000000..7c88bbd459dc17c584f37e2c60ee519097b2b809 --- /dev/null +++ b/src/ObservableTaskFactory.php @@ -0,0 +1,43 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance; + +use Generator; + +/** + * @psalm-import-type TIterableTaskCallback from ObservableTask + */ +final readonly class ObservableTaskFactory +{ + public static function withTimeout( + callable $iterableTask, + float $inactivityTimeout = 5.0, + string $name = '', + string $category = '', + ): ObservableTask { + return new ObservableTask( + iterableTask: new ObservableTaskTimeoutIterator( + iterableTask: static function () use ($iterableTask): Generator { + yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Running, null); + + try { + yield from $iterableTask(); + } finally { + yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Finished, null); + } + }, + inactivityTimeout: $inactivityTimeout, + ), + name: $name, + category: $category, + ); + } + + /** + * @psalm-suppress UnusedConstructor this class is just a wrapper around + * functions + */ + private function __construct() {} +} diff --git a/src/ObservableTaskTimeoutIterator.php b/src/ObservableTaskTimeoutIterator.php index 38aac69e4f03586b0590dff330ba798f0f914c5e..516b3865bd56d1194ec8839c10d50992ef9dcb8d 100644 --- a/src/ObservableTaskTimeoutIterator.php +++ b/src/ObservableTaskTimeoutIterator.php @@ -11,17 +11,19 @@ use Swoole\Coroutine; use Swoole\Coroutine\Channel; /** + * @psalm-import-type TIterableTaskCallback from ObservableTask + * * @template-implements IteratorAggregate<ObservableTaskStatusUpdate> */ readonly class ObservableTaskTimeoutIterator implements IteratorAggregate { /** - * @var Closure():Generator<ObservableTaskStatusUpdate> + * @var Closure():iterable<ObservableTaskStatusUpdate> */ private Closure $iterableTask; /** - * @param callable():Generator<ObservableTaskStatusUpdate> $iterableTask + * @param TIterableTaskCallback $iterableTask */ public function __construct( callable $iterableTask, diff --git a/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php b/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php index 9a696af87ae7e377a701813dcb6626e9fcf08425..8c3603dec450c351275487fd87813a9b4508e457 100644 --- a/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php +++ b/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php @@ -12,11 +12,11 @@ use Distantmagic\Resonance\LlamaCppCompletionRequest; use Distantmagic\Resonance\LlmPrompt\SubjectActionPrompt; use Distantmagic\Resonance\LlmPromptTemplate; use Distantmagic\Resonance\LlmPromptTemplate\ChainPrompt; -use Distantmagic\Resonance\ObservableTask; +use Distantmagic\Resonance\ObservableTaskCategory; +use Distantmagic\Resonance\ObservableTaskFactory; use Distantmagic\Resonance\ObservableTaskStatus; use Distantmagic\Resonance\ObservableTaskStatusUpdate; use Distantmagic\Resonance\ObservableTaskTable; -use Distantmagic\Resonance\ObservableTaskTimeoutIterator; use Distantmagic\Resonance\PromptSubjectResponderAggregate; use Distantmagic\Resonance\WebSocketAuthResolution; use Distantmagic\Resonance\WebSocketConnection; @@ -86,29 +86,21 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs WebSocketConnection $webSocketConnection, JsonRPCRequest $rpcRequest, ): void { - $this->observableTaskTable->observe(new ObservableTask( - iterableTask: new ObservableTaskTimeoutIterator( - iterableTask: function () use ( + $this->observableTaskTable->observe(ObservableTaskFactory::withTimeout( + iterableTask: function () use ( + $webSocketAuthResolution, + $webSocketConnection, + $rpcRequest, + ): Generator { + yield from $this->onObservableRequest( $webSocketAuthResolution, $webSocketConnection, $rpcRequest, - ): Generator { - yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Running, null); - - try { - yield from $this->onObservableRequest( - $webSocketAuthResolution, - $webSocketConnection, - $rpcRequest, - ); - } finally { - yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Finished, null); - } - }, - inactivityTimeout: 5.0, - ), + ); + }, + inactivityTimeout: 5.0, name: 'websocket_jsonrpc_response', - category: 'llama_cpp', + category: ObservableTaskCategory::LlamaCpp->value, )); } @@ -160,6 +152,8 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs break; } + yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Running, null); + $this->onResponseChunk( webSocketAuthResolution: $webSocketAuthResolution, webSocketConnection: $webSocketConnection,