线性回归代码问题

今天在自己手敲一遍线性回归代码的时候碰到了个bug。代码和李沐老师课上的相同:

```python
import numpy as np
import torch
import random
    
def generate_data(weight, bias, num_of_data):
    data = torch.normal(1, 0.001, (num_of_data, weight.shape[0]))
    result = torch.matmul(data, weight) + bias
    label = result + torch.normal(1, 0.001, result.shape)
    return data, label

def data_iter(batch_size, data, label):
    number = len(data)
    indice = list(range(number))
    random.shuffle(indice)
    for i in range(0, number, batch_size):
            batch_indices = torch.tensor(indice[i:min(i+batch_size, number)])
            yield data[batch_indices], label[batch_indices]

def linear_model(x, w, b):
    return torch.matmul(x,w)+b`

def sgd(lr, batch_size, params):
    with torch.no_grad():
         for param in params:
                    param -= param.grad * lr / batch_size
                    param.grad.zero_()

def square_loss(predict, label):
   return (predict-label)**2`

num_of_data = 100
weight = torch.normal(1, 0.001, (2,1), requires_grad=True)
bias = torch.zeros(1, requires_grad=True)
data, label = generate_data(weight, bias, num_of_data)

epoch = 10
batch_size = 10
learning_rate = 0.01

for i in range(epoch):
    for X, Y in data_iter(batch_size, data, label):
            output = linear_model(X, weight, bias)
            l = square_loss(output, Y)
            l.sum().backward()
            sgd(learning_rate, batch_size, [weight, bias])
    with torch.no_grad():
            print(f'epoch: {epoch}, loss: {square_loss(output(data, weight, bias), label)}')`

    ```

在l.sum().backward()这一步遇到了问题:Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time. 尝试解决,但是确实没看出来哪里出现了错误。为什么会出现重复使用计算图的情况呢… 感谢各位!

Hi, @Eric_Gao! 在你的代码中,你把 weight 和 bias 用于 generate_data 中,我认为这部分的计算是会被归于 weight 和 bias 的计算图中的并且在 backward 时被计算。我认为你可以再设置一个 true_weight 和 true_bias 专门用于数据生成,然后 weight 和 bias 用做模型参数并进行训练优化,thanks!

1 Like

问题成功解决,非常感谢您的回答! :grinning:

@Eric_Gao

Your squared_loss is wrong!
1/2 is related to the differentiation of square_loss.

I forgot this part when I tried to finish the code. Really thanks for your reply :slight_smile: