首页 > 其他分享 >协同过滤

协同过滤

时间:2023-03-24 11:03:47浏览次数:31  
标签:Acc 200 00 协同 01 Loss Epoch 过滤

数据准备

import numpy as np
import pandas as pd
from sklearn import model_selection as cv
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.metrics import mean_squared_error
from math import sqrt
import matplotlib.pyplot as plt
import torch
data_root_path="ml-100k/"
header = ['user_id', 'item_id', 'rating', 'timestamp']
df_rating = pd.read_csv(data_root_path+'u.data', sep='\t', names=header)
print(df_rating.head(5))
print(len(df_rating))
   user_id  item_id  rating  timestamp
0      196      242       3  881250949
1      186      302       3  891717742
2       22      377       1  878887116
3      244       51       2  880606923
4      166      346       1  886397596
100000
s = """
movie id | movie title | release date | video release date |
IMDb URL | unknown | Action | Adventure | Animation |
Children's | Comedy | Crime | Documentary | Drama | Fantasy |
Film-Noir | Horror | Musical | Mystery | Romance | Sci-Fi |
Thriller | War | Western 
""".replace("\n"," ")
df_item = pd.read_csv(data_root_path+'u.item', sep='|', 
                      names=list(map(str.strip, s.split("|"))), 
                      encoding='latin-1').drop(columns=
                                               ["movie id","movie title","release date","video release date","IMDb URL"])
print(df_item.head(5))
print(len(df_item))
df_item.shape[1]
   unknown  Action  Adventure  Animation  Children's  Comedy  Crime  \
0        0       0          0          1           1       1      0   
1        0       1          1          0           0       0      0   
2        0       0          0          0           0       0      0   
3        0       1          0          0           0       1      0   
4        0       0          0          0           0       0      1   

   Documentary  Drama  Fantasy  Film-Noir  Horror  Musical  Mystery  Romance  \
0            0      0        0          0       0        0        0        0   
1            0      0        0          0       0        0        0        0   
2            0      0        0          0       0        0        0        0   
3            0      1        0          0       0        0        0        0   
4            0      1        0          0       0        0        0        0   

   Sci-Fi  Thriller  War  Western  
0       0         0    0        0  
1       0         1    0        0  
2       0         1    0        0  
3       0         0    0        0  
4       0         1    0        0  
1682





19
s = """
user id | age | gender | occupation | zip code
""".replace("\n"," ")
df_user = pd.read_csv(data_root_path+'u.user', sep='|', 
                      names=list(map(str.strip, s.split("|"))), ).drop(columns=["user id","zip code"])
dict_occupation = {}
with open(data_root_path+"u.occupation", "r") as f:
    for line in f.readlines():
        line = line.strip("\n").strip()
        dict_occupation[line] = len(dict_occupation)
        df_user.insert(df_user.shape[1], line, 0)
df_user = df_user.drop(columns=["occupation"])
gender_mapping = {
    'F': 1,   
    'M': 0
}
df_user['gender'] = df_user['gender'].map(gender_mapping) 
print(df_user.head(5))
print(len(df_user))
df_user.shape[1]
   age  gender  administrator  artist  doctor  educator  engineer  \
0   24       0              0       0       0         0         0   
1   53       1              0       0       0         0         0   
2   23       0              0       0       0         0         0   
3   24       0              0       0       0         0         0   
4   33       1              0       0       0         0         0   

   entertainment  executive  healthcare  ...  marketing  none  other  \
0              0          0           0  ...          0     0      0   
1              0          0           0  ...          0     0      0   
2              0          0           0  ...          0     0      0   
3              0          0           0  ...          0     0      0   
4              0          0           0  ...          0     0      0   

   programmer  retired  salesman  scientist  student  technician  writer  
0           0        0         0          0        0           0       0  
1           0        0         0          0        0           0       0  
2           0        0         0          0        0           0       0  
3           0        0         0          0        0           0       0  
4           0        0         0          0        0           0       0  

[5 rows x 23 columns]
943





23

划分训练集和验证集

n_users = df_rating.user_id.unique().shape[0]
n_items = df_rating.item_id.unique().shape[0]
print (f"{n_users=}, {n_items=}")

train_data, val_data = cv.train_test_split(df_rating, test_size=0.3)

train_data_matrix = np.zeros((n_users, n_items))
for line in train_data.itertuples():
    train_data_matrix[line[1]-1, line[2]-1] = line[3]

val_data_matrix = np.zeros((n_users, n_items))
for line in val_data.itertuples():
    val_data_matrix[line[1]-1, line[2]-1] = line[3]

print(f"{train_data_matrix=}\n{val_data_matrix=}")
n_users=943, n_items=1682
train_data_matrix=array([[5., 3., 4., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [5., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 5., 0., ..., 0., 0., 0.]])
val_data_matrix=array([[0., 0., 0., ..., 0., 0., 0.],
       [4., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])
LR_df_train = pd.DataFrame(columns=list(df_user)+list(df_item)+["rating"], index=range(len(train_data)))
cnt=0
for row in train_data.itertuples():
    LR_df_train.iloc[cnt, :df_user.shape[1]] = df_user.iloc[row[1]-1]
    LR_df_train.iloc[cnt, df_user.shape[1]:-1] = df_item.iloc[row[2]-1]
    LR_df_train.iloc[cnt, -1] = row[3]
    cnt += 1
print(LR_df_train.head(5))
print(len(LR_df_train))
LR_df_train.shape[1]  
  age gender administrator artist doctor educator engineer entertainment  \
0  20      0             0      0      0        0        0             0   
1  32      1             0      0      0        0        0             0   
2  35      0             0      0      0        0        0             0   
3  14      1             0      0      0        0        0             0   
4  39      1             0      0      0        0        0             0   

  executive healthcare  ... Film-Noir Horror Musical Mystery Romance Sci-Fi  \
0         0          0  ...         0      0       0       0       0      0   
1         0          0  ...         0      0       0       0       0      0   
2         0          0  ...         0      0       0       0       0      1   
3         0          0  ...         0      0       0       0       0      0   
4         0          0  ...         0      0       0       0       0      0   

  Thriller War Western rating  
0        0   0       0      4  
1        0   1       0      4  
2        0   0       0      2  
3        0   0       0      1  
4        0   0       0      4  

[5 rows x 43 columns]
70000





43
LR_df_val = pd.DataFrame(columns=list(df_user)+list(df_item)+["rating"], index=range(len(val_data)))
cnt=0
for row in val_data.itertuples():
    LR_df_val.iloc[cnt, :df_user.shape[1]] = df_user.iloc[row[1]-1]
    LR_df_val.iloc[cnt, df_user.shape[1]:-1] = df_item.iloc[row[2]-1]
    LR_df_val.iloc[cnt, -1] = row[3]
    cnt += 1
print(LR_df_val.head(5))
print(len(LR_df_val))
LR_df_val.shape[1]  
  age gender administrator artist doctor educator engineer entertainment  \
0  44      0             0      0      0        0        0             0   
1  27      0             0      0      0        0        0             0   
2  51      0             0      0      0        0        0             0   
3  34      0             0      0      0        0        0             0   
4  30      0             0      0      0        0        0             0   

  executive healthcare  ... Film-Noir Horror Musical Mystery Romance Sci-Fi  \
0         0          0  ...         0      0       0       0       1      0   
1         0          0  ...         0      0       0       0       0      0   
2         0          0  ...         0      0       0       0       0      0   
3         0          0  ...         0      0       0       0       0      0   
4         0          0  ...         0      0       0       0       1      0   

  Thriller War Western rating  
0        0   0       0      3  
1        0   0       0      3  
2        0   0       0      3  
3        0   1       0      4  
4        0   0       0      4  

[5 rows x 43 columns]
30000





43
inputs_val, labels_val = torch.tensor(LR_df_val.iloc[:, :-1].values.astype(np.float32)), torch.tensor((LR_df_val.iloc[:, -1:].values>3).astype(int))
inputs_val, labels_val
(tensor([[44.,  0.,  0.,  ...,  0.,  0.,  0.],
         [27.,  0.,  0.,  ...,  0.,  0.,  0.],
         [51.,  0.,  0.,  ...,  0.,  0.,  0.],
         ...,
         [29.,  1.,  0.,  ...,  0.,  0.,  0.],
         [31.,  1.,  0.,  ...,  0.,  0.,  0.],
         [19.,  0.,  0.,  ...,  0.,  0.,  0.]]),
 tensor([[0],
         [0],
         [0],
         ...,
         [1],
         [1],
         [0]]))
inputs, labels = torch.tensor(LR_df_train.iloc[:, :-1].values.astype(np.float32)), torch.tensor((LR_df_train.iloc[:, -1:].values>3).astype(int))
inputs, labels
(tensor([[20.,  0.,  0.,  ...,  0.,  0.,  0.],
         [32.,  1.,  0.,  ...,  0.,  1.,  0.],
         [35.,  0.,  0.,  ...,  0.,  0.,  0.],
         ...,
         [42.,  0.,  0.,  ...,  0.,  0.,  0.],
         [40.,  0.,  0.,  ...,  0.,  1.,  0.],
         [24.,  0.,  0.,  ...,  0.,  1.,  0.]]),
 tensor([[1],
         [1],
         [0],
         ...,
         [1],
         [0],
         [1]]))

协同过滤

user_similarity = pairwise_distances(train_data_matrix, metric='cosine')
print(f"{user_similarity=}")
item_similarity = pairwise_distances(train_data_matrix.T, metric='cosine')
print(f"{item_similarity=}")
print(f"zeros' size of user_similarity: {np.sum(np.where(user_similarity,0,1))/user_similarity.shape[0]/user_similarity.shape[0]}, zeros' size of item_similarity: {np.sum(np.where(item_similarity,0,1))/item_similarity.shape[0]/item_similarity.shape[0]}")
user_similarity=array([[0.        , 0.93543162, 0.97379757, ..., 0.93615996, 0.88690173,
        0.75217291],
       [0.93543162, 0.        , 0.88467208, ..., 0.89621907, 0.79223406,
        0.8920644 ],
       [0.97379757, 0.88467208, 0.        , ..., 0.95058479, 0.84574004,
        1.        ],
       ...,
       [0.93615996, 0.89621907, 0.95058479, ..., 0.        , 0.86382552,
        0.94804937],
       [0.88690173, 0.79223406, 0.84574004, ..., 0.86382552, 0.        ,
        0.90007614],
       [0.75217291, 0.8920644 , 1.        , ..., 0.94804937, 0.90007614,
        0.        ]])
item_similarity=array([[0.        , 0.71927952, 0.76663377, ..., 1.        , 1.        ,
        0.94212009],
       [0.71927952, 0.        , 0.77130723, ..., 1.        , 1.        ,
        0.91364895],
       [0.76663377, 0.77130723, 0.        , ..., 1.        , 1.        ,
        0.88141459],
       ...,
       [1.        , 1.        , 1.        , ..., 0.        , 1.        ,
        1.        ],
       [1.        , 1.        , 1.        , ..., 1.        , 0.        ,
        1.        ],
       [0.94212009, 0.91364895, 0.88141459, ..., 1.        , 1.        ,
        0.        ]])
zeros' size of user_similarity: 0.0010604453870625664, zeros' size of item_similarity: 0.0010907970099578526
def predict(ratings, similarity, type='user'):
    if type == 'user':
        mean_user_rating = ratings.mean(axis=1)
        ratings_diff = (ratings - mean_user_rating[:, np.newaxis])
        pred = mean_user_rating[:, np.newaxis] + similarity.dot(ratings_diff) / np.array([np.abs(similarity).sum(axis=1)]).T
    elif type == 'item':
        pred = ratings.dot(similarity) / np.array([np.abs(similarity).sum(axis=1)])
    return pred


item_prediction = predict(train_data_matrix, item_similarity, type='item')
user_prediction = predict(train_data_matrix, user_similarity, type='user')

print(f"{item_prediction=}\n{user_prediction=}")
item_prediction=array([[0.32971542, 0.33393043, 0.34810928, ..., 0.38845925, 0.38845925,
        0.37806896],
       [0.08937264, 0.10062226, 0.09733019, ..., 0.10232005, 0.10232005,
        0.10252517],
       [0.06015872, 0.06360117, 0.0617485 , ..., 0.06186794, 0.06186794,
        0.06240828],
       ...,
       [0.02943056, 0.03754517, 0.03626258, ..., 0.041047  , 0.041047  ,
        0.04063216],
       [0.12106449, 0.12650717, 0.13417332, ..., 0.13920286, 0.13920286,
        0.13899303],
       [0.19063082, 0.18301932, 0.20538459, ..., 0.23319453, 0.23319453,
        0.22839333]])
user_prediction=array([[ 1.39889347,  0.54958744,  0.41911433, ...,  0.24646228,
         0.24646228,  0.24869892],
       [ 1.17894509,  0.32204947,  0.14457168, ..., -0.05586023,
        -0.05586023, -0.05242187],
       [ 1.1715174 ,  0.27811966,  0.10761687, ..., -0.09674941,
        -0.09674941, -0.09337872],
       ...,
       [ 1.0488235 ,  0.24861171,  0.07774342, ..., -0.11514045,
        -0.11514045, -0.11188236],
       [ 1.20173371,  0.33469911,  0.18547453, ..., -0.01255791,
        -0.01255791, -0.00935207],
       [ 1.24143589,  0.38703177,  0.26525567, ...,  0.08909668,
         0.08909668,  0.0917072 ]])
def rmse(prediction, ground_truth):
    prediction = prediction[ground_truth.nonzero()].flatten()
    ground_truth = ground_truth[ground_truth.nonzero()].flatten()
    return sqrt(mean_squared_error(prediction, ground_truth))

print ('User-based CF RMSE: ' + str(rmse(user_prediction, val_data_matrix)))
print ('Item-based CF RMSE: ' + str(rmse(item_prediction, val_data_matrix)))
print(((user_prediction>=0.5)&(val_data_matrix>=1)).sum()/(user_prediction>=0.5).sum())
print(((item_prediction>=0.5)&(val_data_matrix>=1)).sum()/(item_prediction>=0.5).sum())
(val_data_matrix.round()>1).sum()/(val_data_matrix.round()>-1).sum()

User-based CF RMSE: 3.168423228908692
Item-based CF RMSE: 3.4744206898944165
0.0956595342121314
0.06325068410168572





0.017768449669193997
plt.matshow(train_data_matrix)
<matplotlib.image.AxesImage at 0x7f667fd51f90>

image

矩阵分解

u, sigma, vt = np.linalg.svd(train_data_matrix)
print(f"{u=}\n {sigma=}\n {vt=}")
u=array([[ 6.20845849e-02,  2.34134567e-02, -9.77908204e-03, ...,
        -3.68163644e-03,  1.06800887e-03,  4.70209130e-04],
       [ 1.53997685e-02, -5.28537197e-02,  5.90709716e-02, ...,
        -9.74037018e-03, -2.28229936e-02,  2.47332434e-03],
       [ 5.31731805e-03, -2.59268701e-02,  2.47057415e-02, ...,
        -2.86294311e-02,  2.76010019e-02, -8.80238070e-02],
       ...,
       [ 7.79490210e-03, -2.86074739e-02,  8.48371679e-03, ...,
        -1.33064865e-02, -1.45027864e-02,  1.71574437e-02],
       [ 2.32876290e-02, -1.60593975e-03,  2.70199790e-02, ...,
        -7.39322202e-03,  3.18988957e-05,  2.86772821e-03],
       [ 4.13498137e-02, -1.36790075e-02, -5.41465134e-02, ...,
        -3.24451593e-03,  2.32779279e-03, -7.67305153e-04]])
 sigma=array([4.50088904e+02, 1.75319098e+02, 1.55991723e+02, 1.18905112e+02,
       1.17018914e+02, 1.10632968e+02, 9.50590654e+01, 9.17421756e+01,
       8.31425482e+01, 8.02585555e+01, 7.75360579e+01, 7.64208076e+01,
       7.43559227e+01, 7.23761204e+01, 7.04598763e+01, 6.98665871e+01,
       6.94229458e+01, 6.91795321e+01, 6.77294130e+01, 6.76353538e+01,
       6.70986044e+01, 6.59545326e+01, 6.56917965e+01, 6.55422990e+01,
       6.54436404e+01, 6.47449260e+01, 6.43431134e+01, 6.41397108e+01,
       6.35313171e+01, 6.31286723e+01, 6.25684928e+01, 6.23685851e+01,
       6.19748120e+01, 6.16454630e+01, 6.14745454e+01, 6.09942815e+01,
       6.07081928e+01, 6.02244286e+01, 6.00009717e+01, 5.99044526e+01,
       5.92998287e+01, 5.92161035e+01, 5.89424964e+01, 5.84965530e+01,
       5.83062238e+01, 5.81778288e+01, 5.77822592e+01, 5.74843397e+01,
       5.72412113e+01, 5.66326792e+01, 5.64579857e+01, 5.62654233e+01,
       5.62000363e+01, 5.58963030e+01, 5.55869535e+01, 5.54434274e+01,
       5.52115028e+01, 5.50604847e+01, 5.48353602e+01, 5.47468144e+01,
       5.44515304e+01, 5.42347023e+01, 5.39846036e+01, 5.36359708e+01,
       5.33782499e+01, 5.31495723e+01, 5.30462282e+01, 5.29157669e+01,
       5.27584905e+01, 5.25977777e+01, 5.20562609e+01, 5.17983997e+01,
       5.15431315e+01, 5.14816445e+01, 5.11735697e+01, 5.09903710e+01,
       5.06287091e+01, 5.05496898e+01, 5.04117965e+01, 5.02030435e+01,
       4.99271692e+01, 4.98199414e+01, 4.97808960e+01, 4.95578964e+01,
       4.93115352e+01, 4.92807327e+01, 4.90046524e+01, 4.89608703e+01,
       4.86577162e+01, 4.85598054e+01, 4.84159730e+01, 4.82433844e+01,
       4.78660226e+01, 4.78386628e+01, 4.76171885e+01, 4.74950248e+01,
       4.74736634e+01, 4.69848064e+01, 4.69457464e+01, 4.67980186e+01,
       4.65912635e+01, 4.63229533e+01, 4.60248006e+01, 4.60138665e+01,
       4.58107515e+01, 4.56791773e+01, 4.54790811e+01, 4.51961689e+01,
       4.50682250e+01, 4.49487099e+01, 4.48919462e+01, 4.47872765e+01,
       4.45138660e+01, 4.43955072e+01, 4.42939736e+01, 4.41062798e+01,
       4.39794046e+01, 4.38101815e+01, 4.37358791e+01, 4.37085736e+01,
       4.33499520e+01, 4.32643805e+01, 4.30495617e+01, 4.28930269e+01,
       4.27975484e+01, 4.27141116e+01, 4.25486635e+01, 4.23259205e+01,
       4.23199383e+01, 4.21207379e+01, 4.20156749e+01, 4.18116576e+01,
       4.15811312e+01, 4.14909696e+01, 4.14226132e+01, 4.13313688e+01,
       4.12063056e+01, 4.10820572e+01, 4.08509161e+01, 4.07907208e+01,
       4.06709815e+01, 4.05781415e+01, 4.04677230e+01, 4.02846783e+01,
       4.00953579e+01, 3.99593973e+01, 3.98453432e+01, 3.96792118e+01,
       3.96173232e+01, 3.94250493e+01, 3.93445987e+01, 3.91838373e+01,
       3.91276333e+01, 3.90519025e+01, 3.88443422e+01, 3.87179345e+01,
       3.86665175e+01, 3.85525787e+01, 3.82060670e+01, 3.80858325e+01,
       3.80321663e+01, 3.78235500e+01, 3.77639516e+01, 3.77333918e+01,
       3.76722453e+01, 3.76009064e+01, 3.74557559e+01, 3.73487384e+01,
       3.73092045e+01, 3.71947687e+01, 3.71041242e+01, 3.70505191e+01,
       3.69133014e+01, 3.67634247e+01, 3.66456278e+01, 3.65290432e+01,
       3.63140052e+01, 3.62222428e+01, 3.61004853e+01, 3.60701926e+01,
       3.59389740e+01, 3.58451833e+01, 3.57948331e+01, 3.55761379e+01,
       3.53417375e+01, 3.53174426e+01, 3.51606462e+01, 3.50649375e+01,
       3.49956128e+01, 3.48096755e+01, 3.46904704e+01, 3.45500440e+01,
       3.45019410e+01, 3.43743639e+01, 3.43105330e+01, 3.41824911e+01,
       3.41134980e+01, 3.39974675e+01, 3.38476677e+01, 3.37077079e+01,
       3.35970750e+01, 3.35166407e+01, 3.34639923e+01, 3.33747894e+01,
       3.31759406e+01, 3.30944887e+01, 3.30522816e+01, 3.29648691e+01,
       3.28281055e+01, 3.26858498e+01, 3.25139999e+01, 3.24700680e+01,
       3.23813503e+01, 3.22545963e+01, 3.21273721e+01, 3.20818845e+01,
       3.20242678e+01, 3.18360381e+01, 3.17702816e+01, 3.16213432e+01,
       3.14466549e+01, 3.13822826e+01, 3.12861607e+01, 3.12187815e+01,
       3.11363609e+01, 3.10395219e+01, 3.09178254e+01, 3.08624171e+01,
       3.07061125e+01, 3.06601635e+01, 3.05900617e+01, 3.05152173e+01,
       3.04168707e+01, 3.03077836e+01, 3.02589980e+01, 3.01651046e+01,
       2.99367954e+01, 2.98662804e+01, 2.98470019e+01, 2.97897803e+01,
       2.96758975e+01, 2.95737027e+01, 2.94663380e+01, 2.93497454e+01,
       2.91986986e+01, 2.91002229e+01, 2.90677317e+01, 2.88894075e+01,
       2.88067976e+01, 2.87567366e+01, 2.87225318e+01, 2.86139336e+01,
       2.85271629e+01, 2.84975863e+01, 2.84301851e+01, 2.83441968e+01,
       2.82601131e+01, 2.82426450e+01, 2.80797562e+01, 2.79210703e+01,
       2.78667582e+01, 2.78005642e+01, 2.76391897e+01, 2.76162324e+01,
       2.74978014e+01, 2.74619372e+01, 2.74184234e+01, 2.72695067e+01,
       2.71944711e+01, 2.71321931e+01, 2.70346004e+01, 2.69288259e+01,
       2.68937191e+01, 2.68227369e+01, 2.66573737e+01, 2.65857889e+01,
       2.65238483e+01, 2.64860534e+01, 2.63989379e+01, 2.63168675e+01,
       2.61059116e+01, 2.60367140e+01, 2.59844561e+01, 2.58884939e+01,
       2.58569273e+01, 2.57329463e+01, 2.57006182e+01, 2.56455214e+01,
       2.55186367e+01, 2.54685410e+01, 2.53813175e+01, 2.53308575e+01,
       2.52304250e+01, 2.51130057e+01, 2.50649146e+01, 2.49981568e+01,
       2.49389513e+01, 2.48608818e+01, 2.48282399e+01, 2.47475820e+01,
       2.46712052e+01, 2.45808024e+01, 2.45109802e+01, 2.44273407e+01,
       2.43644631e+01, 2.42696360e+01, 2.42409324e+01, 2.41175977e+01,
       2.40607032e+01, 2.39347665e+01, 2.38635017e+01, 2.37679574e+01,
       2.37119844e+01, 2.36928412e+01, 2.36280348e+01, 2.34928894e+01,
       2.34241843e+01, 2.33874919e+01, 2.32949613e+01, 2.32699161e+01,
       2.31787634e+01, 2.31319428e+01, 2.30625048e+01, 2.30036585e+01,
       2.29271303e+01, 2.28608291e+01, 2.27350985e+01, 2.27079917e+01,
       2.25823921e+01, 2.25591043e+01, 2.25100794e+01, 2.24369633e+01,
       2.23552947e+01, 2.22585560e+01, 2.22048496e+01, 2.21599115e+01,
       2.20921059e+01, 2.20239285e+01, 2.19405006e+01, 2.18852074e+01,
       2.18184973e+01, 2.17345279e+01, 2.16807828e+01, 2.16345422e+01,
       2.15900095e+01, 2.15320313e+01, 2.14028095e+01, 2.13324031e+01,
       2.12805329e+01, 2.12446047e+01, 2.11502878e+01, 2.10989774e+01,
       2.10426796e+01, 2.08833717e+01, 2.08682493e+01, 2.08111720e+01,
       2.06974225e+01, 2.06393711e+01, 2.05687419e+01, 2.04976828e+01,
       2.04613525e+01, 2.04461121e+01, 2.03617468e+01, 2.03302413e+01,
       2.02307252e+01, 2.01445445e+01, 2.01103441e+01, 2.00390183e+01,
       1.99616037e+01, 1.99217052e+01, 1.98900933e+01, 1.98050939e+01,
       1.97698483e+01, 1.97209281e+01, 1.96735303e+01, 1.96099809e+01,
       1.95192257e+01, 1.93945316e+01, 1.93893126e+01, 1.93097987e+01,
       1.92435843e+01, 1.92174067e+01, 1.91685938e+01, 1.90699119e+01,
       1.90558651e+01, 1.90117000e+01, 1.89784229e+01, 1.89072391e+01,
       1.88644249e+01, 1.88192075e+01, 1.87054811e+01, 1.86601567e+01,
       1.86360077e+01, 1.85037826e+01, 1.84834148e+01, 1.83288336e+01,
       1.82836562e+01, 1.82595084e+01, 1.81880737e+01, 1.81634530e+01,
       1.81171864e+01, 1.80341462e+01, 1.80194086e+01, 1.79335546e+01,
       1.78438208e+01, 1.78073597e+01, 1.77585958e+01, 1.77069111e+01,
       1.76306807e+01, 1.76105342e+01, 1.75497491e+01, 1.74637563e+01,
       1.74294247e+01, 1.73929066e+01, 1.73237673e+01, 1.72467276e+01,
       1.72006218e+01, 1.71476550e+01, 1.71009765e+01, 1.70374182e+01,
       1.69793755e+01, 1.69418545e+01, 1.68672210e+01, 1.67846576e+01,
       1.67569720e+01, 1.67273250e+01, 1.66433246e+01, 1.65885056e+01,
       1.65502119e+01, 1.65021081e+01, 1.64536178e+01, 1.63739249e+01,
       1.63188721e+01, 1.63092099e+01, 1.61712407e+01, 1.61477304e+01,
       1.61128166e+01, 1.60736905e+01, 1.60672414e+01, 1.60045678e+01,
       1.59704543e+01, 1.58943658e+01, 1.58617638e+01, 1.58069728e+01,
       1.57347047e+01, 1.56890951e+01, 1.56129100e+01, 1.55667133e+01,
       1.55298052e+01, 1.54562226e+01, 1.53785187e+01, 1.53661261e+01,
       1.53186389e+01, 1.52126260e+01, 1.51576920e+01, 1.51174220e+01,
       1.50219491e+01, 1.49976944e+01, 1.49599921e+01, 1.49543298e+01,
       1.48992515e+01, 1.48473193e+01, 1.47990390e+01, 1.47487982e+01,
       1.47270139e+01, 1.46613118e+01, 1.46252571e+01, 1.45700022e+01,
       1.45311514e+01, 1.45111056e+01, 1.44646922e+01, 1.43765858e+01,
       1.43511058e+01, 1.42720085e+01, 1.42403811e+01, 1.41812174e+01,
       1.41051625e+01, 1.40851017e+01, 1.40392498e+01, 1.39518363e+01,
       1.39161551e+01, 1.39058382e+01, 1.38272740e+01, 1.38182187e+01,
       1.37400354e+01, 1.37205621e+01, 1.36807336e+01, 1.36357070e+01,
       1.35866201e+01, 1.35455495e+01, 1.35205043e+01, 1.34631061e+01,
       1.34187795e+01, 1.33148427e+01, 1.33016628e+01, 1.32803901e+01,
       1.32765383e+01, 1.31537534e+01, 1.31360239e+01, 1.30867260e+01,
       1.30089025e+01, 1.30026527e+01, 1.29093980e+01, 1.28936006e+01,
       1.28249732e+01, 1.27995175e+01, 1.27496081e+01, 1.27375223e+01,
       1.26564684e+01, 1.26130591e+01, 1.25585343e+01, 1.25268029e+01,
       1.24738761e+01, 1.24091201e+01, 1.23412526e+01, 1.23183088e+01,
       1.22901085e+01, 1.22478261e+01, 1.22269278e+01, 1.21489042e+01,
       1.21338047e+01, 1.20859491e+01, 1.20653525e+01, 1.20380310e+01,
       1.19726400e+01, 1.19603112e+01, 1.18928156e+01, 1.18223776e+01,
       1.17980808e+01, 1.17056660e+01, 1.16903942e+01, 1.16523980e+01,
       1.16232965e+01, 1.16167851e+01, 1.15540804e+01, 1.15339094e+01,
       1.14824189e+01, 1.14074860e+01, 1.13864956e+01, 1.13705901e+01,
       1.13528108e+01, 1.12632680e+01, 1.12331620e+01, 1.11626495e+01,
       1.11507031e+01, 1.10997091e+01, 1.10492709e+01, 1.09972478e+01,
       1.09795853e+01, 1.09268573e+01, 1.08899604e+01, 1.08603348e+01,
       1.08126326e+01, 1.07533711e+01, 1.07098446e+01, 1.06939280e+01,
       1.06650111e+01, 1.06513552e+01, 1.06302706e+01, 1.05911766e+01,
       1.05217830e+01, 1.04921606e+01, 1.04756299e+01, 1.04392523e+01,
       1.03894112e+01, 1.03289886e+01, 1.02675208e+01, 1.02509836e+01,
       1.02213535e+01, 1.01788944e+01, 1.01520439e+01, 1.01420985e+01,
       1.01260369e+01, 1.00597835e+01, 1.00302594e+01, 1.00058877e+01,
       9.97287733e+00, 9.89736753e+00, 9.86620212e+00, 9.82604310e+00,
       9.79513428e+00, 9.79004609e+00, 9.75380137e+00, 9.65656148e+00,
       9.62711637e+00, 9.54330860e+00, 9.53781757e+00, 9.53127212e+00,
       9.49855759e+00, 9.44893558e+00, 9.43139650e+00, 9.41571269e+00,
       9.35941775e+00, 9.34968211e+00, 9.32005683e+00, 9.28608785e+00,
       9.27110873e+00, 9.24895643e+00, 9.14535224e+00, 9.10113658e+00,
       9.06727224e+00, 9.06421569e+00, 9.02604900e+00, 8.99572559e+00,
       8.97107022e+00, 8.94266365e+00, 8.91152090e+00, 8.89747602e+00,
       8.82630008e+00, 8.79529889e+00, 8.74716372e+00, 8.73390729e+00,
       8.67177280e+00, 8.64956462e+00, 8.62401078e+00, 8.61873430e+00,
       8.55979959e+00, 8.54483172e+00, 8.50510800e+00, 8.44646903e+00,
       8.41704632e+00, 8.40343907e+00, 8.35453776e+00, 8.33071000e+00,
       8.28116502e+00, 8.24862278e+00, 8.19963553e+00, 8.15629245e+00,
       8.13673969e+00, 8.09671099e+00, 8.08347097e+00, 8.06089938e+00,
       8.03204367e+00, 7.99914153e+00, 7.97825006e+00, 7.94608110e+00,
       7.87773586e+00, 7.86474583e+00, 7.84634310e+00, 7.80178288e+00,
       7.76260588e+00, 7.73366109e+00, 7.72146710e+00, 7.68383541e+00,
       7.64182379e+00, 7.62209135e+00, 7.60262978e+00, 7.57160655e+00,
       7.53985666e+00, 7.49501862e+00, 7.45718787e+00, 7.43211636e+00,
       7.39747021e+00, 7.38443238e+00, 7.35832775e+00, 7.31956180e+00,
       7.28473891e+00, 7.25848785e+00, 7.20979988e+00, 7.19350294e+00,
       7.18088430e+00, 7.13653835e+00, 7.09064891e+00, 7.07743674e+00,
       7.01892945e+00, 6.99554770e+00, 6.96303979e+00, 6.94089652e+00,
       6.92071929e+00, 6.88259940e+00, 6.86084879e+00, 6.81292245e+00,
       6.78337815e+00, 6.76837148e+00, 6.74458813e+00, 6.72352585e+00,
       6.68502222e+00, 6.65465564e+00, 6.64484143e+00, 6.59670357e+00,
       6.57425393e+00, 6.54862307e+00, 6.52602322e+00, 6.47231255e+00,
       6.44915238e+00, 6.43412570e+00, 6.43036738e+00, 6.40000212e+00,
       6.37828070e+00, 6.32095549e+00, 6.30396482e+00, 6.26106716e+00,
       6.23341764e+00, 6.19898404e+00, 6.18040406e+00, 6.16043183e+00,
       6.15089038e+00, 6.11657077e+00, 6.08251416e+00, 6.05169042e+00,
       6.03378723e+00, 5.99226681e+00, 5.97831296e+00, 5.93798820e+00,
       5.89507442e+00, 5.87614101e+00, 5.86026917e+00, 5.84624471e+00,
       5.82262366e+00, 5.80045815e+00, 5.78140647e+00, 5.76870878e+00,
       5.76029194e+00, 5.72789566e+00, 5.67527936e+00, 5.66496439e+00,
       5.63465543e+00, 5.60388115e+00, 5.58553073e+00, 5.56862015e+00,
       5.54177116e+00, 5.53610330e+00, 5.48356524e+00, 5.44001409e+00,
       5.40483758e+00, 5.38884255e+00, 5.34506222e+00, 5.32755430e+00,
       5.29797311e+00, 5.29683980e+00, 5.26130858e+00, 5.23716600e+00,
       5.20447993e+00, 5.19410027e+00, 5.15074988e+00, 5.11558950e+00,
       5.09907066e+00, 5.09037944e+00, 5.07330763e+00, 5.03521135e+00,
       5.01885008e+00, 4.99866550e+00, 4.96290441e+00, 4.93711237e+00,
       4.92257499e+00, 4.88249351e+00, 4.85645363e+00, 4.82604149e+00,
       4.81806490e+00, 4.78033912e+00, 4.77375855e+00, 4.71854910e+00,
       4.67062473e+00, 4.66872759e+00, 4.65494499e+00, 4.62771106e+00,
       4.60395053e+00, 4.57556267e+00, 4.53023422e+00, 4.52727225e+00,
       4.47536283e+00, 4.47221888e+00, 4.45675604e+00, 4.41140910e+00,
       4.39994965e+00, 4.36479738e+00, 4.35677012e+00, 4.33845613e+00,
       4.31125821e+00, 4.28504731e+00, 4.27387806e+00, 4.24045368e+00,
       4.22952711e+00, 4.20890781e+00, 4.18976548e+00, 4.17874231e+00,
       4.13051698e+00, 4.12019489e+00, 4.10224535e+00, 4.08333597e+00,
       4.04514933e+00, 4.01379839e+00, 3.99661888e+00, 3.98924936e+00,
       3.98061983e+00, 3.95602868e+00, 3.92072292e+00, 3.89257590e+00,
       3.85796240e+00, 3.83346550e+00, 3.79832770e+00, 3.79103769e+00,
       3.76252530e+00, 3.75869649e+00, 3.71631153e+00, 3.69536188e+00,
       3.67532119e+00, 3.63114038e+00, 3.62502571e+00, 3.62003830e+00,
       3.55982634e+00, 3.54588293e+00, 3.51567371e+00, 3.49721631e+00,
       3.49006434e+00, 3.46789050e+00, 3.46284668e+00, 3.44123759e+00,
       3.41115985e+00, 3.39209466e+00, 3.37965512e+00, 3.37367582e+00,
       3.35383213e+00, 3.29283399e+00, 3.27773192e+00, 3.26722155e+00,
       3.23281990e+00, 3.22429938e+00, 3.21892201e+00, 3.20102161e+00,
       3.16915459e+00, 3.15184375e+00, 3.10872810e+00, 3.09746257e+00,
       3.08162368e+00, 3.03576981e+00, 3.00332948e+00, 2.97837463e+00,
       2.96789003e+00, 2.93391787e+00, 2.90566957e+00, 2.89087203e+00,
       2.87437326e+00, 2.85948131e+00, 2.84818583e+00, 2.82713058e+00,
       2.80685114e+00, 2.79430970e+00, 2.77224680e+00, 2.75144853e+00,
       2.73377086e+00, 2.70881928e+00, 2.68497127e+00, 2.66440658e+00,
       2.65149870e+00, 2.63844556e+00, 2.61391745e+00, 2.58476009e+00,
       2.56964934e+00, 2.54489473e+00, 2.50787566e+00, 2.49320436e+00,
       2.48432500e+00, 2.44968785e+00, 2.43822049e+00, 2.39888674e+00,
       2.36893886e+00, 2.35554500e+00, 2.34427064e+00, 2.32142929e+00,
       2.30935255e+00, 2.28953240e+00, 2.28227215e+00, 2.24826417e+00,
       2.23538400e+00, 2.22248456e+00, 2.17180798e+00, 2.16037839e+00,
       2.13572761e+00, 2.11623227e+00, 2.10352662e+00, 2.09782798e+00,
       2.06269274e+00, 2.03640280e+00, 2.01370690e+00, 1.99489012e+00,
       1.98537561e+00, 1.97499654e+00, 1.93824725e+00, 1.91005873e+00,
       1.88688234e+00, 1.88103489e+00, 1.86453558e+00, 1.84181250e+00,
       1.82471627e+00, 1.82069345e+00, 1.80537776e+00, 1.79178662e+00,
       1.74480403e+00, 1.73422850e+00, 1.69107103e+00, 1.66475083e+00,
       1.64942221e+00, 1.64269331e+00, 1.62235392e+00, 1.58583719e+00,
       1.57218381e+00, 1.55091693e+00, 1.53545052e+00, 1.49521129e+00,
       1.48733674e+00, 1.46236235e+00, 1.42538904e+00, 1.40208236e+00,
       1.39290958e+00, 1.39041452e+00, 1.35352897e+00, 1.35099038e+00,
       1.31790298e+00, 1.29864433e+00, 1.29101331e+00, 1.28865146e+00,
       1.25941562e+00, 1.20378139e+00, 1.18373786e+00, 1.17568949e+00,
       1.13667585e+00, 1.10766687e+00, 1.10025589e+00, 1.06795801e+00,
       1.03597850e+00, 1.01010943e+00, 9.90579946e-01, 9.71100550e-01,
       9.33405335e-01, 9.26500140e-01, 8.96165577e-01, 8.83358395e-01,
       8.62402825e-01, 8.24326812e-01, 7.96071703e-01, 7.82089259e-01,
       7.39250116e-01, 7.27254648e-01, 6.91998736e-01, 6.59346290e-01,
       6.10454174e-01, 5.76874372e-01, 5.59182558e-01, 5.42232507e-01,
       4.82851864e-01, 4.56223791e-01, 4.42875457e-01])
 vt=array([[ 9.03892016e-02,  4.08593223e-02,  2.06275219e-02, ...,
         0.00000000e+00,  0.00000000e+00,  4.40259847e-04],
       [-9.62969448e-02, -5.50041174e-04, -2.46633864e-02, ...,
         0.00000000e+00,  0.00000000e+00,  3.52288265e-04],
       [-3.18876919e-02, -6.99778482e-02, -1.80100040e-02, ...,
         0.00000000e+00,  0.00000000e+00, -4.53173168e-04],
       ...,
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00],
       [-1.52300162e-03, -9.63514164e-04,  1.08947650e-03, ...,
         0.00000000e+00,  0.00000000e+00,  9.66657652e-01]])
def predict_mat(k=2, ):
    return u[:, 0:k] @ np.diag(sigma[0:k]) @ vt[0:k, :]
predict_mat()
array([[ 2.13051592e+00,  1.13949803e+00,  4.75167952e-01, ...,
         0.00000000e+00,  0.00000000e+00,  1.37485195e-02],
       [ 1.51882465e+00,  2.88303616e-01,  3.71512329e-01, ...,
         0.00000000e+00,  0.00000000e+00, -2.12839285e-04],
       [ 6.54040790e-01,  1.00287420e-01,  1.61473962e-01, ...,
         0.00000000e+00,  0.00000000e+00, -5.47658805e-04],
       ...,
       [ 8.00092592e-01,  1.46109500e-01,  1.96067225e-01, ...,
         0.00000000e+00,  0.00000000e+00, -2.22272240e-04],
       [ 9.74527315e-01,  4.28421992e-01,  2.23151465e-01, ...,
         0.00000000e+00,  0.00000000e+00,  4.51539766e-03],
       [ 1.91318027e+00,  7.61755724e-01,  4.43048231e-01, ...,
         0.00000000e+00,  0.00000000e+00,  7.34886203e-03]])
rmses = [rmse(predict_mat(i), val_data_matrix) for i in range(1, min(u.shape[1], vt.shape[0]))]
plt.plot(rmses)
plt.axhline(3.1664352347602613)
plt.axhline(3.4720042185951625)
plt.axhline(rmse(predict_mat(9), val_data_matrix), c="#ff4400")
<matplotlib.lines.Line2D at 0x7f667fca4580>

image


pre_mat = predict_mat(9)
mat_ans = torch.tensor([[1 if pre_mat[row[1]-1, row[2]-1] > 3 else 0 for row in val_data.itertuples()]]).T
# print(mat_ans.sum())
# print(labels_val.sum())
correct = (mat_ans &  labels_val).squeeze().sum().numpy()
print(f"Val Pre {correct/mat_ans.sum():.2%}")
print(f"Val Acc {correct/labels_val.sum():.2%}")
Val Pre 86.64%
Val Acc 4.55%
def fun(i):
    pre_mat = predict_mat(i)
    mat_ans = torch.tensor([[1 if pre_mat[row[1]-1, row[2]-1] > 3 else 0 for row in val_data.itertuples()]]).T
    correct = (mat_ans &  labels_val).squeeze().sum().numpy()
    return correct/mat_ans.sum()
ans_fun = [fun(i) for i in range(1, min(u.shape[1], vt.shape[0]))]
plt.plot(ans_fun)
/home/ran/.conda/envs/pytorch/lib/python3.10/site-packages/torch/_tensor.py:838: RuntimeWarning: invalid value encountered in multiply
  return self.reciprocal() * other





[<matplotlib.lines.Line2D at 0x7f667faf6d10>]

image

逻辑回归

import torch.optim as optim
train_bs = 16
valid_bs = 16
lr_init = 0.1
max_epoch = 200
model = torch.nn.Sequential(torch.nn.Linear(42, 1),torch.nn.Sigmoid())
model[0].weight.data.normal_(0, 0.01)
model[0].bias.data.fill_(0)
criterion = torch.nn.BCELoss() 
optimizer = optim.Adam(model.parameters(), lr=lr_init)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
for epoch in range(max_epoch):
    
    optimizer.zero_grad()
    outputs = model(inputs)
#     print(outputs.shape, labels.shape)
    loss = criterion(outputs.reshape(-1), labels.reshape(-1).float())
    loss.backward()
    optimizer.step()
    scheduler.step()  # 更新学习率
    
    # 统计预测信息
    predicted = torch.round(outputs.data)
    total = labels.size(0)
#     print(predicted.shape, labels.shape)
    correct = (predicted == labels).squeeze().sum().numpy()
    loss_sigma = loss.item()

    # 每10个iteration 打印一次训练信息,loss为10个iteration的平均
    print("Training: Epoch[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch + 1, max_epoch, loss_sigma, correct / total))
Training: Epoch[001/200] Loss: 0.7297 Acc:44.73%
Training: Epoch[002/200] Loss: 1.5042 Acc:55.27%
Training: Epoch[003/200] Loss: 0.9153 Acc:55.27%
Training: Epoch[004/200] Loss: 0.7998 Acc:44.73%
Training: Epoch[005/200] Loss: 1.0716 Acc:44.73%
Training: Epoch[006/200] Loss: 0.8786 Acc:44.73%
Training: Epoch[007/200] Loss: 0.6802 Acc:55.77%
Training: Epoch[008/200] Loss: 0.8075 Acc:55.27%
Training: Epoch[009/200] Loss: 0.9025 Acc:55.27%
Training: Epoch[010/200] Loss: 0.8257 Acc:55.27%
Training: Epoch[011/200] Loss: 0.6963 Acc:55.42%
Training: Epoch[012/200] Loss: 0.7065 Acc:49.33%
Training: Epoch[013/200] Loss: 0.8026 Acc:44.91%
Training: Epoch[014/200] Loss: 0.7927 Acc:45.11%
Training: Epoch[015/200] Loss: 0.7041 Acc:51.06%
Training: Epoch[016/200] Loss: 0.6814 Acc:56.13%
Training: Epoch[017/200] Loss: 0.7365 Acc:55.41%
Training: Epoch[018/200] Loss: 0.7615 Acc:55.35%
Training: Epoch[019/200] Loss: 0.7230 Acc:55.51%
Training: Epoch[020/200] Loss: 0.6787 Acc:56.61%
Training: Epoch[021/200] Loss: 0.6927 Acc:54.94%
Training: Epoch[022/200] Loss: 0.7285 Acc:49.12%
Training: Epoch[023/200] Loss: 0.7192 Acc:50.26%
Training: Epoch[024/200] Loss: 0.6837 Acc:56.28%
Training: Epoch[025/200] Loss: 0.6802 Acc:56.88%
Training: Epoch[026/200] Loss: 0.7043 Acc:55.85%
Training: Epoch[027/200] Loss: 0.7099 Acc:55.81%
Training: Epoch[028/200] Loss: 0.6890 Acc:56.39%
Training: Epoch[029/200] Loss: 0.6760 Acc:56.84%
Training: Epoch[030/200] Loss: 0.6890 Acc:55.51%
Training: Epoch[031/200] Loss: 0.6994 Acc:53.25%
Training: Epoch[032/200] Loss: 0.6875 Acc:55.74%
Training: Epoch[033/200] Loss: 0.6759 Acc:56.76%
Training: Epoch[034/200] Loss: 0.6828 Acc:56.82%
Training: Epoch[035/200] Loss: 0.6912 Acc:56.21%
Training: Epoch[036/200] Loss: 0.6851 Acc:56.40%
Training: Epoch[037/200] Loss: 0.6760 Acc:57.69%
Training: Epoch[038/200] Loss: 0.6792 Acc:56.98%
Training: Epoch[039/200] Loss: 0.6858 Acc:55.98%
Training: Epoch[040/200] Loss: 0.6819 Acc:56.45%
Training: Epoch[041/200] Loss: 0.6756 Acc:56.92%
Training: Epoch[042/200] Loss: 0.6781 Acc:57.35%
Training: Epoch[043/200] Loss: 0.6823 Acc:56.50%
Training: Epoch[044/200] Loss: 0.6794 Acc:57.10%
Training: Epoch[045/200] Loss: 0.6753 Acc:57.63%
Training: Epoch[046/200] Loss: 0.6775 Acc:56.98%
Training: Epoch[047/200] Loss: 0.6801 Acc:56.69%
Training: Epoch[048/200] Loss: 0.6773 Acc:56.92%
Training: Epoch[049/200] Loss: 0.6752 Acc:57.62%
Training: Epoch[050/200] Loss: 0.6774 Acc:57.33%
Training: Epoch[051/200] Loss: 0.6784 Acc:57.11%
Training: Epoch[052/200] Loss: 0.6781 Acc:57.16%
Training: Epoch[053/200] Loss: 0.6775 Acc:57.23%
Training: Epoch[054/200] Loss: 0.6767 Acc:57.35%
Training: Epoch[055/200] Loss: 0.6760 Acc:57.55%
Training: Epoch[056/200] Loss: 0.6755 Acc:57.75%
Training: Epoch[057/200] Loss: 0.6752 Acc:57.69%
Training: Epoch[058/200] Loss: 0.6752 Acc:57.54%
Training: Epoch[059/200] Loss: 0.6753 Acc:57.34%
Training: Epoch[060/200] Loss: 0.6756 Acc:57.05%
Training: Epoch[061/200] Loss: 0.6759 Acc:56.84%
Training: Epoch[062/200] Loss: 0.6761 Acc:56.79%
Training: Epoch[063/200] Loss: 0.6761 Acc:56.79%
Training: Epoch[064/200] Loss: 0.6760 Acc:56.80%
Training: Epoch[065/200] Loss: 0.6758 Acc:56.92%
Training: Epoch[066/200] Loss: 0.6756 Acc:57.07%
Training: Epoch[067/200] Loss: 0.6754 Acc:57.28%
Training: Epoch[068/200] Loss: 0.6752 Acc:57.47%
Training: Epoch[069/200] Loss: 0.6752 Acc:57.57%
Training: Epoch[070/200] Loss: 0.6752 Acc:57.69%
Training: Epoch[071/200] Loss: 0.6752 Acc:57.71%
Training: Epoch[072/200] Loss: 0.6753 Acc:57.84%
Training: Epoch[073/200] Loss: 0.6754 Acc:57.78%
Training: Epoch[074/200] Loss: 0.6755 Acc:57.75%
Training: Epoch[075/200] Loss: 0.6754 Acc:57.75%
Training: Epoch[076/200] Loss: 0.6754 Acc:57.81%
Training: Epoch[077/200] Loss: 0.6753 Acc:57.81%
Training: Epoch[078/200] Loss: 0.6752 Acc:57.69%
Training: Epoch[079/200] Loss: 0.6752 Acc:57.71%
Training: Epoch[080/200] Loss: 0.6752 Acc:57.58%
Training: Epoch[081/200] Loss: 0.6752 Acc:57.57%
Training: Epoch[082/200] Loss: 0.6752 Acc:57.52%
Training: Epoch[083/200] Loss: 0.6752 Acc:57.46%
Training: Epoch[084/200] Loss: 0.6752 Acc:57.39%
Training: Epoch[085/200] Loss: 0.6752 Acc:57.37%
Training: Epoch[086/200] Loss: 0.6752 Acc:57.38%
Training: Epoch[087/200] Loss: 0.6752 Acc:57.44%
Training: Epoch[088/200] Loss: 0.6752 Acc:57.44%
Training: Epoch[089/200] Loss: 0.6752 Acc:57.48%
Training: Epoch[090/200] Loss: 0.6752 Acc:57.53%
Training: Epoch[091/200] Loss: 0.6751 Acc:57.57%
Training: Epoch[092/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[093/200] Loss: 0.6752 Acc:57.64%
Training: Epoch[094/200] Loss: 0.6752 Acc:57.68%
Training: Epoch[095/200] Loss: 0.6752 Acc:57.71%
Training: Epoch[096/200] Loss: 0.6752 Acc:57.69%
Training: Epoch[097/200] Loss: 0.6752 Acc:57.71%
Training: Epoch[098/200] Loss: 0.6752 Acc:57.68%
Training: Epoch[099/200] Loss: 0.6752 Acc:57.64%
Training: Epoch[100/200] Loss: 0.6751 Acc:57.60%
Training: Epoch[101/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[102/200] Loss: 0.6751 Acc:57.55%
Training: Epoch[103/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[104/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[105/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[106/200] Loss: 0.6751 Acc:57.55%
Training: Epoch[107/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[108/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[109/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[110/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[111/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[112/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[113/200] Loss: 0.6751 Acc:57.55%
Training: Epoch[114/200] Loss: 0.6751 Acc:57.55%
Training: Epoch[115/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[116/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[117/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[118/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[119/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[120/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[121/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[122/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[123/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[124/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[125/200] Loss: 0.6751 Acc:57.56%
Training: Epoch[126/200] Loss: 0.6751 Acc:57.55%
Training: Epoch[127/200] Loss: 0.6751 Acc:57.55%
Training: Epoch[128/200] Loss: 0.6751 Acc:57.55%
Training: Epoch[129/200] Loss: 0.6751 Acc:57.55%
Training: Epoch[130/200] Loss: 0.6751 Acc:57.55%
Training: Epoch[131/200] Loss: 0.6751 Acc:57.55%
Training: Epoch[132/200] Loss: 0.6751 Acc:57.55%
Training: Epoch[133/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[134/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[135/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[136/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[137/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[138/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[139/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[140/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[141/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[142/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[143/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[144/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[145/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[146/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[147/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[148/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[149/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[150/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[151/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[152/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[153/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[154/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[155/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[156/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[157/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[158/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[159/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[160/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[161/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[162/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[163/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[164/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[165/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[166/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[167/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[168/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[169/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[170/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[171/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[172/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[173/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[174/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[175/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[176/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[177/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[178/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[179/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[180/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[181/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[182/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[183/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[184/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[185/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[186/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[187/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[188/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[189/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[190/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[191/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[192/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[193/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[194/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[195/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[196/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[197/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[198/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[199/200] Loss: 0.6751 Acc:57.54%
Training: Epoch[200/200] Loss: 0.6751 Acc:57.54%
def val_epoch(inputs = inputs_val, labels = labels_val, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler):
    outputs = model(inputs)
    loss = criterion(outputs.reshape(-1), labels.reshape(-1).float())

    # 统计预测信息
    predicted = (outputs.data >= 0.5).int()
    total = labels.size(0)
    correct = (predicted == labels).squeeze().sum().numpy()
    print(predicted.numpy(), labels.numpy(), correct)
    loss_sigma = loss.item()

    print("Val Loss: {:.4f} Acc:{:.2%} racc:{:.2%},{},{} ".format(loss_sigma, correct / total, (predicted & labels).squeeze().sum().numpy()/predicted.sum(), predicted.sum(), labels.sum()))

val_epoch(inputs = inputs_val, labels = labels_val, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler)
[[1]
 [0]
 [1]
 ...
 [1]
 [0]
 [0]] [[0]
 [0]
 [0]
 ...
 [1]
 [1]
 [0]] 17326
Val Loss: 0.6745 Acc:57.75% racc:59.26%,21657,16687 
# train(lr_init = 0.001, max_epoch = 200)



标签:Acc,200,00,协同,01,Loss,Epoch,过滤
From: https://www.cnblogs.com/RanX2018/p/17250744.html

相关文章

  • 过滤器与拦截器
    1.过滤器跟拦截器的区别在说拦截器之前,不得不说一下过滤器,有时候往往被这两个词搞的头大。其实我们最先接触的就是过滤器,还记得web.xml中配置的<filter>吗~你应该知道sp......
  • CodeIgniter 的数据安全过滤全解析
    由于对CI的SQL安全这些不放心,今天寡人啃了一下午的代码,算是对其机制比较了解了,为了让各位兄弟姐妹少走弯路,特将战果公布,希望大家喜欢。1.无论如何在获取参数之时都建设将x......
  • 爬取的数据,存到mysql中、爬虫和下载中间件、加代理,cookie,header,加入selenium、去重规
    目录0爬取的数据,存到mysql中1爬虫和下载中间件2加代理,cookie,header,加入selenium2.1加代理2.2加cookie,修改请求头,随机生成UserAgent2.3集成selenium3去重规则源码......
  • 布隆过滤器
    本文已收录至Github,推荐阅读......
  • 你学会什么是布隆过滤器了吗?
    导读在对响应时间要求比较严格的情况下,如果我们有里面,那么随着集合中元素数量的增加,我们需要的存储空间越来越大,检索时间也越来越长,导致内存过多开销和时间效率变低。......
  • SQL—分组过滤group by函数与having函数
    题目:查看每个学校的平均发帖数(avg_question_cnt)和平均回帖数(avg_answer_cnt),并取出平均发帖数小于5的学校和平均回帖数小于20的学校。(保留3位小数)大佬的分解:1、限定条件......
  • Redis缓存穿透-布隆过滤器
    Redis缓存穿透-布隆过滤器缓存穿透我举个蘑菇博客中的案例来说,我现在有一个博客详情页,然后博客详情页中的内容假设是存储在Redis中的,然后通过博客的Uid进行获取,正常的情......
  • 【Unity3D】协同程序
    1简介​1)协程概念​协同程序(Coroutine)简称协程,是伴随主线程一起运行的程序片段,是一个能够暂停执行的函数,用于解决程序并行问题。协程是C#中的概念,由于Unity3......
  • 过滤器
    过滤器filter<!DOCTYPEhtml><htmllang="en"><!--过滤器,是一个函数,定义到filters节点下,且一定要有return如果全局过滤器和私有过滤器名字一致,此时按照“**就......
  • ASP.NET MVC Filters 4种默认过滤器的使用
    过滤器(Filters)的出现使得我们可以在ASP.NETMVC程序里更好的控制浏览器请求过来的URL,不是每个请求都会响应内容,只响应特定内容给那些有特定权限的用户,过滤器理论上有以下功......