Untitled
unknown
plain_text
6 days ago
1.0 kB
13
Indexable
def sdedit(model, x_start, measurement, record, save_root): # Note: initialization is different for sdedit compared to other methods ############### begin: complete the following code section ############### img = measurement pbar = tqdm(list(range(num_timesteps-500))[::-1]) ############### end: complete the following code section ############### device = x_start.device for idx in pbar: time = torch.tensor([idx] * img.shape[0], device=device) img = img.requires_grad_() ############### begin: complete the following code section ############### with torch.no_grad(): out = p_sample(model, img, time) img = out['sample'] ############### end: complete the following code section ############### img = img.detach_() if record: if idx % 10 == 0: file_path = os.path.join(save_root, f"progress/x_{str(idx).zfill(4)}.png") plt.imsave(file_path, clear_color(img)) return img
Editor is loading...
Leave a Comment