diff --git a/allennlp/modules/elmo.py b/allennlp/modules/elmo.py
index fb92985e..c527481d 100644
--- a/allennlp/modules/elmo.py
+++ b/allennlp/modules/elmo.py
@@ -605,7 +605,7 @@ class _ElmoBiLm(torch.nn.Module):
output_tensors = [
torch.cat([type_representation, type_representation], dim=-1) * mask.unsqueeze(-1)
]
- for layer_activations in torch.chunk(lstm_outputs, lstm_outputs.size(0), dim=0):
+ for layer_activations in torch.split(lstm_outputs, 1, dim=0):
output_tensors.append(layer_activations.squeeze(0))
return {"activations": output_tensors, "mask": mask}