Update.
[pytorch.git] / attentiontoy1d.py
index d7f06fe..2cecad8 100755 (executable)
@@ -31,8 +31,15 @@ parser.add_argument('--positional_encoding',
                     help = 'Provide a positional encoding',
                     action='store_true', default=False)
 
+parser.add_argument('--seed',
+                    type = int, default = 0,
+                    help = 'Random seed (default 0, < 0 is no seeding)')
+
 args = parser.parse_args()
 
+if args.seed >= 0:
+    torch.manual_seed(args.seed)
+
 ######################################################################
 
 label=''
@@ -62,8 +69,6 @@ if torch.cuda.is_available():
 else:
     device = torch.device('cpu')
 
-torch.manual_seed(1)
-
 ######################################################################
 
 seq_height_min, seq_height_max = 1.0, 25.0
@@ -71,7 +76,7 @@ seq_width_min, seq_width_max = 5.0, 11.0
 seq_length = 100
 
 def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
-    st = torch.arange(seq_length).float()
+    st = torch.arange(seq_length, device = device).float()
     st = st[None, :, None]
     tr = tr[:, None, :, :]
     bx = bx[:, None, :, :]
@@ -81,7 +86,6 @@ def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
 
     x = torch.cat((xtr, xbx), 2)
 
-    # u = x.sign()
     u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1)
 
     collisions = (u.sum(2) > 1).max(1).values
@@ -95,12 +99,12 @@ def generate_sequences(nb):
 
     # Position / height / width
 
-    tr = torch.empty(nb, 2, 3)
+    tr = torch.empty(nb, 2, 3, device = device)
     tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
     tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
     tr[:, :, 2].uniform_(seq_width_min, seq_width_max)
 
-    bx = torch.empty(nb, 2, 3)
+    bx = torch.empty(nb, 2, 3, device = device)
     bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
     bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
     bx[:, :, 2].uniform_(seq_width_min, seq_width_max)
@@ -164,10 +168,10 @@ def save_sequence_images(filename, sequences, tr = None, bx = None):
 
     delta = -1.
     if tr is not None:
-        ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False)
+        ax.scatter(tr[:, 0].cpu(), torch.full((tr.size(0),), delta), color = 'black', marker = '^', clip_on=False)
 
     if bx is not None:
-        ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False)
+        ax.scatter(bx[:, 0].cpu(), torch.full((bx.size(0),), delta), color = 'black', marker = 's', clip_on=False)
 
     fig.savefig(filename, bbox_inches='tight')
 
@@ -317,6 +321,8 @@ if args.with_attention:
 test_input = test_input.detach().to('cpu')
 test_outputs = test_outputs.detach().to('cpu')
 test_targets = test_targets.detach().to('cpu')
+test_bx = test_bx.detach().to('cpu')
+test_tr = test_tr.detach().to('cpu')
 
 for k in range(15):
     save_sequence_images(