diff --git a/model/modeling_albert.py b/model/modeling_albert.py

index 899e6e6..b8e083f 100644

--- a/model/modeling_albert.py

+++ b/model/modeling_albert.py

@@ -1,6 +1,5 @@

 """PyTorch ALBERT model. """

 from __future__ import absolute_import, division, print_function, unicode_literals

-import logging

 import math

 import os

 import sys

@@ -10,7 +9,7 @@ from torch.nn import CrossEntropyLoss, MSELoss

 from .modeling_utils import PreTrainedModel, prune_linear_layer

 from .configuration_albert import AlbertConfig

 from .file_utils import add_start_docstrings

-logger = logging.getLogger(__name__)

+from ..tools.common import logger  # get args via logger

 

 ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {

     'albert-base': "",

@@ -123,12 +122,12 @@ class AlbertEmbeddings(nn.Module):

         # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load

         self.LayerNorm = AlbertLayerNorm(config.embedding_size, eps=config.layer_norm_eps)

         self.dropout = nn.Dropout(config.hidden_dropout_prob)

+        self.position_ids = torch.LongTensor([list(range(logger.args.max_seq_length))] * logger.args.batch_size).to(

+            logger.args.device)

 

     def forward(self, input_ids, token_type_ids=None, position_ids=None):

-        seq_length = input_ids.size(1)

         if position_ids is None:

-            position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)

-            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

+            position_ids = self.position_ids

         if token_type_ids is None:

             token_type_ids = torch.zeros_like(input_ids)

         words_embeddings = self.word_embeddings(input_ids)

@@ -453,7 +452,7 @@ class AlbertPreTrainedModel(PreTrainedModel):

 

 ALBERT_START_DOCSTRING = r"""    The ALBERT model was proposed in

     `ALBERT: A Lite BERT for Self-supervised Learning of Language Representations`_

-    by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut. 

+    by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.

     This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and

     refer to the PyTorch documentation for all matter related to general usage and behavior.

     .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:

@@ -461,7 +460,7 @@ ALBERT_START_DOCSTRING = r"""    The ALBERT model was proposed in

     .. _`torch.nn.Module`:

         https://pytorch.org/docs/stable/nn.html#module

     Parameters:

-        config (:class:`~transformers.ALbertConfig`): Model configuration class with all the parameters of the model. 

+        config (:class:`~transformers.ALbertConfig`): Model configuration class with all the parameters of the model.

             Initializing with a config file does not load the weights associated with the model, only the configuration.

             Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.

 """

diff --git a/model/modeling_utils.py b/model/modeling_utils.py

index 56a52e9..4a49178 100644

--- a/model/modeling_utils.py

+++ b/model/modeling_utils.py

@@ -12,6 +12,7 @@ from torch.nn import CrossEntropyLoss

 from torch.nn import functional as F

 

 from model.configuration_utils import PretrainedConfig

+from model.configuration_albert import AlbertConfig

 from model.file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME

 

 logger = logging.getLogger(__name__)

@@ -54,7 +55,7 @@ class PreTrainedModel(nn.Module):

 

     def __init__(self, config, *inputs, **kwargs):

         super(PreTrainedModel, self).__init__()

-        if not isinstance(config, PretrainedConfig):

+        if not 'AlbertConfig' in str(type(config)) : # modify via infer changes root

             raise ValueError(

                 "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "

                 "To create a model from a pretrained model use "

@@ -123,7 +124,7 @@ class PreTrainedModel(nn.Module):

         Arguments:

 

             new_num_tokens: (`optional`) int:

-                New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. 

+                New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.

                 If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.

 

         Return: ``torch.nn.Embeddings``

diff --git a/processors/glue.py b/processors/glue.py

index 6628226..8836416 100644

--- a/processors/glue.py

+++ b/processors/glue.py

@@ -14,10 +14,6 @@ def collate_fn(batch):

     Returns a padded tensor of sequences sorted from longest to shortest,

     """

     all_input_ids, all_attention_mask, all_token_type_ids, all_lens, all_labels = map(torch.stack, zip(*batch))

-    max_len = max(all_lens).item()

-    all_input_ids = all_input_ids[:, :max_len]

-    all_attention_mask = all_attention_mask[:, :max_len]

-    all_token_type_ids = all_token_type_ids[:, :max_len]

     return all_input_ids, all_attention_mask, all_token_type_ids, all_labels

 

 

@@ -266,6 +262,11 @@ class Sst2Processor(DataProcessor):

         return self._create_examples(

             self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

 

+    def get_test_examples(self, data_dir):

+        """See base class."""

+        return self._create_examples(

+            self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

+

     def get_labels(self):

         """See base class."""

         return ["0", "1"]

diff --git a/tools/fps_counter.py b/tools/fps_counter.py

new file mode 100644

index 0000000..b6b4abf

--- /dev/null

+++ b/tools/fps_counter.py

@@ -0,0 +1,99 @@

+# Copyright 2021 Huawei Technologies Co., Ltd

+#

+# Licensed under the Apache License, Version 2.0 (the "License");

+# you may not use this file except in compliance with the License.

+# You may obtain a copy of the License at

+#

+#     http://www.apache.org/licenses/LICENSE-2.0

+#

+# Unless required by applicable law or agreed to in writing, software

+# distributed under the License is distributed on an "AS IS" BASIS,

+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

+# See the License for the specific language governing permissions and

+# limitations under the License.

+

+import time

+from threading import Lock

+

+

+class FpsCounter:

+    """

+    how to use

+

+    fps=FpsCounter()

+    fps.begin()

+    code

+    fps.end()

+    print(fps.fps())

+

+    """

+    def __init__(self):

+        self.step_sum = 0

+        self.time_sum = 0

+        self.t1 = 0

+        self.on = False

+

+    def begin(self):

+        assert self.on == False, "didnot end last time"

+        self.on = True

+        self.t1 = time.time_ns()

+

+    def end(self):

+        t2 = time.time_ns()

+        assert self.on == True, "didnot begin"

+        self.time_sum += t2 - self.t1

+        self.step_sum += 1

+        self.on = False

+

+    def reset(self):

+        self.step_sum = 0

+        self.time_sum = 0

+        self.t1 = 0

+        self.on = False

+

+    def fps(self, batch=1, n_device=1):

+        if self.step_sum == 0: return 0

+        time_avg = self.time_sum / 1e9 / self.step_sum

+        return batch * n_device / time_avg

+

+

+class FpsCounter2:

+    def __init__(self, node_num=0):

+        self.node_num = node_num

+        self.lock = Lock()

+        self.step_sum = [0 for i in range(node_num)]

+        self.time_sum = [0 for i in range(node_num)]

+        self.t1 = [0 for i in range(node_num)]

+        self.on = [False for i in range(node_num)]

+

+    def begin(self, node_idx=0):

+        assert self.on[node_idx] == False, "didnot end last time"

+        self.lock.acquire()

+        self.on[node_idx] = True

+        self.t1[node_idx] = time.time_ns()

+        self.lock.release()

+

+    def end(self, node_idx=0):

+        t2 = time.time_ns()

+        assert self.on[node_idx] == True, "didnot begin"

+        self.lock.acquire()

+        self.time_sum[node_idx] += t2 - self.t1[node_idx]

+        self.step_sum[node_idx] += 1

+        self.on[node_idx] = False

+        self.lock.release()

+

+    def reset(self, node_idx=0):

+        self.lock.acquire()

+        self.step_sum[node_idx] = 0

+        self.time_sum[node_idx] = 0

+        self.t1[node_idx] = 0

+        self.on[node_idx] = False

+        self.lock.release()

+

+    def fps(self, batch=1, n_device=1, world_size=0):

+        fps = 0

+        for i in range(world_size):

+            if self.step_sum[i] == 0: continue

+            time_avg = self.time_sum[i] / 1e9 / self.step_sum[i]

+            fps += batch * n_device / time_avg

+        return fps