Jax中关于NonZero的使用
技术背景在Jax的JIT即时编译中,会追踪每一个Tensor的Shape变化。如果在计算的过程中出现了一些动态Shape的Tensor(Shape大小跟输入的数据有关),那么就无法使用Jax的JIT进行编译优化。最常见的就是numpy.where这种操作,因为这个操作返回的是符合判定条件的Index序号,而不同输入对应的输出Index长度一般是不一致的,因此在Jax的JIT中无法对该操作进行编译。当然,需要特别说明的是,numpy.where这个操作有两种用法,一种是numpy.where(condition, 1, 0)直接给输入打上Mask。另一种用法是numpy.where(condition),这种用法返回的就是一个Index序列,也就是我们需要讨论的应用场景。
普通模式
我们考虑一个比较简单的Toy Model用于测试:
\
在不采用即时编译的场景下,Jax的代码可以这么写:
import osos.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'import numpy as npnp.random.seed(0)import jaxfrom jax import numpy as jnpdef func(r, q, cutoff=0.2): dis = jnp.abs(r[:, None] - r) maski, maskj = jnp.where(dis<=cutoff) qi = q qj = q return jnp.sum(qi*qj)N = 100r = jnp.array(np.random.random(N), jnp.float32)q = jnp.array(np.random.random(N), jnp.float32)print (func(r, q))# 1035.7422那么我们先记住这个输出的结果,因为采用的随机种子是一致的,一会儿可以直接跟JIT的输出结果进行对比。
JIT模式
Jax的JIT模式的使用方法常见的就是三种,一种是在函数头顶加一个装饰器,一种是在函数引用的时候使用jax.jit(function)来调用,最后一种是配合partial偏函数来使用,都不是很复杂。那么这里先用装饰器的形式演示一下Jax中即时编译的用法:
import osos.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'import numpy as npnp.random.seed(0)import jaxfrom jax import numpy as jnp@jax.jitdef func(r, q, cutoff=0.2): dis = jnp.abs(r[:, None] - r) maski, maskj = jnp.where(dis<=cutoff) qi = q qj = q return jnp.sum(qi*qj)N = 100r = jnp.array(np.random.random(N), jnp.float32)q = jnp.array(np.random.random(N), jnp.float32)print (func(r, q))正如前面所说,因为numpy.where对应的输出是一个动态的Shape,那么在编译阶段就会报错。报错信息如下:
Traceback (most recent call last):File "/home/dechin/projects/gitee/dechin/tests/jax_mask.py", line 21, in <module> print (func(r, q))File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs)File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/api.py", line 622, in cache_miss execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy return xla_callable(fun, device, backend, name, donated_invars, keep_unused,File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/linear_util.py", line 303, in memoized_fun ans = call(fun, *args)File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached return lower_xla_callable(fun, device, backend, name, donated_invars, False,File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs)File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs)File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2 jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2 ans = fun.call_wrapped(*in_tracers_)File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs))File "/home/dechin/projects/gitee/dechin/tests/jax_mask.py", line 12, in func maski, maskj = jnp.where(dis<=cutoff)File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1077, in where return nonzero(condition, size=size, fill_value=fill_value)File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1332, in nonzero size = core.concrete_or_error(operator.index, size,File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/core.py", line 1278, in concrete_or_error raise ConcretizationTypeError(val, context)jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.The error occurred while tracing the function func at /home/dechin/projects/gitee/dechin/tests/jax_mask.py:9 for jit. This concrete value was not available in Python because it depends on the value of the argument 'r'.See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeErrorThe stack trace below excludes JAX-internal frames.The preceding is the original exception that occurred, unmodified.--------------------The above exception was the direct cause of the following exception:Traceback (most recent call last):File "/home/dechin/projects/gitee/dechin/tests/jax_mask.py", line 21, in <module> print (func(r, q))File "/home/dechin/projects/gitee/dechin/tests/jax_mask.py", line 12, in func maski, maskj = jnp.where(dis<=cutoff)File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1077, in where return nonzero(condition, size=size, fill_value=fill_value)File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1332, in nonzero size = core.concrete_or_error(operator.index, size,jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.The error occurred while tracing the function func at /home/dechin/projects/gitee/dechin/tests/jax_mask.py:9 for jit. This concrete value was not available in Python because it depends on the value of the argument 'r'.See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError想避免这个报错,要么就是对该函数不做编译(牺牲性能),要么就是自己写一个CUDA算子(增加工作量),再就是我们这里用到的NonZero定长输出的方法(预置条件)。
NonZero的使用
使用Jax的NonZero函数时,也有一点需要注意:虽然NonZero可以做到固定长度的输出,但是这个固定的长度本身也是一个名为size的传入参数。也就是说,NonZero的输出Shape也是要取决于输入参数的。Jax开发时也考虑到了这一点,所以在编译时提供了一个功能可以设置静态参量:static_argnames,例如我们的案例中,将size这个名称的传参设置为静态参量,这样就可以使用Jax的即时编译了:
import osos.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'import numpy as npnp.random.seed(0)import jaxfrom jax import numpy as jnpfrom functools import partial@partial(jax.jit, static_argnames='size')def func(r, q, cutoff=0.2, size=5000): if q.shape != r.shape+1: raise ValueError("The q.shape should be equal to r.shape+1") dis = jnp.abs(r[:, None] - r) maski, maskj = jnp.nonzero(jnp.where(dis<=cutoff, 1, 0), size=size, fill_value=-1) qi = q qj = q return jnp.sum(qi*qj)N = 100r = jnp.array(np.random.random(N), jnp.float32)q = jnp.array(np.random.random(N), jnp.float32)padder = jnp.array(, jnp.float32)q = jnp.append(q, padder)print (func(r, q))# 1035.7422可以看到,函数用Jax的JIT成功编译,并且输出结果跟前面未编译时候是一致的。当然,这里还用到了一个小技巧,就是NonZero函数输出结果时,不到长度的输出结果会被自动Pad到给定的长度,这里Pad的值使用的是我们给出的fill_value。因为NonZero输出的也是索引,这样我们可以把Pad的这些索引设置为-1,然后在构建参数\(q\)的时候事先在末尾append一个0,这样就可以确保计算的输出结果直接就是正确的。
总结概要
在使用Jax的过程中,有时候会遇到函数输出是一个动态的Shape,这种情况下我们很难利用到Jax的即时编译的功能,不能使得性能最大化。这也是使用Tensor数据结构来计算的一个特点,有好有坏。本文介绍了Jax的另外一个函数NonZero,可以使得我们能够编译那些动态Shape输出的函数。
版权声明
本文首发链接为:https://www.cnblogs.com/dechinphy/p/nonzero.html
作者ID:DechinPhy
更多原著文章:https://www.cnblogs.com/dechinphy/
请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html
页:
[1]