X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=tasks.py;h=b4e6f67c1ee8d22f0f1e202ad252a7613e6a8b94;hb=0ab695df8f6a2a0cc70a424e57943a0d5606903b;hp=7894fcd62a90cd2d629137d36e2fb99c28e0a883;hpb=f08c6c01dcc03b727c69478c3a1de7ebf9facd95;p=culture.git diff --git a/tasks.py b/tasks.py index 7894fcd..b4e6f67 100755 --- a/tasks.py +++ b/tasks.py @@ -500,38 +500,26 @@ class World(Task): logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") - if save_attention_image is not None: - for k in range(10): - ns = torch.randint(self.test_input.size(0), (1,)).item() - input = self.test_input[ns : ns + 1].clone() + ############################## - with torch.autograd.no_grad(): - t = model.training - model.eval() - # model.record_attention(True) - model(BracketedSequence(input)) - model.train(t) - # ram = model.retrieve_attention() - # model.record_attention(False) + input, ar_mask = self.test_input[:64], self.test_ar_mask[:64] + result = input.clone() * (1 - ar_mask) - # tokens_output = [c for c in self.problem.seq2str(input[0])] - # tokens_input = ["n/a"] + tokens_output[:-1] - # for n_head in range(ram[0].size(1)): - # filename = os.path.join( - # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf" - # ) - # attention_matrices = [m[0, n_head] for m in ram] - # save_attention_image( - # filename, - # tokens_input, - # tokens_output, - # attention_matrices, - # k_top=10, - ##min_total_attention=0.9, - # token_gap=12, - # layer_gap=50, - # ) - # logger(f"wrote {filename}") + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, + ) + + img = world.sample2img(result.to("cpu"), self.height, self.width) + + image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png") + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) + logger(f"wrote {image_name}") ######################################################################