自定义层

https://zh.d2l.ai/chapter_deep-learning-computation/custom-layer.html

build的形参X_shape怎么传来的???

1 Like

请问你弄清楚这个形参怎么传的了吗??字数字数字数

练习1按我的理解。张量X降维,先用(1,-1)reshape成向量x,再算xx’,再用(1,-1)reshape成向量
权重矩阵W的shape是(i*j,units),call函数return reshape(xx’,[1,-1])*W

class MyDese(tf.keras.Model):
    def __init__(self, units):
        super().__init__()
        self.units = units

    def build(self, X_shape):
        m = 1
        for i in X_shape:
            m *= i
        self.weight = self.add_weight(name='weight',
            shape=[m*m, self.units],
            initializer=tf.random_normal_initializer())

    def call(self, X):
        print(X.shape)
        x = tf.reshape(X, (-1, 1))
        print(x.shape)
        mul = tf.matmul(x, tf.transpose(x))
        print(mul.shape)
        re = tf.reshape(mul, (1, -1))
        print(re.shape)
        print(self.weight.shape)
        return tf.matmul(re, self.weight)