diff --git a/docs/encoders/bedrock.ipynb b/docs/encoders/bedrock.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..54cc0a7dfbb3709ce0e6355e9d0a7508804298ec --- /dev/null +++ b/docs/encoders/bedrock.ipynb @@ -0,0 +1,1323 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/docs/encoders/bedrock.ipynb) [](https://nbviewer.org/github/aurelio-labs/semantic-router/blob/main/docs/encoders/bedrock.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using Bedrock embedding Models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The 3rd generation embedding models from AWS Bedrock (`amazon.titan-embed-text-v1`, `amazon.titan-embed-text-v2` and `cohere.embed-english-v3`) can both be used with our `BedrockEncoder`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Started" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by installing semantic-router. Support for the new `Bedrock` embedding models was added in `semantic-router==0.0.40`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU \"semantic-router[bedrock]==0.0.40\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by defining a dictionary mapping routes to example phrases that should trigger those routes." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router import Route\n", + "\n", + "politics = Route(\n", + " name=\"politics\",\n", + " utterances=[\n", + " \"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\",\n", + " \"don't you just love the president\",\n", + " \"don't you just hate the president\",\n", + " \"they're going to destroy this country!\",\n", + " \"they will save the country!\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define another for good measure:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " utterances=[\n", + " \"how's the weather today?\",\n", + " \"how are things going?\",\n", + " \"lovely weather today\",\n", + " \"the weather is horrendous\",\n", + " \"let's go to the chippy\",\n", + " ],\n", + ")\n", + "\n", + "routes = [politics, chitchat]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we initialize our embedding model, we will use the `-3-large` model alongside a `dimensions` value of `256`. This will produce _tiny_ 256-dimensional vectors that — according to OpenAI — outperform the 1536-dimensional vectors produced by `text-embedding-ada-002`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from getpass import getpass\n", + "from semantic_router.encoders import BedrockEncoder\n", + "\n", + "aws_access_key_id = os.getenv(\"AWS_ACCESS_KEY_ID\") or getpass(\n", + " \"Enter AWS Access Key ID: \"\n", + ")\n", + "aws_secret_access_key = os.getenv(\"AWS_SECRET_ACCESS_KEY\") or getpass(\n", + " \"Enter AWS Secret Access Key: \"\n", + ")\n", + "aws_session_token = os.getenv(\"AWS_SESSION_TOKEN\") or getpass(\n", + " \"Enter AWS Session Token: \"\n", + ")\n", + "aws_region = os.getenv(\"AWS_REGION\") or getpass(\"Enter AWS Region: \")\n", + "\n", + "encoder = BedrockEncoder(\n", + " name=\"amazon.titan-embed-image-v1\",\n", + " score_threshold=0.5,\n", + " access_key_id=aws_access_key_id,\n", + " secret_access_key=aws_secret_access_key,\n", + " session_token=aws_session_token,\n", + " region=aws_region,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[0.012878418,\n", + " 0.028442383,\n", + " -0.022094727,\n", + " -0.020751953,\n", + " -0.008300781,\n", + " 0.033691406,\n", + " 0.09326172,\n", + " 0.0045166016,\n", + " 0.033935547,\n", + " 0.015319824,\n", + " 0.012939453,\n", + " 0.015380859,\n", + " 0.012756348,\n", + " -0.064453125,\n", + " 0.018432617,\n", + " 0.03173828,\n", + " -0.018188477,\n", + " -0.007171631,\n", + " 0.03955078,\n", + " 0.0033874512,\n", + " 0.007019043,\n", + " 0.010131836,\n", + " -0.025878906,\n", + " 0.056152344,\n", + " 0.01373291,\n", + " -0.020263672,\n", + " 0.055419922,\n", + " -0.06225586,\n", + " 0.040039062,\n", + " -0.015075684,\n", + " 0.012268066,\n", + " -0.056640625,\n", + " 0.04736328,\n", + " -0.002609253,\n", + " -0.0064086914,\n", + " 0.011291504,\n", + " -0.019165039,\n", + " -0.005493164,\n", + " 0.003189087,\n", + " 0.008666992,\n", + " 0.03564453,\n", + " -0.0027923584,\n", + " -0.016601562,\n", + " 0.014404297,\n", + " -0.01171875,\n", + " 0.013183594,\n", + " -0.018920898,\n", + " -0.030639648,\n", + " 0.010864258,\n", + " 0.052734375,\n", + " -0.006164551,\n", + " 0.0035705566,\n", + " 0.0060424805,\n", + " -0.021606445,\n", + " -0.040527344,\n", + " 0.020385742,\n", + " 0.004638672,\n", + " -0.010314941,\n", + " -0.010681152,\n", + " -0.010803223,\n", + " -0.038330078,\n", + " -0.029174805,\n", + " 0.036865234,\n", + " -0.03112793,\n", + " -0.034179688,\n", + " 0.017944336,\n", + " -0.03515625,\n", + " 0.068847656,\n", + " -0.032470703,\n", + " -0.03540039,\n", + " 0.017944336,\n", + " -0.024047852,\n", + " -0.05834961,\n", + " -0.049804688,\n", + " -0.009277344,\n", + " 0.021484375,\n", + " -0.036376953,\n", + " 0.03540039,\n", + " -0.012939453,\n", + " -0.03491211,\n", + " -0.028808594,\n", + " 0.017333984,\n", + " 0.021484375,\n", + " 0.0052490234,\n", + " -0.026611328,\n", + " -0.0026245117,\n", + " 0.05078125,\n", + " -0.022949219,\n", + " -0.057128906,\n", + " -0.019042969,\n", + " 0.01574707,\n", + " -0.0025482178,\n", + " 0.02355957,\n", + " -0.0011367798,\n", + " 0.0039367676,\n", + " -0.015197754,\n", + " -0.02758789,\n", + " -0.025268555,\n", + " -0.048339844,\n", + " 0.04296875,\n", + " -0.01373291,\n", + " -0.0052490234,\n", + " -0.016357422,\n", + " -0.029663086,\n", + " 0.024536133,\n", + " -0.03881836,\n", + " -0.035888672,\n", + " 0.013793945,\n", + " -0.016357422,\n", + " -0.052734375,\n", + " 0.0154418945,\n", + " -0.004058838,\n", + " -0.018432617,\n", + " -0.01574707,\n", + " 0.06225586,\n", + " 0.044433594,\n", + " -0.011474609,\n", + " 0.019897461,\n", + " -0.018432617,\n", + " -0.03515625,\n", + " 0.057861328,\n", + " -0.016967773,\n", + " 0.008666992,\n", + " 0.01574707,\n", + " 0.024780273,\n", + " 0.01953125,\n", + " -0.005554199,\n", + " 0.042236328,\n", + " -0.026123047,\n", + " -0.111328125,\n", + " 0.018798828,\n", + " 0.018066406,\n", + " -0.032958984,\n", + " -0.0025024414,\n", + " -0.01159668,\n", + " -0.028930664,\n", + " -0.055908203,\n", + " 0.037353516,\n", + " 0.018432617,\n", + " 0.015258789,\n", + " -0.021850586,\n", + " -0.0026245117,\n", + " 0.016723633,\n", + " 0.0095825195,\n", + " 0.05029297,\n", + " 0.011779785,\n", + " 0.04711914,\n", + " -0.064941406,\n", + " 0.0059509277,\n", + " -0.025390625,\n", + " 0.03857422,\n", + " 0.046875,\n", + " 0.015258789,\n", + " 0.03930664,\n", + " 0.02355957,\n", + " 0.03125,\n", + " -0.032958984,\n", + " 0.056640625,\n", + " -0.056396484,\n", + " 0.0146484375,\n", + " 0.0025634766,\n", + " 0.006591797,\n", + " 0.0015563965,\n", + " -0.020385742,\n", + " 0.016723633,\n", + " -0.008972168,\n", + " -0.024169922,\n", + " 0.03125,\n", + " -0.028808594,\n", + " 0.040283203,\n", + " 0.0055236816,\n", + " -0.0025787354,\n", + " 0.067871094,\n", + " -0.004119873,\n", + " -0.03515625,\n", + " 0.030517578,\n", + " 0.0077819824,\n", + " -0.026733398,\n", + " -0.01953125,\n", + " -0.014709473,\n", + " -0.045898438,\n", + " -0.012268066,\n", + " 0.022216797,\n", + " 0.008972168,\n", + " 0.017211914,\n", + " -0.0234375,\n", + " -0.017211914,\n", + " 0.030151367,\n", + " -0.0034942627,\n", + " 0.029174805,\n", + " 0.05029297,\n", + " -0.053222656,\n", + " -0.037841797,\n", + " 0.008117676,\n", + " 0.014038086,\n", + " 0.015563965,\n", + " -0.060791016,\n", + " 0.014221191,\n", + " -0.028808594,\n", + " -0.03955078,\n", + " -0.111328125,\n", + " 0.041992188,\n", + " -0.043945312,\n", + " 0.030273438,\n", + " -0.045898438,\n", + " -0.014770508,\n", + " -0.030395508,\n", + " -0.041748047,\n", + " 0.0011291504,\n", + " -0.034423828,\n", + " -0.04272461,\n", + " 0.008300781,\n", + " 0.014831543,\n", + " -0.018798828,\n", + " -0.017700195,\n", + " -0.014099121,\n", + " -0.011169434,\n", + " -0.029418945,\n", + " 0.027832031,\n", + " 0.010986328,\n", + " 0.030151367,\n", + " -0.021728516,\n", + " 0.004547119,\n", + " -0.034423828,\n", + " -0.01977539,\n", + " 0.047851562,\n", + " 0.021362305,\n", + " 0.044433594,\n", + " 0.06933594,\n", + " -0.0046691895,\n", + " -0.049560547,\n", + " -0.091308594,\n", + " 0.084472656,\n", + " 0.015991211,\n", + " 0.030883789,\n", + " 0.03112793,\n", + " 0.041503906,\n", + " 0.018920898,\n", + " 0.04663086,\n", + " 0.0064697266,\n", + " 0.0058288574,\n", + " -0.007873535,\n", + " 0.016113281,\n", + " -0.0058898926,\n", + " 0.040039062,\n", + " 0.041748047,\n", + " 0.04736328,\n", + " -0.06591797,\n", + " 0.07861328,\n", + " -0.021850586,\n", + " -0.013427734,\n", + " 0.033447266,\n", + " 0.013183594,\n", + " 0.025878906,\n", + " 0.036376953,\n", + " -0.017211914,\n", + " 0.0067443848,\n", + " -0.011291504,\n", + " -0.009155273,\n", + " 0.005554199,\n", + " -0.00039863586,\n", + " 0.08251953,\n", + " -0.03491211,\n", + " -0.025878906,\n", + " -0.037109375,\n", + " 0.052734375,\n", + " -0.008911133,\n", + " -0.0390625,\n", + " 0.021362305,\n", + " -0.022949219,\n", + " 0.029907227,\n", + " 0.041259766,\n", + " 0.017211914,\n", + " -0.016845703,\n", + " 0.043701172,\n", + " -0.025512695,\n", + " -0.020019531,\n", + " 0.01953125,\n", + " -0.008422852,\n", + " 0.016357422,\n", + " -0.044921875,\n", + " -0.030761719,\n", + " 0.029541016,\n", + " -0.008422852,\n", + " 0.01977539,\n", + " 0.006652832,\n", + " 0.0031433105,\n", + " 0.044189453,\n", + " 0.00793457,\n", + " 0.02722168,\n", + " -0.043701172,\n", + " -0.01550293,\n", + " -0.0068359375,\n", + " -0.033935547,\n", + " 0.025024414,\n", + " -0.038085938,\n", + " -0.037353516,\n", + " -0.032714844,\n", + " 0.037841797,\n", + " -0.057373047,\n", + " -0.017211914,\n", + " -0.012878418,\n", + " -0.0069274902,\n", + " -0.020874023,\n", + " -0.037841797,\n", + " -0.036132812,\n", + " 0.033691406,\n", + " -0.030273438,\n", + " 0.033691406,\n", + " 0.049316406,\n", + " -0.02746582,\n", + " -0.030761719,\n", + " 0.03564453,\n", + " -0.02746582,\n", + " -0.03112793,\n", + " -0.00340271,\n", + " 0.016845703,\n", + " 0.03515625,\n", + " 0.009033203,\n", + " 0.026489258,\n", + " 0.04663086,\n", + " 0.0067443848,\n", + " 0.017944336,\n", + " 0.008850098,\n", + " -0.008544922,\n", + " 0.0022277832,\n", + " -0.030029297,\n", + " 0.010192871,\n", + " -0.021240234,\n", + " -0.020385742,\n", + " 0.008666992,\n", + " 0.005706787,\n", + " -0.02758789,\n", + " 0.05419922,\n", + " 0.036132812,\n", + " 0.032714844,\n", + " -0.014526367,\n", + " -0.02758789,\n", + " 0.03466797,\n", + " 0.05883789,\n", + " 0.026977539,\n", + " 0.05102539,\n", + " 0.052246094,\n", + " 0.0056152344,\n", + " 0.009094238,\n", + " -0.000579834,\n", + " -0.03100586,\n", + " -0.017822266,\n", + " 0.040283203,\n", + " -0.011474609,\n", + " -0.063964844,\n", + " 0.026977539,\n", + " 0.006958008,\n", + " -0.009765625,\n", + " 0.010253906,\n", + " -0.007385254,\n", + " 0.0051574707,\n", + " -0.0030670166,\n", + " -0.011047363,\n", + " -0.017333984,\n", + " -0.015991211,\n", + " 0.026245117,\n", + " -0.030639648,\n", + " -0.022460938,\n", + " 0.0059814453,\n", + " -0.021240234,\n", + " -0.011962891,\n", + " 0.010925293,\n", + " -0.021484375,\n", + " 0.037353516,\n", + " -0.050048828,\n", + " -0.08544922,\n", + " -0.024658203,\n", + " -0.026611328,\n", + " 0.020385742,\n", + " -0.033935547,\n", + " 0.025390625,\n", + " 0.0030670166,\n", + " -0.008117676,\n", + " -0.022338867,\n", + " 0.024291992,\n", + " 0.052246094,\n", + " -0.059570312,\n", + " -0.0138549805,\n", + " -0.01940918,\n", + " 0.05517578,\n", + " -0.0006866455,\n", + " 0.0049743652,\n", + " 0.07519531,\n", + " 0.057617188,\n", + " 0.004425049,\n", + " -0.043945312,\n", + " -0.029663086,\n", + " 0.017578125,\n", + " 0.030029297,\n", + " -0.007446289,\n", + " -0.030761719,\n", + " -0.021484375,\n", + " -0.009765625,\n", + " 0.013671875,\n", + " 0.012207031,\n", + " -0.012878418,\n", + " -0.043945312,\n", + " -0.020141602,\n", + " 0.013183594,\n", + " 0.0074157715,\n", + " -0.028686523,\n", + " 0.025268555,\n", + " 0.026367188,\n", + " 0.030395508,\n", + " 0.041748047,\n", + " 0.017944336,\n", + " 0.036376953,\n", + " 0.010437012,\n", + " -0.0625,\n", + " 0.04296875,\n", + " 0.0057373047,\n", + " 0.059570312,\n", + " 0.072753906,\n", + " 0.03881836,\n", + " -0.0021972656,\n", + " -0.027832031,\n", + " 0.0074157715,\n", + " 0.00045394897,\n", + " -0.003753662,\n", + " -0.010070801,\n", + " 0.008972168,\n", + " -0.0051574707,\n", + " 0.007537842,\n", + " -0.0079956055,\n", + " -0.03173828,\n", + " -0.012451172,\n", + " -0.015563965,\n", + " 0.027709961,\n", + " -0.039794922,\n", + " -0.016113281,\n", + " -0.056396484,\n", + " 0.016601562,\n", + " -0.030395508,\n", + " -0.033447266,\n", + " 0.052001953,\n", + " 0.001159668,\n", + " -0.02368164,\n", + " -0.046142578,\n", + " 0.01977539,\n", + " -0.02746582,\n", + " -0.038330078,\n", + " -0.052734375,\n", + " -0.030151367,\n", + " -0.030639648,\n", + " -0.0043945312,\n", + " 0.025390625,\n", + " 0.0048828125,\n", + " 0.029663086,\n", + " 0.01928711,\n", + " -0.025634766,\n", + " -0.022583008,\n", + " -0.019165039,\n", + " 0.026733398,\n", + " -0.035888672,\n", + " -0.015014648,\n", + " -0.0069274902,\n", + " -0.005126953,\n", + " 0.032958984,\n", + " 0.033203125,\n", + " -0.019897461,\n", + " -0.038330078,\n", + " 0.020874023,\n", + " 0.027954102,\n", + " -0.06689453,\n", + " -0.0069274902,\n", + " -0.0036315918,\n", + " -0.025634766,\n", + " -0.020507812,\n", + " 0.017333984,\n", + " -0.019165039,\n", + " 0.04663086,\n", + " -0.052734375,\n", + " -0.017333984,\n", + " 0.009338379,\n", + " -0.012756348,\n", + " -0.007507324,\n", + " 0.045166016,\n", + " 0.02722168,\n", + " -0.023071289,\n", + " -0.019042969,\n", + " 0.0045166016,\n", + " 0.017822266,\n", + " -0.024291992,\n", + " 0.030883789,\n", + " -0.008361816,\n", + " -0.050048828,\n", + " -0.026000977,\n", + " 0.021850586,\n", + " 0.011413574,\n", + " 0.0134887695,\n", + " 0.013000488,\n", + " -0.0068359375,\n", + " -0.040039062,\n", + " 0.007446289,\n", + " 0.020751953,\n", + " 0.037841797,\n", + " -0.03173828,\n", + " -0.044921875,\n", + " -0.012451172,\n", + " 0.00032806396,\n", + " -0.026123047,\n", + " -0.059570312,\n", + " -0.028564453,\n", + " 0.04272461,\n", + " 0.0064086914,\n", + " 0.030639648,\n", + " 0.018188477,\n", + " 0.016113281,\n", + " 0.043945312,\n", + " 0.015991211,\n", + " 0.020019531,\n", + " 0.055419922,\n", + " -0.016357422,\n", + " -0.002166748,\n", + " -0.025756836,\n", + " 0.015625,\n", + " -0.020263672,\n", + " 0.012573242,\n", + " 0.029296875,\n", + " -0.06689453,\n", + " 0.0062561035,\n", + " 0.03857422,\n", + " -0.010803223,\n", + " -0.026245117,\n", + " -0.016235352,\n", + " -0.04248047,\n", + " 0.033691406,\n", + " 0.02746582,\n", + " 0.024902344,\n", + " -0.025878906,\n", + " 0.046142578,\n", + " -0.029541016,\n", + " -0.015075684,\n", + " 0.015991211,\n", + " -0.030883789,\n", + " 0.017700195,\n", + " 0.03173828,\n", + " -0.005126953,\n", + " -0.0034484863,\n", + " -0.041992188,\n", + " -0.01159668,\n", + " -0.007293701,\n", + " 0.04321289,\n", + " 0.009399414,\n", + " -0.017578125,\n", + " -0.029418945,\n", + " 0.06542969,\n", + " -0.03125,\n", + " 0.020019531,\n", + " -0.05029297,\n", + " 0.033447266,\n", + " -0.0154418945,\n", + " 0.041748047,\n", + " 0.04345703,\n", + " 0.03515625,\n", + " -0.003479004,\n", + " 0.021484375,\n", + " -0.025146484,\n", + " 4.196167e-05,\n", + " 0.007659912,\n", + " -0.03540039,\n", + " -0.012512207,\n", + " 0.0087890625,\n", + " 0.041259766,\n", + " 0.015319824,\n", + " -0.018066406,\n", + " -0.0018920898,\n", + " 0.033447266,\n", + " -0.01184082,\n", + " -0.04345703,\n", + " -0.024780273,\n", + " 0.064453125,\n", + " -0.012207031,\n", + " -0.036132812,\n", + " 0.10839844,\n", + " -0.016357422,\n", + " -0.0047302246,\n", + " -0.013793945,\n", + " 0.018066406,\n", + " -0.017700195,\n", + " 0.01953125,\n", + " -0.0027313232,\n", + " -0.04272461,\n", + " 0.01940918,\n", + " -0.01586914,\n", + " 0.024414062,\n", + " -0.044433594,\n", + " -0.026123047,\n", + " 0.022094727,\n", + " -0.046142578,\n", + " 0.030761719,\n", + " 0.017578125,\n", + " -0.0028076172,\n", + " 0.059326172,\n", + " 0.025512695,\n", + " 0.025146484,\n", + " -0.03125,\n", + " 0.002319336,\n", + " -0.022827148,\n", + " 0.053710938,\n", + " -0.010559082,\n", + " 0.04345703,\n", + " 0.005645752,\n", + " -0.021972656,\n", + " -0.018920898,\n", + " -0.040283203,\n", + " 0.017456055,\n", + " 0.056884766,\n", + " 0.01928711,\n", + " -0.022827148,\n", + " 0.012145996,\n", + " -0.047851562,\n", + " 0.021118164,\n", + " -0.028930664,\n", + " -0.029907227,\n", + " -0.030883789,\n", + " -0.022827148,\n", + " -0.013977051,\n", + " 0.043701172,\n", + " 0.007080078,\n", + " 0.04711914,\n", + " -0.010253906,\n", + " -0.041015625,\n", + " 0.009216309,\n", + " -0.010986328,\n", + " 0.04248047,\n", + " -0.02758789,\n", + " 0.025268555,\n", + " 0.03466797,\n", + " 0.045166016,\n", + " -0.00023365021,\n", + " -0.021240234,\n", + " -0.016845703,\n", + " 0.02355957,\n", + " 0.024536133,\n", + " 0.036376953,\n", + " -0.015014648,\n", + " 9.393692e-05,\n", + " -0.0115356445,\n", + " 0.033447266,\n", + " -0.012145996,\n", + " 0.007080078,\n", + " 0.017700195,\n", + " -0.014282227,\n", + " 0.027709961,\n", + " 0.037353516,\n", + " -0.041503906,\n", + " 0.03149414,\n", + " 0.041015625,\n", + " 0.008483887,\n", + " 0.00579834,\n", + " 0.034179688,\n", + " 0.025634766,\n", + " -0.038085938,\n", + " -0.06591797,\n", + " -0.036376953,\n", + " 0.01171875,\n", + " -0.026000977,\n", + " 0.057861328,\n", + " -0.008300781,\n", + " 0.014282227,\n", + " -0.029052734,\n", + " 0.044433594,\n", + " -0.026123047,\n", + " 0.016601562,\n", + " -0.016357422,\n", + " 0.0024261475,\n", + " -0.025268555,\n", + " 0.05493164,\n", + " -0.025756836,\n", + " 0.02746582,\n", + " 0.037353516,\n", + " 0.02734375,\n", + " -0.04736328,\n", + " -0.012756348,\n", + " 0.016601562,\n", + " 0.009765625,\n", + " 0.013366699,\n", + " 0.013305664,\n", + " 0.036621094,\n", + " -0.034423828,\n", + " 0.046875,\n", + " -0.0028533936,\n", + " 0.018310547,\n", + " 0.05517578,\n", + " -0.06591797,\n", + " 0.042236328,\n", + " -0.013305664,\n", + " -0.007446289,\n", + " 0.014343262,\n", + " -0.04296875,\n", + " -0.038330078,\n", + " -0.016235352,\n", + " -0.043701172,\n", + " 0.004180908,\n", + " -0.045410156,\n", + " -0.009643555,\n", + " -0.012939453,\n", + " -0.0020141602,\n", + " -0.006713867,\n", + " -0.03881836,\n", + " -0.010559082,\n", + " 0.036376953,\n", + " 0.024169922,\n", + " -0.01977539,\n", + " 0.025756836,\n", + " -0.010253906,\n", + " 0.05493164,\n", + " 0.01965332,\n", + " 0.012451172,\n", + " 0.053466797,\n", + " -0.0062561035,\n", + " 0.028076172,\n", + " 0.024902344,\n", + " 0.068847656,\n", + " 0.019897461,\n", + " 0.01361084,\n", + " 0.015991211,\n", + " 0.017089844,\n", + " -0.053710938,\n", + " -0.056152344,\n", + " 0.04296875,\n", + " -0.0021972656,\n", + " -0.05517578,\n", + " 0.022460938,\n", + " -0.041259766,\n", + " -0.0234375,\n", + " -0.048583984,\n", + " -0.029296875,\n", + " 0.034423828,\n", + " 0.008056641,\n", + " 0.011352539,\n", + " 0.0390625,\n", + " 0.013366699,\n", + " -0.023803711,\n", + " -0.03466797,\n", + " 0.043701172,\n", + " 0.02746582,\n", + " 0.051757812,\n", + " -0.07128906,\n", + " 0.0059509277,\n", + " 0.022827148,\n", + " 0.013977051,\n", + " -0.046142578,\n", + " -0.016235352,\n", + " 0.017089844,\n", + " -0.001045227,\n", + " -0.014953613,\n", + " 0.012084961,\n", + " -0.0035705566,\n", + " 0.016845703,\n", + " 0.0234375,\n", + " 0.026611328,\n", + " -0.033203125,\n", + " 0.076660156,\n", + " 0.007873535,\n", + " 0.03540039,\n", + " 0.0061950684,\n", + " -0.028564453,\n", + " -0.03491211,\n", + " 0.01586914,\n", + " -0.015991211,\n", + " -0.024780273,\n", + " -0.028686523,\n", + " 0.028076172,\n", + " -0.005645752,\n", + " -0.043945312,\n", + " 0.021118164,\n", + " -0.0027008057,\n", + " 0.01550293,\n", + " 0.031982422,\n", + " -0.024780273,\n", + " 0.025878906,\n", + " 0.05859375,\n", + " -0.0050354004,\n", + " 0.033691406,\n", + " 0.044677734,\n", + " 0.018432617,\n", + " -0.007171631,\n", + " 0.003829956,\n", + " -0.047851562,\n", + " 0.026855469,\n", + " -0.005065918,\n", + " -0.02722168,\n", + " -0.03173828,\n", + " -0.0703125,\n", + " -0.016967773,\n", + " -0.008605957,\n", + " 0.037353516,\n", + " 0.03149414,\n", + " -0.06347656,\n", + " 0.031982422,\n", + " -0.033691406,\n", + " -0.03540039,\n", + " 0.021728516,\n", + " 0.07080078,\n", + " -0.03491211,\n", + " -0.014221191,\n", + " 0.046142578,\n", + " -0.010803223,\n", + " 0.009094238,\n", + " -0.0048217773,\n", + " -5.7935715e-05,\n", + " 0.055664062,\n", + " 0.025512695,\n", + " -0.024291992,\n", + " 0.04663086,\n", + " -0.008300781,\n", + " 0.056640625,\n", + " -0.006713867,\n", + " -0.018188477,\n", + " 0.012268066,\n", + " -0.045898438,\n", + " -0.051513672,\n", + " 0.016357422,\n", + " -0.049316406,\n", + " -0.0020599365,\n", + " -0.04345703,\n", + " -0.08935547,\n", + " -0.056640625,\n", + " -0.048828125,\n", + " -0.020996094,\n", + " 0.036376953,\n", + " -0.052001953,\n", + " 0.020629883,\n", + " 0.048339844,\n", + " 0.029296875,\n", + " 0.021728516,\n", + " 0.028930664,\n", + " -0.024169922,\n", + " -0.030273438,\n", + " -0.036621094,\n", + " -0.028808594,\n", + " -0.0546875,\n", + " 6.3478947e-06,\n", + " 0.03857422,\n", + " 0.01965332,\n", + " -0.016235352,\n", + " -0.017089844,\n", + " -0.012451172,\n", + " 0.010498047,\n", + " 0.025024414,\n", + " 0.016601562,\n", + " 0.032958984,\n", + " 0.0047912598,\n", + " -0.011047363,\n", + " 0.011352539,\n", + " -0.0044555664,\n", + " 0.004211426,\n", + " -0.004119873,\n", + " 0.0045776367,\n", + " 0.03149414,\n", + " -0.025146484,\n", + " -0.010070801,\n", + " -0.02331543,\n", + " 0.032714844,\n", + " 0.018798828,\n", + " -0.020751953,\n", + " 0.06201172,\n", + " 0.0043029785,\n", + " -0.039794922,\n", + " 0.010131836,\n", + " 0.048828125,\n", + " -0.036621094,\n", + " -0.007873535,\n", + " -0.029296875,\n", + " -0.046142578,\n", + " -0.016845703,\n", + " 0.056640625,\n", + " -0.048095703,\n", + " 0.0051574707,\n", + " 0.0008087158,\n", + " 0.00018310547,\n", + " -0.019165039,\n", + " 0.017700195,\n", + " -0.032226562,\n", + " -0.0047912598,\n", + " -0.053710938,\n", + " -0.06542969,\n", + " 0.013427734,\n", + " 0.004119873,\n", + " 0.021362305,\n", + " -0.0038452148,\n", + " -0.008056641,\n", + " -0.021606445,\n", + " 0.00793457,\n", + " -0.018798828,\n", + " -0.048828125,\n", + " 0.006958008,\n", + " -0.0390625,\n", + " -0.044921875,\n", + " -0.029052734,\n", + " -0.0039367676,\n", + " -0.009460449,\n", + " 0.03149414,\n", + " -0.024658203,\n", + " -0.007171631,\n", + " -0.020751953,\n", + " 0.010620117,\n", + " 0.027709961,\n", + " -0.012878418,\n", + " -0.006134033,\n", + " -0.036376953,\n", + " 0.0234375,\n", + " -0.008056641,\n", + " -0.029296875,\n", + " 0.0048217773,\n", + " 0.053222656,\n", + " -0.03857422,\n", + " -0.03930664,\n", + " -0.041015625,\n", + " 0.012084961,\n", + " -0.025146484,\n", + " 0.03491211,\n", + " -0.041748047,\n", + " 0.04248047,\n", + " -0.0003452301,\n", + " -0.018920898,\n", + " -0.046142578,\n", + " -0.014160156,\n", + " 0.046875,\n", + " -0.022216797,\n", + " 0.052246094,\n", + " -0.026611328,\n", + " 0.029541016,\n", + " -0.016357422,\n", + " 0.04272461,\n", + " -0.018920898,\n", + " 0.0078125,\n", + " -0.018676758,\n", + " 0.014770508,\n", + " -0.016357422,\n", + " -0.040527344,\n", + " -0.004486084,\n", + " -0.018066406,\n", + " -0.03100586,\n", + " 0.033691406,\n", + " -0.016113281,\n", + " -0.051757812,\n", + " 0.028320312,\n", + " -0.0234375,\n", + " 0.005126953,\n", + " -0.01171875,\n", + " 0.022216797,\n", + " -0.03466797,\n", + " 0.044433594,\n", + " -0.012268066,\n", + " 0.020263672,\n", + " 0.016479492,\n", + " 0.050048828,\n", + " -0.059570312,\n", + " 0.016967773,\n", + " -0.010925293,\n", + " 0.0013809204,\n", + " 0.026977539,\n", + " -0.022460938,\n", + " 0.034179688,\n", + " 0.01953125,\n", + " 0.005706787,\n", + " 0.036376953,\n", + " -0.018188477,\n", + " 0.041503906,\n", + " 0.08251953,\n", + " 0.009521484,\n", + " 0.005493164,\n", + " 0.0021820068,\n", + " -0.014465332,\n", + " 0.01965332,\n", + " 0.0008735657,\n", + " 0.029418945,\n", + " -0.057617188,\n", + " 0.021972656,\n", + " 0.008483887,\n", + " 0.064941406,\n", + " 0.0013198853,\n", + " -0.032714844,\n", + " -0.0087890625,\n", + " -0.014160156,\n", + " 0.080566406,\n", + " -0.012390137,\n", + " 0.02746582,\n", + " 0.0044555664,\n", + " -0.029541016,\n", + " 0.011657715,\n", + " -0.010803223,\n", + " -0.020874023,\n", + " 0.0030670166,\n", + " 0.013549805,\n", + " 0.0025787354,\n", + " -0.022827148,\n", + " -0.011291504,\n", + " 0.018188477,\n", + " 0.036132812,\n", + " 0.008178711,\n", + " ...]]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "encoder([\"hey\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we define the `RouteLayer`. When called, the route layer will consume text (a query) and output the category (`Route`) it belongs to — to initialize a `RouteLayer` we need our `encoder` model and a list of `routes`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-05-13 22:26:54 INFO semantic_router.utils.logger local\u001b[0m\n" + ] + } + ], + "source": [ + "from semantic_router.layer import RouteLayer\n", + "\n", + "rl = RouteLayer(encoder=encoder, routes=routes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can check the dimensionality of our vectors by looking at the `index` attribute of the `RouteLayer`." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(11, 1024)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl.index.index.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We do have 256-dimensional vectors. Now let's test them:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='politics', function_call=None, similarity_score=None)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"don't you love politics?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='chitchat', function_call=None, similarity_score=None)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"how's the weather today?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Both are classified accurately, what if we send a query that is unrelated to our existing `Route` objects?" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name=None, function_call=None, similarity_score=None)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rl(\"How does llama model work?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, we return `None` because no matches were identified. We always recommend optimizing your `RouteLayer` for optimal performance, you can see how in [this notebook](https://github.com/aurelio-labs/semantic-router/blob/main/docs/06-threshold-optimization.ipynb)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "decision-layer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index 8598bbc58ba586cacf4b81d1ea440d7cbb3c7b0b..a1026240d37fbfecb1ec8b1445d42fc05f04265f 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -70,7 +70,7 @@ class AutoEncoder: elif self.type == EncoderType.GOOGLE: self.model = GoogleEncoder(name=name) elif self.type == EncoderType.BEDROCK: - self.model = BedrockEncoder(name=name) + self.model = BedrockEncoder(name=name) # type: ignore else: raise ValueError(f"Encoder type '{type}' not supported") diff --git a/semantic_router/encoders/bedrock.py b/semantic_router/encoders/bedrock.py index d8e11490027728bea3fe50b0698a1d76d78ca90d..ce04719be7e5fa2117a706938c811803e746b1bc 100644 --- a/semantic_router/encoders/bedrock.py +++ b/semantic_router/encoders/bedrock.py @@ -1,8 +1,25 @@ -import json -from typing import List, Optional, Any +""" +This module provides the BedrockEncoder class for generating embeddings using Amazon's Bedrock Platform. + +The BedrockEncoder class is a subclass of BaseEncoder and utilizes the TextEmbeddingModel from the +Amazon's Bedrock Platform to generate embeddings for given documents. It requires an AWS Access Key ID +and AWS Secret Access Key and supports customization of the pre-trained model, score threshold, and region. + +Example usage: -import boto3 + from semantic_router.encoders.bedrock_encoder import BedrockEncoder + encoder = BedrockEncoder(access_key_id="your-access-key-id", secret_access_key="your-secret-key", region="your-region") + embeddings = encoder(["document1", "document2"]) + +Classes: + BedrockEncoder: A class for generating embeddings using the Bedrock Platform. +""" + +import json +from typing import List, Optional, Any +import os +import tiktoken from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault @@ -11,87 +28,223 @@ class BedrockEncoder(BaseEncoder): client: Any = None type: str = "bedrock" input_type: Optional[str] = "search_query" - session: Optional[Any] = None + name: str + access_key_id: Optional[str] = None + secret_access_key: Optional[str] = None + session_token: Optional[str] = None region: Optional[str] = None def __init__( self, - name: Optional[str] = None, - session: Optional[Any] = None, - region: Optional[str] = None, - score_threshold: float = 0.3, + name: str = EncoderDefault.BEDROCK.value["embedding_model"], input_type: Optional[str] = "search_query", + score_threshold: float = 0.3, + access_key_id: Optional[str] = None, + secret_access_key: Optional[str] = None, + session_token: Optional[str] = None, + region: Optional[str] = None, ): - if name is None: - name = EncoderDefault.BEDROCK.value["embedding_model"] + """Initializes the BedrockEncoder. + + Args: + name: The name of the pre-trained model to use for embedding. + If not provided, the default model specified in EncoderDefault will + be used. + score_threshold: The threshold for similarity scores. + access_key_id: The AWS access key id for an IAM principle. + If not provided, it will be retrieved from the access_key_id + environment variable. + secret_access_key: The secret access key for an IAM principle. + If not provided, it will be retrieved from the AWS_SECRET_KEY + environment variable. + session_token: The session token for an IAM principle. + If not provided, it will be retrieved from the AWS_SESSION_TOKEN + environment variable. + region: The location of the Bedrock resources. + If not provided, it will be retrieved from the AWS_REGION + environment variable, defaulting to "us-west-1" + + Raises: + ValueError: If the Bedrock Platform client fails to initialize. + """ + super().__init__(name=name, score_threshold=score_threshold) + self.access_key_id = self.get_env_variable("access_key_id", access_key_id) + self.secret_access_key = self.get_env_variable( + "secret_access_key", secret_access_key + ) + self.session_token = self.get_env_variable("AWS_SESSION_TOKEN", session_token) + self.region = self.get_env_variable("AWS_REGION", region, default="us-west-1") + self.input_type = input_type - self.session = session or boto3.Session() - if self.session.get_credentials() is None: - raise ValueError("Could not get AWS session") - self.region = region or self.session.region_name - if self.region is None: - raise ValueError("No AWS region provided") + try: - self.client = self.session.client( - service_name="bedrock-runtime", region_name=str(self.region) + self.client = self._initialize_client( + self.access_key_id, + self.secret_access_key, + self.session_token, + self.region, ) + except Exception as e: raise ValueError(f"Bedrock client failed to initialise. Error: {e}") from e + def _initialize_client( + self, access_key_id, secret_access_key, session_token, region + ): + """Initializes the Bedrock client. + + Args: + access_key_id: The Amazon access key ID. + secret_access_key: The Amazon secret key. + region: The location of the AI Platform resources. + + Returns: + An instance of the TextEmbeddingModel client. + + Raises: + ImportError: If the required Bedrock libraries are not + installed. + ValueError: If the Bedrock client fails to initialize. + """ + try: + import boto3 + except ImportError: + raise ImportError( + "Please install Amazon's Boto3 client library to use the BedrockEncoder. " + "You can install them with: " + "`pip install boto3`" + ) + + access_key_id = access_key_id or os.getenv("access_key_id") + aws_secret_key = secret_access_key or os.getenv("secret_access_key") + region = region or os.getenv("AWS_REGION", "us-west-2") + + if access_key_id is None: + raise ValueError("AWS access key ID cannot be 'None'.") + + if aws_secret_key is None: + raise ValueError("AWS secret access key cannot be 'None'.") + + try: + bedrock_client = boto3.client( + "bedrock-runtime", + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + aws_session_token=session_token, + region_name=region, + ) + except Exception as err: + raise ValueError( + f"The Bedrock client failed to initialize. Error: {err}" + ) from err + + return bedrock_client + def __call__(self, docs: List[str]) -> List[List[float]]: + """Generates embeddings for the given documents. + + Args: + docs: A list of strings representing the documents to embed. + + Returns: + A list of lists, where each inner list contains the embedding values for a + document. + + Raises: + ValueError: If the Bedrock Platform client is not initialized or if the + API call fails. + """ if self.client is None: raise ValueError("Bedrock client is not initialised.") try: embeddings = [] - if "amazon" in self.name: + + def chunk_strings(strings, MAX_WORDS=20): + """ + Breaks up a list of strings into smaller chunks. + + Args: + strings (list): A list of strings to be chunked. + max_chunk_size (int): The maximum size of each chunk. Default is 75. + + Returns: + list: A list of lists, where each inner list contains a chunk of strings. + """ + encoding = tiktoken.get_encoding("cl100k_base") + chunked_strings = [] + current_chunk = [] + + for text in strings: + encoded_text = encoding.encode(text) + + if len(encoded_text) > MAX_WORDS: + current_chunk = [ + encoding.decode(encoded_text[i : i + MAX_WORDS]) + for i in range(0, len(encoded_text), MAX_WORDS) + ] + else: + current_chunk = [encoding.decode(encoded_text)] + + chunked_strings.append(current_chunk) + return chunked_strings + + if self.name and "amazon" in self.name: for doc in docs: - doc = json.dumps( + embedding_body = json.dumps( { "inputText": doc, } ) response = self.client.invoke_model( - body=doc, + body=embedding_body, modelId=self.name, - accept="*/*", + accept="application/json", contentType="application/json", ) response_body = json.loads(response.get("body").read()) + embeddings.append(response_body.get("embedding")) + elif self.name and "cohere" in self.name: + chunked_docs = chunk_strings(docs) + for chunk in chunked_docs: + chunk = json.dumps({"texts": chunk, "input_type": self.input_type}) - embedding = response_body.get("embedding") - embeddings.append(embedding) - elif "cohere" in self.name: - MAX_WORDS = 400 - for doc in docs: - words = doc.split() - if len(words) > MAX_WORDS: - chunks = [ - " ".join(words[i : i + MAX_WORDS]) - for i in range(0, len(words), MAX_WORDS) - ] - else: - chunks = [doc] - - for chunk in chunks: - chunk = json.dumps( - {"texts": [chunk], "input_type": self.input_type} - ) - - response = self.client.invoke_model( - body=chunk, - modelId=self.name, - accept="*/*", - contentType="application/json", - ) + response = self.client.invoke_model( + body=chunk, + modelId=self.name, + accept="*/*", + contentType="application/json", + ) - response_body = json.loads(response.get("body").read()) + response_body = json.loads(response.get("body").read()) - chunk_embeddings = response_body.get("embeddings") - embeddings.extend(chunk_embeddings) + chunk_embeddings = response_body.get("embeddings") + embeddings.extend(chunk_embeddings) else: raise ValueError("Unknown model name") return embeddings except Exception as e: raise ValueError(f"Bedrock call failed. Error: {e}") from e + + @staticmethod + def get_env_variable(var_name, provided_value, default=None): + """Retrieves environment variable or uses a provided value. + + Args: + var_name (str): The name of the environment variable. + provided_value (Optional[str]): The provided value to use if not None. + default (Optional[str]): The default value if the environment variable is not set. + + Returns: + str: The value of the environment variable or the provided/default value. + + Raises: + ValueError: If no value is provided and the environment variable is not set. + """ + if provided_value is not None: + return provided_value + value = os.getenv(var_name, default) + if value is None: + raise ValueError(f"No {var_name} provided") + return value diff --git a/tests/unit/encoders/test_bedrock.py b/tests/unit/encoders/test_bedrock.py index 6d43882414c3a2b425084d034bbfb619cc8da334..43955d453c704bc95fcce598b94e0a46479947a0 100644 --- a/tests/unit/encoders/test_bedrock.py +++ b/tests/unit/encoders/test_bedrock.py @@ -6,39 +6,40 @@ from semantic_router.encoders import BedrockEncoder @pytest.fixture def bedrock_encoder(mocker): - mocker.patch("boto3.Session") - mocker.patch("boto3.Session.client") - return BedrockEncoder() + mocker.patch("semantic_router.encoders.bedrock.BedrockEncoder._initialize_client") + return BedrockEncoder( + access_key_id="fake_id", + secret_access_key="fake_secret", + session_token="fake_token", + region="us-west-2", + ) class TestBedrockEncoder: def test_initialisation_with_default_values(self, bedrock_encoder): - assert bedrock_encoder.client is not None, "Client should be initialised" - assert bedrock_encoder.type == "bedrock", "Default type not set correctly" assert ( bedrock_encoder.input_type == "search_query" ), "Default input type not set correctly" - assert bedrock_encoder.session is not None, "Session should be initialised" - assert bedrock_encoder.region is not None, "Region should be initialised" + assert bedrock_encoder.region == "us-west-2", "Region should be initialised" def test_initialisation_with_custom_values(self, mocker): - mocker.patch("boto3.Session") - mocker.patch("boto3.Session.client") + # mocker.patch( + # "semantic_router.encoders.bedrock.BedrockEncoder._initialize_client" + # ) name = "custom_model" - session = mocker.Mock() - region = "us-west-2" score_threshold = 0.5 input_type = "custom_input" bedrock_encoder = BedrockEncoder( name=name, - session=session, - region=region, score_threshold=score_threshold, input_type=input_type, + access_key_id="fake_id", + secret_access_key="fake_secret", + session_token="fake_token", + region="us-west-2", ) assert bedrock_encoder.name == name, "Custom name not set correctly" - assert bedrock_encoder.session == session, "Custom session not set correctly" - assert bedrock_encoder.region == region, "Custom region not set correctly" + assert bedrock_encoder.region == "us-west-2", "Custom region not set correctly" assert ( bedrock_encoder.score_threshold == score_threshold ), "Custom score threshold not set correctly" @@ -49,28 +50,9 @@ class TestBedrockEncoder: def test_call_method(self, bedrock_encoder): response_content = json.dumps({"embedding": [0.1, 0.2, 0.3]}) response_body = BytesIO(response_content.encode("utf-8")) - - mock_response = {"body": response_body} - bedrock_encoder.client.invoke_model.return_value = mock_response - - result = bedrock_encoder(["test"]) - - assert isinstance(result, list), "Result should be a list" - assert all( - isinstance(item, list) for item in result - ), "Each item in result should be a list" - assert result == [[0.1, 0.2, 0.3]], "Embedding should be [0.1, 0.2, 0.3]" - - def test_returns_list_of_embeddings_for_valid_input(self, bedrock_encoder): - response_content = json.dumps({"embedding": [0.1, 0.2, 0.3]}) - response_body = BytesIO(response_content.encode("utf-8")) - mock_response = {"body": response_body} - bedrock_encoder.client.invoke_model.return_value = mock_response - result = bedrock_encoder(["test"]) - assert isinstance(result, list), "Result should be a list" assert all( isinstance(item, list) for item in result @@ -78,9 +60,17 @@ class TestBedrockEncoder: assert result == [[0.1, 0.2, 0.3]], "Embedding should be [0.1, 0.2, 0.3]" def test_raises_value_error_if_client_is_not_initialised(self, mocker): - mocker.patch("boto3.Session.client", return_value=None) + mocker.patch( + "semantic_router.encoders.bedrock.BedrockEncoder._initialize_client", + side_effect=Exception("Client initialisation failed"), + ) with pytest.raises(ValueError): - BedrockEncoder() + BedrockEncoder( + access_key_id="fake_id", + secret_access_key="fake_secret", + session_token="fake_token", + region="us-west-2", + ) def test_raises_value_error_if_call_to_bedrock_fails(self, bedrock_encoder): bedrock_encoder.client.invoke_model.side_effect = Exception( @@ -89,51 +79,17 @@ class TestBedrockEncoder: with pytest.raises(ValueError): bedrock_encoder(["test"]) - def test_raises_value_error_if_no_aws_session_credentials(self, mocker): - mocker.patch("boto3.Session") - mock_session = mocker.Mock() - mock_session.get_credentials.return_value = None - with pytest.raises(ValueError, match="Could not get AWS session"): - BedrockEncoder(session=mock_session) - - def test_raises_value_error_if_no_aws_region(self, mocker): - mocker.patch("boto3.Session") - mock_session = mocker.Mock() - mock_session.region_name = None - with pytest.raises(ValueError, match="No AWS region provided"): - BedrockEncoder(session=mock_session) - - def test_raises_value_error_if_client_initialisation_fails(self, mocker): - mocker.patch("boto3.Session") - mock_session = mocker.Mock() - mock_session.client.side_effect = Exception("Client initialisation failed") - with pytest.raises(ValueError, match="Bedrock client failed to initialise"): - BedrockEncoder(session=mock_session) - - def test_raises_value_error_for_unknown_model_name(self, mocker): - mocker.patch("boto3.Session") - mock_session = mocker.Mock() - mock_session.get_credentials.return_value = True - mocker.patch("boto3.Session.client") - - unknown_model_name = "unknown_model" - bedrock_encoder = BedrockEncoder( - name=unknown_model_name, - session=mock_session, - region="us-west-2", - ) - - with pytest.raises(ValueError, match="Unknown model name"): - bedrock_encoder(["test"]) - @pytest.fixture def bedrock_encoder_with_cohere(mocker): - mocker.patch("boto3.Session") - mock_session = mocker.Mock() - mock_session.get_credentials.return_value = True - mocker.patch("boto3.Session.client") - return BedrockEncoder(name="cohere_model", session=mock_session, region="us-west-2") + mocker.patch("semantic_router.encoders.bedrock.BedrockEncoder._initialize_client") + return BedrockEncoder( + name="cohere_model", + access_key_id="fake_id", + secret_access_key="fake_secret", + session_token="fake_token", + region="us-west-2", + ) class TestBedrockEncoderWithCohere: @@ -141,11 +97,8 @@ class TestBedrockEncoderWithCohere: response_content = json.dumps({"embeddings": [[0.1, 0.2, 0.3]]}) response_body = BytesIO(response_content.encode("utf-8")) mock_response = {"body": response_body} - bedrock_encoder_with_cohere.client.invoke_model.return_value = mock_response - result = bedrock_encoder_with_cohere(["short test"]) - assert isinstance(result, list), "Result should be a list" assert all( isinstance(item, list) for item in result @@ -157,10 +110,7 @@ class TestBedrockEncoderWithCohere: response_content = json.dumps({"embeddings": [[0.1, 0.2, 0.3]]}) response_body = BytesIO(response_content.encode("utf-8")) mock_response = {"body": response_body} - bedrock_encoder_with_cohere.client.invoke_model.return_value = mock_response - result = bedrock_encoder_with_cohere(["test with different input type"]) - assert isinstance(result, list), "Result should be a list" assert result == [[0.1, 0.2, 0.3]], "Expected specific embeddings"