import torch
import mhalib
class FastSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(cxt, input, dim, batch, seqlen, heads, stream, sync, timers):
if timers: timers['start_fprop'].record()
mhalib.FastSoftmaxFprop(input, batch, seqlen, heads, stream, sync)
if timers: timers['stop_fprop'].record()
cxt.save_for_backward(input,seqlen)
cxt.dim = dim
cxt.batch = batch
cxt.heads = heads
cxt.stream = stream
cxt.sync = sync
cxt.timers = timers
return input
@staticmethod
def backward(cxt, grad_output):
output, seqlen, = cxt.saved_tensors
dim = cxt.dim
batch = cxt.batch
heads = cxt.heads
if cxt.timers: cxt.timers['start_dgrad'].record()
mhalib.FastSoftmaxBprop(output, grad_output, batch, seqlen, heads, cxt.stream, cxt.sync)
if cxt.timers: cxt.timers['stop_dgrad'].record()
return grad_output, None, None, None, None, None, None, None
class FastSoftmax(torch.nn.Module):
def __init__(self, dim=None, stream=True, sync=True, timer=False):
super(FastSoftmax, self).__init__()
self.dim = dim
self.stream = stream
self.sync = sync
if timer:
self.timers = {'start_fprop':torch.cuda.Event(enable_timing=True),
'start_dgrad':torch.cuda.Event(enable_timing=True),
'stop_fprop':torch.cuda.Event(enable_timing=True),
'stop_dgrad':torch.cuda.Event(enable_timing=True)}
else:
self.timers = None
def forward(self, input, batch, seqlen, heads):
return FastSoftmaxFunction.apply(input, self.dim, batch, seqlen, heads, self.stream, self.sync, self.timers)
class FastMaskSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(cxt, input, mask, dim, batch, seqlen, heads, stream, sync, timers):
if timers: timers['start_fprop'].record()
mhalib.FastMaskSoftmaxFprop(input, mask, batch, seqlen, heads, stream, sync)
if timers: timers['stop_fprop'].record()
cxt.save_for_backward(input,seqlen)
cxt.dim = dim
cxt.batch = batch
cxt.heads = heads
cxt.stream = stream
cxt.sync = sync
cxt.timers = timers
return input
@staticmethod
def backward(cxt, grad_output):
output, seqlen, = cxt.saved_tensors
dim = cxt.dim
batch = cxt.batch
heads = cxt.heads
if cxt.timers: cxt.timers['start_dgrad'].record()
mhalib.FastSoftmaxBprop(output, grad_output, batch, seqlen, heads, cxt.stream, cxt.sync)
if cxt.timers: cxt.timers['stop_dgrad'].record()
return grad_output, None, None, None, None, None, None, None, None, None, None, None
class FastMaskSoftmax(torch.nn.Module):
def __init__(self, dim=None, stream=True, sync=True, timer=False):
super(FastMaskSoftmax, self).__init__()
self.dim = dim
self.stream = stream
self.sync = sync
if timer:
self.timers = {'start_fprop':torch.cuda.Event(enable_timing=True),
'start_dgrad':torch.cuda.Event(enable_timing=True),
'stop_fprop':torch.cuda.Event(enable_timing=True),
'stop_dgrad':torch.cuda.Event(enable_timing=True)}
else:
self.timers = None
def forward(self, input, mask, batch, seqlen, heads):
return FastMaskSoftmaxFunction.apply(input, mask, self.dim, batch, seqlen, heads, self.stream, self.sync, self.timers)
class FastMaskSoftmaxDropoutFunction(torch.autograd.Function):
@staticmethod
def forward(cxt, input, mask, dim, batch, seqlen, heads, dropout_prob, stream, sync, timers, is_training):
if timers: timers['start_fprop'].record()
output, dropout_mask, = mhalib.FastMaskSoftmaxDropoutFprop(input, mask, batch, seqlen, heads, dropout_prob, stream, sync, is_training)
if timers: timers['stop_fprop'].record()
cxt.save_for_backward(input,dropout_mask,seqlen)
cxt.dim = dim
cxt.batch = batch
cxt.heads = heads
cxt.dropout_prob = dropout_prob
cxt.stream = stream
cxt.sync = sync
cxt.timers = timers
return output
@staticmethod
def backward(cxt, grad_output):
output, dropout_mask, seqlen, = cxt.saved_tensors
dim = cxt.dim
batch = cxt.batch
heads = cxt.heads
dropout_prob = cxt.dropout_prob
if cxt.timers: cxt.timers['start_dgrad'].record()
mhalib.FastMaskSoftmaxDropoutBprop(output, grad_output, dropout_mask, batch, seqlen, heads, dropout_prob, cxt.stream, cxt.sync)
if cxt.timers: cxt.timers['stop_dgrad'].record()
return grad_output, None, None, None, None, None, None, None, None, None, None, None, None, None
class FastMaskSoftmaxDropout(torch.nn.Module):
def __init__(self, dim=None, dropout_prob=None, stream=True, sync=True, timer=False):
super(FastMaskSoftmaxDropout, self).__init__()
self.dim = dim
self.dropout_prob = dropout_prob
self.stream = stream
self.sync = sync
if timer:
self.timers = {'start_fprop':torch.cuda.Event(enable_timing=True),
'start_dgrad':torch.cuda.Event(enable_timing=True),
'stop_fprop':torch.cuda.Event(enable_timing=True),
'stop_dgrad':torch.cuda.Event(enable_timing=True)}
else:
self.timers = None
def forward(self, input, mask, batch, seqlen, heads, is_training):
return FastMaskSoftmaxDropoutFunction.apply(input, mask, self.dim, batch, seqlen, heads, self.dropout_prob, self.stream, self.sync, self.timers, is_training)