X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=attentiontoy1d.py;h=6540a0f03bd36316bcd875a90058b3b831dff545;hp=1dbd61415d3066671033c850c7ae00c269922146;hb=b27b7cc54f450bb5fe8c9ea2faf5e01d0082889a;hpb=1f16f4ade71103d9d7445d83e1e242314b735e25 diff --git a/attentiontoy1d.py b/attentiontoy1d.py index 1dbd614..6540a0f 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -10,6 +10,8 @@ import torch, math, sys, argparse from torch import nn from torch.nn import functional as F +import matplotlib.pyplot as plt + ###################################################################### parser = argparse.ArgumentParser(description='Toy attention model.') @@ -146,8 +148,6 @@ def generate_sequences(nb): ###################################################################### -import matplotlib.pyplot as plt - def save_sequence_images(filename, sequences, tr = None, bx = None): fig = plt.figure() ax = fig.add_subplot(1, 1, 1)