From 019c0c733fd187cd55c14830d82cad7d8d2b8056 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timo=20K=C3=B6sters?= <timo@koesters.xyz>
Date: Sun, 19 Nov 2023 11:39:45 +0100
Subject: [PATCH] Final tests for renoising

---
 configs/default_cifar10_configs.py |  4 ++--
 sample.sh                          |  4 ++--
 sampling.py                        | 30 +++++++++++++++---------------
 3 files changed, 19 insertions(+), 19 deletions(-)

diff --git a/configs/default_cifar10_configs.py b/configs/default_cifar10_configs.py
index c42aeb7..0f6481a 100644
--- a/configs/default_cifar10_configs.py
+++ b/configs/default_cifar10_configs.py
@@ -8,7 +8,7 @@ def get_default_configs():
   config.training = training = ml_collections.ConfigDict()
 
   #config.training.batch_size = 64
-  config.training.batch_size = 4
+  config.training.batch_size = 1
 
   training.n_iters = 1300001
 
@@ -76,4 +76,4 @@ def get_default_configs():
   #config.device = 'cuda:0'
   config.device = 'cpu'
 
-  return config
\ No newline at end of file
+  return config
diff --git a/sample.sh b/sample.sh
index e98f2b4..3507f70 100755
--- a/sample.sh
+++ b/sample.sh
@@ -1,5 +1,5 @@
 #!/bin/bash
 source venv/bin/activate
 
-#TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="0" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/cifar10_ncsnpp_continuous.py --mode sample --workdir workdir > /dev/null
-TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="0" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/mnist_ncsnpp_continuous.py --mode sample --workdir workdir_mnist5 > /dev/null
+TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="1" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/cifar10_ncsnpp_continuous.py --mode sample --workdir workdir > /dev/null
+# TORCH_CUDNN_V8_API_DISABLED=1 CUDA_VISIBLE_DEVICES="1" MIOPEN_LOG_LEVEL=4 python3 main.py --config configs/ve/mnist_ncsnpp_continuous.py --mode sample --workdir workdir_mnist5 > /dev/null
diff --git a/sampling.py b/sampling.py
index c9d0717..0b8c0a5 100644
--- a/sampling.py
+++ b/sampling.py
@@ -161,8 +161,8 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale
   for run in range(1, 2):
     # STEPS nichtlinear?
     eprint("run " + str(run))
-    steps = 400
-    iterations = 10
+    steps = 1000
+    iterations = 1
     eprint("steps=" + str(steps) + ", iters=" + str(iterations))
     timesteps = torch.linspace(eps, sde.T, steps, device=device)
     #timesteps[timesteps<0.5] = ((timesteps[timesteps<0.5] * 2.0) ** 0.5) / 2.0
@@ -217,24 +217,24 @@ def euler_sampler_conditional(sample_dir, step, model, sde, shape, inverse_scale
         x_grad = torch.autograd.grad(lossessum, x)[0]
         #sigma_delta = (2 * (sigma_min * (sigma_max / sigma_min) ** t) ** 2 * math.log(sigma_max / sigma_min))
         #dt = (t - next_t)
-        new_x -= 1.0 * x_grad #/ losses.sqrt()
+        new_x -= 0.05 * x_grad #/ losses.sqrt()
 
         # Manifold constraint
         # Take phase from new_x and amplitude from y_t
         # TODO: not every time, maybe every 10 iters?
-        if True or i < 60:
-            score2 = mutils.score_fn(model, sde, new_x, vec_next_t, False)
-            sigma2 = sigma_min * (sigma_max /sigma_min) ** next_t
-            x_tweedie2 = new_x + sigma2*sigma2 * score2
-
-            #new_x = anti_measure_fn(x_tweedie2, target_measurements)
-            new_x = x_tweedie2 # for renoising in deblurring
-
+        # if True or i < 60:
+        if i % 10 == 0:
+            # Renoising
+            # score2 = mutils.score_fn(model, sde, new_x, vec_next_t, False)
+            # sigma2 = sigma_min * (sigma_max /sigma_min) ** next_t
+            # x_tweedie2 = new_x + sigma2*sigma2 * score2
+            # new_x = anti_measure_fn(x_tweedie2, target_measurements)
+            # z = torch.randn_like(x)
+            # new_x = new_x + sigma2*z
+
+            # Hypothetical constraint
             z = torch.randn_like(x)
-            new_x = new_x + sigma2*z
-
-            #z = torch.randn_like(x)
-            #new_x = anti_measure_fn(new_x, measure_fn(sigma*z + targets))
+            new_x = anti_measure_fn(new_x, measure_fn(sigma*z + targets))
 
         x = new_x.detach()
 
-- 
GitLab