diff --git a/code/configs/final_RoP_Cov_A_fMCG.yaml b/code/configs/final_RoP_Cov_A_fMCG.yaml
deleted file mode 100644
index 37d963a..0000000
--- a/code/configs/final_RoP_Cov_A_fMCG.yaml
+++ /dev/null
@@ -1,223 +0,0 @@
-alias: "A_fMCG"
-
-train:
-  data_config:
-    dataset_config:
-      data_path: "/home/stepankonev/w/data/prerendered/training_sparse/"
-      lstm_input_data: ["xy", "yaw", "speed", "width", "length", "valid"]
-      lstm_input_data_diff: ["xy", "yaw", "speed", "valid"]
-      mask_history: True
-      mask_history_fraction: 0.15
-    dataloader_config:
-      batch_size: 42
-      shuffle: True
-      num_workers: 6
-  optimizer:
-    lr: 0.0001
-  n_epochs: 120
-  validate_every_n_steps: 1000
-  max_iterations: 5000001
-  normalize: True
-  normalize_output: True
-  clip_grad_norm: 0.4
-  scheduler: True
-
-val:
-  data_config:
-    dataset_config:
-      data_path: "/home/stepankonev/w/data/prerendered/validation_sparse/"
-      lstm_input_data: ["xy", "yaw", "speed", "width", "length", "valid"]
-      lstm_input_data_diff: ["xy", "yaw", "speed", "valid"]
-      mask_history: False
-    dataloader_config:
-      batch_size: 42
-      shuffle: False
-      num_workers: 6
-
-
-model:
-  n_trajectories: 6
-  size: 640
-  make_em: False
-  multiple_predictions: True
-
-  agent_mcg_linear:
-    layers: [24, 32, 64, 128]
-    pre_activation: False
-    pre_batchnorm: False
-    batchnorm: False
-
-  interaction_mcg_linear:
-    layers: [24, 32, 64, 128]
-    pre_activation: False
-    pre_batchnorm: False
-    batchnorm: False
-
-  agent_history_encoder:
-    position_lstm_config:
-      input_size: 13
-      hidden_size: 64
-    position_diff_lstm_config:
-      input_size: 11
-      hidden_size: 64
-    position_mcg_config:
-      agg_mode: "max"
-      running_mean_mode: "real"
-      alpha: 0.1
-      beta: 0.9
-      n_blocks: 5
-      identity_c_mlp: True
-      block:
-        c_bias: True
-        mlp:
-          n_layers: 3
-          n_in: 128
-          n_out: 128
-          bias: True
-          batchnorm: False
-          dropout: False
-
-  interaction_history_encoder:
-    position_lstm_config:
-      input_size: 13
-      hidden_size: 64
-    position_diff_lstm_config:
-      input_size: 11
-      hidden_size: 64
-    position_mcg_config:
-      block:
-        c_bias: True
-        mlp:
-          n_layers: 3
-          n_in: 128
-          n_out: 128
-          bias: True
-          batchnorm: False
-          dropout: False
-      agg_mode: "max"
-      running_mean_mode: "real"
-      alpha: 0.1
-      beta: 0.9
-      n_blocks: 5
-      identity_c_mlp: True
-
-  polyline_encoder:
-    layers: [27, 32, 64, 128]
-    pre_activation: False
-    pre_batchnorm: False
-    batchnorm: False
-
-  history_mcg_encoder:
-    block:
-      c_bias: True
-      mlp:
-        n_layers: 3
-        n_in: 256
-        n_out: 256
-        bias: True
-        batchnorm: False
-        dropout: False
-    agg_mode: "max"
-    running_mean_mode: "real"
-    alpha: 0.1
-    beta: 0.9
-    n_blocks: 5
-    identity_c_mlp: True
-
-  interaction_mcg_encoder:
-    block:
-      c_bias: True
-      mlp:
-        n_layers: 3
-        n_in: 256
-        n_out: 256
-        bias: True
-        batchnorm: False
-        dropout: False
-    agg_mode: "max"
-    running_mean_mode: "real"
-    alpha: 0.1
-    beta: 0.9
-    n_blocks: 5
-    identity_c_mlp: False
-
-  roadgraph_mcg_encoder:
-    block:
-      c_bias: True
-      mlp:
-        n_layers: 3
-        n_in: 128
-        n_out: 128
-        bias: True
-        batchnorm: False
-        dropout: False
-    agg_mode: "max"
-    running_mean_mode: "real"
-    alpha: 0.1
-    beta: 0.9
-    n_blocks: 5
-    identity_c_mlp: False
-
-  agent_and_interaction_linear:
-    layers: [512, 256, 128]
-    pre_activation: True
-    pre_batchnorm: False
-    batchnorm: False
-
-  decoder_handler_config:
-    n_decoders: 5
-    return_embedding: True
-    decoder_config:
-      trainable_cov: True
-      size: 640
-      n_trajectories: 6
-      mcg_predictor:
-        block:
-          c_bias: True
-          mlp:
-            n_layers: 3
-            n_in: 640
-            n_out: 640
-            bias: True
-            batchnorm: False
-            dropout: False
-        agg_mode: "max"
-        running_mean_mode: "real"
-        alpha: 0.1
-        beta: 0.9
-        n_blocks: 5
-        identity_c_mlp: False
-      DECODER:
-        layers: [640, 512, 401]
-        pre_activation: True
-        pre_batchnorm: False
-        batchnorm: False
-
-  final_decoder:
-    trainable_cov: True
-    size: 640
-    return_embedding: False
-    n_trajectories: 6
-    mcg_predictor:
-      block:
-        c_bias: True
-        mlp:
-          n_layers: 3
-          n_in: 640
-          n_out: 640
-          bias: True
-          batchnorm: False
-          dropout: False
-      agg_mode: "max"
-      running_mean_mode: "real"
-      alpha: 0.1
-      beta: 0.9
-      n_blocks: 5
-      identity_c_mlp: False
-    DECODER:
-      layers: [640, 512, 401]
-      pre_activation: True
-      pre_batchnorm: False
-      batchnorm: False
-
-  mha_decoder: True
diff --git a/code/configs/final_RoP_Cov_Single.yaml b/code/configs/final_RoP_Cov_Single.yaml
index 23393c9..e7901dc 100644
--- a/code/configs/final_RoP_Cov_Single.yaml
+++ b/code/configs/final_RoP_Cov_Single.yaml
@@ -9,11 +9,11 @@ train:
       mask_history: True
       mask_history_fraction: 0.15
     dataloader_config:
-      batch_size: 42
+      batch_size: 128
       shuffle: True
       num_workers: 6
   optimizer:
-    lr: 0.0001
+    lr: 0.0003
   n_epochs: 120
   validate_every_n_steps: 1000
   max_iterations: 5000001
@@ -30,7 +30,7 @@ val:
       lstm_input_data_diff: ["xy", "yaw", "speed", "valid"]
       mask_history: False
     dataloader_config:
-      batch_size: 42
+      batch_size: 128
       shuffle: False
       num_workers: 6
 
@@ -168,7 +168,7 @@ model:
     n_decoders: 1
     return_embedding: False
     decoder_config:
-      trainable_cov: True
+      trainable_cov: False
       size: 640
       n_trajectories: 6
       mcg_predictor:
diff --git a/code/model/losses.py b/code/model/losses.py
index 6869f6c..77eb58d 100644
--- a/code/model/losses.py
+++ b/code/model/losses.py
@@ -3,13 +3,18 @@ from torch import nn
 import numpy as np
 from torch.distributions.lowrank_multivariate_normal import LowRankMultivariateNormal
 
+def custom_logdet(A):
+    U, S, V = torch.svd(A)
+    abs_det = torch.prod(S)
+    return torch.log(abs_det)
+
 def nll_with_covariances(gt, predictions, confidences, avails, covariance_matrices):
     precision_matrices = torch.inverse(covariance_matrices)
     gt = torch.unsqueeze(gt, 1)
     avails = avails[:, None, :, None]
     coordinates_delta = (gt - predictions).unsqueeze(-1)
     errors = coordinates_delta.permute(0, 1, 2, 4, 3) @ precision_matrices @ coordinates_delta
-    errors = avails * (-0.5 * errors.squeeze(-1) - 0.5 * torch.logdet(covariance_matrices).unsqueeze(-1))
+    errors = avails * (-0.5 * errors.squeeze(-1) - 0.5 * custom_logdet(covariance_matrices).unsqueeze(-1))
     assert torch.isfinite(errors).all()
     with np.errstate(divide="ignore"):
         errors = nn.functional.log_softmax(confidences, dim=1) + \
diff --git a/code/model/modules.py b/code/model/modules.py
index d7f63e8..a2d464f 100644
--- a/code/model/modules.py
+++ b/code/model/modules.py
@@ -2,8 +2,12 @@ import math
 import numpy as np
 import torch
 from torch import nn
-from torch_scatter import scatter_max
+from mx_driving import scatter_max
 
+def custom_logdet(A):
+    U, S, V = torch.svd(A)
+    abs_det = torch.prod(S)
+    return torch.log(abs_det)
 
 class MLP(nn.Module):
     def __init__(self, config):
@@ -49,12 +53,12 @@ class NormalMLP(nn.Module):
     def forward(self, x):
         tmp = []
         prev_x_shape = x.shape
-        assert torch.isfinite(x).all()
+        # assert torch.isfinite(x).all()
         tmp.append(x)
         for l in self._mlp:
             x = l(x)
             tmp.append(x)
-            assert torch.isfinite(x).all()
+            # assert torch.isfinite(x).all()
         return x
         output = self._mlp(x)
         return output
@@ -102,9 +106,8 @@ class MCGBlock(nn.Module):
     
     def _repeat_tensor(self, tensor, scatter_numbers, axis=0):
         result = []
-        for i in range(len(scatter_numbers)):
-            result.append(tensor[[i]].expand((int(scatter_numbers[i]), -1, -1)))
-        result = torch.cat(result, axis=0)
+        tensor_reshape = tensor[:len(scatter_numbers)].reshape(len(scatter_numbers), 1, tensor.shape[-1])
+        result = tensor_reshape.repeat_interleave(scatter_numbers, dim=0)
         return result
 
     def _compute_running_mean(self, prevoius_mean, new_value, i):
@@ -123,21 +126,22 @@ class MCGBlock(nn.Module):
         else:
             assert not self._config["identity_c_mlp"]
         c = self._repeat_tensor(c, scatter_numbers)
-        assert torch.isfinite(s).all()
-        assert torch.isfinite(c).all()
+        # assert torch.isfinite(s).all()
+        # assert torch.isfinite(c).all()
         running_mean_s, running_mean_c = s, c
         for i, cg_block in enumerate(self._blocks, start=1):
             s, c = cg_block(scatter_numbers, running_mean_s, running_mean_c)
-            assert torch.isfinite(s).all()
-            assert torch.isfinite(c).all()
+            # assert torch.isfinite(s).all()
+            # assert torch.isfinite(c).all()
             running_mean_s = self._compute_running_mean(running_mean_s, s, i)
             running_mean_c = self._compute_running_mean(running_mean_c, c, i)
-            assert torch.isfinite(running_mean_s).all()
-            assert torch.isfinite(running_mean_c).all()
+            # assert torch.isfinite(running_mean_s).all()
+            # assert torch.isfinite(running_mean_c).all()
         if return_s:
             return running_mean_s 
         if aggregate_batch:
-            return scatter_max(running_mean_c, scatter_idx, dim=0)[0]
+            out, argmax = scatter_max(running_mean_c.squeeze(1), scatter_idx.int())
+            return out
         return running_mean_c
 
 
@@ -160,27 +164,27 @@ class Decoder(nn.Module):
     
     def forward(self, target_scatter_numbers, target_scatter_idx, final_embedding, batch_size):
         # assert torch.isfinite(self._learned_anchor_embeddings).all()
-        assert torch.isfinite(final_embedding).all()
+        # assert torch.isfinite(final_embedding).all()
         trajectories_embeddings = self._mcg_predictor(
             target_scatter_numbers, target_scatter_idx, self._learned_anchor_embeddings,
             final_embedding, return_s=True)
-        assert torch.isfinite(trajectories_embeddings).all()
+        # assert torch.isfinite(trajectories_embeddings).all()
         if self._return_embedding:
             return trajectories_embeddings
         # 
         res = self._mlp_decoder(trajectories_embeddings)
         coordinates = res[:, :, :80 * 2].reshape(
             batch_size, self._config["n_trajectories"], 80, 2)
-        assert torch.isfinite(coordinates).all()
+        # assert torch.isfinite(coordinates).all()
         a = res[:, :, 80 * 2: 80 * 3].reshape(
             batch_size, self._config["n_trajectories"], 80, 1)
-        assert torch.isfinite(a).all()
+        # assert torch.isfinite(a).all()
         b = res[:, :, 80 * 3: 80 * 4].reshape(
             batch_size, self._config["n_trajectories"], 80, 1)
-        assert torch.isfinite(b).all()
+        # assert torch.isfinite(b).all()
         c = res[:, :, 80 * 4: 80 * 5].reshape(
             batch_size, self._config["n_trajectories"], 80, 1)
-        assert torch.isfinite(c).all()
+        # assert torch.isfinite(c).all()
         probas = res[:, :, -1]
         assert torch.isfinite(probas).all()
         if self._config["trainable_cov"]:
@@ -276,8 +280,8 @@ class EM(nn.Module):
         B = diff.unsqueeze(-1)
         C = precision_matrices6[:, :, self._selector, :, :].unsqueeze(2)
         qform = (A @ C @ B)[..., 0, 0]
-        logdetCovM = torch.logdet(covariance_matrices6[:, :, self._selector, :, :].unsqueeze(2))
-        assert torch.isfinite(logdetCovM).all()
+        logdetCovM = custom_logdet(covariance_matrices6[:, :, self._selector, :, :].unsqueeze(2))
+        # assert torch.isfinite(logdetCovM).all()
         pMatrix = torch.exp((
             -np.log(2 * np.pi) - 0.5 * logdetCovM - 0.5 * qform).sum(dim=-1)) + 1e-8
         pMatrix = (pMatrix * probas6.unsqueeze(2)) / ((
@@ -308,8 +312,8 @@ class EM(nn.Module):
                 (covariance_matrices.unsqueeze(1) + (diff.unsqueeze(-1) @ diff.unsqueeze(-2)))
                 ).sum(axis=2)
             covariance_matrices6 = covariance_matrices6 / probas6[..., None, None, None]
-            with torch.no_grad():
-                assert torch.isfinite(torch.logdet(covariance_matrices6)).all()
+            # with torch.no_grad():
+            #     assert torch.isfinite(torch.logdet(covariance_matrices6)).all()
         return probas6, trajectories6, covariance_matrices6
 
 
diff --git a/code/train.py b/code/train.py
index 1de5943..50b066c 100644
--- a/code/train.py
+++ b/code/train.py
@@ -1,4 +1,6 @@
 import torch
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
 from torch import nn
 from torch.optim import Adam
 from torch.optim.lr_scheduler import ReduceLROnPlateau
@@ -35,12 +37,11 @@ def get_last_file(path):
 def get_git_revision_short_hash():
     return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip()
 
-
-config = get_config(sys.argv[1])
-alias = sys.argv[1].split("/")[-1].split(".")[0]
+config = get_config('./configs/final_RoP_Cov_Single.yaml')
+alias = "final_RoP_Cov_Single"
 try:
     models_path = os.path.join("../models", f"{alias}__{get_git_revision_short_hash()}")
-    os.mkdir(tb_path)
+    # os.mkdir(tb_path)
     os.mkdir(models_path)
 except:
     pass
@@ -49,7 +50,7 @@ dataloader = get_dataloader(config["train"]["data_config"])
 val_dataloader = get_dataloader(config["val"]["data_config"])
 model = MultiPathPP(config["model"])
 model.cuda()
-optimizer = Adam(model.parameters(), **config["train"]["optimizer"])
+optimizer = torch_npu.optim.NpuFusedAdam(model.parameters(), **config["train"]["optimizer"])
 if config["train"]["scheduler"]:
     scheduler = ReduceLROnPlateau(optimizer, patience=20, factor=0.5, verbose=True)
 num_steps = 0
@@ -66,7 +67,7 @@ params = sum([np.prod(p.size()) for p in model_parameters])
 print("N PARAMS=", params)
 
 train_losses = []
-
+loss_log_file = './loss_log.txt'
 for epoch in tqdm(range(config["train"]["n_epochs"])):
     pbar = tqdm(dataloader)
     for data in pbar:
@@ -76,14 +77,14 @@ for epoch in tqdm(range(config["train"]["n_epochs"])):
             data = normalize(data, config)
         dict_to_cuda(data)
         probas, coordinates, covariance_matrices, loss_coeff = model(data, num_steps)
-        assert torch.isfinite(coordinates).all()
-        assert torch.isfinite(probas).all()
-        assert torch.isfinite(covariance_matrices).all()
+        # assert torch.isfinite(coordinates).all()
+        # assert torch.isfinite(probas).all()
+        # assert torch.isfinite(covariance_matrices).all()
         xy_future_gt = data["target/future/xy"]
         if config["train"]["normalize_output"]:
             # assert not (config["train"]["normalize_output"] and config["train"]["trainable_cov"])
             xy_future_gt = (data["target/future/xy"] - torch.Tensor([1.4715e+01, 4.3008e-03]).cuda()) / 10.
-            assert torch.isfinite(xy_future_gt).all()
+            # assert torch.isfinite(xy_future_gt).all()
         loss = nll_with_covariances(
             xy_future_gt, coordinates, probas, data["target/future/valid"].squeeze(-1),
             covariance_matrices) * loss_coeff
@@ -99,6 +100,8 @@ for epoch in tqdm(range(config["train"]["n_epochs"])):
         if num_steps % 10 == 0:
             pbar.set_description(f"loss = {round(loss.item(), 2)}")
         if num_steps % 1000 == 0 and this_num_steps > 0:
+            with open(loss_log_file, 'a') as file:
+                file.write("{}: {}\n".format(num_steps, round(loss.item(), 2)))
             saving_data = {
                 "num_steps": num_steps,
                 "model_state_dict": model.state_dict(),
diff --git a/models b/models
deleted file mode 100644
index e69de29..0000000
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..3c7427a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,16 @@
+numpy==1.23.2
+tensorflow-cpu-aws==2.11.0
+torch==2.1.0
+torchvision==0.16.0
+protobuf==4.25.0
+tqdm
+decorator
+sympy
+scipy
+attrs
+cloudpickle
+psutil
+synr==0.5.0
+tornado
+matplotlib
+pyyaml
\ No newline at end of file