Skip to content
Snippets Groups Projects
Unverified Commit 5ec1cef3 authored by Timo Kösters's avatar Timo Kösters
Browse files

More experiments

parent d5ba09cc
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,7 @@ def get_default_configs():
# training
config.training = training = ml_collections.ConfigDict()
#config.training.batch_size = 16
#config.training.batch_size = 64
config.training.batch_size = 4
training.n_iters = 1300001
......
......@@ -7,8 +7,8 @@ def get_default_configs():
# training
config.training = training = ml_collections.ConfigDict()
config.training.batch_size = 10
#config.training.batch_size = 1024
# config.training.batch_size = 25
config.training.batch_size = 1024
training.n_iters = 1300001
......@@ -70,7 +70,7 @@ def get_default_configs():
optim.grad_clip = 1.
config.seed = 42
#config.device = 'cuda:0'
config.device = 'cpu'
config.device = 'cuda:0'
# config.device = 'cpu'
return config
\ No newline at end of file
#!/bin/bash
rsync -av --progress --exclude workdir --exclude venv puffin:testing/workdir/samples/ workdir/samples
rsync -av --progress --exclude workdir_mnist2 --exclude venv puffin:testing/workdir_mnist2/samples/ workdir_mnist2/samples
rsync -av --progress --exclude workdir_mnist3 --exclude venv puffin:testing/workdir_mnist3/samples/ workdir_mnist3/samples
#rsync -av --progress --exclude workdir_mnist2 --exclude venv puffin:testing/workdir_mnist2/samples/ workdir_mnist2/samples
#rsync -av --progress --exclude workdir_mnist3 --exclude venv puffin:testing/workdir_mnist3/samples/ workdir_mnist3/samples
rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist5/samples/ workdir_mnist5/samples
rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist28/samples/ workdir_mnist28/samples
rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist29/samples/ workdir_mnist29/samples
rsync -av --progress --exclude workdir --exclude venv puffin:testing/workdir_mnist3/checkpoints-meta/ workdir_mnist3/checkpoints-meta
rsync -av --progress --exclude workdir --exclude venv puffin:testing/workdir_mnist5/checkpoints-meta/ workdir_mnist5/checkpoints-meta
rsync -av --progress --exclude workdir --exclude venv puffin:testing/workdir_mnist28/checkpoints-meta/ workdir_mnist28/checkpoints-meta
# rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist28/samples/ workdir_mnist28/samples
# rsync -av --progress --exclude workdir_mnist5 --exclude venv puffin:testing/workdir_mnist29/samples/ workdir_mnist29/samples
# rsync -av --progress --exclude workdir --exclude venv puffin:testing/workdir_mnist3/checkpoints-meta/ workdir_mnist3/checkpoints-meta
# rsync -av --progress --exclude workdir --exclude venv puffin:testing/workdir_mnist5/checkpoints-meta/ workdir_mnist5/checkpoints-meta
# rsync -av --progress --exclude workdir --exclude venv puffin:testing/workdir_mnist28/checkpoints-meta/ workdir_mnist28/checkpoints-meta
......@@ -210,9 +210,7 @@ def sample(config, workdir):
#anti_measure_fn = lambda x_tweedie, image: image
def measure_fn(image):
# Add noise
z = torch.randn_like(image)
measurements = torch.abs(torch.fft.fft2(image + 0.0*z))
measurements = torch.abs(torch.fft.fft2(image))
return measurements
def anti_measure_fn(x_tweedie, measured_diff):
......@@ -221,6 +219,9 @@ def sample(config, workdir):
return torch.real(torch.fft.ifft2((measured_diff / (torch.abs(x_tweedie_fft)+0.001) * x_tweedie_fft)))
measurements = measure_fn(targets)
# Add noise
z = torch.randn_like(measurements)
#measurements += 0.1* z
sample, n = sampling.euler_sampler_conditional(sample_dir, step, model, sde, sampling_shape, inverse_scaler, config.sampling.snr, config.sampling.n_steps_each, config.sampling.probability_flow, config.training.continuous, config.sampling.noise_removal, config.device, sampling_eps, measure_fn, anti_measure_fn, measurements, targets)
ema.restore(model.parameters())
......
#!/bin/bash
source venv/bin/activate
#TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="3" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/cifar10_ncsnpp_continuous.py --mode sample --workdir workdir > /dev/null
#TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="0" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/cifar10_ncsnpp_continuous.py --mode sample --workdir workdir > /dev/null
TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="0" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/mnist_ncsnpp_continuous.py --mode sample --workdir workdir_mnist5 > /dev/null
......@@ -19,6 +19,7 @@
import functools
import os
import math
import tensorflow as tf
from torchvision.utils import make_grid, save_image
import torch.nn.functional as F
......@@ -89,21 +90,27 @@ def ve_sde(x, t):
return drift, diffusion
def reverse_sde_old(model, sde, x, t, next_t):
def reverse_sde_old(model, sde, x, t, next_t, dt):
drift, diffusion = ve_sde(x, t)
score = mutils.score_fn(model, sde, x, t, False)
drift = drift - diffusion[:, None, None, None] ** 2 * score
return score, drift, diffusion
drift = x - (drift - diffusion[:, None, None, None] ** 2 * score) * dt
return score, drift, (diffusion * torch.sqrt(dt))[:, None, None, None]
def reverse_sde(model, sde, x, t, next_t, nextnext_t):
def reverse_sde(model, sde, x, t, next_t):
dt = (t - next_t)[:, None, None, None]
score = mutils.score_fn(model, sde, x, t, False)
sigma_min = 0.01
sigma_max = 50.0
sigma_i2 = (sigma_min * (sigma_max /sigma_min) ** t) ** 2
sigma_next_i2 = (sigma_min * (sigma_max /sigma_min) ** next_t) ** 2
sigma_nextnext_i2 = (sigma_min * (sigma_max /sigma_min) ** nextnext_t) ** 2
drift = x + (sigma_i2 - sigma_next_i2)[:, None, None, None] * score
diffusion = torch.sqrt(sigma_next_i2 - sigma_nextnext_i2)
# = diffusion^2 in old method ve_sde
sigma_delta = (2 * (sigma_min * (sigma_max / sigma_min) ** t) ** 2 * math.log(sigma_max / sigma_min))[:, None, None, None]
#sigma_i2 = (sigma_min * (sigma_max /sigma_min) ** t) ** 2
#sigma_next_i2 = (sigma_min * (sigma_max /sigma_min) ** next_t) ** 2
#sigma_delta = (sigma_i2 - sigma_next_i2)[:, None, None, None] / dt
drift = x + sigma_delta * dt * score
diffusion = torch.sqrt(sigma_delta * dt)
return score, drift, diffusion
......@@ -154,10 +161,12 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale
for run in range(1, 2):
# STEPS nichtlinear?
eprint("run " + str(run))
steps = 100
steps = 400
iterations = 10
eprint("steps=" + str(steps) + ", iters=" + str(iterations))
timesteps = torch.linspace(eps, sde.T, steps, device=device) ** 1.5
timesteps = torch.linspace(eps, sde.T, steps, device=device)
#timesteps[timesteps<0.5] = ((timesteps[timesteps<0.5] * 2.0) ** 0.5) / 2.0
#timesteps[timesteps>0.5] = ((timesteps[timesteps>0.5] * 2.0 - 1.0) ** 2.0) / 2.0 + 0.5
total_results = None
for samplei in range(10, 10+iterations):
# Initial sample
......@@ -182,9 +191,9 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale
# Euler sampler
z = torch.randn_like(x)
score, drift, diffusion = reverse_sde(model, sde, x, vec_t, vec_next_t, vec_nextnext_t)
new_x = drift + diffusion[:, None, None, None] * z
# score, drift, diffusion = reverse_sde_old(model, sde, x, vec_t, vec_next_t, 1.0 / steps)
score, drift, diffusion = reverse_sde(model, sde, x, vec_t, vec_next_t)
new_x = drift + diffusion * z
# PC sampler
"""
x, x_mean = langevin_update_fn(model, sde, x, vec_t, snr, n_steps)
......@@ -192,13 +201,13 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale
"""
# Gradient step
timestep = (t * (sde.N - 1) / sde.T).long()
sigma_min = 0.01
sigma_max = 50.0
sigma = sigma_min * (sigma_max /sigma_min) ** t
actual_variance = x.var((0, 1)).sqrt()[0, :].mean()
eprint("target variance=", sigma, " actual=", actual_variance, " error=", 1.0 - sigma / actual_variance)
#new_x = new_x / actual_variance * sigma
#sigma = sde.discrete_sigmas.to(t.device)[timestep]
x_tweedie = x + sigma*sigma * score
x_tweedie_measured = measure_fn(x_tweedie)
diff = anti_measure_fn(x_tweedie, torch.abs(target_measurements - x_tweedie_measured))
......@@ -206,18 +215,23 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale
losses = torch.sum(losses, (-1, -2), keepdim=True)
lossessum = torch.sum(losses)
x_grad = torch.autograd.grad(lossessum, x)[0]
new_x -= 1.0 * x_grad
#sigma_delta = (2 * (sigma_min * (sigma_max / sigma_min) ** t) ** 2 * math.log(sigma_max / sigma_min))
#dt = (t - next_t)
new_x -= 1.0 * x_grad #/ losses.sqrt()
# Manifold constraint
# Take phase from new_x and amplitude from y_t
# TODO: not every time, maybe every 10 iters?
if i < 60:
if True or i < 60:
score2 = mutils.score_fn(model, sde, new_x, vec_next_t, False)
sigma2 = sigma_min * (sigma_max /sigma_min) ** next_t
x_tweedie2 = new_x + sigma2*sigma2 * score2
new_x = anti_measure_fn(x_tweedie2, target_measurements)
#new_x = anti_measure_fn(x_tweedie2, target_measurements)
new_x = x_tweedie2 # for renoising in deblurring
z = torch.randn_like(x)
new_x = new_x + diffusion[:, None, None, None]*z
new_x = new_x + sigma2*z
#z = torch.randn_like(x)
#new_x = anti_measure_fn(new_x, measure_fn(sigma*z + targets))
......@@ -226,6 +240,7 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale
if i == steps-1:
save_sample(sample_dir, 0, steps-i, inverse_scaler(targets.detach()))
save_sample(sample_dir, 1, steps-i, inverse_scaler(target_measurements.detach()))
if i % 10 == 5:
save_sample(sample_dir, 1000, steps-i, inverse_scaler(x_tweedie.detach()))
# save_sample(sample_dir, 1000, steps-i, inverse_scaler(new_x.detach()))
......@@ -233,7 +248,7 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale
#elif (i-1) % 25 == 0:
# save_sample(sample_dir, step, steps-i, inverse_scaler(measure_fn(x_tweedie).detach()))
result = x
result = x_tweedie
result_measured = measure_fn(x_tweedie) # Use tweedie here so it's not constrained
measurement_error = mse(result_measured, target_measurements)
if total_results == None or total_results_measurement_error == None:
......@@ -321,8 +336,6 @@ def img2gray(rgb):
if rgb.shape[2] == 1:
return rgb.squeeze(2)
eprint(rgb.shape)
eprint(rgb[...,:3].shape)
return rgb[...,:3] @ np.array([0.2989, 0.5870, 0.1140])
def cross_correlation(true, pred):
true_gray, pred_gray = img2gray(true), img2gray(pred)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment