X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=c01cc8f3dc3653d64fb465d83795fd17adad5936;hb=eea23df18f107fc65c810261c7775a9393ef7c8e;hp=6d9f69d65120dc5152c4a38ed2cf72a7a7682b72;hpb=bd9e5951a5741f7e3e44fc03379795eff83242d6;p=picoclvr.git diff --git a/main.py b/main.py index 6d9f69d..c01cc8f 100755 --- a/main.py +++ b/main.py @@ -20,7 +20,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ###################################################################### parser = argparse.ArgumentParser( - description="An implementation of GPT with cache to solve a toy geometric reasonning task." + description="An implementation of GPT with cache to solve a toy geometric reasoning task." ) parser.add_argument("--log_filename", type=str, default="train.log") @@ -421,9 +421,7 @@ class TaskPicoCLVR(Task): f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%" ) - img = picoclvr.descr2img( - result_descr, [0], height=self.height, width=self.width - ) + img = picoclvr.descr2img(result_descr, height=self.height, width=self.width) if img.dim() == 5: if img.size(1) == 1: