首页 > 其他分享 >学习笔记426—keras中to_categorical函数解析

学习笔记426—keras中to_categorical函数解析

时间:2023-11-17 12:33:03浏览次数:42  
标签:num keras labels categorical shape classes np 426

keras中to_categorical函数解析

1.to_categorical的功能

简单来说,to_categorical就是将类别向量转换为二进制(只有0和1)的矩阵类型表示。其表现为将原有的类别向量转换为独热编码的形式。先上代码看一下效果:

from keras.utils.np_utils import *
#类别向量定义
b = [0,1,2,3,4,5,6,7,8]
#调用to_categorical将b按照9个类别来进行转换
b = to_categorical(b, 9)
print(b)
 
执行结果如下:
[[1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1.]]

to_categorical最为keras中提供的一个工具方法,从以上代码运行可以看出,将原来类别向量中的每个值都转换为矩阵里的一个行向量,从左到右依次是0,1,2,...8个类别。2表示为[0. 0. 1. 0. 0. 0. 0. 0. 0.],只有第3个为1,作为有效位,其余全部为0。
2.one_hot encoding(独热编码)介绍

独热编码又称为一位有效位编码,上边代码例子中其实就是将类别向量转换为独热编码的类别矩阵。也就是如下转换:

0  1  2  3  4  5  6  7  8
0=> [1. 0. 0. 0. 0. 0. 0. 0. 0.]
1=> [0. 1. 0. 0. 0. 0. 0. 0. 0.]
2=> [0. 0. 1. 0. 0. 0. 0. 0. 0.]
3=> [0. 0. 0. 1. 0. 0. 0. 0. 0.]
4=> [0. 0. 0. 0. 1. 0. 0. 0. 0.]
5=> [0. 0. 0. 0. 0. 1. 0. 0. 0.]
6=> [0. 0. 0. 0. 0. 0. 1. 0. 0.]
7=> [0. 0. 0. 0. 0. 0. 0. 1. 0.]
8=> [0. 0. 0. 0. 0. 0. 0. 0. 1.]

 那么一道思考题来了,让你自己编码实现类别向量向独热编码的转换,该怎样实现呢?

以下是我自己粗浅写的一个小例子,仅供参考:

def convert_to_one_hot(labels, num_classes):
    #计算向量有多少行
    num_labels = len(labels)
    #生成值全为0的独热编码的矩阵
    labels_one_hot = np.zeros((num_labels, num_classes))
    #计算向量中每个类别值在最终生成的矩阵“压扁”后的向量里的位置
    index_offset = np.arange(num_labels) * num_classes
    #遍历矩阵,为每个类别的位置填充1
    labels_one_hot.flat[index_offset + labels] = 1
    return labels_one_hot
#进行测试
b = [2, 4, 6, 8, 6, 2, 3, 7]
print(convert_to_one_hot(b,9))
 
测试结果:
[[0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0.]]

 3.源码解析

to_categorical在keras的utils/np_utils.py中,源码如下:

def to_categorical(y, num_classes=None, dtype='float32'):
    """Converts a class vector (integers) to binary class matrix.
    E.g. for use with categorical_crossentropy.
    # Arguments
        y: class vector to be converted into a matrix
            (integers from 0 to num_classes).
        num_classes: total number of classes.
        dtype: The data type expected by the input, as a string
            (`float32`, `float64`, `int32`...)
    # Returns
        A binary matrix representation of the input. The classes axis
        is placed last.
    # Example
    ```python
    # Consider an array of 5 labels out of a set of 3 classes {0, 1, 2}:
    > labels
    array([0, 2, 1, 2, 0])
    # `to_categorical` converts this into a matrix with as many
    # columns as there are classes. The number of rows
    # stays the same.
    > to_categorical(labels)
    array([[ 1.,  0.,  0.],
           [ 0.,  0.,  1.],
           [ 0.,  1.,  0.],
           [ 0.,  0.,  1.],
           [ 1.,  0.,  0.]], dtype=float32)
    ```
    """
    #将输入y向量转换为数组
    y = np.array(y, dtype='int')
    #获取数组的行列大小
    input_shape = y.shape
    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
        input_shape = tuple(input_shape[:-1])
    #y变为1维数组
    y = y.ravel()
    #如果用户没有输入分类个数,则自行计算分类个数
    if not num_classes:
        num_classes = np.max(y) + 1
    n = y.shape[0]
    #生成全为0的n行num_classes列的值全为0的矩阵
    categorical = np.zeros((n, num_classes), dtype=dtype)
    #np.arange(n)得到每个行的位置值,y里边则是每个列的位置值
    categorical[np.arange(n), y] = 1
    #进行reshape矫正
    output_shape = input_shape + (num_classes,)
    categorical = np.reshape(categorical, output_shape)
    return categorical


 

意在交流学习,欢迎点赞评论,并关注微信公众号:弈介布衣


标签:num,keras,labels,categorical,shape,classes,np,426
From: https://blog.51cto.com/hechangchun/8439807

相关文章

  • 解决 keras 首次装载预训练模型VGG16 时下载失败问题
    解决:Exception:URLfetchfailureonhttps://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5:None--[Errno104]Connectionresetbypeer解决方案:1、先将数据集单独下载下来:models/vgg16_weights_tf_d......
  • 2023-2024-1 20232426刘嘉成 《网络空间安全导论》第1周学习总结
    2023-2024-120232426刘嘉成《网络空间安全导论》第1周学习总结教材学习内容总结简要总结本周学习内容:本周我自学了《网络空间安全导论》第一章:网络空间安全概述,分别从信息时代与信息安全,网络空间安全学科浅谈,网络空间安全法律法规,信息安全标准四个方面进行了学习。对我国网......
  • 学期:2023-2024-1 学号:20231426 《计算机基础与程序设计》第七周学习总结
    作业信息这个作业属于哪个课程2022-2023-1-计算机基础与程序设计这个作业要求在哪里2022-2023-1计算机基础与程序设计作业这个作业的目标通过教材内容了解数组、子程序与参数作业正文https://www.cnblogs.com/hhaxx/p/17826871.html教材学习内容总结《计......
  • Keras_Quantization
    PTQ训练后量化的实现代码;过程:权重量化;infer校准数据集统计示例代码:QAT量化训练的实现代码;过程(量化后小模型平均精度损失1~2个点)训练模拟顶点模型(卷积参数为定点数,batchnormalization参数为高精度浮点数)combinesbatchnormalizationwiththeprecedingconvoluti......
  • 学期:2023-2024-1 学号:20231426 《计算机基础与程序设计》第六周学习总结
    作业信息这个作业属于哪个课程2022-2023-1-计算机基础与程序设计这个作业要求在哪里2022-2023-1计算机基础与程序设计作业这个作业的目标通过教材内容了解复合数据结构、查找与排序算法、递归、代码安全、简单类型与组合类型作业正文https://www.cnblogs.com/......
  • Mysql为什么存储表数据为什么不能超过2000万行,深度解释 转发 https://www.toutiao.co
    下面是我朋友的面试记录:面试官:讲一下你实习做了什么。朋友:我在实习期间做了一个存储用户操作记录的功能,主要是从MQ获取上游服务发送过来的用户操作信息,然后把这些信息存到MySQL里面,提供给数仓的同事使用。朋友:由于数据量比较大,每天大概有四五千多万条,所以我还给它做了分表的操......
  • TensorFlow、PyTorch、Keras、Scikit-learn和ChatGPT。视觉开发软件工具 Halcon、Visi
     目录TensorFlow、PyTorch、Keras、Scikit-learn和ChatGPT1.TensorFlow2.PyTorch3.Keras视觉开发软件工具Halcon、VisionPro、LabView、OpenCV,还有eVision、Mil、Sapera等。(一)、Halcon(二)OpenCV:ComputerVision(计算机视觉)(三)VisionProTensorFlow、PyTorch、Keras、Scikit-learn和......
  • Keras TypeError: ('Keyword argument not understood:', 'input')
    TypeError:('Keywordargumentnotunderstood:','input') model=Model(input=[inputs],output=output)报错信息TypeError:('Keywordargumentnotunderstood:','input')解决方法换成model=Model(inputs=...,outputs=...) ......
  • keras中 keras.layers merge is not callable
       旧版本中:   fromkeras.layersimportmerge       merge6=merge([layer1,layer2],mode='concat',concat_axis=3)新版本中:   fromkeras.layers.mergeimportconcatenate       merge=concatenate([layer1,layer2],axis=3) ......
  • 【AutoML】AutoKeras 的安装和环境配置(VSCode)
    本地环境中已经有太多的工作配置了(Python、Java、Maven、Docker等等),为了不影响其他环境运行,我选择直接在VSCode中创建工作空间并配置好AutoKeras(反正最后也是要在VSCode中进行开发的)。<br>打开VSCode后先创建一个工作区,然后在终端运行以下代码:python3-mvenvautokeras-......