Python 如何在不减慢控制流和数组速度的情况下使用jit

Python 如何在不减慢控制流和数组速度的情况下使用jit

问题描述

import jax
from jax import jit
import jax.numpy as jnp
import numpy as np


array1 = np.random.normal(size=(1000,1000))
def first():
  for i in range(1000):
    for j in range(1000):
      if array1[i,j] >= 0:
        array1[i,j] = 1
      else:
        array1[i,j] = -1
# %timeit first()

from functools import partial
key = jax.random.PRNGKey(seed=0)
array2 = jax.random.normal(key, shape=(1000,1000))

@partial(jit, static_argnums=(0,1,2))
def f( i,j):
  r = jax.lax.cond(array2[i,j] >= 0, lambda x: 1, lambda x: -1, None)
  # if array2[i,j] >= 0:
  # # if i == j:
  #   array2.at[i,j].set(1)
  # else: array2.at[i,j].set(-1)
  array2.at[i,j].set(r)

# f_jit = jit(f, static_argnums=(0,1))
def second():
  for i in range(1000):
    for j in range(1000):
      # jax.lax.cond(array2[i,j]>=0, lambda x: True, lambda x: False, None)
      f(i,j)
%timeit second()

我有两个函数: firstsecond 。我希望 second 能像(或更快地)运行 first 一样快。 first 函数是使用 numpy 的函数。 second 函数使用 jax 。 在这种情况下,最好的方法是如何使用 jax 实现 first 函数? jax.lax.cond 会明显减慢运行速度,我认为。

我故意保留了注释以展示我尝试过的东西。

解决方案

first 运行相对快的原因是它执行了1,000,000个numpy数组操作,并且numpy已经针对快速每次操作分发进行了优化。

second 运行相对慢的原因是它执行了1,000,000个JAX数组操作,而JAX并没有针对快速每次操作分发进行优化。

关于这个问题的一些基本背景,请参阅 JAX FAQ:JAX比NumPy更快吗? 。

但是,如果你问如何以最快的方式完成你的工作,在NumPy和JAX中的答案都是避免编写循环。下面是等效的代码,将计算变为纯函数,以便比较(你的原始 second 函数实际上什么也没做,因为 array.at[i].set() 没有进行原地操作):

def first_fast(array):
  return np.where(array >= 0, 1, 0)

def second_fast(array):
  return jnp.where(array >= 0, 1, 0)

通常情况下,如果你发现自己在NumPy或JAX中编写循环来处理数组数值,那么可以预期结果代码会执行得很慢。在NumPy和JAX中,几乎总会有更好的方法使用内置的向量化操作来计算结果。

如果你对JAX和NumPy之间的进一步性能测试感兴趣,请务必阅读 FAQ: Benchmarking JAX Code 以确保你正在比较正确的事物。

Camera课程

Python教程

Java教程

Web教程

数据库教程

图形图像教程

办公软件教程

Linux教程

计算机教程

大数据教程

开发工具教程