Multi-Head Neural Network Design in PyTorch

Looking for a way to spice up your relationship with deep neural networks? Trying bringing up a multi-head design and watch the sparks start flying!
abstract multi head attention neural network illustration fi alpharithms

Neural Networks have a diverse range of design architectures. These are often uniquely suited to specific problem domains or performance requirements. The Multi-Head design offers both semantic and computational isolation of elements of the network. This offers benefits both in model performance and development workflow. While this concept is not new to the world of computer science, its implementation in modern systems has brought significant improvements in model performances.

Highlights

  • Crash course on PyTorch Deep Neural Network design architecture
  • Historical origins of multi-head design
  • Where multi-headed design is used in modern applications/systems/networks
  • Implementing basic deep neural networks in PyTorch
  • Implementing a multi-head in deep neural networks in PyTorch
  • Considerations for when to use Multi-head network architectures

TL;DR – Multi-head architecture helps separate conceptually and computationally isolated components.

Quick History

Multi-head network design stems from a broader concept of multi-task design for systems, with literature dating back to the late 1990s. In one paper, titled Multitask Learning [1], author Rick Caruana states the motivating case for multi-headed design as such:

Multitask Learning is an approach to inductive transfer that improves generalization by using the domain information contained in the training signals of related tasks as an inductive bias. It does this by learning tasks in parallel while using a shared representation; what is learned for each task can help other tasks be learned better.

These concepts were discussed before the widespread adoption of deep learning models (largely a resource constraint then) but have relevancy in modern AI systems. In modern systems, multi-headed design is expressed in a number of ways:

  • Models pre-training on multiple datasets.
  • Multi-head components are used by large language models like ChatGPT [2].
  • Fine-tuning models based on user prompt context.
  • Separating specific action groups into separate network “heads.”
  • so.many.more.

These are sparse examples of the many ways in which multi-headed design concepts are expressed in neural network architectures.

PyTorch Crash Course: Basic Deep Neural Network Design

PyTorch makes implementing neural networks easy. The abstraction of network components, deep integration with GPU-optimized libraries like CUDA, and a vast array of components like loss functions and optimizers make it a leading deep learning framework. Essentially, deep learning networks need the following:

  • Input Layer
  • Hidden Layer
  • Output Layer
  • Considerations for bias (see here), activation functions, and multi-head design.

We’ll save the part for multi-head design for last. For now, let’s focus on the essential components to lay the groundwork for multi-head design use cases.

Designing a deep neural network in PyTorch can be as simple as using a pre-defined model class like Linear as such: layer = Linear(3, 5) where 3 is the number of input features and 5 is the number of output features (a.k.a. nodes in many visualizations.) This syntax doesn’t account for the concept of activation functions, however, but PyTorch provides support there too.

This can be implemented simply as relu(layer(x)) where relu is an activation function from PyTorch’s functional Modele, layer is the Linear model we just created, and x are the input parameters. Here, the relu function transforms the output of the network layer. Here’s an example of defining each layer separately:

import torch
import torch.nn as nn
import torch.nn.functional as F

# verbose definition of params
input_features = 5
output_features = 2
hidden_layer_size = 128

# section 1: network design
layer_1 = nn.Linear(input_features, hidden_layer_size)
layer_2 = nn.Linear(hidden_layer_size, hidden_layer_size)
layer_3 = nn.Linear(hidden_layer_size, output_features)

# section 2: forward pass
x = ...
x = F.relu(layer_1(x))
x = F.relu(layer_2(x))
x = layer_3(x)  # <----- the final output layer

# an example of applying the action
action = torch.argmax(x).item()

Here we see a simple three-layer design:

  • layer_1: the layer that takes the initial inputs and outputs into the hidden layer.
  • layer_2: the “hidden layer” that makes this a “deep” neural network.
  • layer_3: the output layer, that translates the network learning into a probability space for each action. For example a Buy/Sell signal.

In the block labeled “section 2” we see an implementation of a “forward pass” via which input passes through the network, being transformed by weights, biases, and activation functions (relu in this case) into the output layer. Practically, section2 would be implemented as the forward method in a subclass of the nn.Module class in PyTorch. Here, however, it serves to illustrate a basic use only.

Multi-Headed Deep Neural Network in PyTorch

Let’s consider what a multi-headed design might look like — then we’ll move on to discussing why it might be beneficial:

import torch
import torch.nn as nn
import torch.nn.functional as F

# defines out network params
input_features = 5
output_features = 2
hidden_layer_size = 128

# defines the network structure
layer_1 = nn.Linear(input_features, hidden_layer_size)
layer_2 = nn.Linear(hidden_layer_size, hidden_layer_size)

# the multi-headed components, isolated from each other
trade_action = nn.Linear(hidden_layer_size, output_features)
leverage_amt = nn.Linear(hidden_layer_size, out_features=100)
risk_amt = nn.Linear(hidden_layer_size, out_features=100)

# example of a forward pass
x = ...
x = F.relu(layer_1(x))
x = F.relu(layer_2(x))

# separate actions for each "head"
buy_or_sell = torch.argmax(trade_action(x)).item
leverage = torch.argmax(use_leverage(x)).item
risk = torch.argmax(risk_amt(x)).item

Here we see a stark difference in our forward-pass approach. Rather than outputting a single value for the entire network, we implement three separate network “heads” for three conceptually different concerns:

  • trade_action: The signal to Buy/Sell an asset.
  • leverage_amt: The amount of leverage to use.
  • risk_amt: The amount of risk to take.

Here we see that all the input values into the network first pass through the input layer and then the hidden layer. After that, however, rather than all actions being encoded into a single layer, they are isolated. Two immediate benefits to note are:

  • Conceptual separation of tasks/outputs.
  • Parameter separation of learning.

The first benefit helps keep a more semantic flow during development and allows for stronger alignment with domain-specific task definitions. The second benefit is more foundational and tough to objectively measure. Does it make sense to separate these actions? There’s no universal answer to this question and each network designer should carefully consider the problem domain, what’s being optimized for, and the developer workflow. This particular case, an example of how this might make sense:

# reset a training environment
agent = ...
env = ...
state, info = env. reset()
done = False
while not done:
    
    # generate an action(s) from the network
    buy_sell, leverage_amount, risk_amount = agent(state)

    # apply the action to the environment
    state, reward, terminal, info = env(buy_sell, leverage_amount, risk_amount)

Here we note an immediate semantic benefit of having three distinct components of an agent’s action returned from the network. This can help avoid simple syntactic errors during the development of training routines and represent a clearer intent in code. In this case, these actions are doing the following:

  • buy_sell: Whether to Buy or Sell an asset.
  • leverage_amount: The amount of leverage (if relevant) to use on the order.
  • risk_amount: The risk to accept during order sizing.

How these action components are applied in the environment is case-specific. However, the clarity of their intent and simplified application process are broadly applicable and good software design practices in my opinion.

Discussion

Multi-headed network designs offer novel discretization of network components. This offers both semantic and performance improvements — depending on the use case. Modern LLMs tend to use multi-head components directing their “attention heads”[2] for previously unmatched emergent properties.

Essentially, the multi-headed components of LLM models are largely attributed as having given rise to the new generation of models, such as ChatGPT, that offer significant performance improvements.

Will dropping in a multi-head design to your network turn your startup into the next OpenAI — doubtful. However, understanding the applicability of multi-headed architectures can help design and develop networks for deep learning that are more contextual and applicable to their problem domains.

References

  1. Caruana, R. Multitask Learning. Machine Learning 28, 41–75 (1997). https://doi.org/10.1023/A:1007379606734
  2. Vaswani, A., et. al. Attention Is All You Need. arXiv.
    https://doi.org/10.48550/arXiv.1706.03762
Zαck West
Full-Stack Software Engineer with 10+ years of experience. Expertise in developing distributed systems, implementing object-oriented models with a focus on semantic clarity, driving development with TDD, enhancing interfaces through thoughtful visual design, and developing deep learning agents.