首页 > 其他分享 >基于keras的残差网络

基于keras的残差网络

时间:2023-03-19 12:13:18浏览次数:44  
标签:________________________________________________________________________________

1 前言

理论上,网络层数越深,拟合效果越好。但是,层数加深也会导致梯度消失或梯度爆炸现象产生。当网络层数已经过深时,深层网络表现为“恒等映射”。实践表明,神经网络对残差的学习比对恒等关系的学习表现更好,因此,残差网络在深层模型中广泛应用。

本文以MNIST手写数字分类为例,为方便读者认识残差网络,网络中只有全连接层,没有卷积层。关于MNIST数据集的说明,见使用TensorFlow实现MNIST数据集分类

笔者工作空间如下:

img

代码资源见-->残差网络(ResNet)案例分析

2 实验

renet.py

from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import add,Input,Dense,Activation

#载入数据
def read_data(path):
    mnist=input_data.read_data_sets(path,one_hot=True)
    train_x,train_y=mnist.train.images,mnist.train.labels,
    valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
    test_x,test_y=mnist.test.images,mnist.test.labels
    return train_x,train_y,valid_x,valid_y,test_x,test_y

#残差块
def ResBlock(x,hidden_size1,hidden_size2):
    r=Dense(hidden_size1,activation='relu')(x)  #第一隐层
    r=Dense(hidden_size2)(r)  #第二隐层
    if x.shape[1]==hidden_size2:
        shortcut=x
    else:
        shortcut=Dense(hidden_size2)(x)  #shortcut(捷径)
    o=add([r,shortcut])
    o=Activation('relu')(o)  #激活函数
    return o

#残差网络
def ResNet(train_x,train_y,valid_x,valid_y,test_x,test_y):
    inputs=Input(shape=(784,))
    x=ResBlock(inputs,30,30)
    x=ResBlock(x,30,30)
    x=ResBlock(x,20,20)
    x=Dense(10,activation='softmax')(x)
    model=Model(input=inputs,output=x)
    #查看网络结构
    model.summary()
    #编译模型
    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    #训练模型
    model.fit(train_x,train_y,batch_size=500,nb_epoch=50,verbose=2,validation_data=(valid_x,valid_y))
    #评估模型
    pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
    print('test_loss:',pre[0],'- test_acc:',pre[1])
     
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
ResNet(train_x,train_y,valid_x,valid_y,test_x,test_y)

网络各层输出尺寸:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 784)          0                                            
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 30)           23550       input_1[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 30)           930         dense_1[0][0]                    
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 30)           23550       input_1[0][0]                    
__________________________________________________________________________________________________
add_1 (Add)                     (None, 30)           0           dense_2[0][0]                    
                                                                 dense_3[0][0]                    
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 30)           0           add_1[0][0]                      
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 30)           930         activation_1[0][0]               
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 30)           930         dense_4[0][0]                    
__________________________________________________________________________________________________
add_2 (Add)                     (None, 30)           0           dense_5[0][0]                    
                                                                 activation_1[0][0]               
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 30)           0           add_2[0][0]                      
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 20)           620         activation_2[0][0]               
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 20)           420         dense_6[0][0]                    
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 20)           620         activation_2[0][0]               
__________________________________________________________________________________________________
add_3 (Add)                     (None, 20)           0           dense_7[0][0]                    
                                                                 dense_8[0][0]                    
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 20)           0           add_3[0][0]                      
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 10)           210         activation_3[0][0]               
==================================================================================================
Total params: 51,760
Trainable params: 51,760
Non-trainable params: 0

网络训练结果:

Epoch 48/50
 - 1s - loss: 0.0019 - acc: 0.9999 - val_loss: 0.1463 - val_acc: 0.9706
Epoch 49/50
 - 1s - loss: 0.0016 - acc: 0.9999 - val_loss: 0.1502 - val_acc: 0.9722
Epoch 50/50
 - 1s - loss: 0.0013 - acc: 0.9999 - val_loss: 0.1542 - val_acc: 0.9728
test_loss: 0.16228994959965348 - test_acc: 0.9721000045537949

​ 声明:本文转自基于keras的残差网络

标签:________________________________________________________________________________
From: https://www.cnblogs.com/zhyan8/p/17232719.html

相关文章