Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
7.4 kB
6
Indexable
Never
Sampling 1392 tokens for [0,1392]. Conditioning on 0 tokens
Ancestral sampling 3 samples with temp=0.98, top_k=0, top_p=0.0
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_375/1755400942.py in <module>
      1 zs=[t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(3)]
----> 2 zs=sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)
      3 x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()

/opt/conda/lib/python3.7/site-packages/jukebox/sample.py in sample_partial_window(zs, labels, sampling_kwargs, level, prior, tokens_to_sample, hps)
     26         start = current_tokens - n_ctx + tokens_to_sample
     27 
---> 28     return sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps)
     29 
     30 # Sample a single window of length=n_ctx at position=start on level=level

/opt/conda/lib/python3.7/site-packages/jukebox/sample.py in sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps)
     67     z_samples = []
     68     for z_i, z_conds_i, y_i in zip(z_list, z_conds_list, y_list):
---> 69         z_samples_i = prior.sample(n_samples=z_i.shape[0], z=z_i, z_conds=z_conds_i, y=y_i, **sampling_kwargs)
     70         z_samples.append(z_samples_i)
     71     z = t.cat(z_samples, dim=0)

/opt/conda/lib/python3.7/site-packages/jukebox/prior/prior.py in sample(self, n_samples, z, z_conds, y, fp16, temp, top_k, top_p, chunk_size, sample_tokens)
    272                 z = self.prior_postprocess(z)
    273             else:
--> 274                 encoder_kv = self.get_encoder_kv(prime, fp16=fp16, sample=True)
    275                 if no_past_context:
    276                     z = self.prior.sample(n_samples, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp, top_k=top_k,

/opt/conda/lib/python3.7/site-packages/jukebox/prior/prior.py in get_encoder_kv(self, prime, fp16, sample)
    288                 self.prime_prior.cuda()
    289             N = prime.shape[0]
--> 290             prime_acts = self.prime_prior(prime, None, None, None, fp16=fp16)
    291             assert_shape(prime_acts, (N, self.prime_loss_dims, self.prime_acts_width))
    292             assert prime_acts.dtype == t.float, f'Expected t.float, got {prime_acts.dtype}'

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/jukebox/prior/autoregressive.py in forward(self, x, x_cond, y_cond, encoder_kv, fp16, loss_full, encode, get_preds, get_acts, get_sep_loss)
    147         x = self.x_emb_dropout(x) + self.pos_emb_dropout(self.pos_emb()) + x_cond # Pos emb and dropout
    148 
--> 149         x = self.transformer(x, encoder_kv=encoder_kv, fp16=fp16) # Transformer
    150         if self.add_cond_after_transformer: # Piped doesnt add x_cond
    151             x = x + x_cond

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/jukebox/transformer/transformer.py in forward(self, x, encoder_kv, sample, fp16, fp16_out)
    185                     x = l(x, encoder_kv=encoder_kv, sample=sample)
    186                 else:
--> 187                     x = l(x, encoder_kv=None, sample=sample)
    188             if l.attn.record_attn:
    189                 self.ws.append(l.attn.w)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/jukebox/transformer/transformer.py in forward(self, x, encoder_kv, sample)
     76                                (x,),
     77                                (*self.attn.parameters(), *self.ln_0.parameters()),
---> 78                                self.checkpoint_attn == 3)  # 2 recomputes after the projections, and 1 recomputes after head splitting.
     79             m = checkpoint(lambda _x: self.mlp(self.ln_1(_x)), (x + a,),
     80                            (*self.mlp.parameters(), *self.ln_1.parameters()),

/opt/conda/lib/python3.7/site-packages/jukebox/utils/checkpoint.py in checkpoint(func, inputs, params, flag)
      7         return CheckpointFunction.apply(func, len(inputs), *args)
      8     else:
----> 9         return func(*inputs)
     10 
     11 class CheckpointFunction(t.autograd.Function):

/opt/conda/lib/python3.7/site-packages/jukebox/transformer/transformer.py in <lambda>(_x, _enc_kv, _s)
     73             else:
     74                 assert encoder_kv is None
---> 75                 a = checkpoint(lambda _x,_enc_kv=None,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s),
     76                                (x,),
     77                                (*self.attn.parameters(), *self.ln_0.parameters()),

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/jukebox/transformer/factored_attention.py in forward(self, x, encoder_kv, sample)
    289     def forward(self, x, encoder_kv=None, sample=False):
    290         curr_ctx = x.shape[1]
--> 291         x = self.c_attn(x)
    292         query, key, value, sample = self.qkv(x, encoder_kv=encoder_kv, sample=sample)
    293         if self.checkpoint_attn == 2 and not sample:

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/jukebox/transformer/ops.py in forward(self, x)
     97     def forward(self, x):
     98         size_out = (*x.size()[:-1], self.n_out)
---> 99         x = t.addmm(self.b.type_as(x), x.view(-1, x.size(-1)), self.w.type_as(x)) # If x if float then float else half
    100         x = x.view(*size_out)
    101         return x

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16F, lda, b, CUDA_R_16F, ldb, &fbeta, c, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP)`