projects
/
mygptrnn.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
3da672d
)
Update.
author
François Fleuret
<francois@fleuret.org>
Sat, 27 Jan 2024 17:45:49 +0000
(18:45 +0100)
committer
François Fleuret
<francois@fleuret.org>
Sat, 27 Jan 2024 17:45:49 +0000
(18:45 +0100)
maxval.py
patch
|
blob
|
history
diff --git
a/maxval.py
b/maxval.py
index
a202b17
..
31aa4b6
100755
(executable)
--- a/
maxval.py
+++ b/
maxval.py
@@
-59,33
+59,43
@@
def pscan(X, V, s=1):
def pscan_diff(X, V, s=1):
if X.size(1) == 1:
def pscan_diff(X, V, s=1):
if X.size(1) == 1:
- return
+ return
X, V
T = 2 * (X.size(1) // 2)
Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
T = 2 * (X.size(1) // 2)
Xf = X[:, :T].view(X.size(0), X.size(1) // 2, 2, X.size(2))
Vf = V[:, :T].view(V.size(0), V.size(1) // 2, 2)
+ Xr = X.new(X.size())
+ Vr = V.new(V.size())
+ Xrf = Xr[:, :T].view(Xr.size(0), Xr.size(1) // 2, 2, Xr.size(2))
+ Vrf = Vr[:, :T].view(Vr.size(0), Vr.size(1) // 2, 2)
+
# [:, :, 0] < [:, :, 1]
m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
# [:, :, 0] < [:, :, 1]
m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
- V
f[:, :, 1]
= m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
+ V
v
= m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
m = m[:, :, None]
m = m[:, :, None]
- Xf[:, :, 1] = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
+ Xx = m * Xf[:, :, 0] + (1 - m) * Xf[:, :, 1]
+
+ Xrf[:, :, 1], Vrf[:, :, 1] = pscan_diff(Xx, Vv, s * 2)
- pscan_diff(Xf[:, :, 1], Vf[:, :, 1], s * 2)
+ Xr[:, 0] = X[:, 0]
+ Vr[:, 0] = V[:, 0]
# [:, :-1, 1] < [:, 1:, 0]
# [:, :-1, 1] < [:, 1:, 0]
- m = (Vf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
- V
f[:, 1:, 0] = m * (V
f[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
+ m = (V
r
f[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
+ V
rf[:, 1:, 0] = m * (Vr
f[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
m = m[:, :, None]
m = m[:, :, None]
- X
f[:, 1:, 0] = m * X
f[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
+ X
rf[:, 1:, 0] = m * Xr
f[:, :-1, 1] + (1 - m) * Xf[:, 1:, 0]
if T < X.size(1):
# [:, -2] < [:, -1]
m = (V[:, -2] - s >= V[:, -1]).long()
if T < X.size(1):
# [:, -2] < [:, -1]
m = (V[:, -2] - s >= V[:, -1]).long()
- V[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
+ V
r
[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
m = m[:, None]
m = m[:, None]
- X[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
+ Xr[:, -1] = m * X[:, -2] + (1 - m) * X[:, -1]
+
+ return Xr, Vr
######################################################################
######################################################################
@@
-104,8
+114,7
@@
if __name__ == "__main__":
# print(V0)
# print(X0)
# print(V0)
# print(X0)
- X1, V1 = X.clone(), V.clone()
- pscan_diff(X1, V1)
+ X1, V1 = pscan_diff(X, V)
# print("########### X V ############################################")
# print(V)
# print("########### X V ############################################")
# print(V)