From 504f61114d90b57e1d0faf55a298756da2c8fbfa Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 29 Jun 2024 16:24:23 +0300 Subject: [PATCH] Update. --- main.py | 15 ++++++++++++--- wireworld.py | 21 +++++++++------------ 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/main.py b/main.py index 1565499..b88847e 100755 --- a/main.py +++ b/main.py @@ -13,7 +13,7 @@ from torch.nn import functional as F import ffutils import mygpt -import sky, quizz_machine +import sky, wireworld, quizz_machine # world quizzes vs. culture quizzes @@ -37,7 +37,7 @@ parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) -parser.add_argument("--log_filename", type=str, default="train.log", help=" ") +parser.add_argument("--log_filename", type=str, default="train.log") parser.add_argument("--result_dir", type=str, default=None) @@ -79,6 +79,8 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) +parser.add_argument("--problem", type=str, default="sky") + parser.add_argument("--nb_gpts", type=int, default=5) parser.add_argument("--nb_models_for_generation", type=int, default=1) @@ -219,8 +221,15 @@ else: assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 -quizz_machine = quizz_machine.QuizzMachine( +if args.problem=="sky": problem=sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2), +elif args.problem="wireworld": + problem=wireworld.Wireworld(height=10, width=15, nb_iterations=4) +else: + raise ValueError + +quizz_machine = quizz_machine.QuizzMachine( + problem=problem, nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.physical_batch_size, diff --git a/wireworld.py b/wireworld.py index 98e2334..219d7dd 100755 --- a/wireworld.py +++ b/wireworld.py @@ -17,7 +17,7 @@ from torch.nn import functional as F import problem -class Physics(problem.Problem): +class Wireworld(problem.Problem): colors = torch.tensor( [ [128, 128, 128], @@ -38,14 +38,11 @@ class Physics(problem.Problem): "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><" ) - def __init__( - self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4 - ): + def __init__(self, height=6, width=8, nb_objects=2, nb_walls=2, nb_iterations=4): self.height = height self.width = width self.nb_objects = nb_objects self.nb_walls = nb_walls - self.speed = speed self.nb_iterations = nb_iterations def direction_tokens(self): @@ -55,7 +52,7 @@ class Physics(problem.Problem): frame_sequences = [] result = torch.full( - (nb * 100, self.nb_iterations, self.height, self.width), self.token_empty + (nb * 4, self.nb_iterations, self.height, self.width), self.token_empty ) for n in range(result.size(0)): @@ -114,7 +111,7 @@ class Physics(problem.Problem): result = result[i] if result.size(0) < nb: - print(result.size(0)) + # print(result.size(0)) result = torch.cat( [result, self.generate_frame_sequences(nb - result.size(0))], dim=0 ) @@ -264,17 +261,17 @@ class Physics(problem.Problem): if __name__ == "__main__": import time - sky = Physics(height=10, width=15, speed=1, nb_iterations=100) + wireworld = Wireworld(height=10, width=15, nb_iterations=4) start_time = time.perf_counter() - frame_sequences = sky.generate_frame_sequences(nb=96) + frame_sequences = wireworld.generate_frame_sequences(nb=96) delay = time.perf_counter() - start_time print(f"{frame_sequences.size(0)/delay:02f} seq/s") - # print(sky.seq2str(seq[:4])) + # print(wireworld.seq2str(seq[:4])) for t in range(frame_sequences.size(1)): - img = sky.seq2img(frame_sequences[:, t]) + img = wireworld.seq2img(frame_sequences[:, t]) torchvision.utils.save_image( img.float() / 255.0, f"/tmp/frame_{t:03d}.png", @@ -286,7 +283,7 @@ if __name__ == "__main__": # m = (torch.rand(seq.size()) < 0.05).long() # seq = (1 - m) * seq + m * 23 - # img = sky.seq2img(frame_sequences[:60]) + # img = wireworld.seq2img(frame_sequences[:60]) # torchvision.utils.save_image( # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1 -- 2.20.1