Untitled
unknown
plain_text
9 months ago
1.4 kB
5
Indexable
def generate(self, prompt, duration, model_size, seed, audio_input=None):
# Initialize model with caching
if not hasattr(self, "model") or self.current_model != model_size:
self.model = MusicGen.get_pretrained(f"facebook/musicgen-{model_size}")
self.current_model = model_size
# Configure generation parameters
self.model.set_generation_params(
duration=duration,
top_k=250,
top_p=0.0,
temperature=1.0,
cfg_coef=1.0,
)
# Handle different generation modes
if audio_input is not None:
# Audio continuation generation
audio_output = self.model.generate_continuation(
audio_input,
self.model.sample_rate,
[prompt],
progress=True
)
else:
# Text-to-music generation
audio_output = self.model.generate(
[prompt],
progress=True
)
# Format output for ComfyUI (batch, channels, samples)
audio_output = audio_output.unsqueeze(1) # Add channel dimension if mono
return (audio_output.cpu(),)
@classmethod
def IS_CHANGED(cls, **kwargs):
return float("nan")Editor is loading...
Leave a Comment