https://zh.d2l.ai/chapter_deep-learning-computation/custom-layer.html
build的形参X_shape怎么传来的???
1 Like
请问你弄清楚这个形参怎么传的了吗??字数字数字数
练习1按我的理解。张量X降维,先用(1,-1)reshape成向量x,再算xx’,再用(1,-1)reshape成向量
权重矩阵W的shape是(i*j,units),call函数return reshape(xx’,[1,-1])*W
class MyDese(tf.keras.Model):
def __init__(self, units):
super().__init__()
self.units = units
def build(self, X_shape):
m = 1
for i in X_shape:
m *= i
self.weight = self.add_weight(name='weight',
shape=[m*m, self.units],
initializer=tf.random_normal_initializer())
def call(self, X):
print(X.shape)
x = tf.reshape(X, (-1, 1))
print(x.shape)
mul = tf.matmul(x, tf.transpose(x))
print(mul.shape)
re = tf.reshape(mul, (1, -1))
print(re.shape)
print(self.weight.shape)
return tf.matmul(re, self.weight)