Cosmetics.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 19 Dec 2022 20:19:29 +0000 (21:19 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 19 Dec 2022 20:19:29 +0000 (21:19 +0100)
main.py
picoclvr.py

diff --git a/main.py b/main.py
index 6d9f69d..c01cc8f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -20,7 +20,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description="An implementation of GPT with cache to solve a toy geometric reasonning task."
+    description="An implementation of GPT with cache to solve a toy geometric reasoning task."
 )
 
 parser.add_argument("--log_filename", type=str, default="train.log")
@@ -421,9 +421,7 @@ class TaskPicoCLVR(Task):
             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
         )
 
-        img = picoclvr.descr2img(
-            result_descr, [0], height=self.height, width=self.width
-        )
+        img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
 
         if img.dim() == 5:
             if img.size(1) == 1:
index cc937af..bd0470f 100755 (executable)
@@ -241,15 +241,9 @@ def generate(
 # Extracts the image after <img> in descr as a 1x3xHxW tensor
 
 
-def descr2img(descr, n, height, width):
+def descr2img(descr, height, width):
 
-    if type(descr) == list:
-        return torch.cat([descr2img(d, n, height, width) for d in descr], 0)
-
-    if type(n) == list:
-        return torch.cat([descr2img(descr, k, height, width) for k in n], 0).unsqueeze(
-            0
-        )
+    result = []
 
     def token2color(t):
         try:
@@ -257,15 +251,15 @@ def descr2img(descr, n, height, width):
         except KeyError:
             return [128, 128, 128]
 
-    d = descr.split("<img>")
-    d = d[n + 1] if len(d) > n + 1 else ""
-    d = d.strip().split(" ")[: height * width]
-    d = d + ["<unk>"] * (height * width - len(d))
-    d = [token2color(t) for t in d]
-    img = torch.tensor(d).permute(1, 0)
-    img = img.reshape(1, 3, height, width)
+    for d in descr:
+        d = d.split("<img>")[1]
+        d = d.strip().split(" ")[: height * width]
+        d = d + ["<unk>"] * (height * width - len(d))
+        d = [token2color(t) for t in d]
+        img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width)
+        result.append(img)
 
-    return img
+    return torch.cat(result, 0)
 
 
 ######################################################################
@@ -353,7 +347,7 @@ if __name__ == "__main__":
             for d in descr:
                 f.write(f"{d}\n\n")
 
-        img = descr2img(descr, n=0, height=12, width=16)
+        img = descr2img(descr, height=12, width=16)
         if img.size(0) == 1:
             img = F.pad(img, (1, 1, 1, 1), value=64)