diff --git a/LICENSE b/LICENSE deleted file mode 100644 index bbe189a3dae4e4931ce90eccd8e0497336b5790a..0000000000000000000000000000000000000000 --- a/LICENSE +++ /dev/null @@ -1,125 +0,0 @@ -LLAMA 2 COMMUNITY LICENSE AGREEMENT -Llama 2 Version Release Date: July 18, 2023 - -"Agreement" means the terms and conditions for use, reproduction, distribution and -modification of the Llama Materials set forth herein. - -"Documentation" means the specifications, manuals and documentation -accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and- -libraries/llama-downloads/. - -"Licensee" or "you" means you, or your employer or any other person or entity (if -you are entering into this Agreement on such person or entity's behalf), of the age -required under applicable laws, rules or regulations to provide legal consent and that -has legal authority to bind your employer or such other person or entity if you are -entering in this Agreement on their behalf. - -"Llama 2" means the foundational large language models and software and -algorithms, including machine-learning model code, trained model weights, -inference-enabling code, training-enabling code, fine-tuning enabling code and other -elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and- -libraries/llama-downloads/. - -"Llama Materials" means, collectively, Meta's proprietary Llama 2 and -Documentation (and any portion thereof) made available under this Agreement. - -"Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you -are an entity, your principal place of business is in the EEA or Switzerland) and Meta -Platforms, Inc. (if you are located outside of the EEA or Switzerland). - -By clicking "I Accept" below or by using or distributing any portion or element of the -Llama Materials, you agree to be bound by this Agreement. - -1. License Rights and Redistribution. - - a. Grant of Rights. You are granted a non-exclusive, worldwide, non- -transferable and royalty-free limited license under Meta's intellectual property or -other rights owned by Meta embodied in the Llama Materials to use, reproduce, -distribute, copy, create derivative works of, and make modifications to the Llama -Materials. - - b. Redistribution and Use. - - i. If you distribute or make the Llama Materials, or any derivative works -thereof, available to a third party, you shall provide a copy of this Agreement to such -third party. - ii. If you receive Llama Materials, or any derivative works thereof, from -a Licensee as part of an integrated end user product, then Section 2 of this -Agreement will not apply to you. - - iii. You must retain in all copies of the Llama Materials that you -distribute the following attribution notice within a "Notice" text file distributed as a -part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License, -Copyright (c) Meta Platforms, Inc. All Rights Reserved." - - iv. Your use of the Llama Materials must comply with applicable laws -and regulations (including trade compliance laws and regulations) and adhere to the -Acceptable Use Policy for the Llama Materials (available at -https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into -this Agreement. - - v. You will not use the Llama Materials or any output or results of the -Llama Materials to improve any other large language model (excluding Llama 2 or -derivative works thereof). - -2. Additional Commercial Terms. If, on the Llama 2 version release date, the -monthly active users of the products or services made available by or for Licensee, -or Licensee's affiliates, is greater than 700 million monthly active users in the -preceding calendar month, you must request a license from Meta, which Meta may -grant to you in its sole discretion, and you are not authorized to exercise any of the -rights under this Agreement unless or until Meta otherwise expressly grants you -such rights. - -3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE -LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE -PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, -EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY -WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR -FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE -FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING -THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR -USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS. - -4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE -LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, -NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS -AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, -CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN -IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF -ANY OF THE FOREGOING. - -5. Intellectual Property. - - a. No trademark licenses are granted under this Agreement, and in -connection with the Llama Materials, neither Meta nor Licensee may use any name -or mark owned by or associated with the other or any of its affiliates, except as -required for reasonable and customary use in describing and redistributing the -Llama Materials. - - b. Subject to Meta's ownership of Llama Materials and derivatives made by or -for Meta, with respect to any derivative works and modifications of the Llama -Materials that are made by you, as between you and Meta, you are and will be the -owner of such derivative works and modifications. - - c. If you institute litigation or other proceedings against Meta or any entity -(including a cross-claim or counterclaim in a lawsuit) alleging that the Llama -Materials or Llama 2 outputs or results, or any portion of any of the foregoing, -constitutes infringement of intellectual property or other rights owned or licensable -by you, then any licenses granted to you under this Agreement shall terminate as of -the date such litigation or claim is filed or instituted. You will indemnify and hold -harmless Meta from and against any claim by any third party arising out of or related -to your use or distribution of the Llama Materials. - -6. Term and Termination. The term of this Agreement will commence upon your -acceptance of this Agreement or access to the Llama Materials and will continue in -full force and effect until terminated in accordance with the terms and conditions -herein. Meta may terminate this Agreement if you are in breach of any term or -condition of this Agreement. Upon termination of this Agreement, you shall delete -and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the -termination of this Agreement. - -7. Governing Law and Jurisdiction. This Agreement will be governed and -construed under the laws of the State of California without regard to choice of law -principles, and the UN Convention on Contracts for the International Sale of Goods -does not apply to this Agreement. The courts of California shall have exclusive -jurisdiction of any dispute arising out of this Agreement. diff --git a/README.md b/README.md index 45a48ff35f0a8e65092c4d41683ecab08c055893..cc73483771c64c6a50ddddf116d6533504953172 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,29 @@ # Llama Recipes: Examples to get started using the Llama models from Meta - -The 'llama-recipes' repository is a companion to the [Llama 2 model](https://github.com/facebookresearch/llama). The goal of this repository is to provide a scalable library for fine-tuning Llama 2, along with some example scripts and notebooks to quickly get started with using the Llama 2 models in a variety of use-cases, including fine-tuning for domain adaptation and building LLM-based applications with Llama 2 and other tools in the LLM ecosystem. The examples here showcase how to run Llama 2 locally, in the cloud, and on-prem. - +<!-- markdown-link-check-disable --> +The 'llama-recipes' repository is a companion to the [Meta Llama 2](https://github.com/meta-llama/llama) and [Meta Llama 3](https://github.com/meta-llama/llama3) models. The goal of this repository is to provide a scalable library for fine-tuning Meta Llama models, along with some example scripts and notebooks to quickly get started with using the models in a variety of use-cases, including fine-tuning for domain adaptation and building LLM-based applications with Meta Llama and other tools in the LLM ecosystem. The examples here showcase how to run Meta Llama locally, in the cloud, and on-prem. +<!-- markdown-link-check-enable --> +> [!IMPORTANT] +> Llama 3 has a new prompt template and special tokens (based on the tiktoken tokenizer). +> | Token | Description | +> |---|---| +> `<\|begin_of_text\|>` | This is equivalent to the BOS token. | +> `<\|eot_id\|>` | This signifies the end of the message in a turn. | +> `<\|start_header_id\|>{role}<\|end_header_id\|>` | These tokens enclose the role for a particular message. The possible roles can be: system, user, assistant. | +> `<\|end_of_text\|>` | This is equivalent to the EOS token. On generating this token, Llama 3 will cease to generate more tokens | +> +> A multiturn-conversation with Llama 3 follows this prompt template: +> ``` +> <|begin_of_text|><|start_header_id|>system<|end_header_id|> +> +> {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|> +> +> {{ user_message_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|> +> +> {{ model_answer_1 }}<|eot_id|><|start_header_id|>user<|end_header_id|> +> +> {{ user_message_2 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|> +> ``` +> More details on the new tokenizer and prompt template: <PLACEHOLDER_URL> > [!NOTE] > The llama-recipes repository was recently refactored to promote a better developer experience of using the examples. Some files have been moved to new locations. The `src/` folder has NOT been modified, so the functionality of this repo and package is not impacted. > @@ -9,7 +31,7 @@ The 'llama-recipes' repository is a companion to the [Llama 2 model](https://git ## Table of Contents -- [Llama Recipes: Examples to get started using the Llama models from Meta](#llama-recipes-examples-to-get-started-using-the-llama-models-from-meta) +- [Llama Recipes: Examples to get started using the Meta Llama models from Meta](#llama-recipes-examples-to-get-started-using-the-llama-models-from-meta) - [Table of Contents](#table-of-contents) - [Getting Started](#getting-started) - [Prerequisites](#prerequisites) @@ -76,11 +98,11 @@ pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 -e .[t ``` -### Getting the Llama models -You can find Llama 2 models on Hugging Face hub [here](https://huggingface.co/meta-llama), **where models with `hf` in the name are already converted to Hugging Face checkpoints so no further conversion is needed**. The conversion step below is only for original model weights from Meta that are hosted on Hugging Face model hub as well. +### Getting the Meta Llama models +You can find Meta Llama models on Hugging Face hub [here](https://huggingface.co/meta-llama), **where models with `hf` in the name are already converted to Hugging Face checkpoints so no further conversion is needed**. The conversion step below is only for original model weights from Meta that are hosted on Hugging Face model hub as well. #### Model conversion to Hugging Face -The recipes and notebooks in this folder are using the Llama 2 model definition provided by Hugging Face's transformers library. +The recipes and notebooks in this folder are using the Meta Llama model definition provided by Hugging Face's transformers library. Given that the original checkpoint resides under models/7B you can install all requirements and convert the checkpoint with: @@ -105,15 +127,15 @@ Most of the code dealing with Llama usage is organized across 2 main folders: `r Contains examples are organized in folders by topic: | Subfolder | Description | |---|---| -[quickstart](./recipes/quickstart) | The "Hello World" of using Llama2, start here if you are new to using Llama2. -[finetuning](./recipes/finetuning)|Scripts to finetune Llama2 on single-GPU and multi-GPU setups -[inference](./recipes/inference)|Scripts to deploy Llama2 for inference locally and using model servers +[quickstart](./recipes/quickstart) | The "Hello World" of using Llama, start here if you are new to using Llama. +[finetuning](./recipes/finetuning)|Scripts to finetune Llama on single-GPU and multi-GPU setups +[inference](./recipes/inference)|Scripts to deploy Llama for inference locally and using model servers [use_cases](./recipes/use_cases)|Scripts showing common applications of Llama2 [responsible_ai](./recipes/responsible_ai)|Scripts to use PurpleLlama for safeguarding model outputs [llama_api_providers](./recipes/llama_api_providers)|Scripts to run inference on Llama via hosted endpoints -[benchmarks](./recipes/benchmarks)|Scripts to benchmark Llama 2 models inference on various backends +[benchmarks](./recipes/benchmarks)|Scripts to benchmark Llama models inference on various backends [code_llama](./recipes/code_llama)|Scripts to run inference with the Code Llama models -[evaluation](./recipes/evaluation)|Scripts to evaluate fine-tuned Llama2 models using `lm-evaluation-harness` from `EleutherAI` +[evaluation](./recipes/evaluation)|Scripts to evaluate fine-tuned Llama models using `lm-evaluation-harness` from `EleutherAI` ### `src/` @@ -133,5 +155,8 @@ Contains modules which support the example recipes: Please read [CONTRIBUTING.md](CONTRIBUTING.md) for details on our code of conduct, and the process for submitting pull requests to us. ## License -See the License file [here](LICENSE) and Acceptable Use Policy [here](USE_POLICY.md) +<!-- markdown-link-check-disable --> +See the License file for Meta Llama 2 [here](https://llama.meta.com/llama2/license/) and Acceptable Use Policy [here](https://llama.meta.com/llama2/use-policy/) +See the License file for Meta Llama 3 [here](https://llama.meta.com/llama3/license/) and Acceptable Use Policy [here](https://llama.meta.com/llama3/use-policy/) +<!-- markdown-link-check-enable --> diff --git a/USE_POLICY.md b/USE_POLICY.md deleted file mode 100644 index 4299e1d150581897ebca4706eeb537d236626b3d..0000000000000000000000000000000000000000 --- a/USE_POLICY.md +++ /dev/null @@ -1,49 +0,0 @@ -# Llama 2 Acceptable Use Policy - -Meta is committed to promoting safe and fair use of its tools and features, including Llama 2. If you access or use Llama 2, you agree to this Acceptable Use Policy (“Policyâ€). The most recent copy of this policy can be found at [ai.meta.com/llama/use-policy](http://ai.meta.com/llama/use-policy). - -## Prohibited Uses -We want everyone to use Llama 2 safely and responsibly. You agree you will not use, or allow others to use, Llama 2 to: - -1. Violate the law or others’ rights, including to: - 1. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as: - 1. Violence or terrorism - 2. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material - 3. Human trafficking, exploitation, and sexual violence - 4. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials. - 5. Sexual solicitation - 6. Any other criminal activity - 2. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals - 3. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services - 4. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices - 5. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws - 6. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any products or services using the Llama 2 Materials - 7. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system - - - -2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of Llama 2 related to the following: - 1. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State - 2. Guns and illegal weapons (including weapon development) - 3. Illegal drugs and regulated/controlled substances - 4. Operation of critical infrastructure, transportation technologies, or heavy machinery - 5. Self-harm or harm to others, including suicide, cutting, and eating disorders - 6. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual - - - -3. Intentionally deceive or mislead others, including use of Llama 2 related to the following: - 1. Generating, promoting, or furthering fraud or the creation or promotion of disinformation - 2. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content - 3. Generating, promoting, or further distributing spam - 4. Impersonating another individual without consent, authorization, or legal right - 5. Representing that the use of Llama 2 or outputs are human-generated - 6. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement -4. Fail to appropriately disclose to end users any known dangers of your AI system - -Please report any violation of this Policy, software “bug,†or other problems that could lead to a violation of this Policy through one of the following means: - -* Reporting issues with the model: [github.com/facebookresearch/llama](http://github.com/facebookresearch/llama) -* Reporting risky content generated by the model: [developers.facebook.com/llama_output_feedback](http://developers.facebook.com/llama_output_feedback) -* Reporting bugs and security concerns: [facebook.com/whitehat/info](http://facebook.com/whitehat/info) -* Reporting violations of the Acceptable Use Policy or unlicensed uses of Llama: [LlamaUseReport@meta.com](mailto:LlamaUseReport@meta.com) diff --git a/recipes/README.md b/recipes/README.md index 88d3a10c42b8454c321e2e68136e659de14eaa99..e373d4cb30a22f67147177e93d58b3d8bfebfc28 100644 --- a/recipes/README.md +++ b/recipes/README.md @@ -2,7 +2,8 @@ This folder contains examples organized by topic: | Subfolder | Description | |---|---| -[quickstart](./quickstart) | The "Hello World" of using Llama2, start here if you are new to using Llama2. +[quickstart](./quickstart)|The "Hello World" of using Llama2, start here if you are new to using Llama2 +[multilingual](./multilingual)|Scripts to add a new language to Llama2 [finetuning](./finetuning)|Scripts to finetune Llama2 on single-GPU and multi-GPU setups [inference](./inference)|Scripts to deploy Llama2 for inference locally and using model servers [use_cases](./use_cases)|Scripts showing common applications of Llama2 diff --git a/recipes/finetuning/datasets/custom_dataset.py b/recipes/finetuning/datasets/custom_dataset.py index 18fa960e6354e28f87c6a12834899f4bf3a5d7e0..d80494da4307c2a64d79cc1a1304203c31d7ed3b 100644 --- a/recipes/finetuning/datasets/custom_dataset.py +++ b/recipes/finetuning/datasets/custom_dataset.py @@ -11,11 +11,27 @@ import itertools B_INST, E_INST = "[INST]", "[/INST]" def tokenize_dialog(dialog, tokenizer): - prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]] - answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]] - dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens))) - #Add labels, convert prompt token to -100 in order to ignore in loss function - labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)] + if tokenizer.vocab_size >= 128000: + dialog_tokens = tokenizer.apply_chat_template(dialog) + dialog_tokens = dialog_tokens[:-4] # Remove generation prompt <|start_header_id|>assistant<|end_header_id|>\n\n + eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009] + labels = copy.copy(dialog_tokens) + last_idx = 0 + for n, idx in enumerate(eot_indices): + if n % 2 == 1: + last_idx = idx + else: + labels[last_idx:idx+1] = [-100] * (idx-last_idx+1) + + dialog_tokens = [dialog_tokens] + labels_tokens = [labels] + else: + prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]] + answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]] + dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens))) + + #Add labels, convert prompt token to -100 in order to ignore in loss function + labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)] combined_tokens = { "input_ids": list(itertools.chain(*(t for t in dialog_tokens))), diff --git a/recipes/inference/local_inference/chat_completion/chat_completion.py b/recipes/inference/local_inference/chat_completion/chat_completion.py index 7756c1cfce09ad1b2564c644510a3fa941328f3e..1395326ecb8269fc4bb428e70951c2537189502c 100644 --- a/recipes/inference/local_inference/chat_completion/chat_completion.py +++ b/recipes/inference/local_inference/chat_completion/chat_completion.py @@ -8,9 +8,9 @@ import os import sys import torch -from transformers import LlamaTokenizer +from transformers import AutoTokenizer -from llama_recipes.inference.chat_utils import read_dialogs_from_file, format_tokens +from llama_recipes.inference.chat_utils import read_dialogs_from_file from llama_recipes.inference.model_utils import load_model, load_peft_model from llama_recipes.inference.safety_utils import get_safety_checker from accelerate.utils import is_xpu_available @@ -65,15 +65,15 @@ def main( if peft_model: model = load_peft_model(model, peft_model) - tokenizer = LlamaTokenizer.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.add_special_tokens( { - + "pad_token": "<PAD>", } ) - - chats = format_tokens(dialogs, tokenizer) + + chats = tokenizer.apply_chat_template(dialogs) with torch.no_grad(): for idx, chat in enumerate(chats): diff --git a/recipes/inference/local_inference/inference.py b/recipes/inference/local_inference/inference.py index 4f83c8f2caec1fbac1558a2182671fe5de67f24e..2f81f82f2bfa9265a9364e2504290d9abd34eed9 100644 --- a/recipes/inference/local_inference/inference.py +++ b/recipes/inference/local_inference/inference.py @@ -10,7 +10,7 @@ import time import gradio as gr import torch -from transformers import LlamaTokenizer +from transformers import AutoTokenizer from llama_recipes.inference.safety_utils import get_safety_checker, AgentType from llama_recipes.inference.model_utils import load_model, load_peft_model @@ -69,17 +69,16 @@ def main( else: torch.cuda.manual_seed(seed) torch.manual_seed(seed) - + model = load_model(model_name, quantization, use_fast_kernels) if peft_model: model = load_peft_model(model, peft_model) model.eval() - - tokenizer = LlamaTokenizer.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token - + batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt") if is_xpu_available(): batch = {k: v.to("xpu") for k, v in batch.items()} diff --git a/recipes/multilingual/README.md b/recipes/multilingual/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d4fb7c97badcf5bfb49d873c0c9899f46bb5fbaa --- /dev/null +++ b/recipes/multilingual/README.md @@ -0,0 +1,156 @@ +# Extending Llama to a new language +Authored by : Sarvam team +In this recipe, we will see how to add a new language to the Llama family of models. The steps are quite general and can be easily adapted to other models as well. Using this recipe, you should be able to replicate the findings of [OpenHathi](https://huggingface.co/sarvamai/OpenHathi-7B-Hi-v0.1-Base). +Please read more about OpenHathi [here](https://www.sarvam.ai/blog/announcing-openhathi-series) +## Data +The original OpenHathi model uses a combination of [Sangraha](https://huggingface.co/datasets/ai4bharat/sangraha) and Wikipedia as its primary data sources. If the reader is interested in using these sources, they would also have to preprocess the data: clean, filter, and deduplicate. See [Setu](https://github.com/AI4Bharat/setu) for an easy way to do this at scale. + +In this tutorial, we will use the [Varta](https://huggingface.co/datasets/rahular/varta) dataset which contains 40M+ news articles taken from [DailyHunt](https://m.dailyhunt.in/). Since this data is already high-quality, we can skip the pre-processing step mentioned above. We will use the Hindi subset here, but you can add any other language present in the dataset by only passing the right language code (advanced users can also tweak the code to add multiple languages at once). + +## Tokenizer +Our first step towards augmenting a new language to an LLM is creating a better tokenizer. We define 'better' in terms of fertility score or the number of in-language tokens present in the tokenizer. Note that we should add new tokens without disturbing the original vocabulary, and therefore creating a better tokenizer usually involves 2 steps: (i) building a new, in-language only tokenizer, and (ii) merging this new tokenizer with the original. + +### Building the in-language tokenizer +For this, we will first download and prepare the data for training the tokenizer: + +``` +python prepare_data.py --split=validation --lang=hi --docs_to_sample=10000 --save_path=./data +``` + +Here we sample 10,000 Hindi documents from the validation split (we should ideally sample from the training split, but this is much faster) and save it as a text file inside `./data`. Next, we use this text to train a Hindi-only [sentencepiece](https://github.com/google/sentencepiece) tokenizer with a vocabulary size of 16,000. + +``` +python train_tokenizer.py --data_file=./data/hi.txt --save_path=./hi_tokenizer --vocab_size=16000 +``` + +This creates a new sentencepiece Hindi tokenizer and saves it in `./hi_tokenizer`. + +### Merging the tokenizers +This process can again be divided into 2 steps: +- add new tokens to the original Llama2 tokenizer without disturbing its original vocabulary in any way +- expand the input and output embedding matrices of Llama2 to be equal to the new vocabulary size + +We can do the first step by (i) downloading Llama2's `tokenizer.model` file, (ii) loading our Hindi `tokenizer.model` file, (iii) appending the Hindi tokens to Llama2 tokenizer's vocabulary if they are not already present, and (iv) save the extended tokenizer for future use. All this can be done by running + +``` +python extend_tokenizer.py --new_tokenizer_path=./hi_tokenizer --extended_tokenizer_save_path=./extended_tokenizer +``` + +Now, you have a new Llama2 tokenizer which works the same way on English text but can efficiently tokenize Hindi words as well. You can also test to see if it works as intended: + +``` +>>> from transformers import LlamaTokenizer +>>> llama_tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf') +>>> our_tokenizer = LlamaTokenizer.from_pretrained('./extended_tokenizer') +>>> for i in range(len(llama_tokenizer)): +... assert llama_tokenizer.convert_ids_to_tokens(i) == our_tokenizer.convert_ids_to_tokens(i), f"Token mismatch at index {i}." +... +>>> text = "मैं à¤à¤• अचà¥à¤›à¤¾ हाथी हूà¤" +>>> llama_tokenizer.tokenize(text) +['â–', 'म', 'ै', 'ं', 'â–', '<0xE0>', '<0xA4>', '<0x8F>', 'क', 'â–', 'अ', 'च', 'à¥', '<0xE0>', '<0xA4>', '<0x9B>', 'ा', 'â–', 'ह', 'ा', 'थ', 'ी', 'â–', 'ह', 'ू', '<0xE0>', '<0xA4>', '<0x81>'] +>>> our_tokenizer.tokenize(text) +['â–मैं', 'â–à¤à¤•', 'â–अच', 'à¥', 'छा', 'â–हाथी', 'â–हूà¤'] +``` + +## Continual pre-training +OpenHathi uses a two-stage pre-training process: +- Phase 1: learn to translate paragraphs of text (use translated text as context and generate the original text, ~15B tokens) +- Phase 2: bilingual next token prediction (train on text where the language changes after every sentence, ~15B tokens) + +Note: OpenHathi's final data mixture also contains monolingual data and romanized transliterations. + +We can easily create data for both phases using any translation model. OpenHathi uses [IndicTrans2](https://github.com/AI4Bharat/IndicTrans2). We provide sample code for both phases below. + +### Phase 1 +With the assumption that we don't have source-native data, let us first get some English data to translate. + +``` +from datasets import load_dataset +ds = load_dataset("rahular/varta", split="train", streaming=True) +english_paragraphs = [] +for d in ds: + if d["langCode"] != "en": continue + english_paragraphs.append(" ".join(d["text"].split("\n"))) +``` + +Now, our goal is to create data in the format `{translated_paragraph}\n\n{english_paragraph}`. We can use the `translate_paragraph` function ([link](https://github.com/AI4Bharat/IndicTrans2/blob/main/huggingface_interface/example.py#L150])) from the IndicTrans2 codebase to do this easily. + +``` +quantization = "" +en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B" +en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, "en-indic", quantization) +ip = IndicProcessor(inference=True) + +phase1_data = [] +for para in english_paragraphs: + trans_para = translate_paragraph(para, "eng_Latn", "hin_Deva", en_indic_model, en_indic_tokenizer, ip) + phase1_data.append({"text": f"{trans_para}\n\n{para}"}) + +# if you want to save it for future, you can do so easily with HF datasets +from datasets import Dataset +phase1_ds = Dataset.from_list(phase1_data) +phase1_ds.save_to_disk("data/phase1") +``` + +### Phase 2 +This is almost the same as phase 1, except that we have to replace the original sentences in an alternating manner to get the data in the required format. We can use the `split_sentences` ([link](https://github.com/AI4Bharat/IndicTrans2/blob/main/huggingface_interface/example.py#L60])) and `batch_translate` ([link](https://github.com/AI4Bharat/IndicTrans2/blob/main/huggingface_interface/example.py#L109)) functions to do this. + +``` +quantization = "" +en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B" +en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, "en-indic", quantization) +ip = IndicProcessor(inference=True) + +phase2_data = [] +for para in english_paragraphs: + en_sents = split_sentences(para, "eng_Latn") + trans_sents = batch_translate(input_sentences, "eng_Latn", "hin_Deva, en_indic_model, en_indic_tokenizer, ip) + final_para = [] + for idx, (en_sent, trans_sent) in enumerate(zip(en_sents, trans_sents)): + sent_to_append = en_sent if idx % 2 == 0 else trans_sent + final_para.append(sent_to_append) + phase2_data.append({"text": " ".join(final_para)}) + +# if you want to save it for future, you can do so easily with HF datasets +from datasets import Dataset +phase2_ds = Dataset.from_list(phase2_data) +phase2_ds.save_to_disk("data/phase2") +``` + +### Train +Finally, we can start finetuning Llama2 on these datasets by following the [finetuning recipes](https://github.com/meta-llama/llama-recipes/tree/main/recipes/finetuning). Remember to pass the new tokenizer path as an argument to the script: `--tokenizer_name=./extended_tokenizer`. + +OpenHathi was trained on 64 A100 80GB GPUs. Here are the hyperparameters used and other training details: +- maximum learning rate: 2e-4 +- minimum learning rate: 2e-6 +- optimizer: AdamW (weight decay = 0.1) +- beta1: 0.9 +- beta2: 0.95 +- lora rank: 128 +- lora alpha: 64 +- lora trainable: q_proj, v_proj, k_proj, o_proj, gate_proj, down_proj, up_proj +- lora dropout: 0.05 +- block size: 4096 +- global batch size: 4M tokens +- input and output embeddings are trainable +- lr schedule: cosine decay with warmup (warmup ratio = 0.1, number of cycles = 3) +- deepspeed stage 2 +- dtype: bfloat16 + +The resulting (partial) loss plots from the OpenHathi training are shown below: + +Phase 1: train loss + + + +Phase 1: eval loss + + + +Phase 2: train loss + + + +Phase 2: eval loss + + diff --git a/recipes/multilingual/extend_tokenizer.py b/recipes/multilingual/extend_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1e2b3d53d0103f9c860cbb438c2f0913c0f45e28 --- /dev/null +++ b/recipes/multilingual/extend_tokenizer.py @@ -0,0 +1,52 @@ +""" +Code borrowed from https://github.com/ymcui/Chinese-LLaMA-Alpaca/blob/main/scripts/merge_tokenizer/merge_tokenizers.py +""" + +import os +import fire +import re +from transformers import LlamaTokenizer + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" +from huggingface_hub import hf_hub_download +from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model + + +def main(new_tokenizer_path, extended_tokenizer_save_path): + original_tokenizer_path = hf_hub_download(repo_id="meta-llama/Llama-2-7b-chat-hf", filename="tokenizer.model", local_dir="original_tokenizer") + original_tokenizer_spm = sp_pb2_model.ModelProto() + original_tokenizer_spm.ParseFromString(open(original_tokenizer_path, "rb").read()) + new_tokenizer_spm = sp_pb2_model.ModelProto() + new_tokenizer_spm.ParseFromString(open(os.path.join(new_tokenizer_path, "tokenizer.model"), "rb").read()) + + def contains_eng(text): + eng_pattern = re.compile(r"[\u0020-\u007E]+") + return True if eng_pattern.search(text) else False + + original_tokenizer_tokenset = set(p.piece for p in original_tokenizer_spm.pieces) + print(f"Number of tokens before merge: {len(original_tokenizer_tokenset)}") + for p in new_tokenizer_spm.pieces: + piece = p.piece + if piece not in original_tokenizer_tokenset and not contains_eng(piece): + new_p = sp_pb2_model.ModelProto().SentencePiece() + new_p.piece = piece + new_p.score = 0 + original_tokenizer_spm.pieces.append(new_p) + print(f"Number of tokens after merge: {len(original_tokenizer_spm.pieces)}") + + os.makedirs(extended_tokenizer_save_path, exist_ok=True) + with open(os.path.join(extended_tokenizer_save_path, "tokenizer.model"), "wb") as f: + f.write(original_tokenizer_spm.SerializeToString()) + tokenizer = LlamaTokenizer(vocab_file=os.path.join(extended_tokenizer_save_path, "tokenizer.model"), legacy=False) + tokenizer.save_pretrained(extended_tokenizer_save_path) + print(f"Tokenizer saved to {extended_tokenizer_save_path}") + + # Verify that the extended tokenizer's English vocab matches with that of the original Llama tokenizer + tok1 = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf') + tok2 = LlamaTokenizer.from_pretrained(extended_tokenizer_save_path) + for i in range(len(tok1)): + assert tok1.convert_ids_to_tokens(i) == tok2.convert_ids_to_tokens(i), f"Token mismatch at index {i}." + + +if __name__ == "__main__": + fire.Fire(main) \ No newline at end of file diff --git a/recipes/multilingual/imgs/phase1-eval-loss.png b/recipes/multilingual/imgs/phase1-eval-loss.png new file mode 100644 index 0000000000000000000000000000000000000000..a1a0492edfd7f72dc02b6e77821578ceab3566d2 Binary files /dev/null and b/recipes/multilingual/imgs/phase1-eval-loss.png differ diff --git a/recipes/multilingual/imgs/phase1-train-loss.png b/recipes/multilingual/imgs/phase1-train-loss.png new file mode 100644 index 0000000000000000000000000000000000000000..ca6ffc8c2b0790f839d302fe80578839b6d725e2 Binary files /dev/null and b/recipes/multilingual/imgs/phase1-train-loss.png differ diff --git a/recipes/multilingual/imgs/phase2-eval-loss.png b/recipes/multilingual/imgs/phase2-eval-loss.png new file mode 100644 index 0000000000000000000000000000000000000000..ab49c624aabeefec67cb68d81dceb7e461c3ddb8 Binary files /dev/null and b/recipes/multilingual/imgs/phase2-eval-loss.png differ diff --git a/recipes/multilingual/imgs/phase2-train-loss.png b/recipes/multilingual/imgs/phase2-train-loss.png new file mode 100644 index 0000000000000000000000000000000000000000..4e242ea05aa1295c8ff34bd01bd73eacc9a53347 Binary files /dev/null and b/recipes/multilingual/imgs/phase2-train-loss.png differ diff --git a/recipes/multilingual/prepare_data.py b/recipes/multilingual/prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..340dcafcb4b85c3a107b31dbd47d413fde160303 --- /dev/null +++ b/recipes/multilingual/prepare_data.py @@ -0,0 +1,23 @@ +import fire +import os +from datasets import load_dataset + +DATASET = "rahular/varta" + +def main(split="validation", lang="hi", docs_to_sample=10_000, save_path="data"): + dataset = load_dataset(DATASET, split=split, streaming=True) + os.makedirs(save_path, exist_ok=True) + with open(os.path.join(save_path, f"{lang}.txt"), "w") as f: + count = 0 + for idx, d in enumerate(dataset): + if idx % 10_000 == 0: + print(f"Searched {idx} documents for {lang} documents. Found {count} documents.") + if count >= docs_to_sample: + break + if d["langCode"] == lang: + f.write(d["headline"] + "\n" + d["text"] + "\n") + count += 1 + + +if __name__ == "__main__": + fire.Fire(main) \ No newline at end of file diff --git a/recipes/multilingual/train_tokenizer.py b/recipes/multilingual/train_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f0319f08862b947b2d8a51694fe332116cf319d9 --- /dev/null +++ b/recipes/multilingual/train_tokenizer.py @@ -0,0 +1,22 @@ +import fire +import os +import sentencepiece as spm + +def main(data_file, save_path, vocab_size=16_000, num_threads=8): + os.makedirs(save_path, exist_ok=True) + tokenizer_name = os.path.join(save_path, "tokenizer") + + spm.SentencePieceTrainer.train( + input=data_file, + model_prefix=tokenizer_name, + vocab_size=vocab_size, + num_threads=num_threads, + model_type="bpe", + max_sentence_length=1073741824, + shuffle_input_sentence="true", + character_coverage=1.0, + hard_vocab_limit="false", + ) + +if __name__ == "__main__": + fire.Fire(main) diff --git a/recipes/responsible_ai/CodeShieldUsageDemo.ipynb b/recipes/responsible_ai/CodeShieldUsageDemo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..2866bac9053eb368858ddc4340d779cbc4b4afaa --- /dev/null +++ b/recipes/responsible_ai/CodeShieldUsageDemo.ipynb @@ -0,0 +1,196 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# [CodeShield](https://github.com/meta-llama/PurpleLlama/tree/main/CodeShield) Usage Walkthrough\n", + "\n", + "This notebook shows examples of how to use CodeShield. For further information, see the main repository README [here](https://github.com/meta-llama/PurpleLlama/tree/main/CodeShield).\n", + "\n", + "# Getting Started \n", + "\n", + "Either install via PyPi using pip, or install it locally from source. \n", + "\n", + "#### Install Option 1. Install CodeShield package. Run the following in your terminal\n", + "\n", + "```\n", + "pip3 install codeshield\n", + "```\n", + "\n", + "#### Install Option 2. Install CodeShield package locally from source code\n", + "\n", + "```\n", + "git clone https://github.com/meta-llama/PurpleLlama\n", + "cd PurpleLlama/CodeShield\n", + "pip install .\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Examples\n", + "\n", + "Define a helper function to run scan and process output." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from codeshield.cs import CodeShield\n", + "\n", + "async def scan_llm_output(llm_output_code):\n", + " result = await CodeShield.scan_code(llm_output_code)\n", + " if result.is_insecure:\n", + " # perform actions based on treatment recommendation\n", + " if result.recommended_treatment == \"block\":\n", + " llm_output_code = \"*** Code Security issues found, blocking the code ***\"\n", + " if result.recommended_treatment == \"warn\":\n", + " llm_output_code = llm_output_code + \"*** Warning: The generated snippit contains insecure code ***\"\n", + " \n", + " \n", + " summary = \"Security issue detected\" if result.is_insecure else \"No issues found\"\n", + " print(\"__LLM output after treatment___\")\n", + " print(llm_output_code)\n", + " print (\"__Results__\")\n", + " print(summary)\n", + " print(result.recommended_treatment)\n", + " print (\"__Details__\")\n", + " print(result.issues_found)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 1 - Simple hard-coded example\n", + "In this example, we are taking a hard coded example where the code snippet has weak hash md5 function usage. " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "__LLM output after treatment___\n", + "\n", + "def hashString(input):\n", + " return hashlib.md5(input)\n", + "*** Warning: The generated snippit contains insecure code ***\n", + "__Results__\n", + "Security issue detected\n", + "Treatment.WARN\n", + "__Details__\n", + "[Issue(description='Use of weak hashing algorithm', cwe_id='CWE-327', severity=<Severity.WARNING: 'warning'>, rule='\\\\.getMd5Digest\\\\(\\\\)|\\\\.md5\\\\(|\\\\.md5Hex\\\\(|\\\\.getInstance\\\\(\"(MD5|md5)\"', line=3, path=None, char=None, name=None, original=None, replacement=None, analyzer=<Analyzer.REGEX: 'regex'>, pattern_id='weak-md5-hashing'), Issue(description='Use of a Broken or Risky Cryptographic Algorithm', cwe_id='CWE-327', severity=<Severity.WARNING: 'warning'>, rule='\\\\b(md5|sha1)\\\\s*\\\\(', line=3, path=None, char=None, name=None, original=None, replacement=None, analyzer=<Analyzer.REGEX: 'regex'>, pattern_id='risky-crypto-algorithm'), Issue(description='The MD5 hash function is considered insecure. Avoid using it unless explicitly needed for compatibility reasons', cwe_id='CWE-328', severity=<Severity.WARNING: 'warning'>, rule='\\\\bhashlib\\\\.md5\\\\(', line=3, path=None, char=None, name=None, original=None, replacement=None, analyzer=<Analyzer.REGEX: 'regex'>, pattern_id='insecure-md5-hash-usage')]\n" + ] + } + ], + "source": [ + "llm_output_code = \"\"\"\n", + "def hashString(input):\n", + " return hashlib.md5(input)\n", + "\"\"\"\n", + "\n", + "await scan_llm_output(llm_output_code)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 2 - use openAI API\n", + "\n", + "Requires openai package (pip install openai)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"please generate some example code to demonstrate strcpy usage\"\n", + "\n", + "import openai\n", + "client = openai.OpenAI(api_key=\"YOUR_OPEN_AI_KEY\")\n", + "response = client.chat.completions.create(\n", + " model= \"gpt-3.5-turbo\",\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": prompt},\n", + " ],\n", + " max_tokens=1000,\n", + ")\n", + "\n", + "await scan_llm_output(response.choices[0].message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 3 - use externally hosted LLM \n", + "\n", + "Requires [llama-recipes package](https://github.com/meta-llama/llama-recipes)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import getpass\n", + "\n", + "from llama_recipes.inference.llm import TOGETHER, OPENAI, ANYSCALE\n", + "\n", + "if \"EXTERNALLY_HOSTED_LLM_TOKEN\" not in os.environ:\n", + " os.environ[\"EXTERNALLY_HOSTED_LLM_TOKEN\"] = getpass.getpass(prompt=\"Provide token for LLM provider\")\n", + "\n", + "# Delete as appropriate\n", + "model = TOGETHER(\"togethercomputer/CodeLlama-13b-Instruct\", os.environ[\"EXTERNALLY_HOSTED_LLM_TOKEN\"])\n", + "model = OPENAI(\"gpt-4\",os.environ[\"EXTERNALLY_HOSTED_LLM_TOKEN\"])\n", + "model = ANYSCALE(\"codellama/CodeLlama-34b-Instruct-hf\",os.environ[\"EXTERNALLY_HOSTED_LLM_TOKEN\"])\n", + "\n", + "llm_output_code = model.query_with_system_prompt_with_retries(\n", + " system_prompt= \"You are an expert code developer. You output only code and nothing else\", \n", + " prompt= \"Output a single python function which calculates the md5 hash of a string provided as an argument to the function. Output only the code and nothing else.\"\n", + " )\n", + "await scan_llm_output(llm_output_code)" + ] + } + ], + "metadata": { + "fileHeader": "", + "fileUid": "a811f690-1583-439b-98c3-98bd7eb9880c", + "isAdHoc": false, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/recipes/responsible_ai/README.md b/recipes/responsible_ai/README.md index 128dcfd567a436d0d2961f71b3cee984231cd52a..e268f85b5b68b77111d290ecc4feb1d99f7509f8 100644 --- a/recipes/responsible_ai/README.md +++ b/recipes/responsible_ai/README.md @@ -1,11 +1,11 @@ -# Llama Guard +# Meta Llama Guard -Llama Guard is a new experimental model that provides input and output guardrails for LLM deployments. For more details, please visit the main [repository](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard). +Meta Llama Guard and Meta Llama Guard 2 are new models that provide input and output guardrails for LLM inference. For more details, please visit the main [repository](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard2). -**Note** Please find the right model on HF side [here](https://huggingface.co/meta-llama/LlamaGuard-7b). +**Note** Please find the right model on HF side [here](https://huggingface.co/meta-llama/Meta-Llama-Guard-2-8B). ### Running locally -The [llama_guard](llama_guard) folder contains the inference script to run Llama Guard locally. Add test prompts directly to the [inference script](llama_guard/inference.py) before running it. +The [llama_guard](llama_guard) folder contains the inference script to run Meta Llama Guard locally. Add test prompts directly to the [inference script](llama_guard/inference.py) before running it. ### Running on the cloud -The notebooks [Purple_Llama_Anyscale](Purple_Llama_Anyscale.ipynb) & [Purple_Llama_OctoAI](Purple_Llama_OctoAI.ipynb) contain examples for running Llama Guard on cloud hosted endpoints. \ No newline at end of file +The notebooks [Purple_Llama_Anyscale](Purple_Llama_Anyscale.ipynb) & [Purple_Llama_OctoAI](Purple_Llama_OctoAI.ipynb) contain examples for running Meta Llama Guard on cloud hosted endpoints. \ No newline at end of file diff --git a/recipes/responsible_ai/input_output_guardrails_with_llama.ipynb b/recipes/responsible_ai/input_output_guardrails_with_llama.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..94f5a782d61fa7d1fa14ef691ba9ea685f63ad09 --- /dev/null +++ b/recipes/responsible_ai/input_output_guardrails_with_llama.ipynb @@ -0,0 +1,268 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Tq-YFYlqWL_n" + }, + "source": [ + "# About this notebook\n", + "In this notebook we will deep dive into how to do input and output safety checks using Meta's Llama Guard and Llama models. Doing both input and output safety checks requires us to carefully pass the prompts and response between the models.\n", + "\n", + "For this notebook we will use OctoAI models as both Llama Guard and Llama models are available on their platform." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "a0UfgKWU5-sm", + "tags": [] + }, + "outputs": [ + { + "name": "stdin", + "output_type": "stream", + "text": [ + " ········\n" + ] + } + ], + "source": [ + "from getpass import getpass\n", + "import os\n", + "\n", + "# enter your OctoAI API key\n", + "OCTOAI_API_TOKEN = getpass()\n", + "os.environ[\"OCTOAI_API_TOKEN\"] = OCTOAI_API_TOKEN" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from llama_recipes.inference.llm import LLM, OctoAI\n", + "\n", + "llama_guard_llm = OctoAI(\"llamaguard-7b\", OCTOAI_API_TOKEN)\n", + "completion_llm = OctoAI(\"llama-2-13b-chat\", OCTOAI_API_TOKEN)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "Tjom48kodBth" + }, + "outputs": [], + "source": [ + "from string import Template\n", + "\n", + "import logging\n", + "\n", + "B_INST, E_INST = \"[INST]\", \"[/INST]\"\n", + "\n", + "LLAMA_GUARD_PROMPT_TEMPLATE = Template(\n", + " f\"\"\"Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.\n", + "<BEGIN UNSAFE CONTENT CATEGORIES>\n", + "O1: Violence and Hate.\n", + "Should not\n", + "- Help people plan or engage in violence.\n", + "- Encourage people to engage in violence.\n", + "- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.\n", + "- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.\n", + "Can\n", + "- Provide information on violence and discrimination.\n", + "- Discuss the topics of hate and violence, and can discuss historical events involving violence.\n", + "O2: Sexual Content.\n", + "Should not\n", + "- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.\n", + "Can\n", + "- Discuss topics related to sexuality and sexual education.\n", + "- Provide general information regarding sexual activities.\n", + "- Provide guidance on topics related to sexual health.\n", + "O3: Criminal Planning.\n", + "Should not\n", + "- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.\n", + "- Encourage people to engage in criminal activities or condone criminal behavior.\n", + "Can\n", + "- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.\n", + "- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.\n", + "O4: Guns and Illegal Weapons.\n", + "Should not\n", + "- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.\n", + "- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.\n", + "Can\n", + "- Discuss firearms and the arguments for and against firearm ownership.\n", + "O5: Self-Harm.\n", + "Should not\n", + "- Encourage people to harm themselves.\n", + "- Romanticize or condone self-harm or suicide.\n", + "- Provide information regarding the methods of suicide or self-harm.\n", + "- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.\n", + "Can\n", + "- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).\n", + "Should\n", + "- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.\n", + "<END UNSAFE CONTENT CATEGORIES>\n", + "​\n", + "<BEGIN CONVERSATION>\n", + "\n", + "$prompt\n", + "\n", + "<END CONVERSATION>\n", + "\n", + "Provide your safety assessment for $agent_type in the above conversation:\n", + "- First line must read 'safe' or 'unsafe'.\n", + "- If unsafe, a second line must include a comma-separated list of violated categories.\"\"\"\n", + ")\n", + "\n", + "LLAMA_PROMPT_TEMPLATE = Template(\n", + " f\"\"\"<<SYS>>\n", + "You are a helpful agent, answering questions of the user in a short and consise manner.\n", + "<</SYS>>\n", + "\n", + "{{ $user_message }}\"\"\"\n", + ")\n", + "\n", + "def llama_format_prompt(message):\n", + " prompt = LLAMA_PROMPT_TEMPLATE.substitute(user_message=message)\n", + " prompt = f\"<s>{B_INST} {prompt.strip()} {E_INST}\"\n", + " return prompt\n", + "\n", + "def llama_guard_format_prompt(role, message):\n", + " prompt = LLAMA_GUARD_PROMPT_TEMPLATE.substitute(prompt=message, agent_type=role)\n", + " prompt = f\"<s>{B_INST} {prompt.strip()} {E_INST}\"\n", + " return prompt\n", + "\n", + "def get_category_description(category):\n", + " CATEGORY_MAP = {\n", + " '01': 'Violence and Hate',\n", + " '02': 'Sexual Content',\n", + " '03': 'Criminal Planning',\n", + " '04': 'Guns and Illegal Weapons',\n", + " '05': 'Self-Harm'\n", + " }\n", + "\n", + " if not isinstance(category, str) or category is None:\n", + " raise ValueError(\"Invalid input, expected a non-empty string\")\n", + "\n", + " if category not in CATEGORY_MAP:\n", + " raise ValueError(f\"Invalid input, unknown category: {category}\")\n", + "\n", + " return CATEGORY_MAP[category]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "KwONCw1oSroO" + }, + "outputs": [], + "source": [ + "def parse_llamaguard_response(message):\n", + " if not isinstance(message, str) or message is None:\n", + " return (False, \"Invalid input, expected a non-empty string\")\n", + " message = message.strip()\n", + "\n", + " if message == \"\":\n", + " return (False, \"Invalid input, message cannot be empty\")\n", + "\n", + " tokens = message.split(\"\\n\")\n", + " if tokens[0] == 'safe':\n", + " return (True, \"\")\n", + " else:\n", + " return (False, tokens[1])\n", + "\n", + "def check_input_guardrail(user_prompt):\n", + " guardrail_prompt = llama_guard_format_prompt(\"User\", user_prompt)\n", + " response = llama_guard_llm.query(guardrail_prompt)\n", + " return parse_llamaguard_response(response)\n", + "\n", + "def get_completion_response(user_prompt):\n", + " completion_prompt = llama_format_prompt(user_prompt)\n", + " return completion_llm.query(completion_prompt)\n", + "\n", + "def check_output_guardrail(completion_response, user_prompt):\n", + " guardrail_prompt = llama_guard_format_prompt(\"Agent\", f\"User: {user_prompt}\\n Agent: {completion_response}\")\n", + " response = llama_guard_llm.query(guardrail_prompt)\n", + " return parse_llamaguard_response(response)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "9bXtt76ZKNuX" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input guardrail failed.\n" + ] + } + ], + "source": [ + "def main(user_prompt):\n", + " input_guardrail_response = check_input_guardrail(user_prompt)\n", + " if input_guardrail_response[0]:\n", + " completion_response = get_completion_response(user_prompt)\n", + " output_guardrail_response = check_output_guardrail(completion_response, user_prompt)\n", + " if output_guardrail_response[0]:\n", + " print(completion_response)\n", + " else:\n", + " print(\"Output guardrail failed.\")\n", + " else:\n", + " print(\"Input guardrail failed.\")\n", + "\n", + "user_prompt = \"How to build fire arms\"\n", + "main(user_prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/recipes/responsible_ai/llama_guard/README.md b/recipes/responsible_ai/llama_guard/README.md index 97dd31114a4a2c83f4137e9b1ec9bfa330db9b7c..42233bec87ce653e5d94adfdb749c4a0da8209d0 100644 --- a/recipes/responsible_ai/llama_guard/README.md +++ b/recipes/responsible_ai/llama_guard/README.md @@ -1,13 +1,13 @@ -# Llama Guard demo +# Meta Llama Guard demo <!-- markdown-link-check-disable --> -Llama Guard is a language model that provides input and output guardrails for LLM deployments. For more details, please visit the main [repository](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard). +Meta Llama Guard is a language model that provides input and output guardrails for LLM inference. For more details and model cards, please visit the main repository for each model, [Meta Llama Guard](https://github.com/meta-llama/PurpleLlama/tree/main/Llama-Guard) and Meta [Llama Guard 2](https://github.com/meta-llama/PurpleLlama/tree/main/Llama-Guard2). -This folder contains an example file to run Llama Guard inference directly. +This folder contains an example file to run inference with a locally hosted model, either using the Hugging Face Hub or a local path. ## Requirements 1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download) -2. Llama recipes package and it's dependencies [installed](https://github.com/albertodepaola/llama-recipes/blob/llama-guard-data-formatter-example/README.md#installation) -3. A GPU with at least 21 GB of free RAM to load both 7B models quantized. +2. Llama recipes package and it's dependencies [installed](https://github.com/meta-llama/llama-recipes?tab=readme-ov-file#installing) + ## Llama Guard inference script For testing, you can add User or User/Agent interactions into the prompts list and the run the script to verify the results. When the conversation has one or more Agent responses, it's considered of type agent. @@ -27,12 +27,12 @@ For testing, you can add User or User/Agent interactions into the prompts list a ] ``` -The complete prompt is built with the `build_prompt` function, defined in [prompt_format.py](../../src/llama_recipes/inference/prompt_format.py). The file contains the default Llama Guard categories. These categories can adjusted and new ones can be added, as described in the [research paper](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/), on section 4.5 Studying the adaptability of the model. +The complete prompt is built with the `build_custom_prompt` function, defined in [prompt_format.py](../../../src/llama_recipes/inference/prompt_format_utils.py). The file contains the default Meta Llama Guard categories. These categories can adjusted and new ones can be added, as described in the [research paper](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/), on section 4.5 Studying the adaptability of the model. <!-- markdown-link-check-enable --> To run the samples, with all the dependencies installed, execute this command: -`python examples/llama_guard/inference.py` +`python recipes/responsible_ai/llama_guard/inference.py` This is the output: @@ -53,8 +53,14 @@ This is the output: ================================== ``` +To run it with a local model, you can use the `model_id` param in the inference script: + +`python recipes/responsible_ai/llama_guard/inference.py --model_id=/home/ubuntu/models/llama3/llama_guard_2-hf/ --llama_guard_version=LLAMA_GUARD_2` + +Note: Make sure to also add the llama_guard_version if when it does not match the default, the script allows you to run the prompt format from Meta Llama Guard 1 on Meta Llama Guard 2 + ## Inference Safety Checker -When running the regular inference script with prompts, Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Llama Guard is always loaded quantized using Hugging Face Transformers library. +When running the regular inference script with prompts, Meta Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Meta Llama Guard is always loaded quantized using Hugging Face Transformers library with bitsandbytes. In this case, the default categories are applied by the tokenizer, using the `apply_chat_template` method. diff --git a/recipes/responsible_ai/llama_guard/inference.py b/recipes/responsible_ai/llama_guard/inference.py index 4685bd3b6ffff437bc110bc471ef16579b1c2b38..abfee7603527529c83e5fabf785d105b9a1d4be1 100644 --- a/recipes/responsible_ai/llama_guard/inference.py +++ b/recipes/responsible_ai/llama_guard/inference.py @@ -2,10 +2,10 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import fire -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig -from llama_recipes.inference.prompt_format_utils import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY +from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion from typing import List, Tuple from enum import Enum @@ -13,20 +13,25 @@ class AgentType(Enum): AGENT = "Agent" USER = "User" -def main(): +def main( + model_id: str = "meta-llama/LlamaGuard-7b", + llama_guard_version: LlamaGuardVersion = LlamaGuardVersion.LLAMA_GUARD_1 +): """ - Entry point of the program for generating text using a pretrained model. + Entry point for Llama Guard inference sample script. + + This function loads Llama Guard from Hugging Face or a local model and + executes the predefined prompts in the script to showcase how to do inference with Llama Guard. + Args: - ckpt_dir (str): The directory containing checkpoint files for the pretrained model. - tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding. - temperature (float, optional): The temperature value for controlling randomness in generation. - Defaults to 0.6. - top_p (float, optional): The top-p sampling parameter for controlling diversity in generation. - Defaults to 0.9. - max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 128. - max_gen_len (int, optional): The maximum length of generated sequences. Defaults to 64. - max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 4. + model_id (str): The ID of the pretrained model to use for generation. This can be either the path to a local folder containing the model files, + or the repository ID of a model hosted on the Hugging Face Hub. Defaults to 'meta-llama/LlamaGuard-7b'. + llama_guard_version (LlamaGuardVersion): The version of the Llama Guard model to use for formatting prompts. Defaults to LLAMA_GUARD_1. """ + try: + llama_guard_version = LlamaGuardVersion[llama_guard_version] + except KeyError as e: + raise ValueError(f"Invalid Llama Guard version '{llama_guard_version}'. Valid values are: {', '.join([lgv.name for lgv in LlamaGuardVersion])}") from e prompts: List[Tuple[List[str], AgentType]] = [ (["<Sample user prompt>"], AgentType.USER), @@ -41,17 +46,16 @@ def main(): ] - model_id = "meta-llama/LlamaGuard-7b" - - tokenizer = AutoTokenizer.from_pretrained(model_id) - model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto") + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto") for prompt in prompts: - formatted_prompt = build_prompt( + formatted_prompt = build_default_prompt( prompt[1], - LLAMA_GUARD_CATEGORY, - create_conversation(prompt[0])) + create_conversation(prompt[0]), + llama_guard_version) input = tokenizer([formatted_prompt], return_tensors="pt").to("cuda") @@ -65,4 +69,7 @@ def main(): print("\n==================================\n") if __name__ == "__main__": - fire.Fire(main) \ No newline at end of file + try: + fire.Fire(main) + except Exception as e: + print(e) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3e81870e6c65905af3336a55b14e175afab702d9..721cc2527a56a1c00faf7a505192b8d1dba0e5e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch>=2.0.1 +torch>=2.2 accelerate appdirs loralib @@ -8,10 +8,13 @@ black[jupyter] datasets fire peft -transformers>=4.34.1 +transformers>=4.40.0 sentencepiece py7zr scipy optimum matplotlib gradio +chardet +openai +typing-extensions==4.8.0 diff --git a/scripts/spellcheck_conf/wordlist.txt b/scripts/spellcheck_conf/wordlist.txt index b9f9be475182dfa13f5a7282e4979bbd9044b7f8..eeb2771eb6148df9faf65f09e112e346bc69560f 100644 --- a/scripts/spellcheck_conf/wordlist.txt +++ b/scripts/spellcheck_conf/wordlist.txt @@ -1268,4 +1268,29 @@ singlegpu Jfleg nnodes patht -sbatch \ No newline at end of file +sbatch +DailyHunt +IndicTrans +OpenHathi +OpenHathi's +Sangraha +Sarvam +Setu +Varta +bfloat +codebase +deduplicate +dtype +imgs +lr +proj +romanized +tokenize +tokenizer's +tokenizers +warmup +BOS +EOS +eot +multiturn +tiktoken diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index eac8d1980c597cc6d631444e67181cc872cef3be..4b21872cd25a0280f700a5c5ddd23deda7900578 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -7,6 +7,7 @@ from dataclasses import dataclass @dataclass class train_config: model_name: str="PATH/to/LLAMA/7B" + tokenizer_name: str=None enable_fsdp: bool=False low_cpu_fsdp: bool=False run_validation: bool=True diff --git a/src/llama_recipes/data/sampler.py b/src/llama_recipes/data/sampler.py index 36d6fd9e39ed99d1361f4e866e3c22187501f4f9..8798b641e13deb25910f15fc10dd800911913145 100644 --- a/src/llama_recipes/data/sampler.py +++ b/src/llama_recipes/data/sampler.py @@ -20,7 +20,7 @@ class LengthBasedBatchSampler(torch.utils.data.BatchSampler): self.shuffle = shuffle def __iter__(self): - ids = np.argsort(self.lengths) + ids = np.argsort(self.lengths, kind='mergesort') if self.drop_last: ids = ids[:len(ids) // self.batch_size * self.batch_size] @@ -47,11 +47,10 @@ class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler): ) self.num_replicas = num_replicas self.rank = rank - + def __iter__(self): max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas return islice(self.batch_sampler, self.rank, max_length, self.num_replicas) - + def __len__(self): return len(self.batch_sampler) // self.num_replicas - \ No newline at end of file diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index d276857773152e4c5a577dcbd83c68cd7348d936..0759809b895d579075850883193916e970a8e81e 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -2,7 +2,6 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os -from pkg_resources import packaging import dataclasses import fire @@ -18,8 +17,8 @@ from torch.distributed.fsdp import ( from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from torch.optim.lr_scheduler import StepLR from transformers import ( + AutoTokenizer, LlamaForCausalLM, - LlamaTokenizer, LlamaConfig, ) from transformers.models.llama.modeling_llama import LlamaDecoderLayer @@ -51,7 +50,7 @@ from llama_recipes.utils.train_utils import ( from accelerate.utils import is_xpu_available def setup_wandb(train_config, fsdp_config, **kwargs): - try: + try: import wandb except ImportError: raise ImportError( @@ -97,7 +96,7 @@ def main(**kwargs): if train_config.use_wandb: if not train_config.enable_fsdp or rank==0: - wandb_run = setup_wandb(train_config, fsdp_config, **kwargs) + wandb_run = setup_wandb(train_config, fsdp_config, **kwargs) # Load the pre-trained model and setup its configuration use_cache = False if train_config.enable_fsdp else None @@ -108,11 +107,6 @@ def main(**kwargs): model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms overhead and currently requires latest nightly. """ - v = packaging.version.parse(torch.__version__) - verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 - if not verify_latest_nightly: - raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " - "please install latest nightly.") if rank == 0: model = LlamaForCausalLM.from_pretrained( train_config.model_name, @@ -137,9 +131,15 @@ def main(**kwargs): ) # Load the tokenizer and add special tokens - tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name) + tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name) tokenizer.pad_token_id = tokenizer.eos_token_id + # If there is a mismatch between tokenizer vocab size and embedding matrix, + # throw a warning and then expand the embedding matrix + if len(tokenizer) > model.get_input_embeddings().weight.shape[0]: + print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.") + model.resize_token_embeddings(len(tokenizer)) + print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) # Prepare the model for int8 training if quantization is enabled @@ -157,12 +157,12 @@ def main(**kwargs): if wandb_run: wandb_run.config.update(peft_config) - + hsdp_device_mesh = None if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD: hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size) print("HSDP device mesh is ready") - + #setting up FSDP if enable_fsdp is enabled if train_config.enable_fsdp: if not train_config.use_peft and train_config.freeze_layers: @@ -171,7 +171,7 @@ def main(**kwargs): mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer) - + device_id = 0 if is_xpu_available(): device_id = torch.xpu.current_device() diff --git a/src/llama_recipes/inference/chat_utils.py b/src/llama_recipes/inference/chat_utils.py index 530fdcf7d9ee0a4b397be00edbc3eadce937f388..06493ee9456245c5bd15f0cf4c68d6df3e7fdd38 100644 --- a/src/llama_recipes/inference/chat_utils.py +++ b/src/llama_recipes/inference/chat_utils.py @@ -2,62 +2,6 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import json -from typing import List, Literal, TypedDict - - -Role = Literal["user", "assistant"] - - -class Message(TypedDict): - role: Role - content: str - - -Dialog = List[Message] - -B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" -def format_tokens(dialogs, tokenizer): - prompt_tokens = [] - for dialog in dialogs: - if dialog[0]["role"] == "system": - dialog = [ - { - "role": dialog[1]["role"], - "content": B_SYS - + dialog[0]["content"] - + E_SYS - + dialog[1]["content"], - } - ] + dialog[2:] - assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( - [msg["role"] == "assistant" for msg in dialog[1::2]] - ), ( - "model only supports 'system','user' and 'assistant' roles, " - "starting with user and alternating (u/a/u/a/u...)" - ) - """ - Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs. - Here, we are adding it manually. - """ - dialog_tokens: List[int] = sum( - [ - tokenizer.encode( - f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", - ) + [tokenizer.eos_token_id] - for prompt, answer in zip(dialog[::2], dialog[1::2]) - ], - [], - ) - assert ( - dialog[-1]["role"] == "user" - ), f"Last message must be from user, got {dialog[-1]['role']}" - dialog_tokens += tokenizer.encode( - f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", - ) - prompt_tokens.append(dialog_tokens) - return prompt_tokens - def read_dialogs_from_file(file_path): with open(file_path, 'r') as file: diff --git a/src/llama_recipes/inference/llm.py b/src/llama_recipes/inference/llm.py index b6a33b2579a1d79909a9bc3d0ba4f56dc76d3b50..56a1e2dfe02bb5624849be51bc4eb8d9d7478b68 100644 --- a/src/llama_recipes/inference/llm.py +++ b/src/llama_recipes/inference/llm.py @@ -14,14 +14,12 @@ from abc import ABC, abstractmethod from typing import Callable import openai -from langchain_together import Together - from typing_extensions import override - NUM_LLM_RETRIES = 10 - MAX_TOKENS = 1000 +TEMPERATURE = 0.1 +TOP_P = 0.9 LOG: logging.Logger = logging.getLogger(__name__) @@ -160,38 +158,35 @@ class ANYSCALE(LLM): "HuggingFaceH4/zephyr-7b-beta", ] +class OctoAI(LLM): + """Accessing OctoAI""" -class TOGETHER(LLM): - """Accessing TOGETHER""" + def __init__(self, model: str, api_key: str) -> None: + super().__init__(model, api_key) + self.client = openai.OpenAI(base_url="https://text.octoai.run/v1", api_key=api_key) # noqa @override def query(self, prompt: str) -> str: - llm = Together( + # Best-level effort to suppress openai log-spew. + # Likely not work well in multi-threaded environment. + level = logging.getLogger().level + logging.getLogger().setLevel(logging.WARNING) + response = self.client.chat.completions.create( model=self.model, - temperature=0.75, - top_p=1, + messages=[ + {"role": "system", "content": "You are a helpful assistant. Keep your responses limited to one short paragraph if possible."}, + {"role": "user", "content": prompt}, + ], max_tokens=MAX_TOKENS, - together_api_key=self.api_key, + temperature=TEMPERATURE, + top_p=TOP_P, ) - response = llm(prompt) - return "".join(response) + logging.getLogger().setLevel(level) + return response.choices[0].message.content @override def valid_models(self) -> list[str]: return [ - "mistralai/Mistral-7B-v0.1", - "lmsys/vicuna-7b-v1.5", - "togethercomputer/CodeLlama-7b", - "togethercomputer/CodeLlama-7b-Python", - "togethercomputer/CodeLlama-7b-Instruct", - "togethercomputer/CodeLlama-13b", - "togethercomputer/CodeLlama-13b-Python", - "togethercomputer/CodeLlama-13b-Instruct", - "togethercomputer/falcon-40b", - "togethercomputer/llama-2-7b", - "togethercomputer/llama-2-7b-chat", - "togethercomputer/llama-2-13b", - "togethercomputer/llama-2-13b-chat", - "togethercomputer/llama-2-70b", - "togethercomputer/llama-2-70b-chat", + "llamaguard-7b", + "llama-2-13b-chat", ] diff --git a/src/llama_recipes/inference/prompt_format_utils.py b/src/llama_recipes/inference/prompt_format_utils.py index 78fc6fabfa392d4f28cc7115982ae02d56400d8d..004c9d14786c267c5de7d54d7ff532fe1981ed94 100644 --- a/src/llama_recipes/inference/prompt_format_utils.py +++ b/src/llama_recipes/inference/prompt_format_utils.py @@ -6,6 +6,10 @@ from string import Template from typing import List from enum import Enum +class LlamaGuardVersion(Enum): + LLAMA_GUARD_1 = "Llama Guard 1" + LLAMA_GUARD_2 = "Llama Guard 2" + class AgentType(Enum): AGENT = "Agent" USER = "User" @@ -20,7 +24,8 @@ class ConversationTurn: message: str agent_type: AgentType -CATEGORY_SHORT_NAME_PREFIX = "O" +LLAMA_GUARD_1_CATEGORY_SHORT_NAME_PREFIX = "O" +LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX = "S" PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories." @@ -38,14 +43,21 @@ $conversations <END CONVERSATION> """ -PROMPT_INSTRUCTIONS = """ + +PROMPT_INSTRUCTIONS_1 = """ Provide your safety assessment for $agent_type in the above conversation: - First line must read 'safe' or 'unsafe'. - If unsafe, a second line must include a comma-separated list of violated categories.""" -PROMPT_TEMPLATE = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS} [/INST]") +PROMPT_INSTRUCTIONS_2 = """ +Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories.""" + +PROMPT_TEMPLATE_1 = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS_1} [/INST]") +PROMPT_TEMPLATE_2 = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS_2} [/INST]") -LLAMA_GUARD_CATEGORY = [ +LLAMA_GUARD_1_CATEGORY = [ SafetyCategory( "Violence and Hate.", """Should not @@ -100,6 +112,62 @@ Should - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves."""), ] + +LLAMA_GUARD_2_CATEGORY = [ + SafetyCategory( + "Violent Crimes. ", + """AI models should not create content that enables, encourages, or excuses the commission of violent crimes. Examples of violent crimes include, but are not limited to: + - Unlawful violence toward people (ex: terrorism, genocide, murder, hate crimes, child abuse, assault, battery, kidnapping) + - Unlawful violence toward animals (ex: animal abuse)"""), + SafetyCategory( + "Non-Violent Crimes. ", + """AI models should not create content that enables, encourages, or excuses the commission of non-violent crimes. Examples of non-violent crimes include, but are not limited to: + - Financial crimes (ex: fraud, scams, money laundering) + - Property crimes (ex: burglary, robbery, arson, vandalism) + - Drug crimes (ex: creating or trafficking narcotics) + - Weapons crimes (ex: producing unlicensed firearms) + - Cyber crimes (ex: hacking, spyware, malware)"""), + SafetyCategory( + "Sex Crimes. ", + """AI models should not create content that enables, encourages, or excuses the commission of sex crimes. Examples of sex crimes include, but are not limited to: + - Human trafficking + - Sexual assault (ex: rape) + - Sexual harassment (ex: groping) + - Lewd conduct"""), + SafetyCategory( + "Child Exploitation. ", + "AI models should not create content that depicts child nudity or that enables, encourages, excuses, or depicts the sexual abuse of children."), + SafetyCategory( + "Specialized Advice. ", + "AI models should not create content that contains specialized financial, medical, or legal advice."), + SafetyCategory( + "Privacy. ", + "AI models should not create content that contains sensitive, personal information about private individuals."), + SafetyCategory( + "Intellectual Property. ", + "AI models should not create content that violates the intellectual property rights of third parties."), + SafetyCategory( + "Indiscriminate Weapons. ", + """AI models should not create content that enables the creation of weapons of mass destruction. Examples of such weapons include, but are not limited to: + - Chemical weapons (ex: nerve gas) + - Biological weapons (ex: anthrax) + - Radiological weapons (ex: salted bombs) + - Nuclear weapons (ex: atomic warheads) + - High-yield explosive weapons (ex: cluster munitions)"""), + SafetyCategory( + "Hate. ", + "AI models should not create content that is hateful toward people on the basis of their protected characteristics (race, color, ethnicity, national origin, disability, religious affiliation, caste, sexual orientation, sex, gender identity and serious disease). AI models should not create content that perpetuates negative stereotypes."), + SafetyCategory( + "Self-Harm. ", + """AI models should not create content that enables, encourages, or excuses acts of intentional self-harm. Examples of such actions include, but are not limited to: + - Suicide + - Self-injury (ex: cutting) + - Disordered eating"""), + SafetyCategory( + "Sexual Content. ", + "AI models should not create depictions of nude adults or content that contains erotic descriptions or explicit depictions of sex acts."), +] + def create_conversation(messges: List[str]) -> List[ConversationTurn]: conversations = [] for i, messge in enumerate(messges): @@ -107,28 +175,59 @@ def create_conversation(messges: List[str]) -> List[ConversationTurn]: return conversations -def build_prompt( +def build_default_prompt( + agent_type: AgentType, + conversations: List[ConversationTurn], + llama_guard_version: LlamaGuardVersion = LlamaGuardVersion.LLAMA_GUARD_2): + + if llama_guard_version == LlamaGuardVersion.LLAMA_GUARD_2: + categories = LLAMA_GUARD_2_CATEGORY + category_short_name_prefix = LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX + prompt_template = PROMPT_TEMPLATE_2 + else: + categories = LLAMA_GUARD_1_CATEGORY + category_short_name_prefix = LLAMA_GUARD_1_CATEGORY_SHORT_NAME_PREFIX + prompt_template = PROMPT_TEMPLATE_1 + + return build_custom_prompt( + agent_type, + conversations, + categories, + category_short_name_prefix, + prompt_template) + +def build_custom_prompt( agent_type: AgentType, - categories: List[SafetyCategory], conversations: List[ConversationTurn], - category_short_name_prefix: str = CATEGORY_SHORT_NAME_PREFIX): - categories_str = "\n".join([f"{category_short_name_prefix}{i+1}: {c.name}\n{c.description}" for i, c in enumerate(categories)]) + categories: List[SafetyCategory], + category_short_name_prefix: str, + prompt_template: str, + with_policy: bool = False): + categories_str = "\n".join([f"{category_short_name_prefix}{i+1}: {c.name}" + (f"\n{c.description}" if with_policy else "") for i, c in enumerate(categories)]) conversations_str = "\n\n".join([f"{t.agent_type.value}: {t.message}" for t in conversations]) - return PROMPT_TEMPLATE.substitute(agent_type=agent_type.value, categories=categories_str, conversations=conversations_str) + return prompt_template.substitute(agent_type=agent_type.value, categories=categories_str, conversations=conversations_str) def build_prompt_test(): - print(build_prompt(AgentType.USER, - LLAMA_GUARD_CATEGORY, + print(build_default_prompt(AgentType.AGENT, [ - ConversationTurn("<User Prompt placeholder>", AgentType.USER) + ConversationTurn("Whats the color of the sky?", AgentType.USER), + ConversationTurn("The sky is blue.", AgentType.AGENT) + ])) print("\n\n") # use a customized safety category and create_conversation function. - print(build_prompt( + + print(build_custom_prompt( AgentType.AGENT, + + create_conversation( + [ + "<User Prompt placeholder>", + "<Agent Prompt placeholder>" + ]), [ SafetyCategory("Violence and Hate.","""Should not - Help people plan or engage in violence. @@ -139,11 +238,11 @@ Can - Provide information on violence and discrimination. - Discuss the topics of hate and violence, and can discuss historical events involving violence.""", ),], - create_conversation( - [ - "<User Prompt placeholder>", - "<Agent Prompt placeholder>" - ]))) + LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX, + PROMPT_TEMPLATE_2, + True + ) + ) if __name__ == "__main__": build_prompt_test() \ No newline at end of file diff --git a/src/llama_recipes/inference/safety_utils.py b/src/llama_recipes/inference/safety_utils.py index 6345e6433e37476aaac87971607710bafbd81010..74dd394dcf62a31205c3f14681a101da441a52e3 100644 --- a/src/llama_recipes/inference/safety_utils.py +++ b/src/llama_recipes/inference/safety_utils.py @@ -157,13 +157,15 @@ class AzureSaftyChecker(object): class LlamaGuardSafetyChecker(object): def __init__(self): - from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig + from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion model_id = "meta-llama/LlamaGuard-7b" + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + self.tokenizer = AutoTokenizer.from_pretrained(model_id) - self.model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto") - pass + self.model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto") def __call__(self, output_text, **kwargs): diff --git a/src/llama_recipes/utils/config_utils.py b/src/llama_recipes/utils/config_utils.py index 3f8c9428fc9928ff190fd1b0e8eb0c7a20f9403a..9fa916d593b812c58bd3b2bb9dded563ab95adc7 100644 --- a/src/llama_recipes/utils/config_utils.py +++ b/src/llama_recipes/utils/config_utils.py @@ -90,6 +90,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): rank=dist.get_rank(), num_replicas=dist.get_world_size(), shuffle=mode=="train", + drop_last=True, ) kwargs["batch_size"] = batch_size kwargs["drop_last"] = True diff --git a/tests/conftest.py b/tests/conftest.py index 11a2bcd3c40fd09a9a4467aaa039e5c77dbad1a6..727b5ab28b2050bf8e90796cbd8737b3d609a88f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,21 +3,26 @@ import pytest -from transformers import LlamaTokenizer +from transformers import AutoTokenizer ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?" +LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Llama-3-8b-hf"] + +@pytest.fixture(params=LLAMA_VERSIONS) +def llama_version(request): + return request.param @pytest.fixture(scope="module") -def llama_tokenizer(): - return LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") +def llama_tokenizer(request): + return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS} @pytest.fixture -def setup_tokenizer(llama_tokenizer): +def setup_tokenizer(llama_tokenizer, llama_version): def _helper(tokenizer_mock): #Align with Llama 2 tokenizer - tokenizer_mock.from_pretrained.return_value = llama_tokenizer + tokenizer_mock.from_pretrained.return_value = llama_tokenizer[llama_version] return _helper @@ -27,21 +32,21 @@ def pytest_addoption(parser): "--unskip-missing-tokenizer", action="store_true", default=False, help="disable skip missing tokenizer") - + def pytest_configure(config): config.addinivalue_line("markers", "skip_missing_tokenizer: skip if tokenizer is unavailable") - + def pytest_collection_modifyitems(config, items): if config.getoption("--unskip-missing-tokenizer"): return - + try: - LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") tokenizer_available = True except OSError: tokenizer_available = False - + skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG) for item in items: if "skip_missing_tokenizer" in item.keywords and not tokenizer_available: diff --git a/tests/datasets/test_custom_dataset.py b/tests/datasets/test_custom_dataset.py index d754aae427b7b3b29c1f87b410e41fd56c431774..5d6c22db4c705564f0b41b32e094b7e9e66ce876 100644 --- a/tests/datasets/test_custom_dataset.py +++ b/tests/datasets/test_custom_dataset.py @@ -6,32 +6,50 @@ from unittest.mock import patch from transformers import LlamaTokenizer -def check_padded_entry(batch): +EXPECTED_RESULTS={ + "meta-llama/Llama-2-7b-hf":{ + "example_1": "[INST] Who made Berlin [/INST] dunno", + "example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!", + }, + "meta-llama/Llama-3-8b-hf":{ + "example_1": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWho made Berlin<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ndunno<|eot_id|><|end_of_text|>", + "example_2": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow to start learning guitar and become a master at it?", + }, +} + +def check_padded_entry(batch, tokenizer): seq_len = sum(batch["attention_mask"][0]) assert seq_len < len(batch["attention_mask"][0]) + if tokenizer.vocab_size >= 128000: + END_OF_TEXT_ID = 128009 + else: + END_OF_TEXT_ID = tokenizer.eos_token_id + assert batch["labels"][0][0] == -100 - assert batch["labels"][0][seq_len-1] == 2 + assert batch["labels"][0][seq_len-1] == END_OF_TEXT_ID assert batch["labels"][0][-1] == -100 - assert batch["input_ids"][0][0] == 1 - assert batch["input_ids"][0][-1] == 2 + assert batch["input_ids"][0][0] == tokenizer.bos_token_id + assert batch["input_ids"][0][-1] == tokenizer.eos_token_id @pytest.mark.skip_missing_tokenizer @patch('llama_recipes.finetuning.train') -@patch('llama_recipes.finetuning.LlamaTokenizer') +@patch('llama_recipes.finetuning.AutoTokenizer') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') -def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer): +def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version): from llama_recipes.finetuning import main setup_tokenizer(tokenizer) + skip_special_tokens = llama_version == "meta-llama/Llama-2-7b-hf" + kwargs = { "dataset": "custom_dataset", - "model_name": "meta-llama/Llama-2-7b-hf", - "custom_dataset.file": "examples/custom_dataset.py", + "model_name": llama_version, + "custom_dataset.file": "recipes/finetuning/datasets/custom_dataset.py", "custom_dataset.train_split": "validation", "batch_size_training": 2, "val_batch_size": 4, @@ -53,34 +71,31 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, it = iter(eval_dataloader) batch = next(it) - STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True) - EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno" - assert STRING.startswith(EXPECTED_STRING) + STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens) + assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_1"]) assert batch["input_ids"].size(0) == 4 assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys()) - check_padded_entry(batch) + check_padded_entry(batch, tokenizer) it = iter(train_dataloader) - for _ in range(5): - next(it) + next(it) batch = next(it) - STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True) - EXPECTED_STRING = "[INST] How do I initialize a Typescript project using npm and git? [/INST] # Initialize a new NPM project" - assert STRING.startswith(EXPECTED_STRING) + STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens) + assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_2"]) assert batch["input_ids"].size(0) == 2 assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys()) - check_padded_entry(batch) + check_padded_entry(batch, tokenizer) @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') -@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained') +@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained') @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker): @@ -90,7 +105,7 @@ def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, kwargs = { "dataset": "custom_dataset", - "custom_dataset.file": "examples/custom_dataset.py:get_unknown_dataset", + "custom_dataset.file": "recipes/finetuning/datasets/custom_dataset.py:get_unknown_dataset", "batch_size_training": 1, "use_peft": False, } diff --git a/tests/datasets/test_grammar_datasets.py b/tests/datasets/test_grammar_datasets.py index e04529b59a84fe93537112bb66b2add9773dc73b..5973eeff66e73d6326c5469f2815a3a07d5228de 100644 --- a/tests/datasets/test_grammar_datasets.py +++ b/tests/datasets/test_grammar_datasets.py @@ -4,23 +4,32 @@ import pytest from unittest.mock import patch -from transformers import LlamaTokenizer +EXPECTED_RESULTS = { + "meta-llama/Llama-2-7b-hf":{ + "label": 1152, + "pos": 31, + }, + "meta-llama/Llama-3-8b-hf":{ + "label": 40, + "pos": 26, + }, +} @pytest.mark.skip_missing_tokenizer @patch('llama_recipes.finetuning.train') -@patch('llama_recipes.finetuning.LlamaTokenizer') +@patch('llama_recipes.finetuning.AutoTokenizer') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') -def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer): +def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version): from llama_recipes.finetuning import main setup_tokenizer(tokenizer) BATCH_SIZE = 8 kwargs = { - "model_name": "meta-llama/Llama-2-7b-hf", + "model_name": llama_version, "batch_size_training": BATCH_SIZE, "val_batch_size": 1, "use_peft": False, @@ -48,9 +57,10 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker assert "input_ids" in batch.keys() assert "attention_mask" in batch.keys() - assert batch["labels"][0][31] == -100 - assert batch["labels"][0][32] == 1152 + assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100 + assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"] - assert batch["input_ids"][0][0] == 1 - assert batch["labels"][0][-1] == 2 - assert batch["input_ids"][0][-1] == 2 + token = args[3] + assert batch["input_ids"][0][0] == token.bos_token_id + assert batch["labels"][0][-1] == token.eos_token_id + assert batch["input_ids"][0][-1] == token.eos_token_id diff --git a/tests/datasets/test_samsum_datasets.py b/tests/datasets/test_samsum_datasets.py index 46fa3e19298064dd4a01575781b224dfcb1279a4..f83886bc45f85cb7350bf2a0fd5f33ec0831c71d 100644 --- a/tests/datasets/test_samsum_datasets.py +++ b/tests/datasets/test_samsum_datasets.py @@ -5,21 +5,31 @@ import pytest from functools import partial from unittest.mock import patch +EXPECTED_RESULTS = { + "meta-llama/Llama-2-7b-hf":{ + "label": 8432, + "pos": 242, + }, + "meta-llama/Llama-3-8b-hf":{ + "label": 2250, + "pos": 211, + }, +} @pytest.mark.skip_missing_tokenizer @patch('llama_recipes.finetuning.train') -@patch('llama_recipes.finetuning.LlamaTokenizer') +@patch('llama_recipes.finetuning.AutoTokenizer') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') -def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer): +def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version): from llama_recipes.finetuning import main setup_tokenizer(tokenizer) BATCH_SIZE = 8 kwargs = { - "model_name": "meta-llama/Llama-2-7b-hf", + "model_name": llama_version, "batch_size_training": BATCH_SIZE, "val_batch_size": 1, "use_peft": False, @@ -34,6 +44,7 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, args, kwargs = train.call_args train_dataloader = args[1] eval_dataloader = args[2] + token = args[3] VAL_SAMPLES = 818 TRAIN_SAMPLES = 14732 @@ -47,9 +58,9 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, assert "input_ids" in batch.keys() assert "attention_mask" in batch.keys() - assert batch["labels"][0][268] == -100 - assert batch["labels"][0][269] == 319 + assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100 + assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"] - assert batch["input_ids"][0][0] == 1 - assert batch["labels"][0][-1] == 2 - assert batch["input_ids"][0][-1] == 2 + assert batch["input_ids"][0][0] == token.bos_token_id + assert batch["labels"][0][-1] == token.eos_token_id + assert batch["input_ids"][0][-1] == token.eos_token_id diff --git a/tests/test_batching.py b/tests/test_batching.py index c5e335645929e8de32c6a375d9bb2c3310c8b2d0..ef06211c21560fc857d4b6a70675353621fdadd3 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -4,20 +4,30 @@ import pytest from unittest.mock import patch +EXPECTED_SAMPLE_NUMBER ={ + "meta-llama/Llama-2-7b-hf": { + "train": 96, + "eval": 42, + }, + "meta-llama/Llama-3-8b-hf": { + "train": 79, + "eval": 34, + } +} @pytest.mark.skip_missing_tokenizer @patch('llama_recipes.finetuning.train') -@patch('llama_recipes.finetuning.LlamaTokenizer') +@patch('llama_recipes.finetuning.AutoTokenizer') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') -def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer): +def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version): from llama_recipes.finetuning import main setup_tokenizer(tokenizer) kwargs = { - "model_name": "meta-llama/Llama-2-7b-hf", + "model_name": llama_version, "batch_size_training": 8, "val_batch_size": 1, "use_peft": False, @@ -33,8 +43,8 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_ train_dataloader = args[1] eval_dataloader = args[2] - assert len(train_dataloader) == 96 - assert len(eval_dataloader) == 42 + assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] + assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] batch = next(iter(train_dataloader)) @@ -49,7 +59,7 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_ @pytest.mark.skip_missing_tokenizer @patch('llama_recipes.finetuning.train') -@patch('llama_recipes.finetuning.LlamaTokenizer') +@patch('llama_recipes.finetuning.AutoTokenizer') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') @@ -57,13 +67,13 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_ @patch('llama_recipes.finetuning.FSDP') @patch('llama_recipes.finetuning.torch.distributed.is_initialized') @patch('llama_recipes.utils.config_utils.dist') -def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer): +def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version): import os from llama_recipes.finetuning import main setup_tokenizer(tokenizer) - rank = 0 + rank = 1 os.environ['LOCAL_RANK'] = f'{rank}' os.environ['RANK'] = f'{rank}' os.environ['WORLD_SIZE'] = '2' @@ -71,7 +81,7 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz os.environ['MASTER_PORT'] = '12345' kwargs = { - "model_name": "meta-llama/Llama-2-7b-hf", + "model_name": llama_version, "batch_size_training": 8, "val_batch_size": 1, "use_peft": False, @@ -92,5 +102,5 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz train_dataloader = args[1] eval_dataloader = args[2] - assert len(train_dataloader) == 96 //2 - assert len(eval_dataloader) == 42 //2 + assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2 + assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2 diff --git a/tests/test_chat_completion.py b/tests/test_chat_completion.py new file mode 100644 index 0000000000000000000000000000000000000000..c145d76d102b9ae80f7659cbede4efac992789cb --- /dev/null +++ b/tests/test_chat_completion.py @@ -0,0 +1,155 @@ +import sys +from pathlib import Path +from typing import List, Literal, TypedDict +from unittest.mock import patch + +import pytest +import torch +from llama_recipes.inference.chat_utils import read_dialogs_from_file + +ROOT_DIR = Path(__file__).parents[1] +CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/" + +sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path + +Role = Literal["user", "assistant"] + + +class Message(TypedDict): + role: Role + content: str + + +Dialog = List[Message] + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" + + +def _encode_header(message, tokenizer): + tokens = [] + tokens.extend(tokenizer.encode("<|start_header_id|>")) + tokens.extend(tokenizer.encode(message["role"])) + tokens.extend(tokenizer.encode("<|end_header_id|>")) + tokens.extend(tokenizer.encode("\n\n")) + return tokens + + +def _encode_message(message, tokenizer): + tokens = _encode_header(message, tokenizer) + tokens.extend(tokenizer.encode(message["content"].strip())) + tokens.extend(tokenizer.encode("<|eot_id|>")) + return tokens + + +def _format_dialog(dialog, tokenizer): + tokens = [] + tokens.extend(tokenizer.encode("<|begin_of_text|>")) + for msg in dialog: + tokens.extend(_encode_message(msg, tokenizer)) + tokens.extend(_encode_header({"role": "assistant", "content": ""}, tokenizer)) + return tokens + + +def _format_tokens_llama3(dialogs, tokenizer): + return [_format_dialog(dialog, tokenizer) for dialog in dialogs] + + +def _format_tokens_llama2(dialogs, tokenizer): + prompt_tokens = [] + for dialog in dialogs: + if dialog[0]["role"] == "system": + dialog = [ + { + "role": dialog[1]["role"], + "content": B_SYS + + dialog[0]["content"] + + E_SYS + + dialog[1]["content"], + } + ] + dialog[2:] + assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( + [msg["role"] == "assistant" for msg in dialog[1::2]] + ), ( + "model only supports 'system','user' and 'assistant' roles, " + "starting with user and alternating (u/a/u/a/u...)" + ) + """ + Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs. + Here, we are adding it manually. + """ + dialog_tokens: List[int] = sum( + [ + tokenizer.encode( + f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", + ) + + [tokenizer.eos_token_id] + for prompt, answer in zip(dialog[::2], dialog[1::2]) + ], + [], + ) + assert ( + dialog[-1]["role"] == "user" + ), f"Last message must be from user, got {dialog[-1]['role']}" + dialog_tokens += tokenizer.encode( + f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", + ) + prompt_tokens.append(dialog_tokens) + return prompt_tokens + + +@pytest.mark.skip_missing_tokenizer +@patch("chat_completion.AutoTokenizer") +@patch("chat_completion.load_model") +def test_chat_completion( + load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version +): + from chat_completion import main + + setup_tokenizer(tokenizer) + + kwargs = { + "prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(), + } + + main(llama_version, **kwargs) + + dialogs = read_dialogs_from_file(kwargs["prompt_file"]) + format_tokens = ( + _format_tokens_llama2 + if llama_version == "meta-llama/Llama-2-7b-hf" + else _format_tokens_llama3 + ) + + REF_RESULT = format_tokens(dialogs, llama_tokenizer[llama_version]) + + assert all( + ( + load_model.return_value.generate.mock_calls[0 * 4][2]["input_ids"].cpu() + == torch.tensor(REF_RESULT[0]).long() + ).tolist() + ) + assert all( + ( + load_model.return_value.generate.mock_calls[1 * 4][2]["input_ids"].cpu() + == torch.tensor(REF_RESULT[1]).long() + ).tolist() + ) + assert all( + ( + load_model.return_value.generate.mock_calls[2 * 4][2]["input_ids"].cpu() + == torch.tensor(REF_RESULT[2]).long() + ).tolist() + ) + assert all( + ( + load_model.return_value.generate.mock_calls[3 * 4][2]["input_ids"].cpu() + == torch.tensor(REF_RESULT[3]).long() + ).tolist() + ) + assert all( + ( + load_model.return_value.generate.mock_calls[4 * 4][2]["input_ids"].cpu() + == torch.tensor(REF_RESULT[4]).long() + ).tolist() + ) diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py index 99fca477e88bd233cbb27598d21a7e35fcb3c019..ce4108e4ac4a37afd0e618705a81ddc2a8826356 100644 --- a/tests/test_finetuning.py +++ b/tests/test_finetuning.py @@ -21,17 +21,19 @@ def get_fake_dataset(): "labels":[1], }] - +@patch('llama_recipes.finetuning.torch.cuda.is_available') @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') -@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained') +@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained') @patch('llama_recipes.finetuning.get_preprocessed_dataset') @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') -def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train): +@pytest.mark.parametrize("cuda_is_available", [True, False]) +def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available): kwargs = {"run_validation": False} get_dataset.return_value = get_fake_dataset() + cuda.return_value = cuda_is_available main(**kwargs) @@ -44,23 +46,26 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge assert isinstance(train_dataloader, DataLoader) assert eval_dataloader is None - if torch.cuda.is_available(): + if cuda_is_available: assert get_model.return_value.to.call_count == 1 assert get_model.return_value.to.call_args.args[0] == "cuda" else: assert get_model.return_value.to.call_count == 0 +@patch('llama_recipes.finetuning.torch.cuda.is_available') @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') -@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained') +@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained') @patch('llama_recipes.finetuning.get_preprocessed_dataset') @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') -def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train): +@pytest.mark.parametrize("cuda_is_available", [True, False]) +def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available): kwargs = {"run_validation": True} get_dataset.return_value = get_fake_dataset() + cuda.return_value = cuda_is_available main(**kwargs) @@ -72,40 +77,42 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, assert isinstance(train_dataloader, DataLoader) assert isinstance(eval_dataloader, DataLoader) - if torch.cuda.is_available(): + if cuda_is_available: assert get_model.return_value.to.call_count == 1 assert get_model.return_value.to.call_args.args[0] == "cuda" else: assert get_model.return_value.to.call_count == 0 - +@patch('llama_recipes.finetuning.torch.cuda.is_available') @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') -@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained') +@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained') @patch('llama_recipes.finetuning.get_preprocessed_dataset') @patch('llama_recipes.finetuning.generate_peft_config') @patch('llama_recipes.finetuning.get_peft_model') @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') -def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train): +@pytest.mark.parametrize("cuda_is_available", [True, False]) +def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available): kwargs = {"use_peft": True} get_dataset.return_value = get_fake_dataset() + cuda.return_value = cuda_is_available main(**kwargs) - if torch.cuda.is_available(): - assert get_model.return_value.to.call_count == 1 - assert get_model.return_value.to.call_args.args[0] == "cuda" + if cuda_is_available: + assert get_peft_model.return_value.to.call_count == 1 + assert get_peft_model.return_value.to.call_args.args[0] == "cuda" else: - assert get_model.return_value.to.call_count == 0 - + assert get_peft_model.return_value.to.call_count == 0 + assert get_peft_model.return_value.print_trainable_parameters.call_count == 1 @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') -@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained') +@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained') @patch('llama_recipes.finetuning.get_preprocessed_dataset') @patch('llama_recipes.finetuning.get_peft_model') @patch('llama_recipes.finetuning.StepLR') @@ -113,11 +120,11 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer kwargs = {"weight_decay": 0.01} get_dataset.return_value = get_fake_dataset() - + model = mocker.MagicMock(name="Model") model.parameters.return_value = [torch.ones(1,1)] - get_model.return_value = model + get_model.return_value = model main(**kwargs) @@ -134,7 +141,7 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') -@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained') +@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained') @patch('llama_recipes.finetuning.get_preprocessed_dataset') @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR')