Guide

Assume:

  • N layers in total
  • n represents a specific layer
  • batchSize number of input images handled at once.
  • For the very last layer, activation function is Softmax.
  • For all hidden layers, activation function is ReLU.
  • m is the number of dimensions of the final predication vector.
    • This is also the number of ‘columns’ of the output of the very last layer.
  • j is the index of an input in one batch.
  • s represents the Softmax result of the very last layer.
  • y is the true label of all the inputs in one batch. Shape: (batchSize, m)
  • η is the learning rate.

Things we need to calculate:

  • Lose Function: Cross Entropy:
  • Derivative of to the raw output of the last layer:
    • Handling of Softmax is covered inherently in this step.
  • Derivative of to the weights of any layer:
    • Used to update the weights of the network.
  • Derivative of to the biases of any layer:
    • Used to update the biases of the network.
  • Derivative of to the activation results of layer n-1: from the derivative of to the raw output of layer n:
    • This is crossing the matrix multiplication boundary
    • This is not needed when (i.e. skip if n is the first layer)
  • Derivative of to the raw output of layer n: , from the derivative of to the activation results of layer n:

Loss function: Cross Entropy

Given batchSize > 1, should be a column vector with batchSize number of rows, and 1 element per row (i.e. only has one column).

But we don’t need to have a separate kernel for this step. The calculation can be merged to the kernel that calculates the derivative of to the raw output of the very last layer.

Loss derivative of raw output of the very last layer

I.e. Derivative of to the raw output of the last layer:

// Context: Only for the very last layer, not crossing layer.
// Input: actual_label, predicated_label (i.e. softmax result)
// Output: dl_dx
__global__ void cross_entropy_backprop(
  int batch_size,  // this is also the number of rows for the input matrix,
  int preds_width, // the width of the 'preds' and 'actual' matrix
  float* preds,    // the flatten input array, predicated label
  float* actual,   // the flatten actual array, shape: (batch_size, preds_width),
  float* dl_dx,    // the partial derivative of the raw output of the last layer
                   // the shape is (batch_size, preds_width). Note that this is 
                   // not the 'weights' matrix shape. The shape of the last layer
                   // weights matrix should be:
                   // - Rows: last layer input data width,
                   // - Columns: preds_wdith (or actual_wdith).
) {
  int in_column_index = blockIdx.x * blockDim.x + threadIdx.x;
  int in_row_index = blockIdx.y * blockDim.y + threadIdx.y;
  if (in_row_index < batch_size && in_column_index < preds_width) {
    int idx =  in_row_index * preds_width + in_column_index;
    dl_dx[idx] = preds[idx] - actual[idx];
  }
}

Loss derivative of weights of a layer, given the loss derivative of the raw output

I.e. Derivative of to the weights of any layer: given

注意,左边 Sum 的每个元素都是一个矩阵, shape 就是 weights 的shape,总合还是一个矩阵。

Updating Weights

So, for weights updating:

Loss derivative of biases of a layer, given the loss derivative of the raw output

I.e. Derivative of to the biases of any layer: given

Updating Biases

So, for biases updating:

Updating biases and weights together

// Context: Same layer, update w, b with this layer's input and output
// Input: dl_dx, weights, biases
// Output: weights, biases
__global__ void update_weights_biases(
  int w, int h,   // The width and height of the weights matrix
                  // the 'w' is also the width of the output of the last layer, i.e. width of 'x'
                  // the 'h' is also the width of the previous layer's output, i.e. width of 'a'
                  // 'h' is also the width of the previous layer's weights matrix.
  int batch_size, // The batch size.
  float lr,       // Learning rate.
  float* a,       // The activation results of the previous layer, shape: [batch_size x h].
  float* dl_dx,   // The partial derivative of the loss function with respect to
                  // the raw output of this layer, shape: [batch_size x w].
  float* weights, // The weights matrix of the this layer.
  float* biases,  // The biases vector of the this layer, shape: [1 x w].
) {
  // row_index, col_index points to the [i, j] in the weights matrix.
  // col_index also represents the location in the biases row vector too.
  int col_index = blockIdx.x * blockDim.x + threadIdx.x;
  int row_index = blockIdx.y * blockDim.y + threadIdx.y;
 
  if (row_index >= h || col_index >= w) {
    return;
  }
 
  // calculate the a^T * dl_dx:
  // Full row of a^T with row_index -> Full column of a with row_index as column index,
  // Full column of dl_dx with col_index.
  float dw_sum = 0.0f;
  for (int i = 0; i < batch_size; i++) {
    dw_sum += a[i * h + row_index] * dl_dx[i * w + col_index];
  }
  // update weights
  weights[row_index * w + col_index] -= lr * dw_sum / float(batch_size); 
 
  // calculate the [1...1] * dl_dx:
  float db_sum = 0.0f;
  for (int i = 0; i < batch_size; i++) {
    db_sum += dl_dx[i * w + col_index];
  }
  // update biases
  biases[col_index] -= lr * db_sum / float(batch_size);
 
  // Notes for the biases update:
  // I think there are 'duplicated'  calculations here, because the biases updates
  // only needs to be calculated once. But the 'row_index' means this kernel will
  // be called multiple times in parallel.
  // But I guess this is fine for the correctness, because the 'duplicated'
  // caculations won't occur sequentially, but simultaneously and the read value
  // of biases[col_index] should all be the same for all kernel parallel threads.
  // So the final value written to the biases[col_index] in all parallel threads
  // will be exactly the same.
}

Loss derivative of a layer input, given the loss derivative of the raw output

I.e. Derivative of to the activation results of layer n-1: given the derivative of to the raw output of layer n:

I.e. crossing the weigths matrix multiplication

// Context: Crossing layer, raw result -> last layer activation result
// Input: dl_dx, weights
// Output: dl_da (the result from the previous layer)
__global__ void backprop_multiplication(
  int batch_size,   // The batch size.
  int x_width,      // The width of the output of the layer, i.e. width of 'x'
  int a_width,      // The width of the input of the layer, i.e. width of 'a'
                    // This is also the heights of the current weights matrix.
  float* weights,   // The weights matrix of the layer, shape: [a_width x x_width].
  float* dl_dx,     // The partial derivative of the loss function with respect to
                    // the raw output of the layer, shape: [batch_size x x_width].
  float* dl_da,     // The partial derivative of the loss function with respect to
                    // the activation results of the previous layer, shape: [batch_size x a_width].
) {
  // row_index, col_index pointing to the [i, j] in the output matrix: dl_da
  int col_index = blockIdx.x * blockDim.x + threadIdx.x;
  int row_index = blockIdx.y * blockDim.y + threadIdx.y;
 
  if (row_index >= batch_size || col_index >= a_width) {
    return;
  }
  float dl = 0.0f;
  for (int i = 0; i < x_width; i++) {
    // The Transformation for weights and weights_T:
    //  weights_T[i, j] = weights[j, i] = weights[j * x_width + i]
    dl += dl_dx[row_index * x_width + i] * weights[col_index * x_width + i];
  }
  dl_da[row_index * a_width + col_index] = dl;
}

Loss derivative of a layer raw output, given the loss derivative of the activation results

I.e. Derivative of to the raw output of layer n: , given the derivative of to the activation results of layer n:

I.e. crossing the ReLU activation function

// Context: Same layer, activation result -> raw result
// Input: a, dl_da
// Output: dl_dx
__global__ void backprop_relu(
  int batch_size, // The batch size.
  int x_width,    // The width of the raw output of the layer, i.e. width of 'x'
                  // This is also the width of the weights matrix of this layer,
                  // though weights matrix is not used at all in this step.
  float* a        // The activation results of this layer, shape:
                  // [batch_size x x_width].
  float* dl_da    // The partial derivative of the loss function with respect to
                  // the activation results of this layer, shape:
                  // [batch_size x x_width].
  float* dl_dx    // The partial derivative of the loss function with respect to
                  // the raw output of this layer, shape:[batch_size x x_width].
) {
  // row_index, col_index pointing to the [i, j] in the output matrix: dl_dx
  int col_index = blockIdx.x * blockDim.x + threadIdx.x;
  int row_index = blockIdx.y * blockDim.y + threadIdx.y;
 
  if (row_index >= batch_size || col_index >= x_width) {
    return
  }
  float act_val = a[row_index * x_width + col_index];
  dl_dx[row_index * x_width + col_index] = act_val > 0.0f ? dl_da[row_index * x_width + col_index] : 0.0f;
}