######################################################################
if torch.cuda.is_available():
- device = torch.device('cuda')
torch.backends.cudnn.benchmark = True
+ device = torch.device('cuda')
else:
device = torch.device('cpu')
type = str, default = '0, 1, 3, 5, 6, 7, 8, 9',
help = 'What MNIST classes to use')
+parser.add_argument('--nb_classes',
+ type = int, default = 2,
+ help = 'How many classes for sequences')
+
+parser.add_argument('--nb_epochs',
+ type = int, default = 50,
+ help = 'How many epochs')
+
+parser.add_argument('--batch_size',
+ type = int, default = 100,
+ help = 'Batch size')
+
######################################################################
def entropy(target):
probas /= probas.sum()
return - (probas * probas.log()).sum().item()
+def robust_log_mean_exp(x):
+ # a = x.max()
+ # return (x-a).exp().mean().log() + a
+ # a = x.max()
+ return x.exp().mean().log()
+
######################################################################
args = parser.parse_args()
c = target
b = a.new(a.size(0), 2)
- b[:, 0].uniform_(10)
- b[:, 1].uniform_(0.5)
+ b[:, 0].uniform_(0.0, 10.0)
+ b[:, 1].uniform_(0.0, 0.5)
b[:, 1] += b[:, 0] + target.float()
return a, b, c
def create_sequences_pairs(train = False):
nb, length = 10000, 1024
- noise_level = 1e-2
+ noise_level = 2e-2
- nb_classes = 4
- ha = torch.randint(nb_classes, (nb, ), device = device) + 1
- # hb = torch.randint(nb_classes, (nb, ), device = device)
+ ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
+ # hb = torch.randint(args.nb_classes, (nb, ), device = device)
hb = ha
pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
a = a - pos.view(nb, 1)
a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
- a = a * ha.float().view(-1, 1).expand_as(a) / (1 + nb_classes)
+ a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
noise = a.new(a.size()).normal_(0, noise_level)
a = a + noise
- pos = torch.empty(nb, device = device).uniform_(0.5)
+ pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
b1 = b1 - pos.view(nb, 1)
- b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1)
- pos = pos + hb.float() / (nb_classes + 1) * 0.5
+ b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
+ pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
b2 = b2 - pos.view(nb, 1)
- b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1)
+ b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
b = b1 + b2
noise = b.new(b.size()).normal_(0, noise_level)
b = b + noise
- ######################################################################
- # for k in range(10):
- # file = open(f'/tmp/dat{k:02d}', 'w')
- # for i in range(a.size(1)):
- # file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
- # file.close()
- # exit(0)
- ######################################################################
-
- a = (a - a.mean()) / a.std()
- b = (b - b.mean()) / b.std()
+ # a = (a - a.mean()) / a.std()
+ # b = (b - b.mean()) / b.std()
return a, b, ha
class NetForSequencePair(nn.Module):
def feature_model(self):
+ kernel_size = 11
+ pooling_size = 4
return nn.Sequential(
- nn.Conv1d(1, self.nc, kernel_size = 5),
- nn.MaxPool1d(2), nn.ReLU(),
- nn.Conv1d(self.nc, self.nc, kernel_size = 5),
- nn.MaxPool1d(2), nn.ReLU(),
- nn.Conv1d(self.nc, self.nc, kernel_size = 5),
- nn.MaxPool1d(2), nn.ReLU(),
- nn.Conv1d(self.nc, self.nc, kernel_size = 5),
- nn.MaxPool1d(2), nn.ReLU(),
- nn.Conv1d(self.nc, self.nc, kernel_size = 5),
- nn.MaxPool1d(2), nn.ReLU(),
+ nn.Conv1d( 1, self.nc, kernel_size = kernel_size),
+ nn.AvgPool1d(pooling_size),
+ nn.LeakyReLU(),
+ nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
+ nn.AvgPool1d(pooling_size),
+ nn.LeakyReLU(),
+ nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
+ nn.AvgPool1d(pooling_size),
+ nn.LeakyReLU(),
+ nn.Conv1d(self.nc, self.nc, kernel_size = kernel_size),
+ nn.AvgPool1d(pooling_size),
+ nn.LeakyReLU(),
)
def __init__(self):
elif args.data == 'sequence_pair':
create_pairs = create_sequences_pairs
model = NetForSequencePair()
+ ######################################################################
+ a, b, c = create_pairs()
+ for k in range(10):
+ file = open(f'/tmp/train_{k:02d}.dat', 'w')
+ for i in range(a.size(1)):
+ file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
+ file.close()
+ # exit(0)
+ ######################################################################
else:
raise Exception('Unknown data ' + args.data)
######################################################################
-nb_epochs, batch_size = 50, 100
-
print('nb_parameters %d' % sum(x.numel() for x in model.parameters()))
-optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
-
model.to(device)
-for e in range(nb_epochs):
+for e in range(args.nb_epochs):
input_a, input_b, classes = create_pairs(train = True)
acc_mi = 0.0
- for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
- input_b.split(batch_size),
- input_br.split(batch_size)):
+ optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
+
+ for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
+ input_b.split(args.batch_size),
+ input_br.split(args.batch_size)):
mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
- loss = - mi
acc_mi += mi.item()
+ loss = - mi
optimizer.zero_grad()
loss.backward()
optimizer.step()
- acc_mi /= (input_a.size(0) // batch_size)
+ acc_mi /= (input_a.size(0) // args.batch_size)
print('%d %.04f %.04f' % (e + 1, acc_mi / math.log(2), entropy(classes) / math.log(2)))
input_a, input_b, classes = create_pairs(train = False)
-for e in range(nb_epochs):
- input_br = input_b[torch.randperm(input_b.size(0))]
+input_br = input_b[torch.randperm(input_b.size(0))]
- acc_mi = 0.0
+acc_mi = 0.0
- for batch_a, batch_b, batch_br in zip(input_a.split(batch_size),
- input_b.split(batch_size),
- input_br.split(batch_size)):
- mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
- acc_mi += mi.item()
+for batch_a, batch_b, batch_br in zip(input_a.split(args.batch_size),
+ input_b.split(args.batch_size),
+ input_br.split(args.batch_size)):
+ mi = model(batch_a, batch_b).mean() - model(batch_a, batch_br).exp().mean().log()
+ acc_mi += mi.item()
- acc_mi /= (input_a.size(0) // batch_size)
+acc_mi /= (input_a.size(0) // args.batch_size)
print('test %.04f %.04f'%(acc_mi / math.log(2), entropy(classes) / math.log(2)))