From d011e3c041b07648aab5b46e0bd411c89c5e0ee6 Mon Sep 17 00:00:00 2001
From: Mateusz Charytoniuk <mateusz.charytoniuk@protonmail.com>
Date: Tue, 26 Mar 2024 21:05:42 +0100
Subject: [PATCH] feat: observable tasks factory

---
 src/ObservableTask.php                        |  5 ++-
 src/ObservableTaskCategory.php                | 10 +++++
 src/ObservableTaskFactory.php                 | 43 +++++++++++++++++++
 src/ObservableTaskTimeoutIterator.php         |  6 ++-
 .../LlamaCppSubjectActionPromptResponder.php  | 36 +++++++---------
 5 files changed, 76 insertions(+), 24 deletions(-)
 create mode 100644 src/ObservableTaskCategory.php
 create mode 100644 src/ObservableTaskFactory.php

diff --git a/src/ObservableTask.php b/src/ObservableTask.php
index 03256c9b..3cd40439 100644
--- a/src/ObservableTask.php
+++ b/src/ObservableTask.php
@@ -8,6 +8,9 @@ use Closure;
 use Generator;
 use Throwable;
 
+/**
+ * @psalm-type TIterableTaskCallback = callable():iterable<ObservableTaskStatusUpdate>
+ */
 readonly class ObservableTask implements ObservableTaskInterface
 {
     /**
@@ -16,7 +19,7 @@ readonly class ObservableTask implements ObservableTaskInterface
     private Closure $iterableTask;
 
     /**
-     * @param callable():iterable<ObservableTaskStatusUpdate> $iterableTask
+     * @param TIterableTaskCallback $iterableTask
      */
     public function __construct(
         callable $iterableTask,
diff --git a/src/ObservableTaskCategory.php b/src/ObservableTaskCategory.php
new file mode 100644
index 00000000..a9901c76
--- /dev/null
+++ b/src/ObservableTaskCategory.php
@@ -0,0 +1,10 @@
+<?php
+
+declare(strict_types=1);
+
+namespace Distantmagic\Resonance;
+
+enum ObservableTaskCategory: string
+{
+    case LlamaCpp = 'llama_cpp';
+}
diff --git a/src/ObservableTaskFactory.php b/src/ObservableTaskFactory.php
new file mode 100644
index 00000000..7c88bbd4
--- /dev/null
+++ b/src/ObservableTaskFactory.php
@@ -0,0 +1,43 @@
+<?php
+
+declare(strict_types=1);
+
+namespace Distantmagic\Resonance;
+
+use Generator;
+
+/**
+ * @psalm-import-type TIterableTaskCallback from ObservableTask
+ */
+final readonly class ObservableTaskFactory
+{
+    public static function withTimeout(
+        callable $iterableTask,
+        float $inactivityTimeout = 5.0,
+        string $name = '',
+        string $category = '',
+    ): ObservableTask {
+        return new ObservableTask(
+            iterableTask: new ObservableTaskTimeoutIterator(
+                iterableTask: static function () use ($iterableTask): Generator {
+                    yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Running, null);
+
+                    try {
+                        yield from $iterableTask();
+                    } finally {
+                        yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Finished, null);
+                    }
+                },
+                inactivityTimeout: $inactivityTimeout,
+            ),
+            name: $name,
+            category: $category,
+        );
+    }
+
+    /**
+     * @psalm-suppress UnusedConstructor this class is just a wrapper around
+     *     functions
+     */
+    private function __construct() {}
+}
diff --git a/src/ObservableTaskTimeoutIterator.php b/src/ObservableTaskTimeoutIterator.php
index 38aac69e..516b3865 100644
--- a/src/ObservableTaskTimeoutIterator.php
+++ b/src/ObservableTaskTimeoutIterator.php
@@ -11,17 +11,19 @@ use Swoole\Coroutine;
 use Swoole\Coroutine\Channel;
 
 /**
+ * @psalm-import-type TIterableTaskCallback from ObservableTask
+ *
  * @template-implements IteratorAggregate<ObservableTaskStatusUpdate>
  */
 readonly class ObservableTaskTimeoutIterator implements IteratorAggregate
 {
     /**
-     * @var Closure():Generator<ObservableTaskStatusUpdate>
+     * @var Closure():iterable<ObservableTaskStatusUpdate>
      */
     private Closure $iterableTask;
 
     /**
-     * @param callable():Generator<ObservableTaskStatusUpdate> $iterableTask
+     * @param TIterableTaskCallback $iterableTask
      */
     public function __construct(
         callable $iterableTask,
diff --git a/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php b/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php
index 9a696af8..8c3603de 100644
--- a/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php
+++ b/src/WebSocketJsonRPCResponder/LlamaCppSubjectActionPromptResponder.php
@@ -12,11 +12,11 @@ use Distantmagic\Resonance\LlamaCppCompletionRequest;
 use Distantmagic\Resonance\LlmPrompt\SubjectActionPrompt;
 use Distantmagic\Resonance\LlmPromptTemplate;
 use Distantmagic\Resonance\LlmPromptTemplate\ChainPrompt;
-use Distantmagic\Resonance\ObservableTask;
+use Distantmagic\Resonance\ObservableTaskCategory;
+use Distantmagic\Resonance\ObservableTaskFactory;
 use Distantmagic\Resonance\ObservableTaskStatus;
 use Distantmagic\Resonance\ObservableTaskStatusUpdate;
 use Distantmagic\Resonance\ObservableTaskTable;
-use Distantmagic\Resonance\ObservableTaskTimeoutIterator;
 use Distantmagic\Resonance\PromptSubjectResponderAggregate;
 use Distantmagic\Resonance\WebSocketAuthResolution;
 use Distantmagic\Resonance\WebSocketConnection;
@@ -86,29 +86,21 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs
         WebSocketConnection $webSocketConnection,
         JsonRPCRequest $rpcRequest,
     ): void {
-        $this->observableTaskTable->observe(new ObservableTask(
-            iterableTask: new ObservableTaskTimeoutIterator(
-                iterableTask: function () use (
+        $this->observableTaskTable->observe(ObservableTaskFactory::withTimeout(
+            iterableTask: function () use (
+                $webSocketAuthResolution,
+                $webSocketConnection,
+                $rpcRequest,
+            ): Generator {
+                yield from $this->onObservableRequest(
                     $webSocketAuthResolution,
                     $webSocketConnection,
                     $rpcRequest,
-                ): Generator {
-                    yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Running, null);
-
-                    try {
-                        yield from $this->onObservableRequest(
-                            $webSocketAuthResolution,
-                            $webSocketConnection,
-                            $rpcRequest,
-                        );
-                    } finally {
-                        yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Finished, null);
-                    }
-                },
-                inactivityTimeout: 5.0,
-            ),
+                );
+            },
+            inactivityTimeout: 5.0,
             name: 'websocket_jsonrpc_response',
-            category: 'llama_cpp',
+            category: ObservableTaskCategory::LlamaCpp->value,
         ));
     }
 
@@ -160,6 +152,8 @@ abstract readonly class LlamaCppSubjectActionPromptResponder extends WebSocketJs
                 break;
             }
 
+            yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Running, null);
+
             $this->onResponseChunk(
                 webSocketAuthResolution: $webSocketAuthResolution,
                 webSocketConnection: $webSocketConnection,
-- 
GitLab