There is an alternative to bypass the computational challenge of vanilla score matching with
denoising score matching.
Denoising score matching takes its name from the procedure to solve score matching scalability. It consists of adding noise to the data point to help us computing the trace of the Jacobian term.
To peform denoising score matching we need to define a perturbation kernel denoted as \( q_\sigma \).
Being \( \mathbf{x} \) a noise-free datapoint, \( \mathbf{\tilde{x}} \) a perturbed data point and \( \sigma \) a gaussian distribution (typically), we convolve the perturbation kernel with the original data distribution \( p_{\text{data}}(\mathbf{x}) \), to get a noisy data distribution \( q_\sigma (\mathbf{\tilde{x}}) \).
The key idea behind denoising score matching is to estimate the scope function of the noise data density instead of the score function of the original data density.
\[
\frac{1}{2} \mathbb{E}_{q_\sigma(\mathbf{\tilde{x}})} \left[ \left\| \nabla_\mathbf{\tilde{x}} \log p_\sigma(\mathbf{\tilde{x}}) - s_\theta(\mathbf{\tilde{x}}) \right\|^2_2 \right].
\tag{Eq. 11} \label{eq:denoising-score-matching}
\]
When estimating the score function of the noisy data distribution, the equivalent form obtained sfter aritmetic deviation which is the objective of denoising score matching is:
\[
\frac{1}{2} \mathbb{E}_{p_{\text{data}(\mathbf{x})}} \mathbb{E}_{q_\sigma (\mathbf{\tilde{x} \vert \mathbf{x}}}) \left[ \left\| \nabla_\mathbf{\tilde{x}} \log p_\sigma(\mathbf{\tilde{x}} \vert \mathbf{x}) - s_\theta(\mathbf{\tilde{x}}) \right\|^2_2 \right].
\tag{Eq. 12} \label{eq:denoising-score-matching-objective}
\]
The gradient of the perturbation kernel $\nabla_\mathbf{\tilde{x}} \log p_\sigma(\mathbf{\tilde{x}} \vert \mathbf{x})$ is scalable and fully tractable because we usually define the perturbation kernel by hand.
Fig. 15. Score matching objectives.
The tradeoff of denoising score matching is that, since it requires adding noise to datapoints, it cannot estimate the scores of the noise-free distributions. When are trying to lower the magnitude of the noise, the variance becomes bigger and bigger and eventually explodes. Therefore, it is not to use denoising score matching for noise-free score estimation.
Applying first the definition of expectation in continuous probability and then expandinf the square of the norm \( \| \mathbf{a} - \mathbf{b} \|_2^2 = \| \mathbf{a} \|_2^2 + \| \mathbf{b} \|_2^2 - 2 \mathbf{a}^\top \mathbf{b} \), we can derive the objective:
\[
\begin{align}
\frac{1}{2} \mathbb{E}_{\tilde{\mathbf{x}} \sim q_\sigma} \left[ \left\| \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}}) - s_\theta(\tilde{\mathbf{x}}) \right\|_2^2 \right] &= \tag{Eq. 13.1} \\
&= \frac{1}{2} \int q_\sigma(\tilde{\mathbf{x}}) \left\| \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}}) - s_\theta(\tilde{\mathbf{x}}) \right\|_2^2 d\tilde{\mathbf{x}} \tag{Eq. 13.2} \\
&= \frac{1}{2} \int q_\sigma(\tilde{\mathbf{x}}) \left\| \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}}) \right\|_2^2 d\tilde{\mathbf{x}}
+ \frac{1}{2} \int q_\sigma(\tilde{\mathbf{x}}) \left\| s_\theta(\tilde{\mathbf{x}}) \right\|_2^2 d\tilde{\mathbf{x}}
- \int q_\sigma(\tilde{\mathbf{x}}) \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}})^\top s_\theta(\tilde{\mathbf{x}}) d\tilde{\mathbf{x}}. \tag{Eq. 13.3}
\end{align}
\label{eq:denoising-score-matching-objective-expanded}
\]
where we can write the third term as an expectation:
\[
\begin{align}
- \int q_\sigma(\tilde{\mathbf{x}}) \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}})^\top \mathbf{s}_\theta(\tilde{\mathbf{x}}) \, d\tilde{\mathbf{x}} \tag{Eq. 14.1} \\
&= - \int \nabla_{\tilde{\mathbf{x}}} q_\sigma(\tilde{\mathbf{x}})^\top \mathbf{s}_\theta(\tilde{\mathbf{x}}) \, d\tilde{\mathbf{x}} \tag{Eq. 14.2} \\
&= - \int \nabla_{\tilde{\mathbf{x}}} \left( \int p_\text{data}(\mathbf{x}) q_\sigma(\tilde{\mathbf{x}} | \mathbf{x}) \, d\mathbf{x} \right)^\top \mathbf{s}_\theta(\tilde{\mathbf{x}}) \, d\tilde{\mathbf{x}} \tag{Eq. 14.3} \\
&= - \int \nabla_{\tilde{\mathbf{x}}} \left( \int p_\text{data}(\mathbf{x}) q_\sigma(\tilde{\mathbf{x}} | \mathbf{x}) \, d\mathbf{x} \right)^\top \mathbf{s}_\theta(\tilde{\mathbf{x}}) \, d\tilde{\mathbf{x}} \tag{Eq. 14.4} \\
&= - \int \left( \int p_\text{data}(\mathbf{x}) \nabla_{\tilde{\mathbf{x}}} q_\sigma(\tilde{\mathbf{x}} | \mathbf{x}) \, d\mathbf{x} \right)^\top \mathbf{s}_\theta(\tilde{\mathbf{x}}) \, d\tilde{\mathbf{x}} \tag{Eq. 14.5} \\
&= - \int \left( \int p_\text{data}(\mathbf{x}) \nabla_{\tilde{\mathbf{x}}} q_\sigma(\tilde{\mathbf{x}} | \mathbf{x}) \, d\mathbf{x} \right)^\top \mathbf{s}_\theta(\tilde{\mathbf{x}}) \, d\tilde{\mathbf{x}} \tag{Eq. 14.6} \\
&= - \int \left( \int p_\text{data}(\mathbf{x}) q_\sigma(\tilde{\mathbf{x}} | \mathbf{x}) \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}} | \mathbf{x}) \, d\mathbf{x} \right)^\top \mathbf{s}_\theta(\tilde{\mathbf{x}}) \, d\tilde{\mathbf{x}} \tag{Eq. 14.7} \\
&= - \int \left( \int p_\text{data}(\mathbf{x}) q_\sigma(\tilde{\mathbf{x}} | \mathbf{x}) \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}} | \mathbf{x}) \, d\mathbf{x} \right)^\top \mathbf{s}_\theta(\tilde{\mathbf{x}}) \, d\tilde{\mathbf{x}} \tag{Eq. 14.8} \\
&= - \int \int p_\text{data}(\mathbf{x}) q_\sigma(\tilde{\mathbf{x}} | \mathbf{x}) \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}} | \mathbf{x})^\top \mathbf{s}_\theta(\tilde{\mathbf{x}}) \, d\mathbf{x} \, d\tilde{\mathbf{x}} \tag{Eq. 14.9} \\
&= - \int \int p_\text{data}(\mathbf{x}) q_\sigma(\tilde{\mathbf{x}} | \mathbf{x}) \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}} | \mathbf{x})^\top \mathbf{s}_\theta(\tilde{\mathbf{x}}) \, d\mathbf{x} \, d\tilde{\mathbf{x}} \tag{Eq. 14.10} \\
&= - \mathbb{E}_{\mathbf{x} \sim p_\text{data}(\mathbf{x}), \tilde{\mathbf{x}} \sim q_\sigma(\tilde{\mathbf{x}} | \mathbf{x})} \left[ \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}} | \mathbf{x})^\top \mathbf{s}_\theta(\tilde{\mathbf{x}}) \right] \tag{Eq. 14.11} \\
\end{align}
\]
with this derivation, we can rewrite
Eq. 13.1. as:
\[
\begin{align}
\frac{1}{2} \mathbb{E}_{\tilde{\mathbf{x}} \sim q_\sigma} \left[ \left\| \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}}) - s_\theta(\tilde{\mathbf{x}}) \right\|_2^2 \right] \tag{Eq. 15.1} \\
&= \text{const.} + \frac{1}{2} \mathbb{E}_{\tilde{\mathbf{x}} \sim q_\sigma} \left[ \| \mathbf{s}_\theta(\tilde{\mathbf{x}}) \|_2^2 \right] - \int q_\sigma(\tilde{\mathbf{x}}) \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}})^\top \mathbf{s}_\theta(\tilde{\mathbf{x}}) \, d\tilde{\mathbf{x}} \tag{Eq. 15.2} \\
&= \text{const.} + \frac{1}{2} \mathbb{E}_{\tilde{\mathbf{x}} \sim q_\sigma} \left[ \| \mathbf{s}_\theta(\tilde{\mathbf{x}}) \|_2^2 \right] - \mathbb{E}_{\mathbf{x} \sim p_\text{data}(\mathbf{x}), \tilde{\mathbf{x}} \sim q_\sigma(\tilde{\mathbf{x}} | \mathbf{x})} \left[ \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}} | \mathbf{x})^\top \mathbf{s}_\theta(\tilde{\mathbf{x}}) \right] \tag{Eq. 15.3} \\
\end{align}
\]
Introducing Squared Norm for the Gradient Term (3rd term) we get:
\[
\begin{align}
= \text{const.} + \frac{1}{2} \mathbb{E}_{\mathbf{x} \sim p_\text{data}(\mathbf{x}), \tilde{\mathbf{x}} \sim q_\sigma(\tilde{\mathbf{x}} | \mathbf{x})} \left[ \| \mathbf{s}_\theta(\tilde{\mathbf{x}}) \|_2^2 \right] - \frac{1}{2} \mathbb{E}_{\mathbf{x} \sim p_\text{data}(\mathbf{x}), \tilde{\mathbf{x}} \sim q_\sigma(\tilde{\mathbf{x}} | \mathbf{x})} \left[ \| \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}} | \mathbf{x}) \|_2^2 \right] \tag{Eq. 16.1} \\
= \text{const.} + \frac{1}{2} \mathbb{E}_{\mathbf{x} \sim p_\text{data}(\mathbf{x}), \tilde{\mathbf{x}} \sim q_\sigma(\tilde{\mathbf{x}} | \mathbf{x})} \left[ \| \mathbf{s}_\theta(\tilde{\mathbf{x}}) - \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}} | \mathbf{x}) \|_2^2 \right] + \text{const.} \tag{Eq. 16.2} \\
= \frac{1}{2} \mathbb{E}_{\mathbf{x} \sim p_\text{data}(\mathbf{x}), \tilde{\mathbf{x}} \sim q_\sigma(\tilde{\mathbf{x}} | \mathbf{x})} \left[ \| \mathbf{s}_\theta(\tilde{\mathbf{x}}) - \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x}} | \mathbf{x}) \|_2^2 \right] + \text{const.} \tag{Eq. 16.3} \\
\end{align}
\]
As a conclusion, denoising score matching can be summarize as:
- Sample a minibatch of datapoints: \( \{ \mathbf{x}_1, \mathbf{x}_2, \dots, \mathbf{x}_N \} \sim p_{\text{data}}(\mathbf{x}) \)
- Sample a minibatch of perturbed datapoints: \( \{ \mathbf{\tilde{x}}_1, \mathbf{\tilde{x}}_2, \dots, \mathbf{\tilde{x}}n \} \overset{} {\sim} q_\sigma(\mathbf{\tilde{x}}) \) with \( \mathbf{\tilde{x}}_i \sim q_\sigma (\mathbf{\tilde{x}}_i \vert \mathbf{x}) \)
- Estimate the denoising score matching loss with empirical means:
\[
= \frac{1}{2n} \sum_{i=1}^n \left[ \| \mathbf{s}_\theta(\tilde{\mathbf{x_i}}) - \nabla_{\tilde{\mathbf{x}}} \log q_\sigma(\tilde{\mathbf{x_i}} | \mathbf{x_i}) \|_2^2 \right].
\tag{Eq. 17} \label{eq:denoising-score-matching-loss}
\]
In the special case of a Gaussian perturbation:
\[
= \frac{1}{2n} \sum_{i=1}^n \left[ \left| \left| \mathbf{s}_\theta(\tilde{\mathbf{x_i}}) - \frac{\tilde{\mathbf{x_i}} - \mathbf{x_i}}{\sigma^2} \right|\right|_2^2 \right].
\tag{Eq. 18} \label{eq:gaussian-perturbation}
\]
- Stochastic gradient descent.
- Choose a very small \( \sigma \) (be careful because the variance can explode).