Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
574 B
2
Indexable
Never
using Random
Random.seed!(42)

function julia_selu2(x, alpha=1.67f0, lmbda = 1.05f0)
    @. ifelse(x>0, lmbda * x, lmbda * (alpha * exp(x) - alpha))
end

x = randn(Float32, 1000000);

@time julia_selu2(x);
@time julia_selu2(x);  # 0.013850 seconds

############ in jax ##############
import jax.numpy as jnp
from jax import jit
from jax import random
key = random.PRNGKey(42)

def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready() # 6.15 ms ± 694 µs