稠密连接网络(DenseNet)

http://zh-v2.d2l.ai/chapter_convolutional-modern/densenet.html

有人知道为什么过渡层用平均池化,但不用最大池化吗?

我自己换成了max pooling尝试了一下,发现accuracy没啥变化。

DenseNet121/169/201模型代码,参照DenseNet原文和torchvision源码实现,供参考

import re
import torch
import torchvision
import torch.utils.checkpoint as cp
from torch import nn, Tensor
from typing import List, Tuple
from collections import OrderedDict
from torchvision.models.utils import load_state_dict_from_url


class _DenseLayer(nn.Module):
    def __init__(self, in_channels, growth_rate, bn_size,
                dropout_rate, momery_efficient: bool):
        super(_DenseLayer, self).__init__()
        self.momery_efficient = momery_efficient
        self.dropout_rate = dropout_rate
        # 用于bottleneck前向计算, 标号2表示加载checkpoints的
        self.norm1: nn.BatchNorm2d
        self.relu1: nn.ReLU
        self.conv1: nn.Conv2d
        self.norm2: nn.BatchNorm2d
        self.relu2: nn.ReLU
        self.conv2: Conv2d
        # 将bottleneck添加到网络中
        self.add_module('norm1', nn.BatchNorm2d(in_channels))
        self.add_module('relu1', nn.ReLU(inplace=True))
        self.add_module('conv1', nn.Conv2d(in_channels, growth_rate * bn_size,
                                            kernel_size=1, stride=1, bias=False))
        self.add_module('norm2', nn.BatchNorm2d(growth_rate * bn_size))
        self.add_module('relu2', nn.ReLU(inplace=True))
        self.add_module('conv2', nn.Conv2d(growth_rate * bn_size, growth_rate,
                                            kernel_size=3, stride=1, padding=1,
                                            bias=False))

    def bottleneck(self, inputs: List[Tensor]):
        concated_features = torch.cat(inputs, 1)
        bottleneck_ouputs = self.conv1(self.relu1(self.norm1(concated_features)))
        return bottleneck_ouputs

    @torch.jit.unused
    def call_checkpoints_bottleneck(self, inputs: List[Tensor]):
        def closure(*inputs):
            return self.bottleneck(inputs)
        return cp.checkpoint(closure, *inputs)

    def forward(self, input: Tensor):
        if isinstance(input, Tensor):
            prev_features = [input]
        else:
            prev_features = input
        if self.momery_efficient:
            bottleneck_ouputs = self.call_checkpoints_bottleneck(prev_features)
        else:
            bottleneck_ouputs = self.bottleneck(prev_features)
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_ouputs)))
        if self.dropout_rate > 0:
            new_features = F.dropout(new_features, p=self.dropout_rate,
                                    training=self.training)
        return new_features


class _DenseBlock(nn.ModuleDict):
    def __init__(self, num_layers, in_channels, growth_rate, bn_size,
                dropout_rate, momery_efficient):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(in_channels + growth_rate * i, growth_rate,
                                bn_size, dropout_rate, momery_efficient)
            self.add_module('denselayer%d' % (i + 1), layer)
    def forward(self, x):
        # 先把上个denseblock的输入放到一个列表,后面逐渐添加各denselayer输出
        features = [x]
        # self.items()访问以OrderedDict方式存在当前self._modules中的layers
        for name, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        return features


class _Transition(nn.Module):#
    def __init__(self, in_channels, out_channels):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(in_channels))
        self.add_module('relu', nn.ReLU(inplace=True))
        # 调整(增加)channels
        self.add_module('conv', nn.Conv2d(in_channels, out_channels,
                                        kernel_size=1, stride=1, bias=False))
        # 减小feature-maps尺寸
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    def __init__(self, block_config: Tuple[int, int, int, int],
                num_classes: int = 1000,
                in_channels: int = 64,
                growth_rate: int = 32,
                bn_size: int = 4,
                dropout_rate: float = 0.,
                momery_efficient: bool = False):
        super(DenseNet, self).__init__()
        # 前面初始部分
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, in_channels, kernel_size=7, stride=2,
                                padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(in_channels)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        ]))
        # 密集部分
        num_features = in_channels
        for i, num_layers in enumerate(block_config):
            denseblock = _DenseBlock(num_layers, num_features, growth_rate,
                                    bn_size, dropout_rate, momery_efficient)
            self.features.add_module('denseblock%d' % (i + 1), denseblock)
            num_features += growth_rate * num_layers
            if i < len(block_config) - 1:
                trans = _Transition(num_features, num_features // 2)
                self.features.add_module('transition%d' % (i + 1), trans)
                # transition通道减半,更新
                num_features = num_features // 2
        # 结尾前batchnorm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        self.classifier = nn.Linear(num_features, num_classes)

        # 初始化参数
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        out = self.features(x)
        out = F.relu(out, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out

class Constructor:
    def __init__(self, num_classes: int = 1000,
                momery_efficient: bool = False,
                load: bool = False,
                progress: bool = True):
        self.num_classes = num_classes
        self.momery_efficient = momery_efficient
        self.load = load
        self.progress = progress
        self.model_urls = {
            'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
            'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
            'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
            }

    def _load_state_dict(self, model: nn.Module, model_url: str):
        state_dict = load_state_dict_from_url(model_url, progress=self.progress)
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        return model.load_state_dict(state_dict)

    def _build_model(self, block_config, moder_url=None):
        model = DenseNet(block_config=block_config, 
                        num_classes=self.num_classes,
                        momery_efficient=self.momery_efficient)
        if self.load:
            model = self._load_state_dict(model, moder_url)
        return model

    def DenseNet121(self):
        return self._build_model((6, 12, 24, 16), self.model_urls['densenet121'])

    def DenseNet169(self):
        return self._build_model((6, 12, 32, 32), self.model_urls['densenet169'])

    def DenseNet201(self):
        return self._build_model((6, 12, 48, 32), self.model_urls['densenet201'])


if __name__ == '__main__':
    num_classes = 1000
    momery_efficient = True
    load = True
    progress = True
    densenet169 = Constructor(num_classes, momery_efficient, load,
                            progress).DenseNet169()
    print(densenet169)

个人感觉(基于神经网络的不可解释性):nn.AvgPool2d(kernel_size=2, stride=2)),这里池化窗口大小为2x2,所以平均池化和最大池化效果差不多。平均池化可以更多的提取到相邻像素的信息。