X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=attentiontoy1d.py;h=1dbd61415d3066671033c850c7ae00c269922146;hb=1f16f4ade71103d9d7445d83e1e242314b735e25;hp=ad0c0b142a15007a1af64f1b4148056ecf78cf59;hpb=602df813edceaae1b4eb2c69bb5cf0f5823444a7;p=pytorch.git diff --git a/attentiontoy1d.py b/attentiontoy1d.py index ad0c0b1..1dbd614 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -12,7 +12,7 @@ from torch.nn import functional as F ###################################################################### -parser = argparse.ArgumentParser(description='Toy RNN.') +parser = argparse.ArgumentParser(description='Toy attention model.') parser.add_argument('--nb_epochs', type = int, default = 250) @@ -147,7 +147,6 @@ def generate_sequences(nb): ###################################################################### import matplotlib.pyplot as plt -import matplotlib.collections as mc def save_sequence_images(filename, sequences, tr = None, bx = None): fig = plt.figure()