diff --git a/examples/inference.py b/examples/inference.py index 3b554440ba37fb63d0058dc0a1f3274c3230bfc4..4f83c8f2caec1fbac1558a2182671fe5de67f24e 100644 --- a/examples/inference.py +++ b/examples/inference.py @@ -7,6 +7,7 @@ import fire import os import sys import time +import gradio as gr import torch from transformers import LlamaTokenizer @@ -39,18 +40,8 @@ def main( use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels **kwargs ): - if prompt_file is not None: - assert os.path.exists( - prompt_file - ), f"Provided Prompt file does not exist {prompt_file}" - with open(prompt_file, "r") as f: - user_prompt = "\n".join(f.readlines()) - elif not sys.stdin.isatty(): - user_prompt = "\n".join(sys.stdin.readlines()) - else: - print("No user prompt provided. Exiting.") - sys.exit(1) + def inference(user_prompt, temperature, top_p, top_k, max_new_tokens, **kwargs,): safety_checker = get_safety_checker(enable_azure_content_safety, enable_sensitive_topics, enable_salesforce_content_safety, @@ -126,7 +117,49 @@ def main( if not is_safe: print(method) print(report) - + return output_text + + if prompt_file is not None: + assert os.path.exists( + prompt_file + ), f"Provided Prompt file does not exist {prompt_file}" + with open(prompt_file, "r") as f: + user_prompt = "\n".join(f.readlines()) + inference(user_prompt, temperature, top_p, top_k, max_new_tokens) + elif not sys.stdin.isatty(): + user_prompt = "\n".join(sys.stdin.readlines()) + inference(user_prompt, temperature, top_p, top_k, max_new_tokens) + else: + gr.Interface( + fn=inference, + inputs=[ + gr.components.Textbox( + lines=9, + label="User Prompt", + placeholder="none", + ), + gr.components.Slider( + minimum=0, maximum=1, value=1.0, label="Temperature" + ), + gr.components.Slider( + minimum=0, maximum=1, value=1.0, label="Top p" + ), + gr.components.Slider( + minimum=0, maximum=100, step=1, value=50, label="Top k" + ), + gr.components.Slider( + minimum=1, maximum=2000, step=1, value=200, label="Max tokens" + ), + ], + outputs=[ + gr.components.Textbox( + lines=5, + label="Output", + ) + ], + title="Llama2 Playground", + description="https://github.com/facebookresearch/llama-recipes", + ).queue().launch(server_name="0.0.0.0", share=True) if __name__ == "__main__": fire.Fire(main)