class AttentionHead(nn.Module):
"""
A single attention head.
This module is used in the MultiHeadAttention module.
"""
def __init__(self, hidden_size, attention_head_size, dropout, bias=True):
super().__init__()
self.hidden_size = hidden_size
self.attention_head_size = attention_head_size
self.query = nn.Linear(hidden_size, attention_head_size, bias=bias)
self.key = nn.Linear(hidden_size, attention_head_size, bias=bias)
self.value = nn.Linear(hidden_size, attention_head_size, bias=bias)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
query = self.query(x)
key = self.key(x)
value = self.value(x)
attention_scores = torch.matmul(query, key.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
attention_output = torch.matmul(attention_probs, value)
return (attention_output, attention_probs)
class MultiHeadAttention(nn.Module):
"""
Multi-head attention module.
This module is used in the TransformerEncoder module.
"""
def __init__(self, config):
super().__init__()
self.hidden_size = config["hidden_size"]
self.num_attention_heads = config["num_attention_heads"]
self.attention_head_size = self.hidden_size // self.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.qkv_bias = config["qkv_bias"]
self.heads = nn.ModuleList([])
for _ in range(self.num_attention_heads):
head = AttentionHead(
self.hidden_size,
self.attention_head_size,
config["attention_probs_dropout_prob"],
self.qkv_bias
)
self.heads.append(head)
self.output_projection = nn.Linear(self.all_head_size, self.hidden_size)
self.output_dropout = nn.Dropout(config["hidden_dropout_prob"])
def forward(self, x, output_attentions=False):
attention_outputs = [head(x) for head in self.heads]
attention_output = torch.cat([attention_output for attention_output, _ in attention_outputs], dim=-1)
attention_output = self.output_projection(attention_output)
attention_output = self.output_dropout(attention_output)
if not output_attentions:
return (attention_output, None)
else:
attention_probs = torch.stack([attention_probs for _, attention_probs in attention_outputs], dim=1)
return (attention_output, attention_probs)