From d51d2cce9cec023ec40e79a7e2debba794061c02 Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri <hamid.nazeri2010@gmail.com> Date: Tue, 6 Feb 2024 03:10:30 +0000 Subject: [PATCH] adding sdpa for flash attn --- src/llama_recipes/finetuning.py | 4 ++-- src/llama_recipes/inference/model_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 3b4973cf..d0a72400 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -94,7 +94,7 @@ def main(**kwargs): load_in_8bit=True if train_config.quantization else None, device_map="auto" if train_config.quantization else None, use_cache=use_cache, - attn_implementation="eager" if train_config.use_fast_kernels else None, + attn_implementation="sdpa" if train_config.use_fast_kernels else None, ) else: llama_config = LlamaConfig.from_pretrained(train_config.model_name) @@ -108,7 +108,7 @@ def main(**kwargs): load_in_8bit=True if train_config.quantization else None, device_map="auto" if train_config.quantization else None, use_cache=use_cache, - attn_implementation="eager" if train_config.use_fast_kernels else None, + attn_implementation="sdpa" if train_config.use_fast_kernels else None, ) # Load the tokenizer and add special tokens diff --git a/src/llama_recipes/inference/model_utils.py b/src/llama_recipes/inference/model_utils.py index cc7cc6d2..e70cb82c 100644 --- a/src/llama_recipes/inference/model_utils.py +++ b/src/llama_recipes/inference/model_utils.py @@ -13,7 +13,7 @@ def load_model(model_name, quantization, use_fast_kernels): load_in_8bit=quantization, device_map="auto", low_cpu_mem_usage=True, - attn_implementation="eager" if use_fast_kernels else None, + attn_implementation="sdpa" if use_fast_kernels else None, ) return model -- GitLab