From 888106500badae460e9ae2183512c7124601acad Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 27 Mar 2024 09:06:02 +0100 Subject: [PATCH] Update. --- greed.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/greed.py b/greed.py index 20cef79..dc11d14 100755 --- a/greed.py +++ b/greed.py @@ -77,6 +77,8 @@ def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3, nb_coins=2): rnd = rnd * (1 - wall.clamp(max=1)) rnd = torch.rand(nb, height, width) + rnd[:, 0, 0] = 0 # Do not put coin at the agent's starting + # position coins = torch.zeros(nb, T, height, width, dtype=torch.int64) rnd = rnd * (1 - wall.clamp(max=1)) for k in range(nb_coins): -- 2.39.5