From d8fe74070df7e91b57b3dd3e6e9ced7d8ec980cc Mon Sep 17 00:00:00 2001 From: Mateusz Charytoniuk <mateusz.charytoniuk@protonmail.com> Date: Sat, 23 Mar 2024 17:30:55 +0100 Subject: [PATCH] feat: observable tasks categories --- src/JsonRPCNotificationError.php | 38 +++++++++++++ src/JsonRPCResponseError.php | 39 ++++++++++++++ src/ObservableTask.php | 17 +++++- src/ObservableTaskInterface.php | 7 ++- src/ObservableTaskStatusUpdate.php | 2 +- src/ObservableTaskTable.php | 54 +++++++++++++++---- src/ObservableTaskTableRow.php | 17 ++++++ src/Session.php | 2 +- .../LlamaCppSubjectActionPromptResponder.php | 20 +++++-- src/views/observable_tasks_dashboard.twig | 8 ++- 10 files changed, 186 insertions(+), 18 deletions(-) create mode 100644 src/JsonRPCNotificationError.php create mode 100644 src/JsonRPCResponseError.php create mode 100644 src/ObservableTaskTableRow.php diff --git a/src/JsonRPCNotificationError.php b/src/JsonRPCNotificationError.php new file mode 100644 index 00000000..136722c5 --- /dev/null +++ b/src/JsonRPCNotificationError.php @@ -0,0 +1,38 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance; + +use Stringable; + +/** + * @psalm-suppress PossiblyUnusedProperty + * + * @template TPayload + */ +readonly class JsonRPCNotificationError implements Stringable +{ + /** + * @param TPayload $payload + */ + public function __construct( + public JsonRPCMethodInterface $method, + public int $code = -32000, + public string $message = 'Server error', + public mixed $payload = null, + ) {} + + public function __toString(): string + { + return json_encode([ + 'jsonrpc' => '2.0', + 'method' => $this->method->getValue(), + 'error' => [ + 'code' => $this->code, + 'data' => $this->payload, + 'message' => $this->message, + ], + ]); + } +} diff --git a/src/JsonRPCResponseError.php b/src/JsonRPCResponseError.php new file mode 100644 index 00000000..7efac502 --- /dev/null +++ b/src/JsonRPCResponseError.php @@ -0,0 +1,39 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance; + +use Stringable; + +/** + * @psalm-suppress PossiblyUnusedProperty + * + * @template TPayload + */ +readonly class JsonRPCResponseError implements Stringable +{ + /** + * @param TPayload $payload + */ + public function __construct( + private JsonRPCRequest $rpcRequest, + public int $code = -32000, + public string $message = 'Server error', + public mixed $payload = null, + ) {} + + public function __toString(): string + { + return json_encode([ + 'id' => $this->rpcRequest->requestId, + 'jsonrpc' => '2.0', + 'method' => $this->rpcRequest->method->getValue(), + 'error' => [ + 'code' => $this->code, + 'data' => $this->payload, + 'message' => $this->message, + ], + ]); + } +} diff --git a/src/ObservableTask.php b/src/ObservableTask.php index b3d5b5a1..03256c9b 100644 --- a/src/ObservableTask.php +++ b/src/ObservableTask.php @@ -18,11 +18,19 @@ readonly class ObservableTask implements ObservableTaskInterface /** * @param callable():iterable<ObservableTaskStatusUpdate> $iterableTask */ - public function __construct(callable $iterableTask) - { + public function __construct( + callable $iterableTask, + private string $name = '', + private string $category = '', + ) { $this->iterableTask = Closure::fromCallable($iterableTask); } + public function getCategory(): string + { + return $this->category; + } + public function getIterator(): Generator { try { @@ -34,4 +42,9 @@ readonly class ObservableTask implements ObservableTaskInterface ); } } + + public function getName(): string + { + return $this->name; + } } diff --git a/src/ObservableTaskInterface.php b/src/ObservableTaskInterface.php index f862534c..a09e9971 100644 --- a/src/ObservableTaskInterface.php +++ b/src/ObservableTaskInterface.php @@ -9,4 +9,9 @@ use IteratorAggregate; /** * @template-extends IteratorAggregate<ObservableTaskStatusUpdate> */ -interface ObservableTaskInterface extends IteratorAggregate {} +interface ObservableTaskInterface extends IteratorAggregate +{ + public function getCategory(): string; + + public function getName(): string; +} diff --git a/src/ObservableTaskStatusUpdate.php b/src/ObservableTaskStatusUpdate.php index bfcbe8fe..38e96c4d 100644 --- a/src/ObservableTaskStatusUpdate.php +++ b/src/ObservableTaskStatusUpdate.php @@ -16,7 +16,7 @@ readonly class ObservableTaskStatusUpdate implements JsonSerializable */ public function __construct( public ObservableTaskStatus $status, - public mixed $data + public mixed $data, ) {} public function jsonSerialize(): array diff --git a/src/ObservableTaskTable.php b/src/ObservableTaskTable.php index 7deb2d4b..b168bed6 100644 --- a/src/ObservableTaskTable.php +++ b/src/ObservableTaskTable.php @@ -13,7 +13,7 @@ use Swoole\Coroutine; use Swoole\Table; /** - * @template-implements IteratorAggregate<non-empty-string,?ObservableTaskStatusUpdate> + * @template-implements IteratorAggregate<non-empty-string,ObservableTaskTableRow> */ #[Singleton] readonly class ObservableTaskTable implements IteratorAggregate @@ -38,12 +38,14 @@ readonly class ObservableTaskTable implements IteratorAggregate ); $this->table = new Table(2 * $observableTaskConfiguration->maxTasks); + $this->table->column('category', Table::TYPE_STRING, 255); + $this->table->column('name', Table::TYPE_STRING, 255); $this->table->column('status', Table::TYPE_STRING, $observableTaskConfiguration->serializedStatusSize); $this->table->create(); } /** - * @return Generator<non-empty-string,?ObservableTaskStatusUpdate> + * @return Generator<non-empty-string,ObservableTaskTableRow> */ public function getIterator(): Generator { @@ -52,7 +54,11 @@ readonly class ObservableTaskTable implements IteratorAggregate * @var mixed $row explicitly mixed for typechecks */ foreach ($this->table as $slotId => $row) { - yield $slotId => $this->unserializeTableRow($row); + $unserializedRow = $this->unserializeTableRow($row); + + if ($unserializedRow) { + yield $slotId => $unserializedRow; + } } } @@ -61,7 +67,13 @@ readonly class ObservableTaskTable implements IteratorAggregate */ public function getStatus(string $taskId): ?ObservableTaskStatusUpdate { - return $this->unserializeTableRow($this->table->get($taskId)); + $row = $this->table->get($taskId); + + if (!is_array($row)) { + return null; + } + + return $this->unserializeTableStatusColumn($row); } /** @@ -78,6 +90,8 @@ readonly class ObservableTaskTable implements IteratorAggregate if ( !$this->table->set($slotId, [ + 'category' => $observableTask->getCategory(), + 'name' => $observableTask->getName(), 'status' => $this->serializedPendingStatus, ]) ) { @@ -85,9 +99,12 @@ readonly class ObservableTaskTable implements IteratorAggregate } foreach ($observableTask as $statusUpdate) { - if (!$this->table->set($slotId, [ - 'status' => $this->serializer->serialize($statusUpdate), - ]) + if ( + !$this->table->set($slotId, [ + 'category' => $observableTask->getCategory(), + 'name' => $observableTask->getName(), + 'status' => $this->serializer->serialize($statusUpdate), + ]) ) { throw new RuntimeException('Unable to update a slot status.'); } @@ -117,9 +134,28 @@ readonly class ObservableTaskTable implements IteratorAggregate return $slotId; } - private function unserializeTableRow(mixed $row): ?ObservableTaskStatusUpdate + private function unserializeTableRow(mixed $row): ?ObservableTaskTableRow + { + if (!is_array($row)) { + return null; + } + + $observableTaskStatusUpdate = $this->unserializeTableStatusColumn($row); + + if (is_null($observableTaskStatusUpdate) || !is_string($row['name']) || !is_string($row['category'])) { + return null; + } + + return new ObservableTaskTableRow( + name: $row['name'], + category: $row['category'], + observableTaskStatusUpdate: $observableTaskStatusUpdate, + ); + } + + private function unserializeTableStatusColumn(array $row): ?ObservableTaskStatusUpdate { - if (!is_array($row) || !is_string($row['status'])) { + if (!is_string($row['status'])) { return null; } diff --git a/src/ObservableTaskTableRow.php b/src/ObservableTaskTableRow.php new file mode 100644 index 00000000..e1db6fe3 --- /dev/null +++ b/src/ObservableTaskTableRow.php @@ -0,0 +1,17 @@ +<?php + +declare(strict_types=1); + +namespace Distantmagic\Resonance; + +/** + * @psalm-suppress PossiblyUnusedProperty it's used in the templates + */ +readonly class ObservableTaskTableRow +{ + public function __construct( + public ObservableTaskStatusUpdate $observableTaskStatusUpdate, + public string $category, + public string $name, + ) {} +} diff --git a/src/Session.php b/src/Session.php index 7849a06e..889247e6 100644 --- a/src/Session.php +++ b/src/Session.php @@ -57,7 +57,7 @@ readonly class Session { $storedValue = $this->redis->get($this->id); - if (!is_string($storedValue)) { + if (!is_string($storedValue) || empty($storedValue)) { return null; } diff --git a/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php b/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php index 1de9c946..fcb48713 100644 --- a/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php +++ b/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php @@ -42,6 +42,12 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs */ abstract protected function getPromptFromPayload(mixed $payload): string; + abstract protected function onRequestFailure( + WebSocketAuthResolution $webSocketAuthResolution, + WebSocketConnection $webSocketConnection, + JsonRPCRequest $rpcRequest, + ): void; + abstract protected function onResponseChunk( WebSocketAuthResolution $webSocketAuthResolution, WebSocketConnection $webSocketConnection, @@ -81,7 +87,7 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs JsonRPCRequest $rpcRequest, ): void { $this->observableTaskTable->observe(new ObservableTask( - new ObservableTaskTimeoutIterator( + iterableTask: new ObservableTaskTimeoutIterator( iterableTask: function () use ( $webSocketAuthResolution, $webSocketConnection, @@ -99,8 +105,10 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Finished, null); } }, - inactivityTimeout: 1.0, - ) + inactivityTimeout: 3.0, + ), + name: 'websocket_jsonrpc_response', + category: 'llama_cpp', )); } @@ -139,6 +147,12 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs if ($responseChunk->isFailed) { yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Failed, null); + $this->onRequestFailure( + webSocketAuthResolution: $webSocketAuthResolution, + webSocketConnection: $webSocketConnection, + rpcRequest: $rpcRequest, + ); + break; } diff --git a/src/views/observable_tasks_dashboard.twig b/src/views/observable_tasks_dashboard.twig index 714c401c..297fed2a 100644 --- a/src/views/observable_tasks_dashboard.twig +++ b/src/views/observable_tasks_dashboard.twig @@ -9,13 +9,19 @@ <table> <thead> <tr> + <th>slot</th> + <th>status</th> + <th>category</th> + <th>name</th> </tr> </thead> <tbody> {% for slotId, observableTask in observableTaskTable %} <tr> <td>{{ slotId }}</td> - <td>{{ observableTask.status.value }}</td> + <td>{{ observableTask.observableTaskStatusUpdate.status.value }}</td> + <td>{{ observableTask.category }}</td> + <td>{{ observableTask.name }}</td> </tr> {% endfor %} </tbody> -- GitLab