@@ -166,8 +166,8 @@ class HubertCtc(BaseFairseqModel):
return logits
- def forward(self, **kwargs):
- x = self.w2v_encoder(**kwargs)
+ def forward(self, source, **kwargs):
+ x = self.w2v_encoder(source ,**kwargs)
return x
@@ -302,7 +302,7 @@ class HubertEncoder(FairseqEncoder):
super().set_num_updates(num_updates)
self.num_updates = num_updates
- def forward(self, source, padding_mask, tbc=True, **kwargs):
+ def forward(self, source, padding_mask=None, tbc=True, **kwargs):
w2v_args = {
"source": source,
@@ -324,11 +324,7 @@ class HubertEncoder(FairseqEncoder):
if self.proj:
x = self.proj(x)
- return {
- "encoder_out": x, # T x B x C
- "encoder_padding_mask": padding_mask, # B x T
- "padding_mask": padding_mask,
- }
+ return x
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
@@ -14,8 +14,6 @@ def pad_to_multiple(x, multiple, dim=-1, value=0):
tsz = x.size(dim)
m = tsz / multiple
remainder = math.ceil(m) * multiple - tsz
- if m.is_integer():
- return x, 0
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder