Skip to content
Snippets Groups Projects
Unverified Commit 98fcc538 authored by Hamid Shojanazeri's avatar Hamid Shojanazeri Committed by GitHub
Browse files

Add option to enable Llamaguard content safety check in chat_completion (#354)

parents d89a02fe ed3e11e9
Branches
Tags
No related merge requests found
...@@ -35,6 +35,7 @@ def main( ...@@ -35,6 +35,7 @@ def main(
enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5 enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
enable_llamaguard_content_safety: bool = False,
**kwargs **kwargs
): ):
if prompt_file is not None: if prompt_file is not None:
...@@ -90,6 +91,7 @@ def main( ...@@ -90,6 +91,7 @@ def main(
safety_checker = get_safety_checker(enable_azure_content_safety, safety_checker = get_safety_checker(enable_azure_content_safety,
enable_sensitive_topics, enable_sensitive_topics,
enable_saleforce_content_safety, enable_saleforce_content_safety,
enable_llamaguard_content_safety,
) )
# Safety check of the user prompt # Safety check of the user prompt
safety_results = [check(dialogs[idx][0]["content"]) for check in safety_checker] safety_results = [check(dialogs[idx][0]["content"]) for check in safety_checker]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment