Untitled

 avatar
unknown
python
3 months ago
527 B
47
Indexable
from jax.experimental import shard_map

P = jax.sharding.PartitionSpec

import jax
import jax.numpy as jnp

mesh = jax.make_mesh((8,), ('x'))

def foo(x):  #
  return x.reshape(8, -1).mean(axis=0)

bar = shard_map.shard_map(lambda x: x.mean(keepdims=True), mesh=mesh, in_specs=(jax.sharding.PartitionSpec('x',),), out_specs=jax.sharding.PartitionSpec('x',))

fn = jax.jit(foo, out_shardings=jax.NamedSharding(mesh, P('x')))

x = jnp.zeros((8 * 1024), device=jax.NamedSharding(mesh, P('x')))

y = bar(x)
jax.block_until_ready(y)
Editor is loading...
Leave a Comment