自定义层

https://zh.d2l.ai/chapter_deep-learning-computation/custom-layer.html

build的形参X_shape怎么传来的???

1 Like

请问你弄清楚这个形参怎么传的了吗??字数字数字数

练习1按我的理解。张量X降维,先用(1,-1)reshape成向量x,再算xx’,再用(1,-1)reshape成向量
权重矩阵W的shape是(i*j,units),call函数return reshape(xx’,[1,-1])*W

class MyDese(tf.keras.Model):
    def __init__(self, units):
        super().__init__()
        self.units = units

    def build(self, X_shape):
        m = 1
        for i in X_shape:
            m *= i
        self.weight = self.add_weight(name='weight',
            shape=[m*m, self.units],
            initializer=tf.random_normal_initializer())

    def call(self, X):
        print(X.shape)
        x = tf.reshape(X, (-1, 1))
        print(x.shape)
        mul = tf.matmul(x, tf.transpose(x))
        print(mul.shape)
        re = tf.reshape(mul, (1, -1))
        print(re.shape)
        print(self.weight.shape)
        return tf.matmul(re, self.weight)

因为在自定义层的时候继承了keras.Model基类,这个类中就包含build和call方法。这边相当于对build方法进行了重定义。

build方法用于在首次传入input时进行权重的初始化,keras会自动调用 build 方法来创建层的变量。

call方法是每一次进行前向传播计算时,Keras 会自动调用模型中每个层的 call 方法。

简单来说,就是在首次传入Inputs进行初始化的时候build方法被keras自动调用了,然后再执行call计算。tensorflow官方文档讲优点是: 单独实现 build() 很好地将只创建一次权重与在每次调用时使用权重分开。

需要将${\Omega}{ijk}x{i}x_{j}$改为$x_{i}x_{j}}{\Omega}_{kij}$,便于计算。

我理解的代码如下:

import tensorflow as tf

class MyLayer(tf.keras.Model):
def init(self, units):
super().init()
self.units = units # k

def build(self, X_shape):
    self.weight = self.add_weight(name='weight',
        shape=[self.units, X.shape[-1], X.shape[-1]],
        initializer=tf.random_normal_initializer())

def call(self, X):
    return tf.reduce_sum(tf.matmul(tf.matmul(tf.transpose(X), X), self.weight), axis=[1,2])

X = tf.random.uniform((2, 5))
layer = MyLayer(20)
yk = layer(X)
yk

结果为:
<tf.Tensor: shape=(20,), dtype=float32, numpy=
array([ 0.05039345, -0.01913158, 1.968323 , 0.2077624 , -0.44874996,
1.0606757 , 0.9400774 , -2.5655737 , -1.645488 , 0.19454817,
-0.11475743, 0.07864859, -0.03830071, -1.849295 , 0.67237735,
-0.36910555, 1.3296158 , 1.094212 , 0.45441827, -0.41159102],
dtype=float32)>