X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;fp=main.py;h=3ff64b7230f80589dd7dcbe1e7d2bb624d94ceda;hb=22415499c0a91922e51f9e2cade009fd404351dc;hp=37515b5b26ddcbd952357cbc715ef6bc13969744;hpb=d95b9b72b0f098b5c955395905a0aff710f553a7;p=picoclvr.git diff --git a/main.py b/main.py index 37515b5..3ff64b7 100755 --- a/main.py +++ b/main.py @@ -844,7 +844,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): input = input.to(device) bs = model(mygpt.BracketedSequence(input)) - output_ar = bs.x + output = bs.x loss = F.cross_entropy(output.transpose(1, 2), input)