Untitled
def generate(self, prompt, duration, model_size, seed, audio_input=None): # Initialize model with caching if not hasattr(self, "model") or self.current_model != model_size: self.model = MusicGen.get_pretrained(f"facebook/musicgen-{model_size}") self.current_model = model_size # Configure generation parameters self.model.set_generation_params( duration=duration, top_k=250, top_p=0.0, temperature=1.0, cfg_coef=1.0, ) # Handle different generation modes if audio_input is not None: # Audio continuation generation audio_output = self.model.generate_continuation( audio_input, self.model.sample_rate, [prompt], progress=True ) else: # Text-to-music generation audio_output = self.model.generate( [prompt], progress=True ) # Format output for ComfyUI (batch, channels, samples) audio_output = audio_output.unsqueeze(1) # Add channel dimension if mono return (audio_output.cpu(),) @classmethod def IS_CHANGED(cls, **kwargs): return float("nan")
Leave a Comment