X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=world.py;h=43126d5d63466e948317839e727a410e2b267c62;hb=6917d3d52a4b473d31121a471ab98fa114bdb1a6;hp=ac201e74d7ed5ee1a8af385db09bc2b2712ddbd3;hpb=6bd776c5842485db888d81e756e22623e8dc949f;p=culture.git diff --git a/world.py b/world.py index ac201e7..43126d5 100755 --- a/world.py +++ b/world.py @@ -34,7 +34,7 @@ def generate( nb, height, width, - max_nb_obj=len(colors) - 2, + max_nb_obj=2, nb_iterations=2, ): f_start = torch.zeros(nb, height, width, dtype=torch.int64) @@ -43,7 +43,7 @@ def generate( for n in range(nb): nb_fish = torch.randint(max_nb_obj, (1,)).item() + 1 - for c in range(nb_fish): + for c in torch.randperm(colors.size(0) - 2)[:nb_fish].sort().values: i, j = ( torch.randint(height - 2, (1,))[0] + 1, torch.randint(width - 2, (1,))[0] + 1, @@ -117,7 +117,7 @@ if __name__ == "__main__": height, width = 6, 8 start_time = time.perf_counter() - seq = generate(nb=64, height=height, width=width) + seq = generate(nb=64, height=height, width=width, max_nb_obj=3) delay = time.perf_counter() - start_time print(f"{seq.size(0)/delay:02f} samples/s")