X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=lazy_linear.py;h=97530ef2829bbae5912990adce85c31bdb4d93d9;hb=2d95f238bbaa0e585b50846d39c98df4aae2b7f9;hp=7c9e398a66b16dcc025471b59b6ac874e2a6d5f3;hpb=7443d768bc437889659ba3ed737297f90fe1922e;p=pytorch.git diff --git a/lazy_linear.py b/lazy_linear.py index 7c9e398..97530ef 100755 --- a/lazy_linear.py +++ b/lazy_linear.py @@ -1,13 +1,18 @@ -#!/usr/bin/env python-for-pytorch +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret from torch import nn, Tensor -########## +###################################################################### class LazyLinear(nn.Module): def __init__(self, out_dim, bias = True): - super(LazyLinear, self).__init__() + super().__init__() self.out_dim = out_dim self.bias = bias self.core = None @@ -23,16 +28,25 @@ class LazyLinear(nn.Module): return self.core(x) -########## + def named_parameters(self, memo=None, prefix=''): + assert self.core is not None, 'Parameters not yet defined' + return super().named_parameters(memo, prefix) + +###################################################################### + +if __name__ == "__main__": + model = nn.Sequential(nn.Conv2d(3, 8, kernel_size = 5), + nn.ReLU(inplace = True), + LazyLinear(128), + nn.ReLU(inplace = True), + nn.Linear(128, 10)) + + # model.eval() -model = nn.Sequential(nn.Conv2d(1, 8, kernel_size = 5), - nn.ReLU(inplace = True), - LazyLinear(128), - nn.ReLU(inplace = True), - nn.Linear(128, 10)) + input = Tensor(100, 3, 32, 32).normal_() -# model.eval() + output = model(input) -input = Tensor(100, 1, 32, 32).normal_() + for n, x in model.named_parameters(): + print(n, x.size()) -output = model(input)