diff --git a/docs/encoders/bedrock.ipynb b/docs/encoders/bedrock.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..54cc0a7dfbb3709ce0e6355e9d0a7508804298ec --- /dev/null +++ b/docs/encoders/bedrock.ipynb @@ -0,0 +1,1323 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/encoders/bedrock.ipynb) [](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/encoders/bedrock.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using Bedrock embedding Models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The 3rd generation embedding models from AWS Bedrock (`amazon.titan-embed-text-v1`, `amazon.titan-embed-text-v2` and `cohere.embed-english-v3`) can both be used with our `BedrockEncoder`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Started" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by installing semantic-router. Support for the new `Bedrock` embedding models was added in `semantic-router==0.0.40`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU \"semantic-router[bedrock]==0.0.40\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by defining a dictionary mapping routes to example phrases that should trigger those routes." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router import Route\n", + "\n", + "politics = Route(\n", + " name=\"politics\",\n", + " utterances=[\n", + " \"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\",\n", + " \"don't you just love the president\",\n", + " \"don't you just hate the president\",\n", + " \"they're going to destroy this country!\",\n", + " \"they will save the country!\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define another for good measure:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "routes = [politics, chitchat]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we initialize our embedding model, we will use the `-3-large` model alongside a `dimensions` value of `256`. This will produce _tiny_ 256-dimensional vectors that — according to OpenAI — outperform the 1536-dimensional vectors produced by `text-embedding-ada-002`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from getpass import getpass\n", + "from semantic_router.encoders import BedrockEncoder\n", + "\n", + "aws_access_key_id = os.getenv(\"AWS_ACCESS_KEY_ID\") or getpass(\n", + " \"Enter AWS Access Key ID: \"\n", + ")\n", + "aws_secret_access_key = os.getenv(\"AWS_SECRET_ACCESS_KEY\") or getpass(\n", + " \"Enter AWS Secret Access Key: \"\n", + ")\n", + "aws_session_token = os.getenv(\"AWS_SESSION_TOKEN\") or getpass(\n", + " \"Enter AWS Session Token: \"\n", + ")\n", + "aws_region = os.getenv(\"AWS_REGION\") or getpass(\"Enter AWS Region: \")\n", + "\n", + "encoder = BedrockEncoder(\n", + " name=\"amazon.titan-embed-image-v1\",\n", + " score_threshold=0.5,\n", + " access_key_id=aws_access_key_id,\n", + " secret_access_key=aws_secret_access_key,\n", + " session_token=aws_session_token,\n", + " region=aws_region,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[0.012878418,\n", + " 0.028442383,\n", + " -0.022094727,\n", + " -0.020751953,\n", + " -0.008300781,\n", + " 0.033691406,\n", + " 0.09326172,\n", + " 0.0045166016,\n", + " 0.033935547,\n", + " 0.015319824,\n", + " 0.012939453,\n", + " 0.015380859,\n", + " 0.012756348,\n", + " -0.064453125,\n", + " 0.018432617,\n", + " 0.03173828,\n", + " -0.018188477,\n", + " -0.007171631,\n", + " 0.03955078,\n", + " 0.0033874512,\n", + " 0.007019043,\n", + " 0.010131836,\n", + " -0.025878906,\n", + " 0.056152344,\n", + " 0.01373291,\n", + " -0.020263672,\n", + " 0.055419922,\n", + " -0.06225586,\n", + " 0.040039062,\n", + " -0.015075684,\n", + " 0.012268066,\n", + " -0.056640625,\n", + " 0.04736328,\n", + " -0.002609253,\n", + " -0.0064086914,\n", + " 0.011291504,\n", + " -0.019165039,\n", + " -0.005493164,\n", + " 0.003189087,\n", + " 0.008666992,\n", + " 0.03564453,\n", + " -0.0027923584,\n", + " -0.016601562,\n", + " 0.014404297,\n", + " -0.01171875,\n", + " 0.013183594,\n", + " -0.018920898,\n", + " -0.030639648,\n", + " 0.010864258,\n", + " 0.052734375,\n", + " -0.006164551,\n", + " 0.0035705566,\n", + " 0.0060424805,\n", + " -0.021606445,\n", + " -0.040527344,\n", + " 0.020385742,\n", + " 0.004638672,\n", + " -0.010314941,\n", + " -0.010681152,\n", + " -0.010803223,\n", + " -0.038330078,\n", + " -0.029174805,\n", + " 0.036865234,\n", + " -0.03112793,\n", + " -0.034179688,\n", + " 0.017944336,\n", + " -0.03515625,\n", + " 0.068847656,\n", + " -0.032470703,\n", + " -0.03540039,\n", + " 0.017944336,\n", + " -0.024047852,\n", + " -0.05834961,\n", + " -0.049804688,\n", + " -0.009277344,\n", + " 0.021484375,\n", + " -0.036376953,\n", + " 0.03540039,\n", + " -0.012939453,\n", + " -0.03491211,\n", + " -0.028808594,\n", + " 0.017333984,\n", + " 0.021484375,\n", + " 0.0052490234,\n", + " -0.026611328,\n", + " -0.0026245117,\n", + " 0.05078125,\n", + " -0.022949219,\n", + " -0.057128906,\n", + " -0.019042969,\n", + " 0.01574707,\n", + " -0.0025482178,\n", + " 0.02355957,\n", + " -0.0011367798,\n", + " 0.0039367676,\n", + " -0.015197754,\n", + " -0.02758789,\n", + " -0.025268555,\n", + " -0.048339844,\n", + " 0.04296875,\n", + " -0.01373291,\n", + " -0.0052490234,\n", + " -0.016357422,\n", + " -0.029663086,\n", + " 0.024536133,\n", + " -0.03881836,\n", + " -0.035888672,\n", + " 0.013793945,\n", + " -0.016357422,\n", + " -0.052734375,\n", + " 0.0154418945,\n", + " -0.004058838,\n", + " -0.018432617,\n", + " -0.01574707,\n", + " 0.06225586,\n", + " 0.044433594,\n", + " -0.011474609,\n", + " 0.019897461,\n", + " -0.018432617,\n", + " -0.03515625,\n", + " 0.057861328,\n", + " -0.016967773,\n", + " 0.008666992,\n", + " 0.01574707,\n", + " 0.024780273,\n", + " 0.01953125,\n", + " -0.005554199,\n", + " 0.042236328,\n", + " -0.026123047,\n", + " -0.111328125,\n", + " 0.018798828,\n", + " 0.018066406,\n", + " -0.032958984,\n", + " -0.0025024414,\n", + " -0.01159668,\n", + " -0.028930664,\n", + " -0.055908203,\n", + " 0.037353516,\n", + " 0.018432617,\n", + " 0.015258789,\n", + " -0.021850586,\n", + " -0.0026245117,\n", + " 0.016723633,\n", + " 0.0095825195,\n", + " 0.05029297,\n", + " 0.011779785,\n", + " 0.04711914,\n", + " -0.064941406,\n", + " 0.0059509277,\n", + " -0.025390625,\n", + " 0.03857422,\n", + " 0.046875,\n", + " 0.015258789,\n", + " 0.03930664,\n", + " 0.02355957,\n", + " 0.03125,\n", + " -0.032958984,\n", + " 0.056640625,\n", + " -0.056396484,\n", + " 0.0146484375,\n", + " 0.0025634766,\n", + " 0.006591797,\n", + " 0.0015563965,\n", + " -0.020385742,\n", + " 0.016723633,\n", + " -0.008972168,\n", + " -0.024169922,\n", + " 0.03125,\n", + " -0.028808594,\n", + " 0.040283203,\n", + " 0.0055236816,\n", + " -0.0025787354,\n", + " 0.067871094,\n", + " -0.004119873,\n", + " -0.03515625,\n", + " 0.030517578,\n", + " 0.0077819824,\n", + " -0.026733398,\n", + " -0.01953125,\n", + " -0.014709473,\n", + " -0.045898438,\n", + " -0.012268066,\n", + " 0.022216797,\n", + " 0.008972168,\n", + " 0.017211914,\n", + " -0.0234375,\n", + " -0.017211914,\n", + " 0.030151367,\n", + " -0.0034942627,\n", + " 0.029174805,\n", + " 0.05029297,\n", + " -0.053222656,\n", + " -0.037841797,\n", + " 0.008117676,\n", + " 0.014038086,\n", + " 0.015563965,\n", + " -0.060791016,\n", + " 0.014221191,\n", + " -0.028808594,\n", + " -0.03955078,\n", + " -0.111328125,\n", + " 0.041992188,\n", + " -0.043945312,\n", + " 0.030273438,\n", + " -0.045898438,\n", + " -0.014770508,\n", + " -0.030395508,\n", + " -0.041748047,\n", + " 0.0011291504,\n", + " -0.034423828,\n", + " -0.04272461,\n", + " 0.008300781,\n", + " 0.014831543,\n", + " -0.018798828,\n", + " -0.017700195,\n", + " -0.014099121,\n", + " -0.011169434,\n", + " -0.029418945,\n", + " 0.027832031,\n", + " 0.010986328,\n", + " 0.030151367,\n", + " -0.021728516,\n", + " 0.004547119,\n", + " -0.034423828,\n", + " -0.01977539,\n", + " 0.047851562,\n", + " 0.021362305,\n", + " 0.044433594,\n", + " 0.06933594,\n", + " -0.0046691895,\n", + " -0.049560547,\n", + " -0.091308594,\n", + " 0.084472656,\n", + " 0.015991211,\n", + " 0.030883789,\n", + " 0.03112793,\n", + " 0.041503906,\n", + " 0.018920898,\n", + " 0.04663086,\n", + " 0.0064697266,\n", + " 0.0058288574,\n", + " -0.007873535,\n", + " 0.016113281,\n", + " -0.0058898926,\n", + " 0.040039062,\n", + " 0.041748047,\n", + " 0.04736328,\n", + " -0.06591797,\n", + " 0.07861328,\n", + " -0.021850586,\n", + " -0.013427734,\n", + " 0.033447266,\n", + " 0.013183594,\n", + " 0.025878906,\n", + " 0.036376953,\n", + " -0.017211914,\n", + " 0.0067443848,\n", + " -0.011291504,\n", + " -0.009155273,\n", + " 0.005554199,\n", + " -0.00039863586,\n", + " 0.08251953,\n", + " -0.03491211,\n", + " -0.025878906,\n", + " -0.037109375,\n", + " 0.052734375,\n", + " -0.008911133,\n", + " -0.0390625,\n", + " 0.021362305,\n", + " -0.022949219,\n", + " 0.029907227,\n", + " 0.041259766,\n", + " 0.017211914,\n", + " -0.016845703,\n", + " 0.043701172,\n", + " -0.025512695,\n", + " -0.020019531,\n", + " 0.01953125,\n", + " -0.008422852,\n", + " 0.016357422,\n", + " -0.044921875,\n", + " -0.030761719,\n", + " 0.029541016,\n", + " -0.008422852,\n", + " 0.01977539,\n", + " 0.006652832,\n", + " 0.0031433105,\n", + " 0.044189453,\n", + " 0.00793457,\n", + " 0.02722168,\n", + " -0.043701172,\n", + " -0.01550293,\n", + " -0.0068359375,\n", + " -0.033935547,\n", + " 0.025024414,\n", + " -0.038085938,\n", + " -0.037353516,\n", + " -0.032714844,\n", + " 0.037841797,\n", + " -0.057373047,\n", + " -0.017211914,\n", + " -0.012878418,\n", + " -0.0069274902,\n", + " -0.020874023,\n", + " -0.037841797,\n", + " -0.036132812,\n", + " 0.033691406,\n", + " -0.030273438,\n", + " 0.033691406,\n", + " 0.049316406,\n", + " -0.02746582,\n", + " -0.030761719,\n", + " 0.03564453,\n", + " -0.02746582,\n", + " -0.03112793,\n", + " -0.00340271,\n", + " 0.016845703,\n", + " 0.03515625,\n", + " 0.009033203,\n", + " 0.026489258,\n", + " 0.04663086,\n", + " 0.0067443848,\n", + " 0.017944336,\n", + " 0.008850098,\n", + " -0.008544922,\n", + " 0.0022277832,\n", + " -0.030029297,\n", + " 0.010192871,\n", + " -0.021240234,\n", + " -0.020385742,\n", + " 0.008666992,\n", + " 0.005706787,\n", + " -0.02758789,\n", + " 0.05419922,\n", + " 0.036132812,\n", + " 0.032714844,\n", + " -0.014526367,\n", + " -0.02758789,\n", + " 0.03466797,\n", + " 0.05883789,\n", + " 0.026977539,\n", + " 0.05102539,\n", + " 0.052246094,\n", + " 0.0056152344,\n", + " 0.009094238,\n", + " -0.000579834,\n", + " -0.03100586,\n", + " -0.017822266,\n", + " 0.040283203,\n", + " -0.011474609,\n", + " -0.063964844,\n", + " 0.026977539,\n", + " 0.006958008,\n", + " -0.009765625,\n", + " 0.010253906,\n", + " -0.007385254,\n", + " 0.0051574707,\n", + " -0.0030670166,\n", + " -0.011047363,\n", + " -0.017333984,\n", + " -0.015991211,\n", + " 0.026245117,\n", + " -0.030639648,\n", + " -0.022460938,\n", + " 0.0059814453,\n", + " -0.021240234,\n", + " -0.011962891,\n", + " 0.010925293,\n", + " -0.021484375,\n", + " 0.037353516,\n", + " -0.050048828,\n", + " -0.08544922,\n", + " -0.024658203,\n", + " -0.026611328,\n", + " 0.020385742,\n", + " -0.033935547,\n", + " 0.025390625,\n", + " 0.0030670166,\n", + " -0.008117676,\n", + " -0.022338867,\n", + " 0.024291992,\n", + " 0.052246094,\n", + " -0.059570312,\n", + " -0.0138549805,\n", + " -0.01940918,\n", + " 0.05517578,\n", + " -0.0006866455,\n", + " 0.0049743652,\n", + " 0.07519531,\n", + " 0.057617188,\n", + " 0.004425049,\n", + " -0.043945312,\n", + " -0.029663086,\n", + " 0.017578125,\n", + " 0.030029297,\n", + " -0.007446289,\n", + " -0.030761719,\n", + " -0.021484375,\n", + " -0.009765625,\n", + " 0.013671875,\n", + " 0.012207031,\n", + " -0.012878418,\n", + " -0.043945312,\n", + " -0.020141602,\n", + " 0.013183594,\n", + " 0.0074157715,\n", + " -0.028686523,\n", + " 0.025268555,\n", + " 0.026367188,\n", + " 0.030395508,\n", + " 0.041748047,\n", + " 0.017944336,\n", + " 0.036376953,\n", + " 0.010437012,\n", + " -0.0625,\n", + " 0.04296875,\n", + " 0.0057373047,\n", + " 0.059570312,\n", + " 0.072753906,\n", + " 0.03881836,\n", + " -0.0021972656,\n", + " -0.027832031,\n", + " 0.0074157715,\n", + " 0.00045394897,\n", + " -0.003753662,\n", + " -0.010070801,\n", + " 0.008972168,\n", + " -0.0051574707,\n", + " 0.007537842,\n", + " -0.0079956055,\n", + " -0.03173828,\n", + " -0.012451172,\n", + " -0.015563965,\n", + " 0.027709961,\n", + " -0.039794922,\n", + " -0.016113281,\n", + " -0.056396484,\n", + " 0.016601562,\n", + " -0.030395508,\n", + " -0.033447266,\n", + " 0.052001953,\n", + " 0.001159668,\n", + " -0.02368164,\n", + " -0.046142578,\n", + " 0.01977539,\n", + " -0.02746582,\n", + " -0.038330078,\n", + " -0.052734375,\n", + " -0.030151367,\n", + " -0.030639648,\n", + " -0.0043945312,\n", + " 0.025390625,\n", + " 0.0048828125,\n", + " 0.029663086,\n", + " 0.01928711,\n", + " -0.025634766,\n", + " -0.022583008,\n", + " -0.019165039,\n", + " 0.026733398,\n", + " -0.035888672,\n", + " -0.015014648,\n", + " -0.0069274902,\n", + " -0.005126953,\n", + " 0.032958984,\n", + " 0.033203125,\n", + " -0.019897461,\n", + " -0.038330078,\n", + " 0.020874023,\n", + " 0.027954102,\n", + " -0.06689453,\n", + " -0.0069274902,\n", + " -0.0036315918,\n", + " -0.025634766,\n", + " -0.020507812,\n", + " 0.017333984,\n", + " -0.019165039,\n", + " 0.04663086,\n", + " -0.052734375,\n", + " -0.017333984,\n", + " 0.009338379,\n", + " -0.012756348,\n", + " -0.007507324,\n", + " 0.045166016,\n", + " 0.02722168,\n", + " -0.023071289,\n", + " -0.019042969,\n", + " 0.0045166016,\n", + " 0.017822266,\n", + " -0.024291992,\n", + " 0.030883789,\n", + " -0.008361816,\n", + " -0.050048828,\n", + " -0.026000977,\n", + " 0.021850586,\n", + " 0.011413574,\n", + " 0.0134887695,\n", + " 0.013000488,\n", + " -0.0068359375,\n", + " -0.040039062,\n", + " 0.007446289,\n", + " 0.020751953,\n", + " 0.037841797,\n", + " -0.03173828,\n", + " -0.044921875,\n", + " -0.012451172,\n", + " 0.00032806396,\n", + " -0.026123047,\n", + " -0.059570312,\n", + " -0.028564453,\n", + " 0.04272461,\n", + " 0.0064086914,\n", + " 0.030639648,\n", + " 0.018188477,\n", + " 0.016113281,\n", + " 0.043945312,\n", + " 0.015991211,\n", + " 0.020019531,\n", + " 0.055419922,\n", + " -0.016357422,\n", + " -0.002166748,\n", + " -0.025756836,\n", + " 0.015625,\n", + " -0.020263672,\n", + " 0.012573242,\n", + " 0.029296875,\n", + " -0.06689453,\n", + " 0.0062561035,\n", + " 0.03857422,\n", + " -0.010803223,\n", + " -0.026245117,\n", + " -0.016235352,\n", + " -0.04248047,\n", + " 0.033691406,\n", + " 0.02746582,\n", + " 0.024902344,\n", + " -0.025878906,\n", + " 0.046142578,\n", + " -0.029541016,\n", + " -0.015075684,\n", + " 0.015991211,\n", + " -0.030883789,\n", + " 0.017700195,\n", + " 0.03173828,\n", + " -0.005126953,\n", + " -0.0034484863,\n", + " -0.041992188,\n", + " -0.01159668,\n", + " -0.007293701,\n", + " 0.04321289,\n", + " 0.009399414,\n", + " -0.017578125,\n", + " -0.029418945,\n", + " 0.06542969,\n", + " -0.03125,\n", + " 0.020019531,\n", + " -0.05029297,\n", + " 0.033447266,\n", + " -0.0154418945,\n", + " 0.041748047,\n", + " 0.04345703,\n", + " 0.03515625,\n", + " -0.003479004,\n", + " 0.021484375,\n", + " -0.025146484,\n", + " 4.196167e-05,\n", + " 0.007659912,\n", + " -0.03540039,\n", + " -0.012512207,\n", + " 0.0087890625,\n", + " 0.041259766,\n", + " 0.015319824,\n", + " -0.018066406,\n", + " -0.0018920898,\n", + " 0.033447266,\n", + " -0.01184082,\n", + " -0.04345703,\n", + " -0.024780273,\n", + " 0.064453125,\n", + " -0.012207031,\n", + " -0.036132812,\n", + " 0.10839844,\n", + " -0.016357422,\n", + " -0.0047302246,\n", + " -0.013793945,\n", + " 0.018066406,\n", + " -0.017700195,\n", + " 0.01953125,\n", + " -0.0027313232,\n", + " -0.04272461,\n", + " 0.01940918,\n", + " -0.01586914,\n", + " 0.024414062,\n", + " -0.044433594,\n", + " -0.026123047,\n", + " 0.022094727,\n", + " -0.046142578,\n", + " 0.030761719,\n", + " 0.017578125,\n", + " -0.0028076172,\n", + " 0.059326172,\n", + " 0.025512695,\n", + " 0.025146484,\n", + " -0.03125,\n", + " 0.002319336,\n", + " -0.022827148,\n", + " 0.053710938,\n", + " -0.010559082,\n", + " 0.04345703,\n", + " 0.005645752,\n", + " -0.021972656,\n", + " -0.018920898,\n", + " -0.040283203,\n", + " 0.017456055,\n", + " 0.056884766,\n", + " 0.01928711,\n", + " -0.022827148,\n", + " 0.012145996,\n", + " -0.047851562,\n", + " 0.021118164,\n", + " -0.028930664,\n", + " -0.029907227,\n", + " -0.030883789,\n", + " -0.022827148,\n", + " -0.013977051,\n", + " 0.043701172,\n", + " 0.007080078,\n", + " 0.04711914,\n", + " -0.010253906,\n", + " -0.041015625,\n", + " 0.009216309,\n", + " -0.010986328,\n", + " 0.04248047,\n", + " -0.02758789,\n", + " 0.025268555,\n", + " 0.03466797,\n", + " 0.045166016,\n", + " -0.00023365021,\n", + " -0.021240234,\n", + " -0.016845703,\n", + " 0.02355957,\n", + " 0.024536133,\n", + " 0.036376953,\n", + " -0.015014648,\n", + " 9.393692e-05,\n", + " -0.0115356445,\n", + " 0.033447266,\n", + " -0.012145996,\n", + " 0.007080078,\n", + " 0.017700195,\n", + " -0.014282227,\n", + " 0.027709961,\n", + " 0.037353516,\n", + " -0.041503906,\n", + " 0.03149414,\n", + " 0.041015625,\n", + " 0.008483887,\n", + " 0.00579834,\n", + " 0.034179688,\n", + " 0.025634766,\n", + " -0.038085938,\n", + " -0.06591797,\n", + " -0.036376953,\n", + " 0.01171875,\n", + " -0.026000977,\n", + " 0.057861328,\n", + " -0.008300781,\n", + " 0.014282227,\n", + " -0.029052734,\n", + " 0.044433594,\n", + " -0.026123047,\n", + " 0.016601562,\n", + " -0.016357422,\n", + " 0.0024261475,\n", + " -0.025268555,\n", + " 0.05493164,\n", + " -0.025756836,\n", + " 0.02746582,\n", + " 0.037353516,\n", + " 0.02734375,\n", + " -0.04736328,\n", + " -0.012756348,\n", + " 0.016601562,\n", + " 0.009765625,\n", + " 0.013366699,\n", + " 0.013305664,\n", + " 0.036621094,\n", + " -0.034423828,\n", + " 0.046875,\n", + " -0.0028533936,\n", + " 0.018310547,\n", + " 0.05517578,\n", + " -0.06591797,\n", + " 0.042236328,\n", + " -0.013305664,\n", + " -0.007446289,\n", + " 0.014343262,\n", + " -0.04296875,\n", + " -0.038330078,\n", + " -0.016235352,\n", + " -0.043701172,\n", + " 0.004180908,\n", + " -0.045410156,\n", + " -0.009643555,\n", + " -0.012939453,\n", + " -0.0020141602,\n", + " -0.006713867,\n", + " -0.03881836,\n", + " -0.010559082,\n", + " 0.036376953,\n", + " 0.024169922,\n", + " -0.01977539,\n", + " 0.025756836,\n", + " -0.010253906,\n", + " 0.05493164,\n", + " 0.01965332,\n", + " 0.012451172,\n", + " 0.053466797,\n", + " -0.0062561035,\n", + " 0.028076172,\n", + " 0.024902344,\n", + " 0.068847656,\n", + " 0.019897461,\n", + " 0.01361084,\n", + " 0.015991211,\n", + " 0.017089844,\n", + " -0.053710938,\n", + " -0.056152344,\n", + " 0.04296875,\n", + " -0.0021972656,\n", + " -0.05517578,\n", + " 0.022460938,\n", + " -0.041259766,\n", + " -0.0234375,\n", + " -0.048583984,\n", + " -0.029296875,\n", + " 0.034423828,\n", + " 0.008056641,\n", + " 0.011352539,\n", + " 0.0390625,\n", + " 0.013366699,\n", + " -0.023803711,\n", + " -0.03466797,\n", + " 0.043701172,\n", + " 0.02746582,\n", + " 0.051757812,\n", + " -0.07128906,\n", + " 0.0059509277,\n", + " 0.022827148,\n", + " 0.013977051,\n", + " -0.046142578,\n", + " -0.016235352,\n", + " 0.017089844,\n", + " -0.001045227,\n", + " -0.014953613,\n", + " 0.012084961,\n", + " -0.0035705566,\n", + " 0.016845703,\n", + " 0.0234375,\n", + " 0.026611328,\n", + " -0.033203125,\n", + " 0.076660156,\n", + " 0.007873535,\n", + " 0.03540039,\n", + " 0.0061950684,\n", + " -0.028564453,\n", + " -0.03491211,\n", + " 0.01586914,\n", + " -0.015991211,\n", + " -0.024780273,\n", + " -0.028686523,\n", + " 0.028076172,\n", + " -0.005645752,\n", + " -0.043945312,\n", + " 0.021118164,\n", + " -0.0027008057,\n", + " 0.01550293,\n", + " 0.031982422,\n", + " -0.024780273,\n", + " 0.025878906,\n", + " 0.05859375,\n", + " -0.0050354004,\n", + " 0.033691406,\n", + " 0.044677734,\n", + " 0.018432617,\n", + " -0.007171631,\n", + " 0.003829956,\n", + " -0.047851562,\n", + " 0.026855469,\n", + " -0.005065918,\n", + " -0.02722168,\n", + " -0.03173828,\n", + " -0.0703125,\n", + " -0.016967773,\n", + " -0.008605957,\n", + " 0.037353516,\n", + " 0.03149414,\n", + " -0.06347656,\n", + " 0.031982422,\n", + " -0.033691406,\n", + " -0.03540039,\n", + " 0.021728516,\n", + " 0.07080078,\n", + " -0.03491211,\n", + " -0.014221191,\n", + " 0.046142578,\n", + " -0.010803223,\n", + " 0.009094238,\n", + " -0.0048217773,\n", + " -5.7935715e-05,\n", + " 0.055664062,\n", + " 0.025512695,\n", + " -0.024291992,\n", + " 0.04663086,\n", + " -0.008300781,\n", + " 0.056640625,\n", + " -0.006713867,\n", + " -0.018188477,\n", + " 0.012268066,\n", + " -0.045898438,\n", + " -0.051513672,\n", + " 0.016357422,\n", + " -0.049316406,\n", + " -0.0020599365,\n", + " -0.04345703,\n", + " -0.08935547,\n", + " -0.056640625,\n", + " -0.048828125,\n", + " -0.020996094,\n", + " 0.036376953,\n", + " -0.052001953,\n", + " 0.020629883,\n", + " 0.048339844,\n", + " 0.029296875,\n", + " 0.021728516,\n", + " 0.028930664,\n", + " -0.024169922,\n", + " -0.030273438,\n", + " -0.036621094,\n", + " -0.028808594,\n", + " -0.0546875,\n", + " 6.3478947e-06,\n", + " 0.03857422,\n", + " 0.01965332,\n", + " -0.016235352,\n", + " -0.017089844,\n", + " -0.012451172,\n", + " 0.010498047,\n", + " 0.025024414,\n", + " 0.016601562,\n", + " 0.032958984,\n", + " 0.0047912598,\n", + " -0.011047363,\n", + " 0.011352539,\n", + " -0.0044555664,\n", + " 0.004211426,\n", + " -0.004119873,\n", + " 0.0045776367,\n", + " 0.03149414,\n", + " -0.025146484,\n", + " -0.010070801,\n", + " -0.02331543,\n", + " 0.032714844,\n", + " 0.018798828,\n", + " -0.020751953,\n", + " 0.06201172,\n", + " 0.0043029785,\n", + " -0.039794922,\n", + " 0.010131836,\n", + " 0.048828125,\n", + " -0.036621094,\n", + " -0.007873535,\n", + " -0.029296875,\n", + " -0.046142578,\n", + " -0.016845703,\n", + " 0.056640625,\n", + " -0.048095703,\n", + " 0.0051574707,\n", + " 0.0008087158,\n", + " 0.00018310547,\n", + " -0.019165039,\n", + " 0.017700195,\n", + " -0.032226562,\n", + " -0.0047912598,\n", + " -0.053710938,\n", + " -0.06542969,\n", + " 0.013427734,\n", + " 0.004119873,\n", + " 0.021362305,\n", + " -0.0038452148,\n", + " -0.008056641,\n", + " -0.021606445,\n", + " 0.00793457,\n", + " -0.018798828,\n", + " -0.048828125,\n", + " 0.006958008,\n", + " -0.0390625,\n", + " -0.044921875,\n", + " -0.029052734,\n", + " -0.0039367676,\n", + " -0.009460449,\n", + " 0.03149414,\n", + " -0.024658203,\n", + " -0.007171631,\n", + " -0.020751953,\n", + " 0.010620117,\n", + " 0.027709961,\n", + " -0.012878418,\n", + " -0.006134033,\n", + " -0.036376953,\n", + " 0.0234375,\n", + " -0.008056641,\n", + " -0.029296875,\n", + " 0.0048217773,\n", + " 0.053222656,\n", + " -0.03857422,\n", + " -0.03930664,\n", + " -0.041015625,\n", + " 0.012084961,\n", + " -0.025146484,\n", + " 0.03491211,\n", + " -0.041748047,\n", + " 0.04248047,\n", + " -0.0003452301,\n", + " -0.018920898,\n", + " -0.046142578,\n", + " -0.014160156,\n", + " 0.046875,\n", + " -0.022216797,\n", + " 0.052246094,\n", + " -0.026611328,\n", + " 0.029541016,\n", + " -0.016357422,\n", + " 0.04272461,\n", + " -0.018920898,\n", + " 0.0078125,\n", + " -0.018676758,\n", + " 0.014770508,\n", + " -0.016357422,\n", + " -0.040527344,\n", + " -0.004486084,\n", + " -0.018066406,\n", + " -0.03100586,\n", + " 0.033691406,\n", + " -0.016113281,\n", + " -0.051757812,\n", + " 0.028320312,\n", + " -0.0234375,\n", + " 0.005126953,\n", + " -0.01171875,\n", + " 0.022216797,\n", + " -0.03466797,\n", + " 0.044433594,\n", + " -0.012268066,\n", + " 0.020263672,\n", + " 0.016479492,\n", + " 0.050048828,\n", + " -0.059570312,\n", + " 0.016967773,\n", + " -0.010925293,\n", + " 0.0013809204,\n", + " 0.026977539,\n", + " -0.022460938,\n", + " 0.034179688,\n", + " 0.01953125,\n", + " 0.005706787,\n", + " 0.036376953,\n", + " -0.018188477,\n", + " 0.041503906,\n", + " 0.08251953,\n", + " 0.009521484,\n", + " 0.005493164,\n", + " 0.0021820068,\n", + " -0.014465332,\n", + " 0.01965332,\n", + " 0.0008735657,\n", + " 0.029418945,\n", + " -0.057617188,\n", + " 0.021972656,\n", + " 0.008483887,\n", + " 0.064941406,\n", + " 0.0013198853,\n", + " -0.032714844,\n", + " -0.0087890625,\n", + " -0.014160156,\n", + " 0.080566406,\n", + " -0.012390137,\n", + " 0.02746582,\n", + " 0.0044555664,\n", + " -0.029541016,\n", + " 0.011657715,\n", + " -0.010803223,\n", + " -0.020874023,\n", + " 0.0030670166,\n", + " 0.013549805,\n", + " 0.0025787354,\n", + " -0.022827148,\n", + " -0.011291504,\n", + " 0.018188477,\n", + " 0.036132812,\n", + " 0.008178711,\n", + " ...]]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "encoder([\"hey\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-05-13 22:26:54 INFO semantic_router.utils.logger local\u001b[0m\n" + ] + } + ], + "source": [ + "from semantic_router.layer import RouteLayer\n", + "\n", + "rl = RouteLayer(encoder=encoder, routes=routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can check the dimensionality of our vectors by looking at the `index` attribute of the `RouteLayer`." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(11, 1024)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl.index.index.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We do have 256-dimensional vectors. Now let's test them:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='politics', function_call=None, similarity_score=None)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"don't you love politics?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='chitchat', function_call=None, similarity_score=None)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"how's the weather today?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Both are classified accurately, what if we send a query that is unrelated to our existing `Route` objects?" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name=None, function_call=None, similarity_score=None)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"How does llama model work?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, we return `None` because no matches were identified. We always recommend optimizing your `RouteLayer` for optimal performance, you can see how in [this notebook](https://github.com/aurelio-labs/semantic-router/blob/main/docs/06-threshold-optimization.ipynb)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "decision-layer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/poetry.lock b/poetry.lock index 53b7b5cd7df3ef496e4c97d39e3beea6cd662bc8..8b7f9ccdac9b216db117ee7f88a66f7765c2d7e2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -259,6 +259,47 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "boto3" +version = "1.34.98" +description = "The AWS SDK for Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "boto3-1.34.98-py3-none-any.whl", hash = "sha256:030e43b8efe22b4cf10b9f3ef9e30cd4cf9ef9784b26efe9a4583339f2b2bcec"}, + {file = "boto3-1.34.98.tar.gz", hash = "sha256:28c10956033fa79e64529f48c3b62db86d5e4b77024a7343764b6bde6b553543"}, +] + +[package.dependencies] +botocore = ">=1.34.98,<1.35.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.10.0,<0.11.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.34.98" +description = "Low-level, data-driven core of boto 3." +optional = true +python-versions = ">=3.8" +files = [ + {file = "botocore-1.34.98-py3-none-any.whl", hash = "sha256:631c0031d8ce922b5752ab395ead896a0281b0dc74745a754d0351a27c5d83de"}, + {file = "botocore-1.34.98.tar.gz", hash = "sha256:4cee65df02f4b0be08ad1401965cc89efafebc50ef0727d2d17083c7f1ed2831"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = [ + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, + {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, +] + +[package.extras] +crt = ["awscrt (==0.20.9)"] + [[package]] name = "cachetools" version = "5.3.3" @@ -1096,12 +1137,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -1857,6 +1898,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = true +python-versions = ">=3.7" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + [[package]] name = "joblib" version = "1.4.0" @@ -3508,6 +3560,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3853,6 +3906,23 @@ files = [ {file = "ruff-0.1.15.tar.gz", hash = "sha256:f6dfa8c1b21c913c326919056c390966648b680966febcb796cc9d1aaab8564e"}, ] +[[package]] +name = "s3transfer" +version = "0.10.1" +description = "An Amazon S3 Transfer Manager" +optional = true +python-versions = ">= 3.8" +files = [ + {file = "s3transfer-0.10.1-py3-none-any.whl", hash = "sha256:ceb252b11bcf87080fb7850a224fb6e05c8a776bab8f2b64b7f25b969464839d"}, + {file = "s3transfer-0.10.1.tar.gz", hash = "sha256:5683916b4c724f799e600f41dd9e10a9ff19871bf87623cc8f491cb4f5fa0a19"}, +] + +[package.dependencies] +botocore = ">=1.33.2,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] + [[package]] name = "safetensors" version = "0.4.2" @@ -4560,6 +4630,20 @@ files = [ {file = "types_PyYAML-6.0.12.20240311-py3-none-any.whl", hash = "sha256:b845b06a1c7e54b8e5b4c683043de0d9caf205e7434b3edc678ff2411979b8f6"}, ] +[[package]] +name = "types-requests" +version = "2.31.0.6" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.7" +files = [ + {file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"}, + {file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"}, +] + +[package.dependencies] +types-urllib3 = "*" + [[package]] name = "types-requests" version = "2.31.0.20240406" @@ -4574,6 +4658,17 @@ files = [ [package.dependencies] urllib3 = ">=2" +[[package]] +name = "types-urllib3" +version = "1.26.25.14" +description = "Typing stubs for urllib3" +optional = false +python-versions = "*" +files = [ + {file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"}, + {file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"}, +] + [[package]] name = "typing-extensions" version = "4.11.0" @@ -4585,6 +4680,22 @@ files = [ {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"}, ] +[[package]] +name = "urllib3" +version = "1.26.18" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +files = [ + {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"}, + {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"}, +] + +[package.extras] +brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] +secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] + [[package]] name = "urllib3" version = "2.2.1" @@ -4756,6 +4867,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] +bedrock = ["boto3"] fastembed = ["fastembed"] google = ["google-cloud-aiplatform"] hybrid = ["pinecone-text"] @@ -4769,4 +4881,4 @@ vision = ["pillow", "torch", "torchvision", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "9f308d2dd1c067185f9d84721b25d81e7d1e72a239059863bad1f4439a7a26cc" +content-hash = "be798556d4ad5d05ba0682534dcfab1c06e3ff1c33bcf3c24d178b665c81dde8" diff --git a/pyproject.toml b/pyproject.toml index fba7047b0eb73f7f52a66d2fbfeecfd81536375a..558ffce8de2e5537943635a84103c3c6440ce092 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ matplotlib = { version = "^3.8.3", optional = true} qdrant-client = {version = "^1.8.0", optional = true} google-cloud-aiplatform = {version = "^1.45.0", optional = true} requests-mock = "^1.12.1" +boto3 = { version = "^1.34.98", optional = true } [tool.poetry.extras] hybrid = ["pinecone-text"] @@ -49,6 +50,7 @@ processing = ["matplotlib"] mistralai = ["mistralai"] qdrant = ["qdrant-client"] google = ["google-cloud-aiplatform"] +bedrock = ["boto3"] [tool.poetry.group.dev.dependencies] ipykernel = "^6.25.0" diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index 5efc730398a45fc3a9de5f234a6d43a0e37911be..a1026240d37fbfecb1ec8b1445d42fc05f04265f 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -1,6 +1,7 @@ from typing import List, Optional from semantic_router.encoders.base import BaseEncoder +from semantic_router.encoders.bedrock import BedrockEncoder from semantic_router.encoders.bm25 import BM25Encoder from semantic_router.encoders.clip import CLIPEncoder from semantic_router.encoders.cohere import CohereEncoder @@ -29,6 +30,7 @@ __all__ = [ "VitEncoder", "CLIPEncoder", "GoogleEncoder", + "BedrockEncoder", ] @@ -67,6 +69,8 @@ class AutoEncoder: self.model = CLIPEncoder(name=name) elif self.type == EncoderType.GOOGLE: self.model = GoogleEncoder(name=name) + elif self.type == EncoderType.BEDROCK: + self.model = BedrockEncoder(name=name) # type: ignore else: raise ValueError(f"Encoder type '{type}' not supported") diff --git a/semantic_router/encoders/bedrock.py b/semantic_router/encoders/bedrock.py new file mode 100644 index 0000000000000000000000000000000000000000..ce04719be7e5fa2117a706938c811803e746b1bc --- /dev/null +++ b/semantic_router/encoders/bedrock.py @@ -0,0 +1,250 @@ +""" +This module provides the BedrockEncoder class for generating embeddings using Amazon's Bedrock Platform. + +The BedrockEncoder class is a subclass of BaseEncoder and utilizes the TextEmbeddingModel from the +Amazon's Bedrock Platform to generate embeddings for given documents. It requires an AWS Access Key ID +and AWS Secret Access Key and supports customization of the pre-trained model, score threshold, and region. + +Example usage: + + from semantic_router.encoders.bedrock_encoder import BedrockEncoder + + encoder = BedrockEncoder(access_key_id="your-access-key-id", secret_access_key="your-secret-key", region="your-region") + embeddings = encoder(["document1", "document2"]) + +Classes: + BedrockEncoder: A class for generating embeddings using the Bedrock Platform. +""" + +import json +from typing import List, Optional, Any +import os +import tiktoken +from semantic_router.encoders import BaseEncoder +from semantic_router.utils.defaults import EncoderDefault + + +class BedrockEncoder(BaseEncoder): + client: Any = None + type: str = "bedrock" + input_type: Optional[str] = "search_query" + name: str + access_key_id: Optional[str] = None + secret_access_key: Optional[str] = None + session_token: Optional[str] = None + region: Optional[str] = None + + def __init__( + self, + name: str = EncoderDefault.BEDROCK.value["embedding_model"], + input_type: Optional[str] = "search_query", + score_threshold: float = 0.3, + access_key_id: Optional[str] = None, + secret_access_key: Optional[str] = None, + session_token: Optional[str] = None, + region: Optional[str] = None, + ): + """Initializes the BedrockEncoder. + + Args: + name: The name of the pre-trained model to use for embedding. + If not provided, the default model specified in EncoderDefault will + be used. + score_threshold: The threshold for similarity scores. + access_key_id: The AWS access key id for an IAM principle. + If not provided, it will be retrieved from the access_key_id + environment variable. + secret_access_key: The secret access key for an IAM principle. + If not provided, it will be retrieved from the AWS_SECRET_KEY + environment variable. + session_token: The session token for an IAM principle. + If not provided, it will be retrieved from the AWS_SESSION_TOKEN + environment variable. + region: The location of the Bedrock resources. + If not provided, it will be retrieved from the AWS_REGION + environment variable, defaulting to "us-west-1" + + Raises: + ValueError: If the Bedrock Platform client fails to initialize. + """ + + super().__init__(name=name, score_threshold=score_threshold) + self.access_key_id = self.get_env_variable("access_key_id", access_key_id) + self.secret_access_key = self.get_env_variable( + "secret_access_key", secret_access_key + ) + self.session_token = self.get_env_variable("AWS_SESSION_TOKEN", session_token) + self.region = self.get_env_variable("AWS_REGION", region, default="us-west-1") + + self.input_type = input_type + + try: + self.client = self._initialize_client( + self.access_key_id, + self.secret_access_key, + self.session_token, + self.region, + ) + + except Exception as e: + raise ValueError(f"Bedrock client failed to initialise. Error: {e}") from e + + def _initialize_client( + self, access_key_id, secret_access_key, session_token, region + ): + """Initializes the Bedrock client. + + Args: + access_key_id: The Amazon access key ID. + secret_access_key: The Amazon secret key. + region: The location of the AI Platform resources. + + Returns: + An instance of the TextEmbeddingModel client. + + Raises: + ImportError: If the required Bedrock libraries are not + installed. + ValueError: If the Bedrock client fails to initialize. + """ + try: + import boto3 + except ImportError: + raise ImportError( + "Please install Amazon's Boto3 client library to use the BedrockEncoder. " + "You can install them with: " + "`pip install boto3`" + ) + + access_key_id = access_key_id or os.getenv("access_key_id") + aws_secret_key = secret_access_key or os.getenv("secret_access_key") + region = region or os.getenv("AWS_REGION", "us-west-2") + + if access_key_id is None: + raise ValueError("AWS access key ID cannot be 'None'.") + + if aws_secret_key is None: + raise ValueError("AWS secret access key cannot be 'None'.") + + try: + bedrock_client = boto3.client( + "bedrock-runtime", + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + aws_session_token=session_token, + region_name=region, + ) + except Exception as err: + raise ValueError( + f"The Bedrock client failed to initialize. Error: {err}" + ) from err + + return bedrock_client + + def __call__(self, docs: List[str]) -> List[List[float]]: + """Generates embeddings for the given documents. + + Args: + docs: A list of strings representing the documents to embed. + + Returns: + A list of lists, where each inner list contains the embedding values for a + document. + + Raises: + ValueError: If the Bedrock Platform client is not initialized or if the + API call fails. + """ + if self.client is None: + raise ValueError("Bedrock client is not initialised.") + try: + embeddings = [] + + def chunk_strings(strings, MAX_WORDS=20): + """ + Breaks up a list of strings into smaller chunks. + + Args: + strings (list): A list of strings to be chunked. + max_chunk_size (int): The maximum size of each chunk. Default is 75. + + Returns: + list: A list of lists, where each inner list contains a chunk of strings. + """ + encoding = tiktoken.get_encoding("cl100k_base") + chunked_strings = [] + current_chunk = [] + + for text in strings: + encoded_text = encoding.encode(text) + + if len(encoded_text) > MAX_WORDS: + current_chunk = [ + encoding.decode(encoded_text[i : i + MAX_WORDS]) + for i in range(0, len(encoded_text), MAX_WORDS) + ] + else: + current_chunk = [encoding.decode(encoded_text)] + + chunked_strings.append(current_chunk) + return chunked_strings + + if self.name and "amazon" in self.name: + for doc in docs: + embedding_body = json.dumps( + { + "inputText": doc, + } + ) + response = self.client.invoke_model( + body=embedding_body, + modelId=self.name, + accept="application/json", + contentType="application/json", + ) + + response_body = json.loads(response.get("body").read()) + embeddings.append(response_body.get("embedding")) + elif self.name and "cohere" in self.name: + chunked_docs = chunk_strings(docs) + for chunk in chunked_docs: + chunk = json.dumps({"texts": chunk, "input_type": self.input_type}) + + response = self.client.invoke_model( + body=chunk, + modelId=self.name, + accept="*/*", + contentType="application/json", + ) + + response_body = json.loads(response.get("body").read()) + + chunk_embeddings = response_body.get("embeddings") + embeddings.extend(chunk_embeddings) + else: + raise ValueError("Unknown model name") + return embeddings + except Exception as e: + raise ValueError(f"Bedrock call failed. Error: {e}") from e + + @staticmethod + def get_env_variable(var_name, provided_value, default=None): + """Retrieves environment variable or uses a provided value. + + Args: + var_name (str): The name of the environment variable. + provided_value (Optional[str]): The provided value to use if not None. + default (Optional[str]): The default value if the environment variable is not set. + + Returns: + str: The value of the environment variable or the provided/default value. + + Raises: + ValueError: If no value is provided and the environment variable is not set. + """ + if provided_value is not None: + return provided_value + value = os.getenv(var_name, default) + if value is None: + raise ValueError(f"No {var_name} provided") + return value diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 63f1e440a1a816a79a3b37b974d191eac48a24af..86ab123318812dbc3c18a8a6f98a7e4bb186669c 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -15,6 +15,7 @@ class EncoderType(Enum): VIT = "vit" CLIP = "clip" GOOGLE = "google" + BEDROCK = "bedrock" class EncoderInfo(BaseModel): diff --git a/semantic_router/utils/defaults.py b/semantic_router/utils/defaults.py index 3c9cbb2dd1010f5b861c49fcafad389c591fe9cb..75331c06581ad4692bc24f1633ba5a609ba28e47 100644 --- a/semantic_router/utils/defaults.py +++ b/semantic_router/utils/defaults.py @@ -31,3 +31,8 @@ class EncoderDefault(Enum): "GOOGLE_EMBEDDING_MODEL", "textembedding-gecko@003" ), } + BEDROCK = { + "embedding_model": os.environ.get( + "BEDROCK_EMBEDDING_MODEL", "amazon.titan-embed-image-v1" + ) + } diff --git a/tests/unit/encoders/test_bedrock.py b/tests/unit/encoders/test_bedrock.py new file mode 100644 index 0000000000000000000000000000000000000000..43955d453c704bc95fcce598b94e0a46479947a0 --- /dev/null +++ b/tests/unit/encoders/test_bedrock.py @@ -0,0 +1,116 @@ +import pytest +import json +from io import BytesIO +from semantic_router.encoders import BedrockEncoder + + +@pytest.fixture +def bedrock_encoder(mocker): + mocker.patch("semantic_router.encoders.bedrock.BedrockEncoder._initialize_client") + return BedrockEncoder( + access_key_id="fake_id", + secret_access_key="fake_secret", + session_token="fake_token", + region="us-west-2", + ) + + +class TestBedrockEncoder: + def test_initialisation_with_default_values(self, bedrock_encoder): + assert ( + bedrock_encoder.input_type == "search_query" + ), "Default input type not set correctly" + assert bedrock_encoder.region == "us-west-2", "Region should be initialised" + + def test_initialisation_with_custom_values(self, mocker): + # mocker.patch( + # "semantic_router.encoders.bedrock.BedrockEncoder._initialize_client" + # ) + name = "custom_model" + score_threshold = 0.5 + input_type = "custom_input" + bedrock_encoder = BedrockEncoder( + name=name, + score_threshold=score_threshold, + input_type=input_type, + access_key_id="fake_id", + secret_access_key="fake_secret", + session_token="fake_token", + region="us-west-2", + ) + assert bedrock_encoder.name == name, "Custom name not set correctly" + assert bedrock_encoder.region == "us-west-2", "Custom region not set correctly" + assert ( + bedrock_encoder.score_threshold == score_threshold + ), "Custom score threshold not set correctly" + assert ( + bedrock_encoder.input_type == input_type + ), "Custom input type not set correctly" + + def test_call_method(self, bedrock_encoder): + response_content = json.dumps({"embedding": [0.1, 0.2, 0.3]}) + response_body = BytesIO(response_content.encode("utf-8")) + mock_response = {"body": response_body} + bedrock_encoder.client.invoke_model.return_value = mock_response + result = bedrock_encoder(["test"]) + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(item, list) for item in result + ), "Each item in result should be a list" + assert result == [[0.1, 0.2, 0.3]], "Embedding should be [0.1, 0.2, 0.3]" + + def test_raises_value_error_if_client_is_not_initialised(self, mocker): + mocker.patch( + "semantic_router.encoders.bedrock.BedrockEncoder._initialize_client", + side_effect=Exception("Client initialisation failed"), + ) + with pytest.raises(ValueError): + BedrockEncoder( + access_key_id="fake_id", + secret_access_key="fake_secret", + session_token="fake_token", + region="us-west-2", + ) + + def test_raises_value_error_if_call_to_bedrock_fails(self, bedrock_encoder): + bedrock_encoder.client.invoke_model.side_effect = Exception( + "Bedrock call failed." + ) + with pytest.raises(ValueError): + bedrock_encoder(["test"]) + + +@pytest.fixture +def bedrock_encoder_with_cohere(mocker): + mocker.patch("semantic_router.encoders.bedrock.BedrockEncoder._initialize_client") + return BedrockEncoder( + name="cohere_model", + access_key_id="fake_id", + secret_access_key="fake_secret", + session_token="fake_token", + region="us-west-2", + ) + + +class TestBedrockEncoderWithCohere: + def test_cohere_embedding_single_chunk(self, bedrock_encoder_with_cohere): + response_content = json.dumps({"embeddings": [[0.1, 0.2, 0.3]]}) + response_body = BytesIO(response_content.encode("utf-8")) + mock_response = {"body": response_body} + bedrock_encoder_with_cohere.client.invoke_model.return_value = mock_response + result = bedrock_encoder_with_cohere(["short test"]) + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(item, list) for item in result + ), "Each item should be a list" + assert result == [[0.1, 0.2, 0.3]], "Expected embedding [0.1, 0.2, 0.3]" + + def test_cohere_input_type(self, bedrock_encoder_with_cohere): + bedrock_encoder_with_cohere.input_type = "different_type" + response_content = json.dumps({"embeddings": [[0.1, 0.2, 0.3]]}) + response_body = BytesIO(response_content.encode("utf-8")) + mock_response = {"body": response_body} + bedrock_encoder_with_cohere.client.invoke_model.return_value = mock_response + result = bedrock_encoder_with_cohere(["test with different input type"]) + assert isinstance(result, list), "Result should be a list" + assert result == [[0.1, 0.2, 0.3]], "Expected specific embeddings"