From 0116261799dc85053cebf00f674ef5321cd17110 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Mon, 6 Nov 2023 17:35:35 +0400
Subject: [PATCH] Updated simple_classify

It now takes a query result as an argument and outputs scores_by_category too, for debugging purposes.
---
 decision_layer/decision_layer.py |  10 +-
 walkthrough.ipynb                | 190 ++++++++++++++++++++++++++++---
 2 files changed, 176 insertions(+), 24 deletions(-)

diff --git a/decision_layer/decision_layer.py b/decision_layer/decision_layer.py
index 8a88f4be..2dbfe4f6 100644
--- a/decision_layer/decision_layer.py
+++ b/decision_layer/decision_layer.py
@@ -18,7 +18,7 @@ class DecisionLayer:
     def __call__(self, text: str):
 
         results = self._query(text)
-        decision = self.simple_categorise(results)
+        decision = self.simple_categorize(results)
         # return decision
         raise NotImplementedError("To implement decision logic based on scores")
 
@@ -60,14 +60,12 @@ class DecisionLayer:
             {"decision": d, "score": s.item()} for d, s in zip(decisions, scores)
         ]
 
-    def simple_categorise(self, text: str, top_k: int=5, apply_tan: bool=True):
+    def simple_classify(self, query_results: dict, apply_tan: bool=True):
         """Given some text, categorises it based on the scores from _query."""
-        # get the results from _query
-        results = self._query(text, top_k)
         
         # apply the scoring system to the results and group by category
         scores_by_category = {}
-        for result in results:
+        for result in query_results:
             score = np.tan(result['score'] * (np.pi / 2)) if apply_tan else result['score']
             if result['decision'] in scores_by_category:
                 scores_by_category[result['decision']] += score
@@ -78,6 +76,6 @@ class DecisionLayer:
         sorted_categories = sorted(scores_by_category.items(), key=lambda x: x[1], reverse=True)
         
         # return the category with the highest total score
-        return sorted_categories[0][0] if sorted_categories else None
+        return sorted_categories[0][0] if sorted_categories else None, scores_by_category
     
 
diff --git a/walkthrough.ipynb b/walkthrough.ipynb
index 700f577c..04614ab0 100644
--- a/walkthrough.ipynb
+++ b/walkthrough.ipynb
@@ -26,9 +26,19 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 14,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "[notice] A new release of pip is available: 23.1.2 -> 23.3.1\n",
+      "[notice] To update, run: python.exe -m pip install --upgrade pip\n"
+     ]
+    }
+   ],
    "source": [
     "!pip install -qU \\\n",
     "    decision-layer"
@@ -44,7 +54,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": 15,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -58,7 +68,8 @@
     "        \"don't you just love the president\"\n",
     "        \"don't you just hate the president\",\n",
     "        \"they're going to destroy this country!\",\n",
-    "        \"they will save the country!\"\n",
+    "        \"they will save the country!\",\n",
+    "        \"did you hear about the new goverment proposal regarding the ownership of cats and dogs\",\n",
     "    ]\n",
     ")"
    ]
@@ -72,7 +83,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 16,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -83,7 +94,8 @@
     "        \"how are things going?\",\n",
     "        \"lovely weather today\",\n",
     "        \"the weather is horrendous\",\n",
-    "        \"let's go to the chippy\"\n",
+    "        \"let's go to the chippy\",\n",
+    "        \"it's raining cats and dogs\",\n",
     "    ]\n",
     ")\n",
     "\n",
@@ -99,14 +111,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [],
    "source": [
     "from decision_layer.encoders import OpenAIEncoder\n",
     "import os\n",
     "\n",
-    "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n",
     "encoder = OpenAIEncoder(name=\"text-embedding-ada-002\")"
    ]
   },
@@ -119,7 +130,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 18,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -130,20 +141,20 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 26,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "[{'decision': 'politics', 'score': 0.24968127755063652},\n",
-       " {'decision': 'politics', 'score': 0.2536216026530966},\n",
-       " {'decision': 'politics', 'score': 0.27568433588684954},\n",
-       " {'decision': 'politics', 'score': 0.27732789989574913},\n",
-       " {'decision': 'politics', 'score': 0.28110307885950714}]"
+       "[{'decision': 'politics', 'score': 0.22792677421560453},\n",
+       " {'decision': 'politics', 'score': 0.2315237823644528},\n",
+       " {'decision': 'politics', 'score': 0.2516642096551168},\n",
+       " {'decision': 'politics', 'score': 0.2531645714220874},\n",
+       " {'decision': 'politics', 'score': 0.2566108224655662}]"
       ]
      },
-     "execution_count": 5,
+     "execution_count": 26,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -158,7 +169,150 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "---"
+    "Using the most similar `Decision` `utterances` and their `cosine similarity scores`, use `simple_classify` to apply scoring a secondary scoring system which chooses the `decision` that the utterance belongs to."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Here we use `apply_tan=True`, which means that a `tan` function is assigned to each score boosting the score of `decisions` whose datapoints had greater `cosine similarlity` and reducing the score of those which had lower `cosine similarity`. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "'politics'"
+      ]
+     },
+     "execution_count": 27,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "decision, scores_by_category = dl.simple_classify(query_results=out, apply_tan=True)\n",
+    "decision"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "{'politics': 2.018519173992354}"
+      ]
+     },
+     "execution_count": 28,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "scores_by_category"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "The correct category was chosen. Let's try again for a less clear-cut case:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "[{'decision': 'chitchat', 'score': 0.22320888353212376},\n",
+       " {'decision': 'politics', 'score': 0.22367029584935166},\n",
+       " {'decision': 'politics', 'score': 0.2274250403127478},\n",
+       " {'decision': 'politics', 'score': 0.23451692377042876},\n",
+       " {'decision': 'chitchat', 'score': 0.24924083653953585}]"
+      ]
+     },
+     "execution_count": 31,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "out = dl._query(\"i love cats and dogs!\")\n",
+    "out"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "'politics'"
+      ]
+     },
+     "execution_count": 32,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "decision, scores_by_category = dl.simple_classify(query_results=out, apply_tan=True)\n",
+    "decision"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "{'chitchat': 0.7785435459589187, 'politics': 1.1258003022715952}"
+      ]
+     },
+     "execution_count": 33,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "scores_by_category"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "array(['politics', 'politics', 'politics', 'politics', 'politics',\n",
+       "       'politics', 'chitchat', 'chitchat', 'chitchat', 'chitchat',\n",
+       "       'chitchat', 'chitchat'], dtype='<U8')"
+      ]
+     },
+     "execution_count": 34,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "dl.categories"
    ]
   }
  ],
@@ -178,7 +332,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.11.5"
+   "version": "3.11.4"
   },
   "orig_nbformat": 4
  },
-- 
GitLab