Pytorch 迭代torch.utils.data.random_split中的子集
在本文中,我们将介绍如何使用Pytorch中的torch.utils.data.random_split
函数来划分数据集,并遍历划分出的子集。
阅读更多:Pytorch 教程
1. 数据集的划分和迭代
在机器学习和深度学习中,数据集的划分是非常常见的操作。我们经常需要将整个数据集划分为训练集、验证集和测试集等,以便用于模型的训练、调优和评估。
Pytorch提供了便捷的工具函数torch.utils.data.random_split
,用于将数据集划分为指定大小的子集。
该函数的定义如下:
torch.utils.data.random_split(dataset, lengths, generator=None)
其中参数含义如下:
– dataset
:要划分的数据集,可以是torch.utils.data.Dataset
的子类的实例。
– lengths
:一个包含划分后子集大小的列表。列表中的元素个数应与划分出的子集个数相同。
– generator
:可选参数,用于指定划分的随机数生成器。
下面我们用一个简单的例子来说明如何使用torch.utils.data.random_split
函数进行数据集划分和迭代。
import torch
from torch.utils.data import Dataset, random_split
# 定义一个自定义的数据集类
class MyDataset(Dataset):
def __init__(self, filepath):
# 加载数据集
self.data = torch.load(filepath)
def __len__(self):
# 返回数据集的样本数量
return len(self.data)
def __getitem__(self, index):
# 返回索引对应的样本
return self.data[index]
# 实例化自定义数据集类
dataset = MyDataset("data.pt")
# 定义划分后子集的大小
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
# 使用torch.utils.data.random_split函数划分数据集
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])
# 遍历训练集并输出样本数量
print("训练集样本数量:", len(train_set))
for sample in train_set:
print(sample)
# 遍历验证集并输出样本数量
print("验证集样本数量:", len(val_set))
for sample in val_set:
print(sample)
# 遍历测试集并输出样本数量
print("测试集样本数量:", len(test_set))
for sample in test_set:
print(sample)
上述代码中,我们定义了一个自定义数据集类MyDataset
,其中实现了__len__
方法返回数据集的样本数量,以及__getitem__
方法返回索引对应的样本。
然后我们实例化了MyDataset
类,并使用torch.utils.data.random_split
函数将数据集划分为训练集、验证集和测试集,并指定了划分后子集的大小。
最后,我们通过遍历每个子集并输出样本数量,验证了数据集的划分和迭代。
2. 划分不等长的子集
除了划分等长的子集外,torch.utils.data.random_split
函数还可以划分不等长的子集。这在某些特殊场景下非常有用,例如一些类别的样本较少,我们希望将它们放在子集的前面,以提高样本分布的平衡性。
下面我们通过一个示例来演示如何划分不等长的子集。
假设我们有一个多类别的数据集,每个类别的样本数量如下:[100, 150, 200, 250]。我们希望按照类别分布将数据集划分为训练集、验证集和测试集,并保证每个子集中各个类别的样本比例相同。
import torch
from torch.utils.data import Dataset, random_split
# 定义一个自定义的数据集类
class MyDataset(Dataset):
def __init__(self, filepath):
# 加载数据集
self.data = torch.load(filepath)
def __len__(self):
# 返回数据集的样本数量
return len(self.data)
def __getitem__(self, index):
# 返回索引对应的样本
return self.data[index]
# 模拟类别样本数量
class_counts = [100, 150, 200, 250]
# 计算划分后各个子集的大小
total_samples = sum(class_counts)
train_size = int(0.6 * total_samples)
val_size = int(0.2 * total_samples)
test_size = total_samples - train_size - val_size
# 划分数据集
dataset = MyDataset("data.pt")
subset_sizes = [train_size, val_size, test_size]
subsets = random_split(dataset, subset_sizes)
# 遍历每个子集并输出样本数量和类别分布
for i, subset in enumerate(subsets):
print("子集{}样本数量:{}".format(i+1, len(subset)))
class_counts_subset = [0] * len(class_counts)
for sample in subset:
class_counts_subset[sample["label"]] += 1
print("子集{}类别分布:{}".format(i+1, class_counts_subset))
上述代码中,我们模拟了一个多类别的数据集,其中每个类别的样本数量分别为[100, 150, 200, 250]。
我们首先计算出划分后各个子集的大小,并实例化自定义数据集类MyDataset
。
然后我们使用torch.utils.data.random_split
函数将数据集划分为训练集、验证集和测试集,并保证了每个子集中各个类别的样本比例相同。
最后,我们遍历每个子集,并输出样本数量和每个子集的类别分布,以验证划分结果的正确性。
总结
在本文中,我们介绍了如何使用Pytorch中的torch.utils.data.random_split
函数来划分数据集,并遍历划分出的子集。我们学习了如何划分等长的子集,以及如何划分不等长的子集。这些操作在数据集划分和迭代中非常常见,并且通过示例代码演示了如何使用这些函数来实现这些操作。希望本文对您在使用Pytorch进行数据集划分和迭代时有所帮助。