From b67df84ca9412aeb1101925b7d8efe15a15a4673 Mon Sep 17 00:00:00 2001 From: Lawrence Tsang <pikalaw@gmail.com> Date: Wed, 13 Dec 2023 09:32:13 -0500 Subject: [PATCH] Introduce Google Generative Language Semantic Retriever (#9440) --- .../community/integrations/managed_indices.md | 38 +- docs/examples/managed/GoogleDemo.ipynb | 599 ++++++++++++++++++ .../querying/retriever/retrievers.md | 1 + .../managed/google/generativeai/__init__.py | 5 + .../managed/google/generativeai/base.py | 240 +++++++ .../google/generativeai/__init__.py | 6 + .../google/generativeai/base.py | 255 ++++++++ .../google/generativeai/__init__.py | 6 + .../vector_stores/google/generativeai/base.py | 399 ++++++++++++ .../google/generativeai/genai_extension.py | 589 +++++++++++++++++ poetry.lock | 293 ++++----- pyproject.toml | 3 +- tests/indices/managed/__init__.py | 0 tests/indices/managed/test_google.py | 206 ++++++ tests/response_synthesizers/test_google.py | 176 +++++ tests/vector_stores/test_google.py | 306 +++++++++ 16 files changed, 2963 insertions(+), 159 deletions(-) create mode 100644 docs/examples/managed/GoogleDemo.ipynb create mode 100644 llama_index/indices/managed/google/generativeai/__init__.py create mode 100644 llama_index/indices/managed/google/generativeai/base.py create mode 100644 llama_index/response_synthesizers/google/generativeai/__init__.py create mode 100644 llama_index/response_synthesizers/google/generativeai/base.py create mode 100644 llama_index/vector_stores/google/generativeai/__init__.py create mode 100644 llama_index/vector_stores/google/generativeai/base.py create mode 100644 llama_index/vector_stores/google/generativeai/genai_extension.py create mode 100644 tests/indices/managed/__init__.py create mode 100644 tests/indices/managed/test_google.py create mode 100644 tests/response_synthesizers/test_google.py create mode 100644 tests/vector_stores/test_google.py diff --git a/docs/community/integrations/managed_indices.md b/docs/community/integrations/managed_indices.md index 5f81ea540b..2680fb3ed1 100644 --- a/docs/community/integrations/managed_indices.md +++ b/docs/community/integrations/managed_indices.md @@ -9,10 +9,42 @@ of documents. Once constructed, the index can be used for querying. If the Index has been previously populated with documents - it can also be used directly for querying. -`VectaraIndex` is currently the only supported managed index, although we expect more to be available soon. -Below we show how to use it. +## Google Generative Language Semantic Retriever. -**Vectara Index Construction/Querying** +Google's Semantic Retrieve provides both querying and retrieval capabilities. Create a managed index, insert documents, and use a query engine or retriever anywhere in LlamaIndex! + +```python +from llama_index import SimpleDirectoryReader +from llama_index.indices.managed.google.generativeai import GoogleIndex + +# Create a corpus +index = GoogleIndex.create_corpus(display_name="My first corpus!") +print(f"Newly created corpus ID is {index.corpus_id}.") + +# Ingestion +documents = SimpleDirectoryReader("data").load_data() +index.insert_documents(documents) + +# Querying +query_engine = index.as_query_engine() +response = query_engine.query("What did the author do growing up?") + +# Retrieving +retriever = index.as_retriever() +source_nodes = retriever.retrieve("What did the author do growing up?") +``` + +See the notebook guide for full details. + +```{toctree} +--- +caption: Examples +maxdepth: 1 +--- +/examples/managed/GoogleDemo.ipynb +``` + +## Vectara First, [sign up](https://vectara.com/integrations/llama_index) and use the Vectara Console to create a corpus (aka Index), and add an API key for access. Then put the customer id, corpus id, and API key in your environment. diff --git a/docs/examples/managed/GoogleDemo.ipynb b/docs/examples/managed/GoogleDemo.ipynb new file mode 100644 index 0000000000..f76afd99fe --- /dev/null +++ b/docs/examples/managed/GoogleDemo.ipynb @@ -0,0 +1,599 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/managed/GoogleDemo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Google Generative Language Semantic Retriever\n", + "\n", + "In this notebook, we will show you how to get started quickly with using Google's Generative Language Semantic Retriever, which offers specialized embedding models for high quality retrieval and a tuned model for producing grounded output with customizable safety settings. We will also show you some advanced examples on how to combine the power of LlamaIndex and this unique offering from Google." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index\n", + "%pip install \"google-ai-generativelanguage>=0.4,<=1.0\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Authentication: Get OAuth Credentials\n", + "\n", + "The Google Semantic Retriever API lets you perform semantic search on your own data. Since it's **your data**, this needs stricter access controls than API Keys. Please follow [OAuth Quickstart](https://developers.generativeai.google/tutorials/oauth_quickstart) to setup OAuth authentication. Below are overview of steps from the documentation that are required.\n", + "\n", + "1. Enable the `Generative Language API`: [Documentation](https://developers.generativeai.google/tutorials/oauth_quickstart#1_enable_the_api)\n", + "\n", + "1. Configure the OAuth consent screen: [Documentation](https://developers.generativeai.google/tutorials/oauth_quickstart#2_configure_the_oauth_consent_screen)\n", + "\n", + "1. Authorize credentials for a desktop application: [Documentation](https://developers.generativeai.google/tutorials/oauth_quickstart#3_authorize_credentials_for_a_desktop_application)\n", + " * If you want to run this notebook in Colab start by uploading your\n", + "`client_secret*.json` file using the \"File > Upload\" option.\n", + "\n", + " * Rename the uploaded file to `client_secret.json` or change the variable `client_file_name` in the code below.\n", + "\n", + "<img width=400 src=\"https://developers.generativeai.google/tutorials/images/colab_upload.png\">\n", + "\n", + "\n", + "**Note**: At this time, the Google Generative AI Semantic Retriever API is [only available in certain regions](https://developers.generativeai.google/available_regions)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Replace TODO-your-project-name with the project used in the OAuth Quickstart\n", + "project_name = \"TODO-your-project-name\" # @param {type:\"string\"}\n", + "# Replace TODO-your-email@gmail.com with the email added as a test user in the OAuth Quickstart\n", + "email = \"TODO-your-email@gmail.com\" # @param {type:\"string\"}\n", + "# Replace client_secret.json with the client_secret_* file name you uploaded.\n", + "client_file_name = \"client_secret.json\"\n", + "\n", + "# IMPORTANT: Follow the instructions from the output - you must copy the command\n", + "# to your terminal and copy the output after authentication back here.\n", + "!gcloud config set project $project_name\n", + "!gcloud config set account $email\n", + "\n", + "# NOTE: The simplified project setup in this tutorial triggers a \"Google hasn't verified this app.\" dialog.\n", + "# This is normal, click \"Advanced\" -> \"Go to [app name] (unsafe)\"\n", + "!gcloud auth application-default login --no-browser --client-id-file=$client_file_name --scopes=\"https://www.googleapis.com/auth/generative-language.retriever,https://www.googleapis.com/auth/cloud-platform\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will provide you with a URL, which you should enter into your local browser.\n", + "Follow the instruction to complete the authentication and authorization." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!mkdir -p 'data/'\n", + "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usage\n", + "\n", + "A `corpus` is a collection of `document`s. A `document` is a body of text that is broken into `chunk`s." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index import SimpleDirectoryReader\n", + "from llama_index.indices.managed.google.generativeai import GoogleIndex\n", + "\n", + "# Create a corpus.\n", + "index = GoogleIndex.create_corpus(display_name='My first corpus!')\n", + "print(f\"Newly created corpus ID is {index.corpus_id}.\")\n", + "\n", + "# Ingestion.\n", + "documents = SimpleDirectoryReader(\"data\").load_data()\n", + "index.insert_documents(documents)\n", + "\n", + "# Querying.\n", + "query_engine = index.as_query_engine()\n", + "response = query_engine.query(\"What did the author do growing up?\")\n", + "\n", + "# Show response.\n", + "print(f\"Response is {response.response}\")\n", + "\n", + "# Show cited passages that were used to construct the response.\n", + "for cited_text in [node.text for node in response.source_nodes]:\n", + " print(f\"Cited text: {cited_text}\")\n", + "\n", + "# Show answerability. 0 means not answerable from the passages.\n", + "# 1 means the model is certain the answer can be provided from the passages.\n", + "print(f\"Answerability: {response.metadata.get(\"answerable_probability\", 0)}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a Corpus\n", + "\n", + "There are various ways to create a corpus." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The Google server will provide a corpus ID for you.\n", + "index = GoogleIndex.create_corpus(display_name=\"My first corpus!\")\n", + "print(index.corpus_id)\n", + "\n", + "# You can also provide your own corpus ID. However, this ID needs to be globally\n", + "# unique. You will get an exception if someone else has this ID already.\n", + "index = GoogleIndex.create_corpus(\n", + " corpus_id=\"my-first-corpus\", display_name=\"My first corpus!\"\n", + ")\n", + "\n", + "# If you do not provide any parameter, Google will provide ID and a default\n", + "# display name for you.\n", + "index = GoogleIndex.create_corpus()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reusing a Corpus\n", + "\n", + "Corpora you created persists on the Google servers under your account.\n", + "You can use its ID to get a handle back.\n", + "Then, you can query it, add more document to it, etc." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use a previously created corpus.\n", + "index = GoogleIndex.from_corpus(corpus_id=\"abc-123\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Listing and Deleting Corpora\n", + "\n", + "See the Python library [google-generativeai](https://github.com/google/generative-ai-python) for further documentation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading Documents\n", + "\n", + "Many node parsers and text splitters in LlamaIndex automatically add to each node a *source_node* to associate it to a file, e.g.\n", + "\n", + "```python\n", + " relationships={\n", + " NodeRelationship.SOURCE: RelatedNodeInfo(\n", + " node_id=\"abc-123\",\n", + " metadata={\"file_name\": \"Title for the document\"},\n", + " )\n", + " },\n", + "```\n", + "\n", + "Both `GoogleIndex` and `GoogleVectorStore` recognize this source node,\n", + "and will automatically create documents under your corpus on the Google servers.\n", + "\n", + "In case you are writing your own chunker, you should supply this source node relationship too like below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.schema import NodeRelationship, RelatedNodeInfo, TextNode\n", + "\n", + "index = GoogleIndex.from_corpus(corpus_id=\"123\")\n", + "index.insert_nodes(\n", + " [\n", + " TextNode(\n", + " text=\"It was the best of times.\",\n", + " relationships={\n", + " NodeRelationship.SOURCE: RelatedNodeInfo(\n", + " node_id=\"123\",\n", + " metadata={\"file_name\": \"Tale of Two Cities\"},\n", + " )\n", + " },\n", + " ),\n", + " TextNode(\n", + " text=\"It was the worst of times.\",\n", + " relationships={\n", + " NodeRelationship.SOURCE: RelatedNodeInfo(\n", + " node_id=\"123\",\n", + " metadata={\"file_name\": \"Tale of Two Cities\"},\n", + " )\n", + " },\n", + " ),\n", + " TextNode(\n", + " text=\"Wassup doc\",\n", + " relationships={\n", + " NodeRelationship.SOURCE: RelatedNodeInfo(\n", + " node_id=\"456\",\n", + " metadata={\"file_name\": \"Bugs Bunny Adventure\"},\n", + " )\n", + " },\n", + " ),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If your nodes do not have a source node, then Google server will put your nodes in a default document under your corpus." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Listing and Deleting Documents\n", + "\n", + "See the Python library [google-generativeai](https://github.com/google/generative-ai-python) for further documentation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Querying Corpus\n", + "\n", + "Google's query engine is backed by a specially tuned LLM that grounds its response based on retrieved passages. For each response, an *answerability probability* is returned to indicate how confident the LLM was in answering the question from the retrieved passages.\n", + "\n", + "Furthermore, Google's query engine supports *answering styles*, such as `ABSTRACTIVE` (succint but abstract), `EXTRACTIVE` (very brief and extractive) and `VERBOSE` (extra details).\n", + "\n", + "The engine also supports *safety settings*.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.ai.generativelanguage import (\n", + " GenerateAnswerRequest,\n", + " HarmCategory,\n", + " SafetySetting,\n", + ")\n", + "\n", + "index = GoogleIndex.from_corpus(corpus_id=\"123\")\n", + "query_engine = index.as_query_engine(\n", + " # Extra parameters specific to the Google query engine.\n", + " temperature=0.7,\n", + " answer_style=GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE,\n", + " safety_setting=[\n", + " SafetySetting(\n", + " category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,\n", + " threshold=SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,\n", + " ),\n", + " SafetySetting(\n", + " category=HarmCategory.HARM_CATEGORY_VIOLENCE,\n", + " threshold=SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,\n", + " ),\n", + " ],\n", + ")\n", + "\n", + "response = query_engine.query(\"What movie should I watch with my family?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See the Python library [google-generativeai](https://github.com/google/generative-ai-python) for further documentation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Interpreting the Response" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.response.schema import Response\n", + "\n", + "response = query_engine.query(\"What movie should I watch with my family?\")\n", + "assert isinstance(response, Response)\n", + "\n", + "# Show response.\n", + "print(f\"Response is {response.response}\")\n", + "\n", + "# Show cited passages that were used to construct the response.\n", + "for cited_text in [node.text for node in response.source_nodes]:\n", + " print(f\"Cited text: {cited_text}\")\n", + "\n", + "# Show answerability. 0 means not answerable from the passages.\n", + "# 1 means the model is certain the answer can be provided from the passages.\n", + "print(f\"Answerability: {response.metadata.get(\"answerable_probability\", 0)}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Advanced RAG\n", + "\n", + "The `GoogleIndex` is built based on `GoogleVectorStore` and `GoogleTextSynthesizer`.\n", + "These components can be combined with other powerful constructs in LlamaIndex to produce advanced RAG applications.\n", + "\n", + "Below we show a few examples." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reranker + Google Retriever\n", + "\n", + "Converting content into vectors is a lossy process. LLM-based Reranking\n", + "remediates this by reranking the retrieved content using LLM, which has higher\n", + "fidelity because it has access to both the actual query and the passage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.response_synthesizers.google.generativeai import (\n", + " GoogleTextSynthesizer,\n", + ")\n", + "from llama_index.vector_stores.google.generativeai import (\n", + " GoogleVectorStore,\n", + " google_service_context,\n", + ")\n", + "from llama_index import ServiceContext, VectorStoreIndex\n", + "from llama_index.llms import PaLM\n", + "from llama_index.postprocessor import LLMRerank\n", + "from llama_index.query_engine import RetrieverQueryEngine\n", + "from llama_index.retrievers import VectorIndexRetriever\n", + "\n", + "# Set up the query engine with a reranker.\n", + "store = GoogleVectorStore.from_corpus(corpus_id=\"some-corpus-id\")\n", + "index = VectorStoreIndex.from_vector_store(\n", + " vector_store=store, service_context=google_service_context\n", + ")\n", + "response_synthesizer = GoogleTextSynthesizer.from_defaults(\n", + " temperature=0.7, answer_style=GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE\n", + ")\n", + "reranker = LLMRerank(\n", + " top_n=10, service_context=ServiceContext.from_defaults(llm=PaLM())\n", + ")\n", + "query_engine = RetrieverQueryEngine.from_args(\n", + " retriever=VectorIndexRetriever(\n", + " index=index,\n", + " similarity_top_k=20,\n", + " ),\n", + " response_synthesizer=response_synthesizer,\n", + " node_postprocessors=[reranker],\n", + ")\n", + "\n", + "# Query for better result!\n", + "response = query_engine.query(\"What movie should I watch with my family?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Multi-Query + Google Retriever\n", + "\n", + "Sometimes, a user's query can be too complex. You may get better retrieval result if you break down the original query into smaller, better focused queries." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.indices.query.query_transform.base import (\n", + " StepDecomposeQueryTransform,\n", + ")\n", + "from llama_index.query_engine.multistep_query_engine import (\n", + " MultiStepQueryEngine,\n", + ")\n", + "\n", + "# Set up the query engine with multi-turn query-rewriter.\n", + "store = GoogleVectorStore.from_corpus(corpus_id=\"some-corpus-id\")\n", + "index = VectorStoreIndex.from_vector_store(\n", + " vector_store=store, service_context=google_service_context\n", + ")\n", + "response_synthesizer = GoogleTextSynthesizer.from_defaults(\n", + " temperature=0.7, answer_style=GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE\n", + ")\n", + "single_step_query_engine = index.as_query_engine(\n", + " response_synthesizer=response_synthesizer\n", + ")\n", + "step_decompose_transform = StepDecomposeQueryTransform(\n", + " llm=PaLM(), verbose=True\n", + ")\n", + "query_engine = MultiStepQueryEngine(\n", + " query_engine=single_step_query_engine,\n", + " query_transform=step_decompose_transform,\n", + " response_synthesizer=response_synthesizer,\n", + " index_summary=\"Ask me anything.\",\n", + " num_steps=6,\n", + ")\n", + "\n", + "# Query for better result!\n", + "response = query_engine.query(\"What movie should I watch with my family?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### HyDE + Google Retriever\n", + "\n", + "When you can write prompt that would produce fake answers that share many traits\n", + "with the real answer, you can try HyDE!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.indices.query.query_transform import HyDEQueryTransform\n", + "from llama_index.query_engine.transform_query_engine import (\n", + " TransformQueryEngine,\n", + ")\n", + "\n", + "# Set up the query engine with multi-turn query-rewriter.\n", + "store = GoogleVectorStore.from_corpus(corpus_id=\"some-corpus-id\")\n", + "index = VectorStoreIndex.from_vector_store(\n", + " vector_store=store, service_context=google_service_context\n", + ")\n", + "response_synthesizer = GoogleTextSynthesizer.from_defaults(\n", + " temperature=0.7, answer_style=GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE\n", + ")\n", + "base_query_engine = index.as_query_engine(\n", + " response_synthesizer=response_synthesizer\n", + ")\n", + "hyde = HyDEQueryTransform(include_original=False)\n", + "hyde_query_engine = TransformQueryEngine(base_query_engine, hyde)\n", + "\n", + "# Query for better result!\n", + "response = hyde_query_engine.query(\"What movie should I watch with my family?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Multi-Query + Reranker + HyDE + Google Retriever\n", + "\n", + "Or combine them all!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Google's retriever and AQA model setup.\n", + "store = GoogleVectorStore.from_corpus(corpus_id=\"some-corpus-id\")\n", + "index = VectorStoreIndex.from_vector_store(\n", + " vector_store=store, service_context=google_service_context\n", + ")\n", + "response_synthesizer = GoogleTextSynthesizer.from_defaults(\n", + " temperature=0.7, answer_style=GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE\n", + ")\n", + "\n", + "# Reranker setup.\n", + "reranker = LLMRerank(\n", + " top_n=10, service_context=ServiceContext.from_defaults(llm=PaLM())\n", + ")\n", + "single_step_query_engine = index.as_query_engine(\n", + " response_synthesizer=response_synthesizer, node_postprocessors=[reranker]\n", + ")\n", + "\n", + "# HyDE setup.\n", + "hyde = HyDEQueryTransform(include_original=True)\n", + "hyde_query_engine = TransformQueryEngine(single_step_query_engine, hyde)\n", + "\n", + "# Multi-query setup.\n", + "step_decompose_transform = StepDecomposeQueryTransform(\n", + " llm=PaLM(), verbose=True\n", + ")\n", + "query_engine = MultiStepQueryEngine(\n", + " query_engine=hyde_query_engine,\n", + " query_transform=step_decompose_transform,\n", + " response_synthesizer=response_synthesizer,\n", + " index_summary=\"Ask me anything.\",\n", + " num_steps=6,\n", + ")\n", + "\n", + "# Query for better result!\n", + "response = query_engine.query(\"What movie should I watch with my family?\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/module_guides/querying/retriever/retrievers.md b/docs/module_guides/querying/retriever/retrievers.md index c447ac95be..d91f3810c9 100644 --- a/docs/module_guides/querying/retriever/retrievers.md +++ b/docs/module_guides/querying/retriever/retrievers.md @@ -54,4 +54,5 @@ Auto-Retrieval (with BagelDB) </examples/vector_stores/BagelAutoRetriever.ipynb> /examples/node_postprocessor/MetadataReplacementDemo.ipynb /examples/index_structs/struct_indices/SQLIndexDemo.ipynb DeepMemory (Activeloop) </examples/retrievers/deep_memory.ipynb> +/examples/managed/GoogleDemo.ipynb ``` diff --git a/llama_index/indices/managed/google/generativeai/__init__.py b/llama_index/indices/managed/google/generativeai/__init__.py new file mode 100644 index 0000000000..3e1930d14c --- /dev/null +++ b/llama_index/indices/managed/google/generativeai/__init__.py @@ -0,0 +1,5 @@ +from .base import GoogleIndex + +__all__ = [ + "GoogleIndex", +] diff --git a/llama_index/indices/managed/google/generativeai/base.py b/llama_index/indices/managed/google/generativeai/base.py new file mode 100644 index 0000000000..1b29ec8d1d --- /dev/null +++ b/llama_index/indices/managed/google/generativeai/base.py @@ -0,0 +1,240 @@ +"""Google GenerativeAI Semantic Vector Store & Attributed Question and Answering. + +Google Generative AI Semantic Retriever API is a managed end to end service that +allows developers to create a corpus of documents to perform semantic search on +related passages given a user query. + +Google Generative AI Attributed Question and Answering API is a managed +end-to-end service that allows developers to create responses grounded on +specified passages based on user queries. + +For more information visit: +https://developers.generativeai.google/guide +""" + +import datetime +import logging +from typing import Any, List, Optional, Sequence, Type, cast + +from llama_index import VectorStoreIndex +from llama_index.data_structs.data_structs import IndexDict +from llama_index.indices.base import IndexType +from llama_index.indices.base_retriever import BaseRetriever +from llama_index.indices.managed.base import BaseManagedIndex +from llama_index.indices.query.base import BaseQueryEngine +from llama_index.indices.service_context import ServiceContext +from llama_index.response_synthesizers.google.generativeai import ( + GoogleTextSynthesizer, +) +from llama_index.schema import BaseNode, Document +from llama_index.storage.storage_context import StorageContext +from llama_index.vector_stores.google.generativeai import ( + GoogleVectorStore, + google_service_context, +) + +_logger = logging.getLogger(__name__) + + +class GoogleIndex(BaseManagedIndex): + """Google's Generative AI Semantic vector store with AQA.""" + + _store: GoogleVectorStore + _index: VectorStoreIndex + + def __init__( + self, + vector_store: GoogleVectorStore, + service_context: Optional[ServiceContext] = None, + **kwargs: Any, + ) -> None: + """Creates an instance of GoogleIndex. + + Prefer to use the factories `from_corpus` or `create_corpus` instead. + """ + actual_service_context = service_context or google_service_context + + self._store = vector_store + self._index = VectorStoreIndex.from_vector_store( + vector_store, service_context=actual_service_context, **kwargs + ) + + super().__init__( + index_struct=self._index.index_struct, + service_context=actual_service_context, + **kwargs, + ) + + @classmethod + def from_corpus( + cls: Type[IndexType], *, corpus_id: str, **kwargs: Any + ) -> IndexType: + """Creates a GoogleIndex from an existing corpus. + + Args: + corpus_id: ID of an existing corpus on Google's server. + + Returns: + An instance of GoogleIndex pointing to the specified corpus. + """ + _logger.debug(f"\n\nGoogleIndex.from_corpus(corpus_id={corpus_id})") + return cls( + vector_store=GoogleVectorStore.from_corpus(corpus_id=corpus_id), **kwargs + ) + + @classmethod + def create_corpus( + cls: Type[IndexType], + *, + corpus_id: Optional[str] = None, + display_name: Optional[str] = None, + **kwargs: Any, + ) -> IndexType: + """Creates a GoogleIndex from a new corpus. + + Args: + corpus_id: ID of the new corpus to be created. If not provided, + Google server will provide one. + display_name: Title of the new corpus. If not provided, Google + server will provide one. + + Returns: + An instance of GoogleIndex pointing to the specified corpus. + """ + _logger.debug( + f"\n\nGoogleIndex.from_new_corpus(new_corpus_id={corpus_id}, new_display_name={display_name})" + ) + return cls( + vector_store=GoogleVectorStore.create_corpus( + corpus_id=corpus_id, display_name=display_name + ), + **kwargs, + ) + + @classmethod + def from_documents( + cls: Type[IndexType], + documents: Sequence[Document], + storage_context: Optional[StorageContext] = None, + service_context: Optional[ServiceContext] = None, + show_progress: bool = False, + **kwargs: Any, + ) -> IndexType: + """Build an index from a sequence of documents.""" + _logger.debug(f"\n\nGoogleIndex.from_documents(...)") + + new_display_name = f"Corpus created on {datetime.datetime.now()}" + instance = cls( + vector_store=GoogleVectorStore.create_corpus(display_name=new_display_name), + **kwargs, + ) + + index = cast(GoogleIndex, instance) + index.insert_documents(documents=documents, service_context=service_context) + + return instance + + @property + def corpus_id(self) -> str: + """Returns the corpus ID being used by this GoogleIndex.""" + return self._store.corpus_id + + def _insert(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None: + """Inserts a set of nodes.""" + self._index.insert_nodes(nodes=nodes, **insert_kwargs) + + def insert_documents(self, documents: Sequence[Document], **kwargs: Any) -> None: + """Inserts a set of documents.""" + for document in documents: + self.insert(document=document, **kwargs) + + def delete_ref_doc( + self, ref_doc_id: str, delete_from_docstore: bool = False, **delete_kwargs: Any + ) -> None: + """Deletes a document and its nodes by using ref_doc_id.""" + self._index.delete_ref_doc(ref_doc_id=ref_doc_id, **delete_kwargs) + + def update_ref_doc(self, document: Document, **update_kwargs: Any) -> None: + """Updates a document and its corresponding nodes.""" + self._index.update(document=document, **update_kwargs) + + def as_retriever(self, **kwargs: Any) -> BaseRetriever: + """Returns a Retriever for this managed index.""" + return self._index.as_retriever(**kwargs) + + def as_query_engine( + self, + *, + temperature: float = 0.7, + answer_style: Any = 1, + safety_setting: List[Any] = [], + **kwargs: Any, + ) -> BaseQueryEngine: + """Returns the AQA engine for this index. + + Example: + query_engine = index.as_query_engine( + temperature=0.7, + answer_style=AnswerStyle.ABSTRACTIVE, + safety_setting=[ + SafetySetting( + category=HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + ] + ) + + Args: + temperature: 0.0 to 1.0. + answer_style: See `google.ai.generativelanguage.GenerateAnswerRequest.AnswerStyle` + safety_setting: See `google.ai.generativelanguage.SafetySetting`. + + Returns: + A query engine that uses Google's AQA model. The query engine will + return a `Response` object. + + `Response`'s `source_nodes` will begin with a list of attributed + passages. These passages are the ones that were used to construct + the grounded response. These passages will always have no score, + the only way to mark them as attributed passages. Then, the list + will follow with the originally provided passages, which will have + a score from the retrieval. + + `Response`'s `metadata` may also have have an entry with key + `answerable_probability`, which is the probability that the grounded + answer is likely correct. + """ + # NOTE: lazy import + from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine + + # Don't overwrite the caller's kwargs, which may surprise them. + local_kwargs = kwargs.copy() + + if "retriever" in kwargs: + _logger.warning( + "Ignoring user's retriever to GoogleIndex.as_query_engine, " + "which uses its own retriever." + ) + del local_kwargs["retriever"] + + if "response_synthesizer" in kwargs: + _logger.warning( + "Ignoring user's response synthesizer to " + "GoogleIndex.as_query_engine, which uses its own retriever." + ) + del local_kwargs["response_synthesizer"] + + local_kwargs["retriever"] = self.as_retriever(**local_kwargs) + local_kwargs["response_synthesizer"] = GoogleTextSynthesizer.from_defaults( + temperature=temperature, + answer_style=answer_style, + safety_setting=safety_setting, + ) + if "service_context" not in local_kwargs: + local_kwargs["service_context"] = self._service_context + + return RetrieverQueryEngine.from_args(**local_kwargs) + + def _build_index_from_nodes(self, nodes: Sequence[BaseNode]) -> IndexDict: + """Build the index from nodes.""" + return self._index._build_index_from_nodes(nodes) diff --git a/llama_index/response_synthesizers/google/generativeai/__init__.py b/llama_index/response_synthesizers/google/generativeai/__init__.py new file mode 100644 index 0000000000..d2193a1938 --- /dev/null +++ b/llama_index/response_synthesizers/google/generativeai/__init__.py @@ -0,0 +1,6 @@ +from .base import GoogleTextSynthesizer, SynthesizedResponse + +__all__ = [ + "GoogleTextSynthesizer", + "SynthesizedResponse", +] diff --git a/llama_index/response_synthesizers/google/generativeai/base.py b/llama_index/response_synthesizers/google/generativeai/base.py new file mode 100644 index 0000000000..e9daa9cbf3 --- /dev/null +++ b/llama_index/response_synthesizers/google/generativeai/base.py @@ -0,0 +1,255 @@ +"""Google GenerativeAI Attributed Question and Answering (AQA) service. + +The GenAI Semantic AQA API is a managed end to end service that allows +developers to create responses grounded on specified passages based on +a user query. For more information visit: +https://developers.generativeai.google/guide +""" + +import logging +from typing import TYPE_CHECKING, Any, List, Optional, Sequence, cast + +from llama_index.bridge.pydantic import BaseModel # type: ignore +from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.indices.query.schema import QueryBundle +from llama_index.prompts.mixin import PromptDictType +from llama_index.response.schema import Response +from llama_index.response_synthesizers.base import BaseSynthesizer, QueryTextType +from llama_index.schema import MetadataMode, NodeWithScore, TextNode +from llama_index.types import RESPONSE_TEXT_TYPE +from llama_index.vector_stores.google.generativeai import google_service_context + +if TYPE_CHECKING: + import google.ai.generativelanguage as genai + + +_logger = logging.getLogger(__name__) +_import_err_msg = "`google.generativeai` package not found, please run `pip install google-generativeai`" +_separator = "\n\n" + + +class SynthesizedResponse(BaseModel): + """Response of `GoogleTextSynthesizer.get_response`.""" + + answer: str + """The grounded response to the user's question.""" + + attributed_passages: List[str] + """The list of passages the AQA model used for its response.""" + + answerable_probability: float + """The model's estimate of the probability that its answer is correct and grounded in the input passages.""" + + +class GoogleTextSynthesizer(BaseSynthesizer): + """Google's Attributed Question and Answering service. + + Given a user's query and a list of passages, Google's server will return + a response that is grounded to the provided list of passages. It will not + base the response on parametric memory. + """ + + _client: Any + _temperature: float + _answer_style: Any + _safety_setting: List[Any] + + def __init__( + self, + *, + temperature: float, + answer_style: Any, + safety_setting: List[Any], + **kwargs: Any, + ): + """Create a new Google AQA. + + Prefer to use the factory `from_defaults` instead for type safety. + See `from_defaults` for more documentation. + """ + try: + import llama_index.vector_stores.google.generativeai.genai_extension as genaix + except ImportError: + raise ImportError(_import_err_msg) + + super().__init__( + service_context=google_service_context, + output_cls=SynthesizedResponse, + **kwargs, + ) + + self._client = genaix.build_generative_service() + self._temperature = temperature + self._answer_style = answer_style + self._safety_setting = safety_setting + + # Type safe factory that is only available if Google is installed. + @classmethod + def from_defaults( + cls, + temperature: float = 0.7, + answer_style: int = 1, + safety_setting: List["genai.SafetySetting"] = [], + ) -> "GoogleTextSynthesizer": + """Create a new Google AQA. + + Example: + responder = GoogleTextSynthesizer.create( + temperature=0.7, + answer_style=AnswerStyle.ABSTRACTIVE, + safety_setting=[ + SafetySetting( + category=HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + ] + ) + + Args: + temperature: 0.0 to 1.0. + answer_style: See `google.ai.generativelanguage.GenerateAnswerRequest.AnswerStyle` + The default is ABSTRACTIVE (1). + safety_setting: See `google.ai.generativelanguage.SafetySetting`. + + Returns: + an instance of GoogleTextSynthesizer. + """ + return cls( + temperature=temperature, + answer_style=answer_style, + safety_setting=safety_setting, + ) + + def get_response( + self, + query_str: str, + text_chunks: Sequence[str], + **response_kwargs: Any, + ) -> SynthesizedResponse: + """Generate a grounded response on provided passages. + + Args: + query_str: The user's question. + text_chunks: A list of passages that should be used to answer the + question. + + Returns: + A `SynthesizedResponse` object. + """ + try: + import google.ai.generativelanguage as genai + + import llama_index.vector_stores.google.generativeai.genai_extension as genaix + except ImportError: + raise ImportError(_import_err_msg) + + client = cast(genai.GenerativeServiceClient, self._client) + response = genaix.generate_answer( + prompt=query_str, + passages=list(text_chunks), + answer_style=self._answer_style, + safety_settings=self._safety_setting, + temperature=self._temperature, + client=client, + ) + + return SynthesizedResponse( + answer=response.answer, + attributed_passages=[ + passage.text for passage in response.attributed_passages + ], + answerable_probability=response.answerable_probability, + ) + + async def aget_response( + self, + query_str: str, + text_chunks: Sequence[str], + **response_kwargs: Any, + ) -> RESPONSE_TEXT_TYPE: + # TODO: Implement a true async version. + return self.get_response(query_str, text_chunks, **response_kwargs) + + def synthesize( + self, + query: QueryTextType, + nodes: List[NodeWithScore], + additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, + **response_kwargs: Any, + ) -> Response: + """Returns a grounded response based on provided passages. + + Returns: + Response's `source_nodes` will begin with a list of attributed + passages. These passages are the ones that were used to construct + the grounded response. These passages will always have no score, + the only way to mark them as attributed passages. Then, the list + will follow with the originally provided passages, which will have + a score from the retrieval. + + Response's `metadata` may also have have an entry with key + `answerable_probability`, which is the model's estimate of the + probability that its answer is correct and grounded in the input + passages. + """ + if len(nodes) == 0: + return Response("Empty Response") + + if isinstance(query, str): + query = QueryBundle(query_str=query) + + with self._callback_manager.event( + CBEventType.SYNTHESIZE, payload={EventPayload.QUERY_STR: query.query_str} + ) as event: + internal_response = self.get_response( + query_str=query.query_str, + text_chunks=[ + n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes + ], + **response_kwargs, + ) + + additional_source_nodes = list(additional_source_nodes or []) + + external_response = self._prepare_external_response( + internal_response, nodes + additional_source_nodes + ) + + event.on_end(payload={EventPayload.RESPONSE: external_response}) + + return external_response + + async def asynthesize( + self, + query: QueryTextType, + nodes: List[NodeWithScore], + additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, + **response_kwargs: Any, + ) -> Response: + # TODO: Implement a true async version. + return self.synthesize(query, nodes, additional_source_nodes, **response_kwargs) + + def _prepare_external_response( + self, + response: SynthesizedResponse, + source_nodes: List[NodeWithScore], + ) -> Response: + return Response( + response=response.answer, + source_nodes=[ + NodeWithScore(node=TextNode(text=passage)) + for passage in response.attributed_passages + ] + + source_nodes, + metadata={ + "answerable_probability": response.answerable_probability, + }, + ) + + def _get_prompts(self) -> PromptDictType: + # Not used. + return {} + + def _update_prompts(self, prompts_dict: PromptDictType) -> None: + # Not used. + ... diff --git a/llama_index/vector_stores/google/generativeai/__init__.py b/llama_index/vector_stores/google/generativeai/__init__.py new file mode 100644 index 0000000000..bd0db8bc77 --- /dev/null +++ b/llama_index/vector_stores/google/generativeai/__init__.py @@ -0,0 +1,6 @@ +from .base import GoogleVectorStore, google_service_context + +__all__ = [ + "google_service_context", + "GoogleVectorStore", +] diff --git a/llama_index/vector_stores/google/generativeai/base.py b/llama_index/vector_stores/google/generativeai/base.py new file mode 100644 index 0000000000..3ed4dab925 --- /dev/null +++ b/llama_index/vector_stores/google/generativeai/base.py @@ -0,0 +1,399 @@ +"""Google Generative AI Vector Store. + +The GenAI Semantic Retriever API is a managed end-to-end service that allows +developers to create a corpus of documents to perform semantic search on +related passages given a user query. For more information visit: +https://developers.generativeai.google/guide +""" + +import logging +import uuid +from typing import Any, Dict, List, Optional, Sequence, cast + +from llama_index.bridge.pydantic import BaseModel, Field, PrivateAttr # type: ignore +from llama_index.indices.service_context import ServiceContext +from llama_index.schema import BaseNode, RelatedNodeInfo, TextNode +from llama_index.vector_stores.types import ( + BasePydanticVectorStore, + MetadataFilters, + VectorStoreQuery, + VectorStoreQueryResult, +) + +_logger = logging.getLogger(__name__) +_import_err_msg = "`google.generativeai` package not found, please run `pip install google-generativeai`" +_default_doc_id = "default-doc" + + +google_service_context = ServiceContext.from_defaults( + # Avoids instantiating OpenAI as the default model. + llm=None, + # Avoids instantiating HuggingFace as the default model. + embed_model=None, +) +"""Google GenerativeAI service context. + +Use this to provide the correct service context for `GoogleVectorStore`. + +See the docstring for `GoogleVectorStore` for usage example. +""" + + +class NoSuchCorpusException(Exception): + def __init__(self, *, corpus_id: str) -> None: + super().__init__(f"No such corpus {corpus_id} found") + + +class GoogleVectorStore(BasePydanticVectorStore): + """Google GenerativeAI Vector Store. + + Currently, it computes the embedding vectors on the server side. + + Example: + google_vector_store = GoogleVectorStore.from_corpus( + corpus_id="my-corpus-id") + index = VectorStoreIndex.from_vector_store( + google_vector_store, + service_context=google_service_context) + + Attributes: + corpus_id: The corpus ID that this vector store instance will read and + write to. + """ + + # Semantic Retriever stores the document node's text as string and embeds + # the vectors on the server automatically. + stores_text: bool = True + is_embedding_query: bool = False + + # This is not the Google's corpus name but an ID generated in the LlamaIndex + # world. + corpus_id: str = Field(frozen=True) + """Corpus ID that this instance of the vector store is using.""" + + _client: Any = PrivateAttr() + + def __init__(self, *, client: Any, **kwargs: Any): + """Raw constructor. + + Use the class method `from_corpus` or `create_corpus` instead. + + Args: + client: The low-level retriever class from google.ai.generativelanguage. + """ + try: + import google.ai.generativelanguage as genai + except ImportError: + raise ImportError(_import_err_msg) + + super().__init__(**kwargs) + + assert isinstance(client, genai.RetrieverServiceClient) + self._client = client + + @classmethod + def from_corpus(cls, *, corpus_id: str) -> "GoogleVectorStore": + """Create an instance that points to an existing corpus. + + Args: + corpus_id: ID of an existing corpus on Google's server. + + Returns: + An instance of the vector store that points to the specified corpus. + + Raises: + NoSuchCorpusException if no such corpus is found. + """ + try: + import llama_index.vector_stores.google.generativeai.genai_extension as genaix + except ImportError: + raise ImportError(_import_err_msg) + + _logger.debug(f"\n\nGoogleVectorStore.from_corpus(corpus_id={corpus_id})") + client = genaix.build_semantic_retriever() + if genaix.get_corpus(corpus_id=corpus_id, client=client) is None: + raise NoSuchCorpusException(corpus_id=corpus_id) + + return cls(corpus_id=corpus_id, client=client) + + @classmethod + def create_corpus( + cls, *, corpus_id: Optional[str] = None, display_name: Optional[str] = None + ) -> "GoogleVectorStore": + """Create an instance that points to a newly created corpus. + + Examples: + store = GoogleVectorStore.create_corpus() + print(f"Created corpus with ID: {store.corpus_id}) + + store = GoogleVectorStore.create_corpus( + display_name="My first corpus" + ) + + store = GoogleVectorStore.create_corpus( + corpus_id="my-corpus-1", + display_name="My first corpus" + ) + + Args: + corpus_id: ID of the new corpus to be created. If not provided, + Google server will provide one for you. + display_name: Title of the corpus. If not provided, Google server + will provide one for you. + + Returns: + An instance of the vector store that points to the specified corpus. + + Raises: + An exception if the corpus already exists or the user hits the + quota limit. + """ + try: + import llama_index.vector_stores.google.generativeai.genai_extension as genaix + except ImportError: + raise ImportError(_import_err_msg) + + _logger.debug( + f"\n\nGoogleVectorStore.create_corpus(new_corpus_id={corpus_id}, new_display_name={display_name})" + ) + + client = genaix.build_semantic_retriever() + new_corpus_id = corpus_id or str(uuid.uuid4()) + new_corpus = genaix.create_corpus( + corpus_id=new_corpus_id, display_name=display_name, client=client + ) + name = genaix.EntityName.from_str(new_corpus.name) + return cls(corpus_id=name.corpus_id, client=client) + + @classmethod + def class_name(cls) -> str: + return "GoogleVectorStore" + + @property + def client(self) -> Any: + return self._client + + def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]: + """Add nodes with embedding to vector store. + + If a node has a source node, the source node's ID will be used to create + a document. Otherwise, a default document for that corpus will be used + to house the node. + + Furthermore, if the source node has a metadata field "file_name", it + will be used as the title of the document. If the source node has no + such field, Google server will assign a title to the document. + + Example: + store = GoogleVectorStore.from_corpus(corpus_id="123") + store.add([ + TextNode( + text="Hello, my darling", + relationships={ + NodeRelationship.SOURCE: RelatedNodeInfo( + node_id="doc-456", + metadata={"file_name": "Title for doc-456"}, + ) + }, + ), + TextNode( + text="Goodbye, my baby", + relationships={ + NodeRelationship.SOURCE: RelatedNodeInfo( + node_id="doc-456", + metadata={"file_name": "Title for doc-456"}, + ) + }, + ), + ]) + + The above code will create one document with ID `doc-456` and title + `Title for doc-456`. This document will house both nodes. + """ + try: + import google.ai.generativelanguage as genai + + import llama_index.vector_stores.google.generativeai.genai_extension as genaix + except ImportError: + raise ImportError(_import_err_msg) + + _logger.debug(f"\n\nGoogleVectorStore.add(nodes={nodes})") + + client = cast(genai.RetrieverServiceClient, self.client) + + created_node_ids: List[str] = [] + for nodeGroup in _group_nodes_by_source(nodes): + source = nodeGroup.source_node + document_id = source.node_id + document = genaix.get_document( + corpus_id=self.corpus_id, document_id=document_id, client=client + ) + + if not document: + genaix.create_document( + corpus_id=self.corpus_id, + display_name=source.metadata.get("file_name", None), + document_id=document_id, + metadata=source.metadata, + client=client, + ) + + created_chunks = genaix.batch_create_chunk( + corpus_id=self.corpus_id, + document_id=document_id, + texts=[node.get_content() for node in nodeGroup.nodes], + metadatas=[node.metadata for node in nodeGroup.nodes], + client=client, + ) + created_node_ids.extend([chunk.name for chunk in created_chunks]) + + return created_node_ids + + def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """Delete nodes by ref_doc_id. + + Both the underlying nodes and the document will be deleted from Google + server. + + Args: + ref_doc_id: The document ID to be deleted. + """ + try: + import google.ai.generativelanguage as genai + + import llama_index.vector_stores.google.generativeai.genai_extension as genaix + except ImportError: + raise ImportError(_import_err_msg) + + _logger.debug(f"\n\nGoogleVectorStore.delete(ref_doc_id={ref_doc_id})") + + client = cast(genai.RetrieverServiceClient, self.client) + genaix.delete_document( + corpus_id=self.corpus_id, document_id=ref_doc_id, client=client + ) + + def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: + """Query vector store. + + Example: + store = GoogleVectorStore.from_corpus(corpus_id="123") + store.query( + query=VectorStoreQuery( + query_str="What is the meaning of life?", + # Only nodes with this author. + filters=MetadataFilters( + filters=[ + ExactMatchFilter( + key="author", + value="Arthur Schopenhauer", + ) + ] + ), + # Only from these docs. If not provided, + # the entire corpus is searched. + doc_ids=["doc-456"], + similarity_top_k=3, + ) + ) + + Args: + query: See `llama_index.vector_stores.types.VectorStoreQuery`. + """ + try: + import google.ai.generativelanguage as genai + + import llama_index.vector_stores.google.generativeai.genai_extension as genaix + except ImportError: + raise ImportError(_import_err_msg) + + _logger.debug(f"\n\nGoogleVectorStore.query(query={query})") + + query_str = query.query_str + if query_str is None: + raise ValueError("VectorStoreQuery.query_str should not be None.") + + client = cast(genai.RetrieverServiceClient, self.client) + + relevant_chunks: List[genai.RelevantChunk] = [] + if query.doc_ids is None: + # The chunks from query_corpus should be sorted in reverse order by + # relevant score. + relevant_chunks = genaix.query_corpus( + corpus_id=self.corpus_id, + query=query_str, + filter=_convert_filter(query.filters), + k=query.similarity_top_k, + client=client, + ) + else: + for doc_id in query.doc_ids: + relevant_chunks.extend( + genaix.query_document( + corpus_id=self.corpus_id, + document_id=doc_id, + query=query_str, + filter=_convert_filter(query.filters), + k=query.similarity_top_k, + client=client, + ) + ) + # Make sure the chunks are reversed sorted according to relevant + # scores even across multiple documents. + relevant_chunks.sort(key=lambda c: c.chunk_relevance_score, reverse=True) + + return VectorStoreQueryResult( + nodes=[ + TextNode( + text=chunk.chunk.data.string_value, + id_=_extract_chunk_id(chunk.chunk.name), + ) + for chunk in relevant_chunks + ], + ids=[_extract_chunk_id(chunk.chunk.name) for chunk in relevant_chunks], + similarities=[chunk.chunk_relevance_score for chunk in relevant_chunks], + ) + + +def _extract_chunk_id(entity_name: str) -> str: + try: + import llama_index.vector_stores.google.generativeai.genai_extension as genaix + except ImportError: + raise ImportError(_import_err_msg) + + id = genaix.EntityName.from_str(entity_name).chunk_id + assert id is not None + return id + + +class _NodeGroup(BaseModel): + """Every node in nodes have the same source node.""" + + source_node: RelatedNodeInfo + nodes: List[BaseNode] + + +def _group_nodes_by_source(nodes: Sequence[BaseNode]) -> List[_NodeGroup]: + """Returns a list of lists of nodes where each list has all the nodes + from the same document. + """ + groups: Dict[str, _NodeGroup] = {} + for node in nodes: + source_node: RelatedNodeInfo + if isinstance(node.source_node, RelatedNodeInfo): + source_node = node.source_node + else: + source_node = RelatedNodeInfo(node_id=_default_doc_id) + + if source_node.node_id not in groups: + groups[source_node.node_id] = _NodeGroup(source_node=source_node, nodes=[]) + + groups[source_node.node_id].nodes.append(node) + + return list(groups.values()) + + +def _convert_filter(fs: Optional[MetadataFilters]) -> Dict[str, Any]: + if fs is None: + return {} + assert isinstance(fs, MetadataFilters) + return {f.key: f.value for f in fs.filters} diff --git a/llama_index/vector_stores/google/generativeai/genai_extension.py b/llama_index/vector_stores/google/generativeai/genai_extension.py new file mode 100644 index 0000000000..eb0e77fb3c --- /dev/null +++ b/llama_index/vector_stores/google/generativeai/genai_extension.py @@ -0,0 +1,589 @@ +"""Temporary high-level library of the Google GenerativeAI API. + +The content of this file should eventually go into the Python package +google.generativeai. +""" + +import datetime +import logging +import re +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, MutableSequence, Optional + +import google.ai.generativelanguage as genai +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as gapi_exception +from google.api_core import gapic_v1 +from google.auth import credentials, exceptions +from google.protobuf import timestamp_pb2 + +import llama_index + +_logger = logging.getLogger(__name__) +_DEFAULT_API_ENDPOINT = "generativelanguage.googleapis.com" +_USER_AGENT = f"llama_index/{llama_index.__version__}" +_DEFAULT_PAGE_SIZE = 20 +_DEFAULT_GENERATE_SERVICE_MODEL = "models/aqa" +_MAX_REQUEST_PER_CHUNK = 100 +_NAME_REGEX = re.compile(r"^corpora/([^/]+?)(/documents/([^/]+?)(/chunks/([^/]+?))?)?$") + + +@dataclass +class EntityName: + corpus_id: str + document_id: Optional[str] = None + chunk_id: Optional[str] = None + + def __post_init__(self) -> None: + if self.chunk_id is not None and self.document_id is None: + raise ValueError(f"Chunk must have document ID but found {self}") + + @classmethod + def from_str(cls, encoded: str) -> "EntityName": + matched = _NAME_REGEX.match(encoded) + if not matched: + raise ValueError(f"Invalid entity name: {encoded}") + + return cls( + corpus_id=matched.group(1), + document_id=matched.group(3), + chunk_id=matched.group(5), + ) + + def __repr__(self) -> str: + name = f"corpora/{self.corpus_id}" + if self.document_id is None: + return name + name += f"/documents/{self.document_id}" + if self.chunk_id is None: + return name + name += f"/chunks/{self.chunk_id}" + return name + + def __str__(self) -> str: + return repr(self) + + def is_corpus(self) -> bool: + return self.document_id is None + + def is_document(self) -> bool: + return self.document_id is not None and self.chunk_id is None + + def is_chunk(self) -> bool: + return self.chunk_id is not None + + +@dataclass +class Corpus: + name: str + display_name: Optional[str] + create_time: Optional[timestamp_pb2.Timestamp] + update_time: Optional[timestamp_pb2.Timestamp] + + @property + def corpus_id(self) -> str: + name = EntityName.from_str(self.name) + return name.corpus_id + + @classmethod + def from_corpus(cls, c: genai.Corpus) -> "Corpus": + return cls( + name=c.name, + display_name=c.display_name, + create_time=c.create_time, + update_time=c.update_time, + ) + + +@dataclass +class Document: + name: str + display_name: Optional[str] + create_time: Optional[timestamp_pb2.Timestamp] + update_time: Optional[timestamp_pb2.Timestamp] + custom_metadata: Optional[MutableSequence[genai.CustomMetadata]] + + @property + def corpus_id(self) -> str: + name = EntityName.from_str(self.name) + return name.corpus_id + + @property + def document_id(self) -> str: + name = EntityName.from_str(self.name) + assert isinstance(name.document_id, str) + return name.document_id + + @classmethod + def from_document(cls, d: genai.Document) -> "Document": + return cls( + name=d.name, + display_name=d.display_name, + create_time=d.create_time, + update_time=d.update_time, + custom_metadata=d.custom_metadata, + ) + + +@dataclass +class Config: + """Global configuration for Google Generative AI API. + + Normally, the defaults should work fine. Change them only if you understand + why. + + Attributes: + api_endpoint: The Google Generative API endpoint address. + user_agent: The user agent to use for logging. + page_size: For paging RPCs, how many entities to return per RPC. + testing: Are the unit tests running? + """ + + api_endpoint: str = _DEFAULT_API_ENDPOINT + user_agent: str = _USER_AGENT + page_size: int = _DEFAULT_PAGE_SIZE + testing: bool = False + + +def set_defaults(config: Config) -> None: + """Set global defaults for operations with Google Generative AI API.""" + global _config + _config = config + + +_config = Config() + + +class TestCredentials(credentials.Credentials): + """Credentials that do not provide any authentication information. + + Useful for unit tests where the credentials are not used. + """ + + @property + def expired(self) -> bool: + """Returns `False`, test credentials never expire.""" + return False + + @property + def valid(self) -> bool: + """Returns `True`, test credentials are always valid.""" + return True + + def refresh(self, request: Any) -> None: + """Raises :class:``InvalidOperation``, test credentials cannot be + refreshed. + """ + raise exceptions.InvalidOperation("Test credentials cannot be refreshed.") + + def apply(self, headers: Any, token: Any = None) -> None: + """Anonymous credentials do nothing to the request. + + The optional ``token`` argument is not supported. + + Raises: + google.auth.exceptions.InvalidValue: If a token was specified. + """ + if token is not None: + raise exceptions.InvalidValue("Test credentials don't support tokens.") + + def before_request(self, request: Any, method: Any, url: Any, headers: Any) -> None: + """Test credentials do nothing to the request.""" + + +def _get_test_credentials() -> Optional[credentials.Credentials]: + """Returns a fake credential for testing or None. + + If _config.testing is True, a fake credential is returned. + Otherwise, we are in a real environment and a None is returned. + + If None is passed to the clients later on, the actual credentials will be + inferred by the rules specified in google.auth package. + """ + return TestCredentials() if _config.testing else None + + +def build_semantic_retriever() -> genai.RetrieverServiceClient: + credentials = _get_test_credentials() + return genai.RetrieverServiceClient( + credentials=credentials, + client_info=gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT), + client_options=client_options_lib.ClientOptions( + api_endpoint=_config.api_endpoint + ), + ) + + +def build_generative_service() -> genai.GenerativeServiceClient: + credentials = _get_test_credentials() + return genai.GenerativeServiceClient( + credentials=credentials, + client_info=gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT), + client_options=client_options_lib.ClientOptions( + api_endpoint=_config.api_endpoint + ), + ) + + +def list_corpora( + *, + client: genai.RetrieverServiceClient, +) -> Iterator[Corpus]: + for corpus in client.list_corpora( + genai.ListCorporaRequest(page_size=_config.page_size) + ): + yield Corpus.from_corpus(corpus) + + +def get_corpus( + *, + corpus_id: str, + client: genai.RetrieverServiceClient, +) -> Optional[Corpus]: + try: + corpus = client.get_corpus( + genai.GetCorpusRequest(name=str(EntityName(corpus_id=corpus_id))) + ) + return Corpus.from_corpus(corpus) + except Exception as e: + # If the corpus does not exist, the server returns a permission error. + if not isinstance(e, gapi_exception.PermissionDenied): + raise + _logger.warning(f"Corpus {corpus_id} not found: {e}") + return None + + +def create_corpus( + *, + corpus_id: Optional[str] = None, + display_name: Optional[str] = None, + client: genai.RetrieverServiceClient, +) -> Corpus: + name: Optional[str] + if corpus_id is not None: + name = str(EntityName(corpus_id=corpus_id)) + else: + name = None + + new_display_name = display_name or f"Untitled {datetime.datetime.now()}" + + new_corpus = client.create_corpus( + genai.CreateCorpusRequest( + corpus=genai.Corpus(name=name, display_name=new_display_name) + ) + ) + + return Corpus.from_corpus(new_corpus) + + +def delete_corpus( + *, + corpus_id: str, + client: genai.RetrieverServiceClient, +) -> None: + client.delete_corpus( + genai.DeleteCorpusRequest(name=str(EntityName(corpus_id=corpus_id)), force=True) + ) + + +def list_documents( + *, + corpus_id: str, + client: genai.RetrieverServiceClient, +) -> Iterator[Document]: + for document in client.list_documents( + genai.ListDocumentsRequest( + parent=str(EntityName(corpus_id=corpus_id)), page_size=_DEFAULT_PAGE_SIZE + ) + ): + yield Document.from_document(document) + + +def get_document( + *, + corpus_id: str, + document_id: str, + client: genai.RetrieverServiceClient, +) -> Optional[Document]: + try: + document = client.get_document( + genai.GetDocumentRequest( + name=str(EntityName(corpus_id=corpus_id, document_id=document_id)) + ) + ) + return Document.from_document(document) + except Exception as e: + if not isinstance(e, gapi_exception.NotFound): + raise + _logger.warning(f"Document {document_id} in corpus {corpus_id} not found: {e}") + return None + + +def create_document( + *, + corpus_id: str, + document_id: Optional[str] = None, + display_name: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + client: genai.RetrieverServiceClient, +) -> Document: + name: Optional[str] + if document_id is not None: + name = str(EntityName(corpus_id=corpus_id, document_id=document_id)) + else: + name = None + + new_display_name = display_name or f"Untitled {datetime.datetime.now()}" + new_metadatas = _convert_to_metadata(metadata) if metadata else None + + new_document = client.create_document( + genai.CreateDocumentRequest( + parent=str(EntityName(corpus_id=corpus_id)), + document=genai.Document( + name=name, display_name=new_display_name, custom_metadata=new_metadatas + ), + ) + ) + + return Document.from_document(new_document) + + +def delete_document( + *, + corpus_id: str, + document_id: str, + client: genai.RetrieverServiceClient, +) -> None: + client.delete_document( + genai.DeleteDocumentRequest( + name=str(EntityName(corpus_id=corpus_id, document_id=document_id)), + force=True, + ) + ) + + +def batch_create_chunk( + *, + corpus_id: str, + document_id: str, + texts: List[str], + metadatas: Optional[List[Dict[str, Any]]] = None, + client: genai.RetrieverServiceClient, +) -> List[genai.Chunk]: + if metadatas is None: + metadatas = [{} for _ in texts] + if len(texts) != len(metadatas): + raise ValueError( + f"metadatas's length {len(metadatas)} and texts's length {len(texts)} are mismatched" + ) + + doc_name = str(EntityName(corpus_id=corpus_id, document_id=document_id)) + + created_chunks: List[genai.Chunk] = [] + + batch_request = genai.BatchCreateChunksRequest( + parent=doc_name, + requests=[], + ) + for text, metadata in zip(texts, metadatas): + batch_request.requests.append( + genai.CreateChunkRequest( + parent=doc_name, + chunk=genai.Chunk( + data=genai.ChunkData(string_value=text), + custom_metadata=_convert_to_metadata(metadata), + ), + ) + ) + + if len(batch_request.requests) >= _MAX_REQUEST_PER_CHUNK: + response = client.batch_create_chunks(batch_request) + created_chunks.extend(list(response.chunks)) + # Prepare a new batch for next round. + batch_request = genai.BatchCreateChunksRequest( + parent=doc_name, + requests=[], + ) + + # Process left over. + if len(batch_request.requests) > 0: + response = client.batch_create_chunks(batch_request) + created_chunks.extend(list(response.chunks)) + + return created_chunks + + +def delete_chunk( + *, + corpus_id: str, + document_id: str, + chunk_id: str, + client: genai.RetrieverServiceClient, +) -> None: + client.delete_chunk( + genai.DeleteChunkRequest( + name=str( + EntityName( + corpus_id=corpus_id, document_id=document_id, chunk_id=chunk_id + ) + ) + ) + ) + + +def query_corpus( + *, + corpus_id: str, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + client: genai.RetrieverServiceClient, +) -> List[genai.RelevantChunk]: + response = client.query_corpus( + genai.QueryCorpusRequest( + name=str(EntityName(corpus_id=corpus_id)), + query=query, + metadata_filters=_convert_filter(filter), + results_count=k, + ) + ) + return list(response.relevant_chunks) + + +def query_document( + *, + corpus_id: str, + document_id: str, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + client: genai.RetrieverServiceClient, +) -> List[genai.RelevantChunk]: + response = client.query_document( + genai.QueryDocumentRequest( + name=str(EntityName(corpus_id=corpus_id, document_id=document_id)), + query=query, + metadata_filters=_convert_filter(filter), + results_count=k, + ) + ) + return list(response.relevant_chunks) + + +@dataclass +class Passage: + text: str + id: str + + +@dataclass +class GroundedAnswer: + answer: str + attributed_passages: List[Passage] + answerable_probability: Optional[float] + + +@dataclass +class GenerateAnswerError(Exception): + finish_reason: genai.Candidate.FinishReason + finish_message: str + safety_ratings: MutableSequence[genai.SafetyRating] + + def __str__(self) -> str: + return ( + f"finish_reason: {self.finish_reason.name} " + f"finish_message: {self.finish_message} " + f"safety ratings: {self.safety_ratings}" + ) + + +def generate_answer( + *, + prompt: str, + passages: List[str], + answer_style: int = genai.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE, + safety_settings: List[genai.SafetySetting] = [], + temperature: Optional[float] = None, + client: genai.GenerativeServiceClient, +) -> GroundedAnswer: + # TODO: Consider passing in the corpus ID instead of the actual + # passages. + response = client.generate_answer( + genai.GenerateAnswerRequest( + contents=[ + genai.Content(parts=[genai.Part(text=prompt)]), + ], + model=_DEFAULT_GENERATE_SERVICE_MODEL, + answer_style=answer_style, + safety_settings=safety_settings, + temperature=temperature, + inline_passages=genai.GroundingPassages( + passages=[ + genai.GroundingPassage( + # IDs here takes alphanumeric only. No dashes allowed. + id=str(index), + content=genai.Content(parts=[genai.Part(text=chunk)]), + ) + for index, chunk in enumerate(passages) + ] + ), + ) + ) + + if response.answer.finish_reason != genai.Candidate.FinishReason.STOP: + raise GenerateAnswerError( + finish_reason=response.answer.finish_reason, + finish_message=response.answer.finish_message, + safety_ratings=response.answer.safety_ratings, + ) + + assert len(response.answer.content.parts) == 1 + return GroundedAnswer( + answer=response.answer.content.parts[0].text, + attributed_passages=[ + Passage( + text=passage.content.parts[0].text, + id=passage.source_id.grounding_passage.passage_id, + ) + for passage in response.answer.grounding_attributions + if len(passage.content.parts) > 0 + ], + answerable_probability=response.answerable_probability, + ) + + +def _convert_to_metadata(metadata: Dict[str, Any]) -> List[genai.CustomMetadata]: + cs: List[genai.CustomMetadata] = [] + for key, value in metadata.items(): + if isinstance(value, str): + c = genai.CustomMetadata(key=key, string_value=value) + elif isinstance(value, (float, int)): + c = genai.CustomMetadata(key=key, numeric_value=value) + else: + raise ValueError(f"Metadata value {value} is not supported") + + cs.append(c) + return cs + + +def _convert_filter(fs: Optional[Dict[str, Any]]) -> List[genai.MetadataFilter]: + if fs is None: + return [] + assert isinstance(fs, dict) + + filters: List[genai.MetadataFilter] = [] + for key, value in fs.items(): + if isinstance(value, str): + condition = genai.Condition( + operation=genai.Condition.Operator.EQUAL, string_value=value + ) + elif isinstance(value, (float, int)): + condition = genai.Condition( + operation=genai.Condition.Operator.EQUAL, numeric_value=value + ) + else: + raise ValueError(f"Filter value {value} is not supported") + + filters.append(genai.MetadataFilter(key=key, conditions=[condition])) + + return filters diff --git a/poetry.lock b/poetry.lock index 0204408c78..2d77252643 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "accelerate" @@ -153,24 +153,24 @@ files = [ [[package]] name = "anyio" -version = "3.7.1" +version = "4.1.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "anyio-3.7.1-py3-none-any.whl", hash = "sha256:91dee416e570e92c64041bd18b900d1d6fa78dff7048769ce5ac5ddad004fbb5"}, - {file = "anyio-3.7.1.tar.gz", hash = "sha256:44a3c9aba0f5defa43261a8b3efb97891f2bd7d804e0e1f56419befa1adfc780"}, + {file = "anyio-4.1.0-py3-none-any.whl", hash = "sha256:56a415fbc462291813a94528a779597226619c8e78af7de0507333f700011e5f"}, + {file = "anyio-4.1.0.tar.gz", hash = "sha256:5a0bec7085176715be77df87fc66d6c9d70626bd752fcc85f57cdbee5b3760da"}, ] [package.dependencies] -exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} idna = ">=2.8" sniffio = ">=1.1" [package.extras] -doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-jquery"] -test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] -trio = ["trio (<0.22)"] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.23)"] [[package]] name = "appnope" @@ -596,13 +596,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.33.10" +version = "1.33.12" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">= 3.7" files = [ - {file = "botocore-1.33.10-py3-none-any.whl", hash = "sha256:9619609692a7f99f33093a305a0fb88ca5c83104f97ead3a4405b1e4ba6058f9"}, - {file = "botocore-1.33.10.tar.gz", hash = "sha256:82be3da9ceac9d847d115a80f0a0dae020c3534ef88839ef907eb3205309fd4a"}, + {file = "botocore-1.33.12-py3-none-any.whl", hash = "sha256:48b9cfb9c5f7f9634a71782f16a324acb522b65856ad46be69efe04c3322b23c"}, + {file = "botocore-1.33.12.tar.gz", hash = "sha256:067c94fa88583c04ae897d48a11d2be09f280363b8e794b82d78d631d3a3e910"}, ] [package.dependencies] @@ -1479,21 +1479,18 @@ sphinx-basic-ng = "*" [[package]] name = "google-ai-generativelanguage" -version = "0.3.3" +version = "0.4.0" description = "Google Ai Generativelanguage API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google-ai-generativelanguage-0.3.3.tar.gz", hash = "sha256:4b59993e0fd63593171cbb089e7f76f71a4333a62741d3929159aeb2e3532a83"}, - {file = "google_ai_generativelanguage-0.3.3-py3-none-any.whl", hash = "sha256:2696fe952ceea233e1a95b89a428b7dd587eac6687bd20cc62edd5c8abc32b98"}, + {file = "google-ai-generativelanguage-0.4.0.tar.gz", hash = "sha256:c8199066c08f74c4e91290778329bb9f357ba1ea5d6f82de2bc0d10552bf4f8c"}, + {file = "google_ai_generativelanguage-0.4.0-py3-none-any.whl", hash = "sha256:e4c425376c1ee26c78acbc49a24f735f90ebfa81bf1a06495fae509a2433232c"}, ] [package.dependencies] google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} -proto-plus = [ - {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""}, - {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, -] +proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" [[package]] @@ -1511,11 +1508,11 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -1528,13 +1525,13 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] [[package]] name = "google-auth" -version = "2.25.1" +version = "2.25.2" description = "Google Authentication Library" optional = false python-versions = ">=3.7" files = [ - {file = "google-auth-2.25.1.tar.gz", hash = "sha256:d5d66b8f4f6e3273740d7bb73ddefa6c2d1ff691704bd407d51c6b5800e7c97b"}, - {file = "google_auth-2.25.1-py2.py3-none-any.whl", hash = "sha256:dfd7b44935d498e106c08883b2dac0ad36d8aa10402a6412e9a1c9d74b4773f1"}, + {file = "google-auth-2.25.2.tar.gz", hash = "sha256:42f707937feb4f5e5a39e6c4f343a17300a459aaf03141457ba505812841cc40"}, + {file = "google_auth-2.25.2-py2.py3-none-any.whl", hash = "sha256:473a8dfd0135f75bb79d878436e568f2695dce456764bf3a02b6f8c540b1d256"}, ] [package.dependencies] @@ -1549,26 +1546,6 @@ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] reauth = ["pyu2f (>=0.1.5)"] requests = ["requests (>=2.20.0,<3.0.0.dev0)"] -[[package]] -name = "google-generativeai" -version = "0.2.2" -description = "Google Generative AI High level API client library and tools." -optional = false -python-versions = ">=3.9" -files = [ - {file = "google_generativeai-0.2.2-py3-none-any.whl", hash = "sha256:0fc3e61fbaeddaca590d30cfa1a4b2945db85d2a782f31eef20982457f4cb31f"}, -] - -[package.dependencies] -google-ai-generativelanguage = "0.3.3" -google-api-core = "*" -google-auth = "*" -protobuf = "*" -tqdm = "*" - -[package.extras] -dev = ["absl-py", "black", "nose2", "pandas", "pytype", "pyyaml"] - [[package]] name = "googleapis-common-protos" version = "1.62.0" @@ -1604,68 +1581,69 @@ requests = "*" [[package]] name = "greenlet" -version = "3.0.1" +version = "3.0.2" description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.7" files = [ - {file = "greenlet-3.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f89e21afe925fcfa655965ca8ea10f24773a1791400989ff32f467badfe4a064"}, - {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28e89e232c7593d33cac35425b58950789962011cc274aa43ef8865f2e11f46d"}, - {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8ba29306c5de7717b5761b9ea74f9c72b9e2b834e24aa984da99cbfc70157fd"}, - {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19bbdf1cce0346ef7341705d71e2ecf6f41a35c311137f29b8a2dc2341374565"}, - {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:599daf06ea59bfedbec564b1692b0166a0045f32b6f0933b0dd4df59a854caf2"}, - {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b641161c302efbb860ae6b081f406839a8b7d5573f20a455539823802c655f63"}, - {file = "greenlet-3.0.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d57e20ba591727da0c230ab2c3f200ac9d6d333860d85348816e1dca4cc4792e"}, - {file = "greenlet-3.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5805e71e5b570d490938d55552f5a9e10f477c19400c38bf1d5190d760691846"}, - {file = "greenlet-3.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:52e93b28db27ae7d208748f45d2db8a7b6a380e0d703f099c949d0f0d80b70e9"}, - {file = "greenlet-3.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f7bfb769f7efa0eefcd039dd19d843a4fbfbac52f1878b1da2ed5793ec9b1a65"}, - {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91e6c7db42638dc45cf2e13c73be16bf83179f7859b07cfc139518941320be96"}, - {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1757936efea16e3f03db20efd0cd50a1c86b06734f9f7338a90c4ba85ec2ad5a"}, - {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19075157a10055759066854a973b3d1325d964d498a805bb68a1f9af4aaef8ec"}, - {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9d21aaa84557d64209af04ff48e0ad5e28c5cca67ce43444e939579d085da72"}, - {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2847e5d7beedb8d614186962c3d774d40d3374d580d2cbdab7f184580a39d234"}, - {file = "greenlet-3.0.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:97e7ac860d64e2dcba5c5944cfc8fa9ea185cd84061c623536154d5a89237884"}, - {file = "greenlet-3.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b2c02d2ad98116e914d4f3155ffc905fd0c025d901ead3f6ed07385e19122c94"}, - {file = "greenlet-3.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:22f79120a24aeeae2b4471c711dcf4f8c736a2bb2fabad2a67ac9a55ea72523c"}, - {file = "greenlet-3.0.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:100f78a29707ca1525ea47388cec8a049405147719f47ebf3895e7509c6446aa"}, - {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60d5772e8195f4e9ebf74046a9121bbb90090f6550f81d8956a05387ba139353"}, - {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:daa7197b43c707462f06d2c693ffdbb5991cbb8b80b5b984007de431493a319c"}, - {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea6b8aa9e08eea388c5f7a276fabb1d4b6b9d6e4ceb12cc477c3d352001768a9"}, - {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d11ebbd679e927593978aa44c10fc2092bc454b7d13fdc958d3e9d508aba7d0"}, - {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dbd4c177afb8a8d9ba348d925b0b67246147af806f0b104af4d24f144d461cd5"}, - {file = "greenlet-3.0.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:20107edf7c2c3644c67c12205dc60b1bb11d26b2610b276f97d666110d1b511d"}, - {file = "greenlet-3.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8bef097455dea90ffe855286926ae02d8faa335ed8e4067326257cb571fc1445"}, - {file = "greenlet-3.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:b2d3337dcfaa99698aa2377c81c9ca72fcd89c07e7eb62ece3f23a3fe89b2ce4"}, - {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80ac992f25d10aaebe1ee15df45ca0d7571d0f70b645c08ec68733fb7a020206"}, - {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:337322096d92808f76ad26061a8f5fccb22b0809bea39212cd6c406f6a7060d2"}, - {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9934adbd0f6e476f0ecff3c94626529f344f57b38c9a541f87098710b18af0a"}, - {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc4d815b794fd8868c4d67602692c21bf5293a75e4b607bb92a11e821e2b859a"}, - {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41bdeeb552d814bcd7fb52172b304898a35818107cc8778b5101423c9017b3de"}, - {file = "greenlet-3.0.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:6e6061bf1e9565c29002e3c601cf68569c450be7fc3f7336671af7ddb4657166"}, - {file = "greenlet-3.0.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:fa24255ae3c0ab67e613556375a4341af04a084bd58764731972bcbc8baeba36"}, - {file = "greenlet-3.0.1-cp37-cp37m-win32.whl", hash = "sha256:b489c36d1327868d207002391f662a1d163bdc8daf10ab2e5f6e41b9b96de3b1"}, - {file = "greenlet-3.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f33f3258aae89da191c6ebaa3bc517c6c4cbc9b9f689e5d8452f7aedbb913fa8"}, - {file = "greenlet-3.0.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:d2905ce1df400360463c772b55d8e2518d0e488a87cdea13dd2c71dcb2a1fa16"}, - {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a02d259510b3630f330c86557331a3b0e0c79dac3d166e449a39363beaae174"}, - {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:55d62807f1c5a1682075c62436702aaba941daa316e9161e4b6ccebbbf38bda3"}, - {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3fcc780ae8edbb1d050d920ab44790201f027d59fdbd21362340a85c79066a74"}, - {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4eddd98afc726f8aee1948858aed9e6feeb1758889dfd869072d4465973f6bfd"}, - {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eabe7090db68c981fca689299c2d116400b553f4b713266b130cfc9e2aa9c5a9"}, - {file = "greenlet-3.0.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:f2f6d303f3dee132b322a14cd8765287b8f86cdc10d2cb6a6fae234ea488888e"}, - {file = "greenlet-3.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d923ff276f1c1f9680d32832f8d6c040fe9306cbfb5d161b0911e9634be9ef0a"}, - {file = "greenlet-3.0.1-cp38-cp38-win32.whl", hash = "sha256:0b6f9f8ca7093fd4433472fd99b5650f8a26dcd8ba410e14094c1e44cd3ceddd"}, - {file = "greenlet-3.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:990066bff27c4fcf3b69382b86f4c99b3652bab2a7e685d968cd4d0cfc6f67c6"}, - {file = "greenlet-3.0.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ce85c43ae54845272f6f9cd8320d034d7a946e9773c693b27d620edec825e376"}, - {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89ee2e967bd7ff85d84a2de09df10e021c9b38c7d91dead95b406ed6350c6997"}, - {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87c8ceb0cf8a5a51b8008b643844b7f4a8264a2c13fcbcd8a8316161725383fe"}, - {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d6a8c9d4f8692917a3dc7eb25a6fb337bff86909febe2f793ec1928cd97bedfc"}, - {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fbc5b8f3dfe24784cee8ce0be3da2d8a79e46a276593db6868382d9c50d97b1"}, - {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85d2b77e7c9382f004b41d9c72c85537fac834fb141b0296942d52bf03fe4a3d"}, - {file = "greenlet-3.0.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:696d8e7d82398e810f2b3622b24e87906763b6ebfd90e361e88eb85b0e554dc8"}, - {file = "greenlet-3.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:329c5a2e5a0ee942f2992c5e3ff40be03e75f745f48847f118a3cfece7a28546"}, - {file = "greenlet-3.0.1-cp39-cp39-win32.whl", hash = "sha256:cf868e08690cb89360eebc73ba4be7fb461cfbc6168dd88e2fbbe6f31812cd57"}, - {file = "greenlet-3.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:ac4a39d1abae48184d420aa8e5e63efd1b75c8444dd95daa3e03f6c6310e9619"}, - {file = "greenlet-3.0.1.tar.gz", hash = "sha256:816bd9488a94cba78d93e1abb58000e8266fa9cc2aa9ccdd6eb0696acb24005b"}, + {file = "greenlet-3.0.2-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9acd8fd67c248b8537953cb3af8787c18a87c33d4dcf6830e410ee1f95a63fd4"}, + {file = "greenlet-3.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:339c0272a62fac7e602e4e6ec32a64ff9abadc638b72f17f6713556ed011d493"}, + {file = "greenlet-3.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38878744926cec29b5cc3654ef47f3003f14bfbba7230e3c8492393fe29cc28b"}, + {file = "greenlet-3.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b3f0497db77cfd034f829678b28267eeeeaf2fc21b3f5041600f7617139e6773"}, + {file = "greenlet-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed1a8a08de7f68506a38f9a2ddb26bbd1480689e66d788fcd4b5f77e2d9ecfcc"}, + {file = "greenlet-3.0.2-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:89a6f6ddcbef4000cda7e205c4c20d319488ff03db961d72d4e73519d2465309"}, + {file = "greenlet-3.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c1f647fe5b94b51488b314c82fdda10a8756d650cee8d3cd29f657c6031bdf73"}, + {file = "greenlet-3.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9560c580c896030ff9c311c603aaf2282234643c90d1dec738a1d93e3e53cd51"}, + {file = "greenlet-3.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:2e9c5423046eec21f6651268cb674dfba97280701e04ef23d312776377313206"}, + {file = "greenlet-3.0.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:b1fd25dfc5879a82103b3d9e43fa952e3026c221996ff4d32a9c72052544835d"}, + {file = "greenlet-3.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cecfdc950dd25f25d6582952e58521bca749cf3eeb7a9bad69237024308c8196"}, + {file = "greenlet-3.0.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:edf7a1daba1f7c54326291a8cde58da86ab115b78c91d502be8744f0aa8e3ffa"}, + {file = "greenlet-3.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f4cf532bf3c58a862196b06947b1b5cc55503884f9b63bf18582a75228d9950e"}, + {file = "greenlet-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e79fb5a9fb2d0bd3b6573784f5e5adabc0b0566ad3180a028af99523ce8f6138"}, + {file = "greenlet-3.0.2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:006c1028ac0cfcc4e772980cfe73f5476041c8c91d15d64f52482fc571149d46"}, + {file = "greenlet-3.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:fefd5eb2c0b1adffdf2802ff7df45bfe65988b15f6b972706a0e55d451bffaea"}, + {file = "greenlet-3.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0c0fdb8142742ee68e97c106eb81e7d3e883cc739d9c5f2b28bc38a7bafeb6d1"}, + {file = "greenlet-3.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:8f8d14a0a4e8c670fbce633d8b9a1ee175673a695475acd838e372966845f764"}, + {file = "greenlet-3.0.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:654b84c9527182036747938b81938f1d03fb8321377510bc1854a9370418ab66"}, + {file = "greenlet-3.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd5bc4fde0842ff2b9cf33382ad0b4db91c2582db836793d58d174c569637144"}, + {file = "greenlet-3.0.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c27b142a9080bdd5869a2fa7ebf407b3c0b24bd812db925de90e9afe3c417fd6"}, + {file = "greenlet-3.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0df7eed98ea23b20e9db64d46eb05671ba33147df9405330695bcd81a73bb0c9"}, + {file = "greenlet-3.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb5d60805057d8948065338be6320d35e26b0a72f45db392eb32b70dd6dc9227"}, + {file = "greenlet-3.0.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e0e28f5233d64c693382f66d47c362b72089ebf8ac77df7e12ac705c9fa1163d"}, + {file = "greenlet-3.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e4bfa752b3688d74ab1186e2159779ff4867644d2b1ebf16db14281f0445377"}, + {file = "greenlet-3.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c42bb589e6e9f9d8bdd79f02f044dff020d30c1afa6e84c0b56d1ce8a324553c"}, + {file = "greenlet-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:b2cedf279ca38ef3f4ed0d013a6a84a7fc3d9495a716b84a5fc5ff448965f251"}, + {file = "greenlet-3.0.2-cp37-cp37m-macosx_11_0_universal2.whl", hash = "sha256:6d65bec56a7bc352bcf11b275b838df618651109074d455a772d3afe25390b7d"}, + {file = "greenlet-3.0.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0acadbc3f72cb0ee85070e8d36bd2a4673d2abd10731ee73c10222cf2dd4713c"}, + {file = "greenlet-3.0.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:14b5d999aefe9ffd2049ad19079f733c3aaa426190ffecadb1d5feacef8fe397"}, + {file = "greenlet-3.0.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f27aa32466993c92d326df982c4acccd9530fe354e938d9e9deada563e71ce76"}, + {file = "greenlet-3.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f34a765c5170c0673eb747213a0275ecc749ab3652bdbec324621ed5b2edaef"}, + {file = "greenlet-3.0.2-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:520fcb53a39ef90f5021c77606952dbbc1da75d77114d69b8d7bded4a8e1a813"}, + {file = "greenlet-3.0.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d1fceb5351ab1601903e714c3028b37f6ea722be6873f46e349a960156c05650"}, + {file = "greenlet-3.0.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:7363756cc439a503505b67983237d1cc19139b66488263eb19f5719a32597836"}, + {file = "greenlet-3.0.2-cp37-cp37m-win32.whl", hash = "sha256:d5547b462b8099b84746461e882a3eb8a6e3f80be46cb6afb8524eeb191d1a30"}, + {file = "greenlet-3.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:950e21562818f9c771989b5b65f990e76f4ac27af66e1bb34634ae67886ede2a"}, + {file = "greenlet-3.0.2-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:d64643317e76b4b41fdba659e7eca29634e5739b8bc394eda3a9127f697ed4b0"}, + {file = "greenlet-3.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f9ea7c2c9795549653b6f7569f6bc75d2c7d1f6b2854eb8ce0bc6ec3cb2dd88"}, + {file = "greenlet-3.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db4233358d3438369051a2f290f1311a360d25c49f255a6c5d10b5bcb3aa2b49"}, + {file = "greenlet-3.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ed9bf77b41798e8417657245b9f3649314218a4a17aefb02bb3992862df32495"}, + {file = "greenlet-3.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4d0df07a38e41a10dfb62c6fc75ede196572b580f48ee49b9282c65639f3965"}, + {file = "greenlet-3.0.2-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10d247260db20887ae8857c0cbc750b9170f0b067dd7d38fb68a3f2334393bd3"}, + {file = "greenlet-3.0.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a37ae53cca36823597fd5f65341b6f7bac2dd69ecd6ca01334bb795460ab150b"}, + {file = "greenlet-3.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:80d068e4b6e2499847d916ef64176811ead6bf210a610859220d537d935ec6fd"}, + {file = "greenlet-3.0.2-cp38-cp38-win32.whl", hash = "sha256:b1405614692ac986490d10d3e1a05e9734f473750d4bee3cf7d1286ef7af7da6"}, + {file = "greenlet-3.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:8756a94ed8f293450b0e91119eca2a36332deba69feb2f9ca410d35e74eae1e4"}, + {file = "greenlet-3.0.2-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:2c93cd03acb1499ee4de675e1a4ed8eaaa7227f7949dc55b37182047b006a7aa"}, + {file = "greenlet-3.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1dac09e3c0b78265d2e6d3cbac2d7c48bd1aa4b04a8ffeda3adde9f1688df2c3"}, + {file = "greenlet-3.0.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ee59c4627c8c4bb3e15949fbcd499abd6b7f4ad9e0bfcb62c65c5e2cabe0ec4"}, + {file = "greenlet-3.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18fe39d70d482b22f0014e84947c5aaa7211fb8e13dc4cc1c43ed2aa1db06d9a"}, + {file = "greenlet-3.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84bef3cfb6b6bfe258c98c519811c240dbc5b33a523a14933a252e486797c90"}, + {file = "greenlet-3.0.2-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aecea0442975741e7d69daff9b13c83caff8c13eeb17485afa65f6360a045765"}, + {file = "greenlet-3.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f260e6c2337871a52161824058923df2bbddb38bc11a5cbe71f3474d877c5bd9"}, + {file = "greenlet-3.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fc14dd9554f88c9c1fe04771589ae24db76cd56c8f1104e4381b383d6b71aff8"}, + {file = "greenlet-3.0.2-cp39-cp39-win32.whl", hash = "sha256:bfcecc984d60b20ffe30173b03bfe9ba6cb671b0be1e95c3e2056d4fe7006590"}, + {file = "greenlet-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:c235131bf59d2546bb3ebaa8d436126267392f2e51b85ff45ac60f3a26549af0"}, + {file = "greenlet-3.0.2.tar.gz", hash = "sha256:1c1129bc47266d83444c85a8e990ae22688cf05fb20d7951fd2866007c2ba9bc"}, ] [package.extras] @@ -2089,21 +2067,15 @@ arrow = ">=0.15.0" [[package]] name = "isort" -version = "5.12.0" +version = "5.13.1" description = "A Python utility / library to sort Python imports." optional = false python-versions = ">=3.8.0" files = [ - {file = "isort-5.12.0-py3-none-any.whl", hash = "sha256:f84c2818376e66cf843d497486ea8fed8700b340f308f076c6fb1229dff318b6"}, - {file = "isort-5.12.0.tar.gz", hash = "sha256:8bef7dde241278824a6d83f44a544709b065191b95b6e50894bdc722fcba0504"}, + {file = "isort-5.13.1-py3-none-any.whl", hash = "sha256:56a51732c25f94ca96f6721be206dd96a95f42950502eb26c1015d333bc6edb7"}, + {file = "isort-5.13.1.tar.gz", hash = "sha256:aaed790b463e8703fb1eddb831dfa8e8616bacde2c083bd557ef73c8189b7263"}, ] -[package.extras] -colors = ["colorama (>=0.4.3)"] -pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"] -plugins = ["setuptools"] -requirements-deprecated-finder = ["pip-api", "pipreqs"] - [[package]] name = "jedi" version = "0.19.1" @@ -2453,13 +2425,13 @@ test = ["flaky", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-sc [[package]] name = "jupyter-server-terminals" -version = "0.4.4" +version = "0.5.0" description = "A Jupyter Server Extension Providing Terminals." optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_server_terminals-0.4.4-py3-none-any.whl", hash = "sha256:75779164661cec02a8758a5311e18bb8eb70c4e86c6b699403100f1585a12a36"}, - {file = "jupyter_server_terminals-0.4.4.tar.gz", hash = "sha256:57ab779797c25a7ba68e97bcfb5d7740f2b5e8a83b5e8102b10438041a7eac5d"}, + {file = "jupyter_server_terminals-0.5.0-py3-none-any.whl", hash = "sha256:2fc0692c883bfd891f4fba0c4b4a684a37234b0ba472f2e97ed0a3888f46e1e4"}, + {file = "jupyter_server_terminals-0.5.0.tar.gz", hash = "sha256:ebcd68c9afbf98a480a533e6f3266354336e645536953b7abcc7bdeebc0154a3"}, ] [package.dependencies] @@ -2467,8 +2439,8 @@ pywinpty = {version = ">=2.0.3", markers = "os_name == \"nt\""} terminado = ">=0.8.3" [package.extras] -docs = ["jinja2", "jupyter-server", "mistune (<3.0)", "myst-parser", "nbformat", "packaging", "pydata-sphinx-theme", "sphinxcontrib-github-alt", "sphinxcontrib-openapi", "sphinxcontrib-spelling", "sphinxemoji", "tornado"] -test = ["coverage", "jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-cov", "pytest-jupyter[server] (>=0.5.3)", "pytest-timeout"] +docs = ["jinja2", "jupyter-server", "mistune (<4.0)", "myst-parser", "nbformat", "packaging", "pydata-sphinx-theme", "sphinxcontrib-github-alt", "sphinxcontrib-openapi", "sphinxcontrib-spelling", "sphinxemoji", "tornado"] +test = ["jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-jupyter[server] (>=0.5.3)", "pytest-timeout"] [[package]] name = "jupyterlab" @@ -3774,17 +3746,17 @@ sympy = "*" [[package]] name = "openai" -version = "1.3.7" +version = "1.3.8" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.3.7-py3-none-any.whl", hash = "sha256:e5c51367a910297e4d1cd33d2298fb87d7edf681edbe012873925ac16f95bee0"}, - {file = "openai-1.3.7.tar.gz", hash = "sha256:18074a0f51f9b49d1ae268c7abc36f7f33212a0c0d08ce11b7053ab2d17798de"}, + {file = "openai-1.3.8-py3-none-any.whl", hash = "sha256:ac5a17352b96db862390d2e6f51de9f7eb32e733f412467b2f160fbd3d0f2609"}, + {file = "openai-1.3.8.tar.gz", hash = "sha256:54963ff247abe185aad6ee443820e48ad9f87eb4de970acb2514bc113ced748c"}, ] [package.dependencies] -anyio = ">=3.5.0,<4" +anyio = ">=3.5.0,<5" distro = ">=1.7.0,<2" httpx = ">=0.23.0,<1" pydantic = ">=1.9.0,<3" @@ -3809,7 +3781,7 @@ files = [ [package.dependencies] coloredlogs = "*" datasets = [ - {version = "*"}, + {version = "*", optional = true, markers = "extra != \"onnxruntime\""}, {version = ">=1.2.1", optional = true, markers = "extra == \"onnxruntime\""}, ] evaluate = {version = "*", optional = true, markers = "extra == \"onnxruntime\""} @@ -3905,7 +3877,7 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.2" @@ -3963,13 +3935,13 @@ testing = ["docopt", "pytest (<6.0.0)"] [[package]] name = "pathspec" -version = "0.11.2" +version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"}, - {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, ] [[package]] @@ -6146,7 +6118,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""} +greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""} typing-extensions = ">=4.2.0" [package.extras] @@ -6725,19 +6697,19 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0, [[package]] name = "transformers" -version = "4.35.2" +version = "4.36.0" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = true python-versions = ">=3.8.0" files = [ - {file = "transformers-4.35.2-py3-none-any.whl", hash = "sha256:9dfa76f8692379544ead84d98f537be01cd1070de75c74efb13abcbc938fbe2f"}, - {file = "transformers-4.35.2.tar.gz", hash = "sha256:2d125e197d77b0cdb6c9201df9fa7e2101493272e448b9fba9341c695bee2f52"}, + {file = "transformers-4.36.0-py3-none-any.whl", hash = "sha256:e5a9d9424bcbc5008782ddd79ecbc3a50991e168cc730a14c4c89e80c61f419d"}, + {file = "transformers-4.36.0.tar.gz", hash = "sha256:64e120d252db4bdcd355288d19e857dac9d89886f9d0ac20244cb9af3142bb50"}, ] [package.dependencies] -accelerate = {version = ">=0.20.3", optional = true, markers = "extra == \"torch\""} +accelerate = {version = ">=0.21.0", optional = true, markers = "extra == \"torch\""} filelock = "*" -huggingface-hub = ">=0.16.4,<1.0" +huggingface-hub = ">=0.19.3,<1.0" numpy = ">=1.17" packaging = ">=20.0" protobuf = {version = "*", optional = true, markers = "extra == \"sentencepiece\""} @@ -6751,30 +6723,30 @@ torch = {version = ">=1.10,<1.12.0 || >1.12.0", optional = true, markers = "extr tqdm = ">=4.27" [package.extras] -accelerate = ["accelerate (>=0.20.3)"] -agents = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.10,!=1.12.0)"] -all = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +accelerate = ["accelerate (>=0.21.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.10,!=1.12.0)"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] codecarbon = ["codecarbon (==1.2.0)"] -deepspeed = ["accelerate (>=0.20.3)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -docs = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] docs-specific = ["hf-doc-builder"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] ftfy = ["ftfy"] -integrations = ["optuna", "ray[tune]", "sigopt"] +integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] modelcreation = ["cookiecutter (==1.7.3)"] natten = ["natten (>=0.14.6)"] onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "black (>=23.1,<24.0)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (>=0.0.241,<=0.0.259)", "urllib3 (<2.0.0)"] -ray = ["ray[tune]"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"] +ray = ["ray[tune] (>=2.7.0)"] retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] sagemaker = ["sagemaker (>=2.31.0)"] sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] @@ -6782,18 +6754,18 @@ serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] -tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx"] -tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] timm = ["timm"] tokenizers = ["tokenizers (>=0.14,<0.19)"] -torch = ["accelerate (>=0.20.3)", "torch (>=1.10,!=1.12.0)"] +torch = ["accelerate (>=0.21.0)", "torch (>=1.10,!=1.12.0)"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -torch-vision = ["Pillow (<10.0.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.16.4,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] +torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] -vision = ["Pillow (<10.0.0)"] +vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "tree-sitter" @@ -7021,6 +6993,17 @@ files = [ {file = "types_docutils-0.20.0.3-py3-none-any.whl", hash = "sha256:a930150d8e01a9170f9bca489f46808ddebccdd8bc1e47c07968a77e49fb9321"}, ] +[[package]] +name = "types-protobuf" +version = "4.24.0.4" +description = "Typing stubs for protobuf" +optional = false +python-versions = ">=3.7" +files = [ + {file = "types-protobuf-4.24.0.4.tar.gz", hash = "sha256:57ab42cb171dfdba2c74bb5b50c250478538cc3c5ed95b8b368929ad0c9f90a5"}, + {file = "types_protobuf-4.24.0.4-py3-none-any.whl", hash = "sha256:131ab7d0cbc9e444bc89c994141327dcce7bcaeded72b1acb72a94827eb9c7af"}, +] + [[package]] name = "types-pyopenssl" version = "23.3.0.0" @@ -7113,13 +7096,13 @@ files = [ [[package]] name = "typing-extensions" -version = "4.8.0" +version = "4.9.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"}, - {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, + {file = "typing_extensions-4.9.0-py3-none-any.whl", hash = "sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd"}, + {file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"}, ] [[package]] @@ -7653,4 +7636,4 @@ query-tools = ["guidance", "jsonpath-ng", "lm-format-enforcer", "rank-bm25", "sc [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "59333a1ca7794f76ec07498891f2adeadf8536659538096b188e7a142daf1884" +content-hash = "46961671af5ebcf0f66093857b641a09daa501b3751f60cf5bdf161ccb3dfedd" diff --git a/pyproject.toml b/pyproject.toml index 3aae2033a8..1c1f982d3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ rank-bm25 = {optional = true, version = "^0.2.2"} scikit-learn = {optional = true, version = "*"} spacy = {optional = true, version = "^3.7.1"} aiohttp = "^3.8.6" +types-protobuf = "^4.24.0.4" [tool.poetry.extras] langchain = [ @@ -103,7 +104,7 @@ query_tools = [ black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} boto3 = "1.33.6" # needed for tests codespell = {extras = ["toml"], version = ">=v2.2.6"} -google-generativeai = {python = ">=3.9,<3.12", version = "^0.2.1"} +google-ai-generativelanguage = {python = ">=3.9,<3.12", version = "^0.4.0"} ipython = "8.10.0" jupyter = "^1.0.0" mypy = "0.991" diff --git a/tests/indices/managed/__init__.py b/tests/indices/managed/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/indices/managed/test_google.py b/tests/indices/managed/test_google.py new file mode 100644 index 0000000000..19146da69e --- /dev/null +++ b/tests/indices/managed/test_google.py @@ -0,0 +1,206 @@ +from unittest.mock import MagicMock, patch + +import pytest +from llama_index.response.schema import Response +from llama_index.schema import Document + +try: + import google.ai.generativelanguage as genai + + has_google = True +except ImportError: + has_google = False + +from llama_index.indices.managed.google.generativeai import GoogleIndex + +SKIP_TEST_REASON = "Google GenerativeAI is not installed" + + +if has_google: + import llama_index.vector_stores.google.generativeai.genai_extension as genaix + + genaix.set_defaults( + genaix.Config(api_endpoint="No-such-endpoint-to-prevent-hitting-real-backend") + ) + + +@pytest.mark.skipif(not has_google, reason=SKIP_TEST_REASON) +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_corpus") +def test_from_corpus(mock_get_corpus: MagicMock) -> None: + # Arrange + mock_get_corpus.return_value = genai.Corpus(name="corpora/123") + + # Act + store = GoogleIndex.from_corpus(corpus_id="123") + + # Assert + assert store.corpus_id == "123" + + +@pytest.mark.skipif(not has_google, reason=SKIP_TEST_REASON) +@patch("google.ai.generativelanguage.RetrieverServiceClient.create_corpus") +def test_create_corpus(mock_create_corpus: MagicMock) -> None: + def fake_create_corpus(request: genai.CreateCorpusRequest) -> genai.Corpus: + return request.corpus + + # Arrange + mock_create_corpus.side_effect = fake_create_corpus + + # Act + store = GoogleIndex.create_corpus(display_name="My first corpus") + + # Assert + assert len(store.corpus_id) > 0 + assert mock_create_corpus.call_count == 1 + + request = mock_create_corpus.call_args.args[0] + assert request.corpus.name == f"corpora/{store.corpus_id}" + assert request.corpus.display_name == "My first corpus" + + +@pytest.mark.skipif(not has_google, reason=SKIP_TEST_REASON) +@patch("google.ai.generativelanguage.RetrieverServiceClient.create_corpus") +@patch("google.ai.generativelanguage.RetrieverServiceClient.create_document") +@patch("google.ai.generativelanguage.RetrieverServiceClient.batch_create_chunks") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_document") +def test_from_documents( + mock_get_document: MagicMock, + mock_batch_create_chunk: MagicMock, + mock_create_document: MagicMock, + mock_create_corpus: MagicMock, +) -> None: + from google.api_core import exceptions as gapi_exception + + def fake_create_corpus(request: genai.CreateCorpusRequest) -> genai.Corpus: + return request.corpus + + # Arrange + mock_get_document.side_effect = gapi_exception.NotFound("") + mock_create_corpus.side_effect = fake_create_corpus + mock_create_document.return_value = genai.Document(name="corpora/123/documents/456") + mock_batch_create_chunk.side_effect = [ + genai.BatchCreateChunksResponse( + chunks=[ + genai.Chunk(name="corpora/123/documents/456/chunks/777"), + ] + ), + genai.BatchCreateChunksResponse( + chunks=[ + genai.Chunk(name="corpora/123/documents/456/chunks/888"), + ] + ), + ] + + # Act + index = GoogleIndex.from_documents( + [ + Document(text="Hello, my darling"), + Document(text="Goodbye, my baby"), + ] + ) + + # Assert + assert mock_create_corpus.call_count == 1 + create_corpus_request = mock_create_corpus.call_args.args[0] + assert create_corpus_request.corpus.name == f"corpora/{index.corpus_id}" + + create_document_request = mock_create_document.call_args.args[0] + assert create_document_request.parent == f"corpora/{index.corpus_id}" + + assert mock_batch_create_chunk.call_count == 2 + + first_batch_request = mock_batch_create_chunk.call_args_list[0].args[0] + assert ( + first_batch_request.requests[0].chunk.data.string_value == "Hello, my darling" + ) + + second_batch_request = mock_batch_create_chunk.call_args_list[1].args[0] + assert ( + second_batch_request.requests[0].chunk.data.string_value == "Goodbye, my baby" + ) + + +@pytest.mark.skipif(not has_google, reason=SKIP_TEST_REASON) +@patch("google.ai.generativelanguage.RetrieverServiceClient.query_corpus") +@patch("google.ai.generativelanguage.GenerativeServiceClient.generate_answer") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_corpus") +def test_as_query_engine( + mock_get_corpus: MagicMock, + mock_generate_answer: MagicMock, + mock_query_corpus: MagicMock, +) -> None: + # Arrange + mock_get_corpus.return_value = genai.Corpus(name="corpora/123") + mock_query_corpus.return_value = genai.QueryCorpusResponse( + relevant_chunks=[ + genai.RelevantChunk( + chunk=genai.Chunk( + name="corpora/123/documents/456/chunks/789", + data=genai.ChunkData(string_value="It's 42"), + ), + chunk_relevance_score=0.9, + ) + ] + ) + mock_generate_answer.return_value = genai.GenerateAnswerResponse( + answer=genai.Candidate( + content=genai.Content(parts=[genai.Part(text="42")]), + grounding_attributions=[ + genai.GroundingAttribution( + content=genai.Content( + parts=[genai.Part(text="Meaning of life is 42")] + ), + source_id=genai.AttributionSourceId( + grounding_passage=genai.AttributionSourceId.GroundingPassageId( + passage_id="corpora/123/documents/456/chunks/777", + part_index=0, + ) + ), + ), + genai.GroundingAttribution( + content=genai.Content(parts=[genai.Part(text="Or maybe not")]), + source_id=genai.AttributionSourceId( + grounding_passage=genai.AttributionSourceId.GroundingPassageId( + passage_id="corpora/123/documents/456/chunks/888", + part_index=0, + ) + ), + ), + ], + finish_reason=genai.Candidate.FinishReason.STOP, + ), + answerable_probability=0.9, + ) + + # Act + index = GoogleIndex.from_corpus(corpus_id="123") + query_engine = index.as_query_engine( + answer_style=genai.GenerateAnswerRequest.AnswerStyle.EXTRACTIVE + ) + response = query_engine.query("What is the meaning of life?") + + # Assert + assert mock_query_corpus.call_count == 1 + query_corpus_request = mock_query_corpus.call_args.args[0] + assert query_corpus_request.name == "corpora/123" + assert query_corpus_request.query == "What is the meaning of life?" + + assert isinstance(response, Response) + + assert response.response == "42" + + assert mock_generate_answer.call_count == 1 + generate_answer_request = mock_generate_answer.call_args.args[0] + assert ( + generate_answer_request.contents[0].parts[0].text + == "What is the meaning of life?" + ) + assert ( + generate_answer_request.answer_style + == genai.GenerateAnswerRequest.AnswerStyle.EXTRACTIVE + ) + + passages = generate_answer_request.inline_passages.passages + assert len(passages) == 1 + passage = passages[0] + assert passage.content.parts[0].text == "It's 42" diff --git a/tests/response_synthesizers/test_google.py b/tests/response_synthesizers/test_google.py new file mode 100644 index 0000000000..6c8b6bc194 --- /dev/null +++ b/tests/response_synthesizers/test_google.py @@ -0,0 +1,176 @@ +from unittest.mock import MagicMock, patch + +import pytest + +try: + import google.ai.generativelanguage as genai + + has_google = True +except ImportError: + has_google = False + +from llama_index.response_synthesizers.google.generativeai import ( + GoogleTextSynthesizer, +) +from llama_index.schema import NodeWithScore, TextNode + +SKIP_TEST_REASON = "Google GenerativeAI is not installed" + + +if has_google: + import llama_index.vector_stores.google.generativeai.genai_extension as genaix + + genaix.set_defaults( + genaix.Config( + api_endpoint="No-such-endpoint-to-prevent-hitting-real-backend", + testing=True, + ) + ) + + +@pytest.mark.skipif(not has_google, reason=SKIP_TEST_REASON) +@patch("google.ai.generativelanguage.GenerativeServiceClient.generate_answer") +def test_get_response(mock_generate_answer: MagicMock) -> None: + # Arrange + mock_generate_answer.return_value = genai.GenerateAnswerResponse( + answer=genai.Candidate( + content=genai.Content(parts=[genai.Part(text="42")]), + grounding_attributions=[ + genai.GroundingAttribution( + content=genai.Content( + parts=[genai.Part(text="Meaning of life is 42.")] + ), + source_id=genai.AttributionSourceId( + grounding_passage=genai.AttributionSourceId.GroundingPassageId( + passage_id="corpora/123/documents/456/chunks/789", + part_index=0, + ) + ), + ), + ], + finish_reason=genai.Candidate.FinishReason.STOP, + ), + answerable_probability=0.7, + ) + + # Act + synthesizer = GoogleTextSynthesizer.from_defaults( + temperature=0.5, + answer_style=genai.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE, + safety_setting=[ + genai.SafetySetting( + category=genai.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=genai.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ) + ], + ) + response = synthesizer.get_response( + query_str="What is the meaning of life?", + text_chunks=[ + "It's 42", + ], + ) + + # Assert + assert response.answer == "42" + assert response.attributed_passages == ["Meaning of life is 42."] + assert response.answerable_probability == pytest.approx(0.7) + + assert mock_generate_answer.call_count == 1 + request = mock_generate_answer.call_args.args[0] + assert request.contents[0].parts[0].text == "What is the meaning of life?" + + assert request.answer_style == genai.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE + + assert len(request.safety_settings) == 1 + assert ( + request.safety_settings[0].category + == genai.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT + ) + assert ( + request.safety_settings[0].threshold + == genai.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + ) + + assert request.temperature == 0.5 + + passages = request.inline_passages.passages + assert len(passages) == 1 + passage = passages[0] + assert passage.content.parts[0].text == "It's 42" + + +@pytest.mark.skipif(not has_google, reason=SKIP_TEST_REASON) +@patch("google.ai.generativelanguage.GenerativeServiceClient.generate_answer") +def test_synthesize(mock_generate_answer: MagicMock) -> None: + # Arrange + mock_generate_answer.return_value = genai.GenerateAnswerResponse( + answer=genai.Candidate( + content=genai.Content(parts=[genai.Part(text="42")]), + grounding_attributions=[ + genai.GroundingAttribution( + content=genai.Content( + parts=[genai.Part(text="Meaning of life is 42")] + ), + source_id=genai.AttributionSourceId( + grounding_passage=genai.AttributionSourceId.GroundingPassageId( + passage_id="corpora/123/documents/456/chunks/777", + part_index=0, + ) + ), + ), + genai.GroundingAttribution( + content=genai.Content(parts=[genai.Part(text="Or maybe not")]), + source_id=genai.AttributionSourceId( + grounding_passage=genai.AttributionSourceId.GroundingPassageId( + passage_id="corpora/123/documents/456/chunks/888", + part_index=0, + ) + ), + ), + ], + finish_reason=genai.Candidate.FinishReason.STOP, + ), + answerable_probability=0.9, + ) + + # Act + synthesizer = GoogleTextSynthesizer.from_defaults() + response = synthesizer.synthesize( + query="What is the meaning of life?", + nodes=[ + NodeWithScore( + node=TextNode(text="It's 42"), + score=0.5, + ), + ], + additional_source_nodes=[ + NodeWithScore( + node=TextNode(text="Additional node"), + score=0.4, + ), + ], + ) + + # Assert + assert response.response == "42" + assert len(response.source_nodes) == 4 + + first_attributed_source = response.source_nodes[0] + assert first_attributed_source.node.text == "Meaning of life is 42" + assert first_attributed_source.score is None + + second_attributed_source = response.source_nodes[1] + assert second_attributed_source.node.text == "Or maybe not" + assert second_attributed_source.score is None + + first_input_source = response.source_nodes[2] + assert first_input_source.node.text == "It's 42" + assert first_input_source.score == pytest.approx(0.5) + + first_additional_source = response.source_nodes[3] + assert first_additional_source.node.text == "Additional node" + assert first_additional_source.score == pytest.approx(0.4) + + assert response.metadata is not None + assert response.metadata.get("answerable_probability", None) == pytest.approx(0.9) diff --git a/tests/vector_stores/test_google.py b/tests/vector_stores/test_google.py new file mode 100644 index 0000000000..6152a7f3b7 --- /dev/null +++ b/tests/vector_stores/test_google.py @@ -0,0 +1,306 @@ +from unittest.mock import MagicMock, patch + +import pytest +from llama_index.schema import NodeRelationship, RelatedNodeInfo, TextNode +from llama_index.vector_stores.types import ( + ExactMatchFilter, + MetadataFilters, + VectorStoreQuery, +) + +try: + import google.ai.generativelanguage as genai + + has_google = True +except ImportError: + has_google = False + +from llama_index.vector_stores.google.generativeai import ( + GoogleVectorStore, +) + +SKIP_TEST_REASON = "Google GenerativeAI is not installed" + + +if has_google: + import llama_index.vector_stores.google.generativeai.genai_extension as genaix + + # Make sure the tests do not hit actual production servers. + genaix.set_defaults( + genaix.Config( + api_endpoint="No-such-endpoint-to-prevent-hitting-real-backend", + testing=True, + ) + ) + + +@pytest.mark.skipif(not has_google, reason=SKIP_TEST_REASON) +@patch("google.ai.generativelanguage.RetrieverServiceClient.create_corpus") +def test_create_corpus(mock_create_corpus: MagicMock) -> None: + def fake_create_corpus(request: genai.CreateCorpusRequest) -> genai.Corpus: + return request.corpus + + # Arrange + mock_create_corpus.side_effect = fake_create_corpus + + # Act + store = GoogleVectorStore.create_corpus(display_name="My first corpus") + + # Assert + assert len(store.corpus_id) > 0 + assert mock_create_corpus.call_count == 1 + + request = mock_create_corpus.call_args.args[0] + assert request.corpus.name == f"corpora/{store.corpus_id}" + assert request.corpus.display_name == "My first corpus" + + +@pytest.mark.skipif(not has_google, reason=SKIP_TEST_REASON) +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_corpus") +def test_from_corpus(mock_get_corpus: MagicMock) -> None: + # Arrange + mock_get_corpus.return_value = genai.Corpus(name="corpora/123") + + # Act + store = GoogleVectorStore.from_corpus(corpus_id="123") + + # Assert + assert store.corpus_id == "123" + + +@pytest.mark.skipif(not has_google, reason=SKIP_TEST_REASON) +def test_class_name() -> None: + # Act + class_name = GoogleVectorStore.class_name() + + # Assert + assert class_name == "GoogleVectorStore" + + +@pytest.mark.skipif(not has_google, reason=SKIP_TEST_REASON) +@patch("google.ai.generativelanguage.RetrieverServiceClient.batch_create_chunks") +@patch("google.ai.generativelanguage.RetrieverServiceClient.create_document") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_document") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_corpus") +def test_add( + mock_get_corpus: MagicMock, + mock_get_document: MagicMock, + mock_create_document: MagicMock, + mock_batch_create_chunks: MagicMock, +) -> None: + from google.api_core import exceptions as gapi_exception + + # Arrange + # We will use a max requests per batch to be 2. + # Then, we send 3 requests. + # We expect to have 2 batches where the last batch has only 1 request. + genaix._MAX_REQUEST_PER_CHUNK = 2 + mock_get_corpus.return_value = genai.Corpus(name="corpora/123") + mock_get_document.side_effect = gapi_exception.NotFound("") + mock_create_document.return_value = genai.Document(name="corpora/123/documents/456") + mock_batch_create_chunks.side_effect = [ + genai.BatchCreateChunksResponse( + chunks=[ + genai.Chunk(name="corpora/123/documents/456/chunks/777"), + genai.Chunk(name="corpora/123/documents/456/chunks/888"), + ] + ), + genai.BatchCreateChunksResponse( + chunks=[ + genai.Chunk(name="corpora/123/documents/456/chunks/999"), + ] + ), + ] + + # Act + store = GoogleVectorStore.from_corpus(corpus_id="123") + response = store.add( + [ + TextNode( + text="Hello my baby", + relationships={ + NodeRelationship.SOURCE: RelatedNodeInfo( + node_id="456", + metadata={"file_name": "Title for doc 456"}, + ) + }, + metadata={"position": 100}, + ), + TextNode( + text="Hello my honey", + relationships={ + NodeRelationship.SOURCE: RelatedNodeInfo( + node_id="456", + metadata={"file_name": "Title for doc 456"}, + ) + }, + metadata={"position": 200}, + ), + TextNode( + text="Hello my ragtime gal", + relationships={ + NodeRelationship.SOURCE: RelatedNodeInfo( + node_id="456", + metadata={"file_name": "Title for doc 456"}, + ) + }, + metadata={"position": 300}, + ), + ] + ) + + # Assert + assert response == [ + "corpora/123/documents/456/chunks/777", + "corpora/123/documents/456/chunks/888", + "corpora/123/documents/456/chunks/999", + ] + + create_document_request = mock_create_document.call_args.args[0] + assert create_document_request == genai.CreateDocumentRequest( + parent="corpora/123", + document=genai.Document( + name="corpora/123/documents/456", + display_name="Title for doc 456", + custom_metadata=[ + genai.CustomMetadata( + key="file_name", + string_value="Title for doc 456", + ), + ], + ), + ) + + assert mock_batch_create_chunks.call_count == 2 + mock_batch_create_chunks_calls = mock_batch_create_chunks.call_args_list + + first_batch_create_chunks_request = mock_batch_create_chunks_calls[0].args[0] + assert first_batch_create_chunks_request == genai.BatchCreateChunksRequest( + parent="corpora/123/documents/456", + requests=[ + genai.CreateChunkRequest( + parent="corpora/123/documents/456", + chunk=genai.Chunk( + data=genai.ChunkData(string_value="Hello my baby"), + custom_metadata=[ + genai.CustomMetadata( + key="position", + numeric_value=100, + ), + ], + ), + ), + genai.CreateChunkRequest( + parent="corpora/123/documents/456", + chunk=genai.Chunk( + data=genai.ChunkData(string_value="Hello my honey"), + custom_metadata=[ + genai.CustomMetadata( + key="position", + numeric_value=200, + ), + ], + ), + ), + ], + ) + + second_batch_create_chunks_request = mock_batch_create_chunks_calls[1].args[0] + assert second_batch_create_chunks_request == genai.BatchCreateChunksRequest( + parent="corpora/123/documents/456", + requests=[ + genai.CreateChunkRequest( + parent="corpora/123/documents/456", + chunk=genai.Chunk( + data=genai.ChunkData(string_value="Hello my ragtime gal"), + custom_metadata=[ + genai.CustomMetadata( + key="position", + numeric_value=300, + ), + ], + ), + ), + ], + ) + + +@pytest.mark.skipif(not has_google, reason=SKIP_TEST_REASON) +@patch("google.ai.generativelanguage.RetrieverServiceClient.delete_document") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_corpus") +def test_delete( + mock_get_corpus: MagicMock, + mock_delete_document: MagicMock, +) -> None: + # Arrange + mock_get_corpus.return_value = genai.Corpus(name="corpora/123") + + # Act + store = GoogleVectorStore.from_corpus(corpus_id="123") + store.delete(ref_doc_id="doc-456") + + # Assert + delete_document_request = mock_delete_document.call_args.args[0] + assert delete_document_request == genai.DeleteDocumentRequest( + name="corpora/123/documents/doc-456", + force=True, + ) + + +@pytest.mark.skipif(not has_google, reason=SKIP_TEST_REASON) +@patch("google.ai.generativelanguage.RetrieverServiceClient.query_corpus") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_corpus") +def test_query( + mock_get_corpus: MagicMock, + mock_query_corpus: MagicMock, +) -> None: + # Arrange + mock_get_corpus.return_value = genai.Corpus(name="corpora/123") + mock_query_corpus.return_value = genai.QueryCorpusResponse( + relevant_chunks=[ + genai.RelevantChunk( + chunk=genai.Chunk( + name="corpora/123/documents/456/chunks/789", + data=genai.ChunkData(string_value="42"), + ), + chunk_relevance_score=0.9, + ) + ] + ) + + # Act + store = GoogleVectorStore.from_corpus(corpus_id="123") + store.query( + query=VectorStoreQuery( + query_str="What is the meaning of life?", + filters=MetadataFilters( + filters=[ + ExactMatchFilter( + key="author", + value="Arthur Schopenhauer", + ) + ] + ), + similarity_top_k=1, + ) + ) + + # Assert + assert mock_query_corpus.call_count == 1 + query_corpus_request = mock_query_corpus.call_args.args[0] + assert query_corpus_request == genai.QueryCorpusRequest( + name="corpora/123", + query="What is the meaning of life?", + metadata_filters=[ + genai.MetadataFilter( + key="author", + conditions=[ + genai.Condition( + operation=genai.Condition.Operator.EQUAL, + string_value="Arthur Schopenhauer", + ) + ], + ) + ], + results_count=1, + ) -- GitLab