Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 17 Jul 2024 17:21:53 +0000 (19:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 17 Jul 2024 17:21:53 +0000 (19:21 +0200)
grids.py
main.py

index 7050b77..400bf91 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -1147,15 +1147,16 @@ class Grids(problem.Problem):
                 w1, w2 = d
                 eq.append((c[i1], w1, c[i2], w2))
 
-            ii = torch.randperm(len(eq))
+            ii = torch.randperm(self.height - 2)[: len(eq)]
 
             for k, x in enumerate(eq):
                 i = ii[k]
                 c1, w1, c2, w2 = x
-                X[i, 0:w1] = c1
-                X[i, w1 : w1 + w2] = c2
-                f_X[i, 0:w1] = c1
-                f_X[i, w1 : w1 + w2] = c2
+                s = torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
+                X[i, s : s + w1] = c1
+                X[i, s + w1 : s + w1 + w2] = c2
+                f_X[i, s : s + w1] = c1
+                f_X[i, s + w1 : s + w1 + w2] = c2
 
             i1, i2 = torch.randperm(N)[:2]
             v1, v2 = v[i1], v[i2]
@@ -1164,11 +1165,12 @@ class Grids(problem.Problem):
             d = d[torch.randint(d.size(0), (1,)).item()]
             w1, w2 = d
             c1, c2 = c[i1], c[i2]
+            s = 0  # torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
             i = self.height - 1
-            X[i, 0:w1] = c1
-            X[i, w1 : w1 + 1] = c2
-            f_X[i, 0:w1] = c1
-            f_X[i, w1 : w1 + w2] = c2
+            X[i, s : s + w1] = c1
+            X[i, s + w1 : s + w1 + 1] = c2
+            f_X[i, s : s + w1] = c1
+            f_X[i, s + w1 : s + w1 + w2] = c2
 
     ######################################################################
 
@@ -1267,12 +1269,12 @@ if __name__ == "__main__":
             "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
         )
 
-    exit(0)
+    exit(0)
 
     nb = 1000
 
     # for t in grids.all_tasks:
-    for t in [grids.task_count]:
+    for t in [grids.task_compute]:
         start_time = time.perf_counter()
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         delay = time.perf_counter() - start_time
diff --git a/main.py b/main.py
index 5a37251..178925b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -16,7 +16,7 @@ import ffutils
 import mygpt
 import sky, grids, quiz_machine
 
-import threading
+import threading, subprocess
 
 import torch.multiprocessing as mp
 
@@ -36,6 +36,8 @@ parser.add_argument("--resume", action="store_true", default=False)
 
 parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
 
+parser.add_argument("--log_command", type=str, default=None)
+
 ########################################
 
 parser.add_argument("--nb_epochs", type=int, default=10000)
@@ -666,4 +668,9 @@ for n_epoch in range(args.nb_epochs):
             forward_only=args.forward_only,
         )
 
+    if args.log_command is not None:
+        s = args.log_command.split()
+        s.insert(1, args.result_dir)
+        subprocess.run(s)
+
 ######################################################################