import os
import time
from datetime import datetime
from datetime import timedelta

import torch
from torch.distributed import PrefixStore
from torch.distributed.rendezvous import rendezvous
import torch_npu


class TorchNpuRunStoreSample:
    def __init__(self, ip: str, port: int):
        self._current_rank = int(os.getenv('RANK', 0))
        self._world_size = int(os.getenv('WORLD_SIZE', 0))
        self._init_method = f'parallel://{ip}:{port}'
        self._timeout = timedelta(minutes=1)
        self._key = 'sample_torch_npu_run_store:test_case_001'

        rendezvous_iterator = rendezvous(
            self._init_method, self._current_rank, self._world_size, timeout=self._timeout
        )

        self._store, self._current_rank, self._world_size = next(rendezvous_iterator)
        self._store.set_timeout(self._timeout)
        self._store = PrefixStore("default_pg", self._store)

    def store_based_barrier(self):
        self._store.add(self._key, 1)

        timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
        print(f'[{timestamp}]rank: {self._current_rank} add key: {self._key} first time')

        start_time = time.time()
        alive_count = self._store.add(self._key, 0)
        while alive_count != self._world_size:
            time.sleep(0.01)
            alive_count = self._store.add(self._key, 0)
            if alive_count == self._world_size:
                break
            if timedelta(seconds=(time.time() - start_time)) > self._timeout:
                timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
                raise RuntimeError(f'[{timestamp}]rank: {self._current_rank} wait all workers ready timeout')

        timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
        print(f'[{timestamp}] Rank: {self._current_rank} complete store-based barrier for worker count: {alive_count}')


if __name__ == "__main__":
    sample = TorchNpuRunStoreSample(ip='127.0.0.1', port=29513)
    sample.store_based_barrier()