Untitled
unknown
plain_text
8 months ago
1.0 kB
14
Indexable
def sdedit(model, x_start, measurement, record, save_root):
# Note: initialization is different for sdedit compared to other methods
############### begin: complete the following code section ###############
img = measurement
pbar = tqdm(list(range(num_timesteps-500))[::-1])
############### end: complete the following code section ###############
device = x_start.device
for idx in pbar:
time = torch.tensor([idx] * img.shape[0], device=device)
img = img.requires_grad_()
############### begin: complete the following code section ###############
with torch.no_grad():
out = p_sample(model, img, time)
img = out['sample']
############### end: complete the following code section ###############
img = img.detach_()
if record:
if idx % 10 == 0:
file_path = os.path.join(save_root, f"progress/x_{str(idx).zfill(4)}.png")
plt.imsave(file_path, clear_color(img))
return imgEditor is loading...
Leave a Comment