TensorBoard 是TensorFlow的可视化工具包,提供机器学习实验所需的可视化功能和工具:
- 跟踪和可视化损失及准确率等指标
- 可视化模型图(操作和层)
- 查看权重、偏差或其他张量随时间变化的直方图
- 将嵌入投射到较低的维度空间
- 显示图片、文字和音频数据
- 剖析 TensorFlow 程序
安装 TensorBoard
pip install tensorboard
启动 TensorBoard
log_dir(日志目录)相关请看SummaryWriter API
从命令行进入'log_dir'所在目录,然后运行如下命令来启动 TensorBoard:
tensorboard --logdir=日志目录名
默认在6006端口启动,也可以通过以下命令指定 TensorBoard 的启动端口:
tensorboard --logdir=日志目录名 --port=6007
SummaryWriter API
SummaryWriter
API用于在给定日志目录中创建事件文件,并向其中添加摘要和事件,以供TensorBoard
使用。
创建 SummaryWriter
实例
from torch.utils.tensorboard import SummaryWriter
'''
writer = SummaryWriter(log_dir=None, comment="")
log_dir:事件文件保存的目录地址,默认是 runs/**CURRENT_DATETIME_HOSTNAME**。、
comment:注释日志目录后缀附加到默认的“log_dir”。如果指定了“log_dir”,则此参数不起作用。
'''
# 日志目录地址:runs/Apr30_23-04-41_DESKTOP-56I3UUD
writer = SummaryWriter()
# 日志目录地址:runs/Apr30_23-04-43_DESKTOP-56I3UUDtest-comment
writer = SummaryWriter(comment='test-comment')
# 日志目录地址:logs
writer = SummaryWriter('logs')
# 在命令行中通过 tensorboard --logdir=logs 启动 TensorBoard
add_scalar()
add_scalar
方法向摘要中添加标量数据,通常用来可视化网络训练中的各类标量参数,例如损失、学习率和准确率等。
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('logs')
'''
writer.add_scalar(tag, scalar_value, global_step=None)
tag(str): 标签,标识数据
scalar_value(float/str):要保存的标量值
global_step(int):要记录的全局步长值
'''
for i in range(100):
writer.add_scalar('y=2x', 2*i, i)
writer.close()
如果再创建一个新的事件文件,tag
也是y=2x
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('logs')
for i in range(100):
writer.add_scalar('y=2x', 3*i, i)
writer.close()
可以看到他们被可视化到同一个图里,怎么解决这个问题呢?
方法一:删除 log_dir 下的原事件文件并杀死程序重新启动。这会摧毁训练历史信息。
方法二:建一个顶层的日志目录,每个新的训练工作都在顶层日志目录下新建一个子目录。
add_image()的使用
add_image
方法将图片数据添加到摘要,常用来观察训练结果,可视化相应的像素矩阵,例如本地图片,或者是特征图等。
import numpy as np
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('logs') # 创建目录文件
'''
writer.add_image(tag, img_tensor, global_step=None, dataformats='CHW')
tag(str):数据标识
img_tensor(torch.Tensor/numpy.ndarray/str/blobbane):图片数据
global_step(int):要记录的全局步长值
dataformats:图片数据格式规范的表单CHW,HWC,HW,WH等(C即channels通道,H即特征图的高,W即特征图的宽)
'''
img = np.random.randn(1, 100, 100)
writer.add_image('test-img', img)
img_path = 'dataset/train/ants/5650366_e22b7e1065.jpg' # 图片路径
img_pil = Image.open(img_path) # 读取图片
img_arr = np.array(img_pil) # 将图片转化为numpy.ndarray类型
writer.add_image('test-img', img_arr, 1, dataformats='HWC') # 添加图片数据
writer.close()
注意:从PIL转到numpy,在add_image()
中要指定维度信息HWC
,dataformats
默认是CHW
。
还可以通过 opencv 读取图片来获得 numpy 型图片数据。