From b2dd6e39f0570cb748f17fe8f0bab03734b701f2 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Wed, 8 Nov 2023 23:19:44 +0400 Subject: [PATCH] Added in more utterances of type 'other' ('NULL') These now match the number of non-other types. --- 00_performance_tests.ipynb | 180 ++++++++++++++++++++++--------- decision_layer/decision_layer.py | 2 +- 2 files changed, 129 insertions(+), 53 deletions(-) diff --git a/00_performance_tests.ipynb b/00_performance_tests.ipynb index e59e8823..18ef8571 100644 --- a/00_performance_tests.ipynb +++ b/00_performance_tests.ipynb @@ -151,29 +151,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, - "outputs": [ - { - "ename": "APIError", - "evalue": "Bad gateway. {\"error\":{\"code\":502,\"message\":\"Bad gateway.\",\"param\":null,\"type\":\"cf_bad_gateway\"}} 502 {'error': {'code': 502, 'message': 'Bad gateway.', 'param': None, 'type': 'cf_bad_gateway'}} {'Date': 'Wed, 08 Nov 2023 13:48:30 GMT', 'Content-Type': 'application/json', 'Content-Length': '84', 'Connection': 'keep-alive', 'X-Frame-Options': 'SAMEORIGIN', 'Referrer-Policy': 'same-origin', 'Cache-Control': 'private, max-age=0, no-store, no-cache, must-revalidate, post-check=0, pre-check=0', 'Expires': 'Thu, 01 Jan 1970 00:00:01 GMT', 'Server': 'cloudflare', 'CF-RAY': '822e45f93ceeb482-DXB', 'alt-svc': 'h3=\":443\"; ma=86400'}", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mAPIError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32mc:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\20231106 Semantic Layer\\Repo\\semantic-layer\\00_performance_tests.ipynb Cell 3\u001b[0m line \u001b[0;36m1\n\u001b[0;32m <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y135sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mdecision_layer\u001b[39;00m \u001b[39mimport\u001b[39;00m DecisionLayer\n\u001b[0;32m <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y135sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m decisions \u001b[39m=\u001b[39m [\n\u001b[0;32m <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y135sZmlsZQ%3D%3D?line=3'>4</a>\u001b[0m politics,\n\u001b[0;32m <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y135sZmlsZQ%3D%3D?line=4'>5</a>\u001b[0m other_brands,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y135sZmlsZQ%3D%3D?line=9'>10</a>\u001b[0m mathematics,\n\u001b[0;32m <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y135sZmlsZQ%3D%3D?line=10'>11</a>\u001b[0m ]\n\u001b[1;32m---> <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y135sZmlsZQ%3D%3D?line=12'>13</a>\u001b[0m dl \u001b[39m=\u001b[39m DecisionLayer(encoder\u001b[39m=\u001b[39;49mencoder, decisions\u001b[39m=\u001b[39;49mdecisions)\n", - "File \u001b[1;32mc:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\20231106 Semantic Layer\\Repo\\semantic-layer\\decision_layer\\decision_layer.py:16\u001b[0m, in \u001b[0;36mDecisionLayer.__init__\u001b[1;34m(self, encoder, decisions)\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[39mif\u001b[39;00m decisions:\n\u001b[0;32m 14\u001b[0m \u001b[39m# initialize index now\u001b[39;00m\n\u001b[0;32m 15\u001b[0m \u001b[39mfor\u001b[39;00m decision \u001b[39min\u001b[39;00m decisions:\n\u001b[1;32m---> 16\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_add_decision(decision\u001b[39m=\u001b[39;49mdecision)\n", - "File \u001b[1;32mc:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\20231106 Semantic Layer\\Repo\\semantic-layer\\decision_layer\\decision_layer.py:29\u001b[0m, in \u001b[0;36mDecisionLayer._add_decision\u001b[1;34m(self, decision)\u001b[0m\n\u001b[0;32m 27\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_add_decision\u001b[39m(\u001b[39mself\u001b[39m, decision: Decision):\n\u001b[0;32m 28\u001b[0m \u001b[39m# create embeddings\u001b[39;00m\n\u001b[1;32m---> 29\u001b[0m embeds \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mencoder(decision\u001b[39m.\u001b[39;49mutterances)\n\u001b[0;32m 31\u001b[0m \u001b[39m# create decision array\u001b[39;00m\n\u001b[0;32m 32\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcategories \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n", - "File \u001b[1;32mc:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\20231106 Semantic Layer\\Repo\\semantic-layer\\decision_layer\\encoders\\openai.py:24\u001b[0m, in \u001b[0;36mOpenAIEncoder.__call__\u001b[1;34m(self, texts)\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[39mfor\u001b[39;00m j \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m5\u001b[39m):\n\u001b[0;32m 22\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m 23\u001b[0m \u001b[39m# create embeddings\u001b[39;00m\n\u001b[1;32m---> 24\u001b[0m res \u001b[39m=\u001b[39m openai\u001b[39m.\u001b[39;49mEmbedding\u001b[39m.\u001b[39;49mcreate(\n\u001b[0;32m 25\u001b[0m \u001b[39minput\u001b[39;49m\u001b[39m=\u001b[39;49mtexts, engine\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mname\n\u001b[0;32m 26\u001b[0m )\n\u001b[0;32m 27\u001b[0m passed \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n\u001b[0;32m 28\u001b[0m \u001b[39mexcept\u001b[39;00m openai\u001b[39m.\u001b[39merror\u001b[39m.\u001b[39mRateLimitError:\n", - "File \u001b[1;32mc:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_layer\\Lib\\site-packages\\openai\\api_resources\\embedding.py:33\u001b[0m, in \u001b[0;36mEmbedding.create\u001b[1;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[0;32m 31\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mTrue\u001b[39;00m:\n\u001b[0;32m 32\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m---> 33\u001b[0m response \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49mcreate(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m 35\u001b[0m \u001b[39m# If a user specifies base64, we'll just return the encoded string.\u001b[39;00m\n\u001b[0;32m 36\u001b[0m \u001b[39m# This is only for the default case.\u001b[39;00m\n\u001b[0;32m 37\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m user_provided_encoding_format:\n", - "File \u001b[1;32mc:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_layer\\Lib\\site-packages\\openai\\api_resources\\abstract\\engine_api_resource.py:155\u001b[0m, in \u001b[0;36mEngineAPIResource.create\u001b[1;34m(cls, api_key, api_base, api_type, request_id, api_version, organization, **params)\u001b[0m\n\u001b[0;32m 129\u001b[0m \u001b[39m@classmethod\u001b[39m\n\u001b[0;32m 130\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcreate\u001b[39m(\n\u001b[0;32m 131\u001b[0m \u001b[39mcls\u001b[39m,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 138\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mparams,\n\u001b[0;32m 139\u001b[0m ):\n\u001b[0;32m 140\u001b[0m (\n\u001b[0;32m 141\u001b[0m deployment_id,\n\u001b[0;32m 142\u001b[0m engine,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 152\u001b[0m api_key, api_base, api_type, api_version, organization, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mparams\n\u001b[0;32m 153\u001b[0m )\n\u001b[1;32m--> 155\u001b[0m response, _, api_key \u001b[39m=\u001b[39m requestor\u001b[39m.\u001b[39;49mrequest(\n\u001b[0;32m 156\u001b[0m \u001b[39m\"\u001b[39;49m\u001b[39mpost\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[0;32m 157\u001b[0m url,\n\u001b[0;32m 158\u001b[0m params\u001b[39m=\u001b[39;49mparams,\n\u001b[0;32m 159\u001b[0m headers\u001b[39m=\u001b[39;49mheaders,\n\u001b[0;32m 160\u001b[0m stream\u001b[39m=\u001b[39;49mstream,\n\u001b[0;32m 161\u001b[0m request_id\u001b[39m=\u001b[39;49mrequest_id,\n\u001b[0;32m 162\u001b[0m request_timeout\u001b[39m=\u001b[39;49mrequest_timeout,\n\u001b[0;32m 163\u001b[0m )\n\u001b[0;32m 165\u001b[0m \u001b[39mif\u001b[39;00m stream:\n\u001b[0;32m 166\u001b[0m \u001b[39m# must be an iterator\u001b[39;00m\n\u001b[0;32m 167\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(response, OpenAIResponse)\n", - "File \u001b[1;32mc:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_layer\\Lib\\site-packages\\openai\\api_requestor.py:299\u001b[0m, in \u001b[0;36mAPIRequestor.request\u001b[1;34m(self, method, url, params, headers, files, stream, request_id, request_timeout)\u001b[0m\n\u001b[0;32m 278\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrequest\u001b[39m(\n\u001b[0;32m 279\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[0;32m 280\u001b[0m method,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 287\u001b[0m request_timeout: Optional[Union[\u001b[39mfloat\u001b[39m, Tuple[\u001b[39mfloat\u001b[39m, \u001b[39mfloat\u001b[39m]]] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[0;32m 288\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], \u001b[39mbool\u001b[39m, \u001b[39mstr\u001b[39m]:\n\u001b[0;32m 289\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mrequest_raw(\n\u001b[0;32m 290\u001b[0m method\u001b[39m.\u001b[39mlower(),\n\u001b[0;32m 291\u001b[0m url,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 297\u001b[0m request_timeout\u001b[39m=\u001b[39mrequest_timeout,\n\u001b[0;32m 298\u001b[0m )\n\u001b[1;32m--> 299\u001b[0m resp, got_stream \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_interpret_response(result, stream)\n\u001b[0;32m 300\u001b[0m \u001b[39mreturn\u001b[39;00m resp, got_stream, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mapi_key\n", - "File \u001b[1;32mc:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_layer\\Lib\\site-packages\\openai\\api_requestor.py:710\u001b[0m, in \u001b[0;36mAPIRequestor._interpret_response\u001b[1;34m(self, result, stream)\u001b[0m\n\u001b[0;32m 702\u001b[0m \u001b[39mreturn\u001b[39;00m (\n\u001b[0;32m 703\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_interpret_response_line(\n\u001b[0;32m 704\u001b[0m line, result\u001b[39m.\u001b[39mstatus_code, result\u001b[39m.\u001b[39mheaders, stream\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m\n\u001b[0;32m 705\u001b[0m )\n\u001b[0;32m 706\u001b[0m \u001b[39mfor\u001b[39;00m line \u001b[39min\u001b[39;00m parse_stream(result\u001b[39m.\u001b[39miter_lines())\n\u001b[0;32m 707\u001b[0m ), \u001b[39mTrue\u001b[39;00m\n\u001b[0;32m 708\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m 709\u001b[0m \u001b[39mreturn\u001b[39;00m (\n\u001b[1;32m--> 710\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_interpret_response_line(\n\u001b[0;32m 711\u001b[0m result\u001b[39m.\u001b[39;49mcontent\u001b[39m.\u001b[39;49mdecode(\u001b[39m\"\u001b[39;49m\u001b[39mutf-8\u001b[39;49m\u001b[39m\"\u001b[39;49m),\n\u001b[0;32m 712\u001b[0m result\u001b[39m.\u001b[39;49mstatus_code,\n\u001b[0;32m 713\u001b[0m result\u001b[39m.\u001b[39;49mheaders,\n\u001b[0;32m 714\u001b[0m stream\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m,\n\u001b[0;32m 715\u001b[0m ),\n\u001b[0;32m 716\u001b[0m \u001b[39mFalse\u001b[39;00m,\n\u001b[0;32m 717\u001b[0m )\n", - "File \u001b[1;32mc:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_layer\\Lib\\site-packages\\openai\\api_requestor.py:775\u001b[0m, in \u001b[0;36mAPIRequestor._interpret_response_line\u001b[1;34m(self, rbody, rcode, rheaders, stream)\u001b[0m\n\u001b[0;32m 773\u001b[0m stream_error \u001b[39m=\u001b[39m stream \u001b[39mand\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39merror\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m resp\u001b[39m.\u001b[39mdata\n\u001b[0;32m 774\u001b[0m \u001b[39mif\u001b[39;00m stream_error \u001b[39mor\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39m200\u001b[39m \u001b[39m<\u001b[39m\u001b[39m=\u001b[39m rcode \u001b[39m<\u001b[39m \u001b[39m300\u001b[39m:\n\u001b[1;32m--> 775\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mhandle_error_response(\n\u001b[0;32m 776\u001b[0m rbody, rcode, resp\u001b[39m.\u001b[39mdata, rheaders, stream_error\u001b[39m=\u001b[39mstream_error\n\u001b[0;32m 777\u001b[0m )\n\u001b[0;32m 778\u001b[0m \u001b[39mreturn\u001b[39;00m resp\n", - "\u001b[1;31mAPIError\u001b[0m: Bad gateway. {\"error\":{\"code\":502,\"message\":\"Bad gateway.\",\"param\":null,\"type\":\"cf_bad_gateway\"}} 502 {'error': {'code': 502, 'message': 'Bad gateway.', 'param': None, 'type': 'cf_bad_gateway'}} {'Date': 'Wed, 08 Nov 2023 13:48:30 GMT', 'Content-Type': 'application/json', 'Content-Length': '84', 'Connection': 'keep-alive', 'X-Frame-Options': 'SAMEORIGIN', 'Referrer-Policy': 'same-origin', 'Cache-Control': 'private, max-age=0, no-store, no-cache, must-revalidate, post-check=0, pre-check=0', 'Expires': 'Thu, 01 Jan 1970 00:00:01 GMT', 'Server': 'cloudflare', 'CF-RAY': '822e45f93ceeb482-DXB', 'alt-svc': 'h3=\":443\"; ma=86400'}" - ] - } - ], + "outputs": [], "source": [ "from decision_layer import DecisionLayer\n", "\n", @@ -192,7 +172,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -343,45 +323,141 @@ " (\"Do you like to watch TV?\", \"NULL\"),\n", " (\"What's your favorite type of cuisine?\", \"NULL\"),\n", " (\"Do you like to play video games?\", \"NULL\"),\n", + " (\"What's the weather like today?\", \"NULL\"),\n", + " (\"Tell me a fun fact.\", \"NULL\"),\n", + " (\"What's the time?\", \"NULL\"),\n", + " (\"Can you recommend a good book?\", \"NULL\"),\n", + " (\"What's the latest news?\", \"NULL\"),\n", + " (\"Tell me a story.\", \"NULL\"),\n", + " (\"What's your favorite joke?\", \"NULL\"),\n", + " (\"Can you play music?\", \"NULL\"),\n", + " (\"What's the capital of France?\", \"NULL\"),\n", + " (\"Who won the last World Cup?\", \"NULL\"),\n", + " (\"What's the tallest mountain in the world?\", \"NULL\"),\n", + " (\"Who is the current president of the United States?\", \"NULL\"),\n", + " (\"What's the distance to the moon?\", \"NULL\"),\n", + " (\"Can you set a reminder for me?\", \"NULL\"),\n", + " (\"What's the meaning of life?\", \"NULL\"),\n", + " (\"Can you tell me a riddle?\", \"NULL\"),\n", + " (\"What's the population of China?\", \"NULL\"),\n", + " (\"Who wrote 'To Kill a Mockingbird'?\", \"NULL\"),\n", + " (\"What's the longest river in the world?\", \"NULL\"),\n", + " (\"Can you translate 'hello' to Spanish?\", \"NULL\"),\n", + " (\"What's the speed of light?\", \"NULL\"),\n", + " (\"Who invented the telephone?\", \"NULL\"),\n", + " (\"What's the currency of Japan?\", \"NULL\"),\n", + " (\"Who painted the Mona Lisa?\", \"NULL\"),\n", + " (\"What's the largest ocean in the world?\", \"NULL\"),\n", + " (\"Who is the richest person in the world?\", \"NULL\"),\n", + " (\"What's the national animal of Australia?\", \"NULL\"),\n", + " (\"Who discovered gravity?\", \"NULL\"),\n", + " (\"What's the lifespan of a turtle?\", \"NULL\"),\n", + " (\"Can you tell me a tongue twister?\", \"NULL\"),\n", + " (\"What's the national flower of India?\", \"NULL\"),\n", + " (\"Who is the author of 'Harry Potter'?\", \"NULL\"),\n", + " (\"What's the diameter of the Earth?\", \"NULL\"),\n", + " (\"Who was the first person to climb Mount Everest?\", \"NULL\"),\n", + " (\"What's the national bird of the United States?\", \"NULL\"),\n", + " (\"Who is the CEO of Tesla?\", \"NULL\"),\n", + " (\"What's the highest grossing movie of all time?\", \"NULL\"),\n", + " (\"Can you tell me a nursery rhyme?\", \"NULL\"),\n", + " (\"What's the national sport of Canada?\", \"NULL\"),\n", + " (\"Who is the Prime Minister of the United Kingdom?\", \"NULL\"),\n", + " (\"What's the deepest part of the ocean?\", \"NULL\"),\n", + " (\"Who composed the Fifth Symphony?\", \"NULL\"),\n", + " (\"What's the largest country in the world?\", \"NULL\"),\n", + " (\"Who is the fastest man in the world?\", \"NULL\"),\n", + " (\"What's the national dish of Spain?\", \"NULL\"),\n", + " (\"Who won the Nobel Prize in Literature last year?\", \"NULL\"),\n", + " (\"What's the smallest planet in the solar system?\", \"NULL\"),\n", + " (\"Who is the current Pope?\", \"NULL\"),\n", + " (\"What's the national anthem of France?\", \"NULL\"),\n", + " (\"Who was the first man on the moon?\", \"NULL\"),\n", + " (\"What's the oldest civilization in the world?\", \"NULL\"),\n", + " (\"Who is the most followed person on Instagram?\", \"NULL\"),\n", + " (\"What's the most spoken language in the world?\", \"NULL\"),\n", + " (\"Who is the director of 'Inception'?\", \"NULL\"),\n", + " (\"What's the national fruit of New Zealand?\", \"NULL\"),\n", + " (\"What's the weather like in London?\", \"NULL\"),\n", + " (\"Can you tell me a fun fact about cats?\", \"NULL\"),\n", + " (\"What's the current time in Tokyo?\", \"NULL\"),\n", + " (\"Can you recommend a good movie?\", \"NULL\"),\n", + " (\"What's the latest sports news?\", \"NULL\"),\n", + " (\"Tell me an interesting historical event.\", \"NULL\"),\n", + " (\"What's your favorite science fact?\", \"NULL\"),\n", + " (\"Can you play a trivia game?\", \"NULL\"),\n", + " (\"What's the capital of Sweden?\", \"NULL\"),\n", + " (\"Who won the last season of 'The Voice'?\", \"NULL\"),\n", + " (\"What's the tallest building in the world?\", \"NULL\"),\n", + " (\"Who is the current Prime Minister of Japan?\", \"NULL\"),\n", + " (\"What's the distance to Mars?\", \"NULL\"),\n", + " (\"Can you set an alarm for me?\", \"NULL\"),\n", + " (\"What's the secret to happiness?\", \"NULL\"),\n", + " (\"Can you tell me a brain teaser?\", \"NULL\"),\n", + " (\"What's the population of Brazil?\", \"NULL\"),\n", + " (\"Who wrote '1984'?\", \"NULL\"),\n", + " (\"What's the longest highway in the world?\", \"NULL\"),\n", + " (\"Can you translate 'good morning' to Italian?\", \"NULL\"),\n", + " (\"What's the speed of sound?\", \"NULL\"),\n", + " (\"Who invented the internet?\", \"NULL\"),\n", + " (\"What's the currency of Switzerland?\", \"NULL\"),\n", + " (\"Who sculpted 'The Thinker'?\", \"NULL\"),\n", + " (\"What's the largest continent in the world?\", \"NULL\"),\n", + " (\"Who is the most successful Olympian?\", \"NULL\"),\n", + " (\"What's the national animal of Scotland?\", \"NULL\"),\n", + " (\"Who discovered penicillin?\", \"NULL\"),\n", + " (\"What's the lifespan of a parrot?\", \"NULL\"),\n", + " (\"Can you tell me a palindrome?\", \"NULL\"),\n", + " (\"What's the national flower of the United States?\", \"NULL\"),\n", + " (\"Who is the author of 'The Hobbit'?\", \"NULL\"),\n", + " (\"What's the diameter of Jupiter?\", \"NULL\"),\n", + " (\"Who was the first woman to win a Nobel Prize?\", \"NULL\"),\n", + " (\"What's the national bird of Australia?\", \"NULL\"),\n", + " (\"Who is the CEO of Google?\", \"NULL\"),\n", + " (\"What's the highest grossing book of all time?\", \"NULL\"),\n", + " (\"Can you tell me a limerick?\", \"NULL\"),\n", + " (\"What's the national sport of Japan?\", \"NULL\"),\n", + " (\"Who is the President of Russia?\", \"NULL\"),\n", + " (\"What's the deepest cave in the world?\", \"NULL\"),\n", + " (\"Who composed the 'Four Seasons'?\", \"NULL\"),\n", + " (\"What's the smallest ocean in the world?\", \"NULL\"),\n", + " (\"Who is the fastest woman in the world?\", \"NULL\"),\n", + " (\"What's the national dish of Germany?\", \"NULL\"),\n", + " (\"Who won the Nobel Prize in Peace last year?\", \"NULL\"),\n", + " (\"What's the hottest planet in the solar system?\", \"NULL\"),\n", + " (\"Who is the current Secretary-General of the United Nations?\", \"NULL\"),\n", + " (\"What's the national anthem of Australia?\", \"NULL\"),\n", + " (\"Who was the first woman in the moon?\", \"NULL\"),\n", + " (\"Who is the author of 'The Catcher in the Rye'?\", \"NULL\"),\n", "]" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ + "import statistics\n", + "\n", "def max_threshold_test(threshold: float, scores: list):\n", " return max(scores) > threshold\n", "\n", "\n", "def mean_threshold_test(threshold: float, scores: list):\n", - " return mean(scores) > threshold" + " return statistics.mean(scores) > threshold" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "c:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_layer\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "ename": "NameError", - "evalue": "name 'queries' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32mc:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\20231106 Semantic Layer\\Repo\\semantic-layer\\00_performance_tests.ipynb Cell 6\u001b[0m line \u001b[0;36m8\n\u001b[0;32m <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y141sZmlsZQ%3D%3D?line=4'>5</a>\u001b[0m thresholds \u001b[39m=\u001b[39m [\u001b[39m0.5\u001b[39m, \u001b[39m0.55\u001b[39m, \u001b[39m0.6\u001b[39m, \u001b[39m0.65\u001b[39m, \u001b[39m0.7\u001b[39m, \u001b[39m0.75\u001b[39m, \u001b[39m0.8\u001b[39m, \u001b[39m0.85\u001b[39m, \u001b[39m0.9\u001b[39m, \u001b[39m0.95\u001b[39m]\n\u001b[0;32m <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y141sZmlsZQ%3D%3D?line=5'>6</a>\u001b[0m threshold_method \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mmax\u001b[39m\u001b[39m'\u001b[39m \u001b[39m# 'mean', 'max'\u001b[39;00m\n\u001b[1;32m----> <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y141sZmlsZQ%3D%3D?line=7'>8</a>\u001b[0m \u001b[39mfor\u001b[39;00m q, expected \u001b[39min\u001b[39;00m tqdm(queries):\n\u001b[0;32m <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y141sZmlsZQ%3D%3D?line=8'>9</a>\u001b[0m \n\u001b[0;32m <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y141sZmlsZQ%3D%3D?line=9'>10</a>\u001b[0m \u001b[39m# Attempt Query 3 Times.\u001b[39;00m\n\u001b[0;32m <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y141sZmlsZQ%3D%3D?line=10'>11</a>\u001b[0m out \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mUNDEFINED_CLASS\u001b[39m\u001b[39m'\u001b[39m \u001b[39m# Initialize actual_decision here\u001b[39;00m\n\u001b[0;32m <a href='vscode-notebook-cell:/c%3A/Users/Siraj/Documents/Personal/Work/Aurelio/20231106%20Semantic%20Layer/Repo/semantic-layer/00_performance_tests.ipynb#Y141sZmlsZQ%3D%3D?line=11'>12</a>\u001b[0m all_attempts_failed \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m \u001b[39m# Initialize flag here\u001b[39;00m\n", - "\u001b[1;31mNameError\u001b[0m: name 'queries' is not defined" + "100%|██████████| 252/252 [07:25<00:00, 1.77s/it]\n" ] } ], @@ -390,8 +466,9 @@ "import time\n", "\n", "results = {}\n", - "thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]\n", - "threshold_method = 'max' # 'mean', 'max'\n", + "# thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]\n", + "thresholds = [0.75, 0.77, 0.79, 0.81, 0.83, 0.85, 0.87, 0.89, 0.91]\n", + "threshold_method = 'mean' # 'mean', 'max'\n", "\n", "for q, expected in tqdm(queries):\n", "\n", @@ -434,23 +511,22 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Threshold: 0.5, Accuracy: 0.8287671232876712\n", - "Threshold: 0.55, Accuracy: 0.8287671232876712\n", - "Threshold: 0.6, Accuracy: 0.8287671232876712\n", - "Threshold: 0.65, Accuracy: 0.8287671232876712\n", - "Threshold: 0.7, Accuracy: 0.8287671232876712\n", - "Threshold: 0.75, Accuracy: 0.8287671232876712\n", - "Threshold: 0.8, Accuracy: 0.8287671232876712\n", - "Threshold: 0.85, Accuracy: 0.8287671232876712\n", - "Threshold: 0.9, Accuracy: 0.8287671232876712\n", - "Threshold: 0.95, Accuracy: 0.8287671232876712\n" + "Threshold: 0.75, Accuracy: 0.5317460317460317\n", + "Threshold: 0.77, Accuracy: 0.6626984126984127\n", + "Threshold: 0.79, Accuracy: 0.8015873015873016\n", + "Threshold: 0.81, Accuracy: 0.8531746031746031\n", + "Threshold: 0.83, Accuracy: 0.7420634920634921\n", + "Threshold: 0.85, Accuracy: 0.6388888888888888\n", + "Threshold: 0.87, Accuracy: 0.5714285714285714\n", + "Threshold: 0.89, Accuracy: 0.5158730158730159\n", + "Threshold: 0.91, Accuracy: 0.503968253968254\n" ] } ], diff --git a/decision_layer/decision_layer.py b/decision_layer/decision_layer.py index 499fc65f..c02e737b 100644 --- a/decision_layer/decision_layer.py +++ b/decision_layer/decision_layer.py @@ -81,4 +81,4 @@ class DecisionLayer: total_scores = {decision: sum(scores) for decision, scores in scores_by_class.items()} top_class = max(total_scores, key=total_scores.get, default=None) # Return the top class and its associated scores - return top_class, scores_by_class.get(top_class, []) \ No newline at end of file + return top_class, scores_by_class.get(top_class, []) -- GitLab