Skip to content
Snippets Groups Projects
Commit 130a1508 authored by Mateusz Charytoniuk's avatar Mateusz Charytoniuk
Browse files

chore: feature extraction from string

parent f56959d0
No related branches found
No related tags found
No related merge requests found
Showing
with 247 additions and 157 deletions
......@@ -8,14 +8,19 @@
displayDetailsOnTestsThatTriggerWarnings="true"
processIsolation="false"
>
<testsuites>
<testsuite name="Unit">
<directory suffix="Test.php">src</directory>
</testsuite>
</testsuites>
<groups>
<exclude>
<group>llamacpp</group>
</exclude>
</groups>
<source>
<include>
<directory suffix=".php">src</directory>
</include>
</source>
<testsuites>
<testsuite name="Unit">
<directory suffix="Test.php">src</directory>
</testsuite>
</testsuites>
</phpunit>
......@@ -13,14 +13,12 @@ readonly class DialogueNode implements DialogueNodeInterface
*/
private Set $responses;
public function __construct(
private DialogueMessageProducerInterface $message,
private DialogueResponseDiscriminatorInterface $responseDiscriminator,
) {
public function __construct(private DialogueMessageProducerInterface $message)
{
$this->responses = new Set();
}
public function addResponse(DialogueResponseInterface $response): void
public function addPotentialResponse(DialogueResponseInterface $response): void
{
$this->responses->add($response);
}
......@@ -30,8 +28,14 @@ readonly class DialogueNode implements DialogueNodeInterface
return $this->message;
}
public function respondTo(DialogueInputInterface $prompt): ?DialogueNodeInterface
public function respondTo(DialogueInputInterface $dialogueInput): ?DialogueNodeInterface
{
return $this->responseDiscriminator->discriminate($this->responses, $prompt);
foreach (new DialogueResponseSortedIterator($this->responses) as $response) {
if ($response->getCondition()->isMetBy($dialogueInput)) {
return $response->getFollowUp();
}
}
return null;
}
}
......@@ -6,7 +6,7 @@ namespace Distantmagic\Resonance;
interface DialogueNodeInterface
{
public function addResponse(DialogueResponseInterface $response): void;
public function addPotentialResponse(DialogueResponseInterface $response): void;
public function getMessageProducer(): DialogueMessageProducerInterface;
......
......@@ -6,7 +6,7 @@ namespace Distantmagic\Resonance;
use Distantmagic\Resonance\DialogueInput\UserInput;
use Distantmagic\Resonance\DialogueMessageProducer\ConstMessageProducer;
use Distantmagic\Resonance\DialogueResponseCondition\ExactInputCondition;
use Distantmagic\Resonance\DialogueResponse\LiteralInputResponse;
use Distantmagic\Resonance\DialogueResponseCondition\LlamaCppInputCondition;
use Mockery;
use PHPUnit\Framework\Attributes\CoversClass;
......@@ -20,38 +20,33 @@ final class DialogueNodeTest extends TestCase
{
public function test_dialogue_produces_no_response(): void
{
$responseDiscriminator = new DialogueResponseDiscriminator();
$rootNode = new DialogueNode(
message: new ConstMessageProducer('What is your current role?'),
responseDiscriminator: $responseDiscriminator,
);
$marketingNode = new DialogueNode(
message: new ConstMessageProducer('Hello, marketer!'),
responseDiscriminator: $responseDiscriminator,
);
$rootNode->addResponse(new DialogueResponse(
$rootNode->addPotentialResponse(new DialogueResponse(
when: new LlamaCppInputCondition(
Mockery::mock(LlamaCppClientInterface::class),
'marketing'
'User states that they are working in a marketing department',
),
followUp: $marketingNode,
));
$rootNode->addResponse(new DialogueResponse(
when: new ExactInputCondition('marketing'),
$rootNode->addPotentialResponse(new LiteralInputResponse(
when: 'marketing',
followUp: $marketingNode,
));
$invalidNode = new DialogueNode(
message: new ConstMessageProducer('nope :('),
responseDiscriminator: $responseDiscriminator,
);
$rootNode->addResponse(new DialogueResponse(
when: new ExactInputCondition('not_a_marketing'),
$rootNode->addPotentialResponse(new LiteralInputResponse(
when: 'not_a_marketing',
followUp: $invalidNode,
));
......
......@@ -4,20 +4,4 @@ declare(strict_types=1);
namespace Distantmagic\Resonance;
readonly class DialogueResponse implements DialogueResponseInterface
{
public function __construct(
private DialogueNodeInterface $followUp,
private DialogueResponseConditionInterface $when,
) {}
public function getCondition(): DialogueResponseConditionInterface
{
return $this->when;
}
public function getFollowUp(): DialogueNodeInterface
{
return $this->followUp;
}
}
abstract readonly class DialogueResponse implements DialogueResponseInterface {}
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance\DialogueResponse;
use Distantmagic\Resonance\DialogueInputInterface;
use Distantmagic\Resonance\DialogueNodeInterface;
use Distantmagic\Resonance\DialogueResponse;
use Distantmagic\Resonance\DialogueResponseResolution;
use Distantmagic\Resonance\DialogueResponseResolutionStatus;
readonly class LiteralInputResponse extends DialogueResponse
{
public function __construct(
private string $when,
private DialogueNodeInterface $followUp,
) {}
public function getCost(): int
{
return 2;
}
public function resolveResponse(DialogueInputInterface $dialogueInput): DialogueResponseResolution
{
if ($dialogueInput->getContent() === $this->when) {
return new DialogueResponseResolution(
status: DialogueResponseResolutionStatus::CanRespond,
);
}
return new DialogueResponseResolution(
status: DialogueResponseResolutionStatus::CannotRespond,
);
}
}
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance\DialogueResponse;
use Distantmagic\Resonance\DialogueInputInterface;
use Distantmagic\Resonance\DialogueResponse;
use Distantmagic\Resonance\DialogueResponseResolution;
readonly class LlamaCppExtractInputResponse extends DialogueResponse
{
public function getCost(): int
{
return 50;
}
public function resolveResponse(DialogueInputInterface $prompt): DialogueResponseResolution {}
}
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance;
abstract readonly class DialogueResponseCondition implements DialogueResponseConditionInterface {}
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance\DialogueResponseCondition;
use Distantmagic\Resonance\DialogueInputInterface;
use Distantmagic\Resonance\DialogueResponseCondition;
readonly class ExactInputCondition extends DialogueResponseCondition
{
public function __construct(
public string $content,
) {}
public function getCost(): int
{
return 2;
}
public function isMetBy(DialogueInputInterface $dialogueInput): bool
{
return $dialogueInput->getContent() === $this->content;
}
}
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance\DialogueResponseCondition;
use Distantmagic\Resonance\DialogueInputInterface;
use Distantmagic\Resonance\DialogueResponseCondition;
use Distantmagic\Resonance\LlamaCppClientInterface;
readonly class LlamaCppInputCondition extends DialogueResponseCondition
{
public function __construct(
public LlamaCppClientInterface $llamaCppClient,
public string $content,
) {}
public function getCost(): int
{
return 50;
}
public function isMetBy(DialogueInputInterface $dialogueInput): bool {}
}
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance;
interface DialogueResponseConditionInterface
{
public function getCost(): int;
public function isMetBy(DialogueInputInterface $dialogueInput): bool;
}
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance;
use Distantmagic\Resonance\Attribute\Singleton;
#[Singleton]
readonly class DialogueResponseDiscriminator implements DialogueResponseDiscriminatorInterface
{
/**
* @param iterable<DialogueResponseInterface> $responses
*/
public function discriminate(
iterable $responses,
DialogueInputInterface $dialogueInput,
): ?DialogueNodeInterface {
foreach (new DialogueResponseSortedIterator($responses) as $response) {
if ($response->getCondition()->isMetBy($dialogueInput)) {
return $response->getFollowUp();
}
}
return null;
}
}
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance;
interface DialogueResponseDiscriminatorInterface
{
/**
* @param iterable<DialogueResponseInterface> $responses
*/
public function discriminate(
iterable $responses,
DialogueInputInterface $dialogueInput,
): ?DialogueNodeInterface;
}
......@@ -6,7 +6,7 @@ namespace Distantmagic\Resonance;
interface DialogueResponseInterface
{
public function getCondition(): DialogueResponseConditionInterface;
public function getCost(): int;
public function getFollowUp(): DialogueNodeInterface;
public function resolveResponse(DialogueInputInterface $dialogueInput): DialogueResponseResolutionInterface;
}
......@@ -4,4 +4,8 @@ declare(strict_types=1);
namespace Distantmagic\Resonance;
enum DialogueResponseResolutionStatus {}
enum DialogueResponseResolutionStatus
{
case CannotRespond;
case CanRespond;
}
......@@ -34,7 +34,7 @@ readonly class DialogueResponseSortedIterator implements IteratorAggregate
foreach ($this->responses as $response) {
$responsesPriorityQueue->push(
$response,
$response->getCondition()->getCost(),
$response->getCost(),
);
}
......
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance;
use PHPUnit\Framework\Attributes\CoversClass;
use PHPUnit\Framework\Attributes\Group;
use PHPUnit\Framework\TestCase;
/**
* @internal
*/
#[CoversClass(LlamaCppClient::class)]
#[Group('llamacpp')]
final class LlamaCppClientTest extends TestCase
{
use TestsDependencyInectionContainerTrait;
public function test_request_header_is_parsed(): void
{
$llamaCppClient = self::$container->make(LlamaCppClient::class);
self::assertSame(LlamaCppHealthStatus::Ok, $llamaCppClient->getHealth());
}
}
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance;
use Distantmagic\Resonance\LlmPromptTemplate\MistralInstructChat;
readonly class LlamaCppExtractString
{
public function __construct(
private LlamaCppClientInterface $llamaCppClient,
) {}
public function extract(
string $input,
string $subject,
): ?string {
$completion = $this->llamaCppClient->generateCompletion(
new LlamaCppCompletionRequest(
promptTemplate: new MistralInstructChat(<<<PROMPT
User is about to provide the $subject.
If user provides the $subject, repeat only that $subject, without any additional comment.
If user did not provide $subject or it is not certain, write the empty string: ""
User input:
$input
PROMPT),
),
);
$ret = '';
foreach ($completion as $token) {
$ret .= $token;
}
$trimmed = trim($ret, ' "');
if (0 === strlen($trimmed)) {
return null;
}
return $trimmed;
}
}
<?php
declare(strict_types=1);
namespace Distantmagic\Resonance;
use Generator;
use PHPUnit\Framework\Attributes\CoversClass;
use PHPUnit\Framework\Attributes\DataProvider;
use PHPUnit\Framework\Attributes\Group;
use PHPUnit\Framework\TestCase;
use Swoole\Event;
/**
* @internal
*/
#[CoversClass(LlamaCppExtractString::class)]
#[Group('llamacpp')]
final class LlamaCppExtractStringTest extends TestCase
{
use TestsDependencyInectionContainerTrait;
public static function inputSubjectProvider(): Generator
{
yield 'application name is provided' => [
'application name',
'My application is called PHP Resonance',
'PHP Resonance',
];
yield 'only application name is provided' => [
'application name',
'PHP Resonance',
'PHP Resonance',
];
yield 'application name is not provided' => [
'application name',
'How are you?',
null,
];
yield 'not on topic' => [
'application name',
'Suggest me the best application name',
null,
];
yield 'not sure' => [
'application name',
'I am not really sure at the moment, was thinking about PHP Resonance, but I have to ask my friends first',
null,
];
yield 'feature' => [
'feature',
'I want to add a blog',
'blog',
];
}
protected function tearDown(): void
{
Event::wait();
}
#[DataProvider('inputSubjectProvider')]
public function test_application_name_is_provided(string $subject, string $input, ?string $expected): void
{
$llamaCppExtractString = self::$container->make(LlamaCppExtractString::class);
SwooleCoroutineHelper::mustRun(static function () use ($expected, $input, $llamaCppExtractString, $subject) {
$extracted = $llamaCppExtractString->extract(
subject: $subject,
input: $input,
);
self::assertSame($expected, $extracted);
});
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment