PyTorch 实现深度自动编码器进行图像重建
机器学习是人工智能的一个分支,包括开发能够使计算机从输入数据中学习并进行决策或预测而不需要硬编码的统计模型和算法。它涉及使用大型数据集训练机器学习算法,使机器能够识别数据中的模式和关系。
什么是自动编码器
具有自动编码器的神经网络架构用于无监督学习任务。它由一组编码器和解码器网络组成,经过训练能够通过将输入数据压缩成低维表示(编码)然后解码以恢复其原始形式来重建输入数据。
为了鼓励网络学习到有价值的特征或数据表示,目标是最小化输入和输出之间的重建误差。自动编码器广泛用于数据压缩、图像去噪和异常检测等领域。这可以减少与数据传输相关的大量工作和成本。
在本文中,我们将探讨如何使用PyTorch的深度自动编码器进行图像重建。这个深度学习模型将在MNIST手写数字上进行训练,学习输入图像的表示后,将重建数字图像。一个基本的自动编码器包括两个主要的函数:
- 编码器
-
解码器
编码器将输入数据经过一系列层次将高维数据转换为低维潜在表示。解码器使用这个潜在表示来使用Python库torch、torch vision以及numpy和matplotlib等常用库生成重建的数据。
步骤
- 导入所有所需的库。
-
初始化将应用于获得的数据集中的每个条目的转换操作。
-
由于PyTorch需要张量来进行计算,因此我们首先将每个条目转换为张量并对其进行归一化,以保持像素值在0到1之间的范围。
-
使用torchvision.datasets程序下载数据集,并将其分别保存在文件夹./MNIST/train和./MNIST/test中,用于训练集和测试集。
-
为了加快学习速度,将这些数据集转换为批量大小为64的数据加载器。
-
随机从集合中打印出25张照片,以便更好地了解我们正在处理的信息。
步骤1:初始化
这一步涉及导入所有必要的库,如numpy、matplotlib、pytorch和torchvision。
语法
torchvision.transforms.ToTensor():
将输入图像(以PIL或numpy格式)转换为PyTorch张量格式。此转换还将像素强度从范围[0, 255]缩放到[0, 1]。
torchvision.transforms.Normalize(mean, std)
根据均值和标准差值对输入图像张量进行归一化。这个转换有助于在训练期间提高深度学习模型的收敛速率。均值和标准差值通常是从训练数据集中计算出来的。
torchvision.transforms.Compose(transforms)
允许将多个图像转换链接到单个对象中。此对象可以传递给PyTorch数据集对象,在训练或推理过程中动态应用转换。
示例
#importing modules
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
plt.rcParams['figure.figsize'] = 15, 10
# Initialize the transform operation
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5), (0.5))
])
# Download the inbuilt MNIST data
train_dataset = torchvision.datasets.MNIST(
root="./MNIST/train", train=True,
transform=torchvision.transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(
root="./MNIST/test", train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
输出
步骤2:初始化自编码器
我们首先初始化自编码器类,它是torch.nn.Module的一个子类。现在我们可以专注于创建我们的模型架构,如下所示,因为这样可以为我们抽象出很多样板代码。
语法
torch.nn.Linear()
将输入张量应用线性变换的模块。
my_linear_layer = nn.Linear(in_features, out_features, bias=True)
torch.nn.ReLU()
应用修正线性单元(ReLU)函数到输入张量的一个激活函数。
torch.nn.Sigmoid()
一个将sigmoid函数应用于输入张量的激活函数。
示例
#Creating the autoencoder classes
class Autoencoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.encoder=torch.nn.Sequential(
torch.nn.Linear(28*28,128), #N, 784 -> 128
torch.nn.ReLU(),
torch.nn.Linear(128,64),
torch.nn.ReLU(),
torch.nn.Linear(64,12),
torch.nn.ReLU(),
torch.nn.Linear(12,3), # --> N, 3
torch.nn.ReLU()
)
self.decoder=torch.nn.Sequential(
torch.nn.Linear(3,12), #N, 3 -> 12
torch.nn.ReLU(),
torch.nn.Linear(12,64),
torch.nn.ReLU(),
torch.nn.Linear(64,128),
torch.nn.ReLU(),
torch.nn.Linear(128,28*28), # --> N, 28*28
torch.nn.Sigmoid()
)
def forward(self,x):
encoded=self.encoder(x)
decoded = self.decoder(encoded)
return decoded
# Instantiating the model and hyperparameters
model = Autoencoder()
criterion = torch.nn.MSELoss()
num_epochs = 10
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
步骤3:创建训练循环
我们正在训练一个自编码器模型来学习图像的压缩表示。训练循环总共通过数据集进行了10次。
- 计算模型对每批照片的输出,迭代处理每批图像。
-
然后计算输出照片与原始图片之间的质量差异。
-
对于每批次平均计算损失,并将图像和它们的输出存储每个周期。
-
我们在循环结束时绘制训练损失,以帮助理解训练过程。
图形表明,损失随着每个周期的过去而降低,表明模型正在吸收新的信息,训练过程成功。
训练循环通过最小化输出图像与原始图像之间的损失,训练自编码器模型来学习图像的压缩表示。损失随着每个周期减小,表示训练成功。
示例
# Create empty list to store the training loss
train_loss = []
# Create empty dictionary to store the images and their reconstructed outputs
outputs = {}
# Loop through each epoch
for epoch in range(num_epochs):
# Initialize variable for storing the running loss
running_loss = 0
# Loop through each batch in the training data
for batch in train_loader:
# Load the images and their labels
img, _ = batch
# Flatten the images into a 1D tensor
img = img.view(img.size(0), -1)
# Generate the output for the autoencoder model
out = model(img)
# Calculate the loss between the input and output images
loss = criterion(out, img)
# Reset the gradients
optimizer.zero_grad()
# Compute the gradients
loss.backward()
# Update the weights
optimizer.step()
# Increment the running loss by the batch loss
running_loss += loss.item()
# Calculate the average running loss over the entire dataset
running_loss /= len(train_loader)
# Add the running loss to the list of training losses
train_loss.append(running_loss)
# Store the input and output images for the last batch
outputs[epoch+1] = {'input': img, 'output': out}
# Plot the training loss over epochs
plt.plot(range(1, num_epochs+1), train_loss)
plt.xlabel("Number of epochs")
plt.ylabel("Training Loss")
plt.show()
输出
步骤4:可视化
使用此代码绘制经过训练的自编码器模型的原始图像和重构图像。输出变量包括关于模型输出的数据,例如在不同训练时期中记录的重构图像和损失值。要绘制特定时期的重构图像,请使用list_epochs变量。
程序绘制给定时期最新批次中的前五个重构图像。
示例
# Plot the re-constructed images
# Initializing the counter
count = 1
# Plotting the reconstructed images
list_epochs = [1, 5, 10]
# Iterate over specified epochs
for val in list_epochs:
# Extract recorded information
temp = outputs[val]['out'].detach().numpy()
title_text = f"Epoch = {val}"
# Plot first 5 images of the last batch
for idx in range(5):
plt.subplot(7, 5, count)
plt.title(title_text)
plt.imshow(temp[idx].reshape(28,28), cmap= 'gray')
plt.axis('off')
# Increment the count
count+=1
# Plot of the original images
# Iterating over first five
# images of the last batch
for idx in range(5):
# Obtaining image from the dictionary
val = outputs[10]['img']
# Plotting image
plt.subplot(7,5,count)
plt.imshow(val[idx].reshape(28, 28),
cmap = 'gray')
plt.title("Original Image")
plt.axis('off')
# Increment the count
count+=1
plt.tight_layout()
plt.show()
输出
步骤 5:测试集性能评估
这段代码是一个示例,展示了如何评估训练好的自编码器模型在一个测试集上的性能。
根据重建图像的视觉检查结果,该代码得出结论,自编码器模型在测试集上表现良好。如果模型在测试集上表现良好,则很可能在新的、未见过的数据上也表现良好。
示例
outputs = {}
# Extract the last batch dataset
img, _ = list(test_loader)[-1]
img = img.reshape(-1, 28 * 28)
#Generating output
out = model(img)
# Storing results in the dictionary
outputs['img'] = img
outputs['out'] = out
# Initialize subplot count
count = 1
val = outputs['out'].detach().numpy()
# Plot first 10 images of the batch
for idx in range(10):
plt.subplot(2, 10, count)
plt.title("Reconstructed \n image")
plt.imshow(val[idx].reshape(28, 28), cmap='gray')
plt.axis('off')
# Increment subplot count
count += 1
# Plotting original images
# Plotting first 10 images
for idx in range(10):
val = outputs['img']
plt.subplot(2, 10, count)
plt.imshow(val[idx].reshape(28, 28), cmap='gray')
plt.title("Original Image")
plt.axis('off')
count += 1
plt.tight_layout()
plt.show()
输出
结论
总之,自动编码器是强大的神经网络,可以应用于许多不同的任务,包括数据压缩、异常检测和图像生成。TensorFlow、Keras和PyTorch是一些使自动编码器开发简单的Python工具。通过深入理解架构并调整设置,您可以开发出非常强大的自动编码器模型。随着机器学习作为一个领域的发展,自动编码器可能会继续成为各种应用的有用工具。