Categories Machine Learning

Neural network gradients, chain rule and PyTorch forward/backward

Press enter or click to view image in full size

Photo by Miltiadis Fragkidis on Unsplash

How to perform gradient descent on a neural network at math level as well at code level? I found that many people have difficulty answering this seemingly basic data science interview question.

This article answers this question from different perspectives:

  • How to compute network gradient the non-modular way?
  • What is vector differentiation?
  • How to use the chain rule of differentiation to achieve modular gradient computation?
  • How is the forward and backward pass implemented in PyTorch?
  • Why and how the mysterious context argument in PyTorch’s forward allows us to pass data from the forward pass to the backward pass?
  • What is the purpose of the the even more mysterious grad_output argument in PyTorch’s backward method, which allows us to implement the chain rule?

A simple regression neural network

Let’s look at the following neural network:

  • It has input x=[x₁, x₂]ᵀ, “ᵀ” is the transpose operator, in other words, x is a 2×1 matrix.
  • It has two hidden layers V=[v₁, v₂]ᵀ and Z=[z₁, z₂]ᵀ. Both V and Z are 2×1 matrices. The two hidden layers contain trainable…

Written By

You May Also Like