From d8fe74070df7e91b57b3dd3e6e9ced7d8ec980cc Mon Sep 17 00:00:00 2001
From: Mateusz Charytoniuk <mateusz.charytoniuk@protonmail.com>
Date: Sat, 23 Mar 2024 17:30:55 +0100
Subject: [PATCH] feat: observable tasks categories

---
 src/JsonRPCNotificationError.php              | 38 +++++++++++++
 src/JsonRPCResponseError.php                  | 39 ++++++++++++++
 src/ObservableTask.php                        | 17 +++++-
 src/ObservableTaskInterface.php               |  7 ++-
 src/ObservableTaskStatusUpdate.php            |  2 +-
 src/ObservableTaskTable.php                   | 54 +++++++++++++++----
 src/ObservableTaskTableRow.php                | 17 ++++++
 src/Session.php                               |  2 +-
 .../LlamaCppSubjectActionPromptResponder.php  | 20 +++++--
 src/views/observable_tasks_dashboard.twig     |  8 ++-
 10 files changed, 186 insertions(+), 18 deletions(-)
 create mode 100644 src/JsonRPCNotificationError.php
 create mode 100644 src/JsonRPCResponseError.php
 create mode 100644 src/ObservableTaskTableRow.php

diff --git a/src/JsonRPCNotificationError.php b/src/JsonRPCNotificationError.php
new file mode 100644
index 00000000..136722c5
--- /dev/null
+++ b/src/JsonRPCNotificationError.php
@@ -0,0 +1,38 @@
+<?php
+
+declare(strict_types=1);
+
+namespace Distantmagic\Resonance;
+
+use Stringable;
+
+/**
+ * @psalm-suppress PossiblyUnusedProperty
+ *
+ * @template TPayload
+ */
+readonly class JsonRPCNotificationError implements Stringable
+{
+    /**
+     * @param TPayload $payload
+     */
+    public function __construct(
+        public JsonRPCMethodInterface $method,
+        public int $code = -32000,
+        public string $message = 'Server error',
+        public mixed $payload = null,
+    ) {}
+
+    public function __toString(): string
+    {
+        return json_encode([
+            'jsonrpc' => '2.0',
+            'method' => $this->method->getValue(),
+            'error' => [
+                'code' => $this->code,
+                'data' => $this->payload,
+                'message' => $this->message,
+            ],
+        ]);
+    }
+}
diff --git a/src/JsonRPCResponseError.php b/src/JsonRPCResponseError.php
new file mode 100644
index 00000000..7efac502
--- /dev/null
+++ b/src/JsonRPCResponseError.php
@@ -0,0 +1,39 @@
+<?php
+
+declare(strict_types=1);
+
+namespace Distantmagic\Resonance;
+
+use Stringable;
+
+/**
+ * @psalm-suppress PossiblyUnusedProperty
+ *
+ * @template TPayload
+ */
+readonly class JsonRPCResponseError implements Stringable
+{
+    /**
+     * @param TPayload $payload
+     */
+    public function __construct(
+        private JsonRPCRequest $rpcRequest,
+        public int $code = -32000,
+        public string $message = 'Server error',
+        public mixed $payload = null,
+    ) {}
+
+    public function __toString(): string
+    {
+        return json_encode([
+            'id' => $this->rpcRequest->requestId,
+            'jsonrpc' => '2.0',
+            'method' => $this->rpcRequest->method->getValue(),
+            'error' => [
+                'code' => $this->code,
+                'data' => $this->payload,
+                'message' => $this->message,
+            ],
+        ]);
+    }
+}
diff --git a/src/ObservableTask.php b/src/ObservableTask.php
index b3d5b5a1..03256c9b 100644
--- a/src/ObservableTask.php
+++ b/src/ObservableTask.php
@@ -18,11 +18,19 @@ readonly class ObservableTask implements ObservableTaskInterface
     /**
      * @param callable():iterable<ObservableTaskStatusUpdate> $iterableTask
      */
-    public function __construct(callable $iterableTask)
-    {
+    public function __construct(
+        callable $iterableTask,
+        private string $name = '',
+        private string $category = '',
+    ) {
         $this->iterableTask = Closure::fromCallable($iterableTask);
     }
 
+    public function getCategory(): string
+    {
+        return $this->category;
+    }
+
     public function getIterator(): Generator
     {
         try {
@@ -34,4 +42,9 @@ readonly class ObservableTask implements ObservableTaskInterface
             );
         }
     }
+
+    public function getName(): string
+    {
+        return $this->name;
+    }
 }
diff --git a/src/ObservableTaskInterface.php b/src/ObservableTaskInterface.php
index f862534c..a09e9971 100644
--- a/src/ObservableTaskInterface.php
+++ b/src/ObservableTaskInterface.php
@@ -9,4 +9,9 @@ use IteratorAggregate;
 /**
  * @template-extends IteratorAggregate<ObservableTaskStatusUpdate>
  */
-interface ObservableTaskInterface extends IteratorAggregate {}
+interface ObservableTaskInterface extends IteratorAggregate
+{
+    public function getCategory(): string;
+
+    public function getName(): string;
+}
diff --git a/src/ObservableTaskStatusUpdate.php b/src/ObservableTaskStatusUpdate.php
index bfcbe8fe..38e96c4d 100644
--- a/src/ObservableTaskStatusUpdate.php
+++ b/src/ObservableTaskStatusUpdate.php
@@ -16,7 +16,7 @@ readonly class ObservableTaskStatusUpdate implements JsonSerializable
      */
     public function __construct(
         public ObservableTaskStatus $status,
-        public mixed $data
+        public mixed $data,
     ) {}
 
     public function jsonSerialize(): array
diff --git a/src/ObservableTaskTable.php b/src/ObservableTaskTable.php
index 7deb2d4b..b168bed6 100644
--- a/src/ObservableTaskTable.php
+++ b/src/ObservableTaskTable.php
@@ -13,7 +13,7 @@ use Swoole\Coroutine;
 use Swoole\Table;
 
 /**
- * @template-implements IteratorAggregate<non-empty-string,?ObservableTaskStatusUpdate>
+ * @template-implements IteratorAggregate<non-empty-string,ObservableTaskTableRow>
  */
 #[Singleton]
 readonly class ObservableTaskTable implements IteratorAggregate
@@ -38,12 +38,14 @@ readonly class ObservableTaskTable implements IteratorAggregate
         );
 
         $this->table = new Table(2 * $observableTaskConfiguration->maxTasks);
+        $this->table->column('category', Table::TYPE_STRING, 255);
+        $this->table->column('name', Table::TYPE_STRING, 255);
         $this->table->column('status', Table::TYPE_STRING, $observableTaskConfiguration->serializedStatusSize);
         $this->table->create();
     }
 
     /**
-     * @return Generator<non-empty-string,?ObservableTaskStatusUpdate>
+     * @return Generator<non-empty-string,ObservableTaskTableRow>
      */
     public function getIterator(): Generator
     {
@@ -52,7 +54,11 @@ readonly class ObservableTaskTable implements IteratorAggregate
          * @var mixed            $row explicitly mixed for typechecks
          */
         foreach ($this->table as $slotId => $row) {
-            yield $slotId => $this->unserializeTableRow($row);
+            $unserializedRow = $this->unserializeTableRow($row);
+
+            if ($unserializedRow) {
+                yield $slotId => $unserializedRow;
+            }
         }
     }
 
@@ -61,7 +67,13 @@ readonly class ObservableTaskTable implements IteratorAggregate
      */
     public function getStatus(string $taskId): ?ObservableTaskStatusUpdate
     {
-        return $this->unserializeTableRow($this->table->get($taskId));
+        $row = $this->table->get($taskId);
+
+        if (!is_array($row)) {
+            return null;
+        }
+
+        return $this->unserializeTableStatusColumn($row);
     }
 
     /**
@@ -78,6 +90,8 @@ readonly class ObservableTaskTable implements IteratorAggregate
 
             if (
                 !$this->table->set($slotId, [
+                    'category' => $observableTask->getCategory(),
+                    'name' => $observableTask->getName(),
                     'status' => $this->serializedPendingStatus,
                 ])
             ) {
@@ -85,9 +99,12 @@ readonly class ObservableTaskTable implements IteratorAggregate
             }
 
             foreach ($observableTask as $statusUpdate) {
-                if (!$this->table->set($slotId, [
-                    'status' => $this->serializer->serialize($statusUpdate),
-                ])
+                if (
+                    !$this->table->set($slotId, [
+                        'category' => $observableTask->getCategory(),
+                        'name' => $observableTask->getName(),
+                        'status' => $this->serializer->serialize($statusUpdate),
+                    ])
                 ) {
                     throw new RuntimeException('Unable to update a slot status.');
                 }
@@ -117,9 +134,28 @@ readonly class ObservableTaskTable implements IteratorAggregate
         return $slotId;
     }
 
-    private function unserializeTableRow(mixed $row): ?ObservableTaskStatusUpdate
+    private function unserializeTableRow(mixed $row): ?ObservableTaskTableRow
+    {
+        if (!is_array($row)) {
+            return null;
+        }
+
+        $observableTaskStatusUpdate = $this->unserializeTableStatusColumn($row);
+
+        if (is_null($observableTaskStatusUpdate) || !is_string($row['name']) || !is_string($row['category'])) {
+            return null;
+        }
+
+        return new ObservableTaskTableRow(
+            name: $row['name'],
+            category: $row['category'],
+            observableTaskStatusUpdate: $observableTaskStatusUpdate,
+        );
+    }
+
+    private function unserializeTableStatusColumn(array $row): ?ObservableTaskStatusUpdate
     {
-        if (!is_array($row) || !is_string($row['status'])) {
+        if (!is_string($row['status'])) {
             return null;
         }
 
diff --git a/src/ObservableTaskTableRow.php b/src/ObservableTaskTableRow.php
new file mode 100644
index 00000000..e1db6fe3
--- /dev/null
+++ b/src/ObservableTaskTableRow.php
@@ -0,0 +1,17 @@
+<?php
+
+declare(strict_types=1);
+
+namespace Distantmagic\Resonance;
+
+/**
+ * @psalm-suppress PossiblyUnusedProperty it's used in the templates
+ */
+readonly class ObservableTaskTableRow
+{
+    public function __construct(
+        public ObservableTaskStatusUpdate $observableTaskStatusUpdate,
+        public string $category,
+        public string $name,
+    ) {}
+}
diff --git a/src/Session.php b/src/Session.php
index 7849a06e..889247e6 100644
--- a/src/Session.php
+++ b/src/Session.php
@@ -57,7 +57,7 @@ readonly class Session
     {
         $storedValue = $this->redis->get($this->id);
 
-        if (!is_string($storedValue)) {
+        if (!is_string($storedValue) || empty($storedValue)) {
             return null;
         }
 
diff --git a/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php b/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php
index 1de9c946..fcb48713 100644
--- a/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php
+++ b/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php
@@ -42,6 +42,12 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs
      */
     abstract protected function getPromptFromPayload(mixed $payload): string;
 
+    abstract protected function onRequestFailure(
+        WebSocketAuthResolution $webSocketAuthResolution,
+        WebSocketConnection $webSocketConnection,
+        JsonRPCRequest $rpcRequest,
+    ): void;
+
     abstract protected function onResponseChunk(
         WebSocketAuthResolution $webSocketAuthResolution,
         WebSocketConnection $webSocketConnection,
@@ -81,7 +87,7 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs
         JsonRPCRequest $rpcRequest,
     ): void {
         $this->observableTaskTable->observe(new ObservableTask(
-            new ObservableTaskTimeoutIterator(
+            iterableTask: new ObservableTaskTimeoutIterator(
                 iterableTask: function () use (
                     $webSocketAuthResolution,
                     $webSocketConnection,
@@ -99,8 +105,10 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs
                         yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Finished, null);
                     }
                 },
-                inactivityTimeout: 1.0,
-            )
+                inactivityTimeout: 3.0,
+            ),
+            name: 'websocket_jsonrpc_response',
+            category: 'llama_cpp',
         ));
     }
 
@@ -139,6 +147,12 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs
             if ($responseChunk->isFailed) {
                 yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Failed, null);
 
+                $this->onRequestFailure(
+                    webSocketAuthResolution: $webSocketAuthResolution,
+                    webSocketConnection: $webSocketConnection,
+                    rpcRequest: $rpcRequest,
+                );
+
                 break;
             }
 
diff --git a/src/views/observable_tasks_dashboard.twig b/src/views/observable_tasks_dashboard.twig
index 714c401c..297fed2a 100644
--- a/src/views/observable_tasks_dashboard.twig
+++ b/src/views/observable_tasks_dashboard.twig
@@ -9,13 +9,19 @@
     <table>
         <thead>
             <tr>
+                <th>slot</th>
+                <th>status</th>
+                <th>category</th>
+                <th>name</th>
             </tr>
         </thead>
         <tbody>
             {% for slotId, observableTask in observableTaskTable %}
                 <tr>
                     <td>{{ slotId }}</td>
-                    <td>{{ observableTask.status.value }}</td>
+                    <td>{{ observableTask.observableTaskStatusUpdate.status.value }}</td>
+                    <td>{{ observableTask.category }}</td>
+                    <td>{{ observableTask.name }}</td>
                 </tr>
             {% endfor %}
         </tbody>
-- 
GitLab