+def nb_errors(model, data_set):
+ ne = 0
+ for b in range(0, data_set.nb_batches):
+ input, target = data_set.get_batch(b)
+ output = model.forward(Variable(input))
+ wta_prediction = output.data.max(1)[1].view(-1)
+
+ for i in range(0, data_set.batch_size):
+ if wta_prediction[i] != target[i]:
+ ne = ne + 1
+
+ return ne
+
+######################################################################
+
+def train_model(model, train_set, validation_set):