Untitled
unknown
python
3 years ago
7.4 kB
13
Indexable
Sampling 1392 tokens for [0,1392]. Conditioning on 0 tokens
Ancestral sampling 3 samples with temp=0.98, top_k=0, top_p=0.0
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_375/1755400942.py in <module>
1 zs=[t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(3)]
----> 2 zs=sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)
3 x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()
/opt/conda/lib/python3.7/site-packages/jukebox/sample.py in sample_partial_window(zs, labels, sampling_kwargs, level, prior, tokens_to_sample, hps)
26 start = current_tokens - n_ctx + tokens_to_sample
27
---> 28 return sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps)
29
30 # Sample a single window of length=n_ctx at position=start on level=level
/opt/conda/lib/python3.7/site-packages/jukebox/sample.py in sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps)
67 z_samples = []
68 for z_i, z_conds_i, y_i in zip(z_list, z_conds_list, y_list):
---> 69 z_samples_i = prior.sample(n_samples=z_i.shape[0], z=z_i, z_conds=z_conds_i, y=y_i, **sampling_kwargs)
70 z_samples.append(z_samples_i)
71 z = t.cat(z_samples, dim=0)
/opt/conda/lib/python3.7/site-packages/jukebox/prior/prior.py in sample(self, n_samples, z, z_conds, y, fp16, temp, top_k, top_p, chunk_size, sample_tokens)
272 z = self.prior_postprocess(z)
273 else:
--> 274 encoder_kv = self.get_encoder_kv(prime, fp16=fp16, sample=True)
275 if no_past_context:
276 z = self.prior.sample(n_samples, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp, top_k=top_k,
/opt/conda/lib/python3.7/site-packages/jukebox/prior/prior.py in get_encoder_kv(self, prime, fp16, sample)
288 self.prime_prior.cuda()
289 N = prime.shape[0]
--> 290 prime_acts = self.prime_prior(prime, None, None, None, fp16=fp16)
291 assert_shape(prime_acts, (N, self.prime_loss_dims, self.prime_acts_width))
292 assert prime_acts.dtype == t.float, f'Expected t.float, got {prime_acts.dtype}'
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
/opt/conda/lib/python3.7/site-packages/jukebox/prior/autoregressive.py in forward(self, x, x_cond, y_cond, encoder_kv, fp16, loss_full, encode, get_preds, get_acts, get_sep_loss)
147 x = self.x_emb_dropout(x) + self.pos_emb_dropout(self.pos_emb()) + x_cond # Pos emb and dropout
148
--> 149 x = self.transformer(x, encoder_kv=encoder_kv, fp16=fp16) # Transformer
150 if self.add_cond_after_transformer: # Piped doesnt add x_cond
151 x = x + x_cond
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
/opt/conda/lib/python3.7/site-packages/jukebox/transformer/transformer.py in forward(self, x, encoder_kv, sample, fp16, fp16_out)
185 x = l(x, encoder_kv=encoder_kv, sample=sample)
186 else:
--> 187 x = l(x, encoder_kv=None, sample=sample)
188 if l.attn.record_attn:
189 self.ws.append(l.attn.w)
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
/opt/conda/lib/python3.7/site-packages/jukebox/transformer/transformer.py in forward(self, x, encoder_kv, sample)
76 (x,),
77 (*self.attn.parameters(), *self.ln_0.parameters()),
---> 78 self.checkpoint_attn == 3) # 2 recomputes after the projections, and 1 recomputes after head splitting.
79 m = checkpoint(lambda _x: self.mlp(self.ln_1(_x)), (x + a,),
80 (*self.mlp.parameters(), *self.ln_1.parameters()),
/opt/conda/lib/python3.7/site-packages/jukebox/utils/checkpoint.py in checkpoint(func, inputs, params, flag)
7 return CheckpointFunction.apply(func, len(inputs), *args)
8 else:
----> 9 return func(*inputs)
10
11 class CheckpointFunction(t.autograd.Function):
/opt/conda/lib/python3.7/site-packages/jukebox/transformer/transformer.py in <lambda>(_x, _enc_kv, _s)
73 else:
74 assert encoder_kv is None
---> 75 a = checkpoint(lambda _x,_enc_kv=None,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s),
76 (x,),
77 (*self.attn.parameters(), *self.ln_0.parameters()),
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
/opt/conda/lib/python3.7/site-packages/jukebox/transformer/factored_attention.py in forward(self, x, encoder_kv, sample)
289 def forward(self, x, encoder_kv=None, sample=False):
290 curr_ctx = x.shape[1]
--> 291 x = self.c_attn(x)
292 query, key, value, sample = self.qkv(x, encoder_kv=encoder_kv, sample=sample)
293 if self.checkpoint_attn == 2 and not sample:
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
/opt/conda/lib/python3.7/site-packages/jukebox/transformer/ops.py in forward(self, x)
97 def forward(self, x):
98 size_out = (*x.size()[:-1], self.n_out)
---> 99 x = t.addmm(self.b.type_as(x), x.view(-1, x.size(-1)), self.w.type_as(x)) # If x if float then float else half
100 x = x.view(*size_out)
101 return x
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16F, lda, b, CUDA_R_16F, ldb, &fbeta, c, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP)`Editor is loading...