From beae303394ff89a77ef180cb4e7f42af3e72d59a Mon Sep 17 00:00:00 2001
From: Oleksandr Maksymets <maksymets.o@gmail.com>
Date: Thu, 4 Apr 2019 10:35:07 -0700
Subject: [PATCH] Made simple agents suitable for challenge submissions (#35)

---
 baselines/agents/.simple_agents.py.swp | Bin 0 -> 16384 bytes
 baselines/agents/ppo_agents.py         |  14 ++++++++++----
 baselines/agents/simple_agents.py      |  23 +++++++++--------------
 test/test_baseline_agents.py           |   2 +-
 4 files changed, 20 insertions(+), 19 deletions(-)
 create mode 100644 baselines/agents/.simple_agents.py.swp

diff --git a/baselines/agents/.simple_agents.py.swp b/baselines/agents/.simple_agents.py.swp
new file mode 100644
index 0000000000000000000000000000000000000000..400f785aeec06f7b3006cbf9b86ec73ac4b864ec
GIT binary patch
literal 16384
zcmeI3O^h5z6~~JZ;)IY7B_KsYNaZmyJ@HQOtQVBc%4@VXUT46&Ywd1`&|;dJo|>7~
z_E)-k){K)7p+uaxAR!@f00}`3D5pq~9Onj3a*C7$E&;+vq9Blf2yuZ6@PAd^^SQRy
z3PKjDmVZ0b)m5+Fd-b|L$Ibc|8>@W5tutJAGxqe?p84{^<89{sfw4&JsD31KMbEBs
zwd8JF1X+FV!B1Z$7h_)^2%PPTW{``#CemQGDU=L@SgKjkk#VkORS>12^vs=`_OB|`
zKC~XN9=Ns#dU?CHa0i<|I;UUtLzRR4-g_Rtb~&5Fdcb<Xdcb<Xdcb<Xdcb<Xdcb<%
z|HcD(dJ}sR*4$KBR4%qZP~4ZFA1v-aQG8yQcwXGv59<Nz0qX(l0qX(l0qX(l0qX(l
z0qX(l0qX(lfg8{RqQ%(F(CcQ}0KoZwx&HstU5xz^JP#Cj5G;TPz^A~y;2!YrI~n^I
zcn|y(JPW=99tRgdAGE+%z&yAe+y-t12f({`F!myN8axGJAi!6_JU9S;@e#&;4xR%~
zf#aYK{&72FZ-d{1-+>pw_rMM~4DJK3eVDOJ;016Id>eGYG4L_)QSkTM7`qIf1doCv
z;J3Fj_A2-x_!c+`{&|40cfjl5HE;>M04{>>f(V=dcZ0WYVeBPv8GIjn6A18m@JVnN
zxD)*5L(m(11MGr2sDQW7D6fE*!H>Z%m;;Bv->?~T2|Nq>pawn;evLN#6?h3;20sJO
zfD2#_90B)(D$wm+#A8dgc}M2v2N#|nDo6kH0^fvEUMwP6T|`!ntE&Ifw{6ASNydF1
z#2iIj&r`joE+LiYIaU2s(Nirt3`d?je)XIdl*zo8#azkU*-J*%?ZrV7J8YoggqTnb
zG$Wk~al+40$|y|9<z>?O66y=vF4gSi$$62rl-Ef_=!m$3pPl51z8hLf#bv1EA}a-z
zp96uLja6@LdG(aHweBsS+-j_^ZEprqFBEx_Eo*I*yRyFV$nwVL++7j&q%psg6STQ#
z!s}!8^J7NOBDFhS{kVE0C{A6GWuosSO(nBkK|WC1l{5+BoXo47+nzf&UTmzw(l2q6
zMIsC?;0HLQkHbnWr+si;_HS;jpPe)whU>NnObd9K6n<B>#!Ns9lLiz5sL!GNGknxW
z3IWPN*LVb&H1qembUkxwWouH`_i6q{<McyQG7H;EJ;x&e0Z^}U5x4k6`f*-MgJDZj
z^o&LhKhD#D<J~k(M~<70b>qgda(nra<g!vd`qU*5Dz&3-o!?It9WiNxT%|f%jq#u@
z3tzkJF*0%3@RKn)qW!4WI+ikcEQ!1{OPbKLKF~`)EDY}ANjpqL?%=Iy$<EPCw4Axm
zREJL0xzj{>Ipx)Ntw80t&7|rkVT<|+l_%*?elL}oQ$<G@#6gZOQ1yIYq64LJBKD;b
zPR%jQqJw!V<08`?DL;+Zb|neJ<UAyAXuFo_Cvs%-px*J2A0+eQ8k$t5(C2vnGx}sd
zUD;?X&#=+bqu=)<Y2-!QThR8D_`!>erZa8r`rRb(M|vvKV)}KZ)Is6>D0lT9dssaN
zDGY^?Q=j2;_*X4^?xMVRc4Pg-@`=Wo#@3gH@58iy*QMM4q&L-JzlRm{Ubzd597>n^
z??=7~_LU@~;aSK0KI7kLWd2^4L$A=-H*hnVKBDl^dK~uOx1WX~UA$3i>uVlGS~02F
z9Cju*uwPgi14YkB{VzYZzfN4|T^gHSj8CEJrst$yd~(u4Z_AZuCMo99S~Dt&F$7ay
zDlj#C?LErS7&I}Q4211>G#ME{oZ|P_?r=CUc46?|<p-De)a-TG-{^X2>c2CCyqMop
zS867`H2tM%up8tuQjUq19X1(7FlHO85G`oBA2#E_w9q1bVp=`J+F26u(jNSDT)D-K
zhZ{2m=lDT+lI55X6<>-Rth|qVQQGH1VG_-Zz(`==FLI|rnNMV$RAfqG_&vnWG)|sc
z+dReFK`6-&c~|l*Npju_GU;RH-{%RYUfn=(mGm-S@;sB$Maiv^G(TyfwBjK~uvp0!
z@5L>dX<@64tuh&_DWn&*K*Gr+?Po!!3!D5ZUlG1+CJ7p(5&N#5G6uQgq7BsokxS)5
z--EL~mCZJTcvi-{JniS*B%Wuq{(k^#=~n=)|I73JKVxnGOYj2_feu&&XTTT0XTg2o
zb*$kpg7e@}a2U|K{ub~i*6u$6KLROu2%H9o!F#Chui!7>4M6q34D4k+U_D?xU_D?x
zU_D?xU_D?xU_D?xa19S|T#*Q3tU0lS_4J{Y=kcW_UhzD7;CYqNP1AC{7$%*L%-kSu
zCr;(DCl>i*i|%~;iHb`bde|POHEyvhJi6q?e~QimUzO|J;sByZ#qzc(W4{}TY{!*!
z+=C^UIZ`ug-oaDkEEPFXmT<7fRBkTRj^`(FJLs4)<}aB!xeu_U3rwam$|`8H7?n&M
z?okE$kvZY3I%YGO<RG(N;AXD=@F=ZV#~10M)uPm11kp~VR?~uNdM~On!}ESBmsFn3
z(0PsMh51rtbQ8NWX$%CAsWoHj-U29tn$_o6vDkI{A_^<oyuCEK;sexF_^4KH$d}dW
zlYpVdb@!-I)5wQQV6Xn}E*FNMrA+{A_UnU56L)2F6`M{t3^m@BYP;Ao8t%CxL<i4t
sUO*?TCRj&WMW<?O(2(XZwG?@UL)9{rmh}<Apy;3`oiezS)YjYVKg1HaZ2$lO

literal 0
HcmV?d00001

diff --git a/baselines/agents/ppo_agents.py b/baselines/agents/ppo_agents.py
index 9a0080bed..db6ad957b 100644
--- a/baselines/agents/ppo_agents.py
+++ b/baselines/agents/ppo_agents.py
@@ -120,11 +120,14 @@ class PPOAgent(Agent):
 def main():
     parser = argparse.ArgumentParser()
     parser.add_argument(
-        "--input_type",
+        "--input-type",
         default="blind",
         choices=["blind", "rgb", "depth", "rgbd"],
     )
-    parser.add_argument("--model_path", default="", type=str)
+    parser.add_argument("--model-path", default="", type=str)
+    parser.add_argument(
+        "--task-config", type=str, default="tasks/pointnav.yaml"
+    )
     args = parser.parse_args()
 
     config = get_defaut_config()
@@ -132,8 +135,11 @@ def main():
     config.MODEL_PATH = args.model_path
 
     agent = PPOAgent(config)
-    challenge = habitat.Challenge()
-    challenge.submit(agent)
+    benchmark = habitat.Benchmark(args.task_config)
+    metrics = benchmark.evaluate(agent)
+
+    for k, v in metrics.items():
+        habitat.logger.info("{}: {:.3f}".format(k, v))
 
 
 if __name__ == "__main__":
diff --git a/baselines/agents/simple_agents.py b/baselines/agents/simple_agents.py
index 6e83cd0c9..8906b86f4 100644
--- a/baselines/agents/simple_agents.py
+++ b/baselines/agents/simple_agents.py
@@ -6,7 +6,6 @@
 
 
 import argparse
-import random
 from math import pi
 
 import numpy as np
@@ -26,16 +25,12 @@ NON_STOP_ACTIONS = [
 
 
 class RandomAgent(habitat.Agent):
-    def __init__(self, config):
-        self.dist_threshold_to_stop = config.TASK.SUCCESS_DISTANCE
+    def __init__(self, success_distance):
+        self.dist_threshold_to_stop = success_distance
 
     def reset(self):
         pass
 
-    def act(self, observations):
-        action = SIM_NAME_TO_ACTION[SimulatorActions.FORWARD.value]
-        return action
-
     def is_goal_reached(self, observations):
         dist = observations["pointgoal"][0]
         return dist <= self.dist_threshold_to_stop
@@ -58,9 +53,8 @@ class ForwardOnlyAgent(RandomAgent):
 
 
 class RandomForwardAgent(RandomAgent):
-    def __init__(self, config):
-        super(RandomForwardAgent, self).__init__(config)
-        self.dist_threshold_to_stop = config.TASK.SUCCESS_DISTANCE
+    def __init__(self, success_distance):
+        super().__init__(success_distance)
         self.FORWARD_PROBABILITY = 0.8
 
     def act(self, observations):
@@ -81,8 +75,8 @@ class RandomForwardAgent(RandomAgent):
 
 
 class GoalFollower(RandomAgent):
-    def __init__(self, config):
-        super(GoalFollower, self).__init__(config)
+    def __init__(self, success_distance):
+        super().__init__(success_distance)
         self.pos_th = self.dist_threshold_to_stop
         self.angle_th = float(np.deg2rad(15))
         self.random_prob = 0
@@ -135,14 +129,15 @@ def get_agent_cls(agent_class_name):
 
 def main():
     parser = argparse.ArgumentParser()
+    parser.add_argument("--success-distance", type=float, default=0.2)
     parser.add_argument(
         "--task-config", type=str, default="tasks/pointnav.yaml"
     )
-    parser.add_argument("--agent_class", type=str, default="GoalFollower")
+    parser.add_argument("--agent-class", type=str, default="GoalFollower")
     args = parser.parse_args()
 
     agent = get_agent_cls(args.agent_class)(
-        habitat.get_config(args.task_config)
+        success_distance=args.success_distance
     )
     benchmark = habitat.Benchmark(args.task_config)
     metrics = benchmark.evaluate(agent)
diff --git a/test/test_baseline_agents.py b/test/test_baseline_agents.py
index f849669e6..230e5949d 100644
--- a/test/test_baseline_agents.py
+++ b/test/test_baseline_agents.py
@@ -52,6 +52,6 @@ def test_simple_agents():
         simple_agents.RandomAgent,
         simple_agents.RandomForwardAgent,
     ]:
-        agent = agent_class(config_env)
+        agent = agent_class(config_env.TASK.SUCCESS_DISTANCE)
         habitat.logger.info(agent_class.__name__)
         habitat.logger.info(benchmark.evaluate(agent, num_episodes=100))
-- 
GitLab