Pytorch Pytorch Lightning 在每个 epoch 结束时打印准确率和损失值

Pytorch Pytorch Lightning 在每个 epoch 结束时打印准确率和损失值

在本文中,我们将介绍如何在使用 Pytorch Lightning 进行深度学习训练时,在每个 epoch 结束时打印准确率和损失值。Pytorch Lightning 是一个轻量级的 Pytorch 扩展包,能够简化机器学习和深度学习模型的开发和训练过程。

阅读更多:Pytorch 教程

Pytorch Lightning 简介

Pytorch Lightning 是一个基于 Pytorch 的轻量级开发框架,它提供了一些额外的功能和抽象,使得开发和训练深度学习模型更加容易和高效。Pytorch Lightning 的目标是解决在深度学习模型开发过程中的冗余代码和重复工作。

在每个 epoch 结束时打印准确率和损失值

要在每个 epoch 结束时打印准确率和损失值,我们可以使用 Pytorch Lightning 中的回调函数(Callback)来实现。回调函数是一种在训练过程中某些特定阶段执行的功能代码块。在 Pytorch Lightning 中,我们可以自定义回调函数来满足我们的需求。

首先,我们需要创建一个自定义回调函数。在该回调函数中,我们可以使用 Pytorch Lightning 提供的几个内置回调函数来获取当前的准确率和损失值,并在每个 epoch 结束时打印出来。下面是一个简单的例子:

import pytorch_lightning as pl

class PrintAccuracyAndLossCallback(pl.Callback):
    def on_epoch_end(self, trainer, pl_module):
        # 获取当前 epoch 的训练损失值
        train_loss = trainer.callback_metrics['train_loss']

        # 获取当前 epoch 的准确率
        val_accuracy = trainer.callback_metrics['val_accuracy']

        # 打印当前 epoch 的准确率和损失值
        print(f"Epoch {trainer.current_epoch}: Train Loss {train_loss:.4f}, Val Accuracy {val_accuracy:.4f}")

然后,我们需要将该回调函数添加到训练过程中。在使用 Pytorch Lightning 训练模型的代码中,我们可以添加以下行来将回调函数添加到训练过程中:

print_accuracy_and_loss_callback = PrintAccuracyAndLossCallback()
trainer = pl.Trainer(callbacks=[print_accuracy_and_loss_callback])

以上代码将在每个 epoch 结束时调用自定义回调函数,并打印当前的准确率和损失值。

示例说明

我们来看一个具体的示例,使用 Pytorch Lightning 训练一个简单的卷积神经网络(CNN)模型来识别 mnist 数据集中的手写数字。

首先,我们定义一个包含卷积层、全连接层和输出层的简单 CNN 模型:

import torch.nn as nn
import torch.nn.functional as F

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3)
        self.fc1 = nn.Linear(16*26*26, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

然后,我们使用 Pytorch Lightning 进行训练和验证:

import pytorch_lightning as pl
import torchvision
from torch.utils.data import DataLoader

# 创建训练数据集和验证数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=torchvision.transforms.ToTensor())
val_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=torchvision.transforms.ToTensor())

# 创建训练数据加载器和验证数据加载器
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32)

class LightningModel(pl.LightningModule):
    def __init__(self):
        super(LightningModel, self).__init__()
        self.model = CNNModel()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        val_loss = F.cross_entropy(y_pred, y)
        self.log('val_loss', val_loss)
        return val_loss

    def validation_epoch_end(self, outputs):
        avg_val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        self.log('val_loss_epoch', avg_val_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

# 创建 Pytorch Lightning 的训练器
trainer = pl.Trainer(callbacks=[PrintAccuracyAndLossCallback()])

# 创建模型实例并开始训练
model = LightningModel()
trainer.fit(model, train_dataloader, val_dataloader)

运行以上代码后,每个 epoch 结束时将会打印当前的训练损失值和验证准确率。

总结

本文介绍了如何使用 Pytorch Lightning 在每个 epoch 结束时打印准确率和损失值。通过自定义回调函数并利用 Pytorch Lightning 的内置函数和方法,我们可以在训练过程中方便地获取并打印这些指标。这样可以帮助我们更好地了解模型的训练情况,并进行模型调优和改进。

Camera课程

Python教程

Java教程

Web教程

数据库教程

图形图像教程

办公软件教程

Linux教程

计算机教程

大数据教程

开发工具教程