X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=world.py;h=118a470b29b159d97826dc895362db8e8d673ded;hb=8adf0586ee5aeb9fbdf81b78c7ff4b484a9b82ab;hp=97c7b1d5dd3d85ebe9889caad1691e5df302e160;hpb=0ab695df8f6a2a0cc70a424e57943a0d5606903b;p=culture.git diff --git a/world.py b/world.py index 97c7b1d..118a470 100755 --- a/world.py +++ b/world.py @@ -22,7 +22,7 @@ colors = torch.tensor( [255, 0, 0], [0, 128, 0], [0, 0, 255], - [255, 255, 0], + [255, 200, 0], [192, 192, 192], ] ) @@ -34,14 +34,14 @@ def generate( nb, height, width, - max_nb_obj=colors.size(0) - 2, + max_nb_obj=2, nb_iterations=2, ): f_start = torch.zeros(nb, height, width, dtype=torch.int64) f_end = torch.zeros(nb, height, width, dtype=torch.int64) n = torch.arange(f_start.size(0)) - for n in range(nb): + for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"): nb_fish = torch.randint(max_nb_obj, (1,)).item() + 1 for c in torch.randperm(colors.size(0) - 2)[:nb_fish].sort().values: i, j = (