Skip to content
Snippets Groups Projects
Unverified Commit 53e8bbe2 authored by Timo Kösters's avatar Timo Kösters
Browse files

e2e deblur

parent c3e2aace
Branches e2e
No related tags found
No related merge requests found
...@@ -7,8 +7,8 @@ def get_default_configs(): ...@@ -7,8 +7,8 @@ def get_default_configs():
# training # training
config.training = training = ml_collections.ConfigDict() config.training = training = ml_collections.ConfigDict()
#config.training.batch_size = 64 config.training.batch_size = 128
config.training.batch_size = 4 # config.training.batch_size = 4
training.n_iters = 1300001 training.n_iters = 1300001
...@@ -73,7 +73,7 @@ def get_default_configs(): ...@@ -73,7 +73,7 @@ def get_default_configs():
optim.grad_clip = 1. optim.grad_clip = 1.
config.seed = 42 config.seed = 42
#config.device = 'cuda:0' config.device = 'cuda:0'
config.device = 'cpu' # config.device = 'cpu'
return config return config
...@@ -7,8 +7,8 @@ def get_default_configs(): ...@@ -7,8 +7,8 @@ def get_default_configs():
# training # training
config.training = training = ml_collections.ConfigDict() config.training = training = ml_collections.ConfigDict()
config.training.batch_size = 16 # config.training.batch_size = 16
# config.training.batch_size = 1024 config.training.batch_size = 1024
training.n_iters = 1300001 training.n_iters = 1300001
...@@ -70,7 +70,7 @@ def get_default_configs(): ...@@ -70,7 +70,7 @@ def get_default_configs():
optim.grad_clip = 1. optim.grad_clip = 1.
config.seed = 42 config.seed = 42
# config.device = 'cuda:0' config.device = 'cuda:0'
config.device = 'cpu' # config.device = 'cpu'
return config return config
...@@ -4,6 +4,7 @@ rsync -av --progress --exclude workdir --exclude venv puffin:testing/workdir/sam ...@@ -4,6 +4,7 @@ rsync -av --progress --exclude workdir --exclude venv puffin:testing/workdir/sam
#rsync -av --progress --exclude workdir_mnist2 --exclude venv puffin:testing/workdir_mnist2/samples/ workdir_mnist2/samples #rsync -av --progress --exclude workdir_mnist2 --exclude venv puffin:testing/workdir_mnist2/samples/ workdir_mnist2/samples
#rsync -av --progress --exclude workdir_mnist3 --exclude venv puffin:testing/workdir_mnist3/samples/ workdir_mnist3/samples #rsync -av --progress --exclude workdir_mnist3 --exclude venv puffin:testing/workdir_mnist3/samples/ workdir_mnist3/samples
rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist5/samples/ workdir_mnist5/samples rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist5/samples/ workdir_mnist5/samples
rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist_e2e_deblur/samples/ workdir_mnist_e2e_deblur/samples
rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist_e2e_fft/samples/ workdir_mnist_e2e_fft/samples rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist_e2e_fft/samples/ workdir_mnist_e2e_fft/samples
# rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist28/samples/ workdir_mnist28/samples # rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist28/samples/ workdir_mnist28/samples
# rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist29/samples/ workdir_mnist29/samples # rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist29/samples/ workdir_mnist29/samples
......
...@@ -16,10 +16,12 @@ ...@@ -16,10 +16,12 @@
"""All functions related to loss computation and optimization. """All functions related to loss computation and optimization.
""" """
from utils import eprint, save_checkpoint, restore_checkpoint
import torch import torch
import torch.optim as optim import torch.optim as optim
import numpy as np import numpy as np
from models import utils as mutils from models import utils as mutils
from torchvision.transforms import GaussianBlur
from sde_lib import VESDE from sde_lib import VESDE
import logging import logging
...@@ -68,14 +70,17 @@ def loss_fn(model, sde, batch, reduce_mean, train): ...@@ -68,14 +70,17 @@ def loss_fn(model, sde, batch, reduce_mean, train):
t = torch.full((batch.shape[0],), sde.T, device=batch.device) #torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps t = torch.full((batch.shape[0],), sde.T, device=batch.device) #torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
z = torch.randn_like(batch) z = torch.randn_like(batch)
mean, std = sde.marginal_prob(batch, t) mean, std = sde.marginal_prob(batch, t)
def measure_fn(image): measure_fn = GaussianBlur(5, 3.0)
measurements = torch.abs(torch.fft.fft2(image))
return measurements
perturbed_data = measure_fn(batch) perturbed_data = measure_fn(batch)
z = torch.randn_like(perturbed_data)
perturbed_data += 0.02 * z;
score = mutils.score_fn(model, sde, perturbed_data, t, train) score = mutils.score_fn(model, sde, perturbed_data, t, train)
# eprint("score", score)
# eprint("batch", batch)
losses = torch.square(score - batch) losses = torch.square(score - batch)
# eprint("losses", losses)
reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs) reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)
......
...@@ -230,8 +230,8 @@ class NCSNpp(nn.Module): ...@@ -230,8 +230,8 @@ class NCSNpp(nn.Module):
self.all_modules = nn.ModuleList(modules) self.all_modules = nn.ModuleList(modules)
def forward(self, x): def forward(self, x, time_cond):
time_cond = torch.tensor(np.zeros((16,)), device=x.device, dtype=x.dtype) # time_cond = torch.tensor(np.zeros((16,)), device=x.device, dtype=x.dtype)
# timestep/noise_level embedding; only for continuous training # timestep/noise_level embedding; only for continuous training
modules = self.all_modules modules = self.all_modules
m_idx = 0 m_idx = 0
......
...@@ -157,10 +157,16 @@ def train(config, workdir): ...@@ -157,10 +157,16 @@ def train(config, workdir):
eval_batch = torch.from_numpy(next(eval_iter)['image']._numpy()).to(config.device).float() eval_batch = torch.from_numpy(next(eval_iter)['image']._numpy()).to(config.device).float()
eval_batch = eval_batch.permute(0, 3, 1, 2) eval_batch = eval_batch.permute(0, 3, 1, 2)
eval_batch = scaler(eval_batch) eval_batch = scaler(eval_batch)
def measure_fn(image): # def measure_fn(image):
measurements = torch.abs(torch.fft.fft2(image)) # measurements = torch.abs(torch.fft.fft2(image))
return measurements # return measurements
measure_fn = GaussianBlur(5, 3.0)
# def measure_fn(image):
# measurements = torch.abs(torch.fft.fft2(image))
# return measurements
perturbed_data = measure_fn(eval_batch) perturbed_data = measure_fn(eval_batch)
z = torch.randn_like(perturbed_data)
perturbed_data += 0.02 * z;
t = torch.full((eval_batch.shape[0],), sde.T, device=eval_batch.device) t = torch.full((eval_batch.shape[0],), sde.T, device=eval_batch.device)
result = mutils.score_fn(model, sde, perturbed_data, t, train) result = mutils.score_fn(model, sde, perturbed_data, t, train)
...@@ -240,9 +246,10 @@ def sample(config, workdir): ...@@ -240,9 +246,10 @@ def sample(config, workdir):
batch = scaler(batch) batch = scaler(batch)
targets = batch targets = batch
#measure_fn = GaussianBlur(15, 2.0) measure_fn = GaussianBlur(5, 3.0)
#anti_measure_fn = lambda x_tweedie, image: image anti_measure_fn = lambda x_tweedie, image: image
"""
def measure_fn(image): def measure_fn(image):
measurements = torch.abs(torch.fft.fft2(image)) measurements = torch.abs(torch.fft.fft2(image))
return measurements return measurements
...@@ -251,11 +258,17 @@ def sample(config, workdir): ...@@ -251,11 +258,17 @@ def sample(config, workdir):
# Take phase from x_tweedie and amplitude from measured_diff # Take phase from x_tweedie and amplitude from measured_diff
x_tweedie_fft = torch.fft.fft2(x_tweedie) x_tweedie_fft = torch.fft.fft2(x_tweedie)
return torch.real(torch.fft.ifft2((measured_diff / (torch.abs(x_tweedie_fft)+0.001) * x_tweedie_fft))) return torch.real(torch.fft.ifft2((measured_diff / (torch.abs(x_tweedie_fft)+0.001) * x_tweedie_fft)))
"""
measurements = measure_fn(targets) measurements = measure_fn(targets)
# Add noise # Add noise
z = torch.randn_like(measurements) z = torch.randn_like(measurements)
#measurements += 0.1* z measurements += 0.02 * z;
# Salt and Pepper
# z = torch.rand_like(measurements)
# measurements[z<0.02] = 1.0
# z = torch.rand_like(measurements)
# measurements[z<0.02] = 0.0
sample, n = sampling.euler_sampler_conditional(sample_dir, step, model, sde, sampling_shape, inverse_scaler, config.sampling.snr, config.sampling.n_steps_each, config.sampling.probability_flow, config.training.continuous, config.sampling.noise_removal, config.device, sampling_eps, measure_fn, anti_measure_fn, measurements, targets) sample, n = sampling.euler_sampler_conditional(sample_dir, step, model, sde, sampling_shape, inverse_scaler, config.sampling.snr, config.sampling.n_steps_each, config.sampling.probability_flow, config.training.continuous, config.sampling.noise_removal, config.device, sampling_eps, measure_fn, anti_measure_fn, measurements, targets)
ema.restore(model.parameters()) ema.restore(model.parameters())
......
...@@ -76,7 +76,7 @@ def langevin_update_fn(model, sde, x, t, target_snr, n_steps): ...@@ -76,7 +76,7 @@ def langevin_update_fn(model, sde, x, t, target_snr, n_steps):
diff = step_size[:, None, None, None] * grad diff = step_size[:, None, None, None] * grad
x_mean = x + diff x_mean = x + diff
noise2 = 1.00 * torch.sqrt(step_size * 2)[:, None, None, None] * noise noise2 = 1.00 * torch.sqrt(step_size * 2)[:, None, None, None] * noise
x = x_mean + noise2 x = x_mean.detach() + noise2.detach()
return x, x_mean return x, x_mean
...@@ -161,10 +161,11 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale ...@@ -161,10 +161,11 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale
for run in range(1, 2): for run in range(1, 2):
# STEPS nichtlinear? # STEPS nichtlinear?
eprint("run " + str(run)) eprint("run " + str(run))
steps = 400 steps = 100
iterations = 10 iterations = 10
eprint("steps=" + str(steps) + ", iters=" + str(iterations)) eprint("steps=" + str(steps) + ", iters=" + str(iterations))
timesteps = torch.linspace(eps, sde.T, steps, device=device) timesteps = torch.linspace(eps, sde.T, steps, device=device)
#timesteps = timesteps ** 0.375
#timesteps[timesteps<0.5] = ((timesteps[timesteps<0.5] * 2.0) ** 0.5) / 2.0 #timesteps[timesteps<0.5] = ((timesteps[timesteps<0.5] * 2.0) ** 0.5) / 2.0
#timesteps[timesteps>0.5] = ((timesteps[timesteps>0.5] * 2.0 - 1.0) ** 2.0) / 2.0 + 0.5 #timesteps[timesteps>0.5] = ((timesteps[timesteps>0.5] * 2.0 - 1.0) ** 2.0) / 2.0 + 0.5
total_results = None total_results = None
...@@ -194,9 +195,11 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale ...@@ -194,9 +195,11 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale
# score, drift, diffusion = reverse_sde_old(model, sde, x, vec_t, vec_next_t, 1.0 / steps) # score, drift, diffusion = reverse_sde_old(model, sde, x, vec_t, vec_next_t, 1.0 / steps)
score, drift, diffusion = reverse_sde(model, sde, x, vec_t, vec_next_t) score, drift, diffusion = reverse_sde(model, sde, x, vec_t, vec_next_t)
new_x = drift + diffusion * z new_x = drift + diffusion * z
# PC sampler # PC sampler
""" """
x, x_mean = langevin_update_fn(model, sde, x, vec_t, snr, n_steps) x, x_mean = langevin_update_fn(model, sde, x, vec_t, snr, n_steps)
x = x.requires_grad_()
new_x, x_mean, score = reverse_diffusion_update_fn(model, sde, x, vec_t) new_x, x_mean, score = reverse_diffusion_update_fn(model, sde, x, vec_t)
""" """
...@@ -217,12 +220,12 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale ...@@ -217,12 +220,12 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale
x_grad = torch.autograd.grad(lossessum, x)[0] x_grad = torch.autograd.grad(lossessum, x)[0]
#sigma_delta = (2 * (sigma_min * (sigma_max / sigma_min) ** t) ** 2 * math.log(sigma_max / sigma_min)) #sigma_delta = (2 * (sigma_min * (sigma_max / sigma_min) ** t) ** 2 * math.log(sigma_max / sigma_min))
#dt = (t - next_t) #dt = (t - next_t)
new_x -= 1.0 * x_grad #/ losses.sqrt() new_x -= 0.1 * x_grad
# Manifold constraint # Manifold constraint
# Take phase from new_x and amplitude from y_t # Take phase from new_x and amplitude from y_t
# TODO: not every time, maybe every 10 iters? # TODO: not every time, maybe every 10 iters?
if True or i < 60: if False: #and i < 60:
score2 = mutils.score_fn(model, sde, new_x, vec_next_t, False) score2 = mutils.score_fn(model, sde, new_x, vec_next_t, False)
sigma2 = sigma_min * (sigma_max /sigma_min) ** next_t sigma2 = sigma_min * (sigma_max /sigma_min) ** next_t
x_tweedie2 = new_x + sigma2*sigma2 * score2 x_tweedie2 = new_x + sigma2*sigma2 * score2
......
...@@ -3,4 +3,5 @@ source venv/bin/activate ...@@ -3,4 +3,5 @@ source venv/bin/activate
#TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="0" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/cifar10_ncsnpp_continuous.py --mode train --workdir workdir_mnist3 > /dev/null #TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="0" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/cifar10_ncsnpp_continuous.py --mode train --workdir workdir_mnist3 > /dev/null
#TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="0" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/mnist_ncsnpp_continuous.py --mode train --workdir workdir_mnist5 > /dev/null #TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="0" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/mnist_ncsnpp_continuous.py --mode train --workdir workdir_mnist5 > /dev/null
TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="0" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/mnist_ncsnpp_continuous.py --mode train --workdir workdir_mnist_e2e_fft > /dev/null # TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="0" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/mnist_ncsnpp_continuous.py --mode train --workdir workdir_mnist_e2e_fft > /dev/null
TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="0" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/mnist_ncsnpp_continuous.py --mode train --workdir workdir_mnist_e2e_deblur > /dev/null
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment