From beae303394ff89a77ef180cb4e7f42af3e72d59a Mon Sep 17 00:00:00 2001 From: Oleksandr Maksymets <maksymets.o@gmail.com> Date: Thu, 4 Apr 2019 10:35:07 -0700 Subject: [PATCH] Made simple agents suitable for challenge submissions (#35) --- baselines/agents/.simple_agents.py.swp | Bin 0 -> 16384 bytes baselines/agents/ppo_agents.py | 14 ++++++++++---- baselines/agents/simple_agents.py | 23 +++++++++-------------- test/test_baseline_agents.py | 2 +- 4 files changed, 20 insertions(+), 19 deletions(-) create mode 100644 baselines/agents/.simple_agents.py.swp diff --git a/baselines/agents/.simple_agents.py.swp b/baselines/agents/.simple_agents.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..400f785aeec06f7b3006cbf9b86ec73ac4b864ec GIT binary patch literal 16384 zcmeI3O^h5z6~~JZ;)IY7B_KsYNaZmyJ@HQOtQVBc%4@VXUT46&Ywd1`&|;dJo|>7~ z_E)-k){K)7p+uaxAR!@f00}`3D5pq~9Onj3a*C7$E&;+vq9Blf2yuZ6@PAd^^SQRy z3PKjDmVZ0b)m5+Fd-b|L$Ibc|8>@W5tutJAGxqe?p84{^<89{sfw4&JsD31KMbEBs zwd8JF1X+FV!B1Z$7h_)^2%PPTW{``#CemQGDU=L@SgKjkk#VkORS>12^vs=`_OB|` zKC~XN9=Ns#dU?CHa0i<|I;UUtLzRR4-g_Rtb~&5Fdcb<Xdcb<Xdcb<Xdcb<Xdcb<% z|HcD(dJ}sR*4$KBR4%qZP~4ZFA1v-aQG8yQcwXGv59<Nz0qX(l0qX(l0qX(l0qX(l z0qX(l0qX(lfg8{RqQ%(F(CcQ}0KoZwx&HstU5xz^JP#Cj5G;TPz^A~y;2!YrI~n^I zcn|y(JPW=99tRgdAGE+%z&yAe+y-t12f({`F!myN8axGJAi!6_JU9S;@e#&;4xR%~ zf#aYK{&72FZ-d{1-+>pw_rMM~4DJK3eVDOJ;016Id>eGYG4L_)QSkTM7`qIf1doCv z;J3Fj_A2-x_!c+`{&|40cfjl5HE;>M04{>>f(V=dcZ0WYVeBPv8GIjn6A18m@JVnN zxD)*5L(m(11MGr2sDQW7D6fE*!H>Z%m;;Bv->?~T2|Nq>pawn;evLN#6?h3;20sJO zfD2#_90B)(D$wm+#A8dgc}M2v2N#|nDo6kH0^fvEUMwP6T|`!ntE&Ifw{6ASNydF1 z#2iIj&r`joE+LiYIaU2s(Nirt3`d?je)XIdl*zo8#azkU*-J*%?ZrV7J8YoggqTnb zG$Wk~al+40$|y|9<z>?O66y=vF4gSi$$62rl-Ef_=!m$3pPl51z8hLf#bv1EA}a-z zp96uLja6@LdG(aHweBsS+-j_^ZEprqFBEx_Eo*I*yRyFV$nwVL++7j&q%psg6STQ# z!s}!8^J7NOBDFhS{kVE0C{A6GWuosSO(nBkK|WC1l{5+BoXo47+nzf&UTmzw(l2q6 zMIsC?;0HLQkHbnWr+si;_HS;jpPe)whU>NnObd9K6n<B>#!Ns9lLiz5sL!GNGknxW z3IWPN*LVb&H1qembUkxwWouH`_i6q{<McyQG7H;EJ;x&e0Z^}U5x4k6`f*-MgJDZj z^o&LhKhD#D<J~k(M~<70b>qgda(nra<g!vd`qU*5Dz&3-o!?It9WiNxT%|f%jq#u@ z3tzkJF*0%3@RKn)qW!4WI+ikcEQ!1{OPbKLKF~`)EDY}ANjpqL?%=Iy$<EPCw4Axm zREJL0xzj{>Ipx)Ntw80t&7|rkVT<|+l_%*?elL}oQ$<G@#6gZOQ1yIYq64LJBKD;b zPR%jQqJw!V<08`?DL;+Zb|neJ<UAyAXuFo_Cvs%-px*J2A0+eQ8k$t5(C2vnGx}sd zUD;?X&#=+bqu=)<Y2-!QThR8D_`!>erZa8r`rRb(M|vvKV)}KZ)Is6>D0lT9dssaN zDGY^?Q=j2;_*X4^?xMVRc4Pg-@`=Wo#@3gH@58iy*QMM4q&L-JzlRm{Ubzd597>n^ z??=7~_LU@~;aSK0KI7kLWd2^4L$A=-H*hnVKBDl^dK~uOx1WX~UA$3i>uVlGS~02F z9Cju*uwPgi14YkB{VzYZzfN4|T^gHSj8CEJrst$yd~(u4Z_AZuCMo99S~Dt&F$7ay zDlj#C?LErS7&I}Q4211>G#ME{oZ|P_?r=CUc46?|<p-De)a-TG-{^X2>c2CCyqMop zS867`H2tM%up8tuQjUq19X1(7FlHO85G`oBA2#E_w9q1bVp=`J+F26u(jNSDT)D-K zhZ{2m=lDT+lI55X6<>-Rth|qVQQGH1VG_-Zz(`==FLI|rnNMV$RAfqG_&vnWG)|sc z+dReFK`6-&c~|l*Npju_GU;RH-{%RYUfn=(mGm-S@;sB$Maiv^G(TyfwBjK~uvp0! z@5L>dX<@64tuh&_DWn&*K*Gr+?Po!!3!D5ZUlG1+CJ7p(5&N#5G6uQgq7BsokxS)5 z--EL~mCZJTcvi-{JniS*B%Wuq{(k^#=~n=)|I73JKVxnGOYj2_feu&&XTTT0XTg2o zb*$kpg7e@}a2U|K{ub~i*6u$6KLROu2%H9o!F#Chui!7>4M6q34D4k+U_D?xU_D?x zU_D?xU_D?xU_D?xa19S|T#*Q3tU0lS_4J{Y=kcW_UhzD7;CYqNP1AC{7$%*L%-kSu zCr;(DCl>i*i|%~;iHb`bde|POHEyvhJi6q?e~QimUzO|J;sByZ#qzc(W4{}TY{!*! z+=C^UIZ`ug-oaDkEEPFXmT<7fRBkTRj^`(FJLs4)<}aB!xeu_U3rwam$|`8H7?n&M z?okE$kvZY3I%YGO<RG(N;AXD=@F=ZV#~10M)uPm11kp~VR?~uNdM~On!}ESBmsFn3 z(0PsMh51rtbQ8NWX$%CAsWoHj-U29tn$_o6vDkI{A_^<oyuCEK;sexF_^4KH$d}dW zlYpVdb@!-I)5wQQV6Xn}E*FNMrA+{A_UnU56L)2F6`M{t3^m@BYP;Ao8t%CxL<i4t sUO*?TCRj&WMW<?O(2(XZwG?@UL)9{rmh}<Apy;3`oiezS)YjYVKg1HaZ2$lO literal 0 HcmV?d00001 diff --git a/baselines/agents/ppo_agents.py b/baselines/agents/ppo_agents.py index 9a0080bed..db6ad957b 100644 --- a/baselines/agents/ppo_agents.py +++ b/baselines/agents/ppo_agents.py @@ -120,11 +120,14 @@ class PPOAgent(Agent): def main(): parser = argparse.ArgumentParser() parser.add_argument( - "--input_type", + "--input-type", default="blind", choices=["blind", "rgb", "depth", "rgbd"], ) - parser.add_argument("--model_path", default="", type=str) + parser.add_argument("--model-path", default="", type=str) + parser.add_argument( + "--task-config", type=str, default="tasks/pointnav.yaml" + ) args = parser.parse_args() config = get_defaut_config() @@ -132,8 +135,11 @@ def main(): config.MODEL_PATH = args.model_path agent = PPOAgent(config) - challenge = habitat.Challenge() - challenge.submit(agent) + benchmark = habitat.Benchmark(args.task_config) + metrics = benchmark.evaluate(agent) + + for k, v in metrics.items(): + habitat.logger.info("{}: {:.3f}".format(k, v)) if __name__ == "__main__": diff --git a/baselines/agents/simple_agents.py b/baselines/agents/simple_agents.py index 6e83cd0c9..8906b86f4 100644 --- a/baselines/agents/simple_agents.py +++ b/baselines/agents/simple_agents.py @@ -6,7 +6,6 @@ import argparse -import random from math import pi import numpy as np @@ -26,16 +25,12 @@ NON_STOP_ACTIONS = [ class RandomAgent(habitat.Agent): - def __init__(self, config): - self.dist_threshold_to_stop = config.TASK.SUCCESS_DISTANCE + def __init__(self, success_distance): + self.dist_threshold_to_stop = success_distance def reset(self): pass - def act(self, observations): - action = SIM_NAME_TO_ACTION[SimulatorActions.FORWARD.value] - return action - def is_goal_reached(self, observations): dist = observations["pointgoal"][0] return dist <= self.dist_threshold_to_stop @@ -58,9 +53,8 @@ class ForwardOnlyAgent(RandomAgent): class RandomForwardAgent(RandomAgent): - def __init__(self, config): - super(RandomForwardAgent, self).__init__(config) - self.dist_threshold_to_stop = config.TASK.SUCCESS_DISTANCE + def __init__(self, success_distance): + super().__init__(success_distance) self.FORWARD_PROBABILITY = 0.8 def act(self, observations): @@ -81,8 +75,8 @@ class RandomForwardAgent(RandomAgent): class GoalFollower(RandomAgent): - def __init__(self, config): - super(GoalFollower, self).__init__(config) + def __init__(self, success_distance): + super().__init__(success_distance) self.pos_th = self.dist_threshold_to_stop self.angle_th = float(np.deg2rad(15)) self.random_prob = 0 @@ -135,14 +129,15 @@ def get_agent_cls(agent_class_name): def main(): parser = argparse.ArgumentParser() + parser.add_argument("--success-distance", type=float, default=0.2) parser.add_argument( "--task-config", type=str, default="tasks/pointnav.yaml" ) - parser.add_argument("--agent_class", type=str, default="GoalFollower") + parser.add_argument("--agent-class", type=str, default="GoalFollower") args = parser.parse_args() agent = get_agent_cls(args.agent_class)( - habitat.get_config(args.task_config) + success_distance=args.success_distance ) benchmark = habitat.Benchmark(args.task_config) metrics = benchmark.evaluate(agent) diff --git a/test/test_baseline_agents.py b/test/test_baseline_agents.py index f849669e6..230e5949d 100644 --- a/test/test_baseline_agents.py +++ b/test/test_baseline_agents.py @@ -52,6 +52,6 @@ def test_simple_agents(): simple_agents.RandomAgent, simple_agents.RandomForwardAgent, ]: - agent = agent_class(config_env) + agent = agent_class(config_env.TASK.SUCCESS_DISTANCE) habitat.logger.info(agent_class.__name__) habitat.logger.info(benchmark.evaluate(agent, num_episodes=100)) -- GitLab