Skip to content
Snippets Groups Projects
Commit 33da341a authored by abhilash1910's avatar abhilash1910
Browse files

upstream resolve conflict

parent 81fecf3d
No related branches found
No related tags found
No related merge requests found
......@@ -55,7 +55,10 @@ def main(
# Set the seeds for reproducibility
torch.cuda.manual_seed(seed)
if is_xpu_available():
torch.xpu.manual_seed(seed)
else:
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
model = load_model(model_name, quantization)
if peft_model:
......@@ -105,7 +108,10 @@ def main(
sys.exit(1) # Exit the program with an error status
tokens= torch.tensor(chat).long()
tokens= tokens.unsqueeze(0)
tokens= tokens.to("cuda:0")
if is_xpu_available():
tokens= tokens.to("xpu:0")
else:
tokens= tokens.to("cuda:0")
outputs = model.generate(
input_ids=tokens,
max_new_tokens=max_new_tokens,
......
......@@ -10,10 +10,16 @@ import time
import torch
from transformers import LlamaTokenizer
<<<<<<< HEAD:examples/inference.py
from llama_recipes.inference.safety_utils import get_safety_checker
from llama_recipes.inference.model_utils import load_model, load_peft_model
=======
from safety_utils import get_safety_checker
from model_utils import load_model, load_peft_model, load_llama_from_config
from accelerate.utils import is_xpu_available
>>>>>>> ed7ba99 (enable xpu finetuning and inference):inference/inference.py
def main(
model_name,
......@@ -50,7 +56,10 @@ def main(
sys.exit(1)
# Set the seeds for reproducibility
torch.cuda.manual_seed(seed)
if is_xpu_available():
torch.xpu.manual_seed(seed)
else:
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
model = load_model(model_name, quantization)
......@@ -102,7 +111,15 @@ def main(
batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt")
<<<<<<< HEAD:examples/inference.py
batch = {k: v.to("cuda") for k, v in batch.items()}
=======
batch = tokenizer(user_prompt, return_tensors="pt")
if is_xpu_available():
batch = {k: v.to("xpu") for k, v in batch.items()}
else:
batch = {k: v.to("cuda") for k, v in batch.items()}
>>>>>>> ed7ba99 (enable xpu finetuning and inference):inference/inference.py
start = time.perf_counter()
with torch.no_grad():
outputs = model.generate(
......
......@@ -6,9 +6,13 @@ import fire
import torch
from vllm import LLM
from vllm import LLM, SamplingParams
from accelerate.utils import is_xpu_available
if is_xpu_available():
torch.xpu.manual_seed(42)
else:
torch.cuda.manual_seed(42)
torch.cuda.manual_seed(42)
torch.manual_seed(42)
def load_model(model_name, tp_size=1):
......
......@@ -42,6 +42,7 @@ from llama_recipes.utils.train_utils import (
print_model_size,
get_policies
)
from accelerate.utils import is_xpu_available
def main(**kwargs):
......@@ -49,7 +50,10 @@ def main(**kwargs):
update_config((train_config, fsdp_config), **kwargs)
# Set the seeds for reproducibility
torch.cuda.manual_seed(train_config.seed)
if is_xpu_available():
torch.xpu.manual_seed(train_config.seed)
else:
torch.cuda.manual_seed(train_config.seed)
torch.manual_seed(train_config.seed)
if train_config.enable_fsdp:
......@@ -60,7 +64,10 @@ def main(**kwargs):
world_size = int(os.environ["WORLD_SIZE"])
if torch.distributed.is_initialized():
torch.cuda.set_device(local_rank)
if is_xpu_available():
torch.xpu.set_device(local_rank)
else:
torch.cuda.set_device(local_rank)
clear_gpu_cache(local_rank)
setup_environ_flags(rank)
......@@ -146,7 +153,7 @@ def main(**kwargs):
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
sharding_strategy=fsdp_config.sharding_strategy,
device_id=torch.cuda.current_device(),
device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
......@@ -155,7 +162,10 @@ def main(**kwargs):
if fsdp_config.fsdp_activation_checkpointing:
apply_fsdp_checkpointing(model)
elif not train_config.quantization and not train_config.enable_fsdp:
model.to("cuda")
if is_xpu_available():
model.to("xpu:0")
else:
model.to("cuda")
dataset_config = generate_dataset_config(train_config, kwargs)
......
......@@ -6,6 +6,7 @@ import psutil
import threading
import torch
from accelerate.utils import is_xpu_available
def byte2gb(x):
return int(x / 2**30)
......@@ -13,9 +14,14 @@ def byte2gb(x):
class MemoryTrace:
def __enter__(self):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.cuda.memory_allocated())
if is_xpu_available():
torch.xpu.empty_cache()
torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.xpu.memory_allocated())
elif torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.cuda.memory_allocated())
self.process = psutil.Process()
self.cpu_begin = byte2gb(self.cpu_mem_used())
self.peak_monitoring = True
......@@ -44,17 +50,30 @@ class MemoryTrace:
self.peak_monitoring = False
gc.collect()
torch.cuda.empty_cache()
self.end = byte2gb(torch.cuda.memory_allocated())
self.peak = byte2gb(torch.cuda.max_memory_allocated())
cuda_info = torch.cuda.memory_stats()
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
self.used = byte2gb(self.end - self.begin)
self.peaked = byte2gb(self.peak - self.begin)
self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
if is_xpu_available():
torch.xpu.empty_cache()
self.end = byte2gb(torch.xpu.memory_allocated())
self.peak = byte2gb(torch.xpu.max_memory_allocated())
xpu_info = torch.xpu.memory_stats()
self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0)
self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
self.m_xpu_ooms = xpu_info.get("num_ooms", 0)
self.used = byte2gb(self.end - self.begin)
self.peaked = byte2gb(self.peak - self.begin)
self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
else:
torch.cuda.empty_cache()
self.end = byte2gb(torch.cuda.memory_allocated())
self.peak = byte2gb(torch.cuda.max_memory_allocated())
cuda_info = torch.cuda.memory_stats()
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
self.used = byte2gb(self.end - self.begin)
self.peaked = byte2gb(self.peak - self.begin)
self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
self.cpu_end = self.cpu_mem_used()
self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
......
......@@ -20,6 +20,7 @@ from transformers import LlamaTokenizer
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
from llama_recipes.utils.memory_utils import MemoryTrace
from accelerate.utils import is_xpu_available, is_ccl_available
def set_tokenizer_params(tokenizer: LlamaTokenizer):
......@@ -101,7 +102,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
epoch_end_time = time.perf_counter()-epoch_start_time
epoch_times.append(epoch_end_time)
# Reducing total_loss across all devices if there's more than one CUDA device
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
train_epoch_loss = total_loss / len(train_dataloader)
if train_config.enable_fsdp:
......@@ -113,17 +116,29 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
if train_config.enable_fsdp:
if rank==0:
if is_xpu_available():
print(f"Max XPU memory allocated was {memtrace.peak} GB")
print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}")
else:
print(f"Max CUDA memory allocated was {memtrace.peak} GB")
print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
else:
if is_xpu_available():
print(f"Max XPU memory allocated was {memtrace.peak} GB")
print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}")
else:
print(f"Max CUDA memory allocated was {memtrace.peak} GB")
print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
else:
print(f"Max CUDA memory allocated was {memtrace.peak} GB")
print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
# Update the learning rate as needed
lr_scheduler.step()
......@@ -246,6 +261,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
)
# If there's more than one CUDA device, reduce evaluation loss across all devices
if is_xpu_available() and (torch.cuda.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
......@@ -279,7 +296,11 @@ def check_frozen_layers_peft_model(model):
def setup():
"""Initialize the process group for distributed training"""
dist.init_process_group("nccl")
if is_ccl_available():
# distributed training on xpus
dist.init_process_group("ccl")
else:
dist.init_process_group("nccl")
def setup_environ_flags(rank):
......@@ -303,7 +324,10 @@ def clear_gpu_cache(rank=None):
"""Clear the GPU cache for all ranks"""
if rank == 0:
print(f"Clearing GPU cache for all ranks")
torch.cuda.empty_cache()
if is_xpu_available():
torch.xpu_empty_cache()
else:
torch.cuda.empty_cache()
def get_parameter_dtypes(model):
......@@ -335,13 +359,14 @@ def print_model_size(model, config, rank: int = 0) -> None:
def get_policies(cfg, rank):
"""Get the policies for mixed precision and fsdp wrapping"""
verify_bfloat_support = (
verify_bfloat_support = ((
torch.version.cuda
and torch.cuda.is_bf16_supported()
and packaging.version.parse(torch.version.cuda).release >= (11, 0)
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
) or
(is_xpu_available()))
mixed_precision_policy = None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment