From 4817708db50a18242ade3ba88971dd4ef0a73004 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Mon, 3 Dec 2018 11:58:07 -0500 Subject: [PATCH] OCD update. --- mine_mnist.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mine_mnist.py b/mine_mnist.py index 06458b5..7845d81 100755 --- a/mine_mnist.py +++ b/mine_mnist.py @@ -8,6 +8,13 @@ from torch import nn ###################################################################### +if torch.cuda.is_available(): + device = torch.device('cuda') +else: + device = torch.device('cpu') + +###################################################################### + parser = argparse.ArgumentParser( description = 'An implementation of Mutual Information estimator with a deep model', formatter_class = argparse.ArgumentDefaultsHelpFormatter @@ -27,13 +34,6 @@ parser.add_argument('--mnist_classes', ###################################################################### -if torch.cuda.is_available(): - device = torch.device('cuda') -else: - device = torch.device('cpu') - -###################################################################### - def entropy(target): probas = [] for k in range(target.max() + 1): -- 2.20.1