首页 > 其他分享 >deeplearning4j训练MNIST数据集以及验证

deeplearning4j训练MNIST数据集以及验证

时间:2023-07-14 17:33:10浏览次数:46  
标签:验证 new deeplearning4j png File org import MNIST

训练模型官方示例

MNIST数据下载地址: http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz
GitHub示例地址: https://github.com/deeplearning4j/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/LeNetMNISTReLu.java

/*******************************************************************************
 *
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.examples.quickstart.modeling.convolution;

import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.examples.utils.DataUtilities;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

/**
 * Implementation of LeNet-5 for handwritten digits image classification on MNIST dataset (99% accuracy)
 * <a href="http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf">[LeCun et al., 1998. Gradient based learning applied to document recognition]</a>
 * Some minor changes are made to the architecture like using ReLU and identity activation instead of
 * sigmoid/tanh, max pooling instead of avg pooling and softmax output layer.
 * <p>
 * This example will download 15 Mb of data on the first run.
 *
 * @author hanlon
 * @author agibsonccc
 * @author fvaleri
 * @author dariuszzbyrad
 */
public class LeNetMNISTReLu {
    private static final Logger LOGGER = LoggerFactory.getLogger(LeNetMNISTReLu.class);
    //    private static final String BASE_PATH = System.getProperty("java.io.tmpdir") + "/mnist";
    private static final String BASE_PATH = "D:\\Documents\\Downloads\\mnist_png";
    private static final String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";

    public static void main(String[] args) throws Exception {
        // 图片高度
        int height = 28;    // height of the picture in px
        // 图片宽度
        int width = 28;     // width of the picture in px
        // 通道 1 表示 黑白
        int channels = 1;   // single channel for grayscale images
        // 可能出现的结果数量 0-9 10个数字
        int outputNum = 10; // 10 digits classification
        // 批处理数量
        int batchSize = 54; // number of samples that will be propagated through the network in each iteration
        // 迭代次数
        int nEpochs = 1;    // number of training epochs
        // 随机数生成器
        int seed = 1234;    // number used to initialize a pseudorandom number generator.
        Random randNumGen = new Random(seed);

        LOGGER.info("Data load...");
        if (!new File(BASE_PATH + "/mnist_png").exists()) {

            LOGGER.debug("Data downloaded from {}", DATA_URL);
            String localFilePath = BASE_PATH + "/mnist_png.tar.gz";
            if (DataUtilities.downloadFile(DATA_URL, localFilePath)) {
                DataUtilities.extractTarGz(localFilePath, BASE_PATH);
            }
        }

        LOGGER.info("Data vectorization...");
        // vectorization of train data
        File trainData = new File(BASE_PATH + "/mnist_png/training");
        FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // use parent directory name as the image label
        ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
        trainRR.initialize(trainSplit);
        // MNIST中的数据
        DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);

        // pixel values from 0-255 to 0-1 (min-max scaling)
        DataNormalization imageScaler = new ImagePreProcessingScaler();
        imageScaler.fit(trainIter);
        trainIter.setPreProcessor(imageScaler);

        // vectorization of test data
        File testData = new File(BASE_PATH + "/mnist_png/testing");
        FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
        ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
        testRR.initialize(testSplit);
        DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
        testIter.setPreProcessor(imageScaler); // same normalization for better results

        LOGGER.info("Network configuration and training...");
        // reduce the learning rate as the number of training epochs increases
        // iteration #, learning rate
        Map<Integer, Double> learningRateSchedule = new HashMap<>();
        learningRateSchedule.put(0, 0.06);
        learningRateSchedule.put(200, 0.05);
        learningRateSchedule.put(600, 0.028);
        learningRateSchedule.put(800, 0.0060);
        learningRateSchedule.put(1000, 0.001);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .l2(0.0005) // ridge regression value
                .updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))
                .weightInit(WeightInit.XAVIER)
                .list()
                .layer(new ConvolutionLayer.Builder(5, 5)
                        .nIn(channels)
                        .stride(1, 1)
                        .nOut(20)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(new ConvolutionLayer.Builder(5, 5)
                        .stride(1, 1) // nIn need not specified in later layers
                        .nOut(50)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(new DenseLayer.Builder().activation(Activation.RELU)
                        .nOut(500)
                        .build())
                .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .build())
                .setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        net.setListeners(new ScoreIterationListener(10));
        LOGGER.info("Total num of params: {}", net.numParams());

        // evaluation while training (the score should go down)
        for (int i = 0; i < nEpochs; i++) {
            net.fit(trainIter);
            LOGGER.info("Completed epoch {}", i);
            Evaluation eval = net.evaluate(testIter);
            LOGGER.info(eval.stats());

            trainIter.reset();
            testIter.reset();
        }

        File ministModelPath = new File(BASE_PATH + "/minist-model.zip");
        ModelSerializer.writeModel(net, ministModelPath, true);
        LOGGER.info("The MINIST model has been saved in {}", ministModelPath.getPath());
    }
}

验证模型

package org.deeplearning4j.examples.quickstart.modeling.convolution;

import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;

import java.io.File;
import java.io.IOException;

/**
 * @description:
 * @author: Mr.Fang
 * @create: 2023-07-14 15:06
 **/

public class VerifyMNSIT {
    public static void main(String[] args) throws IOException {

        // 加载训练好的模型
        File modelFile = new File("D:\\Documents\\Downloads\\mnist_png\\minist-model.zip");
        MultiLayerNetwork model = MultiLayerNetwork.load(modelFile, true);

        // 加载待验证的图像
        File imageFile = new File("D:\\Documents\\Downloads\\mnist_png\\mnist_png\\testing\\8\\1717.png");
        NativeImageLoader loader = new NativeImageLoader(28, 28, 1);
        INDArray image = loader.asMatrix(imageFile);
        DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
        scaler.transform(image);

        // 对图像进行预测
        INDArray output = model.output(image);
        int predictedLabel = output.argMax().getInt();
        // 在这行代码中,`output.argMax()`用于找到`output`中具有最大值的索引。`output`是一个包含模型的输出概率的NDArray对象。对于MNIST模型,输出是一个长度为10的向量,表示数字0到9的概率分布。
        //
        //`.argMax()`方法返回具有最大值的索引。例如,如果`output`的值为[0.1, 0.3, 0.2, 0.05, 0.25, 0.05, 0.05, 0.1, 0.05, 0.05],则`.argMax()`将返回索引1,因为在位置1处的值0.3是最大的。
        //
        //最后,`.getInt()`方法将获取`.argMax()`的结果并将其转换为一个整数,表示预测的标签。在这个例子中,`predictedLabel`将包含模型预测的数字标签。
        //
        //简而言之,这行代码的作用是找到输出中概率最高的数字标签,以进行预测。
        System.out.println("Predicted label: " + predictedLabel);
    }
}

输出结果

o.n.l.f.Nd4jBackend - Loaded [CpuBackend] backend
o.n.n.NativeOpsHolder - Number of threads used for linear algebra: 6
o.n.l.c.n.CpuNDArrayFactory - Binary level Generic x86 optimization level AVX/AVX2
o.n.n.Nd4jBlas - Number of threads used for OpenMP BLAS: 6
o.n.l.a.o.e.DefaultOpExecutioner - Backend used: [CPU]; OS: [Windows 10]
o.n.l.a.o.e.DefaultOpExecutioner - Cores: [12]; Memory: [4.0GB];
o.n.l.a.o.e.DefaultOpExecutioner - Blas vendor: [OPENBLAS]
o.n.l.c.n.CpuBackend - Backend build information:
 GCC: "12.1.0"
STD version: 201103L
DEFAULT_ENGINE: samediff::ENGINE_CPU
HAVE_FLATBUFFERS
HAVE_OPENBLAS
o.d.n.m.MultiLayerNetwork - Starting MultiLayerNetwork with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE]
Predicted label: 8

标签:验证,new,deeplearning4j,png,File,org,import,MNIST
From: https://www.cnblogs.com/bxmm/p/17554573.html

相关文章

  • 容器签名验证
    如果容器构建过程受到破坏,它会使用户很容易意外使用恶意镜像而不是实际的容器镜像。对容器进行签名和验证始终确保我们运行的是实际的容器镜像。https://github.com/sigstore/cosign安装二进制命令gitclonehttps://github.com/sigstore/cosigngomodtidygobuildcmd/co......
  • playwright+opencv 过滑块拼图验证码
    前言最近看到浏览器自动化框架playwright,就使用了一下在模拟登录掘金是通过密码登陆时遇到需要通过拼图验证码于是通过查找发现可以通过opencv库解决问题下面是解决过程过程1.首先需要获取到图片,通过查看html可以很容易找到需要的图片2.通过opencv进行图像处理来获取到拼......
  • dede后台验证码错误或不显示的解决办法
    用v56之前版权include/vdimgck.php覆盖现在的文件,测试可行,不过验证码样式就回到以前版本的,喜欢新版的朋友可能会不舒服。我们打开data/safe/inc_safe_config.php打开头一行$safe_gdopen='1,2,3,4,5,6';修改为$safe_gdopen='1,2,3,4,5';去掉6,即关闭后台验证码选项,测......
  • dede怎么取消后台验证码,DEDE取消后台验证码
    DEDEv5.5/v5.6怎么取消后台验证码,DEDE取消后台验证码有些时候,使用DEDECMS的时候,在后台登陆处遇到后台验证码错误,或者嫌弃后台验证码麻烦,我们就可以将后台的验证码取消掉.这篇文章是介绍教你如果取消DEDECMSv5.5后台验证码的步骤:找到DEDE后台目录,也就是默认的/dede/login.......
  • 爬虫突破验证码技巧 - 2Captcha
    在互联网世界中,验证码作为一种防止机器人访问的工具,是爬虫最常遇到的阻碍。验证码的类型众多,从简单的数字、字母验证码,到复杂的图像识别验证码,再到更为高级的交互式验证码,每一种都有其独特的识别方法和应对策略。在这篇文章中,我们将一一介绍各种验证码的工作原理和使用2Captcha进......
  • HttpURLConnection绕过SSL验证,信任所有证书的工具类
    HttpURLConnection绕过SSL验证,信任所有证书的工具类发起https请求时经常会出现javax.net.ssl.SSLHandshakeException:java.security.cert.CertificateException这样的错误,那是因为环境中没有证书校验,我们可以在连接中设置绕过SSL校验来解决这个问题。publicclassSslUtil{......
  • 页面图片验证码显示(base64图片显示)
     前端在做登陆的时候经常会遇见输入验证码登陆的问题,一般情况下,后端返回的类似一个流图片,在页面中给图片赋值是不能显示的,想下面这种形式 这样的形式通常需要前端处理一下这个流,转成base64的形式,进行显示,代码如下:asyncgetInterfaceCode(){letres=awaitg......
  • 验证torch和torchvision安装成功
    importtorchprint("torch_version:",torch.__version__)print("cuda_version:",torch.version.cuda)print("cudnn_version:",torch.backends.cudnn.version())print("----------------------------------")flag=torch.cuda.is_ava......
  • 金额正则验证
    onInput(value){ if(value){ lettemp=value.toString(); temp=temp.replace(/^0{2,}/g,'') temp=temp.replace(/。/g,"."); temp=temp.replace(/[^\d.]/g,"");//清除"数字"和"."以外的......
  • 爬虫—图形验证码获取
    获取验证码图片步骤1.使用selenium操作谷歌浏览器,打开目标网站2.对目标网站进行截图,并将图片保存到本地3.获取验证码元素节点在屏幕上的位置,即横纵坐标4.使用Image库读取保存的截图5.使用pillow模块抠出大图中的验证码只截取元素节点位置对应部分导入所需库和打开目标......