diff --git a/examples/post_training/configs/rl.toml b/examples/post_training/configs/rl.toml
index 417b607..3788778 100644
--- a/examples/post_training/configs/rl.toml
+++ b/examples/post_training/configs/rl.toml
@@ -17,7 +17,7 @@ redis = "12800"
 
 [train]
 resume = false
-epoch = 80
+epoch = 8
 output_dir = "./outputs/rl"
 epsilon = 1e-6
 optm_name = "AdamW"
@@ -46,7 +46,7 @@ batch_size = 4
 quantization = "none"
 
 [policy]
-model_name_or_path = "nvidia/Cosmos-Reason1-7B"
+model_name_or_path = "./Cosmos-Reason1-7B"
 model_max_length = 10240
 model_gradient_checkpointing = true
 
@@ -57,7 +57,7 @@ experiment_name = "post_training/rl"
 
 [train.train_policy]
 type = "grpo"
-dataset.name = "nvidia/Cosmos-Reason1-RL-Dataset"
+dataset.name = "./datasets/Cosmos-Reason1-RL-Dataset"
 dataset.subset = "robovqa"
 dataset.split = "rl"
 enable_dataset_cache = false
@@ -82,14 +82,14 @@ save_mode = "async"
 
 [rollout.parallelism]
 n_init_replicas = 1
-tp_size = 2
+tp_size = 4
 pp_size = 1
 
 [policy.parallelism]
 n_init_replicas = 1
 tp_size = 1
 cp_size = 1
-dp_shard_size = 2
+dp_shard_size = 4
 pp_size = 1
 dp_replicate_size = 1
 cp_rotate_method = "allgather"
diff --git a/examples/post_training/configs/sft.toml b/examples/post_training/configs/sft.toml
index ac34a15..08b86eb 100644
--- a/examples/post_training/configs/sft.toml
+++ b/examples/post_training/configs/sft.toml
@@ -40,7 +40,7 @@ validation_step = 30
 validation_batch_per_replica = 2
 
 [policy]
-model_name_or_path = "nvidia/Cosmos-Reason1-7B"
+model_name_or_path = "./Cosmos-Reason1-7B"
 model_max_length = 4096
 model_gradient_checkpointing = true
 
@@ -51,7 +51,7 @@ experiment_name = "post_training/sft"
 
 [train.train_policy]
 type = "sft"
-dataset.name = "nvidia/Cosmos-Reason1-SFT-Dataset"
+dataset.name = "./datasets/Cosmos-Reason1-SFT-Dataset"
 dataset.subset = "robovqa"
 dataset.split = "understanding"
 dataset.test_size = 100
@@ -71,7 +71,7 @@ save_mode = "async"
 n_init_replicas = 1
 tp_size = 1
 cp_size = 1
-dp_shard_size = 4
+dp_shard_size = 8
 pp_size = 1
 dp_replicate_size = 1
 cp_rotate_method = "allgather"
diff --git a/examples/post_training/tools/dataset/cosmos_grpo.py b/examples/post_training/tools/dataset/cosmos_grpo.py
index 6240105..90bbed1 100644
--- a/examples/post_training/tools/dataset/cosmos_grpo.py
+++ b/examples/post_training/tools/dataset/cosmos_grpo.py
@@ -154,12 +154,6 @@ class CosmosGRPOValDataset(CosmosGRPODataset):
     """
 
     def setup(self, config: Config, tokenizer: AutoTokenizer, *args, **kwargs):
-        if not config.train.enable_validation:
-            logger.warning(
-                "Validation is not enabled in the config. Skipping setup for CosmosGRPOValDataset."
-            )
-            return
-
         self.config = config
         self.tokenizer = tokenizer
         self.dataset = load_dataset(
@@ -268,8 +262,6 @@ if __name__ == "__main__":
     config = Config.from_dict(config)
 
     util.prepare_cosmos_data(dataset=config.train.train_policy.dataset)
-    if config.train.enable_validation:
-        util.prepare_cosmos_data(dataset=config.validation.dataset)
 
     # It is best practice to pass the dataset and val_dataset as factory functions
     # so that the dataset and val_dataset can be loaded on demand. (Not all workers need them)
@@ -283,6 +275,4 @@ if __name__ == "__main__":
         dataset=get_dataset,
         reward_fns=[custom_reward_fn],
         data_packer=DemoDataPacker(),
-        val_dataset=get_val_dataset,
-        val_data_packer=DemoDataPacker(),
     )
diff --git a/examples/post_training/tools/dataset/cosmos_sft.py b/examples/post_training/tools/dataset/cosmos_sft.py
index 6398e41..493c28f 100644
--- a/examples/post_training/tools/dataset/cosmos_sft.py
+++ b/examples/post_training/tools/dataset/cosmos_sft.py
@@ -15,6 +15,8 @@
 
 """Supervised Fine-Tuning (SFT) dataset."""
 # ruff: noqa: E402
+import torch,torch_npu
+from torch_npu.contrib import transfer_to_npu
 
 from cosmos_reason1_utils.script import init_script
 
diff --git a/scripts/inference_sample.py b/scripts/inference_sample.py
index 2b6b46a..5a5f18d 100644
--- a/scripts/inference_sample.py
+++ b/scripts/inference_sample.py
@@ -22,6 +22,8 @@ uv run scripts/inference_sample.py
 ```
 """
 
+import torch, torch_npu
+from torch_npu.contrib import transfer_to_npu
 from pathlib import Path
 
 import qwen_vl_utils