import os
import shutil
import stat
import json
from unittest.mock import patch

from torch_npu.profiler.analysis.prof_bean._ge_memory_record_bean import GeMemoryRecordBean
from torch_npu.profiler.analysis.prof_common_func._file_manager import FileManager

from torch_npu.testing.testcase import TestCase, run_tests


class TestFileManager(TestCase):

    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls.tmp_dir = "./tmp_dir"
        os.makedirs(cls.tmp_dir)

    @classmethod
    def tearDownClass(cls) -> None:
        shutil.rmtree(cls.tmp_dir)

    def test_file_all(self):
        test_file_path = os.path.join(self.tmp_dir, "test_file.log")
        with os.fdopen(os.open(test_file_path,
                               os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp:
            fp.write("something")
        self.assertEqual("something", FileManager.file_read_all(test_file_path))

    def test_read_csv_file(self):
        dir_path = self.tmp_dir
        test_file1 = os.path.join(self.tmp_dir, "test_file1.csv")
        with os.fdopen(os.open(test_file1,
                               os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp:
            fp.write("Component,Timestamp(us),Total Allocated(KB),Total Reserved(KB),Device\n")
            fp.write("APP,18.927,1024,2048,NPU:0\n")
        self.assertEqual([], FileManager.read_csv_file(dir_path, GeMemoryRecordBean))
        test_dict = {
            "Component": "APP", "Timestamp(us)": "18.927", "Total Allocated(KB)":1024,
            "Total Reserved(KB)": 2048, "Device": "NPU:0"
        }
        expect = GeMemoryRecordBean(test_dict)
        read_result = FileManager.read_csv_file(test_file1, GeMemoryRecordBean)
        self.assertEqual(1, len(read_result))
        self.assertEqual(expect.component, read_result[0].component)
        self.assertEqual(expect.device_tag, read_result[0].device_tag)
        self.assertEqual(expect.time_ns, read_result[0].time_ns)
        self.assertEqual(expect.total_allocated, read_result[0].total_allocated)
        self.assertEqual(expect.total_reserved, read_result[0].total_reserved)

    def test_create_csv_file(self):
        test_file = "test_file.csv"
        headers = ["H1", "H2", "H3"]
        data = [["header10", "header11", "header12"]]
        FileManager.create_csv_file(self.tmp_dir, data, test_file, headers)
        test_file_path = os.path.join(self.tmp_dir, test_file)
        with open(test_file_path, 'r') as fp:
            read_header = fp.readline()
            self.assertEqual("H1,H2,H3\n", read_header)
            line1 = fp.readline()
            self.assertEqual("header10,header11,header12\n", line1)

    def test_create_json_file_by_path(self):
        test_file = "test_file.json"
        data = {"Name":"ZhangShan", "Age":666}
        output_path = os.path.join(self.tmp_dir, test_file)
        FileManager.create_json_file_by_path(output_path, data)
        with open(output_path, 'r') as fp:
            read_data = json.load(fp)
        self.assertEqual(data, read_data)

    def test_append_trace(self):
        test_file = "test_file.json"
        data1 = {"Name":"ZhangShan", "Age":666}
        data2 = {"Height":180, "Addr":"China"}
        output_path = os.path.join(self.tmp_dir, test_file)
        FileManager.create_prepare_trace_json_by_path(output_path, data1)
        FileManager.append_trace_json_by_path(output_path, data2, output_path)
        with open(output_path, 'r') as fp:
            read_data = json.load(fp)
        expect = {**data1, **data2}
        self.assertEqual(read_data, expect)

    @patch('os.stat')
    @patch('os.geteuid')
    def test_check_file_owner(self, mock_geteuid, mock_stat):
        test_file = "file_owner.json"
        test_path = os.path.join(self.tmp_dir, test_file)
        mock_geteuid.return_value = 1000
        mock_stat.return_value.st_uid = 0
        self.assertTrue(FileManager.check_file_owner(test_path))
        mock_stat.return_value.st_uid = 1000
        self.assertTrue(FileManager.check_file_owner(test_path))
        mock_stat.return_value.st_uid = 9999
        self.assertFalse(FileManager.check_file_owner(test_path))

    def test_check_db_file_valid_invalid_size(self):
        test_file_path = os.path.join(self.tmp_dir, "invalid_db.db")
        with open(test_file_path, 'w') as fp:
            fp.write("a" * 20)

        with patch('torch_npu.profiler.analysis.prof_common_func._file_manager.Constant.MAX_FILE_SIZE', 10):
            with self.assertRaises(RuntimeError):
                FileManager.check_db_file_vaild(test_file_path)

    def test_file_read_all_nonexistent_file(self):
        test_file_path = os.path.join(self.tmp_dir, "nonexistent.log")
        with patch('torch_npu.profiler.analysis.prof_common_func._file_manager.PathManager.check_directory_path_readable'):
            result = FileManager.file_read_all(test_file_path)
            self.assertEqual('', result)

    def test_read_csv_file_empty_file(self):
        test_file_path = os.path.join(self.tmp_dir, "empty.csv")
        with open(test_file_path, 'w') as fp:
            pass
        result = FileManager.read_csv_file(test_file_path, GeMemoryRecordBean)
        self.assertEqual([], result)

    def test_create_csv_file_empty_data(self):
        test_file = "empty_file.csv"
        FileManager.create_csv_file(self.tmp_dir, [], test_file)
        test_file_path = os.path.join(self.tmp_dir, test_file)
        self.assertFalse(os.path.exists(test_file_path))

    def test_read_json_file_exceeds_max_size(self):
        test_file_path = os.path.join(self.tmp_dir, "large_file.json")
        with open(test_file_path, 'w') as fp:
            fp.write("a" * 10000)

        with patch('torch_npu.profiler.analysis.prof_common_func._file_manager.Constant.MAX_FILE_SIZE', 10):
            result = FileManager.read_json_file(test_file_path)
            self.assertEqual({}, result)

    def test_file_read_all_exceeds_max_size(self):
        test_file_path = os.path.join(self.tmp_dir, "large_file.log")
        with open(test_file_path, 'w') as fp:
            fp.write("a" * 10000)

        with patch('torch_npu.profiler.analysis.prof_common_func._file_manager.Constant.MAX_FILE_SIZE', 10):
            result = FileManager.file_read_all(test_file_path)
            self.assertEqual("", result)

    def test_create_json_file_empty_data(self):
        test_file = "empty_data.json"
        FileManager.create_json_file(self.tmp_dir, [], test_file)
        test_file_path = os.path.join(self.tmp_dir, test_file)
        self.assertFalse(os.path.exists(test_file_path))

    def test_read_csv_file_exceeds_max_size(self):
        test_file_path = os.path.join(self.tmp_dir, "large_file.csv")
        with open(test_file_path, 'w') as fp:
            fp.write("A" * 1024 * 1024 * 2)

        with patch('torch_npu.profiler.analysis.prof_common_func._file_manager.Constant.MAX_CSV_SIZE', 1024):
            result = FileManager.read_csv_file(test_file_path, GeMemoryRecordBean)
            self.assertEqual([], result)

    def test_file_read_all_empty_file(self):
        test_file_path = os.path.join(self.tmp_dir, "empty_file.log")
        with os.fdopen(os.open(test_file_path,
                               os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp:
            pass
        self.assertEqual("", FileManager.file_read_all(test_file_path))

    def test_should_create_empty_file_when_file_name_exist(self):
        parser_done_file = "parser.done"
        parser_done_path = os.path.join(self.tmp_dir, parser_done_file)

        FileManager.create_empty_file(self.tmp_dir, parser_done_file)

        self.assertTrue(os.path.exists(parser_done_path))
        with open(parser_done_path, 'r') as f:
            content = f.read()
        self.assertEqual(content, "")

if __name__ == "__main__":
    run_tests()