import torch
import mhalib
class Bmm2Function(torch.autograd.Function):
@staticmethod
def forward(ctx, batch1, batch2, seqlen, batch, maxseqlen, heads, embed, sync, stream):
ctx.save_for_backward(batch1, batch2, seqlen)
ctx.batch = batch
ctx.maxseqlen = maxseqlen
ctx.heads = heads
ctx.embed = embed
ctx.stream = stream
ctx.sync = sync
ntokens = seqlen.sum().item()
ctx.ntokens = ntokens
output = torch.empty([ntokens,heads,embed], device="cuda", dtype=torch.float16)
mhalib.FastBmm2Fprop(batch2.flatten().contiguous(), batch1.flatten().contiguous(), output, batch, seqlen, heads, embed, False, False, stream, sync)
return output[:ntokens]
@staticmethod
def backward(ctx, grad_output):
batch1, batch2, seqlen = ctx.saved_tensors
batch = ctx.batch
maxseqlen = ctx.maxseqlen
heads = ctx.heads
embed = ctx.embed
ntokens = ctx.ntokens
ntokens2 = 0
for i in range(batch):
ntokens2 += seqlen[i]*seqlen[i]
grad_batch1 = torch.empty([ntokens2*heads], device="cuda", dtype=torch.float16)
grad_batch2 = torch.empty([ntokens,heads*embed], device="cuda", dtype=torch.float16)
mhalib.FastBmm2Dgrad1(batch2.flatten().contiguous(), grad_output, grad_batch1, batch, seqlen, heads, embed, False, False, ctx.stream, ctx.sync)
mhalib.FastBmm2Dgrad2(grad_output, batch1, grad_batch2, batch, seqlen, heads, embed, False, False, ctx.stream, ctx.sync)
return grad_batch1[:ntokens2*heads], grad_batch2[:ntokens], None, None, None, None, None, None, None
class Bmm2(torch.nn.Module):
def __init__(self, batch, seqlen, heads, embed, stream=True, sync=True):
super(Bmm2, self).__init__()
self.heads = heads
self.embed = embed
self.maxseqlen = seqlen
self.stream = stream
self.sync = sync
def forward(self, batch1, batch2, batch, seqlen):
return Bmm2Function.apply(batch1, batch2, seqlen, batch, self.maxseqlen, self.heads, self.embed, self.stream, self.sync)
class Bmm2StridedFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, batch1, mixed, seqlen, batch, maxseqlen, heads, embed, stream, sync, timers):
ctx.save_for_backward(batch1, mixed, seqlen)
ctx.batch = batch
ctx.maxseqlen = maxseqlen
ctx.heads = heads
ctx.embed = embed
ctx.stream = stream
ctx.sync = sync
ctx.timers = timers
ntokens = seqlen.sum().item()
ctx.ntokens = ntokens
output = torch.empty([ntokens,heads,embed], device="cuda", dtype=torch.float16)
if timers: timers['start_fprop'].record()
mhalib.FastBmm2Fprop(mixed, batch1, output, batch, seqlen, heads, embed, False, True, stream, sync)
if timers: timers['stop_fprop'].record()
return output[:ntokens]
@staticmethod
def backward(ctx, grad_output):
batch1, mixed, seqlen = ctx.saved_tensors
batch = ctx.batch
maxseqlen = ctx.maxseqlen
heads = ctx.heads
embed = ctx.embed
ntokens = ctx.ntokens
ntokens2 = 0
for i in range(batch):
ntokens2 += seqlen[i]*seqlen[i]
grad_batch1 = torch.empty(ntokens2*heads, device="cuda", dtype=torch.float16)
grad_mixed = torch.empty([ntokens,heads*3*embed], device="cuda", dtype=torch.float16)
if ctx.timers: ctx.timers['start_dgrad'].record()
mhalib.FastBmm2Dgrad1(mixed, grad_output, grad_batch1, batch, seqlen, heads, embed, False, True, ctx.stream, ctx.sync)
if ctx.timers: ctx.timers['stop_dgrad'].record()
if ctx.timers: ctx.timers['start_wgrad'].record()
mhalib.FastBmm2Dgrad2(grad_output, batch1, grad_mixed, batch, seqlen, heads, embed, False, True, ctx.stream, ctx.sync)
if ctx.timers: ctx.timers['stop_wgrad'].record()
return grad_batch1[:ntokens2*heads], grad_mixed[:ntokens], None, None, None, None, None, None, None, None
class Bmm2Strided(torch.nn.Module):
def __init__(self, batch, seqlen, heads, embed, stream=True, sync=True, timer=False):
super(Bmm2Strided, self).__init__()
self.heads = heads
self.embed = embed
self.maxseqlen = seqlen
self.stream = stream
self.sync = sync
if timer:
self.timers = {'start_fprop':torch.cuda.Event(enable_timing=True),
'start_dgrad':torch.cuda.Event(enable_timing=True),
'start_wgrad':torch.cuda.Event(enable_timing=True),
'stop_fprop':torch.cuda.Event(enable_timing=True),
'stop_dgrad':torch.cuda.Event(enable_timing=True),
'stop_wgrad':torch.cuda.Event(enable_timing=True)}
else:
self.timers = None
def forward(self, batch1, mixed, batch, seqlen):
return Bmm2StridedFunction.apply(batch1, mixed, seqlen, batch, self.maxseqlen, self.heads, self.embed, self.stream, self.sync, self.timers)