Python 如何在不减慢控制流和数组速度的情况下使用jit
问题描述
import jax
from jax import jit
import jax.numpy as jnp
import numpy as np
array1 = np.random.normal(size=(1000,1000))
def first():
for i in range(1000):
for j in range(1000):
if array1[i,j] >= 0:
array1[i,j] = 1
else:
array1[i,j] = -1
# %timeit first()
from functools import partial
key = jax.random.PRNGKey(seed=0)
array2 = jax.random.normal(key, shape=(1000,1000))
@partial(jit, static_argnums=(0,1,2))
def f( i,j):
r = jax.lax.cond(array2[i,j] >= 0, lambda x: 1, lambda x: -1, None)
# if array2[i,j] >= 0:
# # if i == j:
# array2.at[i,j].set(1)
# else: array2.at[i,j].set(-1)
array2.at[i,j].set(r)
# f_jit = jit(f, static_argnums=(0,1))
def second():
for i in range(1000):
for j in range(1000):
# jax.lax.cond(array2[i,j]>=0, lambda x: True, lambda x: False, None)
f(i,j)
%timeit second()
我有两个函数: first
和 second
。我希望 second
能像(或更快地)运行 first
一样快。 first
函数是使用 numpy
的函数。 second
函数使用 jax
。 在这种情况下,最好的方法是如何使用 jax
实现 first
函数? jax.lax.cond
会明显减慢运行速度,我认为。
我故意保留了注释以展示我尝试过的东西。
解决方案
first
运行相对快的原因是它执行了1,000,000个numpy数组操作,并且numpy已经针对快速每次操作分发进行了优化。
second
运行相对慢的原因是它执行了1,000,000个JAX数组操作,而JAX并没有针对快速每次操作分发进行优化。
关于这个问题的一些基本背景,请参阅 JAX FAQ:JAX比NumPy更快吗? 。
但是,如果你问如何以最快的方式完成你的工作,在NumPy和JAX中的答案都是避免编写循环。下面是等效的代码,将计算变为纯函数,以便比较(你的原始 second
函数实际上什么也没做,因为 array.at[i].set()
没有进行原地操作):
def first_fast(array):
return np.where(array >= 0, 1, 0)
def second_fast(array):
return jnp.where(array >= 0, 1, 0)
通常情况下,如果你发现自己在NumPy或JAX中编写循环来处理数组数值,那么可以预期结果代码会执行得很慢。在NumPy和JAX中,几乎总会有更好的方法使用内置的向量化操作来计算结果。
如果你对JAX和NumPy之间的进一步性能测试感兴趣,请务必阅读 FAQ: Benchmarking JAX Code 以确保你正在比较正确的事物。