更新时间:2023-07-21 来源:黑马程序员 浏览量:
ResNet(Residual Network)是由Kaiming He等人提出的深度学习神经网络结构,它在2015年的ImageNet图像识别竞赛中取得了非常显著的成绩,引起了广泛的关注。ResNet的主要贡献是解决了深度神经网络的梯度消失问题,使得可以训练更深的网络,从而获得更好的性能。
问题:在传统的深度神经网络中,随着网络层数的增加,梯度在反向传播过程中逐渐变小,导致浅层网络的权重更新几乎没有效果,难以训练。这被称为梯度消失问题。
ResNet的解决方法:ResNet引入了“残差块”(residual block),每个残差块包含了一条“跳跃连接”(shortcut connection),它允许梯度能够直接穿过块,从而避免了梯度消失问题。因此,深度网络可以通过恒等映射(identity mapping)来学习残差,使得网络在增加深度时反而变得更容易训练。
ResNet结构特点:
1.残差块:每个残差块由两个或三个卷积层组成,它们的输出通过跳跃连接与块的输入相加,形成残差(residual)。
2.跳跃连接:跳跃连接允许梯度直接流过块,有助于避免梯度消失问题。
3.批量归一化:ResNet中广泛使用批量归一化层来加速训练并稳定网络。
4.残差块堆叠:ResNet通过堆叠多个残差块来构建深层网络。深度可以根据任务的复杂性而自由选择。
接下来我们看一个简化的ResNet代码演示(使用TensorFlow):
import tensorflow as tf from tensorflow.keras import layers, models # 定义一个基本的残差块 def residual_block(x, filters, downsample=False): # 如果downsample为True,使用步长为2的卷积层实现降采样 stride = 2 if downsample else 1 # 记录输入,以便在跳跃连接时使用 identity = x # 第一个卷积层 x = layers.Conv2D(filters, kernel_size=3, strides=stride, padding='same')(x) x = layers.BatchNormalization()(x) x = layers.Activation('relu')(x) # 第二个卷积层 x = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same')(x) x = layers.BatchNormalization()(x) # 如果进行了降采样,需要对identity进行相应处理,保证维度一致 if downsample: identity = layers.Conv2D(filters, kernel_size=1, strides=stride, padding='same')(identity) identity = layers.BatchNormalization()(identity) # 跳跃连接:将卷积层的输出与输入相加 x = layers.add([x, identity]) x = layers.Activation('relu')(x) return x # 构建ResNet网络 def ResNet(input_shape, num_classes): input_img = layers.Input(shape=input_shape) # 第一个卷积层 x = layers.Conv2D(64, kernel_size=7, strides=2, padding='same')(input_img) x = layers.BatchNormalization()(x) x = layers.Activation('relu')(x) x = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x) # 堆叠残差块组成网络 x = residual_block(x, filters=64) x = residual_block(x, filters=64) x = residual_block(x, filters=64) x = residual_block(x, filters=128, downsample=True) x = residual_block(x, filters=128) x = residual_block(x, filters=128) x = residual_block(x, filters=256, downsample=True) x = residual_block(x, filters=256) x = residual_block(x, filters=256) x = residual_block(x, filters=512, downsample=True) x = residual_block(x, filters=512) x = residual_block(x, filters=512) # 全局平均池化 x = layers.GlobalAveragePooling2D()(x) # 全连接层输出 x = layers.Dense(num_classes, activation='softmax')(x) # 创建模型 model = models.Model(inputs=input_img, outputs=x) return model # 在这里定义输入图像的形状和类别数 input_shape = (224, 224, 3) num_classes = 1000 # 构建ResNet模型 model = ResNet(input_shape, num_classes) model.summary()
请注意,上述代码是一个简化版本的ResNet网络,实际上,ResNet有不同的变体,可以根据任务的复杂性和资源的可用性选择适合的ResNet结构。