diff --git a/examples/chat_completion/chat_completion.py b/examples/chat_completion/chat_completion.py index e8665cd948201ed2f35970234b613a4d0739448d..8c2636c33c1eccb5c64840a3a78e7441af561862 100644 --- a/examples/chat_completion/chat_completion.py +++ b/examples/chat_completion/chat_completion.py @@ -55,7 +55,10 @@ def main( # Set the seeds for reproducibility - torch.cuda.manual_seed(seed) + if is_xpu_available(): + torch.xpu.manual_seed(seed) + else: + torch.cuda.manual_seed(seed) torch.manual_seed(seed) model = load_model(model_name, quantization) if peft_model: @@ -105,7 +108,10 @@ def main( sys.exit(1) # Exit the program with an error status tokens= torch.tensor(chat).long() tokens= tokens.unsqueeze(0) - tokens= tokens.to("cuda:0") + if is_xpu_available(): + tokens= tokens.to("xpu:0") + else: + tokens= tokens.to("cuda:0") outputs = model.generate( input_ids=tokens, max_new_tokens=max_new_tokens, diff --git a/examples/inference.py b/examples/inference.py index 28952920178ab4ba87c714d68b4c142743ded365..f872238b3db62bbc5cc3b98d23938e3fde8a8c67 100644 --- a/examples/inference.py +++ b/examples/inference.py @@ -10,10 +10,16 @@ import time import torch from transformers import LlamaTokenizer +<<<<<<< HEAD:examples/inference.py from llama_recipes.inference.safety_utils import get_safety_checker from llama_recipes.inference.model_utils import load_model, load_peft_model +======= +from safety_utils import get_safety_checker +from model_utils import load_model, load_peft_model, load_llama_from_config +from accelerate.utils import is_xpu_available +>>>>>>> ed7ba99 (enable xpu finetuning and inference):inference/inference.py def main( model_name, @@ -50,7 +56,10 @@ def main( sys.exit(1) # Set the seeds for reproducibility - torch.cuda.manual_seed(seed) + if is_xpu_available(): + torch.xpu.manual_seed(seed) + else: + torch.cuda.manual_seed(seed) torch.manual_seed(seed) model = load_model(model_name, quantization) @@ -102,7 +111,15 @@ def main( batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt") +<<<<<<< HEAD:examples/inference.py batch = {k: v.to("cuda") for k, v in batch.items()} +======= + batch = tokenizer(user_prompt, return_tensors="pt") + if is_xpu_available(): + batch = {k: v.to("xpu") for k, v in batch.items()} + else: + batch = {k: v.to("cuda") for k, v in batch.items()} +>>>>>>> ed7ba99 (enable xpu finetuning and inference):inference/inference.py start = time.perf_counter() with torch.no_grad(): outputs = model.generate( diff --git a/examples/vllm/inference.py b/examples/vllm/inference.py index e587bc038ceca65359c0573a9b3a719943d26843..39a785766391887e3bf30e4a81c3956e77fb9e7a 100644 --- a/examples/vllm/inference.py +++ b/examples/vllm/inference.py @@ -6,9 +6,13 @@ import fire import torch from vllm import LLM from vllm import LLM, SamplingParams +from accelerate.utils import is_xpu_available +if is_xpu_available(): + torch.xpu.manual_seed(42) +else: + torch.cuda.manual_seed(42) -torch.cuda.manual_seed(42) torch.manual_seed(42) def load_model(model_name, tp_size=1): diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index a475c1cad42994ee8517db27548156bf1e35703f..a81f8c02cbd4af73f64c498b2ebbaa429e000485 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -42,6 +42,7 @@ from llama_recipes.utils.train_utils import ( print_model_size, get_policies ) +from accelerate.utils import is_xpu_available def main(**kwargs): @@ -49,7 +50,10 @@ def main(**kwargs): update_config((train_config, fsdp_config), **kwargs) # Set the seeds for reproducibility - torch.cuda.manual_seed(train_config.seed) + if is_xpu_available(): + torch.xpu.manual_seed(train_config.seed) + else: + torch.cuda.manual_seed(train_config.seed) torch.manual_seed(train_config.seed) if train_config.enable_fsdp: @@ -60,7 +64,10 @@ def main(**kwargs): world_size = int(os.environ["WORLD_SIZE"]) if torch.distributed.is_initialized(): - torch.cuda.set_device(local_rank) + if is_xpu_available(): + torch.xpu.set_device(local_rank) + else: + torch.cuda.set_device(local_rank) clear_gpu_cache(local_rank) setup_environ_flags(rank) @@ -146,7 +153,7 @@ def main(**kwargs): auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, sharding_strategy=fsdp_config.sharding_strategy, - device_id=torch.cuda.current_device(), + device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(), limit_all_gathers=True, sync_module_states=train_config.low_cpu_fsdp, param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) @@ -155,7 +162,10 @@ def main(**kwargs): if fsdp_config.fsdp_activation_checkpointing: apply_fsdp_checkpointing(model) elif not train_config.quantization and not train_config.enable_fsdp: - model.to("cuda") + if is_xpu_available(): + model.to("xpu:0") + else: + model.to("cuda") dataset_config = generate_dataset_config(train_config, kwargs) diff --git a/src/llama_recipes/utils/memory_utils.py b/src/llama_recipes/utils/memory_utils.py index 725f2b0d5f7a5ec9551108f7bcd7e5c166af7af1..26f6851b2a1d35e1979de8a4049ef607d82d2198 100644 --- a/src/llama_recipes/utils/memory_utils.py +++ b/src/llama_recipes/utils/memory_utils.py @@ -6,6 +6,7 @@ import psutil import threading import torch +from accelerate.utils import is_xpu_available def byte2gb(x): return int(x / 2**30) @@ -13,9 +14,14 @@ def byte2gb(x): class MemoryTrace: def __enter__(self): gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero - self.begin = byte2gb(torch.cuda.memory_allocated()) + if is_xpu_available(): + torch.xpu.empty_cache() + torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = byte2gb(torch.xpu.memory_allocated()) + elif torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = byte2gb(torch.cuda.memory_allocated()) self.process = psutil.Process() self.cpu_begin = byte2gb(self.cpu_mem_used()) self.peak_monitoring = True @@ -44,17 +50,30 @@ class MemoryTrace: self.peak_monitoring = False gc.collect() - torch.cuda.empty_cache() - self.end = byte2gb(torch.cuda.memory_allocated()) - self.peak = byte2gb(torch.cuda.max_memory_allocated()) - cuda_info = torch.cuda.memory_stats() - self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) - self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0) - self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) - self.m_cuda_ooms = cuda_info.get("num_ooms", 0) - self.used = byte2gb(self.end - self.begin) - self.peaked = byte2gb(self.peak - self.begin) - self.max_reserved = byte2gb(torch.cuda.max_memory_reserved()) + if is_xpu_available(): + torch.xpu.empty_cache() + self.end = byte2gb(torch.xpu.memory_allocated()) + self.peak = byte2gb(torch.xpu.max_memory_allocated()) + xpu_info = torch.xpu.memory_stats() + self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"]) + self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0) + self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"]) + self.m_xpu_ooms = xpu_info.get("num_ooms", 0) + self.used = byte2gb(self.end - self.begin) + self.peaked = byte2gb(self.peak - self.begin) + self.max_reserved = byte2gb(torch.xpu.max_memory_reserved()) + else: + torch.cuda.empty_cache() + self.end = byte2gb(torch.cuda.memory_allocated()) + self.peak = byte2gb(torch.cuda.max_memory_allocated()) + cuda_info = torch.cuda.memory_stats() + self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) + self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0) + self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) + self.m_cuda_ooms = cuda_info.get("num_ooms", 0) + self.used = byte2gb(self.end - self.begin) + self.peaked = byte2gb(self.peak - self.begin) + self.max_reserved = byte2gb(torch.cuda.max_memory_reserved()) self.cpu_end = self.cpu_mem_used() self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin) diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index c3790651a6bcc691c3227e588d349294a60ad4da..9497d2b5c968239e300b2cd682d88a656b458b6b 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -20,6 +20,7 @@ from transformers import LlamaTokenizer from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper from llama_recipes.utils.memory_utils import MemoryTrace +from accelerate.utils import is_xpu_available, is_ccl_available def set_tokenizer_params(tokenizer: LlamaTokenizer): @@ -101,7 +102,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche epoch_end_time = time.perf_counter()-epoch_start_time epoch_times.append(epoch_end_time) # Reducing total_loss across all devices if there's more than one CUDA device - if torch.cuda.device_count() > 1 and train_config.enable_fsdp: + if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): + dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) + elif torch.cuda.device_count() > 1 and train_config.enable_fsdp: dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) train_epoch_loss = total_loss / len(train_dataloader) if train_config.enable_fsdp: @@ -113,17 +116,29 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.enable_fsdp: if rank==0: + if is_xpu_available(): + print(f"Max XPU memory allocated was {memtrace.peak} GB") + print(f"Max XPU memory reserved was {memtrace.max_reserved} GB") + print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB") + print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}") + else: + print(f"Max CUDA memory allocated was {memtrace.peak} GB") + print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") + print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") + print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") + print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") + else: + if is_xpu_available(): + print(f"Max XPU memory allocated was {memtrace.peak} GB") + print(f"Max XPU memory reserved was {memtrace.max_reserved} GB") + print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB") + print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}") + else: print(f"Max CUDA memory allocated was {memtrace.peak} GB") print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") - else: - print(f"Max CUDA memory allocated was {memtrace.peak} GB") - print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") - print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") - print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") - print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") # Update the learning rate as needed lr_scheduler.step() @@ -246,6 +261,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): ) # If there's more than one CUDA device, reduce evaluation loss across all devices + if is_xpu_available() and (torch.cuda.device_count() > 1 and train_config.enable_fsdp): + dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) if torch.cuda.device_count() > 1 and train_config.enable_fsdp: dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) @@ -279,7 +296,11 @@ def check_frozen_layers_peft_model(model): def setup(): """Initialize the process group for distributed training""" - dist.init_process_group("nccl") + if is_ccl_available(): + # distributed training on xpus + dist.init_process_group("ccl") + else: + dist.init_process_group("nccl") def setup_environ_flags(rank): @@ -303,7 +324,10 @@ def clear_gpu_cache(rank=None): """Clear the GPU cache for all ranks""" if rank == 0: print(f"Clearing GPU cache for all ranks") - torch.cuda.empty_cache() + if is_xpu_available(): + torch.xpu_empty_cache() + else: + torch.cuda.empty_cache() def get_parameter_dtypes(model): @@ -335,13 +359,14 @@ def print_model_size(model, config, rank: int = 0) -> None: def get_policies(cfg, rank): """Get the policies for mixed precision and fsdp wrapping""" - verify_bfloat_support = ( + verify_bfloat_support = (( torch.version.cuda and torch.cuda.is_bf16_supported() and packaging.version.parse(torch.version.cuda).release >= (11, 0) and dist.is_nccl_available() and nccl.version() >= (2, 10) - ) + ) or + (is_xpu_available())) mixed_precision_policy = None