Untitled
unknown
python
3 months ago
527 B
47
Indexable
from jax.experimental import shard_map P = jax.sharding.PartitionSpec import jax import jax.numpy as jnp mesh = jax.make_mesh((8,), ('x')) def foo(x): # return x.reshape(8, -1).mean(axis=0) bar = shard_map.shard_map(lambda x: x.mean(keepdims=True), mesh=mesh, in_specs=(jax.sharding.PartitionSpec('x',),), out_specs=jax.sharding.PartitionSpec('x',)) fn = jax.jit(foo, out_shardings=jax.NamedSharding(mesh, P('x'))) x = jnp.zeros((8 * 1024), device=jax.NamedSharding(mesh, P('x'))) y = bar(x) jax.block_until_ready(y)
Editor is loading...
Leave a Comment