torch.distributions.categorical详解

torch.distributions.categorical详解

torch.distributions.categorical详解

在PyTorch中,torch.distributions模块提供了许多常见的概率分布,用于生成各种随机变量。其中,torch.distributions.categorical是用来表示多元分布的类,可以用来生成服从多元分布的随机样本。

创建Categorical分布

要创建一个Categorical分布对象,可以使用torch.distributions.Categorical类。下面是一个简单的示例代码:

import torch
from torch.distributions import Categorical

probs = torch.tensor([0.1, 0.3, 0.6])  # 定义多元分布的概率向量
dist = Categorical(probs)  # 创建Categorical分布对象

在上面的代码中,我们首先定义了一个概率向量probs,表示多元分布的各个类别的概率。然后使用Categorical类来创建一个Categorical分布对象dist

生成样本

一旦我们创建了Categorical分布对象,就可以使用sample()方法来生成符合该分布的随机样本。下面是一个示例代码:

sample = dist.sample()
print(sample)

上面的代码中,我们调用了sample()方法来生成一个随机样本,并将结果打印出来。

计算概率密度

除了生成随机样本外,我们还可以使用log_prob()方法来计算给定样本的对数概率密度值。下面是一个示例代码:

print(dist.log_prob(sample))

批处理操作

Categorical分布还支持批处理操作,可以一次生成多个样本。我们可以通过sample()方法的sample_shape参数来指定生成样本的个数。下面是一个示例代码:

samples = dist.sample(sample_shape=(3, 2))
print(samples)

上面的代码中,我们生成了一个3×2的样本矩阵。每一行表示一个样本。

总结

通过本文的介绍,我们了解了如何使用torch.distributions.categorical来表示多元分布,并且生成符合该分布的随机样本。除了生成随机样本外,我们还可以计算给定样本的概率密度值。

Camera课程

Python教程

Java教程

Web教程

数据库教程

图形图像教程

办公软件教程

Linux教程

计算机教程

大数据教程

开发工具教程