diff --git a/src/DialogueNode.php b/src/DialogueNode.php index cfeaa7d766d13792c9d3d61cf997385d25f5fefe..83ac2456eaa4c5eaea09edcee38c0f560fd573ee 100644 --- a/src/DialogueNode.php +++ b/src/DialogueNode.php @@ -47,8 +47,8 @@ readonly class DialogueNode implements DialogueNodeInterface { foreach ($other->getPotentialResponses() as $response) { $this->addPotentialResponse($response); - } } + } public function getMessageProducer(): DialogueMessageProducerInterface { diff --git a/src/DialogueResponseSortedIterator.php b/src/DialogueResponseSortedIterator.php index 4943da417465e411f4eead28097b121376ecee3c..f37614b47b1a8a3c562f2d0cac22caf93c55366c 100644 --- a/src/DialogueResponseSortedIterator.php +++ b/src/DialogueResponseSortedIterator.php @@ -4,8 +4,7 @@ declare(strict_types=1); namespace Distantmagic\Resonance; -use Ds\PriorityQueue; -use Ds\Stack; +use Ds\Set; use Generator; use IteratorAggregate; @@ -15,10 +14,10 @@ use IteratorAggregate; readonly class DialogueResponseSortedIterator implements IteratorAggregate { /** - * @param iterable<DialogueResponseInterface> $responses + * @param Set<DialogueResponseInterface> $responses */ public function __construct( - private iterable $responses, + private Set $responses, ) {} /** @@ -26,29 +25,17 @@ readonly class DialogueResponseSortedIterator implements IteratorAggregate */ public function getIterator(): Generator { - /** - * @var PriorityQueue<DialogueResponseInterface> $responsesPriorityQueue - */ - $responsesPriorityQueue = new PriorityQueue(); - - foreach ($this->responses as $response) { - $responsesPriorityQueue->push( - $response, - $response->getCost(), - ); - } - - /** - * @var Stack<DialogueResponseInterface> $sortedResponses - */ - $sortedResponses = new Stack(); + $responses = $this->responses->toArray(); - foreach ($responsesPriorityQueue as $response) { - $sortedResponses->push($response); - } + usort($responses, $this->compareResponses(...)); - foreach ($sortedResponses as $response) { + foreach ($responses as $response) { yield $response; } } + + private function compareResponses(DialogueResponseInterface $a, DialogueResponseInterface $b): int + { + return $a->getCost() <=> $b->getCost(); + } } diff --git a/src/DialogueResponseSortedIteratorTest.php b/src/DialogueResponseSortedIteratorTest.php index 991b6a88170a8ee6470d816d9445511b3414a2d1..887e715a3e44658106410720e6e9e8acc7b786c0 100644 --- a/src/DialogueResponseSortedIteratorTest.php +++ b/src/DialogueResponseSortedIteratorTest.php @@ -6,6 +6,8 @@ namespace Distantmagic\Resonance; use Distantmagic\Resonance\DialogueResponse\LiteralInputResponse; use Distantmagic\Resonance\DialogueResponse\LlamaCppExtractSubjectResponse; +use Distantmagic\Resonance\DialogueResponse\LlamaCppExtractWhenResponse; +use Ds\Set; use Mockery; use PHPUnit\Framework\Attributes\CoversClass; use PHPUnit\Framework\TestCase; @@ -31,21 +33,34 @@ final class DialogueResponseSortedIteratorTest extends TestCase }, ); - $response2 = new LiteralInputResponse( + $response2 = new LlamaCppExtractWhenResponse( + llamaCppExtractWhen: Mockery::mock(LlamaCppExtractWhenInterface::class), + condition: "user's occupation", + whenProvided: static function (): DialogueResponseResolution { + return new DialogueResponseResolution( + followUp: null, + status: DialogueResponseResolutionStatus::CannotRespond, + ); + }, + ); + + $response3 = new LiteralInputResponse( when: 'marketing', followUp: $marketingNode, ); - $responses = [ + $responses = new Set([ $response1, $response2, - ]; + $response3, + ]); $sortedResponses = iterator_to_array(new DialogueResponseSortedIterator($responses)); self::assertEquals([ - $response2, + $response3, $response1, + $response2, ], $sortedResponses); } } diff --git a/src/LlamaCppExtractWhen.php b/src/LlamaCppExtractWhen.php index cb7b097ccde3da249699440d4c6f2d82a158c184..4faf13ae29de4a9243bc1b7fc8f9363e7d181fed 100644 --- a/src/LlamaCppExtractWhen.php +++ b/src/LlamaCppExtractWhen.php @@ -56,8 +56,17 @@ readonly class LlamaCppExtractWhen implements LlamaCppExtractWhenInterface $ret .= $token; } + $yesNoMaybe = YesNoMaybe::tryFrom($ret); + + if (!($yesNoMaybe instanceof YesNoMaybe)) { + return new LlamaCppExtractYesNoMaybeResult( + result: null, + isFailed: true, + ); + } + return new LlamaCppExtractYesNoMaybeResult( - result: YesNoMaybe::from($ret), + result: $yesNoMaybe, isFailed: false, ); } diff --git a/src/ObservableTaskFactory.php b/src/ObservableTaskFactory.php index 7c88bbd459dc17c584f37e2c60ee519097b2b809..3bf9a5d155a26bed5906043b0b5eddb7360e3a5e 100644 --- a/src/ObservableTaskFactory.php +++ b/src/ObservableTaskFactory.php @@ -22,11 +22,8 @@ final readonly class ObservableTaskFactory iterableTask: static function () use ($iterableTask): Generator { yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Running, null); - try { - yield from $iterableTask(); - } finally { - yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Finished, null); - } + yield from $iterableTask(); + yield new ObservableTaskStatusUpdate(ObservableTaskStatus::Finished, null); }, inactivityTimeout: $inactivityTimeout, ),