X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=mi_estimator.py;h=47381ef3007bf511341dbb0aa3dbc654a45e29a2;hb=cfa42f2006a72819ad507f60429ea5e26308e5d0;hp=f8b859dd6fc07dc69f2b8fecfa2c30f4d1eeace3;hpb=236238fdfe7d65612b58fbbb5bb29cff4ec45d54;p=pytorch.git
diff --git a/mi_estimator.py b/mi_estimator.py
index f8b859d..47381ef 100755
--- a/mi_estimator.py
+++ b/mi_estimator.py
@@ -1,21 +1,9 @@
#!/usr/bin/env python
-#########################################################################
-# This program is free software: you can redistribute it and/or modify #
-# it under the terms of the version 3 of the GNU General Public License #
-# as published by the Free Software Foundation. #
-# #
-# This program is distributed in the hope that it will be useful, but #
-# WITHOUT ANY WARRANTY; without even the implied warranty of #
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
-# General Public License for more details. #
-# #
-# You should have received a copy of the GNU General Public License #
-# along with this program. If not, see . #
-# #
-# Written by and Copyright (C) Francois Fleuret #
-# Contact for comments & bug reports #
-#########################################################################
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret
import argparse, math, sys
from copy import deepcopy
@@ -38,19 +26,20 @@ else:
parser = argparse.ArgumentParser(
description = '''An implementation of a Mutual Information estimator with a deep model
-Three different toy data-sets are implemented:
+ Three different toy data-sets are implemented, each consists of
+ pairs of samples, that may be from different spaces:
- (1) Two MNIST images of same class. The "true" MI is the log of the
- number of used MNIST classes.
+ (1) Two MNIST images of same class. The "true" MI is the log of the
+ number of used MNIST classes.
- (2) One MNIST image and a pair of real numbers whose difference is
- the class of the image. The "true" MI is the log of the number of
- used MNIST classes.
+ (2) One MNIST image and a pair of real numbers whose difference is
+ the class of the image. The "true" MI is the log of the number of
+ used MNIST classes.
- (3) Two 1d sequences, the first with a single peak, the second with
- two peaks, and the height of the peak in the first is the
- difference of timing of the peaks in the second. The "true" MI is
- the log of the number of possible peak heights.''',
+ (3) Two 1d sequences, the first with a single peak, the second with
+ two peaks, and the height of the peak in the first is the
+ difference of timing of the peaks in the second. The "true" MI is
+ the log of the number of possible peak heights.''',
formatter_class = argparse.ArgumentDefaultsHelpFormatter
)
@@ -197,6 +186,8 @@ def create_image_values_pairs(train = False):
######################################################################
+#
+
def create_sequences_pairs(train = False):
nb, length = 10000, 1024
noise_level = 2e-2
@@ -229,16 +220,13 @@ def create_sequences_pairs(train = False):
noise = b.new(b.size()).normal_(0, noise_level)
b = b + noise
- # a = (a - a.mean()) / a.std()
- # b = (b - b.mean()) / b.std()
-
return a, b, ha
######################################################################
class NetForImagePair(nn.Module):
def __init__(self):
- super(NetForImagePair, self).__init__()
+ super().__init__()
self.features_a = nn.Sequential(
nn.Conv2d(1, 16, kernel_size = 5),
nn.MaxPool2d(3), nn.ReLU(),
@@ -269,7 +257,7 @@ class NetForImagePair(nn.Module):
class NetForImageValuesPair(nn.Module):
def __init__(self):
- super(NetForImageValuesPair, self).__init__()
+ super().__init__()
self.features_a = nn.Sequential(
nn.Conv2d(1, 16, kernel_size = 5),
nn.MaxPool2d(3), nn.ReLU(),
@@ -318,7 +306,7 @@ class NetForSequencePair(nn.Module):
)
def __init__(self):
- super(NetForSequencePair, self).__init__()
+ super().__init__()
self.nc = 32
self.nh = 256