Untitled

 avatar
unknown
python
a month ago
1.4 kB
3
Indexable
import jax
import jax.numpy as jnp
import jax.sharding as shd

P = shd.PartitionSpec

mesh = shd.Mesh(devices=np.asarray(jax.devices()).reshape(8, 16), axis_names=('x', 'y'))

d_model = 8192
batch = 8
ffw_mult = 4

cpu_device = jax.devices('cpu')[0]

cpu_x = jnp.zeros((batch, d_model), dtype=jnp.bfloat16, device=cpu_device)
cpu_w1 = jnp.zeros((ffw_mult * d_model, d_model), dtype=jnp.bfloat16, device=cpu_device)
cpu_w2 = jnp.zeros((d_model, ffw_mult * d_model), dtype=jnp.bfloat16, device=cpu_device)

def matmul(w1, w2, x):
  return jnp.einsum('wf,bf->bw', w2, jnp.einsum('fw,bw->bf', w1, x))

def make_sharding():
  model_axes = ('x',)
  return (shd.NamedSharding(mesh, P(model_axes, None)),
          shd.NamedSharding(mesh, P(None, model_axes)),
          shd.NamedSharding(mesh, P(None,)))

w1_sharding, w2_sharding, x_sharding = make_sharding()

try:
  del x, w1, w2
except:
  pass

x, w1, w2 = jax.device_put(cpu_x, x_sharding), jax.device_put(cpu_w1, w1_sharding), jax.device_put(cpu_w2, w2_sharding)

options = jax.stages.CompilerOptions(xla_vf_vmem_max_outstanding_prefetches=0)
jit_matmul = jax.jit(matmul, in_shardings=make_sharding()).lower(w1, w2, x).compile(compiler_options=options)
with pw.xprof_trace(
    block_until_start=True, trace_python=False, trace_plaque=False, collect_pprof=False, devices=jax.devices()) as url:
  result = jit_matmul(w1, w2, x)
  jax.block_until_ready(result)

url
Leave a Comment