Pytorch 迭代torch.utils.data.random_split中的子集

Pytorch 迭代torch.utils.data.random_split中的子集

在本文中,我们将介绍如何使用Pytorch中的torch.utils.data.random_split函数来划分数据集,并遍历划分出的子集。

阅读更多:Pytorch 教程

1. 数据集的划分和迭代

在机器学习和深度学习中,数据集的划分是非常常见的操作。我们经常需要将整个数据集划分为训练集、验证集和测试集等,以便用于模型的训练、调优和评估。

Pytorch提供了便捷的工具函数torch.utils.data.random_split,用于将数据集划分为指定大小的子集。

该函数的定义如下:

torch.utils.data.random_split(dataset, lengths, generator=None)

其中参数含义如下:
dataset:要划分的数据集,可以是torch.utils.data.Dataset的子类的实例。
lengths:一个包含划分后子集大小的列表。列表中的元素个数应与划分出的子集个数相同。
generator:可选参数,用于指定划分的随机数生成器。

下面我们用一个简单的例子来说明如何使用torch.utils.data.random_split函数进行数据集划分和迭代。

import torch
from torch.utils.data import Dataset, random_split

# 定义一个自定义的数据集类
class MyDataset(Dataset):
    def __init__(self, filepath):
        # 加载数据集
        self.data = torch.load(filepath)

    def __len__(self):
        # 返回数据集的样本数量
        return len(self.data)

    def __getitem__(self, index):
        # 返回索引对应的样本
        return self.data[index]

# 实例化自定义数据集类
dataset = MyDataset("data.pt")

# 定义划分后子集的大小
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

# 使用torch.utils.data.random_split函数划分数据集
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

# 遍历训练集并输出样本数量
print("训练集样本数量:", len(train_set))
for sample in train_set:
    print(sample)

# 遍历验证集并输出样本数量
print("验证集样本数量:", len(val_set))
for sample in val_set:
    print(sample)

# 遍历测试集并输出样本数量
print("测试集样本数量:", len(test_set))
for sample in test_set:
    print(sample)

上述代码中,我们定义了一个自定义数据集类MyDataset,其中实现了__len__方法返回数据集的样本数量,以及__getitem__方法返回索引对应的样本。

然后我们实例化了MyDataset类,并使用torch.utils.data.random_split函数将数据集划分为训练集、验证集和测试集,并指定了划分后子集的大小。

最后,我们通过遍历每个子集并输出样本数量,验证了数据集的划分和迭代。

2. 划分不等长的子集

除了划分等长的子集外,torch.utils.data.random_split函数还可以划分不等长的子集。这在某些特殊场景下非常有用,例如一些类别的样本较少,我们希望将它们放在子集的前面,以提高样本分布的平衡性。

下面我们通过一个示例来演示如何划分不等长的子集。

假设我们有一个多类别的数据集,每个类别的样本数量如下:[100, 150, 200, 250]。我们希望按照类别分布将数据集划分为训练集、验证集和测试集,并保证每个子集中各个类别的样本比例相同。

import torch
from torch.utils.data import Dataset, random_split

# 定义一个自定义的数据集类
class MyDataset(Dataset):
    def __init__(self, filepath):
        # 加载数据集
        self.data = torch.load(filepath)

    def __len__(self):
        # 返回数据集的样本数量
        return len(self.data)

    def __getitem__(self, index):
        # 返回索引对应的样本
        return self.data[index]

# 模拟类别样本数量
class_counts = [100, 150, 200, 250]

# 计算划分后各个子集的大小
total_samples = sum(class_counts)
train_size = int(0.6 * total_samples)
val_size = int(0.2 * total_samples)
test_size = total_samples - train_size - val_size

# 划分数据集
dataset = MyDataset("data.pt")
subset_sizes = [train_size, val_size, test_size]
subsets = random_split(dataset, subset_sizes)

# 遍历每个子集并输出样本数量和类别分布
for i, subset in enumerate(subsets):
    print("子集{}样本数量:{}".format(i+1, len(subset)))
    class_counts_subset = [0] * len(class_counts)
    for sample in subset:
        class_counts_subset[sample["label"]] += 1
    print("子集{}类别分布:{}".format(i+1, class_counts_subset))

上述代码中,我们模拟了一个多类别的数据集,其中每个类别的样本数量分别为[100, 150, 200, 250]。

我们首先计算出划分后各个子集的大小,并实例化自定义数据集类MyDataset

然后我们使用torch.utils.data.random_split函数将数据集划分为训练集、验证集和测试集,并保证了每个子集中各个类别的样本比例相同。

最后,我们遍历每个子集,并输出样本数量和每个子集的类别分布,以验证划分结果的正确性。

总结

在本文中,我们介绍了如何使用Pytorch中的torch.utils.data.random_split函数来划分数据集,并遍历划分出的子集。我们学习了如何划分等长的子集,以及如何划分不等长的子集。这些操作在数据集划分和迭代中非常常见,并且通过示例代码演示了如何使用这些函数来实现这些操作。希望本文对您在使用Pytorch进行数据集划分和迭代时有所帮助。

Camera课程

Python教程

Java教程

Web教程

数据库教程

图形图像教程

办公软件教程

Linux教程

计算机教程

大数据教程

开发工具教程