From f8511acbe5f302e990a83b908e55ebc92c62fde2 Mon Sep 17 00:00:00 2001
From: Ming <tslmy@users.noreply.github.com>
Date: Mon, 1 Jan 2024 19:01:21 -0800
Subject: [PATCH] chore: + the linter `mypy` as a pre-commit hook (#9791)

---
 .pre-commit-config.yaml                   | 14 ++++++++++++++
 llama_index/embeddings/pooling.py         |  4 +++-
 tests/vector_stores/test_elasticsearch.py |  1 +
 3 files changed, 18 insertions(+), 1 deletion(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 2c6e71b03f..6f9329703c 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -34,6 +34,20 @@ repos:
         name: black-src
         alias: black
         exclude: ^(docs/|llama_index/_static)
+  - repo: https://github.com/pre-commit/mirrors-mypy
+    rev: v1.0.1
+    hooks:
+      - id: mypy
+        additional_dependencies:
+          [
+            "types-requests",
+            "types-Deprecated",
+            "types-redis",
+            "types-setuptools",
+            "types-PyYAML",
+            "types-protobuf",
+          ]
+        exclude: ^(docs/|llama_index/_static)
   - repo: https://github.com/psf/black-pre-commit-mirror
     rev: 23.10.1
     hooks:
diff --git a/llama_index/embeddings/pooling.py b/llama_index/embeddings/pooling.py
index 1f046d1691..ec591af109 100644
--- a/llama_index/embeddings/pooling.py
+++ b/llama_index/embeddings/pooling.py
@@ -25,7 +25,9 @@ class Pooling(str, Enum):
 
     @classmethod
     @overload
-    def cls_pooling(cls, array: "torch.Tensor") -> "torch.Tensor":
+    # TODO: Remove this `type: ignore` after the false positive problem
+    #  is addressed in mypy: https://github.com/python/mypy/issues/15683 .
+    def cls_pooling(cls, array: "torch.Tensor") -> "torch.Tensor":  # type: ignore
         ...
 
     @classmethod
diff --git a/tests/vector_stores/test_elasticsearch.py b/tests/vector_stores/test_elasticsearch.py
index d4b7c681b8..18bb7f07f4 100644
--- a/tests/vector_stores/test_elasticsearch.py
+++ b/tests/vector_stores/test_elasticsearch.py
@@ -87,6 +87,7 @@ def elasticsearch_connection() -> Union[dict, Generator[dict, None, None]]:
         if index_name.startswith("test_"):
             es.indices.delete(index=index_name)
     es.indices.refresh(index="_all")
+    return {}
 
 
 @pytest.fixture(scope="session")
-- 
GitLab