diff --git a/baselines/agents/.simple_agents.py.swp b/baselines/agents/.simple_agents.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..400f785aeec06f7b3006cbf9b86ec73ac4b864ec Binary files /dev/null and b/baselines/agents/.simple_agents.py.swp differ diff --git a/baselines/agents/ppo_agents.py b/baselines/agents/ppo_agents.py index 9a0080bedc6d2222f73e35809ab1a20b92909bc6..db6ad957bc8801954006dcf1be4e5be90b148daf 100644 --- a/baselines/agents/ppo_agents.py +++ b/baselines/agents/ppo_agents.py @@ -120,11 +120,14 @@ class PPOAgent(Agent): def main(): parser = argparse.ArgumentParser() parser.add_argument( - "--input_type", + "--input-type", default="blind", choices=["blind", "rgb", "depth", "rgbd"], ) - parser.add_argument("--model_path", default="", type=str) + parser.add_argument("--model-path", default="", type=str) + parser.add_argument( + "--task-config", type=str, default="tasks/pointnav.yaml" + ) args = parser.parse_args() config = get_defaut_config() @@ -132,8 +135,11 @@ def main(): config.MODEL_PATH = args.model_path agent = PPOAgent(config) - challenge = habitat.Challenge() - challenge.submit(agent) + benchmark = habitat.Benchmark(args.task_config) + metrics = benchmark.evaluate(agent) + + for k, v in metrics.items(): + habitat.logger.info("{}: {:.3f}".format(k, v)) if __name__ == "__main__": diff --git a/baselines/agents/simple_agents.py b/baselines/agents/simple_agents.py index 6e83cd0c95b5c7e5848874ae82e48bbce4ea147a..8906b86f469327bee919890b394a71f9fb4eebc6 100644 --- a/baselines/agents/simple_agents.py +++ b/baselines/agents/simple_agents.py @@ -6,7 +6,6 @@ import argparse -import random from math import pi import numpy as np @@ -26,16 +25,12 @@ NON_STOP_ACTIONS = [ class RandomAgent(habitat.Agent): - def __init__(self, config): - self.dist_threshold_to_stop = config.TASK.SUCCESS_DISTANCE + def __init__(self, success_distance): + self.dist_threshold_to_stop = success_distance def reset(self): pass - def act(self, observations): - action = SIM_NAME_TO_ACTION[SimulatorActions.FORWARD.value] - return action - def is_goal_reached(self, observations): dist = observations["pointgoal"][0] return dist <= self.dist_threshold_to_stop @@ -58,9 +53,8 @@ class ForwardOnlyAgent(RandomAgent): class RandomForwardAgent(RandomAgent): - def __init__(self, config): - super(RandomForwardAgent, self).__init__(config) - self.dist_threshold_to_stop = config.TASK.SUCCESS_DISTANCE + def __init__(self, success_distance): + super().__init__(success_distance) self.FORWARD_PROBABILITY = 0.8 def act(self, observations): @@ -81,8 +75,8 @@ class RandomForwardAgent(RandomAgent): class GoalFollower(RandomAgent): - def __init__(self, config): - super(GoalFollower, self).__init__(config) + def __init__(self, success_distance): + super().__init__(success_distance) self.pos_th = self.dist_threshold_to_stop self.angle_th = float(np.deg2rad(15)) self.random_prob = 0 @@ -135,14 +129,15 @@ def get_agent_cls(agent_class_name): def main(): parser = argparse.ArgumentParser() + parser.add_argument("--success-distance", type=float, default=0.2) parser.add_argument( "--task-config", type=str, default="tasks/pointnav.yaml" ) - parser.add_argument("--agent_class", type=str, default="GoalFollower") + parser.add_argument("--agent-class", type=str, default="GoalFollower") args = parser.parse_args() agent = get_agent_cls(args.agent_class)( - habitat.get_config(args.task_config) + success_distance=args.success_distance ) benchmark = habitat.Benchmark(args.task_config) metrics = benchmark.evaluate(agent) diff --git a/test/test_baseline_agents.py b/test/test_baseline_agents.py index f849669e6b359b8399efd2baa11a46777622d6ca..230e5949d0047b03a65260bbc1f4894b7780c792 100644 --- a/test/test_baseline_agents.py +++ b/test/test_baseline_agents.py @@ -52,6 +52,6 @@ def test_simple_agents(): simple_agents.RandomAgent, simple_agents.RandomForwardAgent, ]: - agent = agent_class(config_env) + agent = agent_class(config_env.TASK.SUCCESS_DISTANCE) habitat.logger.info(agent_class.__name__) habitat.logger.info(benchmark.evaluate(agent, num_episodes=100))