import torch
import torch.nn as nn
import torch.nn.functional as F
from bmm1 import *
from bmm2 import *
from padding import *
from softmax import *
class FastUnpadBertSelfAttention(nn.Module):
def __init__(self, config, enable_stream=True, enable_sync=True, fuse_mask=True, fuse_scale=True, fuse_qkv=True, fuse_dropout=True, apex_softmax=True, pad=True):
super(FastUnpadBertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.hidden_size = config.hidden_size
self.fuse_qkv = fuse_qkv
self.fuse_scale = fuse_scale
self.fuse_mask = fuse_mask
self.fuse_dropout = fuse_dropout
self.apex_softmax = apex_softmax
self.pad = pad
self.enable_stream = enable_stream
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
if self.fuse_qkv:
self.bmm1 = Bmm1Strided(None,None,self.num_attention_heads,self.attention_head_size, scale=self.fuse_scale, stream=enable_stream, sync=enable_sync, timer=False)
self.bmm2 = Bmm2Strided(None,None,self.num_attention_heads,self.attention_head_size, stream=enable_stream, sync=enable_sync, timer=False)
else:
self.bmm1 = Bmm1(None,None,self.num_attention_heads,self.attention_head_size, scale=self.fuse_scale, stream=enable_stream, sync=enable_sync)
self.bmm2 = Bmm2(None,None,self.num_attention_heads,self.attention_head_size, stream=enable_stream, sync=enable_sync)
if self.fuse_dropout == False:
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
if self.fuse_mask == True and self.fuse_dropout == True:
self.softmax = FastMaskSoftmaxDropout(dim=-1, dropout_prob=config.attention_probs_dropout_prob,stream=enable_stream, sync=(not self.pad), timer=False)
elif self.fuse_mask == True:
self.softmax = FastMaskSoftmax(dim=-1, stream=enable_stream, sync=enable_sync, timer=False)
else:
self.softmax = FastSoftmax(dim=-1, stream=enable_stream, sync=enable_sync, timer=False)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = torch.reshape(x, new_x_shape)
return x.permute(0, 2, 1, 3)
def transpose_key_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = torch.reshape(x, new_x_shape)
return x.permute(0, 2, 3, 1)
def pytorch_softmax(self,attention_scores, batch, seqlen, heads):
ntokens2 = 0
for i in range(batch):
ntokens2 += seqlen[i]*seqlen[i]*self.num_attention_heads
attention_probs = torch.zeros(ntokens2, device="cuda", dtype=torch.float16)
ntokens2 = 0
for i in range(batch):
tokens2 = seqlen[i]*seqlen[i]*self.num_attention_heads
attention_probs[ntokens2:ntokens2+tokens2] = F.softmax(attention_scores[ntokens2:ntokens2+tokens2].view(1,self.num_attention_heads,seqlen[i],seqlen[i]), dim=-1).flatten().contiguous()
ntokens2 += tokens2
return attention_probs
def forward(self, hidden_states, attention_mask, seqlen, batch, is_training=True):
self.batch = batch
if self.fuse_qkv:
weight = torch.cat([self.query.weight.view(self.num_attention_heads,self.attention_head_size,1,self.hidden_size), self.key.weight.view(self.num_attention_heads,self.attention_head_size,1,self.hidden_size), self.value.weight.view(self.num_attention_heads,self.attention_head_size,1,self.hidden_size)], dim=1).reshape(self.all_head_size*3,self.hidden_size).contiguous()
bias = torch.cat([self.query.bias.view(self.num_attention_heads,1,self.attention_head_size), self.key.bias.view(self.num_attention_heads,1,self.attention_head_size), self.value.bias.view(self.num_attention_heads,1,self.attention_head_size)],dim=1).reshape(3*self.hidden_size).contiguous()
mixed_x_layer = torch.addmm(bias, hidden_states, weight.t())
else:
query_layer = self.query(hidden_states)
key_layer = self.key(hidden_states)
value_layer = self.value(hidden_states)
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_qkv:
attention_scores, qkv_layer = self.bmm1(mixed_x_layer, self.batch, seqlen)
else:
attention_scores = self.bmm1(query_layer, key_layer, self.batch, seqlen)
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_scale == False:
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_mask ==True and self.fuse_dropout == True:
attention_probs = self.softmax(attention_scores, attention_mask, self.batch, seqlen, self.num_attention_heads, is_training)
elif self.fuse_mask == True:
attention_probs = self.softmax(attention_scores, attention_mask, self.batch, seqlen, self.num_attention_heads)
else:
attention_scores = attention_scores + attention_mask.view(-1)
if self.apex_softmax == True:
attention_probs = self.softmax(attention_scores, self.batch, seqlen, self.num_attention_heads)
else:
if self.pad == True:
attention_probs = F.softmax(attention_scores.view(batch,self.num_attention_heads,seqlen[0],seqlen[0]), dim=-1).flatten().contiguous()
else:
attention_probs = self.pytorch_softmax(attention_scores, self.batch, seqlen, self.num_attention_heads)
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_dropout == False:
attention_probs = self.dropout(attention_probs)
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_qkv:
context_layer = self.bmm2(attention_probs, qkv_layer, self.batch, seqlen)
else:
context_layer = self.bmm2(attention_probs, value_layer, self.batch, seqlen)
if self.enable_stream: torch.cuda.synchronize()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = torch.reshape(context_layer, new_context_layer_shape)
return context_layer