From: François Fleuret Date: Mon, 23 Sep 2024 06:15:48 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=d677b4d46788459d7d302d34d8fad0e1741f3b74;p=culture.git Update. --- diff --git a/grids.py b/grids.py index 0f7e554..4c132c3 100755 --- a/grids.py +++ b/grids.py @@ -146,8 +146,8 @@ class Grids(problem.Problem): # grid_gray = 192 # thickness = 0 - # background_gray = 255 - # dots = True + # background_gray = 240 + # dots = False named_colors = [ ("white", [background_gray, background_gray, background_gray]), @@ -1820,7 +1820,7 @@ class Grids(problem.Problem): print(t.__name__) quizzes = self.generate_w_quizzes_(nb, tasks=[t]) self.save_quizzes_as_image( - result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow + result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow, delta=True ) @@ -1835,6 +1835,44 @@ if __name__ == "__main__": nb, nrow = 64, 4 # nb, nrow = 8, 2 + nb_rows = 13 + + c_quizzes = torch.load("/home/fleuret/state.pth")["train_c_quizzes"] + c_quizzes = c_quizzes[torch.randperm(c_quizzes.size(0))[: nrow * nb_rows]] + + grids.save_quizzes_as_image( + "/tmp", + "c_quizzes.png", + c_quizzes, + # delta=True, + nrow=nrow, + margin=10, + # grids=False + # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))], + ) + + w_quizzes = grids.generate_w_quizzes_( + nrow * nb_rows, + tasks=[ + grids.task_replace_color, + grids.task_translate, + grids.task_grow, + grids.task_frame, + ], + ) + + grids.save_quizzes_as_image( + "/tmp", + "w_quizzes.png", + w_quizzes, + # delta=True, + nrow=nrow, + margin=10, + # grids=False + # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))], + ) + + exit(0) # for t in grids.all_tasks: diff --git a/main.py b/main.py index 00722d6..4fee0f2 100755 --- a/main.py +++ b/main.py @@ -47,7 +47,7 @@ parser.add_argument("--eval_batch_size", type=int, default=25) parser.add_argument("--nb_train_samples", type=int, default=50000) -parser.add_argument("--nb_test_samples", type=int, default=2500) +parser.add_argument("--nb_test_samples", type=int, default=10000) parser.add_argument("--nb_c_quizzes", type=int, default=5000) @@ -605,7 +605,10 @@ def one_complete_epoch( # Compute the test accuracy quizzes = generate_quiz_set( - args.nb_test_samples, test_c_quizzes, args.c_quiz_multiplier + args.nb_test_samples, + c_quizzes=None, + c_quiz_multiplier=args.c_quiz_multiplier + # args.nb_test_samples, test_c_quizzes, args.c_quiz_multiplier ) imt_set = samples_for_prediction_imt(quizzes.to(local_device)) result = ae_predict(model, imt_set, local_device=local_device).to("cpu") @@ -658,8 +661,6 @@ def evaluate_quizzes(quizzes, models, with_hints, local_device): nb_correct += (nb_mistakes == 0).long() nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long() - # print("\n\n", nb_correct, nb_wrong) - return nb_correct, nb_wrong @@ -796,6 +797,7 @@ def save_models(models, suffix=""): "state_dict": model.state_dict(), "optimizer_state_dict": model.optimizer.state_dict(), "test_accuracy": model.test_accuracy, + "nb_epochs": model.nb_epochs, }, os.path.join(args.result_dir, filename), ) @@ -830,6 +832,35 @@ def save_quiz_image(models, c_quizzes, filename, local_device=main_device): log_string(f"wrote {filename}") +###################################################################### + + +def new_model(i): + if args.model_type == "standard": + model_constructor = attae.AttentionAE + elif args.model_type == "functional": + model_constructor = attae.FunctionalAttentionAE + else: + raise ValueError(f"Unknown model type {args.model_type}") + + model = model_constructor( + vocabulary_size=vocabulary_size * 2, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + dropout=args.dropout, + ) + + model.id = i + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + model.test_accuracy = 0.0 + model.nb_epochs = 0 + + return model + + ###################################################################### problem = grids.Grids( @@ -853,31 +884,10 @@ log_string(f"vocabulary_size {vocabulary_size}") models = [] -if args.model_type == "standard": - model_constructor = attae.AttentionAE -elif args.model_type == "functional": - model_constructor = attae.FunctionalAttentionAE -else: - raise ValueError(f"Unknown model type {args.model_type}") - - for i in range(args.nb_models): - model = model_constructor( - vocabulary_size=vocabulary_size * 2, - dim_model=args.dim_model, - dim_keys=args.dim_keys, - dim_hidden=args.dim_hidden, - nb_heads=args.nb_heads, - nb_blocks=args.nb_blocks, - dropout=args.dropout, - ) - + model = new_model(i) # model = torch.compile(model) - model.id = i - model.test_accuracy = 0.0 - model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - models.append(model) ###################################################################### @@ -896,6 +906,8 @@ if args.resume: model.load_state_dict(d["state_dict"]) model.optimizer.load_state_dict(d["optimizer_state_dict"]) model.test_accuracy = d["test_accuracy"] + model.nb_epochs = d["nb_epochs"] + log_string(f"successfully loaded {filename}") filename = "state.pth" @@ -972,6 +984,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): # The c quizzes used to estimate the test accuracy have to be # solvable without hints + nb_correct, _ = evaluate_quizzes( quizzes=train_c_quizzes, models=models,