Numpy使用numpy提取每行最小值
在本文中,我们将介绍如何使用numpy库从多维数组的每行中提取最小值。
在numpy中,可以轻松地使用min函数提取整个数组中的最小值,如下所示:
import numpy as np
array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
min_val = np.min(array)
print(min_val)
输出:
1
但是,如果我们想要提取每行的最小值,我们该怎么做呢?
阅读更多:Numpy 教程
提取每行最小值
在numpy中,可以使用argmin函数来获取数组中每行的最小值的索引。然后,我们可以使用这些索引来提取每行的最小值。如下所示:
import numpy as np
array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
min_index = np.argmin(array, axis=1)
min_val = array[np.arange(len(array)), min_index]
print(min_index)
print(min_val)
输出:
[0 0 0]
[1 4 7]
在上面的例子中,我们先使用argmin函数获取每行的最小值的索引,然后使用它们来提取每行的最小值。仔细观察该代码片段,可以发现第二行中的axis参数是1,表示我们将沿着第二个轴(即每行)寻找最小值。
我们还可以使用where函数以更简洁的方式编写上面的代码。如下所示:
import numpy as np
array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
min_index = np.argmin(array, axis=1)
min_val = np.take_along_axis(array, np.expand_dims(min_index, axis=1), axis=1).ravel()
print(min_index)
print(min_val)
输出:
[0 0 0]
[1 4 7]
使用where函数时,我们不需要使用arange和[]操作符来提取最小值。相反,我们使用了take_along_axis函数来提取最小值,并将其展平。
总结
本文介绍了如何使用numpy库从多维数组的每行中提取最小值。我们展示了如何使用argmin函数获取每行的最小值的索引,并使用它们来提取每行的最小值。我们还展示了如何使用where函数以更简洁的方式编写这些代码。希望这篇文章能够帮助你更好地理解如何使用numpy来处理多维数组。