torch.distributions.categorical详解
在PyTorch中,torch.distributions
模块提供了许多常见的概率分布,用于生成各种随机变量。其中,torch.distributions.categorical
是用来表示多元分布的类,可以用来生成服从多元分布的随机样本。
创建Categorical分布
要创建一个Categorical分布对象,可以使用torch.distributions.Categorical
类。下面是一个简单的示例代码:
import torch
from torch.distributions import Categorical
probs = torch.tensor([0.1, 0.3, 0.6]) # 定义多元分布的概率向量
dist = Categorical(probs) # 创建Categorical分布对象
在上面的代码中,我们首先定义了一个概率向量probs
,表示多元分布的各个类别的概率。然后使用Categorical
类来创建一个Categorical分布对象dist
。
生成样本
一旦我们创建了Categorical分布对象,就可以使用sample()
方法来生成符合该分布的随机样本。下面是一个示例代码:
sample = dist.sample()
print(sample)
上面的代码中,我们调用了sample()
方法来生成一个随机样本,并将结果打印出来。
计算概率密度
除了生成随机样本外,我们还可以使用log_prob()
方法来计算给定样本的对数概率密度值。下面是一个示例代码:
print(dist.log_prob(sample))
批处理操作
Categorical分布还支持批处理操作,可以一次生成多个样本。我们可以通过sample()
方法的sample_shape
参数来指定生成样本的个数。下面是一个示例代码:
samples = dist.sample(sample_shape=(3, 2))
print(samples)
上面的代码中,我们生成了一个3×2的样本矩阵。每一行表示一个样本。
总结
通过本文的介绍,我们了解了如何使用torch.distributions.categorical
来表示多元分布,并且生成符合该分布的随机样本。除了生成随机样本外,我们还可以计算给定样本的概率密度值。