Ju Sun (孙举)

Dedicated to my research and life

Geometry of Two-Layer Linear Neural Network

| Comments

Two-layer linear neural network is arguably the simplest, nontrivial neural network. This post fleshes out the details of this seminal paper 1, in which the authors set out to understand the solution space of this optimization problem \[\begin{align} \label{eq:nn_main_global} \min \; f\left(\mathbf A, \mathbf B\right) = \frac{1}{2}\left\| \mathbf Y - \mathbf A \mathbf B \mathbf X \right\|_{F}^2, \mathbf A \in \mathbb{R}^{n \times p}, \mathbf B \in \mathbb{R}^{p \times n} \end{align}\] where \(\mathbf Y \in \mathbb{R}^{n \times m}\) and \(\mathbf X \in \mathbb{R}^{n \times m}\) are known with \(m \geq n\). They are stacked (in columns) response-input vector pairs used to train the network. When \(p \ll n\) and \(\mathbf Y = \mathbf X\), it is easy to see any optimum \(\mathbf A \mathbf B = \mathbf U \mathbf U^\top\), where \(\mathbf U\) is an orthonormal basis for the subspace spanned by the first \(p\) eigenvectors of \(\mathbf X \mathbf X^\top\). Hence this formulation is closely related to the well-known principal component analysis. It turns out one interesting aspect about this optimization problem is that under some conditions the only local minimum point (up to invertible linear transformations, of course) is also global minimum. We will use some short-hand notation in the proof: \(\mathbf \Sigma_{\mathbf X\mathbf X} \doteq \mathbf X \mathbf X^\top\), similarly for \(\mathbf \Sigma_{\mathbf Y \mathbf Y^\top}\), \(\mathbf \Sigma_{\mathbf X \mathbf Y}\) and \(\mathbf \Sigma_{\mathbf Y \mathbf X}\).

Theorem 1 Suppose \(\mathbf X\) is full rank and hence \(\mathbf \Sigma_{\mathbf X \mathbf X}\) is invertible. Further assume \(\mathbf \Sigma \doteq \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} \mathbf \Sigma _{\mathbf X \mathbf Y}\) is full rank with \(n\) distinct eigenvalues \(\lambda_1 > \cdots > \lambda_n > 0\). Then \(\eqref{eq:nn_main_global}\) has no spurious local minima, except for equivalent versions of global minimum due to invertible transformations.

Proof.  The partial derivatives of \(f\) are [this post provides a useful trick to do this kind of calculation] \[\begin{align*} \frac{\partial}{\partial \mathbf A} f & = - \left(\mathbf Y - \mathbf A \mathbf B \mathbf X\right) \mathbf X^\top \mathbf B^\top, \newline \frac{\partial}{\partial \mathbf B} f & = - \mathbf A^\top \left(\mathbf Y - \mathbf A \mathbf B \mathbf X\right) \mathbf X^\top. \end{align*}\] Setting the derivatives to zero, we obtain that critical points of \(f\) are characterized by \[\begin{align} \label{eq:crit_ch1} \mathbf A^\top \mathbf \Sigma_{\mathbf Y\mathbf X} & = \mathbf A^\top \mathbf A \mathbf B \mathbf \Sigma_{\mathbf X \mathbf X} \end{align}\] and \[\begin{align} \label{eq:crit_ch2} \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf B^\top & = \mathbf A \mathbf B \mathbf \Sigma_{\mathbf X \mathbf X} \mathbf B^\top. \end{align}\] Now we claim that at any critical point of \(f\), \[\begin{align} \label{eq:crit_ch3} \mathbf A \mathbf B = \mathcal P_{\mathbf A} \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} \end{align}\] with \(\mathbf A\) satisfying \[\begin{align} \label{eq:crit_ch4} \mathcal P_{\mathbf A} \mathbf \Sigma = \mathcal P_{\mathbf A} \mathbf \Sigma \mathcal P_{\mathbf A} = \mathbf \Sigma \mathcal P_{\mathbf A}, \end{align}\] where \(\mathcal P\) is the ortho-projector projecting onto the column span of the sub-indexed matrix. To see this, from \(\eqref{eq:crit_ch1}\) we obtain \[\begin{align*} \mathbf B = \left(\mathbf A^\top \mathbf A\right)^{\dagger} \mathbf A^\top \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} + \left(\mathbf I - \mathbf A^{\dagger} \mathbf A\right) \mathbf L = \mathbf A^{\dagger} \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} + \left(\mathbf I - \mathbf A^{\dagger} \mathbf A\right) \mathbf L, \end{align*}\] where \(\dagger\) denotes the Moore-Penrose pseudo-inverse and \(\mathbf L\) is any \(p \times n\) matrix. So \[\begin{align*} \mathbf A \mathbf B = \mathbf A \mathbf A^{\dagger} \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} + \mathbf A \left(\mathbf I - \mathbf A^{\dagger} \mathbf A\right) \mathbf L = \mathcal P_{\mathbf A} \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1}, \end{align*}\] where we have used the fact that \(\mathbf X = \mathbf X \mathbf X^{\dagger} \mathbf X\) for any matrix \(\mathbf X\). Now by \(\eqref{eq:crit_ch2}\), \[\begin{align*} \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf B^\top \mathbf A^\top = \mathbf A \mathbf B \mathbf \Sigma_{\mathbf X \mathbf X} \mathbf B^\top \mathbf A^\top, \end{align*}\] substituting the analytic expression obtained above for \(\mathbf A \mathbf B\) yields \[\begin{align*} \mathcal P_{\mathbf A} \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} \mathbf \Sigma_{\mathbf X \mathbf Y} \mathcal P_{\mathbf A} = \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} \mathbf \Sigma_{\mathbf X \mathbf Y} \mathcal P_{\mathbf A} \Longleftrightarrow \mathcal P_{\mathbf A} \mathbf \Sigma \mathcal P_{\mathbf A} = \mathbf \Sigma \mathcal P_{\mathbf A}. \end{align*}\] Note also \(\mathbf \Sigma \mathcal P_{\mathbf A} = \mathcal P_{\mathbf A} \mathbf \Sigma\) since \(\mathcal P_{\mathbf A}\mathbf \Sigma \mathcal P_{\mathbf A}\) is symmetric.

Next we show the claimed about global minimum. To this end, suppose at a critical point \(\operatorname{rank}\left(\mathbf A\right) = r \leq p\). Assuming the eigendecomposition of \(\mathbf \Sigma\) as \(\mathbf \Sigma = \mathbf U \mathbf \Lambda \mathbf U^\top\), then we have \[\begin{align*} \mathcal P_{\mathbf U^\top \mathbf A} = \mathbf U^\top \mathbf A \left(\mathbf A^\top \mathbf U \mathbf U^\top \mathbf A\right)^{\dagger} \mathbf A^\top \mathbf U = \mathbf U^\top \mathbf A\left(\mathbf A^\top \mathbf A\right)^{\dagger} \mathbf A^\top \mathbf U = \mathbf U^\top \mathcal P_{\mathbf A} \mathbf U, \end{align*}\] or \(\mathcal P_{\mathbf A} = \mathbf U \mathcal P_{\mathbf U^\top \mathbf A} \mathbf U^\top\). Now by \(\eqref{eq:crit_ch4}\) we obtain \[\begin{align*} \mathbf U \mathcal P_{\mathbf U^\top \mathbf A} \mathbf U^\top \mathbf U \mathbf \Lambda \mathbf U^\top = \mathcal P_{\mathbf A} \mathbf \Sigma = \mathbf \Sigma \mathcal P_{\mathbf A} = \mathbf U \mathbf \Lambda \mathbf U^\top \mathbf U \mathcal P_{\mathbf U^\top \mathbf A} \mathbf U^\top, \end{align*}\] which yields \[\begin{align} \mathcal P_{\mathbf U^\top \mathbf A} \mathbf \Lambda = \mathbf \Lambda \mathcal P_{\mathbf U^\top \mathbf A}. \end{align}\] It is easy to from the above \(\mathcal P_{\mathbf U^\top \mathbf A}\) must be diagonal as \(\mathbf \Lambda\) is diagonal with distinct values by our assumption. Then \(\mathcal P_{\mathbf U^\top \mathbf A} = \mathbf I_{\mathcal J}\) for an index set \(\mathcal J = \left\{\ell_1, \cdots, \ell_r\right\} \in [n]\) with \(1 \leq \ell_1 < \ell_2 < \cdots < \ell_r \leq n\) and \[\begin{align} \label{eq:diagonal_constraint} \mathcal P_{\mathbf A} = \mathbf U \mathcal P_{\mathbf U^\top \mathbf A} \mathbf U^\top = \mathbf U \mathbf I_{\mathcal J} \mathbf U^\top = \mathbf U_{\mathcal J} \mathbf U_{\mathcal J}^\top. \end{align}\] Hence column space is identical to the column space of \(\mathbf U_{\mathcal J}\) and hence at any critical point of \(f\), \(\mathbf A\) can be written in the form \[\begin{align} \label{eq:form_a_final} \mathbf A = \left[\mathbf U_{\mathcal J}, \mathbf 0_{n \times \left(p-r\right)}\right] \mathbf C \end{align}\] for some invertible \(\mathbf C\). So by \(\eqref{eq:crit_ch1}\) again \(\mathbf B\) will take the form \[\begin{align} \mathbf B = \mathbf A^{\dagger} \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} + \left(\mathbf I - \mathbf A^{\dagger} \mathbf A\right) \mathbf L \end{align}\] for some \(\mathbf L\) constrained by \(\eqref{eq:crit_ch2}\). Note that from \(\eqref{eq:form_a_final}\) we have \[\begin{align*} \mathbf A^{\dagger} = \mathbf C^{-1} \left[\mathbf U_{\mathcal J}^\top; \mathbf 0\right], \end{align*}\] where we have used Matlab notation for matrix concatenation. So \[\begin{align} \mathbf B & = \mathbf C^{-1} \begin{bmatrix} \mathbf U^\top_{\mathcal J} \newline \mathbf 0 \end {bmatrix} \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} + \left(\mathbf I - \begin{bmatrix} \mathbf U^\top_{\mathcal J} \newline \mathbf 0 \end {bmatrix} \begin{bmatrix} \mathbf U_{\mathcal J}, \mathbf 0 \end{bmatrix} \right) \mathbf L \newline & = \mathbf C^{-1}\begin{bmatrix} \mathbf U^\top_{\mathcal J} \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} \newline \mathbf 0 \end{bmatrix} + \mathbf C^{-1} \begin{bmatrix} \mathbf 0 & \newline & \mathbf I_{p - r} \end{bmatrix} \mathbf C \mathbf L \newline & = \mathbf C^{-1} \begin{bmatrix} \mathbf U^\top_{\mathcal J} \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} \newline \text{last $p-r$ rows of $\mathbf C \mathbf L$} \end{bmatrix}. \end{align}\] We will need to discuss two cases separately.

When \(r < p\), we can perturb the last \(p-r\) rows of \(\mathbf B\) with arbitrarily small noise to make \(\widehat{\mathbf B}\) full rank, and \(\mathbf A \mathbf B = \mathbf A \widehat{\mathbf B}\). Since when \(\widehat{\mathbf B}\) is full rank, \(f\left(\mathbf A, \widehat{\mathbf B}\right)\) is strictly convex in \(\mathbf A\), for any \(\varepsilon \in \left(0, 1\right)\), we can make \(\overline{\mathbf A} = \left(1-\varepsilon\right)\mathbf A + \varepsilon \widehat{\mathbf A}\left(\widehat{\mathbf B}\right)\) such that \[\begin{align*} f\left(\overline{A}, \widehat{\mathbf B}\right) < f\left(\mathbf A, \widehat{\mathbf B}\right) = f\left(\mathbf A, \mathbf B\right). \end{align*}\] Since \(\varepsilon\) can also be made arbitrarily small and \(\overline{\mathbf A} \to \mathbf A\) as \(\varepsilon \to 0\), \(\left(\mathbf A, \mathbf B\right)\) is a saddle point.

When \(r = p\), we have \[\begin{align*} \mathbf A = \mathbf U_{\mathcal J} \mathbf C, \quad \mathbf B = \mathbf C^{-1} \mathbf U_{\mathcal J}^\top \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} \end{align*}\] for some index set \(\left|\mathcal J\right| = p\) and invertible \(\mathbf C \in \mathbb{R}^{p \times p}\). Hence there are \(\binom{n}{p}\) possible choice for \(\mathcal J\) in this case and so, up to equivalence, \(\binom{n}{p}\) critical points with full rank. Assuming the natural ordering induced by the eigenvalues of \(\mathbf \Sigma\), we’ll show that whenever \(\mathcal J \neq \left\{1, \dots, p\right\}\), the corresponding critical point is a saddle point. First notice that when \(\mathbf A\) is full rank, \(\mathbf B\) is uniquely defined by \(\eqref{eq:crit_ch1}\), so \[\begin{align*} f\left(\mathbf A, \mathbf B\left(\mathbf A\right)\right) & = \left\| \mathbf Y - \mathbf A \mathbf B\left(\mathbf A\right) \mathbf X \right\|_{F}^2 /2 \newline & = \left\| \mathbf Y - \mathcal P_{\mathbf A} \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} \mathbf X \right\|_{F}^2 /2 \newline & = \left\| \mathbf Y \right\|_{F}^2/2 - \left\langle \mathbf \Sigma_{\mathbf Y \mathbf X}, \mathcal P_{\mathbf A} \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} \right \rangle + \left \langle \mathbf \Sigma_{\mathbf Y \mathbf X}^\top \mathcal P_{\mathbf A}, \mathbf \Sigma_{\mathbf X \mathbf X}^{-1} \mathbf \Sigma_{\mathbf X \mathbf Y} \mathcal P_{\mathbf A} \right \rangle/2 \newline & = \left\| \mathbf Y \right\|_{F}^2/2 - \left\langle \mathcal P_{\mathbf A}, \mathbf \Sigma \right\rangle + \left\langle \mathcal P_{\mathbf A}, \mathbf \Sigma \right\rangle/2 \newline & = \left\| \mathbf Y \right\|_{F}^2/2 - \left\langle \mathcal P_{\mathbf A}, \mathbf \Sigma \right\rangle/2 = \left\| \mathbf Y \right\|_{F}^2/2 - \left\langle \mathcal P_{\mathbf U^\top \mathbf A}, \mathbf \Lambda \right\rangle/2 \end{align*}\] where in the simplification we have used \(\eqref{eq:crit_ch3}\) and \(\eqref{eq:crit_ch4}\). Now by \(\eqref{eq:diagonal_constraint}\) and results following that we have \[\begin{align} f\left(\mathbf A, \mathbf B\left(\mathbf A\right)\right) = \left\| \mathbf Y \right\|_{F}^2/2 - \frac{1}{2}\sum_{j \in \mathcal J} \lambda_j. \end{align}\] Now if \(\mathcal J \neq \left\{1, \cdots, p\right\}\), there exists some index \(\ell \in [p]\) but \(\ell \notin \mathcal J\). For some index \(k \in \mathcal J\) such that \(k > \ell\), consider an \(\varepsilon\)-perturbation to \(\mathbf u_k\): \(\widehat{\mathbf u}_k = \left(\mathbf u_k + \varepsilon \mathbf u_{\ell}\right) / \sqrt{1+\varepsilon^2}\) and replace \(\mathbf u_k\) in \(\mathbf U_{\mathcal J}\) by \(\widehat{\mathbf u}_k\) to form \(\widehat{\mathbf U}_{\mathcal J}\). It is obvious still we have \(\widehat{\mathbf U}_{\mathcal J}^\top \widehat{\mathbf U}_{\mathcal J} = \mathbf I_{\mathcal J}\). Now let \(\widehat{\mathbf A} = \widehat{\mathbf U}_{\mathcal J} \mathbf C\) and \(\widehat{\mathbf B} = \mathbf C^{-1} \widehat{\mathbf U}_{\mathcal J}^\top \mathbf \Sigma_{\mathbf Y \mathbf X} \mathbf \Sigma_{\mathbf X \mathbf X}^{-1}\). Now the diagonal elements of \(\mathbf M \doteq \mathcal P_{\mathbf U_{\mathcal J}^\top \widehat{\mathbf A}}\) are: \[\begin{align*} M_{ii} = \begin{cases} 0 & i \notin \mathcal J \cup \left\{\ell\right\} \newline 1 & i \in \mathcal J \; \text{and}\; i \neq k \newline 1/\left(1+\varepsilon^2\right) & i = k \newline \varepsilon^2 /\left(1+\varepsilon^2\right) & i = \ell \end{cases}. \end{align*}\] So the perturbed objective \[\begin{align*} f\left(\widehat{\mathbf A}, \widehat{\mathbf B}\right) & = \left\| \mathbf Y \right\|_{F}^2/2 - \left\langle \mathcal P_{\mathbf U^\top \widehat{\mathbf A}}, \mathbf \Lambda \right\rangle/2 \newline & = \left\| \mathbf Y \right\|_{F}^2/2 - \frac{1}{2}\sum_{j \in \mathcal J} \lambda_j - \varepsilon^2\left(\lambda_{\ell} - \lambda_{k}\right)/\left(1+\varepsilon^2\right). \end{align*}\] Since by our construction \(\lambda_{\ell} - \lambda_k > 0\), for all \(\varepsilon > 0\) \[\begin{align} f\left(\widehat{\mathbf A}, \widehat{\mathbf B}\right) < f \left(\mathbf A, \mathbf B\right), \end{align}\] and also \(\left(\widehat{\mathbf A}, \widehat{\mathbf B}\right)\) can be arbitrarily close to \(\left(\mathbf A, \mathbf B\right)\) as \(\varepsilon \to 0\). So \(\left(\mathbf A, \mathbf B\right)\) is a saddle point. \(\Box\)

One may naturally think of using second-order geometry directly, i.e., deriving the Hessian matrix and looking at its definiteness, to investigate the behaviors of the critical points. It is likely to be a principled but cumbersome approach, as compared to the analysis presented here.

  1. Baldi, Pierre, and Kurt Hornik. “Neural networks and principal component analysis: Learning from examples without local minima." Neural networks 2, no. 1 (1989): 53-58. ftp://canuck.seos.uvic.ca/CFD-NN/1-s2.0-0893608089900142-main.pdf