diff --git a/llama-index-integrations/llms/llama-index-llms-mlx/.gitignore b/llama-index-integrations/llms/llama-index-llms-mlx/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..990c18de229088f55c6c514fd0f2d49981d1b0e7 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mlx/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-integrations/llms/llama-index-llms-mlx/BUILD b/llama-index-integrations/llms/llama-index-llms-mlx/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0896ca890d8bffd60a44fa824f8d57fecd73ee53 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mlx/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", +) diff --git a/llama-index-integrations/llms/llama-index-llms-mlx/Makefile b/llama-index-integrations/llms/llama-index-llms-mlx/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..b9eab05aa370629a4a3de75df3ff64cd53887b68 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mlx/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/llms/llama-index-llms-mlx/README.md b/llama-index-integrations/llms/llama-index-llms-mlx/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1ac63be82742ef56ca149654c61e2b0b6d967d23 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mlx/README.md @@ -0,0 +1,34 @@ +# LlamaIndex Llms Integration: MLX + +## Overview + +--- + +Integrate with MLX LLMs from the mlx-lm library + +## Installation + +--- + +```bash +pip install llama-index-llms-mlx +``` + +## Example + +--- + +```python +from llama_index.llms.mlx import MLXLLM + +llm = MLXLLM( + model_name="microsoft/phi-2", + tokenizer_name="microsoft/phi-2", + context_window=3900, + max_new_tokens=256, + generate_kwargs={"temp": 0.7, "top_p": 0.95}, +) + +response = llm.complete("What is the meaning of life?") +print(str(response)) +``` diff --git a/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/BUILD b/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..db46e8d6c978c67e301dd6c47bee08c1b3fd141c --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/__init__.py b/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a2ebfcafca5cf2f3ecbdd1da59a6ec81452a035 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/__init__.py @@ -0,0 +1,4 @@ +from llama_index.llms.mlx.base import MLXLLM + + +__all__ = ["MLXLLM"] diff --git a/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/base.py b/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f4c0e4181a0089df35ee722c78666969b5288d --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/base.py @@ -0,0 +1,292 @@ +from .utils import gen_stream + +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseGen, + CompletionResponse, + CompletionResponseGen, + LLMMetadata, +) +from llama_index.core.bridge.pydantic import Field, PrivateAttr +from llama_index.core.callbacks import CallbackManager +from llama_index.core.constants import ( + DEFAULT_CONTEXT_WINDOW, + DEFAULT_NUM_OUTPUTS, +) +from llama_index.core.llms.callbacks import ( + llm_chat_callback, + llm_completion_callback, +) +from llama_index.core.llms.custom import CustomLLM +from llama_index.core.base.llms.generic_utils import ( + completion_response_to_chat_response, + stream_completion_response_to_chat_response, +) +from llama_index.core.base.llms.generic_utils import ( + messages_to_prompt as generic_messages_to_prompt, +) +from llama_index.core.prompts.base import PromptTemplate +from llama_index.core.types import BaseOutputParser, PydanticProgramMode + +import logging + +from typing import Any, Callable, Optional, Union, Sequence + +from mlx_lm import load, generate + +logger = logging.getLogger(__name__) + +DEFAULT_MLX_MODEL = "microsoft/phi-2" + + +class MLXLLM(CustomLLM): + r"""MLX LLM. + + Examples: + Thanks to the HuggingFace team for the example code. + `pip install llama-index-llms-MLXLLM` + + ```python + from llama_index.llms.mlx import MLXLLM + + def messages_to_prompt(messages): + prompt = "" + for message in messages: + if message.role == 'system': + prompt += f"<|system|>\n{message.content}</s>\n" + elif message.role == 'user': + prompt += f"<|user|>\n{message.content}</s>\n" + elif message.role == 'assistant': + prompt += f"<|assistant|>\n{message.content}</s>\n" + + # ensure we start with a system prompt, insert blank if needed + if not prompt.startswith("<|system|>\n"): + prompt = "<|system|>\n</s>\n" + prompt + + # add final assistant prompt + prompt = prompt + "<|assistant|>\n" + + return prompt + + def completion_to_prompt(completion): + return f"<|system|>\n</s>\n<|user|>\n{completion}</s>\n<|assistant|>\n" + + from llama_index.core.prompts import PromptTemplate + from llama_index.llms.mlx import MLXLLM + + + llm = MLXLLM( + model_name="microsoft/phi-2", + tokenizer_name="microsoft/phi-2", + context_window=3900, + max_new_tokens=256, + + generate_kwargs={"temperature": 0.7, "top_k": 50, "top_p": 0.95}, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + + ) + + response = llm.complete("What is the meaning of life?") + print(str(response)) + ``` + """ + + model_name: str = Field( + default=DEFAULT_MLX_MODEL, + description=( + "The model name to use from HuggingFace. " + "Unused if `model` is passed in directly." + ), + ) + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum number of tokens available for input.", + gt=0, + ) + max_new_tokens: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="The maximum number of tokens to generate.", + gt=0, + ) + system_prompt: str = Field( + default="", + description=( + "The system prompt, containing any extra instructions or context. " + "The model card on HuggingFace should specify if this is needed." + ), + ) + query_wrapper_prompt: PromptTemplate = Field( + default=PromptTemplate("{query_str}"), + description=( + "The query wrapper prompt, containing the query placeholder. " + "The model card on HuggingFace should specify if this is needed. " + "Should contain a `{query_str}` placeholder." + ), + ) + + tokenizer_outputs_to_remove: list = Field( + default_factory=list, + description=( + "The outputs to remove from the tokenizer. " + "Sometimes huggingface tokenizers return extra inputs that cause errors." + ), + ) + tokenizer_kwargs: dict = Field( + default_factory=dict, description="The kwargs to pass to the tokenizer." + ) + model_kwargs: dict = Field( + default_factory=dict, + description="The kwargs to pass to the model during initialization.", + ) + generate_kwargs: dict = Field( + default_factory=dict, + description="The kwargs to pass to the model during generation.", + ) + + _model: Any = PrivateAttr() + _tokenizer: Any = PrivateAttr() + _stopping_criteria: Any = PrivateAttr() + + def __init__( + self, + context_window: int = DEFAULT_CONTEXT_WINDOW, + max_new_tokens: int = DEFAULT_NUM_OUTPUTS, + query_wrapper_prompt: Union[str, PromptTemplate] = "{query_str}", + model_name: str = DEFAULT_MLX_MODEL, + model: Optional[Any] = None, + tokenizer: Optional[Any] = None, + tokenizer_kwargs: Optional[dict] = None, + tokenizer_outputs_to_remove: Optional[list] = None, + model_kwargs: Optional[dict] = None, + generate_kwargs: Optional[dict] = None, + callback_manager: Optional[CallbackManager] = None, + system_prompt: str = "", + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, + ) -> None: + """Initialize params.""" + model_kwargs = model_kwargs or {} + if model is None: + self._model, self._tokenizer = load(model_name, **model_kwargs) + else: + self._model = model + self._tokenizer = tokenizer + # check context_window + + tokenizer_kwargs = tokenizer_kwargs or {} + if "max_length" not in tokenizer_kwargs: + tokenizer_kwargs["max_length"] = context_window + + if isinstance(query_wrapper_prompt, str): + query_wrapper_prompt = PromptTemplate(query_wrapper_prompt) + + messages_to_prompt = messages_to_prompt or self._tokenizer_messages_to_prompt + + super().__init__( + context_window=context_window, + max_new_tokens=max_new_tokens, + query_wrapper_prompt=query_wrapper_prompt, + tokenizer_name=model_name, + model_name=model_name, + tokenizer_kwargs=tokenizer_kwargs or {}, + tokenizer_outputs_to_remove=tokenizer_outputs_to_remove or [], + model_kwargs=model_kwargs or {}, + generate_kwargs=generate_kwargs or {}, + callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, + ) + + @classmethod + def class_name(cls) -> str: + return "HuggingFace_LLM" + + @property + def metadata(self) -> LLMMetadata: + """LLM metadata.""" + return LLMMetadata( + context_window=self.context_window, + num_output=self.max_new_tokens, + model_name=self.model_name, + ) + + def _tokenizer_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: + """Use the tokenizer to convert messages to prompt. Fallback to generic.""" + if hasattr(self._tokenizer, "apply_chat_template"): + messages_dict = [ + {"role": message.role.value, "content": message.content} + for message in messages + ] + tokens = self._tokenizer.apply_chat_template(messages_dict) + return self._tokenizer.decode(tokens) + + return generic_messages_to_prompt(messages) + + @llm_completion_callback() + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + """Completion endpoint.""" + full_prompt = prompt + if not formatted: + if self.query_wrapper_prompt: + full_prompt = self.query_wrapper_prompt.format(query_str=prompt) + if self.system_prompt: + full_prompt = f"{self.system_prompt} {full_prompt}" + + completion = generate( + self._model, + self._tokenizer, + full_prompt, + max_tokens=self.max_new_tokens, + **self.generate_kwargs, + ) + tokens = self._tokenizer.encode(completion, return_tensors="pt") + return CompletionResponse(text=completion, raw={"model_output": tokens}) + + @llm_completion_callback() + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + """Streaming completion endpoint.""" + full_prompt = prompt + if not formatted: + if self.query_wrapper_prompt: + full_prompt = self.query_wrapper_prompt.format(query_str=prompt) + if self.system_prompt: + full_prompt = f"{self.system_prompt} {full_prompt}" + + def gen() -> CompletionResponseGen: + text = "" + for x in gen_stream( + self._model, + self._tokenizer, + full_prompt, + max_tokens=self.max_new_tokens, + **self.generate_kwargs, + ): + text += x + yield CompletionResponse(text=text, delta=x) + + return gen() + + @llm_chat_callback() + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + prompt = self.messages_to_prompt(messages) + completion_response = self.complete(prompt, formatted=True, **kwargs) + return completion_response_to_chat_response(completion_response) + + @llm_chat_callback() + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + prompt = self.messages_to_prompt(messages) + completion_response = self.stream_complete(prompt, formatted=True, **kwargs) + return stream_completion_response_to_chat_response(completion_response) diff --git a/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/tokenizer_utils.py b/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/tokenizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..307decedc8b1790eb77cd3fb431122287108369c --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/tokenizer_utils.py @@ -0,0 +1,329 @@ +import json +from functools import partial + +from transformers import AutoTokenizer + +REPLACEMENT_CHAR = "\ufffd" + + +def _remove_space(x): + if x and x[0] == " ": + return x[1:] + return x + + +class StreamingDetokenizer: + """The streaming detokenizer interface so that we can detokenize one token at a time. + + Example usage is as follows: + + detokenizer = ... + + # Reset the tokenizer state + detokenizer.reset() + + for token in generate(...): + detokenizer.add_token(token.item()) + + # Contains the whole text so far. Some tokens may not be included + # since it contains whole words usually. + detokenizer.text + + # Contains the printable segment (usually a word) since the last + # time it was accessed + detokenizer.last_segment + + # Contains all the tokens added so far + detokenizer.tokens + + # Make sure that we detokenize any remaining tokens + detokenizer.finalize() + + # Now detokenizer.text should match tokenizer.decode(detokenizer.tokens) + """ + + __slots__ = ("text", "tokens", "offset") + + def reset(self): + raise NotImplementedError + + def add_token(self, token): + raise NotImplementedError + + def finalize(self): + raise NotImplementedError + + @property + def last_segment(self): + """Return the last segment of readable text since last time this property was accessed.""" + text = self.text + if text and text[-1] != REPLACEMENT_CHAR: + segment = text[self.offset :] + self.offset = len(text) + return segment + return "" + + +class NaiveStreamingDetokenizer(StreamingDetokenizer): + """NaiveStreamingDetokenizer relies on the underlying tokenizer + implementation and should work with every tokenizer. + + Its complexity is O(T^2) where T is the longest line since it will + repeatedly detokenize the same tokens until a new line is generated. + """ + + def __init__(self, tokenizer) -> None: + self._tokenizer = tokenizer + self._tokenizer.decode([0]) + self.reset() + + def reset(self): + self.offset = 0 + self._tokens = [] + self._text = "" + self._current_tokens = [] + self._current_text = "" + + def add_token(self, token): + self._current_tokens.append(token) + + def finalize(self): + self._tokens.extend(self._current_tokens) + self._text += self._tokenizer.decode(self._current_tokens) + self._current_tokens = [] + self._current_text = "" + + @property + def text(self): + if self._current_tokens: + self._current_text = self._tokenizer.decode(self._current_tokens) + if self._current_text and self._current_text[-1] == "\n": + self._tokens.extend(self._current_tokens) + self._text += self._current_text + self._current_tokens.clear() + self._current_text = "" + return self._text + self._current_text + + @property + def tokens(self): + return self._tokens + + +class SPMStreamingDetokenizer(StreamingDetokenizer): + """A streaming detokenizer for SPM models. + + It adds tokens to the text if the next token starts with the special SPM + underscore which results in linear complexity. + """ + + def __init__(self, tokenizer, trim_space=True) -> None: + self.trim_space = trim_space + + # Extract the tokens in a list from id to text + self.tokenmap = [None] * len(tokenizer.vocab) + for value, tokenid in tokenizer.vocab.items(): + self.tokenmap[tokenid] = value + + # Replace bytes with their value + for i in range(len(self.tokenmap)): + if self.tokenmap[i].startswith("<0x"): + self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16)) + + self.reset() + + def reset(self): + self.offset = 0 + self._unflushed = "" + self.text = "" + self.tokens = [] + + def add_token(self, token): + v = self.tokenmap[token] + if v[0] == "\u2581": + if self.text or not self.trim_space: + self.text += self._unflushed.replace("\u2581", " ") + else: + self.text = _remove_space(self._unflushed.replace("\u2581", " ")) + self._unflushed = v + else: + self._unflushed += v + + def finalize(self): + if self.text or not self.trim_space: + self.text += self._unflushed.replace("\u2581", " ") + else: + self.text = _remove_space(self._unflushed.replace("\u2581", " ")) + self._unflushed = "" + + +class BPEStreamingDetokenizer(StreamingDetokenizer): + """A streaming detokenizer for OpenAI style BPE models. + + It adds tokens to the text if the next token starts with a space similar to + the SPM detokenizer. + """ + + _byte_decoder = None + + def __init__(self, tokenizer, trim_space=False) -> None: + self.trim_space = trim_space + + # Extract the tokens in a list from id to text + self.tokenmap = [None] * len(tokenizer.vocab) + for value, tokenid in tokenizer.vocab.items(): + self.tokenmap[tokenid] = value + + self.reset() + + # Make the BPE byte decoder from + # https://github.com/openai/gpt-2/blob/master/src/encoder.py + self.make_byte_decoder() + + def reset(self): + self.offset = 0 + self._unflushed = "" + self.text = "" + self.tokens = [] + + def add_token(self, token): + v = self.tokenmap[token] + # if the token starts with space + if self._byte_decoder[v[0]] == 32: + current_text = bytearray( + self._byte_decoder[c] for c in self._unflushed + ).decode("utf-8") + if self.text or not self.trim_space: + self.text += current_text + else: + self.text += _remove_space(current_text) + self._unflushed = v + else: + self._unflushed += v + + def finalize(self): + current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( + "utf-8" + ) + if self.text or not self.trim_space: + self.text += current_text + else: + self.text += _remove_space(current_text) + self._unflushed = "" + + @classmethod + def make_byte_decoder(cls): + """See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale.""" + if cls._byte_decoder is not None: + return + + char_to_bytes = {} + limits = [ + 0, + ord("!"), + ord("~") + 1, + ord("¡"), + ord("¬") + 1, + ord("®"), + ord("ÿ") + 1, + ] + n = 0 + for i, (start, stop) in enumerate(zip(limits, limits[1:])): + if i % 2 == 0: + for b in range(start, stop): + char_to_bytes[chr(2**8 + n)] = b + n += 1 + else: + for b in range(start, stop): + char_to_bytes[chr(b)] = b + cls._byte_decoder = char_to_bytes + + +class TokenizerWrapper: + """A wrapper that combines an HF tokenizer and a detokenizer. + + Accessing any attribute other than the ``detokenizer`` is forwarded to the + huggingface tokenizer. + """ + + def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer) -> None: + self._tokenizer = tokenizer + self._detokenizer = detokenizer_class(tokenizer) + + def __getattr__(self, attr) -> object: + if attr == "detokenizer": + return self._detokenizer + else: + return getattr(self._tokenizer, attr) + + +def _match(a, b): + if type(a) != type(b): + return False + if isinstance(a, dict): + return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a) + if isinstance(a, list): + return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b)) + + return a == b + + +def _is_spm_decoder(decoder): + _target_description = { + "type": "Sequence", + "decoders": [ + {"type": "Replace", "pattern": {"String": "▁"}, "content": " "}, + {"type": "ByteFallback"}, + {"type": "Fuse"}, + {"type": "Strip", "content": " ", "start": 1, "stop": 0}, + ], + } + return _match(_target_description, decoder) + + +def _is_spm_decoder_no_space(decoder): + _target_description = { + "type": "Sequence", + "decoders": [ + {"type": "Replace", "pattern": {"String": "▁"}, "content": " "}, + {"type": "ByteFallback"}, + {"type": "Fuse"}, + ], + } + return _match(_target_description, decoder) + + +def _is_bpe_decoder(decoder): + _target_description = { + "type": "ByteLevel", + "add_prefix_space": False, + "trim_offsets": False, + "use_regex": False, + } + + return _match(_target_description, decoder) + + +def load_tokenizer(model_path, tokenizer_config_extra={}): + """Load a huggingface tokenizer and try to infer the type of streaming + detokenizer to use. + + Note, to use a fast streaming tokenizer, pass a local file path rather than + a Hugging Face repo ID. + """ + detokenizer_class = NaiveStreamingDetokenizer + + tokenizer_file = model_path / "tokenizer.json" + if tokenizer_file.exists(): + tokenizer_content = json.load(tokenizer_file.open()) + if "decoder" in tokenizer_content: + if _is_spm_decoder(tokenizer_content["decoder"]): + detokenizer_class = SPMStreamingDetokenizer + elif _is_spm_decoder_no_space(tokenizer_content["decoder"]): + detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False) + elif _is_bpe_decoder(tokenizer_content["decoder"]): + detokenizer_class = BPEStreamingDetokenizer + + return TokenizerWrapper( + AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), + detokenizer_class, + ) diff --git a/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/utils.py b/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec22f2fd5775b50d2c251a499fe83065f662aee2 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mlx/llama_index/llms/mlx/utils.py @@ -0,0 +1,268 @@ +from datetime import time +from typing import Dict, Optional, Tuple, Generator, Union, Callable, Any + +import mlx.core as mx +import mlx.nn as nn +from .tokenizer_utils import TokenizerWrapper +from transformers import PreTrainedTokenizer + + +def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float): + """ + Apply repetition penalty to specific logits based on the given context. + + Paper: https://arxiv.org/abs/1909.05858 + + Args: + logits (mx.array): The logits produced by the language model. + generated_tokens (any): A list of N previous tokens. + penalty (float): The repetition penalty factor to be applied. + + Returns: + logits (mx.array): Logits with repetition penalty applied to generated tokens. + """ + if len(generated_tokens) > 0: + indices = mx.array([list(generated_tokens)]) + selected_logits = logits[:, indices] + selected_logits = mx.where( + selected_logits < 0, selected_logits * penalty, selected_logits / penalty + ) + logits[:, indices] = selected_logits + return logits + + +def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: + """ + Apply top-p (nucleus) sampling to logits. + + Args: + logits: The logits from the model's output. + top_p: The cumulative probability threshold for top-p filtering. + temperature: Temperature parameter for softmax distribution reshaping. + + Returns: + token selected based on the top-p criterion. + """ + # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 + probs = mx.softmax(logits / temperature, axis=-1) + + # sort probs in ascending order + sorted_indices = mx.argsort(probs, axis=-1) + sorted_probs = probs[..., sorted_indices.squeeze(0)] + + cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + + # select tokens with cumulative probs below threshold + top_probs = mx.where( + cumulative_probs > 1 - top_p, + sorted_probs, + mx.zeros_like(sorted_probs), + ) + + sorted_token = mx.random.categorical(mx.log(top_probs)) + + return sorted_indices.squeeze(0)[sorted_token] + + +def generate_step( + prompt: mx.array, + model: nn.Module, + temp: float = 0.0, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = 20, + top_p: float = 1.0, + logit_bias: Optional[Dict[int, float]] = None, +) -> Generator[Tuple[mx.array, mx.array], None, None]: + """ + A generator producing text based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + temp (float): The temperature for sampling, if 0 the argmax is used. + repetition_penalty (float, optional): The penalty factor for repeating tokens. + repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20). + top_p (float, optional): Nulceus sampling, higher means model considers more less likely words + + Yields: + Generator[Tuple[mx.array, mx.array]]: A generator producing + one token and probability per call. + """ + + def sample(logits: mx.array) -> Tuple[mx.array, float]: + logits = logits + logit_bias if logit_bias else logits + softmax_logits = mx.softmax(logits) + + if temp == 0: + token = mx.argmax(logits, axis=-1) + else: + if top_p > 0 and top_p < 1.0: + token = top_p_sampling(logits, top_p, temp) + else: + token = mx.random.categorical(logits * (1 / temp)) + + prob = softmax_logits[0, token] + return token, prob + + if repetition_penalty and ( + repetition_penalty < 0 or not isinstance(repetition_penalty, float) + ): + raise ValueError( + f"repetition_penalty must be a non-negative float, got {repetition_penalty}" + ) + + y = prompt + cache = None + + repetition_context = prompt.tolist() + + if repetition_context_size: + repetition_context = repetition_context[-repetition_context_size:] + + def _step(y): + nonlocal cache, repetition_context + logits, cache = model(y[None], cache=cache) + logits = logits[:, -1, :] + + if repetition_penalty: + logits = apply_repetition_penalty( + logits, repetition_context, repetition_penalty + ) + y, prob = sample(logits) + repetition_context.append(y.item()) + else: + y, prob = sample(logits) + + if repetition_context_size: + if len(repetition_context) > repetition_context_size: + repetition_context = repetition_context[-repetition_context_size:] + return y, prob + + y, p = _step(y) + + mx.async_eval(y) + while True: + next_y, next_p = _step(y) + mx.async_eval(next_y) + yield y.item(), p + y, p = next_y, next_p + + +def gen_full( + model: nn.Module, + tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], + prompt: str, + temp: float = 0.0, + max_tokens: int = 100, + formatter: Optional[Callable] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = None, + top_p: float = 1.0, + logit_bias: Optional[Dict[int, float]] = None, +) -> str: + """ + Generate text from the model. + + Args: + model (nn.Module): The language model. + tokenizer (PreTrainedTokenizer): The tokenizer. + prompt (str): The string prompt. + temp (float): The temperature for sampling (default 0). + max_tokens (int): The maximum number of tokens (default 100). + verbose (bool): If ``True``, print tokens and timing information + (default ``False``). + formatter (Optional[Callable]): A function which takes a token and a + probability and displays it. + repetition_penalty (float, optional): The penalty factor for repeating tokens. + repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. + """ + if not isinstance(tokenizer, TokenizerWrapper): + tokenizer = TokenizerWrapper(tokenizer) + + prompt_tokens = mx.array(tokenizer.encode(prompt)) + detokenizer = tokenizer.detokenizer + + tic = time.perf_counter() + detokenizer.reset() + + for (token, prob), n in zip( + generate_step( + prompt_tokens, + model, + temp, + repetition_penalty, + repetition_context_size, + top_p, + logit_bias, + ), + range(max_tokens), + ): + if n == 0: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + if token == tokenizer.eos_token_id: + break + detokenizer.add_token(token) + + detokenizer.finalize() + + return detokenizer.text + + +def gen_stream( + model: nn.Module, + tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], + prompt: str, + temp: float = 0.0, + max_tokens: int = 100, + formatter: Optional[Callable] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = None, + top_p: float = 1.0, + logit_bias: Optional[Dict[int, float]] = None, +) -> str: + """ + Generate text from the model. + + Args: + model (nn.Module): The language model. + tokenizer (PreTrainedTokenizer): The tokenizer. + prompt (str): The string prompt. + temp (float): The temperature for sampling (default 0). + max_tokens (int): The maximum number of tokens (default 100). + verbose (bool): If ``True``, print tokens and timing information + (default ``False``). + formatter (Optional[Callable]): A function which takes a token and a + probability and displays it. + repetition_penalty (float, optional): The penalty factor for repeating tokens. + repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. + """ + if not isinstance(tokenizer, TokenizerWrapper): + tokenizer = TokenizerWrapper(tokenizer) + + prompt_tokens = mx.array(tokenizer.encode(prompt)) + detokenizer = tokenizer.detokenizer + + detokenizer.reset() + + for (token, prob), n in zip( + generate_step( + prompt_tokens, + model, + temp, + repetition_penalty, + repetition_context_size, + top_p, + logit_bias, + ), + range(max_tokens), + ): + if token == tokenizer.eos_token_id: + break + detokenizer.add_token(token) + if formatter: + # We have to finalize so that the prob corresponds to the last segment + detokenizer.finalize() + formatter(detokenizer.last_segment, prob.item()) + else: + yield detokenizer.last_segment diff --git a/llama-index-integrations/llms/llama-index-llms-mlx/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-mlx/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..99631f994f7fb709bd491b39017bca4cf3863fdf --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-mlx/pyproject.toml @@ -0,0 +1,58 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +# Feel free to un-skip examples, and experimental, you will just need to +# work through many typos (--write-changes and --interactive will help) +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.llms.mlx" + +[tool.llamahub.class_authors] +MLX = "dwight-foster" + +[tool.mypy] +disallow_untyped_defs = true +# Remove venv skip when integrated with pre-commit +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["Dwight Foster <dwightfoster03@gmail.com>"] +description = "llama-index llms mlx integration" +exclude = ["**/BUILD"] +license = "MIT" +name = "llama-index-llms-mlx" +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.10,<4.0" +llama-index-core = "^0.10.0" +mlx-lm = ">=0.11.0" +mlx = ">=0.11.0" + +[tool.poetry.group.dev.dependencies] +black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} +codespell = {extras = ["toml"], version = ">=v2.2.6"} +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" +types-setuptools = "67.1.0.0"