首页 > 其他分享 >用 Tensorflow.js 做了一个动漫分类的功能(二)

用 Tensorflow.js 做了一个动漫分类的功能(二)

时间:2023-02-22 19:47:35浏览次数:37  
标签:const img js return 动漫 tf Tensorflow model data

前言:

        前面已经通过采集拿到了图片,并且也手动对图片做了标注。接下来就要通过 Tensorflow.js 基于 mobileNet 训练模型,最后就可以实现在采集中对图片进行自动分类了。

        这种功能在应用场景里就比较多了,比如图标素材站点,用户通过上传一个图标,系统会自动匹配出相似的图标,还有二手平台,用户通过上传闲置物品图片,平台自动给出分类等,这些也都是前期对海量图片进行了标注训练而得到一个损失率极低的模型。下面就通过简答的代码实现一个小的动漫分类。

环境:

Node

Http-Server

Parcel

Tensorflow

编码:

1. 训练模型

1.1.  创建项目,安装依赖包

npm install @tensorflow/tfjs --legacy-peer-deps
npm install @tensorflow/tfjs-node-gpu --legacy-peer-deps

1.2.  全局安装 Http-Server

npm install i http-server

1.3.  下载 mobileNet 模型文件 (网上有下载)

1.4.   根目录下启动 Http 服务 (开启跨域),用于 mobileNet 和训练结果的模型可访问

http-server --cors -p 8080

1.5.  创建训练执行脚本 run.js

const tf = require('@tensorflow/tfjs-node-gpu');

const getData = require('./data');
const TRAIN_PATH = './动漫分类/train';
const OUT_PUT = 'output';
const MOBILENET_URL = 'http://127.0.0.1:8080/data/mobilenet/web_model/model.json';

(async () => {
  const { ds, classes } = await getData(TRAIN_PATH, OUT_PUT);
  console.log(ds, classes);
  //引入别人训练好的模型
  const mobilenet = await tf.loadLayersModel(MOBILENET_URL);
  //查看模型结构
  mobilenet.summary();

  const model = tf.sequential();
  //截断模型,复用了86个层
  for (let i = 0; i < 86; ++i) {
    const layer = mobilenet.layers[i];
    layer.trainable = false;
    model.add(layer);
  }
  //降维,摊平数据
  model.add(tf.layers.flatten());
  //设置全连接层
  model.add(tf.layers.dense({
    units: 10,
    activation: 'relu'//设置激活函数,用于处理非线性问题
  }));

  model.add(tf.layers.dense({
    units: classes.length,
    activation: 'softmax'//用于多分类问题
  }));
  //设置损失函数,优化器
  model.compile({
    loss: 'sparseCategoricalCrossentropy',
    optimizer: tf.train.adam(),
    metrics:['acc']
  });

  //训练模型
  await model.fitDataset(ds, { epochs: 20 });
  //保存模型
  await model.save(`file://${process.cwd()}/${OUT_PUT}`);
})();

1.6. 创建图片与 Tensor 转换库 data.js

const fs = require('fs');
const tf = require("@tensorflow/tfjs-node-gpu");

const img2x = (imgPath) => {
  const buffer = fs.readFileSync(imgPath);
  //清除数据
  return tf.tidy(() => {
    //把图片转成tensor
    const imgt = tf.node.decodeImage(new Uint8Array(buffer), 3);
    //调整图片大小
    const imgResize = tf.image.resizeBilinear(imgt, [224, 224]);
    //归一化
    return imgResize.toFloat().sub(255 / 2).div(255 / 2).reshape([1, 224, 224, 3]);
  });
}

const getData = async (traindir, output) => {
  let classes = fs.readdirSync(traindir, 'utf-8');
  fs.writeFileSync(`./${output}/classes.json`, JSON.stringify(classes));
  const data = [];
  classes.forEach((dir, dirIndex) => {
    fs.readdirSync(`${traindir}/${dir}`)
      .filter(n => n.match(/jpg$/))
      .slice(0, 1000)
      .forEach(filename => {
        const imgPath = `${traindir}/${dir}/${filename}`;

        data.push({ imgPath, dirIndex });
      });
  });

  console.log(data);

  //打乱训练顺序,提高准确度
  tf.util.shuffle(data);

  const ds = tf.data.generator(function* () {
    const count = data.length;
    const batchSize = 32;
    for (let start = 0; start < count; start += batchSize) {
      const end = Math.min(start + batchSize, count);
      console.log('当前批次', start);
      yield tf.tidy(() => {
        const inputs = [];
        const labels = [];
        for (let j = start; j < end; ++j) {
          const { imgPath, dirIndex } = data[j];
          const x = img2x(imgPath);
          inputs.push(x);
          labels.push(dirIndex);
        }
        const xs = tf.concat(inputs);
        const ys = tf.tensor(labels);
        return { xs, ys };
      });
    }
  });

  return { ds, classes };
}

module.exports = getData;

1.7. 运行执行文件

node run.js

2. 调用模型

2.1. 全局安装 parcel

npm install i parcel

2.2. 创建页面 index.html

<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
<br>

2.3. 创建模型调用预测脚本 script.js

import * as tf from '@tensorflow/tfjs';
import { img2x, file2img } from './utils';

const MODEL_PATH = 'http://127.0.0.1:8080/t7';
const CLASSES = ["假面骑士","奥特曼","海贼王","火影忍者","龙珠"];


window.onload = async () => {
    const model = await tf.loadLayersModel(MODEL_PATH + '/output/model.json');

    window.predict = async (file) => {
        const img = await file2img(file);
        document.body.appendChild(img);
        const pred = tf.tidy(() => {
            const x = img2x(img);
            return model.predict(x);
        });

        const index = pred.argMax(1).dataSync()[0];
        console.log(pred.argMax(1).dataSync());

        let predictStr = "";
        if (typeof CLASSES[index] == 'undefined') {
            predictStr = BRAND_CLASSES[index];
        } else {
            predictStr = CLASSES[index];
        }

        setTimeout(() => {
            alert(`预测结果:${predictStr}`);
        }, 0);
    };
};

2.4. 创建图片 tensor 格式转换库 utils.js

import * as tf from '@tensorflow/tfjs';

export function img2x(imgEl){
    return tf.tidy(() => {
        const input = tf.browser.fromPixels(imgEl)
            .toFloat()
            .sub(255 / 2)
            .div(255 / 2)
            .reshape([1, 224, 224, 3]);
        return input;
    });
}

export function file2img(f) {
    return new Promise(resolve => {
        const reader = new FileReader();
        reader.readAsDataURL(f);
        reader.onload = (e) => {
            const img = document.createElement('img');
            img.src = e.target.result;
            img.width = 224;
            img.height = 224;
            img.onload = () => resolve(img);
        };
    });
}

2.5. 打包项目并运行

parcel index.html

2.6. 运行效果

注意:

1. 模型训练过程报错

Input to reshape is a tensor with 50176 values, but the requested shape has 150528

1.1. 原因

    张量 reshape 不对,实际输入元素个数与所需矩阵元素个数不一致,就是采集过来的图片有多种图片格式,而不同格式的通道不同 (jpg3 通道,png4 通道,灰色图片 1 通道),在将图片转换 tensor 时与代码里的张量形状不匹配。

1.2. 解决方法

    一种方法是删除灰色或 png 图片,其二是修改代码 tf.node.decodeImage (new Uint8Array (buffer), 3)

 

 

标签:const,img,js,return,动漫,tf,Tensorflow,model,data
From: https://www.cnblogs.com/zerofc/p/17145613.html

相关文章

  • 小程序 app.js中登录获取token还未结束时 首页请求过快报未登录
    场景:在小程序中大家应该都有这样的场景,在onLaunch里用wx.login静默登录拿到code,再用code去发送请求获取token、用户信息等,整个过程都是异步的,然后我们在业务页面里onLoad......
  • tensorflow.js 对视频 / 直播人脸检测和特征点收集
    前言:    这里要介绍的是Tensorflow.js官方提供的两个人脸检测模型,分别是face-detection和face-landmarks-detection。他们不但可以对视频中的人间进行精确定......
  • tensorflow.js 多分类,机器学习区分企鹅种类
    前言:    在规则编码中,我们常常会遇到需要通过多种区间判断某种物品分类。比如二手物品的定价,尽管不是新品没有SKU但是基本的参数是少不了。想通过成色来区分某......
  • tensorflow.js 视频图片多目标检测
    前言:    Tensorflow.js官方提供了很多常用模型库,涵盖了平时开发中大部分场景的模型。例如,前面提到的图片识别,除此之外还有人体姿态识别,目标物体识别,语音文字等识......
  • C# json数组排序-根据指定key排序-jsonArray
    //jsonArray.ToJson()是json数组根据指定treatment_no来进行排序json字符串[{"treatment_type":"1","treatment_no":"42",},{"treatment_type":......
  • js: ElementUI表单验证validate和validateField
    文档回顾1、validate:对整个表单进行校验的方法,参数为一个回调函数。该回调函数会在校验结束后被调用,并传入两个参数:是否校验成功和未通过校验的字段。若不传入回调函数,则......
  • tensorflow2.0+TF-lite 各种报错
    generic_type:type"InterpreterWrapper"isalreadyregistered!原因:tensorflow2.5.0rc0版本太高,降低版本:pipinstalltensorflow==2.3  ValueErron:"batch_si......
  • js中的函数的各种形态 230222
    标准函数functionfn(){console.log(1111)}fn()匿名函数等号右边是匿名函数varfn=function(){console.log(222)}fn()自启动函数本质还是匿名函数(function()......
  • 2023年02月16日vue.js 教程
    1.创建第一个vuevue应用结构可以分成两个部分,一个是视图,一个是脚本;脚本有两个参数:el和data视图 <divid="app"> {{message}}{{name}} </div> 脚本 <......
  • E007Web学习笔记-JavaScript(五):JS事件
    一、概述1、事件概念某些组件被执行了某些操作后,触发某些代码的执行;事件:某些操作,如点击,双击,鼠标移动,键盘按下……; 2、事件源组件,如按钮、文......