当前位置:首页 > 技术 > 正文内容

如何使用 Yolov4 训练人脸口罩检测模型

Lotus2022-10-06 18:16技术

前言

疫情当下,出入医院等公共场所都被要求佩戴口罩。这篇博客将会介绍如何使用 Yolov4,训练一个人脸口罩检测模型(使用 Yolov4 的原因是目前只复现到了 v4 ????),代码地址为 https://github.com/zhiyiYo/yolov4

Yolov4

Yolov4 的神经网络结构相比 Yolov3 变化不是很大,主要更换了激活函数为 Mish,增加了 SPP 块和 PAN 结构(图源 《yolo系列学习笔记----yolov4(SPP原理)》)。

Yolov4 神经网络结构

感觉 Yolov4 最大的特点就是使用了一大堆的 Trick,比如数据增强方面使用了马赛克数据增强、Mixup 数据增强,将定位损失函数更换为 CIOU 损失。论文中提到了很多的 Trick,我的代码中没有全部复现,不过在 VOC2012 数据集训练了 160 个 epoch 之后 mAP 也能达到 83%,效果还是不错的。

可以在终端使用下述指令下载 Yolov4 的代码:

git clone https://github.com/zhiyiYo/yolov4.git

人脸口罩数据集

网上可以找到很多人脸口罩数据集,这里使用的是 AIZOOTech 提供的数据集。由于这个数据集的结构和 Pascal VOC 数据集不一样,所以重新组织一下数据集,并且修复和移除了数据集中的非法标签,可以在 Kaggle 上下载此数据集。目前这个数据集包含 6130 张训练图像,1839 张测试图像,对于 Yolov4 的训练来说应该是绰绰有余的。下载完数据集将其解压到 data 文件夹下。

在训练之前,我们需要使用 K-means 聚类算法对训练集中的边界框进行聚类,对于 416×416 的输入图像,聚类结果如下:

anchors = [
    [[100, 146], [147, 203], [208, 260]],
    [[26, 43], [44, 65], [65, 105]],
    [[4, 8], [8, 15], [15, 27]]
]

训练神经网络

训练目标检测模型一般都需要加载预训练的主干网络的权重,可以从谷歌云盘下载预训练好的权重 CSPDarknet53.pth 并将其放在 model 文件夹下。这里给出训练所用的代码 train.py,使用 python train.py 就能开始训练。模型会先冻结训练上 50 个 epoch,接着解冻训练 110 个 epoch:

# coding:utf-8
from net import TrainPipeline, VOCDataset
from utils.augmentation_utils import YoloAugmentation, ColorAugmentation

# 训练配置
config = {
    "n_classes": len(VOCDataset.classes),
    "image_size": 416,
    "anchors": [
        [[100, 146], [147, 203], [208, 260]],
        [[26, 43], [44, 65], [65, 105]],
        [[4, 8], [8, 15], [15, 27]]
    ],
    "darknet_path": "model/CSPdarknet53.pth",
    "lr": 1e-2,
    "batch_size": 8,
    "freeze_batch_size": 16,
    "freeze": True,
    "freeze_epoch": 50,
    "max_epoch": 160,
    "start_epoch": 0,
    "num_workers": 4,
    "save_frequency": 10,
    "no_aug_ratio": 0
}

# 加载数据集
root = 'data/FaceMaskDataset/train'
dataset = VOCDataset(
    root,
    'all',
    transformer=YoloAugmentation(config['image_size']),
    color_transformer=ColorAugmentation(config['image_size']),
    use_mosaic=True,
    use_mixup=True,
    image_size=config["image_size"]
)

if __name__ == '__main__':
    train_pipeline = TrainPipeline(dataset=dataset, **config)
    train_pipeline.train()

测试神经网络

训练完使用 python evals.py 可以测试所有保存的模型,evals.py 代码如下:

# coding:utf-8
import json
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt

from net import EvalPipeline, VOCDataset

mpl.rc_file('resource/theme/matlab.mplstyle')


# 载入数据集
root = 'data/FaceMaskDataset/val'
dataset = VOCDataset(root, 'all')
anchors = [
    [[100, 146], [147, 203], [208, 260]],
    [[26, 43], [44, 65], [65, 105]],
    [[4, 8], [8, 15], [15, 27]]
]

# 列出所有模型,记得修改 Yolo 模型文件夹的路径
model_dir = Path('model/2022-10-05_22-59-44')
model_paths = [i for i in model_dir.glob('Yolo_*')]
model_paths.sort(key=lambda i: int(i.stem.split("_")[1]))

# 测试所有模型
mAPs = []
iterations = []
for model_path in model_paths:
    iterations.append(int(model_path.stem[5:]))
    ep = EvalPipeline(model_path, dataset, anchors=anchors, conf_thresh=0.001)
    mAPs.append(ep.eval()*100)

# 保存数据
with open('eval/mAPs.json', 'w', encoding='utf-8') as f:
    json.dump(mAPs, f)
    
# 绘制 mAP 曲线
fig, ax = plt.subplots(1, 1, num='mAP 曲线')
ax.plot(iterations, mAPs)
ax.set(xlabel='iteration', ylabel='mAP', title='mAP curve')
plt.show()

得到的 mAP 曲线如下图所示,在第 120 个 epoch 达到最大值 94.14%:

mAP 曲线

下面使用一张真实图像看看训练效果如何,运行 demo.py

# coding:utf-8
from net import VOCDataset
from utils.detection_utils import image_detect

# 模型文件和图片路径
model_path = 'model/Yolo_120.pth'
image_path = 'resource/image/三上老师.jpg'

# 检测目标
anchors = [
    [[100, 146], [147, 203], [208, 260]],
    [[26, 43], [44, 65], [65, 105]],
    [[4, 8], [8, 15], [15, 27]]
]
image = image_detect(model_path, image_path, VOCDataset.classes, anchors=anchors, conf_thresh=0.5)
image.show()

不错,效果非常好 ????:

三上老师

后记

至此,介绍完了训练 Yolov4 人脸口罩检测模型的过程,代码放在了 https://github.com/zhiyiYo/yolov4,以上~~

扫描二维码推送至手机访问。

版权声明:本文来源于网络,仅供学习,如侵权请联系站长删除。

本文链接:https://news.layui.org.cn/post/112.html

分享给朋友:

“如何使用 Yolov4 训练人脸口罩检测模型” 的相关文章

2022.9.30 Java第四次课后总结

1.public class BoxAndUnbox { /** * @param args */ public static void main(String[] args) { int value=100; Integer obj=value; //装箱 int result=obj*2; //拆箱 } } 创建了一个value 并定义了相关变量 public cl...

TTD 专题 (第一篇):C# 那些短命线程都在干什么?

一:背景 1.讲故事 在分析的众多dump中,经常会遇到各种奇葩的问题,仅通过dump这种快照形式还是有很多问题搞不定,而通过 perfview 这种粒度又太粗,很难找到问题之所在,真的很头疼,比如本篇的 短命线程 问题,参考图如下: 我们在 t2 时刻抓取的dump对查看 短命线程 毫无帮助,我根本就不知道这个线程生前执行了什么代码,为什么这么短命,还就因为这样的短命让 线程池 的线程暴增。...

IPv6报文头深度解析

IPv6报文由IPv6基本报文头、IPv6扩展报文头以及上层协议数据单元3部分组成。上层协议数据单元一般由上层协议报文头和它的有效载荷构成,上层协议数据单元可以是一个ICMPv6报文、一个TCP报文或一个UDP报文。 1、IPv6基本报文头 IPv6基本报文头有8个字段,固定大小为40 Byte,每一个IPv6数据报都必须包含基本报文头。基本报文头提供报文转发的基本信息,由转发路径上的所有设备解...

JS奇淫技巧:数值的七种写法

JS奇淫技巧:数值的七种写法 JS奇淫技巧:挑战前端黑科技,数值的七种写法,能全看懂的一定是高手 你知道吗?在JS编程中,数值可以有很多种写法。 第一种写法: 一般情况而言,数值就是数值。 比如: var a = 1; 你可知,这个1可以有很多种变形的写法,甚至是变态的写法。 第二种写法: var a= +!!{}; console.log(a); 即:1变成了+!!{}。 数值1为什么能...

用深度强化学习玩FlappyBird

摘要:学习玩游戏一直是当今AI研究的热门话题之一。使用博弈论/搜索算法来解决这些问题需要特别地进行周密的特性定义,使得其扩展性不强。使用深度学习算法训练的卷积神经网络模型(CNN)自提出以来在图像处理领域的多个大规模识别任务上取得了令人瞩目的成绩。本文是要开发一个一般的框架来学习特定游戏的特性并解决这个问题,其应用的项目是受欢迎的手机游戏Flappy Bird,控制游戏中的小鸟穿过一堆障碍物。本文...

前端三剑客快速入门(二)

前言 本文书接上回,继续css的知识,序号就重新开始了。上篇内容:前端三剑客快速入门(一) CSS 盒子模型 盒子模型属性: border外框 margin外边距 padding内边距 <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <me...

发表评论

访客

看不清,换一张

◎欢迎参与讨论,请在这里发表您的看法和观点。