From d51d2cce9cec023ec40e79a7e2debba794061c02 Mon Sep 17 00:00:00 2001
From: Hamid Shojanazeri <hamid.nazeri2010@gmail.com>
Date: Tue, 6 Feb 2024 03:10:30 +0000
Subject: [PATCH] adding sdpa for flash attn

---
 src/llama_recipes/finetuning.py            | 4 ++--
 src/llama_recipes/inference/model_utils.py | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index 3b4973cf..d0a72400 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -94,7 +94,7 @@ def main(**kwargs):
                 load_in_8bit=True if train_config.quantization else None,
                 device_map="auto" if train_config.quantization else None,
                 use_cache=use_cache,
-                attn_implementation="eager" if train_config.use_fast_kernels else None,
+                attn_implementation="sdpa" if train_config.use_fast_kernels else None,
             )
         else:
             llama_config = LlamaConfig.from_pretrained(train_config.model_name)
@@ -108,7 +108,7 @@ def main(**kwargs):
             load_in_8bit=True if train_config.quantization else None,
             device_map="auto" if train_config.quantization else None,
             use_cache=use_cache,
-            attn_implementation="eager" if train_config.use_fast_kernels else None,
+            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
         )
 
     # Load the tokenizer and add special tokens
diff --git a/src/llama_recipes/inference/model_utils.py b/src/llama_recipes/inference/model_utils.py
index cc7cc6d2..e70cb82c 100644
--- a/src/llama_recipes/inference/model_utils.py
+++ b/src/llama_recipes/inference/model_utils.py
@@ -13,7 +13,7 @@ def load_model(model_name, quantization, use_fast_kernels):
         load_in_8bit=quantization,
         device_map="auto",
         low_cpu_mem_usage=True,
-        attn_implementation="eager" if use_fast_kernels else None,
+        attn_implementation="sdpa" if use_fast_kernels else None,
     )
     return model
 
-- 
GitLab