diff --git a/experimental/splitter_playground/app.py b/experimental/splitter_playground/app.py index d10eb03b06b140a93ec7a47aff625076b3f4e26b..6317101247baed77ec6d2e3e508ed1d3bb588c7a 100644 --- a/experimental/splitter_playground/app.py +++ b/experimental/splitter_playground/app.py @@ -1,12 +1,13 @@ import os import tempfile -from typing import List +from typing import List, Union import streamlit as st import tiktoken from langchain.text_splitter import ( CharacterTextSplitter, RecursiveCharacterTextSplitter, + TextSplitter as LCSplitter, ) from langchain.text_splitter import TokenTextSplitter as LCTokenTextSplitter from streamlit.runtime.uploaded_file_manager import UploadedFile @@ -76,7 +77,7 @@ for ind, col in enumerate(cols): key=f"splitter_cls_{ind}", ) - text_splitter: TextSplitter + text_splitter: Union[TextSplitter, LCSplitter] if text_splitter_cls == "TokenTextSplitter": text_splitter = TokenTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap diff --git a/llama_index/indices/prompt_helper.py b/llama_index/indices/prompt_helper.py index e3c0ce3fe8875b7d654debec87859454f11c9a4a..14b44b7244a27e35b663f68e5ef2e58f6f04d4de 100644 --- a/llama_index/indices/prompt_helper.py +++ b/llama_index/indices/prompt_helper.py @@ -9,6 +9,7 @@ needed), or truncating them so that they fit in a single LLM call. """ import logging +from pydantic import BaseModel, Field, PrivateAttr from typing import Callable, List, Optional, Sequence from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS @@ -25,7 +26,7 @@ DEFAULT_CHUNK_OVERLAP_RATIO = 0.1 logger = logging.getLogger(__name__) -class PromptHelper: +class PromptHelper(BaseModel): """Prompt helper. General prompt helper that can help deal with LLM context window token limitations. @@ -48,6 +49,25 @@ class PromptHelper: """ + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum context size that will get sent to the LLM.", + ) + num_output: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="The amount of token-space to leave in input for generation.", + ) + chunk_overlap_ratio: float = Field( + default=DEFAULT_CHUNK_OVERLAP_RATIO, + description="The percentage token amount that each chunk should overlap.", + ) + chunk_size_limit: Optional[int] = Field(description="The maximum size of a chunk.") + separator: str = Field( + default=" ", description="The separator when chunking tokens." + ) + + _tokenizer: Callable[[str], List] = PrivateAttr() + def __init__( self, context_window: int = DEFAULT_CONTEXT_WINDOW, @@ -58,17 +78,19 @@ class PromptHelper: separator: str = " ", ) -> None: """Init params.""" - self.context_window = context_window - self.num_output = num_output - - self.chunk_overlap_ratio = chunk_overlap_ratio - if self.chunk_overlap_ratio > 1.0 or self.chunk_overlap_ratio < 0.0: + if chunk_overlap_ratio > 1.0 or chunk_overlap_ratio < 0.0: raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.") - self.chunk_size_limit = chunk_size_limit # TODO: make configurable self._tokenizer = tokenizer or globals_helper.tokenizer - self._separator = separator + + super().__init__( + context_window=context_window, + num_output=num_output, + chunk_overlap_ratio=chunk_overlap_ratio, + chunk_size_limit=chunk_size_limit, + separator=separator, + ) @classmethod def from_llm_metadata( @@ -151,7 +173,7 @@ class PromptHelper: raise ValueError("Got 0 as available chunk size.") chunk_overlap = int(self.chunk_overlap_ratio * chunk_size) text_splitter = TokenTextSplitter( - separator=self._separator, + separator=self.separator, chunk_size=chunk_size, chunk_overlap=chunk_overlap, tokenizer=self._tokenizer, diff --git a/llama_index/indices/service_context.py b/llama_index/indices/service_context.py index 01a4b68db85d610b3a75c6549307c4106bb609cb..496a09077fb1302a60e62a9d41fc27ca28213c0e 100644 --- a/llama_index/indices/service_context.py +++ b/llama_index/indices/service_context.py @@ -13,7 +13,7 @@ from llama_index.llms.utils import LLMType, resolve_llm from llama_index.logger import LlamaLogger from llama_index.node_parser.interface import NodeParser from llama_index.node_parser.simple import SimpleNodeParser -from llama_index.prompts.prompts import SimpleInputPrompt +from llama_index.prompts.prompts import Prompt from llama_index.embeddings.utils import resolve_embed_model, EmbedType logger = logging.getLogger(__name__) @@ -78,7 +78,7 @@ class ServiceContext: llama_logger: Optional[LlamaLogger] = None, callback_manager: Optional[CallbackManager] = None, system_prompt: Optional[str] = None, - query_wrapper_prompt: Optional[SimpleInputPrompt] = None, + query_wrapper_prompt: Optional[Prompt] = None, # node parser kwargs chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, @@ -184,7 +184,7 @@ class ServiceContext: llama_logger: Optional[LlamaLogger] = None, callback_manager: Optional[CallbackManager] = None, system_prompt: Optional[str] = None, - query_wrapper_prompt: Optional[SimpleInputPrompt] = None, + query_wrapper_prompt: Optional[Prompt] = None, # node parser kwargs chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, diff --git a/llama_index/llm_predictor/base.py b/llama_index/llm_predictor/base.py index c45891b108081e5af9f9bb3f848b32ca2117e7a7..d65cbb7cb2c4a87dd0aec626d98e71771d463347 100644 --- a/llama_index/llm_predictor/base.py +++ b/llama_index/llm_predictor/base.py @@ -1,8 +1,9 @@ """Wrapper functions around an LLM chain.""" import logging -from abc import abstractmethod -from typing import Any, List, Optional, Protocol, runtime_checkable +from abc import abstractmethod, ABC +from pydantic import BaseModel, PrivateAttr +from typing import Any, List, Optional from llama_index.callbacks.base import CallbackManager from llama_index.llm_predictor.utils import ( @@ -21,8 +22,7 @@ from llama_index.types import TokenAsyncGen, TokenGen logger = logging.getLogger(__name__) -@runtime_checkable -class BaseLLMPredictor(Protocol): +class BaseLLMPredictor(BaseModel, ABC): """Base LLM Predictor.""" @property @@ -63,6 +63,13 @@ class LLMPredictor(BaseLLMPredictor): deprecate this class and move all functionality into the LLM class. """ + class Config: + arbitrary_types_allowed = True + + system_prompt: Optional[str] + query_wrapper_prompt: Optional[Prompt] + _llm: LLM = PrivateAttr() + def __init__( self, llm: Optional[LLMType] = None, @@ -76,8 +83,9 @@ class LLMPredictor(BaseLLMPredictor): if callback_manager: self._llm.callback_manager = callback_manager - self.system_prompt = system_prompt - self.query_wrapper_prompt = query_wrapper_prompt + super().__init__( + system_prompt=system_prompt, query_wrapper_prompt=query_wrapper_prompt + ) @property def llm(self) -> LLM: diff --git a/llama_index/llm_predictor/mock.py b/llama_index/llm_predictor/mock.py index 1bf80fadc46be6e0bd7d75fcc6106afd632a530c..37e442ce80f2aabbf47459015b79911f3ae3b828 100644 --- a/llama_index/llm_predictor/mock.py +++ b/llama_index/llm_predictor/mock.py @@ -1,9 +1,7 @@ """Mock LLM Predictor.""" - +from pydantic import Field from typing import Any, Dict -from llama_index.callbacks.base import CallbackManager -from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.constants import DEFAULT_NUM_OUTPUTS from llama_index.llm_predictor.base import BaseLLMPredictor from llama_index.llms.base import LLMMetadata, LLM @@ -14,7 +12,7 @@ from llama_index.token_counter.utils import ( mock_extract_kg_triplets_response, ) from llama_index.types import TokenAsyncGen, TokenGen -from llama_index.utils import count_tokens, globals_helper +from llama_index.utils import globals_helper # TODO: consolidate with unit tests in tests/mock_utils/mock_predict.py @@ -86,10 +84,9 @@ def _mock_knowledge_graph_triplet_extract(prompt_args: Dict, max_triplets: int) class MockLLMPredictor(BaseLLMPredictor): """Mock LLM Predictor.""" - def __init__(self, max_tokens: int = DEFAULT_NUM_OUTPUTS) -> None: - """Initialize params.""" - self.max_tokens = max_tokens - self.callback_manager = CallbackManager([]) + max_tokens: int = Field( + default=DEFAULT_NUM_OUTPUTS, description="Number of tokens to mock generate." + ) @property def metadata(self) -> LLMMetadata: @@ -99,38 +96,8 @@ class MockLLMPredictor(BaseLLMPredictor): def llm(self) -> LLM: raise NotImplementedError("MockLLMPredictor does not have an LLM model.") - def _log_start(self, prompt: Prompt, prompt_args: dict) -> str: - """Log start of an LLM event.""" - llm_payload = prompt_args.copy() - llm_payload[EventPayload.TEMPLATE] = prompt - event_id = self.callback_manager.on_event_start( - CBEventType.LLM, - payload=llm_payload, - ) - - return event_id - - def _log_end(self, event_id: str, output: str, formatted_prompt: str) -> None: - """Log end of an LLM event.""" - prompt_tokens_count = count_tokens(formatted_prompt) - prediction_tokens_count = count_tokens(output) - self.callback_manager.on_event_end( - CBEventType.LLM, - payload={ - EventPayload.RESPONSE: output, - EventPayload.PROMPT: formatted_prompt, - # deprecated - "formatted_prompt_tokens_count": prompt_tokens_count, - "prediction_tokens_count": prediction_tokens_count, - "total_tokens_used": prompt_tokens_count + prediction_tokens_count, - }, - event_id=event_id, - ) - def predict(self, prompt: Prompt, **prompt_args: Any) -> str: """Mock predict.""" - event_id = self._log_start(prompt, prompt_args) - formatted_prompt = prompt.format(**prompt_args) prompt_str = prompt.prompt_type if prompt_str == PromptType.SUMMARY: @@ -159,7 +126,6 @@ class MockLLMPredictor(BaseLLMPredictor): else: raise ValueError("Invalid prompt type.") - self._log_end(event_id, output, formatted_prompt) return output def stream(self, prompt: Prompt, **prompt_args: Any) -> TokenGen: diff --git a/llama_index/llm_predictor/vellum/predictor.py b/llama_index/llm_predictor/vellum/predictor.py index 6d38e43d1a5c237d291e53d61422964c5d1e460a..1cf6a05f5460649ee63b55890cfbdde4c1efc7e7 100644 --- a/llama_index/llm_predictor/vellum/predictor.py +++ b/llama_index/llm_predictor/vellum/predictor.py @@ -1,5 +1,6 @@ from __future__ import annotations +from pydantic import Field, PrivateAttr from typing import Any, Optional, Tuple, cast from llama_index import Prompt @@ -16,6 +17,17 @@ from llama_index.types import TokenAsyncGen, TokenGen class VellumPredictor(BaseLLMPredictor): + callback_manager: CallbackManager = Field( + default_factory=CallbackManager, exclude=True + ) + + _vellum_client: Any = PrivateAttr() + _async_vellum_client = PrivateAttr() + _prompt_registry: Any = PrivateAttr() + + class Config: + arbitrary_types_allowed = True + def __init__( self, vellum_api_key: str, @@ -29,13 +41,15 @@ class VellumPredictor(BaseLLMPredictor): except ImportError: raise ImportError(import_err_msg) - self.callback_manager = callback_manager or CallbackManager([]) + callback_manager = callback_manager or CallbackManager([]) # Vellum-specific self._vellum_client = Vellum(api_key=vellum_api_key) self._async_vellum_client = AsyncVellum(api_key=vellum_api_key) self._prompt_registry = VellumPromptRegistry(vellum_api_key=vellum_api_key) + super().__init__(callback_manager=callback_manager) + @property def metadata(self) -> LLMMetadata: """Get LLM metadata.""" diff --git a/llama_index/node_parser/extractors/metadata_extractors.py b/llama_index/node_parser/extractors/metadata_extractors.py index 20606c23e24c34cf16c6e01d24f39e2982b2ebf1..06c8f400730b95f45fabdb22a0a498e8ff7ce71c 100644 --- a/llama_index/node_parser/extractors/metadata_extractors.py +++ b/llama_index/node_parser/extractors/metadata_extractors.py @@ -19,10 +19,10 @@ The prompts used to generate the metadata are specifically aimed to help disambiguate the document or subsection from other similar documents or subsections. (similar with contrastive learning) """ - -from abc import abstractmethod import json -from typing import List, Optional, Sequence, cast, Dict, Callable +from abc import abstractmethod +from pydantic import Field, PrivateAttr +from typing import Any, List, Optional, Sequence, cast, Dict, Callable from functools import reduce from llama_index.llms.base import LLM @@ -33,7 +33,7 @@ from llama_index.schema import BaseNode, TextNode class MetadataFeatureExtractor(BaseExtractor): - is_text_node_only = True + is_text_node_only: bool = True @abstractmethod def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: @@ -54,15 +54,29 @@ Excerpt:\n-----\n{content}\n-----\n""" class MetadataExtractor(BaseExtractor): """Metadata extractor.""" + extractors: Sequence[MetadataFeatureExtractor] = Field( + default_factory=list, + description="Metadta feature extractors to apply to each node.", + ) + node_text_template: str = Field( + default=DEFAULT_NODE_TEXT_TEMPLATE, + description="Template to represent how node text is mixed with metadata text.", + ) + disable_template_rewrite: bool = Field( + default=False, description="Disable the node template rewrite." + ) + def __init__( self, extractors: Sequence[MetadataFeatureExtractor], node_text_template: str = DEFAULT_NODE_TEXT_TEMPLATE, disable_template_rewrite: bool = False, ) -> None: - self._extractors = extractors - self._node_text_template = node_text_template - self._disable_template_rewrite = disable_template_rewrite + super().__init__( + extractors=extractors, + node_text_template=node_text_template, + disable_template_rewrite=disable_template_rewrite, + ) def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: """Extract metadata from a document. @@ -72,7 +86,7 @@ class MetadataExtractor(BaseExtractor): """ metadata_list: List[Dict] = [{} for _ in nodes] - for extractor in self._extractors: + for extractor in self.extractors: cur_metadata_list = extractor.extract(nodes) for i, metadata in enumerate(metadata_list): metadata.update(cur_metadata_list[i]) @@ -96,7 +110,7 @@ class MetadataExtractor(BaseExtractor): excluded_llm_metadata_keys (Optional[List[str]]): keys to exclude from llm metadata """ - for extractor in self._extractors: + for extractor in self.extractors: cur_metadata_list = extractor.extract(nodes) for idx, node in enumerate(nodes): node.metadata.update(cur_metadata_list[idx]) @@ -106,9 +120,9 @@ class MetadataExtractor(BaseExtractor): node.excluded_embed_metadata_keys.extend(excluded_embed_metadata_keys) if excluded_llm_metadata_keys is not None: node.excluded_llm_metadata_keys.extend(excluded_llm_metadata_keys) - if not self._disable_template_rewrite: + if not self.disable_template_rewrite: if isinstance(node, TextNode): - cast(TextNode, node).text_template = self._node_text_template + cast(TextNode, node).text_template = self.node_text_template return nodes @@ -133,7 +147,21 @@ class TitleExtractor(MetadataFeatureExtractor): a document-level title """ - is_text_node_only = False # can work for mixture of text and non-text nodes + is_text_node_only: bool = False # can work for mixture of text and non-text nodes + llm_predictor: BaseLLMPredictor = Field( + description="The LLMPredictor to use for generation." + ) + nodes: int = Field( + default=5, description="The number of nodes to extract titles from." + ) + node_template: str = Field( + default=DEFAULT_TITLE_NODE_TEMPLATE, + description="The prompt template to extract titles with.", + ) + combine_template: str = Field( + default=DEFAULT_TITLE_COMBINE_TEMPLATE, + description="The prompt template to merge titles with.", + ) def __init__( self, @@ -147,15 +175,19 @@ class TitleExtractor(MetadataFeatureExtractor): """Init params.""" if nodes < 1: raise ValueError("num_nodes must be >= 1") - self._nodes = nodes - self._node_template = node_template - self._combine_template = combine_template - self._llm_predictor = llm_predictor or LLMPredictor(llm=llm) + llm_predictor = llm_predictor or LLMPredictor(llm=llm) + + super().__init__( + llm_predictor=llm_predictor, + nodes=nodes, + node_template=node_template, + combine_template=combine_template, + ) def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: nodes_to_extract_title: List[BaseNode] = [] for node in nodes: - if len(nodes_to_extract_title) >= self._nodes: + if len(nodes_to_extract_title) >= self.nodes: break if self.is_text_node_only and not isinstance(node, TextNode): continue @@ -166,8 +198,8 @@ class TitleExtractor(MetadataFeatureExtractor): return [] title_candidates = [ - self._llm_predictor.predict( - Prompt(template=self._node_template), + self.llm_predictor.predict( + Prompt(template=self.node_template), context_str=cast(TextNode, node).text, ) for node in nodes_to_extract_title @@ -177,8 +209,8 @@ class TitleExtractor(MetadataFeatureExtractor): lambda x, y: x + "," + y, title_candidates[1:], title_candidates[0] ) - title = self._llm_predictor.predict( - Prompt(template=self._combine_template), + title = self.llm_predictor.predict( + Prompt(template=self.combine_template), context_str=titles, ) else: @@ -186,7 +218,7 @@ class TitleExtractor(MetadataFeatureExtractor): 0 ] # if single node, just use the title from that node - metadata_list = [{"document_title": title.strip(' \t\n\r"')} for node in nodes] + metadata_list = [{"document_title": title.strip(' \t\n\r"')} for _ in nodes] return metadata_list @@ -198,6 +230,11 @@ class KeywordExtractor(MetadataFeatureExtractor): keywords (int): number of keywords to extract """ + llm_predictor: BaseLLMPredictor = Field( + description="The LLMPredictor to use for generation." + ) + keywords: int = Field(default=5, description="The number of keywords to extract.") + def __init__( self, llm: Optional[LLM] = None, @@ -206,10 +243,10 @@ class KeywordExtractor(MetadataFeatureExtractor): keywords: int = 5, ) -> None: """Init params.""" - self._llm_predictor = llm_predictor or LLMPredictor(llm=llm) if keywords < 1: raise ValueError("num_keywords must be >= 1") - self._keywords = keywords + llm_predictor = llm_predictor or LLMPredictor(llm=llm) + super().__init__(llm_predictor=llm_predictor, keywords=keywords) def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: metadata_list: List[Dict] = [] @@ -219,10 +256,10 @@ class KeywordExtractor(MetadataFeatureExtractor): continue # TODO: figure out a good way to allow users to customize keyword template - keywords = self._llm_predictor.predict( + keywords = self.llm_predictor.predict( Prompt( template=f"""\ -{{context_str}}. Give {self._keywords} unique keywords for this \ +{{context_str}}. Give {self.keywords} unique keywords for this \ document. Format as comma separated. Keywords: """ ), context_str=cast(TextNode, node).text, @@ -243,6 +280,19 @@ class QuestionsAnsweredExtractor(MetadataFeatureExtractor): embedding_only (bool): whether to use embedding only """ + llm_predictor: BaseLLMPredictor = Field( + description="The LLMPredictor to use for generation." + ) + questions: int = Field( + default=5, description="The number of questions to generate." + ) + prompt_template: Optional[str] = Field( + default=None, description="Prompt template to use when generating questions." + ) + embedding_only: bool = Field( + default=True, description="Whether to use metadata for emebddings only." + ) + def __init__( self, llm: Optional[LLM] = None, @@ -255,10 +305,13 @@ class QuestionsAnsweredExtractor(MetadataFeatureExtractor): """Init params.""" if questions < 1: raise ValueError("questions must be >= 1") - self._llm_predictor = llm_predictor or LLMPredictor(llm=llm) - self._questions = questions - self._prompt_template = prompt_template - self._embedding_only = embedding_only + llm_predictor = llm_predictor or LLMPredictor(llm=llm) + super().__init__( + llm_predictor=llm_predictor, + questions=questions, + prompt_template=prompt_template, + embedding_only=embedding_only, + ) def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: metadata_list: List[Dict] = [] @@ -268,12 +321,12 @@ class QuestionsAnsweredExtractor(MetadataFeatureExtractor): continue # Extract the title from the first node # TODO: figure out a good way to allow users to customize template - questions = self._llm_predictor.predict( + questions = self.llm_predictor.predict( Prompt( - template=self._prompt_template + template=self.prompt_template or f"""\ {{context_str}}. Given the contextual information, \ -generate {self._questions} questions this document can provide \ +generate {self.questions} questions this document can provide \ specific answers to which are unlikely to be found elsewhere: \ """ ), @@ -281,7 +334,7 @@ specific answers to which are unlikely to be found elsewhere: \ metadata: {json.dumps(node.metadata)} \ content: {cast(TextNode, node).text}""", ) - if self._embedding_only: + if self.embedding_only: node.excluded_llm_metadata_keys = ["questions_this_excerpt_can_answer"] metadata_list.append( {"questions_this_excerpt_can_answer": questions.strip()} @@ -304,6 +357,21 @@ class SummaryExtractor(MetadataFeatureExtractor): summaries (List[str]): list of summaries to extract: 'self', 'prev', 'next' prompt_template (str): template for summary extraction""" + llm_predictor: BaseLLMPredictor = Field( + description="The LLMPredictor to use for generation." + ) + summaries: List[str] = Field( + description="List of summaries to extract: 'self', 'prev', 'next'" + ) + prompt_template: str = Field( + default=DEFAULT_SUMMARY_EXTRACT_TEMPLATE, + description="Template to use when generating summaries.", + ) + + _self_summary: bool = PrivateAttr() + _prev_summary: bool = PrivateAttr() + _next_summary: bool = PrivateAttr() + def __init__( self, llm: Optional[LLM] = None, @@ -312,21 +380,26 @@ class SummaryExtractor(MetadataFeatureExtractor): summaries: List[str] = ["self"], prompt_template: str = DEFAULT_SUMMARY_EXTRACT_TEMPLATE, ): - self._llm_predictor = llm_predictor or LLMPredictor(llm=llm) + llm_predictor = llm_predictor or LLMPredictor(llm=llm) # validation if not all([s in ["self", "prev", "next"] for s in summaries]): raise ValueError("summaries must be one of ['self', 'prev', 'next']") self._self_summary = "self" in summaries self._prev_summary = "prev" in summaries self._next_summary = "next" in summaries - self._prompt_template = prompt_template + + super().__init__( + llm_predictor=llm_predictor, + summaries=summaries, + prompt_template=prompt_template, + ) def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: if not all([isinstance(node, TextNode) for node in nodes]): raise ValueError("Only `TextNode` is allowed for `Summary` extractor") node_summaries = [ - self._llm_predictor.predict( - Prompt(template=self._prompt_template), + self.llm_predictor.predict( + Prompt(template=self.prompt_template), context_str=cast(TextNode, node).text, ).strip() for node in nodes @@ -363,6 +436,8 @@ DEFAULT_ENTITY_MAP = { "VEHI": "vehicles", } +DEFAULT_ENTITY_MODEL = "tomaarsen/span-marker-mbert-base-multinerd" + class EntityExtractor(MetadataFeatureExtractor): """ @@ -372,9 +447,31 @@ class EntityExtractor(MetadataFeatureExtractor): Install SpanMarker with `pip install span-marker`. """ + model_name: str = Field( + default=DEFAULT_ENTITY_MODEL, + description="The model name of the SpanMarker model to use.", + ) + prediction_threshold: float = Field( + default=0.5, description="The confidence threshold for accepting predictions." + ) + span_joiner: str = Field(description="The seperator beween entity names.") + label_entities: bool = Field( + default=False, description="Include entity class labels or not." + ) + device: Optional[str] = Field( + default=None, description="Device to run model on, i.e. 'cuda', 'cpu'" + ) + entity_map: Dict[str, str] = Field( + default_factory=dict, + description="Mapping of entity class names to usable names.", + ) + + _tokenizer: Callable = PrivateAttr() + _model: Any = PrivateAttr + def __init__( self, - model_name: str = "tomaarsen/span-marker-mbert-base-multinerd", + model_name: str = DEFAULT_ENTITY_MODEL, prediction_threshold: float = 0.5, span_joiner: str = " ", label_entities: bool = False, @@ -423,12 +520,19 @@ class EntityExtractor(MetadataFeatureExtractor): self._model = self._model.to(device) self._tokenizer = tokenizer or word_tokenize - self._prediction_threshold = prediction_threshold - self._span_joiner = span_joiner - self._label_entities = label_entities - self._entity_map = DEFAULT_ENTITY_MAP + + base_entity_map = DEFAULT_ENTITY_MAP if entity_map is not None: - self._entity_map.update(entity_map) + base_entity_map.update(entity_map) + + super().__init__( + model_name=model_name, + prediction_threshold=prediction_threshold, + span_joiner=span_joiner, + label_entities=label_entities, + device=device, + entity_map=entity_map, + ) def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: # Extract node-level entity metadata @@ -438,13 +542,18 @@ class EntityExtractor(MetadataFeatureExtractor): words = self._tokenizer(node_text) spans = self._model.predict(words) for span in spans: - if span["score"] > self._prediction_threshold: - ent_label = self._entity_map.get(span["label"], span["label"]) - metadata_label = ent_label if self._label_entities else "entities" + if span["score"] > self.prediction_threshold: + ent_label = self.entity_map.get(span["label"], span["label"]) + metadata_label = ent_label if self.label_entities else "entities" if metadata_label not in metadata: metadata[metadata_label] = set() - metadata[metadata_label].add(self._span_joiner.join(span["span"])) + metadata[metadata_label].add(self.span_joiner.join(span["span"])) + + # convert metadata from set to list + for metadata in metadata_list: + for key, val in metadata.items(): + metadata[key] = list(val) return metadata_list diff --git a/llama_index/node_parser/interface.py b/llama_index/node_parser/interface.py index 409a2726c907d0f9b1ccd331164e407fdecbca26..bbd0eb249bb5c9570828ce56b7f2a5aa0f499c60 100644 --- a/llama_index/node_parser/interface.py +++ b/llama_index/node_parser/interface.py @@ -1,4 +1,5 @@ """Node parser interface.""" +from pydantic import BaseModel from typing import List, Sequence, Dict from abc import ABC, abstractmethod @@ -7,9 +8,12 @@ from llama_index.schema import Document from llama_index.schema import BaseNode -class NodeParser(ABC): +class NodeParser(BaseModel, ABC): """Base interface for node parser.""" + class Config: + arbitrary_types_allowed = True + @abstractmethod def get_nodes_from_documents( self, @@ -24,9 +28,12 @@ class NodeParser(ABC): """ -class BaseExtractor(ABC): +class BaseExtractor(BaseModel, ABC): """Base interface for feature extractor.""" + class Config: + arbitrary_types_allowed = True + @abstractmethod def extract( self, diff --git a/llama_index/node_parser/sentence_window.py b/llama_index/node_parser/sentence_window.py index 7bb6f34316e7525d4456e57f0221357aead68760..fb62517269f07804a1b315a5b275f91ff03de388 100644 --- a/llama_index/node_parser/sentence_window.py +++ b/llama_index/node_parser/sentence_window.py @@ -1,4 +1,5 @@ """Simple node parser.""" +from pydantic import Field from typing import List, Callable, Optional, Sequence from llama_index.callbacks.base import CallbackManager @@ -28,6 +29,36 @@ class SentenceWindowNodeParser(NodeParser): include_prev_next_rel (bool): whether to include prev/next relationships """ + sentence_splitter: Callable[[str], List[str]] = Field( + default_factory=split_by_sentence_tokenizer, + description="The text splitter to use when splitting documents.", + exclude=True, + ) + window_size: int = Field( + default=DEFAULT_WINDOW_SIZE, + description="The number of sentences on each side of a sentence to capture.", + ) + window_metadata_key: str = Field( + default=DEFAULT_WINDOW_METADATA_KEY, + description="The metadata key to store the sentence window under.", + ) + original_text_metadata_key: str = Field( + default=DEFAULT_OG_TEXT_METADATA_KEY, + description="The metadata key to store the original sentence in.", + ) + include_metadata: bool = Field( + default=True, description="Whether or not to consider metadata when splitting." + ) + include_prev_next_rel: bool = Field( + default=True, description="Include prev/next node relationships." + ) + metadata_extractor: Optional[MetadataExtractor] = Field( + default=None, description="Metadata extraction pipeline to apply to nodes." + ) + callback_manager: CallbackManager = Field( + default_factory=CallbackManager, exclude=True + ) + def __init__( self, sentence_splitter: Optional[Callable[[str], List[str]]] = None, @@ -40,15 +71,18 @@ class SentenceWindowNodeParser(NodeParser): metadata_extractor: Optional[MetadataExtractor] = None, ) -> None: """Init params.""" - self.callback_manager = callback_manager or CallbackManager([]) - self._sentence_splitter = sentence_splitter or split_by_sentence_tokenizer() - self._window_size = window_size - self._window_metadata_key = window_metadata_key - self._original_text_metadata_key = original_text_metadata_key - - self._include_metadata = include_metadata - self._include_prev_next_rel = include_prev_next_rel - self._metadata_extractor = metadata_extractor + callback_manager = callback_manager or CallbackManager([]) + sentence_splitter = sentence_splitter or split_by_sentence_tokenizer() + super().__init__( + sentence_splitter=sentence_splitter, + window_size=window_size, + window_metadata_key=window_metadata_key, + original_text_metadata_key=original_text_metadata_key, + include_metadata=include_metadata, + include_prev_next_rel=include_prev_next_rel, + callback_manager=callback_manager, + metadata_extractor=metadata_extractor, + ) @classmethod def from_defaults( @@ -98,12 +132,12 @@ class SentenceWindowNodeParser(NodeParser): ) for document in documents_with_progress: - self._sentence_splitter(document.text) + self.sentence_splitter(document.text) nodes = self.build_window_nodes_from_documents([document]) all_nodes.extend(nodes) - if self._metadata_extractor is not None: - self._metadata_extractor.process_nodes(all_nodes) + if self.metadata_extractor is not None: + self.metadata_extractor.process_nodes(all_nodes) event.on_end(payload={EventPayload.NODES: all_nodes}) @@ -116,7 +150,7 @@ class SentenceWindowNodeParser(NodeParser): all_nodes: List[BaseNode] = [] for doc in documents: text = doc.text - text_splits = self._sentence_splitter(text) + text_splits = self.sentence_splitter(text) nodes = build_nodes_from_splits( text_splits, doc, include_prev_next_rel=True ) @@ -124,22 +158,20 @@ class SentenceWindowNodeParser(NodeParser): # add window to each node for i, node in enumerate(nodes): window_nodes = nodes[ - max(0, i - self._window_size) : min( - i + self._window_size, len(nodes) - ) + max(0, i - self.window_size) : min(i + self.window_size, len(nodes)) ] - node.metadata[self._window_metadata_key] = " ".join( + node.metadata[self.window_metadata_key] = " ".join( [n.text for n in window_nodes] ) - node.metadata[self._original_text_metadata_key] = node.text + node.metadata[self.original_text_metadata_key] = node.text # exclude window metadata from embed and llm node.excluded_embed_metadata_keys.extend( - [self._window_metadata_key, self._original_text_metadata_key] + [self.window_metadata_key, self.original_text_metadata_key] ) node.excluded_llm_metadata_keys.extend( - [self._window_metadata_key, self._original_text_metadata_key] + [self.window_metadata_key, self.original_text_metadata_key] ) all_nodes.extend(nodes) diff --git a/llama_index/node_parser/simple.py b/llama_index/node_parser/simple.py index f60a3433f15e384e105cfe806f6961d5cc15a89b..6a150dc88ecfc303eb9658f9eb8ce5f8b3b51f8e 100644 --- a/llama_index/node_parser/simple.py +++ b/llama_index/node_parser/simple.py @@ -1,4 +1,5 @@ """Simple node parser.""" +from pydantic import Field from typing import List, Optional, Sequence from llama_index.callbacks.base import CallbackManager @@ -23,28 +24,28 @@ class SimpleNodeParser(NodeParser): """ - def __init__( - self, - text_splitter: Optional[TextSplitter] = None, - include_metadata: bool = True, - include_prev_next_rel: bool = True, - callback_manager: Optional[CallbackManager] = None, - metadata_extractor: Optional[MetadataExtractor] = None, - ) -> None: - """Init params.""" - self.callback_manager = callback_manager or CallbackManager([]) - self._text_splitter = text_splitter or get_default_text_splitter( - callback_manager=self.callback_manager - ) - self._include_metadata = include_metadata - self._include_prev_next_rel = include_prev_next_rel - self._metadata_extractor = metadata_extractor + text_splitter: TextSplitter = Field( + description="The text splitter to use when splitting documents." + ) + include_metadata: bool = Field( + default=True, description="Whether or not to consider metadata when splitting." + ) + include_prev_next_rel: bool = Field( + default=True, description="Include prev/next node relationships." + ) + metadata_extractor: Optional[MetadataExtractor] = Field( + default=None, description="Metadata extraction pipeline to apply to nodes." + ) + callback_manager: CallbackManager = Field( + default_factory=CallbackManager, exclude=True + ) @classmethod def from_defaults( cls, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, + text_splitter: Optional[TextSplitter] = None, include_metadata: bool = True, include_prev_next_rel: bool = True, callback_manager: Optional[CallbackManager] = None, @@ -52,7 +53,7 @@ class SimpleNodeParser(NodeParser): ) -> "SimpleNodeParser": callback_manager = callback_manager or CallbackManager([]) - text_splitter = get_default_text_splitter( + text_splitter = text_splitter or get_default_text_splitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, callback_manager=callback_manager, @@ -88,14 +89,14 @@ class SimpleNodeParser(NodeParser): for document in documents_with_progress: nodes = get_nodes_from_document( document, - self._text_splitter, - self._include_metadata, - include_prev_next_rel=self._include_prev_next_rel, + self.text_splitter, + self.include_metadata, + include_prev_next_rel=self.include_prev_next_rel, ) all_nodes.extend(nodes) - if self._metadata_extractor is not None: - self._metadata_extractor.process_nodes(all_nodes) + if self.metadata_extractor is not None: + self.metadata_extractor.process_nodes(all_nodes) event.on_end(payload={EventPayload.NODES: all_nodes}) diff --git a/llama_index/text_splitter/code_splitter.py b/llama_index/text_splitter/code_splitter.py index 4421e5e5d06351485f3e9148e90ba5087fb3d066..3153c46017f2268991648365fb0cd140502f4e9e 100644 --- a/llama_index/text_splitter/code_splitter.py +++ b/llama_index/text_splitter/code_splitter.py @@ -1,10 +1,15 @@ """Code splitter.""" +from pydantic import Field from typing import Any, List, Optional from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.text_splitter.types import TextSplitter +DEFAULT_CHUNK_LINES = 40 +DEFAULT_LINES_OVERLAP = 15 +DEFAULT_MAX_CHARS = 1500 + class CodeSplitter(TextSplitter): """Split code using a AST parser. @@ -13,6 +18,24 @@ class CodeSplitter(TextSplitter): https://docs.sweep.dev/blogs/chunking-2m-files """ + language: str = Field( + description="The programming languge of the code being split." + ) + chunk_lines: int = Field( + default=DEFAULT_CHUNK_LINES, + description="The number of lines to include in each chunk.", + ) + chunk_lines_overlap: int = Field( + default=DEFAULT_LINES_OVERLAP, + description="How many lines of code each chunk overlaps with.", + ) + max_chars: int = Field( + default=DEFAULT_MAX_CHARS, description="Maximum number of characters per chunk." + ) + callback_manager: CallbackManager = Field( + default_factory=CallbackManager, exclude=True + ) + def __init__( self, language: str, @@ -21,11 +44,14 @@ class CodeSplitter(TextSplitter): max_chars: int = 1500, callback_manager: Optional[CallbackManager] = None, ): - self.language = language - self.chunk_lines = chunk_lines - self.chunk_lines_overlap = chunk_lines_overlap - self.max_chars = max_chars - self.callback_manager = callback_manager or CallbackManager([]) + callback_manager = callback_manager or CallbackManager([]) + super().__init__( + language=language, + chunk_lines=chunk_lines, + chunk_lines_overlap=chunk_lines_overlap, + max_chars=max_chars, + callback_manager=callback_manager, + ) def _chunk_node(self, node: Any, text: str, last_end: int = 0) -> List[str]: new_chunks = [] diff --git a/llama_index/text_splitter/sentence_splitter.py b/llama_index/text_splitter/sentence_splitter.py index 588692bbee22ea70ee0d6eaa4a95a2cbd0d2824e..1fbf6500136d0566ed05116410a7407e3eacc0cc 100644 --- a/llama_index/text_splitter/sentence_splitter.py +++ b/llama_index/text_splitter/sentence_splitter.py @@ -1,5 +1,6 @@ """Sentence splitter.""" from dataclasses import dataclass +from pydantic import Field, PrivateAttr from typing import Callable, List, Optional from llama_index.callbacks.base import CallbackManager @@ -14,6 +15,10 @@ from llama_index.text_splitter.utils import ( ) from llama_index.utils import globals_helper +SENTENCE_CHUNK_OVERLAP = 200 +CHUNKING_REGEX = "[^,.;。]+[,.;。]?" +DEFUALT_PARAGRAPH_SEP = "\n\n\n" + @dataclass class _Split: @@ -29,15 +34,50 @@ class SentenceSplitter(MetadataAwareTextSplitter): hanging sentences or parts of sentences at the end of the node chunk. """ + chunk_size: int = Field( + default=DEFAULT_CHUNK_SIZE, description="The token chunk size for each chunk." + ) + chunk_overlap: int = Field( + default=SENTENCE_CHUNK_OVERLAP, + description="The token overlap of each chunk when splitting.", + ) + seperator: str = Field( + default=" ", description="Default seperator for splitting into words" + ) + paragraph_seperator: List = Field( + default=DEFUALT_PARAGRAPH_SEP, description="Seperator between paragraphs." + ) + secondary_chunking_regex: str = Field( + default=CHUNKING_REGEX, description="Backup regex for splitting into sentences." + ) + chunking_tokenizer_fn: Callable[[str], List[str]] = Field( + exclude=True, + description=( + "Function to split text into sentences. " + "Defaults to `nltk.sent_tokenize`." + ), + ) + callback_manager: CallbackManager = Field( + default_factory=CallbackManager, exclude=True + ) + tokenizer: Callable = Field( + default_factory=globals_helper.tokenizer, # type: ignore + description="Tokenizer for splitting words into tokens.", + exclude=True, + ) + + _split_fns: List[Callable] = PrivateAttr() + _sub_sentence_split_fns: List[Callable] = PrivateAttr() + def __init__( self, separator: str = " ", chunk_size: int = DEFAULT_CHUNK_SIZE, - chunk_overlap: int = 200, + chunk_overlap: int = SENTENCE_CHUNK_OVERLAP, tokenizer: Optional[Callable] = None, - paragraph_separator: str = "\n\n\n", + paragraph_separator: str = DEFUALT_PARAGRAPH_SEP, chunking_tokenizer_fn: Optional[Callable[[str], List[str]]] = None, - secondary_chunking_regex: str = "[^,.;。]+[,.;。]?", + secondary_chunking_regex: str = CHUNKING_REGEX, callback_manager: Optional[CallbackManager] = None, ): """Initialize with parameters.""" @@ -46,12 +86,10 @@ class SentenceSplitter(MetadataAwareTextSplitter): f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller." ) - self._chunk_size = chunk_size - self._chunk_overlap = chunk_overlap - self.tokenizer = tokenizer or globals_helper.tokenizer - self.callback_manager = callback_manager or CallbackManager([]) + callback_manager = callback_manager or CallbackManager([]) chunking_tokenizer_fn = chunking_tokenizer_fn or split_by_sentence_tokenizer() + tokenizer = tokenizer or globals_helper.tokenizer self._split_fns = [ split_by_sep(paragraph_separator), @@ -64,13 +102,24 @@ class SentenceSplitter(MetadataAwareTextSplitter): split_by_char(), ] + super().__init__( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + chunking_tokenizer_fn=chunking_tokenizer_fn, + secondary_chunking_regex=secondary_chunking_regex, + separator=separator, + paragraph_separator=paragraph_separator, + callback_manager=callback_manager, + tokenizer=tokenizer, + ) + def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]: metadata_len = len(self.tokenizer(metadata_str)) - effective_chunk_size = self._chunk_size - metadata_len + effective_chunk_size = self.chunk_size - metadata_len return self._split_text(text, chunk_size=effective_chunk_size) def split_text(self, text: str) -> List[str]: - return self._split_text(text, chunk_size=self._chunk_size) + return self._split_text(text, chunk_size=self.chunk_size) def _split_text(self, text: str, chunk_size: int) -> List[str]: """ @@ -149,7 +198,7 @@ class SentenceSplitter(MetadataAwareTextSplitter): else: if ( cur_split.is_sentence - or cur_chunk_len + cur_split_len < chunk_size - self._chunk_overlap + or cur_chunk_len + cur_split_len < chunk_size - self.chunk_overlap or len(cur_chunk) == 0 ): # add split to chunk diff --git a/llama_index/text_splitter/token_splitter.py b/llama_index/text_splitter/token_splitter.py index 4e63a22446afe79d78d6b890532eb8e3c53e53e5..7f2f562bc79131b816a2a847fc1fdcd6c8398084 100644 --- a/llama_index/text_splitter/token_splitter.py +++ b/llama_index/text_splitter/token_splitter.py @@ -1,5 +1,6 @@ """Token splitter.""" import logging +from pydantic import Field, PrivateAttr from typing import Callable, List, Optional from llama_index.callbacks.base import CallbackManager @@ -18,6 +19,30 @@ DEFAULT_METADATA_FORMAT_LEN = 2 class TokenTextSplitter(MetadataAwareTextSplitter): """Implementation of splitting text that looks at word tokens.""" + chunk_size: int = Field( + default=DEFAULT_CHUNK_SIZE, description="The token chunk size for each chunk." + ) + chunk_overlap: int = Field( + default=DEFAULT_CHUNK_OVERLAP, + description="The token overlap of each chunk when splitting.", + ) + seperator: str = Field( + default=" ", description="Default seperator for splitting into words" + ) + backup_seperators: List = Field( + default_factory=list, description="Additional seperators for splitting." + ) + callback_manager: CallbackManager = Field( + default_factory=CallbackManager, exclude=True + ) + tokenizer: Callable = Field( + default_factory=globals_helper.tokenizer, # type: ignore + description="Tokenizer for splitting words into tokens.", + exclude=True, + ) + + _split_fns: List[Callable] = PrivateAttr() + def __init__( self, chunk_size: int = DEFAULT_CHUNK_SIZE, @@ -33,23 +58,30 @@ class TokenTextSplitter(MetadataAwareTextSplitter): f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller." ) - self._chunk_size = chunk_size - self._chunk_overlap = chunk_overlap - self.tokenizer = tokenizer or globals_helper.tokenizer - self.callback_manager = callback_manager or CallbackManager([]) + callback_manager = callback_manager or CallbackManager([]) + tokenizer = tokenizer or globals_helper.tokenizer all_seps = [separator] + (backup_separators or []) self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()] + super().__init__( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + separator=separator, + backup_separators=backup_separators, + callback_manager=callback_manager, + tokenizer=tokenizer, + ) + def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]: """Split text into chunks, reserving space required for metadata str.""" metadata_len = len(self.tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN - effective_chunk_size = self._chunk_size - metadata_len + effective_chunk_size = self.chunk_size - metadata_len return self._split_text(text, chunk_size=effective_chunk_size) def split_text(self, text: str) -> List[str]: """Split text into chunks.""" - return self._split_text(text, chunk_size=self._chunk_size) + return self._split_text(text, chunk_size=self.chunk_size) def _split_text(self, text: str, chunk_size: int) -> List[str]: """Split text into chunks up to chunk_size.""" @@ -129,7 +161,7 @@ class TokenTextSplitter(MetadataAwareTextSplitter): # keep popping off the first element of the previous chunk until: # 1. the current chunk length is less than chunk overlap # 2. the total length is less than chunk size - while cur_len > self._chunk_overlap or cur_len + split_len > chunk_size: + while cur_len > self.chunk_overlap or cur_len + split_len > chunk_size: # pop off the first element first_chunk = cur_chunk.pop(0) cur_len -= len(self.tokenizer(first_chunk)) diff --git a/llama_index/text_splitter/types.py b/llama_index/text_splitter/types.py index 6da292feee9963aebd6f1d829dfc017b1108804c..81e9982a7e40fd401d37c68acd2014d866aea810 100644 --- a/llama_index/text_splitter/types.py +++ b/llama_index/text_splitter/types.py @@ -1,16 +1,19 @@ """Text splitter implementations.""" -from typing import List, Protocol, runtime_checkable +from abc import abstractmethod, ABC +from pydantic import BaseModel +from typing import List -class TextSplitter(Protocol): - def split_text(self, text: str) -> List[str]: - ... +class TextSplitter(ABC, BaseModel): + class Config: + arbitrary_types_allowed = True - -@runtime_checkable -class MetadataAwareTextSplitter(Protocol): + @abstractmethod def split_text(self, text: str) -> List[str]: ... + +class MetadataAwareTextSplitter(TextSplitter): + @abstractmethod def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]: ... diff --git a/tests/conftest.py b/tests/conftest.py index 8c8eac9219176d81d5b3e4df2f0b1f15dd30bd57..884f86d30002e4458967d1c9395e5e41420c5a62 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,9 +45,6 @@ def patch_token_text_splitter(monkeypatch: pytest.MonkeyPatch) -> None: @pytest.fixture def patch_llm_predictor(monkeypatch: pytest.MonkeyPatch) -> None: - def do_nothing(*args: Any, **kwargs: Any) -> Any: - pass - monkeypatch.setattr( LLMPredictor, "predict", @@ -63,11 +60,6 @@ def patch_llm_predictor(monkeypatch: pytest.MonkeyPatch) -> None: "llm", MockLLM(), ) - monkeypatch.setattr( - LLMPredictor, - "__init__", - do_nothing, - ) monkeypatch.setattr( LLMPredictor, "metadata", diff --git a/tests/indices/test_node_utils.py b/tests/indices/test_node_utils.py index fa66e4d8f5a5ed71668af38fdba28164c8147e18..331d28039b2501ef7de261de709324acade9dea5 100644 --- a/tests/indices/test_node_utils.py +++ b/tests/indices/test_node_utils.py @@ -46,7 +46,7 @@ def test_get_nodes_from_document( len(text_splitter.tokenizer(node.get_content())) for node in nodes ] assert all( - chunk_size <= text_splitter._chunk_size for chunk_size in actual_chunk_sizes + chunk_size <= text_splitter.chunk_size for chunk_size in actual_chunk_sizes ) @@ -65,7 +65,7 @@ def test_get_nodes_from_document_with_metadata( for node in nodes ] assert all( - chunk_size <= text_splitter._chunk_size for chunk_size in actual_chunk_sizes + chunk_size <= text_splitter.chunk_size for chunk_size in actual_chunk_sizes ) assert all( [ @@ -85,7 +85,7 @@ def test_get_nodes_from_document_langchain_compatible( ) nodes = get_nodes_from_document( documents[0], - text_splitter, + text_splitter, # type: ignore include_metadata=False, ) assert len(nodes) == 2 diff --git a/tests/indices/test_prompt_helper.py b/tests/indices/test_prompt_helper.py index 62dc0a2d5eabd06e96ef80be76f4221c116b4181..2946ee2fb03e0ce354867076d3b37ba17daffae2 100644 --- a/tests/indices/test_prompt_helper.py +++ b/tests/indices/test_prompt_helper.py @@ -57,7 +57,7 @@ def test_get_text_splitter() -> None: text_splitter = prompt_helper.get_text_splitter_given_prompt( test_prompt, 2, padding=1 ) - assert text_splitter._chunk_size == 2 + assert text_splitter.chunk_size == 2 test_text = "Hello world foo Hello world bar" text_chunks = text_splitter.split_text(test_text) assert text_chunks == ["Hello world", "foo Hello", "world bar"] diff --git a/tests/llm_predictor/vellum/test_predictor.py b/tests/llm_predictor/vellum/test_predictor.py index 2e218c4db66035195d7b342335b852a5fe1ee7e3..785c4518a78db3af68c6ca5df33428c9ddd05d3c 100644 --- a/tests/llm_predictor/vellum/test_predictor.py +++ b/tests/llm_predictor/vellum/test_predictor.py @@ -4,11 +4,8 @@ from unittest import mock import pytest from llama_index import Prompt -from llama_index.callbacks import CBEventType from llama_index.llm_predictor.vellum import ( - VellumRegisteredPrompt, VellumPredictor, - VellumPromptRegistry, ) @@ -31,57 +28,6 @@ def test_predict__basic( assert completion_text == "Hello, world!" -def test_predict__callback_manager( - mock_vellum_client_factory: Callable[..., mock.MagicMock], - vellum_predictor_factory: Callable[..., VellumPredictor], - vellum_prompt_registry_factory: Callable[..., VellumPromptRegistry], - dummy_prompt: Prompt, -) -> None: - """Ensure we invoke a callback manager, when provided""" - - callback_manager = mock.MagicMock() - - vellum_client = mock_vellum_client_factory( - compiled_prompt_text="What's you're favorite greeting?", - completion_text="Hello, world!", - ) - - registered_prompt = VellumRegisteredPrompt( - deployment_id="abc", - deployment_name="my-deployment", - model_version_id="123", - ) - prompt_registry = vellum_prompt_registry_factory(vellum_client=vellum_client) - - with mock.patch.object(prompt_registry, "from_prompt") as mock_from_prompt: - mock_from_prompt.return_value = registered_prompt - - predictor = vellum_predictor_factory( - callback_manager=callback_manager, - vellum_client=vellum_client, - vellum_prompt_registry=prompt_registry, - ) - - predictor.predict(dummy_prompt, thing="greeting") - - callback_manager.on_event_start.assert_called_once_with( - CBEventType.LLM, - payload={ - "thing": "greeting", - "deployment_id": registered_prompt.deployment_id, - "model_version_id": registered_prompt.model_version_id, - }, - ) - callback_manager.on_event_end.assert_called_once_with( - CBEventType.LLM, - payload={ - "response": "Hello, world!", - "formatted_prompt": "What's you're favorite greeting?", - }, - event_id=mock.ANY, - ) - - def test_stream__basic( mock_vellum_client_factory: Callable[..., mock.MagicMock], vellum_predictor_factory: Callable[..., VellumPredictor], @@ -131,89 +77,3 @@ def test_stream__basic( assert next(completion_generator) == " world!" with pytest.raises(StopIteration): next(completion_generator) - - -def test_stream__callback_manager( - mock_vellum_client_factory: Callable[..., mock.MagicMock], - vellum_predictor_factory: Callable[..., VellumPredictor], - vellum_prompt_registry_factory: Callable[..., VellumPromptRegistry], - dummy_prompt: Prompt, -) -> None: - """Ensure we invoke a callback manager, when provided""" - - import vellum - - callback_manager = mock.MagicMock() - - vellum_client = mock_vellum_client_factory( - compiled_prompt_text="What's you're favorite greeting?", - completion_text="Hello, world!", - ) - - def fake_stream() -> Iterator[vellum.GenerateStreamResponse]: - yield vellum.GenerateStreamResponse( - delta=vellum.GenerateStreamResult( - request_index=0, - data=vellum.GenerateStreamResultData( - completion_index=0, - completion=vellum.EnrichedNormalizedCompletion( - id="123", text="Hello,", model_version_id="abc" - ), - ), - error=None, - ) - ) - yield vellum.GenerateStreamResponse( - delta=vellum.GenerateStreamResult( - request_index=0, - data=vellum.GenerateStreamResultData( - completion_index=0, - completion=vellum.EnrichedNormalizedCompletion( - id="456", text=" world!", model_version_id="abc" - ), - ), - error=None, - ) - ) - - vellum_client.generate_stream.return_value = fake_stream() - - registered_prompt = VellumRegisteredPrompt( - deployment_id="abc", - deployment_name="my-deployment", - model_version_id="123", - ) - prompt_registry = vellum_prompt_registry_factory(vellum_client=vellum_client) - - with mock.patch.object(prompt_registry, "from_prompt") as mock_from_prompt: - mock_from_prompt.return_value = registered_prompt - - predictor = vellum_predictor_factory( - callback_manager=callback_manager, - vellum_client=vellum_client, - vellum_prompt_registry=prompt_registry, - ) - - completion_generator = predictor.stream(dummy_prompt, thing="greeting") - - assert next(completion_generator) == "Hello," - assert next(completion_generator) == " world!" - with pytest.raises(StopIteration): - next(completion_generator) - - callback_manager.on_event_start.assert_called_once_with( - CBEventType.LLM, - payload={ - "thing": "greeting", - "deployment_id": registered_prompt.deployment_id, - "model_version_id": registered_prompt.model_version_id, - }, - ) - callback_manager.on_event_end.assert_called_once_with( - CBEventType.LLM, - payload={ - "response": "Hello, world!", - "formatted_prompt": "What's you're favorite greeting?", - }, - event_id=mock.ANY, - )