diff --git a/poetry.lock b/poetry.lock index 8e64ef1d97c81c96f2e3d8bc7e9a70c242e50649..7cf51ec13aa93c6d88ee529e4a0b4906b4817620 100644 --- a/poetry.lock +++ b/poetry.lock @@ -449,13 +449,13 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} [[package]] name = "cohere" -version = "4.50" +version = "4.54" description = "Python SDK for the Cohere API" optional = false python-versions = ">=3.8,<4.0" files = [ - {file = "cohere-4.50-py3-none-any.whl", hash = "sha256:790744034c76cabd9c3d8e05c0b2e85733caee64e695e5a9904c4527202c913e"}, - {file = "cohere-4.50.tar.gz", hash = "sha256:64908677069fee23bb5dc968d576eab3ed644278e800bb7dcc3e5a336e5fc206"}, + {file = "cohere-4.54-py3-none-any.whl", hash = "sha256:6f83bd2530d461f91dfea828c1a7162b37cf03bdad4d801a218681064821f167"}, + {file = "cohere-4.54.tar.gz", hash = "sha256:c4bd84f0d766575430db91262e3ba0f4b655e40fec8a08fde98652e49a18608d"}, ] [package.dependencies] @@ -513,13 +513,13 @@ development = ["black", "flake8", "mypy", "pytest", "types-colorama"] [[package]] name = "comm" -version = "0.2.1" +version = "0.2.2" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." optional = false python-versions = ">=3.8" files = [ - {file = "comm-0.2.1-py3-none-any.whl", hash = "sha256:87928485c0dfc0e7976fd89fc1e187023cf587e7c353e4a9b417555b44adf021"}, - {file = "comm-0.2.1.tar.gz", hash = "sha256:0bc91edae1344d39d3661dcbc36937181fdaddb304790458f8b044dbc064b89a"}, + {file = "comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3"}, + {file = "comm-0.2.2.tar.gz", hash = "sha256:3fd7a84065306e07bea1773df6eb8282de51ba82f77c72f9c85716ab11fe980e"}, ] [package.dependencies] @@ -862,13 +862,13 @@ typing = ["typing-extensions (>=4.8)"] [[package]] name = "flatbuffers" -version = "23.5.26" +version = "24.3.7" description = "The FlatBuffers serialization format for Python" optional = true python-versions = "*" files = [ - {file = "flatbuffers-23.5.26-py2.py3-none-any.whl", hash = "sha256:c0ff356da363087b915fde4b8b45bdda73432fc17cddb3c8157472eab1422ad1"}, - {file = "flatbuffers-23.5.26.tar.gz", hash = "sha256:9ea1144cac05ce5d86e2859f431c6cd5e66cd9c78c558317c7955fb8d4c78d89"}, + {file = "flatbuffers-24.3.7-py2.py3-none-any.whl", hash = "sha256:80c4f5dcad0ee76b7e349671a0d657f2fbba927a0244f88dd3f5ed6a3694e1fc"}, + {file = "flatbuffers-24.3.7.tar.gz", hash = "sha256:0895c22b9a6019ff2f4de2e5e2f7cd15914043e6e7033a94c0c6369422690f22"}, ] [[package]] @@ -1192,13 +1192,13 @@ testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs [[package]] name = "importlib-resources" -version = "6.1.1" +version = "6.1.3" description = "Read resources from Python packages" optional = true python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.1.1-py3-none-any.whl", hash = "sha256:e8bf90d8213b486f428c9c39714b920041cb02c184686a3dee24905aaa8105d6"}, - {file = "importlib_resources-6.1.1.tar.gz", hash = "sha256:3893a00122eafde6894c59914446a512f728a0c1a45f9bb9b63721b6bacf0b4a"}, + {file = "importlib_resources-6.1.3-py3-none-any.whl", hash = "sha256:4c0269e3580fe2634d364b39b38b961540a7738c02cb984e98add8b4221d793d"}, + {file = "importlib_resources-6.1.3.tar.gz", hash = "sha256:56fb4525197b78544a3354ea27793952ab93f935bb4bf746b846bb1015020f2b"}, ] [package.dependencies] @@ -1206,7 +1206,7 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff", "zipp (>=3.17)"] +testing = ["jaraco.collections", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] [[package]] name = "iniconfig" @@ -1221,13 +1221,13 @@ files = [ [[package]] name = "ipykernel" -version = "6.29.2" +version = "6.29.3" description = "IPython Kernel for Jupyter" optional = false python-versions = ">=3.8" files = [ - {file = "ipykernel-6.29.2-py3-none-any.whl", hash = "sha256:50384f5c577a260a1d53f1f59a828c7266d321c9b7d00d345693783f66616055"}, - {file = "ipykernel-6.29.2.tar.gz", hash = "sha256:3bade28004e3ff624ed57974948116670604ac5f676d12339693f3142176d3f0"}, + {file = "ipykernel-6.29.3-py3-none-any.whl", hash = "sha256:5aa086a4175b0229d4eca211e181fb473ea78ffd9869af36ba7694c947302a21"}, + {file = "ipykernel-6.29.3.tar.gz", hash = "sha256:e14c250d1f9ea3989490225cc1a542781b095a18a19447fcf2b5eaf7d0ac5bd2"}, ] [package.dependencies] @@ -1250,7 +1250,7 @@ cov = ["coverage[toml]", "curio", "matplotlib", "pytest-cov", "trio"] docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "trio"] pyqt5 = ["pyqt5"] pyside6 = ["pyside6"] -test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (==0.23.4)", "pytest-cov", "pytest-timeout"] +test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>=0.23.5)", "pytest-cov", "pytest-timeout"] [[package]] name = "ipython" @@ -1338,13 +1338,13 @@ files = [ [[package]] name = "jupyter-client" -version = "8.6.0" +version = "8.6.1" description = "Jupyter protocol implementation and client libraries" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_client-8.6.0-py3-none-any.whl", hash = "sha256:909c474dbe62582ae62b758bca86d6518c85234bdee2d908c778db6d72f39d99"}, - {file = "jupyter_client-8.6.0.tar.gz", hash = "sha256:0642244bb83b4764ae60d07e010e15f0e2d275ec4e918a8f7b80fbbef3ca60c7"}, + {file = "jupyter_client-8.6.1-py3-none-any.whl", hash = "sha256:3b7bd22f058434e3b9a7ea4b1500ed47de2713872288c0d511d19926f99b459f"}, + {file = "jupyter_client-8.6.1.tar.gz", hash = "sha256:e842515e2bab8e19186d89fdfea7abd15e39dd581f94e399f00e2af5a1652d3f"}, ] [package.dependencies] @@ -1361,13 +1361,13 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt [[package]] name = "jupyter-core" -version = "5.7.1" +version = "5.7.2" description = "Jupyter core package. A base package on which Jupyter projects rely." optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_core-5.7.1-py3-none-any.whl", hash = "sha256:c65c82126453a723a2804aa52409930434598fd9d35091d63dfb919d2b765bb7"}, - {file = "jupyter_core-5.7.1.tar.gz", hash = "sha256:de61a9d7fc71240f688b2fb5ab659fbb56979458dc66a71decd098e03c79e218"}, + {file = "jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409"}, + {file = "jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9"}, ] [package.dependencies] @@ -1377,7 +1377,7 @@ traitlets = ">=5.3" [package.extras] docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] -test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] +test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout"] [[package]] name = "kiwisolver" @@ -1494,12 +1494,12 @@ files = [ [[package]] name = "llama-cpp-python" -version = "0.2.50" +version = "0.2.56" description = "Python bindings for the llama.cpp library" optional = true python-versions = ">=3.8" files = [ - {file = "llama_cpp_python-0.2.50.tar.gz", hash = "sha256:28caf4e665dac62ad1d347061b7a96669af7fb9e7f1e4e8c17e736504e321a51"}, + {file = "llama_cpp_python-0.2.56.tar.gz", hash = "sha256:9c82db80e929ae93c2ab069a76a8a52aac82479cf9d0523c3550af48554cc785"}, ] [package.dependencies] @@ -1650,7 +1650,7 @@ traitlets = "*" name = "mistralai" version = "0.0.12" description = "" -optional = false +optional = true python-versions = ">=3.8,<4.0" files = [ {file = "mistralai-0.0.12-py3-none-any.whl", hash = "sha256:d489d1f0a31bf0edbe15c6d12f68b943148d2a725a088be0d8a5d4c888f8436c"}, @@ -1824,38 +1824,38 @@ files = [ [[package]] name = "mypy" -version = "1.8.0" +version = "1.9.0" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" files = [ - {file = "mypy-1.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485a8942f671120f76afffff70f259e1cd0f0cfe08f81c05d8816d958d4577d3"}, - {file = "mypy-1.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:df9824ac11deaf007443e7ed2a4a26bebff98d2bc43c6da21b2b64185da011c4"}, - {file = "mypy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2afecd6354bbfb6e0160f4e4ad9ba6e4e003b767dd80d85516e71f2e955ab50d"}, - {file = "mypy-1.8.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8963b83d53ee733a6e4196954502b33567ad07dfd74851f32be18eb932fb1cb9"}, - {file = "mypy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e46f44b54ebddbeedbd3d5b289a893219065ef805d95094d16a0af6630f5d410"}, - {file = "mypy-1.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:855fe27b80375e5c5878492f0729540db47b186509c98dae341254c8f45f42ae"}, - {file = "mypy-1.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4c886c6cce2d070bd7df4ec4a05a13ee20c0aa60cb587e8d1265b6c03cf91da3"}, - {file = "mypy-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d19c413b3c07cbecf1f991e2221746b0d2a9410b59cb3f4fb9557f0365a1a817"}, - {file = "mypy-1.8.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9261ed810972061388918c83c3f5cd46079d875026ba97380f3e3978a72f503d"}, - {file = "mypy-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:51720c776d148bad2372ca21ca29256ed483aa9a4cdefefcef49006dff2a6835"}, - {file = "mypy-1.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:52825b01f5c4c1c4eb0db253ec09c7aa17e1a7304d247c48b6f3599ef40db8bd"}, - {file = "mypy-1.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f5ac9a4eeb1ec0f1ccdc6f326bcdb464de5f80eb07fb38b5ddd7b0de6bc61e55"}, - {file = "mypy-1.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afe3fe972c645b4632c563d3f3eff1cdca2fa058f730df2b93a35e3b0c538218"}, - {file = "mypy-1.8.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:42c6680d256ab35637ef88891c6bd02514ccb7e1122133ac96055ff458f93fc3"}, - {file = "mypy-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:720a5ca70e136b675af3af63db533c1c8c9181314d207568bbe79051f122669e"}, - {file = "mypy-1.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:028cf9f2cae89e202d7b6593cd98db6759379f17a319b5faf4f9978d7084cdc6"}, - {file = "mypy-1.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4e6d97288757e1ddba10dd9549ac27982e3e74a49d8d0179fc14d4365c7add66"}, - {file = "mypy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f1478736fcebb90f97e40aff11a5f253af890c845ee0c850fe80aa060a267c6"}, - {file = "mypy-1.8.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42419861b43e6962a649068a61f4a4839205a3ef525b858377a960b9e2de6e0d"}, - {file = "mypy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:2b5b6c721bd4aabaadead3a5e6fa85c11c6c795e0c81a7215776ef8afc66de02"}, - {file = "mypy-1.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5c1538c38584029352878a0466f03a8ee7547d7bd9f641f57a0f3017a7c905b8"}, - {file = "mypy-1.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ef4be7baf08a203170f29e89d79064463b7fc7a0908b9d0d5114e8009c3a259"}, - {file = "mypy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7178def594014aa6c35a8ff411cf37d682f428b3b5617ca79029d8ae72f5402b"}, - {file = "mypy-1.8.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ab3c84fa13c04aeeeabb2a7f67a25ef5d77ac9d6486ff33ded762ef353aa5592"}, - {file = "mypy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:99b00bc72855812a60d253420d8a2eae839b0afa4938f09f4d2aa9bb4654263a"}, - {file = "mypy-1.8.0-py3-none-any.whl", hash = "sha256:538fd81bb5e430cc1381a443971c0475582ff9f434c16cd46d2c66763ce85d9d"}, - {file = "mypy-1.8.0.tar.gz", hash = "sha256:6ff8b244d7085a0b425b56d327b480c3b29cafbd2eff27316a004f9a7391ae07"}, + {file = "mypy-1.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f8a67616990062232ee4c3952f41c779afac41405806042a8126fe96e098419f"}, + {file = "mypy-1.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d357423fa57a489e8c47b7c85dfb96698caba13d66e086b412298a1a0ea3b0ed"}, + {file = "mypy-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49c87c15aed320de9b438ae7b00c1ac91cd393c1b854c2ce538e2a72d55df150"}, + {file = "mypy-1.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:48533cdd345c3c2e5ef48ba3b0d3880b257b423e7995dada04248725c6f77374"}, + {file = "mypy-1.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:4d3dbd346cfec7cb98e6cbb6e0f3c23618af826316188d587d1c1bc34f0ede03"}, + {file = "mypy-1.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:653265f9a2784db65bfca694d1edd23093ce49740b2244cde583aeb134c008f3"}, + {file = "mypy-1.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3a3c007ff3ee90f69cf0a15cbcdf0995749569b86b6d2f327af01fd1b8aee9dc"}, + {file = "mypy-1.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2418488264eb41f69cc64a69a745fad4a8f86649af4b1041a4c64ee61fc61129"}, + {file = "mypy-1.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:68edad3dc7d70f2f17ae4c6c1b9471a56138ca22722487eebacfd1eb5321d612"}, + {file = "mypy-1.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:85ca5fcc24f0b4aeedc1d02f93707bccc04733f21d41c88334c5482219b1ccb3"}, + {file = "mypy-1.9.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aceb1db093b04db5cd390821464504111b8ec3e351eb85afd1433490163d60cd"}, + {file = "mypy-1.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0235391f1c6f6ce487b23b9dbd1327b4ec33bb93934aa986efe8a9563d9349e6"}, + {file = "mypy-1.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4d5ddc13421ba3e2e082a6c2d74c2ddb3979c39b582dacd53dd5d9431237185"}, + {file = "mypy-1.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:190da1ee69b427d7efa8aa0d5e5ccd67a4fb04038c380237a0d96829cb157913"}, + {file = "mypy-1.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:fe28657de3bfec596bbeef01cb219833ad9d38dd5393fc649f4b366840baefe6"}, + {file = "mypy-1.9.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e54396d70be04b34f31d2edf3362c1edd023246c82f1730bbf8768c28db5361b"}, + {file = "mypy-1.9.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5e6061f44f2313b94f920e91b204ec600982961e07a17e0f6cd83371cb23f5c2"}, + {file = "mypy-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81a10926e5473c5fc3da8abb04119a1f5811a236dc3a38d92015cb1e6ba4cb9e"}, + {file = "mypy-1.9.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b685154e22e4e9199fc95f298661deea28aaede5ae16ccc8cbb1045e716b3e04"}, + {file = "mypy-1.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:5d741d3fc7c4da608764073089e5f58ef6352bedc223ff58f2f038c2c4698a89"}, + {file = "mypy-1.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:587ce887f75dd9700252a3abbc9c97bbe165a4a630597845c61279cf32dfbf02"}, + {file = "mypy-1.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f88566144752999351725ac623471661c9d1cd8caa0134ff98cceeea181789f4"}, + {file = "mypy-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61758fabd58ce4b0720ae1e2fea5cfd4431591d6d590b197775329264f86311d"}, + {file = "mypy-1.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e49499be624dead83927e70c756970a0bc8240e9f769389cdf5714b0784ca6bf"}, + {file = "mypy-1.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:571741dc4194b4f82d344b15e8837e8c5fcc462d66d076748142327626a1b6e9"}, + {file = "mypy-1.9.0-py3-none-any.whl", hash = "sha256:a260627a570559181a9ea5de61ac6297aa5af202f06fd7ab093ce74e7181e43e"}, + {file = "mypy-1.9.0.tar.gz", hash = "sha256:3cc5da0127e6a478cddd906068496a97a7618a21ce9b54bde5bf7e539c7af974"}, ] [package.dependencies] @@ -2100,13 +2100,13 @@ files = [ [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.3.101" +version = "12.4.99" description = "Nvidia JIT LTO Library" optional = true python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, - {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, + {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"}, + {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"}, ] [[package]] @@ -2205,13 +2205,13 @@ sympy = "*" [[package]] name = "openai" -version = "1.12.0" +version = "1.13.3" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.12.0-py3-none-any.whl", hash = "sha256:a54002c814e05222e413664f651b5916714e4700d041d5cf5724d3ae1a3e3481"}, - {file = "openai-1.12.0.tar.gz", hash = "sha256:99c5d257d09ea6533d689d1cc77caa0ac679fa21efef8893d8b0832a86877f1b"}, + {file = "openai-1.13.3-py3-none-any.whl", hash = "sha256:5769b62abd02f350a8dd1a3a242d8972c947860654466171d60fb0972ae0a41c"}, + {file = "openai-1.13.3.tar.gz", hash = "sha256:ff6c6b3bc7327e715e4b3592a923a5a1c7519ff5dd764a83d69f633d49e77a7b"}, ] [package.dependencies] @@ -2230,7 +2230,7 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] name = "orjson" version = "3.9.15" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "orjson-3.9.15-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:d61f7ce4727a9fa7680cd6f3986b0e2c732639f46a5e0156e550e35258aa313a"}, @@ -2287,13 +2287,13 @@ files = [ [[package]] name = "packaging" -version = "23.2" +version = "24.0" description = "Core utilities for Python packages" optional = false python-versions = ">=3.7" files = [ - {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, - {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, + {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, + {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, ] [[package]] @@ -2602,13 +2602,13 @@ files = [ [[package]] name = "pydantic" -version = "2.6.2" +version = "2.6.4" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.6.2-py3-none-any.whl", hash = "sha256:37a5432e54b12fecaa1049c5195f3d860a10e01bdfd24f1840ef14bd0d3aeab3"}, - {file = "pydantic-2.6.2.tar.gz", hash = "sha256:a09be1c3d28f3abe37f8a78af58284b236a92ce520105ddc91a6d29ea1176ba7"}, + {file = "pydantic-2.6.4-py3-none-any.whl", hash = "sha256:cc46fce86607580867bdc3361ad462bab9c222ef042d3da86f2fb333e1d916c5"}, + {file = "pydantic-2.6.4.tar.gz", hash = "sha256:b1704e0847db01817624a6b86766967f552dd9dbf3afba4004409f908dcc84e6"}, ] [package.dependencies] @@ -2727,13 +2727,13 @@ windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pyparsing" -version = "3.1.1" +version = "3.1.2" description = "pyparsing module - Classes and methods to define and execute parsing grammars" optional = true python-versions = ">=3.6.8" files = [ - {file = "pyparsing-3.1.1-py3-none-any.whl", hash = "sha256:32c7c0b711493c72ff18a981d24f28aaf9c1fb7ed5e9667c9e84e3db623bdbfb"}, - {file = "pyparsing-3.1.1.tar.gz", hash = "sha256:ede28a1a32462f5a9705e07aea48001a08f7cf81a021585011deba701581a0db"}, + {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, + {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, ] [package.extras] @@ -2829,13 +2829,13 @@ testing = ["filelock"] [[package]] name = "python-dateutil" -version = "2.8.2" +version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ - {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, - {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, ] [package.dependencies] @@ -3323,13 +3323,13 @@ files = [ [[package]] name = "sniffio" -version = "1.3.0" +version = "1.3.1" description = "Sniff out which async library your code is running under" optional = false python-versions = ">=3.7" files = [ - {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, - {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] [[package]] @@ -3697,28 +3697,28 @@ telegram = ["requests"] [[package]] name = "traitlets" -version = "5.14.1" +version = "5.14.2" description = "Traitlets Python configuration system" optional = false python-versions = ">=3.8" files = [ - {file = "traitlets-5.14.1-py3-none-any.whl", hash = "sha256:2e5a030e6eff91737c643231bfcf04a65b0132078dad75e4936700b213652e74"}, - {file = "traitlets-5.14.1.tar.gz", hash = "sha256:8585105b371a04b8316a43d5ce29c098575c2e477850b62b848b964f1444527e"}, + {file = "traitlets-5.14.2-py3-none-any.whl", hash = "sha256:fcdf85684a772ddeba87db2f398ce00b40ff550d1528c03c14dbf6a02003cd80"}, + {file = "traitlets-5.14.2.tar.gz", hash = "sha256:8cdd83c040dab7d1dee822678e5f5d100b514f7b72b01615b26fc5718916fdf9"}, ] [package.extras] docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] -test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] +test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.1)", "pytest-mock", "pytest-mypy-testing"] [[package]] name = "transformers" -version = "4.38.1" +version = "4.38.2" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = true python-versions = ">=3.8.0" files = [ - {file = "transformers-4.38.1-py3-none-any.whl", hash = "sha256:a7a9265fb060183e9d975cbbadc4d531b10281589c43f6d07563f86322728973"}, - {file = "transformers-4.38.1.tar.gz", hash = "sha256:86dc84ccbe36123647e84cbd50fc31618c109a41e6be92514b064ab55bf1304c"}, + {file = "transformers-4.38.2-py3-none-any.whl", hash = "sha256:c4029cb9f01b3dd335e52f364c52d2b37c65b4c78e02e6a08b1919c5c928573e"}, + {file = "transformers-4.38.2.tar.gz", hash = "sha256:c5fc7ad682b8a50a48b2a4c05d4ea2de5567adb1bdd00053619dbe5960857dd5"}, ] [package.dependencies] @@ -3803,24 +3803,24 @@ tutorials = ["matplotlib", "pandas", "tabulate", "torch"] [[package]] name = "types-pyyaml" -version = "6.0.12.12" +version = "6.0.12.20240311" description = "Typing stubs for PyYAML" optional = false -python-versions = "*" +python-versions = ">=3.8" files = [ - {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"}, - {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"}, + {file = "types-PyYAML-6.0.12.20240311.tar.gz", hash = "sha256:a9e0f0f88dc835739b0c1ca51ee90d04ca2a897a71af79de9aec5f38cb0a5342"}, + {file = "types_PyYAML-6.0.12.20240311-py3-none-any.whl", hash = "sha256:b845b06a1c7e54b8e5b4c683043de0d9caf205e7434b3edc678ff2411979b8f6"}, ] [[package]] name = "types-requests" -version = "2.31.0.20240218" +version = "2.31.0.20240311" description = "Typing stubs for requests" optional = false python-versions = ">=3.8" files = [ - {file = "types-requests-2.31.0.20240218.tar.gz", hash = "sha256:f1721dba8385958f504a5386240b92de4734e047a08a40751c1654d1ac3349c5"}, - {file = "types_requests-2.31.0.20240218-py3-none-any.whl", hash = "sha256:a82807ec6ddce8f00fe0e949da6d6bc1fbf1715420218a9640d695f70a9e5a9b"}, + {file = "types-requests-2.31.0.20240311.tar.gz", hash = "sha256:b1c1b66abfb7fa79aae09097a811c4aa97130eb8831c60e47aee4ca344731ca5"}, + {file = "types_requests-2.31.0.20240311-py3-none-any.whl", hash = "sha256:47872893d65a38e282ee9f277a4ee50d1b28bd592040df7d1fdaffdf3779937d"}, ] [package.dependencies] @@ -3828,13 +3828,13 @@ urllib3 = ">=2" [[package]] name = "typing-extensions" -version = "4.9.0" +version = "4.10.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.9.0-py3-none-any.whl", hash = "sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd"}, - {file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"}, + {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, + {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, ] [[package]] @@ -3997,6 +3997,7 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p fastembed = ["fastembed"] hybrid = ["pinecone-text"] local = ["llama-cpp-python", "torch", "transformers"] +mistralai = ["mistralai"] pinecone = ["pinecone-client"] processing = ["matplotlib"] vision = ["pillow", "torch", "torchvision", "transformers"] @@ -4004,4 +4005,4 @@ vision = ["pillow", "torch", "torchvision", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "7e533decdcadb8bed91697d492129d5f6136d10624b9c1f5322b6b7c26c47bc7" +content-hash = "90df5d7e6e850b7175bbbfddbaa11d2f45666501caac2559798dfb00daa363d8" diff --git a/pyproject.toml b/pyproject.toml index 024b9e3de66bd003879e1708cd3ec80d6405efec..624771fa8e3736008ac13273db25e6c44921fb5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ python = ">=3.9,<3.13" pydantic = "^2.5.3" openai = "^1.10.0" cohere = "^4.32" -mistralai= "^0.0.12" +mistralai= {version = "^0.0.12", optional = true} numpy = "^1.25.2" colorlog = "^6.8.0" pyyaml = "^6.0.1" @@ -44,6 +44,7 @@ local = ["torch", "transformers", "llama-cpp-python"] pinecone = ["pinecone-client"] vision = ["torch", "torchvision", "transformers", "pillow"] processing = ["matplotlib"] +mistralai = ["mistralai"] [tool.poetry.group.dev.dependencies] ipykernel = "^6.25.0" @@ -67,4 +68,4 @@ build-backend = "poetry.core.masonry.api" line-length = 88 [tool.mypy] -ignore_missing_imports = true +ignore_missing_imports = true \ No newline at end of file diff --git a/semantic_router/encoders/mistral.py b/semantic_router/encoders/mistral.py index cf1e290a0f88b90a75c9604e63a71e7475fd5f69..544c629f81dcd2199b0b5b10e704c177c7c144d3 100644 --- a/semantic_router/encoders/mistral.py +++ b/semantic_router/encoders/mistral.py @@ -2,20 +2,19 @@ import os from time import sleep -from typing import List, Optional +from typing import List, Optional, Any -from mistralai.client import MistralClient -from mistralai.exceptions import MistralException -from mistralai.models.embeddings import EmbeddingResponse from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault +from pydantic.v1 import PrivateAttr class MistralEncoder(BaseEncoder): """Class to encode text using MistralAI""" - client: Optional[MistralClient] + _client: Any = PrivateAttr() + _mistralai: Any = PrivateAttr() type: str = "mistral" def __init__( @@ -27,16 +26,30 @@ class MistralEncoder(BaseEncoder): if name is None: name = EncoderDefault.MISTRAL.value["embedding_model"] super().__init__(name=name, score_threshold=score_threshold) - api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY") + self._client, self._mistralai = self._initialize_client(mistralai_api_key) + + def _initialize_client(self, api_key): + try: + import mistralai + from mistralai.client import MistralClient + except ImportError: + raise ImportError( + "Please install MistralAI to use MistralEncoder. " + "You can install it with: " + "`pip install 'semantic-router[mistralai]'`" + ) + + api_key = api_key or os.getenv("MISTRALAI_API_KEY") if api_key is None: raise ValueError("Mistral API key not provided") try: - self.client = MistralClient(api_key=api_key) + client = MistralClient(api_key=api_key) except Exception as e: raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e + return client, mistralai def __call__(self, docs: List[str]) -> List[List[float]]: - if self.client is None: + if self._client is None: raise ValueError("Mistral client not initialized") embeds = None error_message = "" @@ -44,16 +57,22 @@ class MistralEncoder(BaseEncoder): # Exponential backoff for _ in range(3): try: - embeds = self.client.embeddings(model=self.name, input=docs) + embeds = self._client.embeddings(model=self.name, input=docs) if embeds.data: break - except MistralException as e: + except self._mistralai.exceptions.MistralException as e: sleep(2**_) error_message = str(e) except Exception as e: raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e - if not embeds or not isinstance(embeds, EmbeddingResponse) or not embeds.data: + if ( + not embeds + or not isinstance( + embeds, self._mistralai.models.embeddings.EmbeddingResponse + ) + or not embeds.data + ): raise ValueError(f"No embeddings returned from MistralAI: {error_message}") embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] return embeddings diff --git a/semantic_router/llms/__init__.py b/semantic_router/llms/__init__.py index 4e2eef16f35d5ef726ed121c6159ab27446ada6a..36f13c8dcb003372cd1bf87112ea5838388ff911 100644 --- a/semantic_router/llms/__init__.py +++ b/semantic_router/llms/__init__.py @@ -1,5 +1,6 @@ from semantic_router.llms.base import BaseLLM from semantic_router.llms.cohere import CohereLLM +from semantic_router.llms.llamacpp import LlamaCppLLM from semantic_router.llms.mistral import MistralAILLM from semantic_router.llms.openai import OpenAILLM from semantic_router.llms.openrouter import OpenRouterLLM @@ -8,6 +9,7 @@ from semantic_router.llms.zure import AzureOpenAILLM __all__ = [ "BaseLLM", "OpenAILLM", + "LlamaCppLLM", "OpenRouterLLM", "CohereLLM", "AzureOpenAILLM", diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index 2586d2e4253e485445c9c5e5bc1b3b81061c8279..9a66b732c16b1c0779d60a45c8551f4ebc17e300 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -2,26 +2,27 @@ from contextlib import contextmanager from pathlib import Path from typing import Any, Optional -from llama_cpp import Llama, LlamaGrammar - from semantic_router.llms.base import BaseLLM from semantic_router.schema import Message from semantic_router.utils.logger import logger +from pydantic.v1 import PrivateAttr + class LlamaCppLLM(BaseLLM): - llm: Llama + llm: Any temperature: float max_tokens: Optional[int] = 200 - grammar: Optional[LlamaGrammar] = None + grammar: Optional[Any] = None + _llama_cpp: Any = PrivateAttr() def __init__( self, - llm: Llama, + llm: Any, name: str = "llama.cpp", temperature: float = 0.2, max_tokens: Optional[int] = 200, - grammar: Optional[LlamaGrammar] = None, + grammar: Optional[Any] = None, ): super().__init__( name=name, @@ -30,6 +31,18 @@ class LlamaCppLLM(BaseLLM): max_tokens=max_tokens, grammar=grammar, ) + + try: + import llama_cpp + except ImportError: + raise ImportError( + "Please install LlamaCPP to use Llama CPP llm. " + "You can install it with: " + "`pip install 'semantic-router[local]'`" + ) + self._llama_cpp = llama_cpp + llm = self._llama_cpp.Llama + grammar = self._llama_cpp.LlamaGrammar self.llm = llm self.temperature = temperature self.max_tokens = max_tokens @@ -62,7 +75,7 @@ class LlamaCppLLM(BaseLLM): grammar_path = Path(__file__).parent.joinpath("grammars", "json.gbnf") assert grammar_path.exists(), f"{grammar_path}\ndoes not exist" try: - self.grammar = LlamaGrammar.from_file(grammar_path) + self.grammar = self._llama_cpp.LlamaGrammar.from_file(grammar_path) yield finally: self.grammar = None diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py index e17ba8bab050d6dbdccf3f22376729fd921c692e..647d4073e5c7b50591dd7cd686536c0bebba0714 100644 --- a/semantic_router/llms/mistral.py +++ b/semantic_router/llms/mistral.py @@ -1,18 +1,20 @@ import os -from typing import List, Optional +from typing import List, Optional, Any -from mistralai.client import MistralClient from semantic_router.llms import BaseLLM from semantic_router.schema import Message from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.logger import logger +from pydantic.v1 import PrivateAttr + class MistralAILLM(BaseLLM): - client: Optional[MistralClient] + _client: Any = PrivateAttr() temperature: Optional[float] max_tokens: Optional[int] + _mistralai: Any = PrivateAttr() def __init__( self, @@ -24,25 +26,45 @@ class MistralAILLM(BaseLLM): if name is None: name = EncoderDefault.MISTRAL.value["language_model"] super().__init__(name=name) - api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY") + self._client, self._mistralai = self._initialize_client(mistralai_api_key) + self.temperature = temperature + self.max_tokens = max_tokens + + def _initialize_client(self, api_key): + try: + import mistralai + from mistralai.client import MistralClient + except ImportError: + raise ImportError( + "Please install MistralAI to use MistralAI LLM. " + "You can install it with: " + "`pip install 'semantic-router[mistralai]'`" + ) + api_key = api_key or os.getenv("MISTRALAI_API_KEY") if api_key is None: raise ValueError("MistralAI API key cannot be 'None'.") try: - self.client = MistralClient(api_key=api_key) + client = MistralClient(api_key=api_key) except Exception as e: raise ValueError( f"MistralAI API client failed to initialize. Error: {e}" ) from e - self.temperature = temperature - self.max_tokens = max_tokens + return client, mistralai def __call__(self, messages: List[Message]) -> str: - if self.client is None: + if self._client is None: raise ValueError("MistralAI client is not initialized.") + + chat_messages = [ + self._mistralai.models.chat_completion.ChatMessage( + role=m.role, content=m.content + ) + for m in messages + ] try: - completion = self.client.chat( + completion = self._client.chat( model=self.name, - messages=[m.to_mistral() for m in messages], + messages=chat_messages, temperature=self.temperature, max_tokens=self.max_tokens, ) diff --git a/tests/unit/encoders/test_mistral.py b/tests/unit/encoders/test_mistral.py index c2a91c128d22966da177dfc812e37d3bda215b09..f36f5037abaab256b03ab912714831b33a23fe2a 100644 --- a/tests/unit/encoders/test_mistral.py +++ b/tests/unit/encoders/test_mistral.py @@ -4,6 +4,8 @@ from mistralai.models.embeddings import EmbeddingObject, EmbeddingResponse, Usag from semantic_router.encoders import MistralEncoder +from unittest.mock import patch + @pytest.fixture def mistralai_encoder(mocker): @@ -12,9 +14,21 @@ def mistralai_encoder(mocker): class TestMistralEncoder: + def test_mistral_encoder_import_errors(self): + with patch.dict("sys.modules", {"mistralai": None}): + with pytest.raises(ImportError) as error: + MistralEncoder() + + assert ( + "Please install MistralAI to use MistralEncoder. " + "You can install it with: " + "`pip install 'semantic-router[mistralai]'`" in str(error.value) + ) + def test_mistralai_encoder_init_success(self, mocker): encoder = MistralEncoder(mistralai_api_key="test_api_key") - assert encoder.client is not None + assert encoder._client is not None + assert encoder._mistralai is not None def test_mistralai_encoder_init_no_api_key(self, mocker): mocker.patch("os.getenv", return_value=None) @@ -23,7 +37,7 @@ class TestMistralEncoder: def test_mistralai_encoder_call_uninitialized_client(self, mistralai_encoder): # Set the client to None to simulate an uninitialized client - mistralai_encoder.client = None + mistralai_encoder._client = None with pytest.raises(ValueError) as e: mistralai_encoder(["test document"]) assert "Mistral client not initialized" in str(e.value) @@ -60,7 +74,7 @@ class TestMistralEncoder: responses = [MistralException("mistralai error"), mock_response] mocker.patch.object( - mistralai_encoder.client, "embeddings", side_effect=responses + mistralai_encoder._client, "embeddings", side_effect=responses ) embeddings = mistralai_encoder(["test document"]) assert embeddings == [[0.1, 0.2]] @@ -69,7 +83,7 @@ class TestMistralEncoder: mocker.patch("os.getenv", return_value="fake-api-key") mocker.patch("time.sleep", return_value=None) # To speed up the test mocker.patch.object( - mistralai_encoder.client, + mistralai_encoder._client, "embeddings", side_effect=MistralException("Test error"), ) @@ -83,7 +97,7 @@ class TestMistralEncoder: mocker.patch("os.getenv", return_value="fake-api-key") mocker.patch("time.sleep", return_value=None) # To speed up the test mocker.patch.object( - mistralai_encoder.client, + mistralai_encoder._client, "embeddings", side_effect=Exception("Non-MistralException"), ) @@ -118,7 +132,7 @@ class TestMistralEncoder: responses = [MistralException("mistralai error"), mock_response] mocker.patch.object( - mistralai_encoder.client, "embeddings", side_effect=responses + mistralai_encoder._client, "embeddings", side_effect=responses ) embeddings = mistralai_encoder(["test document"]) assert embeddings == [[0.1, 0.2]] diff --git a/tests/unit/llms/test_llm_llamacpp.py b/tests/unit/llms/test_llm_llamacpp.py index f0a5253f909ecce92769b50ccf7b6578720c3f63..63d92ee8dd0ff4dad8b0c60846433e7197ee1699 100644 --- a/tests/unit/llms/test_llm_llamacpp.py +++ b/tests/unit/llms/test_llm_llamacpp.py @@ -4,6 +4,8 @@ from llama_cpp import Llama from semantic_router.llms.llamacpp import LlamaCppLLM from semantic_router.schema import Message +from unittest.mock import patch + @pytest.fixture def llamacpp_llm(mocker): @@ -13,6 +15,17 @@ def llamacpp_llm(mocker): class TestLlamaCppLLM: + def test_llama_cpp_import_errors(self, llamacpp_llm): + with patch.dict("sys.modules", {"llama_cpp": None}): + with pytest.raises(ImportError) as error: + LlamaCppLLM(llamacpp_llm.llm) + + assert ( + "Please install LlamaCPP to use Llama CPP llm. " + "You can install it with: " + "`pip install 'semantic-router[local]'`" in str(error.value) + ) + def test_llamacpp_llm_init_success(self, llamacpp_llm): assert llamacpp_llm.name == "llama.cpp" assert llamacpp_llm.temperature == 0.2 diff --git a/tests/unit/llms/test_llm_mistral.py b/tests/unit/llms/test_llm_mistral.py index 4e01ca5af528ea3d92b316c04b2fcf1f19f068c6..d73406e726c721242743f36e9de062b7460daaf0 100644 --- a/tests/unit/llms/test_llm_mistral.py +++ b/tests/unit/llms/test_llm_mistral.py @@ -3,6 +3,8 @@ import pytest from semantic_router.llms import MistralAILLM from semantic_router.schema import Message +from unittest.mock import patch + @pytest.fixture def mistralai_llm(mocker): @@ -11,14 +13,25 @@ def mistralai_llm(mocker): class TestMistralAILLM: + def test_mistral_llm_import_errors(self): + with patch.dict("sys.modules", {"mistralai": None}): + with pytest.raises(ImportError) as error: + MistralAILLM() + + assert ( + "Please install MistralAI to use MistralAI LLM. " + "You can install it with: " + "`pip install 'semantic-router[mistralai]'`" in str(error.value) + ) + def test_mistralai_llm_init_with_api_key(self, mistralai_llm): - assert mistralai_llm.client is not None, "Client should be initialized" + assert mistralai_llm._client is not None, "Client should be initialized" assert mistralai_llm.name == "mistral-tiny", "Default name not set correctly" def test_mistralai_llm_init_success(self, mocker): mocker.patch("os.getenv", return_value="fake-api-key") llm = MistralAILLM() - assert llm.client is not None + assert llm._client is not None def test_mistralai_llm_init_without_api_key(self, mocker): mocker.patch("os.getenv", return_value=None) @@ -27,7 +40,7 @@ class TestMistralAILLM: def test_mistralai_llm_call_uninitialized_client(self, mistralai_llm): # Set the client to None to simulate an uninitialized client - mistralai_llm.client = None + mistralai_llm._client = None with pytest.raises(ValueError) as e: llm_input = [Message(role="user", content="test")] mistralai_llm(llm_input) @@ -48,7 +61,7 @@ class TestMistralAILLM: mocker.patch("os.getenv", return_value="fake-api-key") mocker.patch.object( - mistralai_llm.client, + mistralai_llm._client, "chat", return_value=mock_completion, )