[2025-07-28-Monday]

Stage 1: graph ingestion

  • Goal: tranform to ‘IR’.
  • Process: nn.Module Computation Graph
    • Node: Convolution or Add operation, generally, an operation
    • Edge: Flow of tensor

Stage 2: High level graph optimization

  • Goal: Hardware-agnostic graph level optimization
  • Example:
    • Operator Fusion (Kernel Fusion?)
      • Merge several consecutive simple operation (like convolution + bias addition + ReLU act) into one fused operation.
    • Constant folding (constant propagation)
    • DCE (dead code elimination)
    • Layout Tranformation (hmm)
      • 调整 tensor 在内存中的存储顺序,like NCHW vs NHWC

Stage 3: Low-level optimization and code generation

  • hardware specific optimizations
    • cache? computer core? special intrinsics
  • Examples:
    • Tiling or blocking
    • instruction scheduling

Mainstream solutions:

  • TensorRT (NVIDIA) - But only available on Nvidia GPUs
  • TVM (Apache TVM) - open source framework. Support X86, ARM to different AI accelerators.
  • MLIR (Multi-Level IR) - Google + LLVM - not really a compiler, but a toolbox to build compilers (???)
    • It is a multi-level IR system. Tries to make it easier to build new AI compilers by providing a moduler and simpler approach.
  • XLA - JIT compiler - Google - Mainly used as the backend of TensorFlow and JAX. It can compile machine code in a JIT fashion.
  • IREE (Intermediate Representation Execution Environment) - Compiler and Runtime - Google - Based on MLIR, a full completed solution. To be used across heterogeneous hardware.
  • torch.compile - PyTorch 2.0 core functionatility - PyTorch (meta) - user friendly interface

Questions

  • What is nn.Cov2d and what is tf.nn.relu
    • Fundamental bulding blocks in neural neteworks from the two popular deep learning frameworks.
      • nn.Conv2d is the 2D convolutional layer from the PyTorch library. Core component of CNNs. Main job is to apply a set of learnable filters (or kernels) across a 2D input, like an image, to detect fatures such as edges, textures, or more complex shapes.
      • tf.nn.relu is the rectified linear unit activation function from tensorflow library. 似乎就是一个 ReLU 的library。
  • What is 偏置加法?
    • 好吧,这就是把 加上。
  • What is NCHW and what is NHWC
    • N (Batch - how many images are processed at once), C (Channels - how many color channels are there in each image, e.g. 3 for an RGB image), H (Height - the height of each image), W (Width - the width of each image)
    • NCHW Channel-first, NHWC Channel-last
      • Channel-first: The data is stored in the order of [batch, channels, height, width]. The default format used by PyTorch and is often more effcient on NVIDIA GPUs with the cuDNN library.
      • Channel-last: Stored in [batch, height, width, channels]. The default format used by TensorFlow and is often more efficient on CPUs
  • What is JAX?
    • JAX is a high performance machine learning research library from Google.
    • Known for combing a NumPy-like API with powerful automatic transformations.
      • grad: automatic differentiation for calculating gradients
      • jit: JIT compiler that uses Google’s XLA (accelerated linear algebra) compiler to fuse operations and generate highly optimized machine code for accelerates like GPUs and TPUs.
      • vmap and pmap: functions for automatic vectorization (processing batches of data) and parallelization (running code across multiple devices)
    • In short, user can write cleaner, more pythonic code that looks like NumPy but runs magically faster.

[2025-07-30-Wednesday]

Attention

Flash Attention