From f9ab57c594ffb5ee524323abd354aedbbd3f2ced Mon Sep 17 00:00:00 2001
From: Shengsheng Huang <shannie.huang@gmail.com>
Date: Tue, 9 Apr 2024 02:18:20 +0800
Subject: [PATCH] [community] add more data types support to `ipex-llm` llm
 integration (#12635)

---
 docs/docs/examples/llm/ipex_llm.ipynb         |  4 +-
 .../llms/llama-index-llms-ipex-llm/README.md  | 21 ++++++-
 .../llama-index-llms-ipex-llm/examples/BUILD  |  1 +
 .../examples/README.md                        | 25 ++++++++
 .../examples/more_data_type.py                | 58 +++++++++++++++++++
 .../llama_index/llms/ipex_llm/base.py         | 51 ++++++++++++----
 .../llama-index-llms-ipex-llm/pyproject.toml  |  2 +-
 7 files changed, 148 insertions(+), 14 deletions(-)
 create mode 100644 llama-index-integrations/llms/llama-index-llms-ipex-llm/examples/BUILD
 create mode 100644 llama-index-integrations/llms/llama-index-llms-ipex-llm/examples/README.md
 create mode 100644 llama-index-integrations/llms/llama-index-llms-ipex-llm/examples/more_data_type.py

diff --git a/docs/docs/examples/llm/ipex_llm.ipynb b/docs/docs/examples/llm/ipex_llm.ipynb
index 172099cceb..18b72b1752 100644
--- a/docs/docs/examples/llm/ipex_llm.ipynb
+++ b/docs/docs/examples/llm/ipex_llm.ipynb
@@ -8,7 +8,9 @@
     "\n",
     "> [IPEX-LLM](https://github.com/intel-analytics/ipex-llm/) is a PyTorch library for running LLM on Intel CPU and GPU (e.g., local PC with iGPU, discrete GPU such as Arc, Flex and Max) with very low latency.\n",
     "\n",
-    "This example goes over how to use LlamaIndex to interact with [`ipex-llm`](https://github.com/intel-analytics/ipex-llm/) for text generation and chat on CPU."
+    "This example goes over how to use LlamaIndex to interact with [`ipex-llm`](https://github.com/intel-analytics/ipex-llm/) for text generation and chat on CPU. \n",
+    "\n",
+    "For more examples and usage, refer to [Examples](https://github.com/run-llama/llama_index/tree/main/llama-index-integrations/llms/llama-index-llms-ipex-llm/examples)."
    ]
   },
   {
diff --git a/llama-index-integrations/llms/llama-index-llms-ipex-llm/README.md b/llama-index-integrations/llms/llama-index-llms-ipex-llm/README.md
index d0698ba4cd..e982c22f40 100644
--- a/llama-index-integrations/llms/llama-index-llms-ipex-llm/README.md
+++ b/llama-index-integrations/llms/llama-index-llms-ipex-llm/README.md
@@ -1,3 +1,22 @@
 # LlamaIndex Llms Integration: IPEX-LLM
 
-[IPEX-LLM](https://github.com/intel-analytics/ipex-llm) is a PyTorch library for running LLM on Intel CPU and GPU (e.g., local PC with iGPU, discrete GPU such as Arc, Flex and Max) with very low latency. This module allows loading LLMs with ipex-llm optimizations.
+[IPEX-LLM](https://github.com/intel-analytics/ipex-llm) is a PyTorch library for running LLM on Intel CPU and GPU (e.g., local PC with iGPU, discrete GPU such as Arc, Flex and Max) with very low latency. This module enables the use of LLMs optimized with `ipex-llm` in LlamaIndex pipelines.
+
+## Installation
+
+### On CPU
+
+```bash
+pip install llama-index-llms-ipex-llm
+```
+
+## Usage
+
+```python
+from llama_index.llms.ipex_llm import IpexLLM
+```
+
+## Examples
+
+- [Notebook Example](https://docs.llamaindex.ai/en/stable/examples/llm/ipex_llm/)
+- [More Examples](https://github.com/run-llama/llama_index/tree/main/llama-index-integrations/llms/llama-index-llms-ipex-llm/examples)
diff --git a/llama-index-integrations/llms/llama-index-llms-ipex-llm/examples/BUILD b/llama-index-integrations/llms/llama-index-llms-ipex-llm/examples/BUILD
new file mode 100644
index 0000000000..db46e8d6c9
--- /dev/null
+++ b/llama-index-integrations/llms/llama-index-llms-ipex-llm/examples/BUILD
@@ -0,0 +1 @@
+python_sources()
diff --git a/llama-index-integrations/llms/llama-index-llms-ipex-llm/examples/README.md b/llama-index-integrations/llms/llama-index-llms-ipex-llm/examples/README.md
new file mode 100644
index 0000000000..674ebbabac
--- /dev/null
+++ b/llama-index-integrations/llms/llama-index-llms-ipex-llm/examples/README.md
@@ -0,0 +1,25 @@
+# IpexLLM Examples
+
+This folder contains examples showcasing how to use LlamaIndex with `ipex-llm` LLM integration `llama_index.llms.ipex_llm.IpexLLM`.
+
+## Installation
+
+### On CPU
+
+Install `llama-index-llms-ipex-llm`. This will also install `ipex-llm` and its dependencies.
+
+```bash
+pip install llama-index-llms-ipex-llm
+```
+
+## List of Examples
+
+### More Data Types Example
+
+By default, `IpexLLM` loads the model in int4 format. To load a model in different data formats like `sym_int5`, `sym_int8`, etc., you can use the `load_in_low_bit` option in `IpexLLM`.
+
+The example [more_data_type.py](./more_data_type.py) shows how to use the `load_in_low_bit` option. Run the example as following:
+
+```bash
+python more_data_type.py -m <path_to_model> -t <path_to_tokenizer> -l <low_bit_format>
+```
diff --git a/llama-index-integrations/llms/llama-index-llms-ipex-llm/examples/more_data_type.py b/llama-index-integrations/llms/llama-index-llms-ipex-llm/examples/more_data_type.py
new file mode 100644
index 0000000000..45ffdf2b57
--- /dev/null
+++ b/llama-index-integrations/llms/llama-index-llms-ipex-llm/examples/more_data_type.py
@@ -0,0 +1,58 @@
+import argparse
+from llama_index.llms.ipex_llm import IpexLLM
+
+
+# Transform a string into input llama2-specific input
+def completion_to_prompt(completion):
+    return f"<s>[INST] <<SYS>>\n    \n<</SYS>>\n\n{completion} [/INST]"
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="More Data Types Example")
+    parser.add_argument(
+        "--model-name",
+        "-m",
+        type=str,
+        default="meta-llama/Llama-2-7b-hf",
+        help="The huggingface repo id for the large language model to be downloaded"
+        ", or the path to the huggingface checkpoint folder",
+    )
+    parser.add_argument(
+        "--tokenizer-name",
+        "-t",
+        type=str,
+        default="meta-llama/Llama-2-7b-hf",
+        help="The huggingface repo id or the path to the checkpoint containing the tokenizer"
+        "usually it is the same as the model_name",
+    )
+    parser.add_argument(
+        "--low-bit",
+        "-l",
+        type=str,
+        default="asym_int4",
+        choices=["sym_int4", "asym_int4", "sym_int5", "asym_int5", "sym_int8"],
+        help="The quantization type the model will convert to.",
+    )
+
+    args = parser.parse_args()
+    model_name = args.model_name
+    tokenizer_name = args.tokenizer_name
+    low_bit = args.low_bit
+
+    # load the model using low-bit format specified
+    llm = IpexLLM(
+        model_name=model_name,
+        tokenizer_name=tokenizer_name,
+        context_window=512,
+        max_new_tokens=64,
+        load_in_low_bit=low_bit,
+        completion_to_prompt=completion_to_prompt,
+        generate_kwargs={"temperature": 0.7, "do_sample": False},
+    )
+
+    print(
+        "\n----------------------- Text Stream Completion ---------------------------"
+    )
+    response_iter = llm.stream_complete("Explain what is AI?")
+    for response in response_iter:
+        print(response.delta, end="", flush=True)
diff --git a/llama-index-integrations/llms/llama-index-llms-ipex-llm/llama_index/llms/ipex_llm/base.py b/llama-index-integrations/llms/llama-index-llms-ipex-llm/llama_index/llms/ipex_llm/base.py
index 8806773492..7101385c4a 100644
--- a/llama-index-integrations/llms/llama-index-llms-ipex-llm/llama_index/llms/ipex_llm/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-ipex-llm/llama_index/llms/ipex_llm/base.py
@@ -3,7 +3,6 @@ from threading import Thread
 from typing import Any, Callable, List, Optional, Sequence
 
 import torch
-
 from llama_index.core.base.llms.types import (
     ChatMessage,
     ChatResponse,
@@ -59,6 +58,20 @@ class IpexLLM(CustomLLM):
             "Unused if `model` is passed in directly."
         ),
     )
+    load_in_4bit: bool = Field(
+        default=True,
+        description=(
+            "Whether to load model in 4bit." "Unused if `load_in_low_bit` is not None."
+        ),
+    )
+    load_in_low_bit: str = Field(
+        default=None,
+        description=(
+            "Which low bit precisions to use when loading model. "
+            "Example values: 'sym_int4', 'asym_int4', 'fp4', 'nf4', 'fp8', etc."
+            "Will override `load_in_4bit` if this is specified."
+        ),
+    )
     context_window: int = Field(
         default=DEFAULT_CONTEXT_WINDOW,
         description="The maximum number of tokens available for input.",
@@ -124,6 +137,8 @@ class IpexLLM(CustomLLM):
         max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
         tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL,
         model_name: str = DEFAULT_HUGGINGFACE_MODEL,
+        load_in_4bit: Optional[bool] = True,
+        load_in_low_bit: Optional[str] = None,
         model: Optional[Any] = None,
         tokenizer: Optional[Any] = None,
         device_map: Optional[str] = "auto",
@@ -176,19 +191,33 @@ class IpexLLM(CustomLLM):
             self._model = model
         else:
             try:
-                self._model = AutoModelForCausalLM.from_pretrained(
-                    model_name,
-                    load_in_4bit=True,
-                    use_cache=True,
-                    trust_remote_code=True,
-                    **model_kwargs,
-                )
+                if load_in_low_bit:
+                    self._model = AutoModelForCausalLM.from_pretrained(
+                        model_name,
+                        load_in_low_bit=load_in_low_bit,
+                        use_cache=True,
+                        trust_remote_code=True,
+                        **model_kwargs,
+                    )
+                else:
+                    self._model = AutoModelForCausalLM.from_pretrained(
+                        model_name,
+                        load_in_4bit=load_in_4bit,
+                        use_cache=True,
+                        trust_remote_code=True,
+                        **model_kwargs,
+                    )
             except Exception:
                 from ipex_llm.transformers import AutoModel
 
-                self._model = AutoModel.from_pretrained(
-                    model_name, load_in_4bit=True, **model_kwargs
-                )
+                if load_in_low_bit:
+                    self._model = AutoModel.from_pretrained(
+                        model_name, load_in_low_bit=load_in_low_bit, **model_kwargs
+                    )
+                else:
+                    self._model = AutoModel.from_pretrained(
+                        model_name, load_in_4bit=load_in_4bit, **model_kwargs
+                    )
 
         if "xpu" in device_map:
             self._model = self._model.to(device_map)
diff --git a/llama-index-integrations/llms/llama-index-llms-ipex-llm/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-ipex-llm/pyproject.toml
index 99a1ba65f0..1b7dfb4a51 100644
--- a/llama-index-integrations/llms/llama-index-llms-ipex-llm/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-ipex-llm/pyproject.toml
@@ -30,7 +30,7 @@ license = "MIT"
 name = "llama-index-llms-ipex-llm"
 packages = [{include = "llama_index/"}]
 readme = "README.md"
-version = "0.1.0"
+version = "0.1.1"
 
 [tool.poetry.dependencies]
 python = ">=3.9,<4.0"
-- 
GitLab