Untitled
unknown
python
10 months ago
527 B
104
Indexable
from jax.experimental import shard_map
P = jax.sharding.PartitionSpec
import jax
import jax.numpy as jnp
mesh = jax.make_mesh((8,), ('x'))
def foo(x): #
return x.reshape(8, -1).mean(axis=0)
bar = shard_map.shard_map(lambda x: x.mean(keepdims=True), mesh=mesh, in_specs=(jax.sharding.PartitionSpec('x',),), out_specs=jax.sharding.PartitionSpec('x',))
fn = jax.jit(foo, out_shardings=jax.NamedSharding(mesh, P('x')))
x = jnp.zeros((8 * 1024), device=jax.NamedSharding(mesh, P('x')))
y = bar(x)
jax.block_until_ready(y)Editor is loading...
Leave a Comment