线性回归的简洁实现

我觉得步骤应该是:

  1. 优化器将梯度清零(zero_grad)
  2. 输入通过模型进行前向传播 (forward pass)
  3. 计算损失 (loss)
  4. 损失进行反向传播 (backward pass)
  5. 优化器更新权重 (optimizer.step)

但是目的就是 保证参数更新的时候(严格讲是计算梯度前防止累加) 梯度不要累加,就等于参数更新前得清零, 等同于 .zero_gard_() 在计算梯度backward()前都行。 你可以 调换 l = loss(net(X) ,y)和
trainer.zero_grad()的位置, 都不影响。 简单通俗来说 早上出门前先吃饭先刷牙都不重要,只要刷牙吃饭都在出门前完成就好

1 Like

看你最后一行打印loss的函数,你这样得到的是l.sum()也就是训练集所有样本的损失之和当然很大!!老师代码第一次打印的是平均损失,所以你只要把这个结果 /1000 是不是发现很合理了

我也是这么觉得的,每一个epoch最后打印的那个loss是一个单独的计算图,就算不用清空计算图应该问题也不大。

#!/usr/bin/env python3

import torch
from torch import nn

class Linear:
    def __init__(self, n, eta=0.01):
#        self.theta = torch.normal(0.0, 0.01, size=(n+1, 1), requires_grad=True)
        self.theta = torch.tensor([[0.0], [0.0], [0.00]], requires_grad=True)
        self.eta = eta

    def cal(self, x):
        return torch.concat((x, torch.ones(x.shape[0], 1)), dim=1) @ self.theta

    def loss(self, x, y):
#        return (y - self.cal(x)) ** 2
        return nn.MSELoss()(self.cal(x), y)

    def batch(self, x, y):
        loss = self.loss(x, y)
        loss.backward()
        with torch.no_grad():
            self.theta -= self.eta * self.theta.grad / x.shape[0]
            print(f'loss={loss};theta={self.theta.reshape(3)}')
            # not sure if this is necessary
            # self.theta.grad.zero_()


def synthetic_data(w, b, num_examples):
    """生成y=Xw+b+噪声"""
    X = torch.normal(0.0, 1.0, (num_examples, len(w)))
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 10)
mod = Linear(2)
net = nn.Sequential(nn.Linear(2, 1))
net[0].weight.data.fill_(0.0)
net[0].bias.data.fill_(0.00)
loss = nn.MSELoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.01)
for i in range(100):
    print(f'epoch {i}')
    x, y = synthetic_data(true_w, true_b, 10)
    for i, j in zip(x, y):
        l = loss(net(i) ,j)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    l = loss(net(features), labels)
    theta = torch.concat((net[0].weight.data[0], net[0].bias.data), dim=0)
    print(f'loss={l};theta={theta} +vs+ ', end='')
    mod.batch(x, y)

前一节贴了自己根据理解写的,现在把torch的训练结果写在一起对比,相同的lr(0.01),初始值也相同,但是结果差的很远

epoch 0
loss=40.199493408203125;theta=tensor([ 0.7239, -0.1254,  0.8538]) +vs+ loss=29.813140869140625;theta=tensor([ 0.0080, -0.0013,  0.0093], requires_grad=True)
epoch 1
loss=23.488977432250977;theta=tensor([ 1.1116, -0.7396,  1.7217]) +vs+ loss=43.919803619384766;theta=tensor([ 0.0221, -0.0100,  0.0307], requires_grad=True)
epoch 2
loss=15.439271926879883;theta=tensor([ 1.2333, -1.1667,  2.2884]) +vs+ loss=35.333126068115234;theta=tensor([ 0.0408, -0.0238,  0.0627], requires_grad=True)
epoch 3
loss=9.107278823852539;theta=tensor([ 1.3226, -1.6589,  2.8143]) +vs+ loss=41.886661529541016;theta=tensor([ 0.0616, -0.0472,  0.1062], requires_grad=True)
epoch 4
loss=4.869641304016113;theta=tensor([ 1.4891, -2.0748,  3.2445]) +vs+ loss=47.77497482299805;theta=tensor([ 0.0866, -0.0816,  0.1619], requires_grad=True)
epoch 5
loss=2.0043132305145264;theta=tensor([ 1.6313, -2.5625,  3.6038]) +vs+ loss=62.35139083862305;theta=tensor([ 0.1175, -0.1330,  0.2318], requires_grad=True)
epoch 6
loss=1.372653841972351;theta=tensor([ 1.6587, -2.6748,  3.7628]) +vs+ loss=31.220203399658203;theta=tensor([ 0.1505, -0.1898,  0.3120], requires_grad=True)
epoch 7
loss=0.8469464182853699;theta=tensor([ 1.8353, -2.7874,  3.8238]) +vs+ loss=31.033191680908203;theta=tensor([ 0.1938, -0.2514,  0.3992], requires_grad=True)
epoch 8
loss=0.4125572144985199;theta=tensor([ 1.8380, -2.9837,  3.9591]) +vs+ loss=45.233741760253906;theta=tensor([ 0.2380, -0.3272,  0.4982], requires_grad=True)
epoch 9
loss=0.2835714817047119;theta=tensor([ 1.8280, -3.0981,  3.9877]) +vs+ loss=23.174217224121094;theta=tensor([ 0.2812, -0.4115,  0.6030], requires_grad=True)
epoch 10
loss=0.19699090719223022;theta=tensor([ 1.8194, -3.1531,  4.0481]) +vs+ loss=27.476261138916016;theta=tensor([ 0.3224, -0.5036,  0.7175], requires_grad=True)
epoch 11
loss=0.08127938956022263;theta=tensor([ 1.9149, -3.2377,  4.0813]) +vs+ loss=39.93544387817383;theta=tensor([ 0.3781, -0.6057,  0.8398], requires_grad=True)
epoch 12
loss=0.07076151669025421;theta=tensor([ 1.9121, -3.2523,  4.0923]) +vs+ loss=9.7304105758667;theta=tensor([ 0.4335, -0.7093,  0.9666], requires_grad=True)
epoch 13
loss=0.04454811289906502;theta=tensor([ 1.9538, -3.2907,  4.0909]) +vs+ loss=17.899154663085938;theta=tensor([ 0.4971, -0.8190,  1.0956], requires_grad=True)
epoch 14
loss=0.028745323419570923;theta=tensor([ 1.9640, -3.3082,  4.1150]) +vs+ loss=19.298988342285156;theta=tensor([ 0.5644, -0.9330,  1.2316], requires_grad=True)
epoch 15
loss=0.014224288985133171;theta=tensor([ 1.9800, -3.3323,  4.1395]) +vs+ loss=27.38421058654785;theta=tensor([ 0.6372, -1.0552,  1.3765], requires_grad=True)
epoch 16
loss=0.007395690772682428;theta=tensor([ 1.9806, -3.3572,  4.1549]) +vs+ loss=23.562267303466797;theta=tensor([ 0.7108, -1.1880,  1.5290], requires_grad=True)
epoch 17
loss=0.004921122919768095;theta=tensor([ 1.9848, -3.3638,  4.1641]) +vs+ loss=14.054956436157227;theta=tensor([ 0.7868, -1.3247,  1.6876], requires_grad=True)
epoch 18
loss=0.0037674028426408768;theta=tensor([ 1.9860, -3.3691,  4.1687]) +vs+ loss=8.764888763427734;theta=tensor([ 0.8635, -1.4649,  1.8499], requires_grad=True)
epoch 19
loss=0.00217236066237092;theta=tensor([ 1.9892, -3.3788,  4.1752]) +vs+ loss=14.227798461914062;theta=tensor([ 0.9433, -1.6116,  2.0175], requires_grad=True)
epoch 20
loss=0.0014804631937295198;theta=tensor([ 1.9888, -3.3842,  4.1805]) +vs+ loss=8.872236251831055;theta=tensor([ 1.0231, -1.7626,  2.1897], requires_grad=True)
epoch 21
loss=0.0012339865788817406;theta=tensor([ 1.9911, -3.3848,  4.1822]) +vs+ loss=4.2144083976745605;theta=tensor([ 1.1060, -1.9144,  2.3639], requires_grad=True)
epoch 22
loss=0.0007785544148646295;theta=tensor([ 1.9924, -3.3900,  4.1860]) +vs+ loss=7.574267387390137;theta=tensor([ 1.1906, -2.0702,  2.5423], requires_grad=True)
epoch 23
loss=0.000381416582968086;theta=tensor([ 1.9960, -3.3950,  4.1908]) +vs+ loss=11.108306884765625;theta=tensor([ 1.2794, -2.2328,  2.7266], requires_grad=True)
epoch 24
loss=0.00030026386957615614;theta=tensor([ 1.9966, -3.3958,  4.1930]) +vs+ loss=3.9018654823303223;theta=tensor([ 1.3683, -2.3976,  2.9144], requires_grad=True)
epoch 25
loss=0.0002697670424822718;theta=tensor([ 1.9967, -3.3964,  4.1940]) +vs+ loss=1.767190933227539;theta=tensor([ 1.4576, -2.5630,  3.1042], requires_grad=True)
epoch 26
loss=0.00023465680715162307;theta=tensor([ 1.9968, -3.3982,  4.1951]) +vs+ loss=2.983660936355591;theta=tensor([ 1.5483, -2.7309,  3.2970], requires_grad=True)
epoch 27
loss=0.00020080586546100676;theta=tensor([ 1.9970, -3.3992,  4.1971]) +vs+ loss=1.412642240524292;theta=tensor([ 1.6387, -2.9000,  3.4920], requires_grad=True)
epoch 28
loss=0.00018914752581622452;theta=tensor([ 1.9976, -3.3985,  4.1978]) +vs+ loss=0.9602915048599243;theta=tensor([ 1.7311, -3.0703,  3.6880], requires_grad=True)
epoch 29
loss=0.0001670362544246018;theta=tensor([ 1.9993, -3.3990,  4.1994]) +vs+ loss=0.6071179509162903;theta=tensor([ 1.8252, -3.2410,  3.8851], requires_grad=True)
epoch 30
loss=0.0001656902168178931;theta=tensor([ 1.9989, -3.3986,  4.1997]) +vs+ loss=0.11326725780963898;theta=tensor([ 1.9192, -3.4121,  4.0829], requires_grad=True)
epoch 31
loss=0.00017027936701197177;theta=tensor([ 1.9986, -3.3979,  4.1991]) +vs+ loss=0.017175797373056412;theta=tensor([ 2.0132, -3.5832,  4.2809], requires_grad=True)
epoch 32
loss=0.00017005577683448792;theta=tensor([ 1.9984, -3.3992,  4.1995]) +vs+ loss=0.07364974915981293;theta=tensor([ 2.1074, -3.7535,  4.4786], requires_grad=True)
epoch 33
loss=0.0001676440006121993;theta=tensor([ 1.9991, -3.3987,  4.1992]) +vs+ loss=0.25866103172302246;theta=tensor([ 2.2012, -3.9231,  4.6757], requires_grad=True)
epoch 34
loss=0.00015877402620390058;theta=tensor([ 2.0002, -3.3987,  4.2006]) +vs+ loss=0.5912841558456421;theta=tensor([ 2.2952, -4.0914,  4.8714], requires_grad=True)
epoch 35
loss=0.00016528442210983485;theta=tensor([ 2.0008, -3.3997,  4.2003]) +vs+ loss=1.139038324356079;theta=tensor([ 2.3881, -4.2578,  5.0664], requires_grad=True)
epoch 36
loss=0.0001614181383047253;theta=tensor([ 2.0006, -3.3992,  4.2010]) +vs+ loss=1.3117530345916748;theta=tensor([ 2.4807, -4.4225,  5.2600], requires_grad=True)
epoch 37
loss=0.00016502727521583438;theta=tensor([ 2.0006, -3.3996,  4.2002]) +vs+ loss=2.4010090827941895;theta=tensor([ 2.5720, -4.5848,  5.4522], requires_grad=True)
epoch 38
loss=0.0001618371024960652;theta=tensor([ 1.9999, -3.3992,  4.2007]) +vs+ loss=1.7867286205291748;theta=tensor([ 2.6609, -4.7470,  5.6425], requires_grad=True)
epoch 39
loss=0.00016230842447839677;theta=tensor([ 1.9996, -3.3982,  4.1996]) +vs+ loss=6.560575008392334;theta=tensor([ 2.7493, -4.9038,  5.8292], requires_grad=True)
epoch 40
loss=0.00016202664119191468;theta=tensor([ 1.9992, -3.3980,  4.1998]) +vs+ loss=3.9900474548339844;theta=tensor([ 2.8344, -5.0588,  6.0140], requires_grad=True)
epoch 41
loss=0.00016224353748839349;theta=tensor([ 1.9992, -3.3984,  4.2000]) +vs+ loss=3.3899340629577637;theta=tensor([ 2.9202, -5.2123,  6.1963], requires_grad=True)
epoch 42
loss=0.0001664403680479154;theta=tensor([ 1.9990, -3.3989,  4.1996]) +vs+ loss=3.928368330001831;theta=tensor([ 3.0061, -5.3636,  6.3765], requires_grad=True)
epoch 43
loss=0.00016333414532709867;theta=tensor([ 1.9999, -3.3990,  4.1998]) +vs+ loss=7.6767778396606445;theta=tensor([ 3.0890, -5.5130,  6.5528], requires_grad=True)
epoch 44
loss=0.00016104819951578975;theta=tensor([ 2.0003, -3.3989,  4.2003]) +vs+ loss=14.649149894714355;theta=tensor([ 3.1676, -5.6565,  6.7238], requires_grad=True)
epoch 45
loss=0.00015725786215625703;theta=tensor([ 2.0002, -3.3983,  4.2007]) +vs+ loss=7.038060188293457;theta=tensor([ 3.2458, -5.7980,  6.8913], requires_grad=True)
epoch 46
loss=0.00015804634313099086;theta=tensor([ 1.9999, -3.3987,  4.2012]) +vs+ loss=6.785189628601074;theta=tensor([ 3.3240, -5.9385,  7.0547], requires_grad=True)
epoch 47
loss=0.0001646848686505109;theta=tensor([ 2.0000, -3.3996,  4.2004]) +vs+ loss=11.080676078796387;theta=tensor([ 3.4008, -6.0762,  7.2133], requires_grad=True)
epoch 48
loss=0.00016850080282893032;theta=tensor([ 2.0003, -3.3996,  4.1991]) +vs+ loss=9.164220809936523;theta=tensor([ 3.4758, -6.2131,  7.3675], requires_grad=True)
epoch 49
loss=0.00017005126574076712;theta=tensor([ 1.9994, -3.3999,  4.1995]) +vs+ loss=37.14043045043945;theta=tensor([ 3.5417, -6.3406,  7.5109], requires_grad=True)
epoch 50
loss=0.00017227436183020473;theta=tensor([ 1.9995, -3.4002,  4.1992]) +vs+ loss=17.466419219970703;theta=tensor([ 3.6056, -6.4652,  7.6472], requires_grad=True)
epoch 51
loss=0.0001666880853008479;theta=tensor([ 1.9994, -3.3997,  4.2000]) +vs+ loss=22.53124237060547;theta=tensor([ 3.6692, -6.5842,  7.7755], requires_grad=True)
epoch 52
loss=0.0001748951181070879;theta=tensor([ 1.9999, -3.4009,  4.1996]) +vs+ loss=29.02981948852539;theta=tensor([ 3.7296, -6.6960,  7.8955], requires_grad=True)
epoch 53
loss=0.0001712619123281911;theta=tensor([ 2.0000, -3.4003,  4.1995]) +vs+ loss=22.801115036010742;theta=tensor([ 3.7873, -6.8039,  8.0079], requires_grad=True)
epoch 54
loss=0.0001726361660985276;theta=tensor([ 2.0001, -3.4008,  4.2001]) +vs+ loss=34.576438903808594;theta=tensor([ 3.8427, -6.9041,  8.1101], requires_grad=True)
epoch 55
loss=0.00017347534594591707;theta=tensor([ 1.9999, -3.4007,  4.1994]) +vs+ loss=21.741161346435547;theta=tensor([ 3.8985, -6.9985,  8.2061], requires_grad=True)
epoch 56
loss=0.00017639700672589242;theta=tensor([ 1.9993, -3.4010,  4.1995]) +vs+ loss=17.54810905456543;theta=tensor([ 3.9536, -7.0890,  8.2973], requires_grad=True)
epoch 57
loss=0.00017972828936763108;theta=tensor([ 1.9998, -3.4012,  4.1987]) +vs+ loss=38.845252990722656;theta=tensor([ 4.0028, -7.1705,  8.3804], requires_grad=True)
epoch 58
loss=0.0001698939158814028;theta=tensor([ 1.9989, -3.3998,  4.1997]) +vs+ loss=24.035633087158203;theta=tensor([ 4.0506, -7.2466,  8.4576], requires_grad=True)
epoch 59
loss=0.00016960882931016386;theta=tensor([ 1.9987, -3.3993,  4.1995]) +vs+ loss=74.61981201171875;theta=tensor([ 4.0879, -7.3052,  8.5205], requires_grad=True)
epoch 60
loss=0.00017015331832226366;theta=tensor([ 1.9987, -3.3988,  4.1992]) +vs+ loss=55.80016326904297;theta=tensor([ 4.1123, -7.3517,  8.5749], requires_grad=True)
epoch 61
loss=0.00017030151502694935;theta=tensor([ 1.9983, -3.3988,  4.1994]) +vs+ loss=21.357677459716797;theta=tensor([ 4.1337, -7.3972,  8.6219], requires_grad=True)
epoch 62
loss=0.00016912302817218006;theta=tensor([ 1.9990, -3.4003,  4.2007]) +vs+ loss=62.74077606201172;theta=tensor([ 4.1522, -7.4262,  8.6567], requires_grad=True)
epoch 63
loss=0.00017075186769943684;theta=tensor([ 1.9997, -3.4006,  4.2007]) +vs+ loss=23.32436752319336;theta=tensor([ 4.1711, -7.4515,  8.6842], requires_grad=True)
epoch 64
loss=0.00017476364155299962;theta=tensor([ 2.0004, -3.4010,  4.2003]) +vs+ loss=59.611053466796875;theta=tensor([ 4.1798, -7.4662,  8.6996], requires_grad=True)
epoch 65
loss=0.0001723233435768634;theta=tensor([ 2.0004, -3.4007,  4.1999]) +vs+ loss=51.41381072998047;theta=tensor([ 4.1790, -7.4733,  8.7038], requires_grad=True)
epoch 66
loss=0.00017222696624230593;theta=tensor([ 2.0008, -3.4005,  4.1994]) +vs+ loss=19.92046546936035;theta=tensor([ 4.1777, -7.4790,  8.7005], requires_grad=True)
epoch 67
loss=0.00017040420789271593;theta=tensor([ 2.0016, -3.3999,  4.1990]) +vs+ loss=36.07746124267578;theta=tensor([ 4.1770, -7.4757,  8.6890], requires_grad=True)
epoch 68
loss=0.0001699236163403839;theta=tensor([ 2.0014, -3.3997,  4.1988]) +vs+ loss=60.047691345214844;theta=tensor([ 4.1709, -7.4596,  8.6651], requires_grad=True)
epoch 69
loss=0.00017378793563693762;theta=tensor([ 2.0015, -3.4006,  4.1993]) +vs+ loss=38.65812301635742;theta=tensor([ 4.1598, -7.4350,  8.6340], requires_grad=True)
epoch 70
loss=0.00017749964899849147;theta=tensor([ 2.0021, -3.4009,  4.1992]) +vs+ loss=40.08657455444336;theta=tensor([ 4.1406, -7.4041,  8.5945], requires_grad=True)
epoch 71
loss=0.00017457007197663188;theta=tensor([ 2.0013, -3.4006,  4.1989]) +vs+ loss=46.079322814941406;theta=tensor([ 4.1117, -7.3663,  8.5450], requires_grad=True)
epoch 72
loss=0.00017598320846445858;theta=tensor([ 2.0006, -3.4009,  4.1991]) +vs+ loss=38.65458297729492;theta=tensor([ 4.0776, -7.3232,  8.4852], requires_grad=True)
epoch 73
loss=0.00018053247185889632;theta=tensor([ 2.0008, -3.4013,  4.1985]) +vs+ loss=70.78762817382812;theta=tensor([ 4.0319, -7.2583,  8.4178], requires_grad=True)
epoch 74
loss=0.00018454341625329107;theta=tensor([ 2.0007, -3.4012,  4.1975]) +vs+ loss=20.685842514038086;theta=tensor([ 3.9843, -7.1897,  8.3450], requires_grad=True)
epoch 75
loss=0.0001805703795980662;theta=tensor([ 2.0007, -3.4011,  4.1981]) +vs+ loss=21.911151885986328;theta=tensor([ 3.9389, -7.1172,  8.2640], requires_grad=True)
epoch 76
loss=0.00017642680904828012;theta=tensor([ 2.0001, -3.4013,  4.2001]) +vs+ loss=30.90773582458496;theta=tensor([ 3.8961, -7.0373,  8.1734], requires_grad=True)
epoch 77
loss=0.00017017331265378743;theta=tensor([ 2.0000, -3.4005,  4.2005]) +vs+ loss=23.25639533996582;theta=tensor([ 3.8498, -6.9554,  8.0746], requires_grad=True)
epoch 78
loss=0.00016636324289720505;theta=tensor([ 1.9996, -3.4000,  4.2009]) +vs+ loss=13.450610160827637;theta=tensor([ 3.8072, -6.8692,  7.9711], requires_grad=True)
epoch 79
loss=0.00016418930317740887;theta=tensor([ 1.9996, -3.3997,  4.2013]) +vs+ loss=14.077974319458008;theta=tensor([ 3.7614, -6.7814,  7.8630], requires_grad=True)
epoch 80
loss=0.0001676852989476174;theta=tensor([ 1.9991, -3.4002,  4.2013]) +vs+ loss=29.78688621520996;theta=tensor([ 3.7098, -6.6875,  7.7471], requires_grad=True)
epoch 81
loss=0.00016680985572747886;theta=tensor([ 2.0005, -3.4000,  4.2017]) +vs+ loss=35.38242721557617;theta=tensor([ 3.6517, -6.5870,  7.6205], requires_grad=True)
epoch 82
loss=0.00016362463065888733;theta=tensor([ 1.9993, -3.3996,  4.2015]) +vs+ loss=29.195911407470703;theta=tensor([ 3.5847, -6.4811,  7.4862], requires_grad=True)
epoch 83
loss=0.00016521224461030215;theta=tensor([ 2.0000, -3.3998,  4.2019]) +vs+ loss=22.931034088134766;theta=tensor([ 3.5167, -6.3684,  7.3447], requires_grad=True)
epoch 84
loss=0.0001610375620657578;theta=tensor([ 2.0000, -3.3993,  4.2013]) +vs+ loss=23.00523567199707;theta=tensor([ 3.4449, -6.2458,  7.1998], requires_grad=True)
epoch 85
loss=0.000163917793543078;theta=tensor([ 2.0000, -3.3996,  4.2008]) +vs+ loss=19.570383071899414;theta=tensor([ 3.3679, -6.1186,  7.0488], requires_grad=True)
epoch 86
loss=0.00016396815772168338;theta=tensor([ 1.9997, -3.3995,  4.2005]) +vs+ loss=17.819217681884766;theta=tensor([ 3.2883, -5.9845,  6.8930], requires_grad=True)
epoch 87
loss=0.0001636362576391548;theta=tensor([ 1.9990, -3.3995,  4.2010]) +vs+ loss=4.124261856079102;theta=tensor([ 3.2094, -5.8497,  6.7346], requires_grad=True)
epoch 88
loss=0.0001627235469641164;theta=tensor([ 1.9995, -3.3995,  4.2015]) +vs+ loss=16.016010284423828;theta=tensor([ 3.1250, -5.7108,  6.5700], requires_grad=True)
epoch 89
loss=0.00016649524332024157;theta=tensor([ 1.9998, -3.3992,  4.1994]) +vs+ loss=8.818704605102539;theta=tensor([ 3.0429, -5.5676,  6.4012], requires_grad=True)
epoch 90
loss=0.00017015125195030123;theta=tensor([ 2.0000, -3.4003,  4.1997]) +vs+ loss=18.328296661376953;theta=tensor([ 2.9567, -5.4145,  6.2274], requires_grad=True)
epoch 91
loss=0.00017368022236041725;theta=tensor([ 2.0007, -3.4006,  4.1992]) +vs+ loss=6.945518493652344;theta=tensor([ 2.8665, -5.2590,  6.0510], requires_grad=True)
epoch 92
loss=0.00017715533613227308;theta=tensor([ 2.0012, -3.4012,  4.2005]) +vs+ loss=7.404855251312256;theta=tensor([ 2.7714, -5.1014,  5.8710], requires_grad=True)
epoch 93
loss=0.00017647151253186166;theta=tensor([ 2.0004, -3.4012,  4.2007]) +vs+ loss=7.432589054107666;theta=tensor([ 2.6745, -4.9397,  5.6871], requires_grad=True)
epoch 94
loss=0.00016964862879831344;theta=tensor([ 2.0006, -3.4004,  4.2007]) +vs+ loss=2.2681782245635986;theta=tensor([ 2.5772, -4.7772,  5.5012], requires_grad=True)
epoch 95
loss=0.00016688762116245925;theta=tensor([ 2.0002, -3.4001,  4.2013]) +vs+ loss=1.6974331140518188;theta=tensor([ 2.4791, -4.6135,  5.3141], requires_grad=True)
epoch 96
loss=0.00016662708367221057;theta=tensor([ 2.0001, -3.4000,  4.2016]) +vs+ loss=4.312170028686523;theta=tensor([ 2.3805, -4.4447,  5.1253], requires_grad=True)
epoch 97
loss=0.00016773059905972332;theta=tensor([ 2.0002, -3.4002,  4.2010]) +vs+ loss=2.584296703338623;theta=tensor([ 2.2811, -4.2734,  4.9339], requires_grad=True)
epoch 98
loss=0.0001651171623962;theta=tensor([ 1.9996, -3.3997,  4.2005]) +vs+ loss=2.0876243114471436;theta=tensor([ 2.1805, -4.0996,  4.7404], requires_grad=True)
epoch 99
loss=0.00016961248184088618;theta=tensor([ 2.0002, -3.4004,  4.2003]) +vs+ loss=0.5221318006515503;theta=tensor([ 2.0794, -3.9249,  4.5461], requires_grad=True)

首先,一定要有 self.theta.grad.zero_()梯度归零操作,这样才能保证每次参数更新是依据当前batch的数据反向传播梯度信息,否则很难收敛;其次,经测试,在同一学习率下,从零开始的线性回归(无论是你编写的模型还是上一节书上提供的模型)收敛速度就是要比依赖API的模型的收敛速度慢。调大学习率可以让前者加速,原因我也不懂 :smiling_face_with_tear:

在简洁实现中,为什么不用对l先sum()然后再反向传播求梯度,因为在loss函数中,已经对l累加过了,这就是封装的好处

练习3

for paprameter in net.parameters():
            print(paprameter.grad)
1 Like

每算一次损失是一个单独的计算图,不会累计的。累计的话需要再次使用loss吧。

用计时器把训练的代码围起来,发现简洁实现还稍慢一点(0.05s/0.07s)