Untitled
plain_text
2 months ago
574 B
1
Indexable
Never
using Random Random.seed!(42) function julia_selu2(x, alpha=1.67f0, lmbda = 1.05f0) @. ifelse(x>0, lmbda * x, lmbda * (alpha * exp(x) - alpha)) end x = randn(Float32, 1000000); @time julia_selu2(x); @time julia_selu2(x); # 0.013850 seconds ############ in jax ############## import jax.numpy as jnp from jax import jit from jax import random key = random.PRNGKey(42) def selu(x, alpha=1.67, lmbda=1.05): return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) x = random.normal(key, (1000000,)) %timeit selu(x).block_until_ready() # 6.15 ms ± 694 µs