######################################################################
-problem = grids.Grids(
- max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
- chunk_size=100,
- nb_threads=args.nb_threads,
- tasks=args.grids_world_tasks,
-)
-
-if not args.resume:
- problem.save_some_examples(args.result_dir)
-
def pure_noise(nb, device):
r = problem.pure_noise(nb, device)
######################################################################
-log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
-
-vocabulary_size = problem.vocabulary_size()
-
-log_string(f"vocabulary_size {vocabulary_size}")
-
-######################################################################
-
def optimizer_to(optim, device):
"""Move the optimizer optim to the device"""
######################################################################
-models = []
-
-for i in range(args.nb_models):
- # model = attae.FunctionalAttentionAE(
- model = attae.AttentionAE(
- vocabulary_size=vocabulary_size * 2,
- dim_model=args.dim_model,
- dim_keys=args.dim_keys,
- dim_hidden=args.dim_hidden,
- nb_heads=args.nb_heads,
- nb_blocks=args.nb_blocks,
- dropout=args.dropout,
- )
-
- # model = torch.compile(model)
-
- model.id = i
- model.test_accuracy = 0.0
- model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
- models.append(model)
-
-######################################################################
-
def evaluate_quizzes(quizzes, models, local_device):
nb_correct, nb_wrong = 0, 0
######################################################################
+def multithread_execution(fun, arguments):
+ # Single instance, no thread
+ if len(arguments) == 1:
+ return fun(*(arguments[0]))
+
+ records, threads = [], []
+
+ def threadable_fun(*args):
+ r = fun(*args)
+ if type(r) is not tuple:
+ r = (r,)
+ records.append(r)
+
+ for args in arguments:
+ # To get a different sequence between threads
+ # log_string(f"dummy_rand {torch.rand(1)}")
+ torch.rand(1)
+ t = threading.Thread(target=threadable_fun, daemon=True, args=args)
+ threads.append(t)
+ t.start()
+
+ for t in threads:
+ t.join()
+
+ if records[0] == (None,):
+ return
+ else:
+ return [
+ torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
+ ]
+
+
+######################################################################
+
+
+def save_models(models, suffix=""):
+ if suffix != "":
+ suffix = "_" + suffix
+
+ for model in models:
+ filename = f"ae_{model.id:03d}{suffix}.pth"
+ torch.save(
+ {
+ "state_dict": model.state_dict(),
+ "optimizer_state_dict": model.optimizer.state_dict(),
+ "test_accuracy": model.test_accuracy,
+ },
+ os.path.join(args.result_dir, filename),
+ )
+
+ log_string(f"wrote ae_*{suffix}.pth")
+
+
+######################################################################
+
+
def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
c_quizzes = c_quizzes.to(local_device)
######################################################################
+problem = grids.Grids(
+ max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+ chunk_size=100,
+ nb_threads=args.nb_threads,
+ tasks=args.grids_world_tasks,
+)
+
+if not args.resume:
+ problem.save_some_examples(args.result_dir)
+
+
+log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
+
+vocabulary_size = problem.vocabulary_size()
+
+log_string(f"vocabulary_size {vocabulary_size}")
+
+######################################################################
+
+models = []
+
+for i in range(args.nb_models):
+ # model = attae.FunctionalAttentionAE(
+ model = attae.AttentionAE(
+ vocabulary_size=vocabulary_size * 2,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ dropout=args.dropout,
+ )
+
+ # model = torch.compile(model)
+
+ model.id = i
+ model.test_accuracy = 0.0
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
+ models.append(model)
+
+######################################################################
+
current_epoch = 0
if args.resume:
######################################################################
-
-def multithread_execution(fun, arguments):
- # Single instance, no thread
- if len(arguments) == 1:
- return fun(*(arguments[0]))
-
- records, threads = [], []
-
- def threadable_fun(*args):
- r = fun(*args)
- if type(r) is not tuple:
- r = (r,)
- records.append(r)
-
- for args in arguments:
- # To get a different sequence between threads
- # log_string(f"dummy_rand {torch.rand(1)}")
- torch.rand(1)
- t = threading.Thread(target=threadable_fun, daemon=True, args=args)
- threads.append(t)
- t.start()
-
- for t in threads:
- t.join()
-
- if records[0] == (None,):
- return
- else:
- return [
- torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
- ]
-
-
-######################################################################
-
-
-def save_models(models, suffix=""):
- if suffix != "":
- suffix = "_" + suffix
-
- for model in models:
- filename = f"ae_{model.id:03d}{suffix}.pth"
- torch.save(
- {
- "state_dict": model.state_dict(),
- "optimizer_state_dict": model.optimizer.state_dict(),
- "test_accuracy": model.test_accuracy,
- },
- os.path.join(args.result_dir, filename),
- )
-
- log_string(f"wrote ae_*{suffix}.pth")
-
-
-######################################################################
-
for n_epoch in range(current_epoch, args.nb_epochs):
start_time = time.perf_counter()