Kernel Methods for Neural Architecture Search
Published:
Summary
The aim of this project was to propose a principled (training-free) metric for scoring models on a given dataset, thereby introducing a new strategy for Neural Architecture Search (NAS). The score is defined as the (kernel) canonical correlation (Akaho, 2007; Melzer et al., 2001) between the inputs, \(\mathbf{X}\), and the outputs, \(\mathbf{Y}\), with respect to the Reproducing Kernel Hilbert Spaces (RKHS) corresponding to the Neural Tangent Kernel (NTK) and the linear kernel, respectively. We provide a theoretical motivation for this metric, and aim to validate its efficacy on NAS benchmarks: NAS-Bench-201 (Dong et al., 2020) and DARTS (Liu et al., 2019).
Introduction
Selecting Neural Network (NN) architectures for a task has traditionally been a manual, time-intensive process requiring domain expertise and extensive trial-and-error. For example, cross-validation is a popular choice for model selection, involving training of a number of randomly initialized models. NAS addresses this challenge by automating the discovery of high-performing architectures within a predefined search space (Poyser et al., 2024). Most methods for NAS incur a high search cost in the form of (partially) training the architectures to score them, and/or training a *search model* that can make the architecture search efficient. Recently, there has been a increased focus on cheapening the search process, while retaining or improving the quality of the selected architectures. Some of these are NTK-based methods like TE-NAS (Chen et al., 2021), which uses the condition number of the NTK Gram Matrix, and KNAS (Xu et al., 2021), which uses the mean of the Matrix’s entries. These studies report competitive performances on real-world computer vision tasks in NAS benchmarks (Dong & Yang, 2020; Liu et al., 2019; Ying et al., 2019).
Citing the success of NTK-based NAS methods that use correlation between architecture scores and the corresponding test accuracy as a measure of their strategy’s efficacy, we aim to propose a scoring method that is closely correlated to the training loss. As with other NTK-based methods, it will utilize the Gram matrix associated with the NTK, but for the purpose of computing the kernel canonical correlation (KCC) (Akaho, 2007; Melzer et al., 2001) between the inputs and the outputs.
Theory
For simplicity, we assume that we have a univariate regression task at hand. Consider the set of neural network functions, \(\mathcal{A}\), parameterized by some architecture, \(\mathbf{A}\). Most neural networks designed for regression have a linear output layer, and we assume the same for \(\mathbf{A}\). In that case, minimizing the Mean Squared Error (MSE) between the network output and the regression labels, is equivalent to maximizing their correlation, since the parameters of the output layer can be adjusted post-training to get the minimizer of the MSE (Englisch et al., 1994). Accordingly, we define the score for architecture \(\mathbf{A}\) as
\[\begin{aligned} S\left(\mathbf{A}\right) = \max_{f\in\mathcal{A}} \text{Corr}\left(f\left(\mathbf{X}\right), \mathbf{Y}\right) = \max_{f\in\mathcal{A}, g\in\mathcal{L}} \text{Corr}\left(f\left(\mathbf{X}\right), g\left(\mathbf{Y}\right)\right) \end{aligned}\]where \(\mathcal{L}\) is the space of linear functions on \(\mathbf{Y}\). This optimization is, in general, hard to perform. Instead, if we were to optimize over some RKHS, then the optimization is equivalent to performing kernel canonical correlation analysis (KCCA) with the associated reproducing kernel on \(\mathbf{X}\) and the linear kernel on \(\mathbf{Y}\). Accordingly, we seek an RKHS that can approximate the space of functions the network can converge to.
Neural Tangent Kernel
Consider an \(L\)-layer fully-connected feed-forward network under the NTK parameterization:
\[\begin{aligned} \mathbf{x}^{l+1} = \frac{1}{\sqrt{n_l}} \mathbf{W}^{l+1}\mathbf{x}^l + \mathbf{b}^{l+1} \end{aligned}\]with \(n_l\) being the width of layer \(l\), and all parameters initialized i.i.d. from the standard normal distribution, \(\mathcal{N}\left(0, 1\right)\). We define \(\mathbf{\theta}^l ≔ \text{vec}\left(\left\{W^l,b^l\right\}\right)\) as the collection of parameters in layer \(l\), and \(\mathbf{\theta} ≔ \text{vec}\left(\cup_{l=1}^L \mathbf{\theta}^l \right)\) as the collection of all parameters.
The parameter dynamics and the predictive dynamics for this model under gradient flow can be written as:
\[\begin{aligned} \dot{\theta}_t &= -\eta \nabla_{\theta} \mathcal{L}\left(\mathcal{D}; \theta_t\right) = -\eta \nabla_{\theta} f\left(\mathbf{X}; \theta_t\right)^T \nabla_{f} \mathcal{L}(\mathcal{D}; \theta_t) \\ \dot{f}(\mathbf{X}; \theta_t) &= \nabla_{\theta}f(\mathbf{X};\theta_t) \dot{\theta}_t = -\eta \underbrace{\nabla_{\theta} f(\mathbf{X}; \theta_t) \nabla_{\theta} f(\mathbf{X}; \theta_t)^T}_{\triangleq \hat{\Theta}_t(\mathbf{X}, \mathbf{X})} \nabla_{f} \mathcal{L}(\mathcal{D}; \theta_t) \end{aligned}\]where \(\mathcal{D}\) is the training data set, \(\mathcal{L}\) is the loss function, \(\eta\) is the learning rate, and \(\hat{\Theta}_t\) is the Empirical Neural Tangent Kernel (NTK) (Jacot et al., 2018).
In the infinite-width limit, the NTK converges in distribution to an analytical limit, \(\Theta\), and the NNs evolve as linear models (Lee et al., 2019). Under gradient flow, the predictive distribution of this wide network converges to a normal distribution (Lee et al., 2019), \(f^{\text{lin}}_{\theta_{\infty}}\left(x\right) \sim \mathcal{N}\left(\mu_{\text{NN}}\left(x\right),\Sigma_{\text{NN}}\left(x,x\right)\right)\), where
\[\begin{aligned} &\mu_{\text{NN}}\left(x\right) = \Theta\left(x,\mathbf{X}\right) \Theta\left(\mathbf{X}\right)^{-1} \mathbf{Y} \\ &\begin{split} \Sigma_{\text{NN}}\left(x,x'\right) &= \mathcal{K}\left(x,x'\right) + \Theta\left(x,\mathbf{X}\right) \Theta\left(\mathbf{X}\right)^{-1} \mathcal{K}\left(\mathbf{X}\right) \Theta\left(\mathbf{X}\right)^{-1} \Theta\left(\mathbf{X},x'\right) \\ &\quad - \left(\Theta\left(x,\mathbf{X}\right) \Theta\left(\mathbf{X}\right)^{-1} \mathcal{K}\left(\mathbf{X},x'\right) + \mathcal{K}\left(x',\mathbf{X}\right) \Theta\left(\mathbf{X}\right)^{-1} \Theta\left(\mathbf{X},x\right)\right) \end{split} \end{aligned}\]where \(\mathcal{K}\) denotes the NN-GP kernel (G. Matthews et al., 2018), defined as \(\mathcal{K}\left(x,x’\right) = \mathbb{E} \left[f_{\theta}\left(x\right) \cdot f_{\theta}\left(x’\right)\right]\) which also converges in the infinite-width limit.
The covariance \(\Sigma_{\text{NN}}\) is inconvenient to deal with, involving two computationally expensive kernel computations, and a series of cubic-time matrix operations. To tackle this, we can augment the forward pass (denoted by \(\tilde{f}_{\theta}\)) by adding a random, untrainable function, which results in the distribution at convergence having a GP-posterior-like form, with \(\Theta\) as the covariance kernel (He et al., 2020), \(\tilde{f}_{\theta_{\infty}} \sim \mathcal{N}\left(\mu_{\text{NTK}},\Sigma_{\text{NTK}}\right)\), where \(\mu_{\text{NTK}} = \mu_{\text{NN}}\) and:
\[\begin{aligned} \Sigma_{\text{NTK}}\left(x,x'\right) = \Theta\left(x,x'\right) - \Theta\left(x,\mathbf{X}\right) \Theta\left(\mathbf{X}\right)^{-1} \Theta\left(\mathbf{X},x'\right) \end{aligned}\]Importantly, in my Bachelor’s thesis project (Hemachandra et al., 2023), we showed that the ratio between \(\Sigma_{\text{NN}}\left(x,x’\right)\) and \(\Sigma_{\text{NTK}}\left(x,x’\right)\) can be tightly upper bounded, and hence, the NTK-GP posterior, \(\mathcal{N}\left(\mu_{\text{NTK-GP}}\left(x\right),\Sigma_{\text{NTK-GP}}\left(x,x\right)\right)\), may be considered a reasonable approximation for the predictive distribution, \(\mathcal{N}\left(\mu_{\text{NN}}\left(x\right),\Sigma_{\text{NN}}\left(x,x\right)\right)\).
Kernel Canonical Correlation Analysis
We now consider the more practical architectures which have finite width, so that the feature mapping, \(x \mapsto \nabla_{\theta}f_{\theta}\left(x\right)\), associated with the empirical NTK, \(\Theta\), is finite dimensional. Hence, the samples from the NTK-GP prior are almost-surely contained in the RKHS, \(\mathcal{H}_{\Theta}\), associated with \(\Theta\) (<span style="color:red">not sure about this part</span>). Since the GP posterior does not have support where the prior does not, the posterior samples are also contained in this RKHS. Therefore, we can use the RKHS associated with the NTK to compute the KCC:
\[\begin{aligned} S\left(\mathbf{A}\right) \approx \max_{f\in\mathcal{H}_{\Theta}, g\in\mathcal{L}} \text{Corr}\left(f\left(\mathbf{X}\right), g\left(\mathbf{Y}\right)\right) \end{aligned}\]This value is trivially equal to \(\pm 1\) when the kernel matrices associated with \(\mathbf{X}\) and \(\mathbf{Y}\) are full-rank (Gretton, Herbrich, et al., 2005). A common practice is to add some regularization to this problem by penalizing rougher witness functions, \(f\) and \(g\), which yields the following generalized-eigenvalue problem:
\[\begin{aligned} \begin{bmatrix} 0 & \tilde{\Theta}\tilde{L} \\ \tilde{\mathbf{L}}\tilde{\Theta} & 0 \end{bmatrix} \mathbf{u} = \lambda \begin{bmatrix} \tilde{\Theta}^2 + m\epsilon\tilde{\Theta} & 0 \\ 0 & \tilde{\mathbf{L}}^2 + m\epsilon\tilde{\mathbf{L}} \end{bmatrix} \mathbf{u} \end{aligned}\]where \(\tilde{\Theta} = \mathbf{H}\Theta\left(\mathbf{X}\right)\mathbf{H}\) and \(\tilde{\mathbf{L}} = \mathbf{H}\mathbf{Y}\mathbf{Y}^T\mathbf{H}\) are the centered Gram matrices, \(\mathbf{H} = \mathbf{I}_m - 1/m \cdot \mathbf{1}_{m\times m}\) is the centering matrix, \(m\) is the number of data points and \(\epsilon\) is the regularization constant. The regularized canonical correlation is the maximum eigenvalue of this problem, \(\gamma = \lambda_{\text{max}}\).
Limitations and Extensions
The proposed scoring function is expected to be over-confident for two reasons:
Using \(\Theta\left(\mathbf{X}\right)\) means that we are optimizing the correlation over the NTK-GP prior’s support, which is larger than the posterior’s support.
The witness function, \(f\), used for computing the canonical correlation is the best NN fitting the training set. This amounts to ignoring the probabilistic information in the prior/posterior altogether. Therefore, it could represent an over-fitting scenario.
Using a linear kernel might make sense for regression, but needs justification for classification tasks. What even is an appropriate kernel in the classification case? This is an important issue because most NAS benchmarks involve image classification tasks, like CIFAR-10/100 and ImageNet.
Another limitation is the lack of interpretability of KCC, which is crucial in many real-world applications. To address this limitation, we may explore alternate kernel-based measures, such as those based on Hilbert-Schmidt Independence Criterion (HSIC) (Gretton, Bousquet, et al., 2005), and the Kernel Target Alignment (KTA) (Cortes et al., 2012), as proposed in (Chang et al., 2013).