diff --git a/s2s_pipeline.py b/s2s_pipeline.py index bd91fcd3ecad37880b1c99bb7b8791205da17062..fc3d433b6ec1feefd53b7289024ab2caa9280891 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -85,8 +85,10 @@ def parse_arguments(): ) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # Parse configurations from a JSON file if specified return parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: + # Parse arguments from command line if no JSON file is provided return parser.parse_args_into_dataclasses() @@ -98,6 +100,7 @@ def setup_logger(log_level): ) logger = logging.getLogger(__name__) + # torch compile logs if log_level == "debug": torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)