X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mi_estimator.py;fp=mi_estimator.py;h=47381ef3007bf511341dbb0aa3dbc654a45e29a2;hb=114939c1db199607c9c1f6df0d23e5c23b2915f2;hp=68fd51f75e7a226b09da6e0b2667228537a848ca;hpb=fdd573490e517d38fb0477ae1b5df12b74718d45;p=pytorch.git diff --git a/mi_estimator.py b/mi_estimator.py index 68fd51f..47381ef 100755 --- a/mi_estimator.py +++ b/mi_estimator.py @@ -226,7 +226,7 @@ def create_sequences_pairs(train = False): class NetForImagePair(nn.Module): def __init__(self): - super(NetForImagePair, self).__init__() + super().__init__() self.features_a = nn.Sequential( nn.Conv2d(1, 16, kernel_size = 5), nn.MaxPool2d(3), nn.ReLU(), @@ -257,7 +257,7 @@ class NetForImagePair(nn.Module): class NetForImageValuesPair(nn.Module): def __init__(self): - super(NetForImageValuesPair, self).__init__() + super().__init__() self.features_a = nn.Sequential( nn.Conv2d(1, 16, kernel_size = 5), nn.MaxPool2d(3), nn.ReLU(), @@ -306,7 +306,7 @@ class NetForSequencePair(nn.Module): ) def __init__(self): - super(NetForSequencePair, self).__init__() + super().__init__() self.nc = 32 self.nh = 256