首页 > 其他分享 >工业缺陷检测实战——磁瓦表面缺陷分类

工业缺陷检测实战——磁瓦表面缺陷分类

时间:2024-12-23 21:56:04浏览次数:6  
标签:实战 val 磁瓦 args add train path model 缺陷

 第一步:准备数据

6种磁瓦表面缺陷:self.class_indict = ["MT_Blowhole", "MT_Break", "MT_Crack", "MT_Fray", "MT_Free", "MT_Uneven"]

,总共有1330张图片,每个文件夹单独放一种数据

第二步:搭建模型

本文选择一个ConvNext网络,其原理介绍如下:

ConvNext (Convolutional Network Net Generation), 即下一代卷积神经网络, 是近些年来 CV 领域的一个重要发展. ConvNext 由 Facebook AI Research 提出, 仅仅通过卷积结构就达到了与 Transformer 结构相媲美的 ImageNet Top-1 准确率, 这在近年来以 Transformer 为主导的视觉问题解决趋势中显得尤为突出.

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import json
import os
import argparse
import time

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets

from model import convnext_tiny as create_model
from utils import  create_lr_scheduler, get_params_groups, train_one_epoch, evaluate,plot_class_preds


def main(args):
    os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print(args)
    # 创建定义文件夹以及文件
    filename = 'record.txt'
    save_path = 'runs'
    path_num = 1
    while os.path.exists(save_path + f'{path_num}'):
        path_num += 1
    save_path = save_path + f'{path_num}'
    os.mkdir(save_path)
    f = open(save_path + "/" + filename, 'w')
    f.write("{}\n".format(args))

    # print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
    # 实例化SummaryWriter对象
    # #######################################
    tb_writer = SummaryWriter(log_dir=save_path + "/experiment")
    if os.path.exists(save_path + "/weights") is False:
        os.makedirs(save_path + "/weights")

    img_size = 224
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
                                   transforms.CenterCrop(img_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 实例化训练数据集
    train_data_set = datasets.ImageFolder(root=os.path.join(args.data_path, "train"),
                                          transform=data_transform["train"])

    # 实例化验证数据集
    val_data_set = datasets.ImageFolder(root=os.path.join(args.data_path, "val"),
                                        transform=data_transform["val"])

    # 生成class_indices.json文件,包括有模型对应的序列号
    # #######################################
    classes_list = train_data_set.class_to_idx
    cla_dict = dict((val, key) for key, val in classes_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw)

    val_loader = torch.utils.data.DataLoader(val_data_set,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw)

    model = create_model(num_classes=args.num_classes).to(device)

    # Write the model into tensorboard
    # #######################################
    init_img = torch.zeros((1, 3, 224, 224), device=device)
    tb_writer.add_graph(model, init_img)

    if args.weights != "":
        assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
        # weights_dict = torch.load(args.weights, map_location=device)
        weights_dict = torch.load(args.weights, map_location=device)["model"]
        # 删除有关分类类别的权重
        for k in list(weights_dict.keys()):
            if "head" in k:
                del weights_dict[k]
        print(model.load_state_dict(weights_dict, strict=False))

    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除head外,其他权重全部冻结
            if "head" not in name:
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    # pg = [p for p in model.parameters() if p.requires_grad]
    pg = get_params_groups(model, weight_decay=args.wd)
    optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=args.wd)
    lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs,
                                       warmup=True, warmup_epochs=10)

    best_acc = 0.0
    for epoch in range(args.epochs):
        # 计时器time_start
        time_start = time.time()
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch,
                                                lr_scheduler=lr_scheduler)

        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=val_loader,
                                     device=device,
                                     epoch=epoch)
        time_end = time.time()
        f.write("[epoch {}] train_loss: {:.3f},train_acc:{:.3f},val_loss:{:.3f},val_acc:{:.3f},Spend_time:{:.3f}S"
                .format(epoch + 1, train_loss, train_acc, val_loss, val_acc, time_end - time_start))
        f.flush()

        # add Training results into tensorboard
        # #######################################
        tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], train_acc, epoch)
        tb_writer.add_scalar(tags[2], val_loss, epoch)
        tb_writer.add_scalar(tags[3], val_acc, epoch)
        tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

        # add figure into tensorboard
        # #######################################
        fig = plot_class_preds(net=model,
                               images_dir=r"plot_img",
                               transform=data_transform["val"],
                               num_plot=6,
                               device=device)
        if fig is not None:
            tb_writer.add_figure("predictions vs. actuals",
                                 figure=fig,
                                 global_step=epoch)

        if val_acc >= best_acc:
            best_acc = val_acc
            f.write(',save best model')
            torch.save(model.state_dict(), save_path + "/weights/bestmodel.pth")
        f.write('\n')
    f.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=6)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--lr', type=float, default=0.0001)
    parser.add_argument('--wd', type=float, default=5e-2)
    parser.add_argument('--data-path', type=str,
                        default=r"E:\Industrial_inspection\data\MagneticTile_c")
    parser.add_argument('--weights', type=str,
                        default=r"convnext_tiny_1k_224_ema.pth",
                        help='initial weights path')
    # 是否冻结head以外所有权重
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

项目完整文件下载请见演示与介绍视频的简介处给出:➷➷➷

工业缺陷检测实战——磁瓦表面缺陷分类_哔哩哔哩_bilibili

标签:实战,val,磁瓦,args,add,train,path,model,缺陷
From: https://blog.csdn.net/u013289254/article/details/144678240

相关文章

  • 4-Gin HTML 模板渲染 --[Gin 框架入门精讲与实战案例]
    HTML模板渲染下面是使用Gin框架在Go语言中进行HTML模板渲染的四个示例。每个示例都包含了必要的注释来解释代码的作用。示例1:基本模板渲染packagemainimport( "github.com/gin-gonic/gin" "net/http")funcmain(){ r:=gin.Default() //加载HTML模......
  • 【嵌入式Linux】---- 基于petaLinux和SDK开发的LED驱动和应用测试(全流程实战)
    1配置petaLinux环境变量在Linuxproject目录下,打开终端,输入命令source/opt/pkg/petalinux/2018.3/settings.sh2新建petaLinux工程petalinux-create-tproject--templatezynq-nZYNQ7010_LED3配置petaLinux工程输入cdZYNQ7010_LED,进入刚刚创建的工程文件;输入p......
  • Java项目实战之基于 Spring Boot、MyBatis 和 Vue.js 的智能停车场系统设计与技术选型
    1.系统概述本智能停车场系统旨在为停车场提供高效、便捷的管理解决方案,涵盖车辆进出管理、车位预订、停车费用计算、用户信息管理等功能,同时提供管理员操作界面和用户移动端应用,提升停车场运营效率和用户体验。1.1目标实现停车场自动化管理,提高车位利用率,减少人工成本,为用户提......
  • 利用Python爬虫高效获取苏宁商品信息:按关键字搜索的实战指南
    在信息爆炸的今天,数据的获取和处理能力成为了衡量一个企业竞争力的重要指标。对于电商平台而言,如何快速、准确地获取商品信息,成为了提升运营效率的关键。本文将详细介绍如何使用Python爬虫技术,高效地按关键字搜索苏宁商品,并提供详细的代码示例。1.Python爬虫技术概述Python......
  • 鸿蒙Next ArkTS高性能编程实战
    一、引言在应用开发中,高性能编程对于提升用户体验至关重要。本文将详细介绍鸿蒙NextArkTS在高性能编程方面的实践经验,包括声明与表达式、函数、数组以及异常处理等方面的优化技巧,助力开发者打造高效能的应用。二、声明与表达式(一)使用const声明不变的变量在编程过程中,对于那些......
  • Python数据分析-爬虫实战
    数据分析1.爬虫相关概念爬虫的分类聚焦爬虫完成某一项特定数据的采集百分之九十的爬虫都是聚焦爬虫通用爬虫什么内容都采集,都存下来搜索引擎百度谷歌增量爬虫既可以是聚焦爬虫也可以是通用爬虫当内容发生变化,可以增量的获取内容(比如爬取博客,第二天又新......
  • YOLO冲沟缺陷数据集(边坡、地貌)与训练结果分享 - 幽络源
    概述分享这个数据集,一是群内有用户需要,二是自己正好也在做这个数据集,本次分享的数据集为幽络源自行寻找原图手动标注并增强处理,然后已经经过训练测试,F1分数接近1,能覆盖92%的冲沟缺陷与地貌。图像共984张,标注缺陷有1300+处。下载链接:YOLO冲沟数据集,含训练结果与模型展示图使......
  • 【WebGIS项目实战】共享电动车管理系统
    近些年,共享单车、共享充电宝、共享按摩仪,共享电动车、甚至共享汽车,逐渐融入我们的日常。共享经济爆发式增长,对背后的编程技术也提出了更高的要求,在地图应用板块,WebGIS开发的作用也十分亮眼。如何在共享模式下,更好地进行综合调度?如何让用户在使用时,更便捷快速?如何跨越地域,......
  • 【阿尼亚探索大模型】书生大模型实战营-进阶岛第2关(L2G2000)Lagent 自定义你的 Agent
    任务类型任务内容任务一使用Lagent复现“制作一个属于自己的Agent”任务二使用Lagent复现 “Multi-Agents博客写作系统的搭建”任务三将你的Agent部署到HuggingFace或ModelScope平台基础环境配置依然选择30%A100开发机进行实验。使用conda创建虚拟环境。启动......
  • Java 项目实战:基于 Spring Boot 与 Vue.js 技术构建护士排班管理系统的架构设计方案
    一、引言1.1项目背景随着医疗行业的不断发展,医院护士排班管理的复杂性日益增加。传统的手工排班方式难以满足高效、公平、合理的需求,容易出现人力分配不均、员工满意度低等问题。为了提高护士排班的科学性和管理效率,特开发此护士排班管理系统。1.2项目目标本系统旨在实现医......