From dba501c6f80f50e958a4d6a8bf2a884ce8e16b7e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 4 Oct 2024 22:45:42 +0200 Subject: [PATCH] Update. --- attae.py | 76 ++++++++++++-------- main.py | 65 +++++++++++------ world.py | 211 +++++++++++++++++++++++++++++++++++++++++++------------ 3 files changed, 258 insertions(+), 94 deletions(-) diff --git a/attae.py b/attae.py index c04c5d3..2b231de 100755 --- a/attae.py +++ b/attae.py @@ -92,10 +92,42 @@ class MHAttention(nn.Module): ###################################################################### +def create_trunk(dim_model, dim_keys, dim_hidden, nb_heads, nb_blocks, dropout=0.0): + trunk_blocks = [] + + for b in range(nb_blocks): + trunk_blocks += [ + WithResidual( + nn.LayerNorm((dim_model,)), + MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=vanilla_attention, + attention_dropout=dropout, + ), + ), + WithResidual( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ] + + return nn.Sequential(*trunk_blocks) + + +###################################################################### + + class AttentionAE(nn.Module): def __init__( self, - vocabulary_size, + vocabulary_size_in, + vocabulary_size_out, dim_model, dim_keys, dim_hidden, @@ -109,39 +141,24 @@ class AttentionAE(nn.Module): assert dim_model % nb_heads == 0 self.embedding = nn.Sequential( - nn.Embedding(2 * vocabulary_size, dim_model), + nn.Embedding(vocabulary_size_in, dim_model), nn.Dropout(dropout), ) self.positional_encoding = VaswaniPositionalEncoding(len_max) - trunk_blocks = [] - - for b in range(nb_blocks): - trunk_blocks += [ - WithResidual( - nn.LayerNorm((dim_model,)), - MHAttention( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - attention=vanilla_attention, - attention_dropout=dropout, - ), - ), - WithResidual( - nn.LayerNorm((dim_model,)), - nn.Linear(in_features=dim_model, out_features=dim_hidden), - nn.ReLU(), - nn.Linear(in_features=dim_hidden, out_features=dim_model), - nn.Dropout(dropout), - ), - ] - - self.trunk = nn.Sequential(*trunk_blocks) + self.trunk = create_trunk( + dim_model=dim_model, + dim_keys=dim_keys, + dim_hidden=dim_hidden, + nb_heads=nb_heads, + nb_blocks=nb_blocks, + dropout=dropout, + ) - self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size) + self.readout = nn.Linear( + in_features=dim_model, out_features=vocabulary_size_out + ) with torch.no_grad(): for m in self.modules(): @@ -271,7 +288,8 @@ class FunctionalAttentionAE(nn.Module): if __name__ == "__main__": model = FunctionalAttentionAE( - vocabulary_size=100, + vocabulary_size_in=100, + vocabulary_size_out=100, dim_model=16, dim_keys=64, dim_hidden=32, diff --git a/main.py b/main.py index d699bc6..f6cf450 100755 --- a/main.py +++ b/main.py @@ -19,7 +19,7 @@ import threading, subprocess torch.set_float32_matmul_precision("high") -# torch.set_default_dtype(torch.bfloat16) +torch.set_default_dtype(torch.bfloat16) ###################################################################### @@ -234,18 +234,17 @@ for n in vars(args): ###################################################################### -if args.gpus == "all": - gpus_idx = range(torch.cuda.device_count()) +if args.gpus == "none" or not torch.cuda.is_available(): + gpus = [torch.device("cpu")] else: - gpus_idx = [int(k) for k in args.gpus.split(",")] + if args.gpus == "all": + gpus_idx = range(torch.cuda.device_count()) + else: + gpus_idx = [int(k) for k in args.gpus.split(",")] -gpus = [torch.device(f"cuda:{n}") for n in gpus_idx] + gpus = [torch.device(f"cuda:{n}") for n in gpus_idx] -if torch.cuda.is_available(): - main_device = gpus[0] -else: - assert len(gpus) == 0 - main_device = torch.device("cpu") +main_device = gpus[0] if args.train_batch_size is None: args.train_batch_size = args.batch_size @@ -318,6 +317,8 @@ def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): ###################################################################### +# IMT stands for image/mask/target + def add_hints_imt(imt_set, proba_hints): """Set every component of the mask to zero with probability proba, @@ -356,7 +357,7 @@ def add_input_noise_imt(imt_set, proba_input_noise): # Prediction -def samples_for_prediction_imt(input): +def make_imt_samples_for_prediction(input): nb = input.size(0) masks = input.new_zeros(input.size()) u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4) @@ -386,7 +387,7 @@ def ae_predict(model, imt_set, local_device=main_device): imt[:, 0] = imt[:, 0] * (1 - imt[:, 1]) with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = model(imt[:, 0] * 2 + imt[:, 1]) + logits = model(imt[:, 0] + imt[:, 1] * vocabulary_size) dist = torch.distributions.categorical.Categorical(logits=logits) result = (1 - imt[:, 1]) * imt[:, 0] + imt[:, 1] * dist.sample() record.append(result) @@ -421,7 +422,7 @@ def predict_the_four_grids( ###################################################################### -def samples_for_generation_imt(input): +def make_imt_samples_for_generation(input): nb = input.size(0) probs_iterations = 0.1 ** torch.linspace( 0, 1, args.diffusion_nb_iterations, device=input.device @@ -480,7 +481,7 @@ def ae_generate(model, nb, local_device=main_device): for input, masks, changed in src: with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = model(input * 2 + masks) + logits = model(input + masks * vocabulary_size) dist = torch.distributions.categorical.Categorical(logits=logits) output = dist.sample() r = prioritized_rand(input != output) @@ -510,7 +511,7 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): q_p, q_g = quizzes.to(local_device).chunk(2) # Half of the samples are used to train the prediction. - b_p = samples_for_prediction_imt(q_p) + b_p = make_imt_samples_for_prediction(q_p) # We inject noise in all to avoid drift of the culture toward # "finding waldo" type of complexity b_p = add_input_noise_imt(b_p, args.proba_input_noise) @@ -521,7 +522,7 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): # The other half are denoising examples to train the generative # process. - b_g = samples_for_generation_imt(q_g) + b_g = make_imt_samples_for_generation(q_g) imt_set = torch.cat([b_p, b_g]) imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)] @@ -550,7 +551,7 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): model.optimizer.zero_grad() with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = model(input * 2 + masks) + logits = model(input + masks * vocabulary_size) loss_per_token = F.cross_entropy( logits.transpose(1, 2), targets, reduction="none" @@ -575,7 +576,7 @@ def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_de # Save some images of the prediction results quizzes = generate_quiz_set(150, c_quizzes, args.c_quiz_multiplier) - imt_set = samples_for_prediction_imt(quizzes.to(local_device)) + imt_set = make_imt_samples_for_prediction(quizzes.to(local_device)) result = ae_predict(model, imt_set, local_device=local_device).to("cpu") masks = imt_set[:, 1].to("cpu") @@ -622,7 +623,7 @@ def one_complete_epoch( # c_quizzes=test_c_quizzes, c_quiz_multiplier=args.c_quiz_multiplier, ) - imt_set = samples_for_prediction_imt(quizzes.to(local_device)) + imt_set = make_imt_samples_for_prediction(quizzes.to(local_device)) result = ae_predict(model, imt_set, local_device=local_device).to("cpu") correct = (quizzes == result).min(dim=1).values.long() @@ -855,7 +856,8 @@ def new_model(id=-1): raise ValueError(f"Unknown model type {args.model_type}") model = model_constructor( - vocabulary_size=vocabulary_size * 2, + vocabulary_size_in=vocabulary_size * 2, + vocabulary_size_out=vocabulary_size, dim_model=args.dim_model, dim_keys=args.dim_keys, dim_hidden=args.dim_hidden, @@ -910,6 +912,29 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### +if args.test == "aebn": + model = new_model() + + # model.trunk = ( + # model.trunk[: len(model.trunk) // 2] + model.trunk[len(model.trunk) // 2 :] + # ) + + model.id = 0 + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + model.test_accuracy = 0.0 + model.nb_epochs = 0 + + for n_epoch in range(args.nb_epochs): + one_complete_epoch( + model, + n_epoch, + train_c_quizzes=None, + test_c_quizzes=None, + local_device=main_device, + ) + +###################################################################### + train_c_quizzes, test_c_quizzes = None, None models = [] diff --git a/world.py b/world.py index 3ab6944..cd34cf8 100755 --- a/world.py +++ b/world.py @@ -560,31 +560,39 @@ class Grids(problem.Problem): result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow, delta=True ) - def detect_rectangles(self, q1, q2): + def oracle(self, q1, q2): c = torch.arange(self.nb_colors) - I = torch.arange(self.height)[None, :, None] - J = torch.arange(self.width)[None, :, None] + all_i = torch.arange(self.height)[None, :, None] + all_j = torch.arange(self.width)[None, :, None] def corners(q): q = q.reshape(-1, self.height, self.width) a = (q[:, :, :, None] == c[None, None, None, :]).long() mi = a.max(dim=2).values - i = mi * I + i = mi * all_i i1 = (i + (1 - mi) * q.size(1)).min(dim=1).values i2 = (i + (1 - mi) * (-1)).max(dim=1).values + 1 mj = a.max(dim=1).values - j = mj * J + j = mj * all_j j1 = (j + (1 - mj) * q.size(2)).min(dim=1).values j2 = (j + (1 - mj) * (-1)).max(dim=1).values + 1 m = ( - ((I > i1[:, None, :]) & (I < i2[:, None, :] - 1))[:, :, None, :] - & ((J > j1[:, None, :]) & (J < j2[:, None, :] - 1))[:, None, :, :] + ((all_i > i1[:, None, :]) & (all_i < i2[:, None, :] - 1))[:, :, None, :] + & ((all_j > j1[:, None, :]) & (all_j < j2[:, None, :] - 1))[ + :, None, :, : + ] ).long() f = ((a * m).long().sum(dim=(1, 2)) > 0).long() + i1[:, 0], i2[:, 0], j1[:, 0], j2[:, 0] = self.height, 0, self.width, 0 return i1, i2, j1, j2, f + # Coordinates and frame-shape per grid per color + # + # NxC + # q1_i1, q1_i2, q1_j1, q1_j2, q1_f = corners(q1) q2_i1, q2_i2, q2_j1, q2_j2, q2_f = corners(q2) + u1, u2 = 0, 0 for _ in range(10): @@ -592,54 +600,48 @@ class Grids(problem.Problem): r2 = q.new_zeros(q1.size(0), self.height, self.width) m1 = ( - ((I >= q1_i1[:, None, :]) & (I < q1_i2[:, None, :]))[:, :, None, :] - & ((J >= q1_j1[:, None, :]) & (J < q1_j2[:, None, :]))[:, None, :, :] + ((all_i >= q1_i1[:, None, :]) & (all_i < q1_i2[:, None, :]))[ + :, :, None, : + ] + & ((all_j >= q1_j1[:, None, :]) & (all_j < q1_j2[:, None, :]))[ + :, None, :, : + ] ).long() f1 = ( - ( - ((I == q1_i1[:, None, :]) | (I == q1_i2[:, None, :] - 1))[ - :, :, None, : - ] - & ((J >= q1_j1[:, None, :]) & (J < q1_j2[:, None, :]))[ - :, None, :, : - ] - ) - | ( - ((I >= q1_i1[:, None, :]) & (I < q1_i2[:, None, :] - 1))[ + m1 + * ( + ((all_i == q1_i1[:, None, :]) | (all_i == q1_i2[:, None, :] - 1))[ :, :, None, : ] - & ((J == q1_j1[:, None, :]) | (J == q1_j2[:, None, :] - 1))[ + | ((all_j == q1_j1[:, None, :]) | (all_j == q1_j2[:, None, :] - 1))[ :, None, :, : ] - ) - ).long() + ).long() + ) r2 = q.new_zeros(q2.size(0), self.height, self.width) m2 = ( - ((I >= q2_i1[:, None, :]) & (I < q2_i2[:, None, :]))[:, :, None, :] - & ((J >= q2_j1[:, None, :]) & (J < q2_j2[:, None, :]))[:, None, :, :] + ((all_i >= q2_i1[:, None, :]) & (all_i < q2_i2[:, None, :]))[ + :, :, None, : + ] + & ((all_j >= q2_j1[:, None, :]) & (all_j < q2_j2[:, None, :]))[ + :, None, :, : + ] ).long() f2 = ( - ( - ((I == q2_i1[:, None, :]) | (I == q2_i2[:, None, :] - 1))[ - :, :, None, : - ] - & ((J >= q2_j1[:, None, :]) & (J < q2_j2[:, None, :]))[ - :, None, :, : - ] - ) - | ( - ((I >= q2_i1[:, None, :]) & (I < q2_i2[:, None, :] - 1))[ + m2 + * ( + ((all_i == q2_i1[:, None, :]) | (all_i == q2_i2[:, None, :] - 1))[ :, :, None, : ] - & ((J == q2_j1[:, None, :]) | (J == q2_j2[:, None, :] - 1))[ + | ((all_j == q2_j1[:, None, :]) | (all_j == q2_j2[:, None, :] - 1))[ :, None, :, : ] - ) - ).long() + ).long() + ) for c in torch.randperm(self.nb_colors - 1) + 1: r1[...] = q1_f[:, None, None, c] * ( @@ -661,17 +663,119 @@ class Grids(problem.Problem): u1 = (1 - match) * u1 + match * r1 u2 = (1 - match) * u2 + match * r2 - return u1.flatten(1), u2.flatten(1) + ok = (u1.flatten(1) == q1).min(dim=1).values & (u2.flatten(1) == q2).min( + dim=1 + ).values - # o = F.one_hot(q * (1 - m)).sum(dim=1) - # print(o) - # print(o.sort(dim=1, descending=True)) - # c = N x nb_col x 4 + # q1_i1, q1_i2, q1_j1, q1_j2, q1_f = corners(q1) + # q2_i1, q2_i2, q2_j1, q2_j2, q2_f = corners(q2) + # NxC + + q1i1, q1i2, q1j1, q1j2, q1f = ( + q1_i1[:, :, None], + q1_i2[:, :, None], + q1_j1[:, :, None], + q1_j2[:, :, None], + q1_f[:, :, None], + ) + q2i1, q2i2, q2j1, q2j2, q2f = ( + q2_i1[:, None, :], + q2_i2[:, None, :], + q2_j1[:, None, :], + q2_j2[:, None, :], + q2_f[:, None, :], + ) + + match = ( + (q1i1 < q1i2) + & (q2i1 < q2i2) + & (q1i1 == q2i1) + & (q1i2 == q2i2) + & (q1j1 == q2j1) + & (q1j2 == q2j2) + ).long() + translate = ( + ((q1i1 - q2i1).abs() <= 1) + & (q1i1 - q2i1 == q1i2 - q2i2) + & ((q1j1 - q2j1).abs() <= 1) + & (q1j1 - q2j1 == q1j2 - q2j2) + & ((q1i1 - q2i1).abs() + (q1j1 - q2j1).abs() > 0) + ).long() + grow = ( + ( + (q2i1 == q1i1 - 1) + & (q2i2 == q1i2 + 1) + & (q2j1 == q1j1 - 1) + & (q2j2 == q1j2 + 1) + ) + | ( + (q2i1 == q1i1 + 1) + & (q2i2 == q1i2 - 1) + & (q2j1 == q1j1 + 1) + & (q2j2 == q1j2 - 1) + ) + ).long() + + nb_same_color_not_frame = torch.einsum("ncc->n", match * q2f) + nb_change_color = torch.einsum("ncd->n", match * q2f) - nb_same_color_not_frame + nb_frame = torch.einsum("ncc->n", match * (1 - q2f)) + nb_translate = torch.einsum("ncc->n", translate) + nb_translate_change_color = torch.einsum("ncd->n", translate) - nb_translate + nb_grow = torch.einsum("ncc->n", grow) + nb_grow_change_color = torch.einsum("ncd->n", grow) - nb_grow + + print("-------------------------") + print("nb_same_color_not_frame", nb_same_color_not_frame) + print("nb_change_color", nb_change_color) + print("nb_frame", nb_frame) + print("nb_translate", nb_translate) + print("nb_translate_change_color", nb_translate_change_color) + print("nb_grow", nb_grow) + print("nb_grow_change_color", nb_grow_change_color) + + # ok = ok & ( <= 1) & (translate.sum(dim=(1,2) == 3) + + # print("match", match, "\n\n") + + # print("translate", translate, "\n\n") + + # print("grow", grow, "\n\n") + + return u1.flatten(1), u2.flatten(1) ###################################################################### +def recenv(a): + s_row = a.sum(dim=2, keepdim=True) + c_row = s_row.cumsum(dim=1) + s_col = a.sum(dim=1, keepdim=True) + c_col = s_col.cumsum(dim=2) + env_row = ((c_row > 0) & ((c_row < c_row[:, -1:, :]) | (s_row > 0))).long() + env_col = ((c_col > 0) & ((c_col < c_col[:, :, -1:]) | (s_col > 0))).long() + return env_row * env_col + +def valid(q1, q2, m1=None, m2=None): + if m1 is None: + m1=q1.new_zeros(m1.size()) + if m2 is None: + m2=q2.new_zeros(m2.size()) + +def valid_exact_match(q1, q2, m1, m2, c1, c2): + # q1, q2, m1, m2 are NxHxW + q1,m1=F.hone_hot(q1,numclasses=self.nb_colors),m1[:,:,:,None] + q2,m2=F.hone_hot(q2,numclasses=self.nb_colors),m2[:,:,:,None] + a1 = (1 - m1) * q1 + b1 = (1 - m1) * (1-q1) + a2 = (1 - m2) * q2 + b2 = (1 - m2) * (1-q2) + rec_a1 = recenv(a1) + rec_a2 = recenv(a2) + rec = recenv(1-(1-rec_a1)*(1-rec_a2)) + ok = rec_a1 * rec * (1-m1) + if __name__ == "__main__": + import time grids = Grids() @@ -679,6 +783,23 @@ if __name__ == "__main__": nb, nrow = 64, 4 nb_rows = 12 + q = grids.generate_w_quizzes_( + 1, + tasks=[ + grids.task_replace_color, + # grids.task_translate, + # grids.task_grow, + # grids.task_frame, + ], + ) + + q = q.reshape(q.size(0), 4, q.size(1)//4) + + print(q) + print(valid_exact_match(q[:,0], q[:,1] + + exit(0) + # c_quizzes = torch.load("/home/fleuret/state.pth")["train_c_quizzes"] # c_quizzes = c_quizzes[torch.randperm(c_quizzes.size(0))[: nrow * nb_rows]] @@ -694,7 +815,7 @@ if __name__ == "__main__": # ) w_quizzes = grids.generate_w_quizzes_( - 16, + 1, tasks=[ grids.task_replace_color, grids.task_translate, @@ -705,8 +826,8 @@ if __name__ == "__main__": q = w_quizzes.reshape(-1, 4, w_quizzes.size(1) // 4) r = q.new_zeros(q.size()) - r[:, 0], r[:, 1] = grids.detect_rectangles(q[:, 0], q[:, 1]) - r[:, 2], r[:, 3] = grids.detect_rectangles(q[:, 2], q[:, 3]) + r[:, 0], r[:, 1] = grids.oracle(q[:, 0], q[:, 1]) + r[:, 2], r[:, 3] = grids.oracle(q[:, 2], q[:, 3]) grids.save_quizzes_as_image( "/tmp", -- 2.39.5