Pytorch Pytorch Lightning 在每个 epoch 结束时打印准确率和损失值
在本文中,我们将介绍如何在使用 Pytorch Lightning 进行深度学习训练时,在每个 epoch 结束时打印准确率和损失值。Pytorch Lightning 是一个轻量级的 Pytorch 扩展包,能够简化机器学习和深度学习模型的开发和训练过程。
阅读更多:Pytorch 教程
Pytorch Lightning 简介
Pytorch Lightning 是一个基于 Pytorch 的轻量级开发框架,它提供了一些额外的功能和抽象,使得开发和训练深度学习模型更加容易和高效。Pytorch Lightning 的目标是解决在深度学习模型开发过程中的冗余代码和重复工作。
在每个 epoch 结束时打印准确率和损失值
要在每个 epoch 结束时打印准确率和损失值,我们可以使用 Pytorch Lightning 中的回调函数(Callback)来实现。回调函数是一种在训练过程中某些特定阶段执行的功能代码块。在 Pytorch Lightning 中,我们可以自定义回调函数来满足我们的需求。
首先,我们需要创建一个自定义回调函数。在该回调函数中,我们可以使用 Pytorch Lightning 提供的几个内置回调函数来获取当前的准确率和损失值,并在每个 epoch 结束时打印出来。下面是一个简单的例子:
import pytorch_lightning as pl
class PrintAccuracyAndLossCallback(pl.Callback):
def on_epoch_end(self, trainer, pl_module):
# 获取当前 epoch 的训练损失值
train_loss = trainer.callback_metrics['train_loss']
# 获取当前 epoch 的准确率
val_accuracy = trainer.callback_metrics['val_accuracy']
# 打印当前 epoch 的准确率和损失值
print(f"Epoch {trainer.current_epoch}: Train Loss {train_loss:.4f}, Val Accuracy {val_accuracy:.4f}")
然后,我们需要将该回调函数添加到训练过程中。在使用 Pytorch Lightning 训练模型的代码中,我们可以添加以下行来将回调函数添加到训练过程中:
print_accuracy_and_loss_callback = PrintAccuracyAndLossCallback()
trainer = pl.Trainer(callbacks=[print_accuracy_and_loss_callback])
以上代码将在每个 epoch 结束时调用自定义回调函数,并打印当前的准确率和损失值。
示例说明
我们来看一个具体的示例,使用 Pytorch Lightning 训练一个简单的卷积神经网络(CNN)模型来识别 mnist 数据集中的手写数字。
首先,我们定义一个包含卷积层、全连接层和输出层的简单 CNN 模型:
import torch.nn as nn
import torch.nn.functional as F
class CNNModel(nn.Module):
def __init__(self):
super(CNNModel, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3)
self.fc1 = nn.Linear(16*26*26, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
然后,我们使用 Pytorch Lightning 进行训练和验证:
import pytorch_lightning as pl
import torchvision
from torch.utils.data import DataLoader
# 创建训练数据集和验证数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=torchvision.transforms.ToTensor())
val_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=torchvision.transforms.ToTensor())
# 创建训练数据加载器和验证数据加载器
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32)
class LightningModel(pl.LightningModule):
def __init__(self):
super(LightningModel, self).__init__()
self.model = CNNModel()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_pred = self(x)
loss = F.cross_entropy(y_pred, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_pred = self(x)
val_loss = F.cross_entropy(y_pred, y)
self.log('val_loss', val_loss)
return val_loss
def validation_epoch_end(self, outputs):
avg_val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
self.log('val_loss_epoch', avg_val_loss)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
return optimizer
# 创建 Pytorch Lightning 的训练器
trainer = pl.Trainer(callbacks=[PrintAccuracyAndLossCallback()])
# 创建模型实例并开始训练
model = LightningModel()
trainer.fit(model, train_dataloader, val_dataloader)
运行以上代码后,每个 epoch 结束时将会打印当前的训练损失值和验证准确率。
总结
本文介绍了如何使用 Pytorch Lightning 在每个 epoch 结束时打印准确率和损失值。通过自定义回调函数并利用 Pytorch Lightning 的内置函数和方法,我们可以在训练过程中方便地获取并打印这些指标。这样可以帮助我们更好地了解模型的训练情况,并进行模型调优和改进。