From 0116261799dc85053cebf00f674ef5321cd17110 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Mon, 6 Nov 2023 17:35:35 +0400 Subject: [PATCH] Updated simple_classify It now takes a query result as an argument and outputs scores_by_category too, for debugging purposes. --- decision_layer/decision_layer.py | 10 +- walkthrough.ipynb | 190 ++++++++++++++++++++++++++++--- 2 files changed, 176 insertions(+), 24 deletions(-) diff --git a/decision_layer/decision_layer.py b/decision_layer/decision_layer.py index 8a88f4be..2dbfe4f6 100644 --- a/decision_layer/decision_layer.py +++ b/decision_layer/decision_layer.py @@ -18,7 +18,7 @@ class DecisionLayer: def __call__(self, text: str): results = self._query(text) - decision = self.simple_categorise(results) + decision = self.simple_categorize(results) # return decision raise NotImplementedError("To implement decision logic based on scores") @@ -60,14 +60,12 @@ class DecisionLayer: {"decision": d, "score": s.item()} for d, s in zip(decisions, scores) ] - def simple_categorise(self, text: str, top_k: int=5, apply_tan: bool=True): + def simple_classify(self, query_results: dict, apply_tan: bool=True): """Given some text, categorises it based on the scores from _query.""" - # get the results from _query - results = self._query(text, top_k) # apply the scoring system to the results and group by category scores_by_category = {} - for result in results: + for result in query_results: score = np.tan(result['score'] * (np.pi / 2)) if apply_tan else result['score'] if result['decision'] in scores_by_category: scores_by_category[result['decision']] += score @@ -78,6 +76,6 @@ class DecisionLayer: sorted_categories = sorted(scores_by_category.items(), key=lambda x: x[1], reverse=True) # return the category with the highest total score - return sorted_categories[0][0] if sorted_categories else None + return sorted_categories[0][0] if sorted_categories else None, scores_by_category diff --git a/walkthrough.ipynb b/walkthrough.ipynb index 700f577c..04614ab0 100644 --- a/walkthrough.ipynb +++ b/walkthrough.ipynb @@ -26,9 +26,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "[notice] A new release of pip is available: 23.1.2 -> 23.3.1\n", + "[notice] To update, run: python.exe -m pip install --upgrade pip\n" + ] + } + ], "source": [ "!pip install -qU \\\n", " decision-layer" @@ -44,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -58,7 +68,8 @@ " \"don't you just love the president\"\n", " \"don't you just hate the president\",\n", " \"they're going to destroy this country!\",\n", - " \"they will save the country!\"\n", + " \"they will save the country!\",\n", + " \"did you hear about the new goverment proposal regarding the ownership of cats and dogs\",\n", " ]\n", ")" ] @@ -72,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -83,7 +94,8 @@ " \"how are things going?\",\n", " \"lovely weather today\",\n", " \"the weather is horrendous\",\n", - " \"let's go to the chippy\"\n", + " \"let's go to the chippy\",\n", + " \"it's raining cats and dogs\",\n", " ]\n", ")\n", "\n", @@ -99,14 +111,13 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from decision_layer.encoders import OpenAIEncoder\n", "import os\n", "\n", - "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n", "encoder = OpenAIEncoder(name=\"text-embedding-ada-002\")" ] }, @@ -119,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -130,20 +141,20 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[{'decision': 'politics', 'score': 0.24968127755063652},\n", - " {'decision': 'politics', 'score': 0.2536216026530966},\n", - " {'decision': 'politics', 'score': 0.27568433588684954},\n", - " {'decision': 'politics', 'score': 0.27732789989574913},\n", - " {'decision': 'politics', 'score': 0.28110307885950714}]" + "[{'decision': 'politics', 'score': 0.22792677421560453},\n", + " {'decision': 'politics', 'score': 0.2315237823644528},\n", + " {'decision': 'politics', 'score': 0.2516642096551168},\n", + " {'decision': 'politics', 'score': 0.2531645714220874},\n", + " {'decision': 'politics', 'score': 0.2566108224655662}]" ] }, - "execution_count": 5, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -158,7 +169,150 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "---" + "Using the most similar `Decision` `utterances` and their `cosine similarity scores`, use `simple_classify` to apply scoring a secondary scoring system which chooses the `decision` that the utterance belongs to." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we use `apply_tan=True`, which means that a `tan` function is assigned to each score boosting the score of `decisions` whose datapoints had greater `cosine similarlity` and reducing the score of those which had lower `cosine similarity`. " + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'politics'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "decision, scores_by_category = dl.simple_classify(query_results=out, apply_tan=True)\n", + "decision" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'politics': 2.018519173992354}" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scores_by_category" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The correct category was chosen. Let's try again for a less clear-cut case:" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'decision': 'chitchat', 'score': 0.22320888353212376},\n", + " {'decision': 'politics', 'score': 0.22367029584935166},\n", + " {'decision': 'politics', 'score': 0.2274250403127478},\n", + " {'decision': 'politics', 'score': 0.23451692377042876},\n", + " {'decision': 'chitchat', 'score': 0.24924083653953585}]" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out = dl._query(\"i love cats and dogs!\")\n", + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'politics'" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "decision, scores_by_category = dl.simple_classify(query_results=out, apply_tan=True)\n", + "decision" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'chitchat': 0.7785435459589187, 'politics': 1.1258003022715952}" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scores_by_category" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['politics', 'politics', 'politics', 'politics', 'politics',\n", + " 'politics', 'chitchat', 'chitchat', 'chitchat', 'chitchat',\n", + " 'chitchat', 'chitchat'], dtype='<U8')" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dl.categories" ] } ], @@ -178,7 +332,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.4" }, "orig_nbformat": 4 }, -- GitLab