projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
560b7d5
)
Update.
author
Francois Fleuret
<francois@fleuret.org>
Sun, 14 Aug 2022 13:29:45 +0000
(15:29 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Sun, 14 Aug 2022 13:29:45 +0000
(15:29 +0200)
minidiffusion.py
patch
|
blob
|
history
diff --git
a/minidiffusion.py
b/minidiffusion.py
index
879b796
..
e1f6abd
100755
(executable)
--- a/
minidiffusion.py
+++ b/
minidiffusion.py
@@
-105,7
+105,7
@@
parser.add_argument('--learning_rate',
parser.add_argument('--ema_decay',
type = float, default = 0.9999,
parser.add_argument('--ema_decay',
type = float, default = 0.9999,
- help = 'EMA decay, < 0 is no EMA')
+ help = 'EMA decay, <
=
0 is no EMA')
data_list = ', '.join( [ str(k) for k in samplers ])
data_list = ', '.join( [ str(k) for k in samplers ])
@@
-129,23
+129,20
@@
class EMA:
def __init__(self, model, decay):
self.model = model
self.decay = decay
def __init__(self, model, decay):
self.model = model
self.decay = decay
- if self.decay < 0: return
- self.ema = { }
+ self.mem = { }
with torch.no_grad():
for p in model.parameters():
with torch.no_grad():
for p in model.parameters():
- self.
ema
[p] = p.clone()
+ self.
mem
[p] = p.clone()
def step(self):
def step(self):
- if self.decay < 0: return
with torch.no_grad():
for p in self.model.parameters():
with torch.no_grad():
for p in self.model.parameters():
- self.
ema[p].copy_(self.decay * self.ema
[p] + (1 - self.decay) * p)
+ self.
mem[p].copy_(self.decay * self.mem
[p] + (1 - self.decay) * p)
- def copy(self):
- if self.decay < 0: return
+ def copy_to_model(self):
with torch.no_grad():
for p in self.model.parameters():
with torch.no_grad():
for p in self.model.parameters():
- p.copy_(self.
ema
[p])
+ p.copy_(self.
mem
[p])
######################################################################
######################################################################
@@
-234,7
+231,7
@@
alpha = 1 - beta
alpha_bar = alpha.log().cumsum(0).exp()
sigma = beta.sqrt()
alpha_bar = alpha.log().cumsum(0).exp()
sigma = beta.sqrt()
-ema = EMA(model, decay = args.ema_decay)
+ema = EMA(model, decay = args.ema_decay)
if args.ema_decay > 0 else None
for k in range(args.nb_epochs):
for k in range(args.nb_epochs):
@@
-245,8
+242,8
@@
for k in range(args.nb_epochs):
x0 = (x0 - train_mean) / train_std
t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device)
eps = torch.randn_like(x0)
x0 = (x0 - train_mean) / train_std
t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device)
eps = torch.randn_like(x0)
-
inpu
t = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
- input = torch.cat((
inpu
t, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1)
+
x
t = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
+ input = torch.cat((
x
t, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1)
loss = (eps - model(input)).pow(2).mean()
acc_loss += loss.item() * x0.size(0)
loss = (eps - model(input)).pow(2).mean()
acc_loss += loss.item() * x0.size(0)
@@
-254,11
+251,11
@@
for k in range(args.nb_epochs):
loss.backward()
optimizer.step()
loss.backward()
optimizer.step()
- ema.step()
+
if ema is not None:
ema.step()
print(f'{k} {acc_loss / train_input.size(0)}')
print(f'{k} {acc_loss / train_input.size(0)}')
-
ema.copy
()
+
if ema is not None: ema.copy_to_model
()
######################################################################
# Plot
######################################################################
# Plot