X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=attentiontoy1d.py;h=d7f06fe0b587ba8f08dbfdda93ca58728a955f84;hb=4d0e56bee81c535293367628dd73cbf993d0690a;hp=ef203403f93fd2a3b3793e746dcbb3ab5939bcdb;hpb=0ca6f6d779888f1a9fa2caeb1326814094ed6904;p=pytorch.git diff --git a/attentiontoy1d.py b/attentiontoy1d.py index ef20340..d7f06fe 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -10,6 +10,8 @@ import torch, math, sys, argparse from torch import nn from torch.nn import functional as F +import matplotlib.pyplot as plt + ###################################################################### parser = argparse.ArgumentParser(description='Toy attention model.') @@ -146,9 +148,6 @@ def generate_sequences(nb): ###################################################################### -import matplotlib.pyplot as plt -import matplotlib.collections as mc - def save_sequence_images(filename, sequences, tr = None, bx = None): fig = plt.figure() ax = fig.add_subplot(1, 1, 1) @@ -310,8 +309,9 @@ test_input = torch.cat((test_input, positional_input.expand(test_input.size(0), test_outputs = model((test_input - mu) / std).detach() if args.with_attention: - x = model[0:4]((test_input - mu) / std) - test_A = model[4].attention(x) + k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer)) + x = model[0:k]((test_input - mu) / std) + test_A = model[k].attention(x) test_A = test_A.detach().to('cpu') test_input = test_input.detach().to('cpu')