From: François Fleuret Date: Tue, 31 Oct 2023 15:33:15 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=7cf92d14892ccce7c5a1eaa38c0d6b8fff03e751;p=pytorch.git Update. --- diff --git a/eingather.py b/eingather.py index c7552d7..03b713c 100755 --- a/eingather.py +++ b/eingather.py @@ -71,7 +71,7 @@ def lambda_eingather(op, src_shape, *indexes_shape): idx.append(lambda indexes: a) print(f"{idx=}") - return lambda indexes: [ f(indexes) for f in idx] + return lambda indexes: [f(indexes) for f in idx] f = do(src_shape, s_src) print(f"{f(0)=}") @@ -102,12 +102,12 @@ index2 = torch.randint(src.size(3), (src.size(1),)) # result[a, c, e] = src[c, a, index1[e, a, e], index2[a]] -#result = eingather("ca(eae)(a) -> ace", src, index1, index2) +# result = eingather("ca(eae)(a) -> ace", src, index1, index2) from functorch.dim import dims -a,c,e=dims(3) -result=src[c,a,index1[e,a,e],index2[a]].order(a,c,e) +a, c, e = dims(3) +result = src[c, a, index1[e, a, e], index2[a]].order(a, c, e) # Check diff --git a/picocrafter.py b/picocrafter.py index 33a00c1..31ba1e4 100755 --- a/picocrafter.py +++ b/picocrafter.py @@ -35,11 +35,13 @@ # 5pt. # # The agent can carry "keys" ("a", "b", "c") that open "vaults" ("A", -# "B", "C"). They keys can only be used in sequence: initially the -# agent can move only to free spaces, or to the "a", in which case it -# now carries it, and can move to free spaces or the "A". When it -# moves to the "A", it gets a reward and loses the "a", but can now -# move to the "b", etc. Rewards are 1 for "A" and "B" and 10 for "C". +# "B", "C"). The keys and vault can only be used in sequence: +# initially the agent can move only to free spaces, or to the "a", in +# which case the key is removed from the environment and the agent now +# carries it, and can move to free spaces or the "A". When it moves to +# the "A", it gets a reward, loses the "a", the "A" is removed from +# the environment, but can now move to the "b", etc. Rewards are 1 for +# "A" and "B" and 10 for "C". ###################################################################### @@ -90,6 +92,7 @@ class PicroCrafterEngine: ("b", "B"), ("B", "c"), ("c", "C"), + ("C", " "), ] ] ) @@ -245,6 +248,11 @@ class PicroCrafterEngine: ] ) + self.life_level_in_100th = ( + self.life_level_in_100th + * (self.accessible_object != self.token2id[" "]).long() + ) + reward[torch.logical_not(alive_before)] = 0 return reward, inventory, self.life_level_in_100th // 100 @@ -387,7 +395,7 @@ if __name__ == "__main__": ansi_term = False # nb_agents, nb_iter, display = 1000, 100, False nb_agents, nb_iter, display = 3, 10000, True - ansi_term = True + # ansi_term = True start_time = time.perf_counter() engine = PicroCrafterEngine( @@ -427,7 +435,7 @@ if __name__ == "__main__": width=engine.world_width, ansi_term=ansi_term, ) - time.sleep(0.5) + time.sleep(0.25) if (life_levels > 0).long().sum() == 0: break