Update.
[culture.git] / world.py
index ac201e7..118a470 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -22,7 +22,7 @@ colors = torch.tensor(
         [255, 0, 0],
         [0, 128, 0],
         [0, 0, 255],
         [255, 0, 0],
         [0, 128, 0],
         [0, 0, 255],
-        [255, 255, 0],
+        [255, 200, 0],
         [192, 192, 192],
     ]
 )
         [192, 192, 192],
     ]
 )
@@ -34,16 +34,16 @@ def generate(
     nb,
     height,
     width,
     nb,
     height,
     width,
-    max_nb_obj=len(colors) - 2,
+    max_nb_obj=2,
     nb_iterations=2,
 ):
     f_start = torch.zeros(nb, height, width, dtype=torch.int64)
     f_end = torch.zeros(nb, height, width, dtype=torch.int64)
     n = torch.arange(f_start.size(0))
 
     nb_iterations=2,
 ):
     f_start = torch.zeros(nb, height, width, dtype=torch.int64)
     f_end = torch.zeros(nb, height, width, dtype=torch.int64)
     n = torch.arange(f_start.size(0))
 
-    for n in range(nb):
+    for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
         nb_fish = torch.randint(max_nb_obj, (1,)).item() + 1
         nb_fish = torch.randint(max_nb_obj, (1,)).item() + 1
-        for c in range(nb_fish):
+        for c in torch.randperm(colors.size(0) - 2)[:nb_fish].sort().values:
             i, j = (
                 torch.randint(height - 2, (1,))[0] + 1,
                 torch.randint(width - 2, (1,))[0] + 1,
             i, j = (
                 torch.randint(height - 2, (1,))[0] + 1,
                 torch.randint(width - 2, (1,))[0] + 1,
@@ -117,7 +117,7 @@ if __name__ == "__main__":
 
     height, width = 6, 8
     start_time = time.perf_counter()
 
     height, width = 6, 8
     start_time = time.perf_counter()
-    seq = generate(nb=64, height=height, width=width)
+    seq = generate(nb=64, height=height, width=width, max_nb_obj=3)
     delay = time.perf_counter() - start_time
     print(f"{seq.size(0)/delay:02f} samples/s")
 
     delay = time.perf_counter() - start_time
     print(f"{seq.size(0)/delay:02f} samples/s")