Skip to content
Snippets Groups Projects
Unverified Commit 019c0c73 authored by Timo Kösters's avatar Timo Kösters
Browse files

Final tests for renoising

parent 5ec1cef3
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,7 @@ def get_default_configs():
config.training = training = ml_collections.ConfigDict()
#config.training.batch_size = 64
config.training.batch_size = 4
config.training.batch_size = 1
training.n_iters = 1300001
......
#!/bin/bash
source venv/bin/activate
#TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="0" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/cifar10_ncsnpp_continuous.py --mode sample --workdir workdir > /dev/null
TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="0" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/mnist_ncsnpp_continuous.py --mode sample --workdir workdir_mnist5 > /dev/null
TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="1" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/cifar10_ncsnpp_continuous.py --mode sample --workdir workdir > /dev/null
# TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="1" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/mnist_ncsnpp_continuous.py --mode sample --workdir workdir_mnist5 > /dev/null
......@@ -161,8 +161,8 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale
for run in range(1, 2):
# STEPS nichtlinear?
eprint("run " + str(run))
steps = 400
iterations = 10
steps = 1000
iterations = 1
eprint("steps=" + str(steps) + ", iters=" + str(iterations))
timesteps = torch.linspace(eps, sde.T, steps, device=device)
#timesteps[timesteps<0.5] = ((timesteps[timesteps<0.5] * 2.0) ** 0.5) / 2.0
......@@ -217,24 +217,24 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale
x_grad = torch.autograd.grad(lossessum, x)[0]
#sigma_delta = (2 * (sigma_min * (sigma_max / sigma_min) ** t) ** 2 * math.log(sigma_max / sigma_min))
#dt = (t - next_t)
new_x -= 1.0 * x_grad #/ losses.sqrt()
new_x -= 0.05 * x_grad #/ losses.sqrt()
# Manifold constraint
# Take phase from new_x and amplitude from y_t
# TODO: not every time, maybe every 10 iters?
if True or i < 60:
score2 = mutils.score_fn(model, sde, new_x, vec_next_t, False)
sigma2 = sigma_min * (sigma_max /sigma_min) ** next_t
x_tweedie2 = new_x + sigma2*sigma2 * score2
# if True or i < 60:
if i % 10 == 0:
# Renoising
# score2 = mutils.score_fn(model, sde, new_x, vec_next_t, False)
# sigma2 = sigma_min * (sigma_max /sigma_min) ** next_t
# x_tweedie2 = new_x + sigma2*sigma2 * score2
# new_x = anti_measure_fn(x_tweedie2, target_measurements)
new_x = x_tweedie2 # for renoising in deblurring
# z = torch.randn_like(x)
# new_x = new_x + sigma2*z
# Hypothetical constraint
z = torch.randn_like(x)
new_x = new_x + sigma2*z
#z = torch.randn_like(x)
#new_x = anti_measure_fn(new_x, measure_fn(sigma*z + targets))
new_x = anti_measure_fn(new_x, measure_fn(sigma*z + targets))
x = new_x.detach()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment