X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=attentiontoy1d.py;h=92d90cf79bb62ac618b1bd7590ee8f8fb498cba8;hb=21f814e17a7022e3332ea1f8ff6fc43c769e7e92;hp=ad0c0b142a15007a1af64f1b4148056ecf78cf59;hpb=602df813edceaae1b4eb2c69bb5cf0f5823444a7;p=pytorch.git diff --git a/attentiontoy1d.py b/attentiontoy1d.py index ad0c0b1..92d90cf 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -10,9 +10,11 @@ import torch, math, sys, argparse from torch import nn from torch.nn import functional as F +import matplotlib.pyplot as plt + ###################################################################### -parser = argparse.ArgumentParser(description='Toy RNN.') +parser = argparse.ArgumentParser(description='Toy attention model.') parser.add_argument('--nb_epochs', type = int, default = 250) @@ -29,8 +31,15 @@ parser.add_argument('--positional_encoding', help = 'Provide a positional encoding', action='store_true', default=False) +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed (default 0, < 0 is no seeding)') + args = parser.parse_args() +if args.seed >= 0: + torch.manual_seed(args.seed) + ###################################################################### label='' @@ -60,8 +69,6 @@ if torch.cuda.is_available(): else: device = torch.device('cpu') -torch.manual_seed(1) - ###################################################################### seq_height_min, seq_height_max = 1.0, 25.0 @@ -146,9 +153,6 @@ def generate_sequences(nb): ###################################################################### -import matplotlib.pyplot as plt -import matplotlib.collections as mc - def save_sequence_images(filename, sequences, tr = None, bx = None): fig = plt.figure() ax = fig.add_subplot(1, 1, 1) @@ -310,8 +314,9 @@ test_input = torch.cat((test_input, positional_input.expand(test_input.size(0), test_outputs = model((test_input - mu) / std).detach() if args.with_attention: - x = model[0:4]((test_input - mu) / std) - test_A = model[4].attention(x) + k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer)) + x = model[0:k]((test_input - mu) / std) + test_A = model[k].attention(x) test_A = test_A.detach().to('cpu') test_input = test_input.detach().to('cpu')