参考
神经网络15分钟入门——使用python从零开始写一个两层神经网络_Mr.看海的博客-CSDN博客_神经网络入门python
# 参考 https://blog.csdn.net/fengzhuqiaoqiu/article/details/102810048 import numpy as np ## 正向传播函数 # H = x*w + b def affine_formard(x,w,b): out = None N = x.shape[0] x_row = x.reshape(N,-1) out = np.dot(x_row, w) + b cache = (x, w, b) return out, cache ## 反向传播函数 def affine_backward(dout,cache): x, w, b = cache dx, dw, db = None, None, None dx = np.dot(dout, w.T) dx = np.reshape(dx,x.shape) x_row = x.reshape(x.shape[0],-1) dw = np.dot(x_row.T, dout) db = np.sum(dout, axis=0, keepdims=True) return dx, dw, db X = np.array([[2,1],[-1,1],[-1,-1],[1,-1]]) t = np.array([0,1,2,3]) np.random.seed(1) input_dim = X.shape[1] ## 输入参数的维度 num_classes = t.shape[0] ## 输出参数的维度 hidden_dim = 50 ## 隐藏层维度 reg = 0.001 ## 正则化强度 epsilon = 0.001 ## 梯度下降的学习率 W1 = np.random.randn(input_dim, hidden_dim) W2 = np.random.randn(hidden_dim, num_classes) b1 = np.zeros((1, hidden_dim)) b2 = np.zeros((1, num_classes)) for j in range(10000): H, fc_cache = affine_formard(X,W1,b1) H = np.maximum(0, H) relu_cache = H Y, cachey = affine_formard(H, W2, b2) ## Softmax 层计算 probs = np.exp(Y - np.max(Y, axis=1, keepdims=True)) probs /= np.sum(probs, axis=1, keepdims=True) ## 计算 loss N = Y.shape[0] # print(probs[np.arange(N),t]) loss = -np.sum(np.log(probs[np.arange(N), t]))/N print('loss:', loss) dx = probs.copy() dx[np.arange(N), t] -= 1 dx /= N dh1, dW2, db2 = affine_backward(dx, cachey) dh1[relu_cache <=0 ] = 0 dX, dW1, db1 = affine_backward(dh1, fc_cache) dW2 += reg * W2 dW1 += reg * W1 W2 += -epsilon * dW2 b2 += -epsilon * db2 W1 += -epsilon * dW1 b1 += -epsilon * db1 test = np.array([[2,2],[-2,2],[-2,-2],[2,-2]]) H, fc_cache = affine_formard(test, W1, b1) H = np.maximum(0,H) relu_cache = H Y, cachey = affine_formard(H, W2, b2) ## softmax probs = np.exp(Y - np.max(Y, axis=1, keepdims=True)) probs /= np.sum(probs, axis=1, keepdims=True) print(probs) for k in range(4): print(test[k,:],'所在的象限为', np.argmax(probs[k,:])+1)View Code
标签:Python,cache,两层,shape,##,神经网络,dx,np,probs From: https://www.cnblogs.com/ghzhan/p/16753690.html