Untitled
import jax import jax.numpy as jnp import jax.sharding as shd P = shd.PartitionSpec mesh = shd.Mesh(devices=np.asarray(jax.devices()).reshape(8, 16), axis_names=('x', 'y')) d_model = 8192 batch = 8 ffw_mult = 4 cpu_device = jax.devices('cpu')[0] cpu_x = jnp.zeros((batch, d_model), dtype=jnp.bfloat16, device=cpu_device) cpu_w1 = jnp.zeros((ffw_mult * d_model, d_model), dtype=jnp.bfloat16, device=cpu_device) cpu_w2 = jnp.zeros((d_model, ffw_mult * d_model), dtype=jnp.bfloat16, device=cpu_device) def matmul(w1, w2, x): return jnp.einsum('wf,bf->bw', w2, jnp.einsum('fw,bw->bf', w1, x)) def make_sharding(): model_axes = ('x',) return (shd.NamedSharding(mesh, P(model_axes, None)), shd.NamedSharding(mesh, P(None, model_axes)), shd.NamedSharding(mesh, P(None,))) w1_sharding, w2_sharding, x_sharding = make_sharding() try: del x, w1, w2 except: pass x, w1, w2 = jax.device_put(cpu_x, x_sharding), jax.device_put(cpu_w1, w1_sharding), jax.device_put(cpu_w2, w2_sharding) options = jax.stages.CompilerOptions(xla_vf_vmem_max_outstanding_prefetches=0) jit_matmul = jax.jit(matmul, in_shardings=make_sharding()).lower(w1, w2, x).compile(compiler_options=options) with pw.xprof_trace( block_until_start=True, trace_python=False, trace_plaque=False, collect_pprof=False, devices=jax.devices()) as url: result = jit_matmul(w1, w2, x) jax.block_until_ready(result) url
Leave a Comment