class DumbRec(nn.Module):
def __init__(
self,
- dim_in,
+ dim_model,
dim_qk,
dim_v,
nb_heads,
self.k_star = randw(nb_lines, dim_qk)
- self.w_qw = randw(nb_heads, dim_qk, dim_in)
- self.w_qr = randw(nb_heads, dim_qk, dim_in)
- # self.w_k = randw(nb_heads, dim_qk, dim_in)
- self.w_v = randw(nb_heads, dim_v, dim_in)
- self.w_o = randw(dim_v * nb_heads, dim_in)
+ self.w_qw = randw(nb_heads, dim_qk, dim_model)
+ self.w_qr = randw(nb_heads, dim_qk, dim_model)
+ # self.w_k = randw(nb_heads, dim_qk, dim_model)
+ self.w_v = randw(nb_heads, dim_v, dim_model)
+ self.w_o = randw(dim_v * nb_heads, dim_model)
def reset_inner_loss(self):
self.acc_attention = 0
class KVRec(nn.Module):
def __init__(
self,
- dim_in,
+ dim_model,
dim_qk,
dim_v,
nb_heads,
self.k_star = randw(nb_lines, dim_qk)
- self.w_qw = randw(nb_heads, dim_qk, dim_in)
- self.w_qr = randw(nb_heads, dim_qk, dim_in)
- self.w_k = randw(nb_heads, dim_qk, dim_in)
- self.w_v = randw(nb_heads, dim_v, dim_in)
- self.w_o = randw(dim_v * nb_heads, dim_in)
+ self.w_qw = randw(nb_heads, dim_qk, dim_model)
+ self.w_qr = randw(nb_heads, dim_qk, dim_model)
+ self.w_k = randw(nb_heads, dim_qk, dim_model)
+ self.w_v = randw(nb_heads, dim_v, dim_model)
+ self.w_o = randw(dim_v * nb_heads, dim_model)
def reset_inner_loss(self):
self.acc_attention = 0
class Caterpillar(nn.Module):
def __init__(
self,
- dim_in,
+ dim_model,
dim_qk,
dim_v,
nb_heads,
self.caterpillar_height = caterpillar_height
self.attention_dropout = attention_dropout
- self.w_G = randw(nb_heads, caterpillar_height, dim_in)
+ self.w_G = randw(nb_heads, caterpillar_height, dim_model)
self.b_G = nn.Parameter(
torch.full(
(nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
)
)
- self.w_K = randw(nb_heads, dim_qk, dim_in)
- self.w_V = randw(nb_heads, dim_v, dim_in)
- self.w_Q = randw(nb_heads, dim_qk, dim_in)
- self.w_O = randw(dim_v * nb_heads, dim_in)
+ self.w_K = randw(nb_heads, dim_qk, dim_model)
+ self.w_V = randw(nb_heads, dim_v, dim_model)
+ self.w_Q = randw(nb_heads, dim_qk, dim_model)
+ self.w_O = randw(dim_v * nb_heads, dim_model)
self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
class QKVAttention(nn.Module):
def __init__(
self,
- dim_in,
+ dim_model,
dim_qk,
dim_v,
nb_heads=1,
self.attention_dropout = attention_dropout
self.record_attention = False
- self.w_q = randw(nb_heads, dim_qk, dim_in)
- self.w_k = randw(nb_heads, dim_qk, dim_in)
- self.w_v = randw(nb_heads, dim_v, dim_in)
- self.w_o = randw(dim_v * nb_heads, dim_in)
+ self.w_q = randw(nb_heads, dim_qk, dim_model)
+ self.w_k = randw(nb_heads, dim_qk, dim_model)
+ self.w_v = randw(nb_heads, dim_v, dim_model)
+ self.w_o = randw(dim_v * nb_heads, dim_model)
def forward(self, bs):
x_q = bs.x
def attlayer():
if attention_layer == "mha":
return QKVAttention(
- dim_in=dim_model,
+ dim_model=dim_model,
dim_qk=dim_keys,
dim_v=dim_model // nb_heads,
nb_heads=nb_heads,
)
elif attention_layer == "dumbrec":
return DumbRec(
- dim_in=dim_model,
+ dim_model=dim_model,
dim_qk=dim_keys,
dim_v=dim_rec_v,
nb_heads=nb_heads,
)
elif attention_layer == "kvrec":
return KVRec(
- dim_in=dim_model,
+ dim_model=dim_model,
dim_qk=dim_keys,
dim_v=dim_rec_v,
nb_heads=nb_heads,
)
elif attention_layer == "caterpillar":
return Caterpillar(
- dim_in=dim_model,
+ dim_model=dim_model,
dim_qk=dim_keys,
dim_v=dim_rec_v,
nb_heads=nb_heads,
print("Basic check.")
m = Caterpillar(
- dim_in=4,
+ dim_model=4,
dim_qk=3,
dim_v=7,
nb_heads=1,