From d7251452cf2d2255c9998da2ad9a116d9265167c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 17 Jul 2024 19:21:53 +0200 Subject: [PATCH] Update. --- grids.py | 24 +++++++++++++----------- main.py | 9 ++++++++- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/grids.py b/grids.py index 7050b77..400bf91 100755 --- a/grids.py +++ b/grids.py @@ -1147,15 +1147,16 @@ class Grids(problem.Problem): w1, w2 = d eq.append((c[i1], w1, c[i2], w2)) - ii = torch.randperm(len(eq)) + ii = torch.randperm(self.height - 2)[: len(eq)] for k, x in enumerate(eq): i = ii[k] c1, w1, c2, w2 = x - X[i, 0:w1] = c1 - X[i, w1 : w1 + w2] = c2 - f_X[i, 0:w1] = c1 - f_X[i, w1 : w1 + w2] = c2 + s = torch.randint(self.width - (w1 + w2) + 1, (1,)).item() + X[i, s : s + w1] = c1 + X[i, s + w1 : s + w1 + w2] = c2 + f_X[i, s : s + w1] = c1 + f_X[i, s + w1 : s + w1 + w2] = c2 i1, i2 = torch.randperm(N)[:2] v1, v2 = v[i1], v[i2] @@ -1164,11 +1165,12 @@ class Grids(problem.Problem): d = d[torch.randint(d.size(0), (1,)).item()] w1, w2 = d c1, c2 = c[i1], c[i2] + s = 0 # torch.randint(self.width - (w1 + w2) + 1, (1,)).item() i = self.height - 1 - X[i, 0:w1] = c1 - X[i, w1 : w1 + 1] = c2 - f_X[i, 0:w1] = c1 - f_X[i, w1 : w1 + w2] = c2 + X[i, s : s + w1] = c1 + X[i, s + w1 : s + w1 + 1] = c2 + f_X[i, s : s + w1] = c1 + f_X[i, s + w1 : s + w1 + w2] = c2 ###################################################################### @@ -1267,12 +1269,12 @@ if __name__ == "__main__": "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow ) - exit(0) + # exit(0) nb = 1000 # for t in grids.all_tasks: - for t in [grids.task_count]: + for t in [grids.task_compute]: start_time = time.perf_counter() prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) delay = time.perf_counter() - start_time diff --git a/main.py b/main.py index 5a37251..178925b 100755 --- a/main.py +++ b/main.py @@ -16,7 +16,7 @@ import ffutils import mygpt import sky, grids, quiz_machine -import threading +import threading, subprocess import torch.multiprocessing as mp @@ -36,6 +36,8 @@ parser.add_argument("--resume", action="store_true", default=False) parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1) +parser.add_argument("--log_command", type=str, default=None) + ######################################## parser.add_argument("--nb_epochs", type=int, default=10000) @@ -666,4 +668,9 @@ for n_epoch in range(args.nb_epochs): forward_only=args.forward_only, ) + if args.log_command is not None: + s = args.log_command.split() + s.insert(1, args.result_dir) + subprocess.run(s) + ###################################################################### -- 2.39.5