Untitled
unknown
plain_text
2 years ago
574 B
8
Indexable
using Random
Random.seed!(42)
function julia_selu2(x, alpha=1.67f0, lmbda = 1.05f0)
@. ifelse(x>0, lmbda * x, lmbda * (alpha * exp(x) - alpha))
end
x = randn(Float32, 1000000);
@time julia_selu2(x);
@time julia_selu2(x); # 0.013850 seconds
############ in jax ##############
import jax.numpy as jnp
from jax import jit
from jax import random
key = random.PRNGKey(42)
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready() # 6.15 ms ± 694 µs
Editor is loading...