Untitled

 avatar
unknown
plain_text
12 days ago
1.4 kB
3
Indexable
    def generate(self, prompt, duration, model_size, seed, audio_input=None):
        # Initialize model with caching
        if not hasattr(self, "model") or self.current_model != model_size:
            self.model = MusicGen.get_pretrained(f"facebook/musicgen-{model_size}")
            self.current_model = model_size
            
        # Configure generation parameters
        self.model.set_generation_params(
            duration=duration,
            top_k=250,
            top_p=0.0,
            temperature=1.0,
            cfg_coef=1.0,
        )
        
        # Handle different generation modes
        if audio_input is not None:
            # Audio continuation generation
            audio_output = self.model.generate_continuation(
                audio_input, 
                self.model.sample_rate,
                [prompt],
                progress=True
            )
        else:
            # Text-to-music generation
            audio_output = self.model.generate(
                [prompt], 
                progress=True
            )
            
        # Format output for ComfyUI (batch, channels, samples)
        audio_output = audio_output.unsqueeze(1)  # Add channel dimension if mono
        
        return (audio_output.cpu(),)

    @classmethod
    def IS_CHANGED(cls, **kwargs):
        return float("nan")
Leave a Comment