From 348cad7430a1e411f33b5880d9aa1df40765961e Mon Sep 17 00:00:00 2001
From: Grigoriy <skvrd@users.noreply.github.com>
Date: Thu, 29 Feb 2024 08:16:19 -0800
Subject: [PATCH] fix name typo, add unit test to ollama (#11493)

---
 .../llama_index/llms/ollama/base.py              | 16 +++++++++-------
 .../llama-index-llms-ollama/tests/test_utils.py  | 12 ++++++++++++
 2 files changed, 21 insertions(+), 7 deletions(-)
 create mode 100644 llama-index-integrations/llms/llama-index-llms-ollama/tests/test_utils.py

diff --git a/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py b/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py
index 562b0c7342..78281eb591 100644
--- a/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py
@@ -20,7 +20,7 @@ from llama_index.core.llms.custom import CustomLLM
 DEFAULT_REQUEST_TIMEOUT = 30.0
 
 
-def get_addtional_kwargs(
+def get_additional_kwargs(
     response: Dict[str, Any], exclude: Tuple[str, ...]
 ) -> Dict[str, Any]:
     return {k: v for k, v in response.items() if k not in exclude}
@@ -109,12 +109,12 @@ class Ollama(CustomLLM):
                 message=ChatMessage(
                     content=message.get("content"),
                     role=MessageRole(message.get("role")),
-                    additional_kwargs=get_addtional_kwargs(
+                    additional_kwargs=get_additional_kwargs(
                         message, ("content", "role")
                     ),
                 ),
                 raw=raw,
-                additional_kwargs=get_addtional_kwargs(raw, ("message",)),
+                additional_kwargs=get_additional_kwargs(raw, ("message",)),
             )
 
     @llm_chat_callback()
@@ -156,13 +156,15 @@ class Ollama(CustomLLM):
                             message=ChatMessage(
                                 content=text,
                                 role=MessageRole(message.get("role")),
-                                additional_kwargs=get_addtional_kwargs(
+                                additional_kwargs=get_additional_kwargs(
                                     message, ("content", "role")
                                 ),
                             ),
                             delta=delta,
                             raw=chunk,
-                            additional_kwargs=get_addtional_kwargs(chunk, ("message",)),
+                            additional_kwargs=get_additional_kwargs(
+                                chunk, ("message",)
+                            ),
                         )
 
     @llm_completion_callback()
@@ -188,7 +190,7 @@ class Ollama(CustomLLM):
             return CompletionResponse(
                 text=text,
                 raw=raw,
-                additional_kwargs=get_addtional_kwargs(raw, ("response",)),
+                additional_kwargs=get_additional_kwargs(raw, ("response",)),
             )
 
     @llm_completion_callback()
@@ -220,7 +222,7 @@ class Ollama(CustomLLM):
                             delta=delta,
                             text=text,
                             raw=chunk,
-                            additional_kwargs=get_addtional_kwargs(
+                            additional_kwargs=get_additional_kwargs(
                                 chunk, ("response",)
                             ),
                         )
diff --git a/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_utils.py b/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_utils.py
new file mode 100644
index 0000000000..0d300048cc
--- /dev/null
+++ b/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_utils.py
@@ -0,0 +1,12 @@
+from llama_index.llms.ollama.base import get_additional_kwargs
+
+
+def test_get_additional_kwargs():
+    response = {"key1": "value1", "key2": "value2", "exclude_me": "value3"}
+    exclude = ("exclude_me", "exclude_me_too")
+
+    expected = {"key1": "value1", "key2": "value2"}
+
+    actual = get_additional_kwargs(response, exclude)
+
+    assert actual == expected
-- 
GitLab