X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=attentiontoy1d.py;h=6540a0f03bd36316bcd875a90058b3b831dff545;hb=b27b7cc54f450bb5fe8c9ea2faf5e01d0082889a;hp=cff8350839b3f169da6512dbb73e127db7047a89;hpb=c8ca3a8eb2917f92db6e6f8ed7cb00595af02e52;p=pytorch.git diff --git a/attentiontoy1d.py b/attentiontoy1d.py index cff8350..6540a0f 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -1,18 +1,20 @@ #!/usr/bin/env python -# @XREMOTE_HOST: elk.fleuret.org -# @XREMOTE_EXEC: /home/fleuret/conda/bin/python -# @XREMOTE_PRE: killall -q -9 python || echo "Nothing killed" -# @XREMOTE_GET: *.pdf *.log +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret import torch, math, sys, argparse from torch import nn from torch.nn import functional as F +import matplotlib.pyplot as plt + ###################################################################### -parser = argparse.ArgumentParser(description='Toy RNN.') +parser = argparse.ArgumentParser(description='Toy attention model.') parser.add_argument('--nb_epochs', type = int, default = 250) @@ -146,9 +148,6 @@ def generate_sequences(nb): ###################################################################### -import matplotlib.pyplot as plt -import matplotlib.collections as mc - def save_sequence_images(filename, sequences, tr = None, bx = None): fig = plt.figure() ax = fig.add_subplot(1, 1, 1)