"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2026 Huawei Technologies Co.,Ltd.

MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:

         http://license.coscl.org.cn/MulanPSL2

THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""

import copy
import bisect
from typing import Dict, Tuple, Optional

from base import DeviceSnapshot, BlockState, Block, Segment, TraceEntry
from util import get_logger

from .hooker_defs import AllocatorHooker

allocator_logger = get_logger("ALLOCATOR")


class AllocatorContext:
    def __init__(self, snapshot: DeviceSnapshot):
        self.device_snapshot = snapshot
        self.current_undo_event: TraceEntry = None
        self.workspace_flag = False

    def set_current_undo_event(self, undo_event: TraceEntry):
        self.current_undo_event = undo_event


class SimulatedCachingAllocator:
    def __init__(self, ctx: AllocatorContext):
        self.ctx = ctx
        self.hookers: Dict[int, AllocatorHooker] = {}

    def register_hooker(self, hooker: AllocatorHooker) -> int:
        idx = hash(hooker)
        self.hookers[idx] = hooker
        return idx

    def unregister_hooker(self, hooker_id: int):
        if hooker_id in self.hookers:
            del self.hookers[hooker_id]

    def alloc_block(self, new_block: Block) -> bool:
        """
            回放时模拟分配一个新的block
        :param new_block: 待分配的block
        """
        _error = "Failed to simulate alloc block"
        gap_result = self.find_gap_for_alloc_block(new_block.address, new_block.size,
                                                   self.ctx.current_undo_event.stream if self.ctx.current_undo_event else None)
        if gap_result is None:
            allocator_logger.error(f"{_error}: cannot find gap for block (addr={new_block.address}, size={new_block.size})")
            return False
        segment, insert_idx = gap_result
        if self.ctx.current_undo_event:
            new_block.free_event_idx = self.ctx.current_undo_event.idx
        if self.ctx.current_undo_event and self.ctx.current_undo_event.action == 'free_completed':
            new_block.state = BlockState.ACTIVE_PENDING_FREE
        else:
            new_block.state = BlockState.ACTIVE_ALLOCATED
        new_block.segment_ptr = segment
        for hooker in self.hookers.values():
            hooker.pre_replay_alloc_block(new_block, self.ctx.device_snapshot)
        blocks = segment.blocks
        blocks.insert(insert_idx, new_block)
        segment.active_size += new_block.size
        self.ctx.device_snapshot.total_activated += new_block.size
        if new_block.state == BlockState.ACTIVE_ALLOCATED:
            segment.allocated_size += new_block.size
            self.ctx.device_snapshot.total_allocated += new_block.size
        for hooker in self.hookers.values():
            hooker.post_replay_alloc_block(new_block, self.ctx.device_snapshot)
        return True

    def free_block(self, alloc_event: TraceEntry) -> bool:
        """
            回放时模拟释放一个block,可能涉及到拆分合并
        :param alloc_event: 待回滚的alloc事件
        """
        _error = "Failed to simulate free block"
        seg_idx = self.ctx.device_snapshot.find_segment_idx_by_addr(alloc_event.addr, alloc_event.stream)
        if seg_idx == -1:
            allocator_logger.error(f"{_error}: cannot find segment for block (addr={alloc_event.addr})")
            return False
        exist_block = self.find_block_by_addr(seg_idx, alloc_event.addr)
        if exist_block is None:
            # workspace场景容忍
            if self.ctx.workspace_flag:
                allocator_logger.warning(f"{_error}: cannot find block (addr={alloc_event.addr}), workspace scenario tolerance")
                return True
            allocator_logger.error(f"{_error}: cannot find block (addr={alloc_event.addr})")
            return False
        if exist_block.size < alloc_event.size:
            allocator_logger.error(f"{_error}: block size ({exist_block.size}) < event size ({alloc_event.size})")
            return False
        exist_block.alloc_event_idx = alloc_event.idx
        exist_block_copy = copy.copy(exist_block)
        for hooker in self.hookers.values():
            hooker.pre_replay_free_block(exist_block, self.ctx.device_snapshot)
        segment = exist_block.segment_ptr
        if segment is None:
            allocator_logger.error(f"{_error}: block has no segment_ptr")
            return False
        segment.active_size -= exist_block.size
        self.ctx.device_snapshot.total_activated -= exist_block.size
        if exist_block.state == BlockState.ACTIVE_ALLOCATED:
            segment.allocated_size -= exist_block.size
            self.ctx.device_snapshot.total_allocated -= exist_block.size
        segment.blocks.remove(exist_block)
        for hooker in self.hookers.values():
            hooker.post_replay_free_block(exist_block_copy, self.ctx.device_snapshot)
        return True

    def active_block(self, free_requested_event: TraceEntry) -> bool:
        """
            回放时模拟active一个block
        :param free_requested_event: 待回放的free_requested请求
        """
        _error = "Failed to simulate active block"
        seg_idx = self.ctx.device_snapshot.find_segment_idx_by_addr(free_requested_event.addr, free_requested_event.stream)
        if seg_idx == -1:
            allocator_logger.error(f"{_error}: cannot find segment for block (addr={free_requested_event.addr})")
            return False
        active_pending_free_block = self.find_block_by_addr(seg_idx, free_requested_event.addr)
        if active_pending_free_block is None:
            allocator_logger.error(f"{_error}: cannot find block (addr={free_requested_event.addr})")
            return False
        if active_pending_free_block.state != BlockState.ACTIVE_PENDING_FREE:
            # workspace场景容忍异常
            if self.ctx.workspace_flag:
                allocator_logger.warning(
                    f"{_error}: block (addr={free_requested_event.addr}) is not in {BlockState.ACTIVE_PENDING_FREE} state, "
                    f"but workspace_flag is True, skipping")
                return True
            allocator_logger.error(
                f"{_error}: block (addr={free_requested_event.addr}) is not in {BlockState.ACTIVE_PENDING_FREE} state, "
                f"current state: {active_pending_free_block.state}")
            return False
        if active_pending_free_block.segment_ptr is None:
            allocator_logger.error(f"{_error}: the found active pending block's segment is none.")
            return False
        active_pending_free_block.state = BlockState.ACTIVE_ALLOCATED
        active_pending_free_block.segment_ptr.allocated_size += active_pending_free_block.size
        self.ctx.device_snapshot.total_allocated += active_pending_free_block.size
        return True

    def alloc_or_map_segment(self, new_segment: Segment, merge: bool = False) -> bool:
        """
            回放时模拟alloc或map一个新的内存段
        :param new_segment: 新内存段
        :param merge: 是否合并,map时对应虚拟内存场景,否则仅为alloc
        """
        _error = "Failed to alloc or map segment"
        segments = self.ctx.device_snapshot.segments
        for hooker in self.hookers.values():
            hooker.pre_replay_map_or_alloc_segment(new_segment, self.ctx.device_snapshot)
        if self.ctx.current_undo_event:
            new_segment.free_or_unmap_event_idx = self.ctx.current_undo_event.idx
        if not merge:
            self.insert_segment_sorted(new_segment)
            self.ctx.device_snapshot.total_reserved += new_segment.total_size
            for hooker in self.hookers.values():
                hooker.post_replay_map_or_alloc_segment(new_segment, self.ctx.device_snapshot)
            return True
        new_seg_start = new_segment.address
        new_seg_end = new_seg_start + new_segment.total_size
        left_adjacent_idx = -1
        right_adjacent_idx = -1
        for i, seg in enumerate(segments):
            if seg.stream != new_segment.stream:
                continue
            if seg.address + seg.total_size == new_seg_start:
                left_adjacent_idx = i
            elif new_seg_end == seg.address:
                right_adjacent_idx = i
        if left_adjacent_idx == -1 and right_adjacent_idx == -1:
            self.insert_segment_sorted(new_segment)
            self.ctx.device_snapshot.total_reserved += new_segment.total_size
            for hooker in self.hookers.values():
                hooker.post_replay_map_or_alloc_segment(new_segment, self.ctx.device_snapshot)
            return True
        virtual_map_segment = copy.deepcopy(new_segment)
        if left_adjacent_idx != -1:
            left_seg = segments[left_adjacent_idx]
            left_seg.total_size += new_segment.total_size
            left_seg.allocated_size += new_segment.allocated_size
            left_seg.active_size += new_segment.active_size
            for block in new_segment.blocks:
                block.segment_ptr = left_seg
                left_seg.blocks.append(block)
            new_segment = left_seg
            if right_adjacent_idx != -1:
                if not self.merge_segments(left_adjacent_idx, right_adjacent_idx):
                    allocator_logger.error(f"{_error}: failed to merge right segment")
                    return False
        else:
            self.insert_segment_sorted(new_segment)
            new_idx = segments.index(new_segment)
            # 插入后右相邻索引后移一位,重新计算以保持健壮性
            corrected_right_idx = new_idx + 1
            if corrected_right_idx < len(segments) and segments[corrected_right_idx].address == new_seg_end:
                if not self.merge_segments(new_idx, corrected_right_idx):
                    allocator_logger.error(f"{_error}: failed to merge right segment")
                    return False
            else:
                allocator_logger.error(f"{_error}: right adjacent segment not found after insert (expected addr={new_seg_end})")
        self.ctx.device_snapshot.total_reserved += virtual_map_segment.total_size
        for hooker in self.hookers.values():
            hooker.post_replay_map_or_alloc_segment(virtual_map_segment, self.ctx.device_snapshot)
        return True

    def free_segment(self, alloc_seg_event: TraceEntry) -> bool:
        """
            回放时模拟free一个内存段(非虚拟内存场景)
        :param alloc_seg_event: 待回滚的alloc事件
        """
        _error = "Free segment failed"
        seg_addr = alloc_seg_event.addr
        exist_seg = self.find_segment_by_exact_addr(seg_addr, alloc_seg_event.stream)
        if exist_seg is None:
            allocator_logger.error(f"{_error}: cannot found segment(addr={seg_addr}, stream={alloc_seg_event.stream})")
            return False
        if exist_seg.total_size != alloc_seg_event.size:
            allocator_logger.error(f"{_error}: cannot free segment(addr={seg_addr}, size={alloc_seg_event.size}) in "
                                   f"exist segment(addr={exist_seg.address}, size={exist_seg.total_size})")
            return False
        if exist_seg.active_size > 0:
            allocator_logger.error(f"{_error}: cannot free a segment that still has active allocations.")
            return False

        exist_seg.alloc_or_map_event_idx = alloc_seg_event.idx
        for hooker in self.hookers.values():
            hooker.pre_replay_unmap_or_free_segment(exist_seg, self.ctx.device_snapshot)
        self.ctx.device_snapshot.total_reserved -= exist_seg.total_size
        self.ctx.device_snapshot.segments.remove(exist_seg)
        for hooker in self.hookers.values():
            hooker.post_replay_unmap_or_free_segment(exist_seg, self.ctx.device_snapshot)
        return True

    def unmap_segment(self, map_event):
        """
            回放时模拟unmap一个已有的内存段(虚拟内存场景)
        :param map_event: 待回滚的map事件
        """
        _error = "Unmap segment failed"
        segments = self.ctx.device_snapshot.segments
        virtual_free_segment = Segment.build_from_event(map_event)
        seg_addr = virtual_free_segment.address
        unmap_size = virtual_free_segment.total_size
        exist_seg_idx = self.ctx.device_snapshot.find_segment_idx_by_addr(seg_addr, map_event.stream)
        if exist_seg_idx < 0 or exist_seg_idx >= len(segments):
            allocator_logger.error(f"{_error}: cannot found segment(addr={seg_addr})")
            return False
        exist_seg = segments[exist_seg_idx]
        virtual_free_segment.free_or_unmap_event_idx = exist_seg.free_or_unmap_event_idx
        virtual_free_segment.alloc_or_map_event_idx = map_event.idx
        if not (seg_addr >= exist_seg.address and seg_addr + unmap_size <= exist_seg.address + exist_seg.total_size):
            allocator_logger.error(
                f"{_error}: cannot unmap segment(addr={seg_addr}, unmap_size={unmap_size}) in exist segment("
                f"addr={exist_seg.address}, total_size={exist_seg.total_size})")
            return False
        for hooker in self.hookers.values():
            hooker.pre_replay_unmap_or_free_segment(virtual_free_segment, self.ctx.device_snapshot)
        seg_start = exist_seg.address
        seg_end = seg_start + exist_seg.total_size
        unmap_end = seg_addr + unmap_size
        if exist_seg.stream != map_event.stream:
            allocator_logger.error(f"{_error}: stream mismatch (segment: {exist_seg.stream}, event: {map_event.stream})")
            return False
        if seg_addr == seg_start:
            if not self.shrink_segment(exist_seg_idx, seg_addr, unmap_size, 'left'):
                allocator_logger.error(f"{_error}: failed to shrink segment from left")
                return False
        elif unmap_end == seg_end:
            if not self.shrink_segment(exist_seg_idx, seg_addr, unmap_size, 'right'):
                allocator_logger.error(f"{_error}: failed to shrink segment from right")
                return False
        else:
            if not self.split_segment_at(exist_seg_idx, seg_addr, unmap_size):
                allocator_logger.error(f"{_error}: failed to split segment")
                return False
        self.ctx.device_snapshot.total_reserved -= unmap_size
        for hooker in self.hookers.values():
            hooker.post_replay_unmap_or_free_segment(virtual_free_segment, self.ctx.device_snapshot)
        return True

    def find_segment_by_exact_addr(self, addr: int, stream: int) -> Optional[Segment]:
        seg_idx = self.ctx.device_snapshot.find_segment_idx_by_addr(addr, stream)
        if seg_idx != -1:
            seg = self.ctx.device_snapshot.segments[seg_idx]
            if seg.address == addr and seg.stream == stream:
                return seg
        return None

    def find_block_by_addr(self, seg_idx: int, block_addr: int) -> Optional[Block]:
        if seg_idx < 0 or seg_idx >= len(self.ctx.device_snapshot.segments):
            return None
        segment = self.ctx.device_snapshot.segments[seg_idx]
        blocks = segment.blocks
        left, right = 0, len(blocks) - 1
        while left <= right:
            mid = (left + right) // 2
            if blocks[mid].address == block_addr:
                return blocks[mid]
            elif blocks[mid].address < block_addr:
                left = mid + 1
            else:
                right = mid - 1
        return None

    def find_gap_for_alloc_block(self, event_addr: int, event_size: int, stream: int = None) -> Optional[Tuple[Segment, int]]:
        seg_idx = self.ctx.device_snapshot.find_segment_idx_by_addr(event_addr, stream)
        if seg_idx == -1:
            return None
        segment = self.ctx.device_snapshot.segments[seg_idx]
        blocks = segment.blocks
        event_end = event_addr + event_size
        seg_start = segment.address
        seg_end = seg_start + segment.total_size
        if len(blocks) == 0:
            if seg_start <= event_addr and event_end <= seg_end:
                return segment, 0
            return None

        if blocks[0].address >= event_end:
            if seg_start <= event_addr:
                return segment, 0
            return None

        left, right = 0, len(blocks) - 1
        while left < right:
            mid = (left + right + 1) // 2
            if blocks[mid].address <= event_addr:
                left = mid
            else:
                right = mid - 1

        gap_start = blocks[left].address + blocks[left].size
        if left + 1 < len(blocks):
            gap_end = blocks[left + 1].address
            if gap_start <= event_addr and event_end <= gap_end:
                return segment, left + 1
        else:
            if gap_start <= event_addr and event_end <= seg_end:
                return segment, len(blocks)

        return None

    def insert_segment_sorted(self, segment: Segment):
        segments = self.ctx.device_snapshot.segments
        keys = [(seg.address, seg.stream) for seg in segments]
        idx = bisect.bisect_left(keys, (segment.address, segment.stream))
        segments.insert(idx, segment)

    def split_segment_at(self, seg_idx: int, cut_addr: int, cut_size: int) -> bool:
        _error = "Failed to split segment"
        segments = self.ctx.device_snapshot.segments
        if seg_idx < 0 or seg_idx >= len(segments):
            allocator_logger.error(f"{_error}: invalid segment index {seg_idx}")
            return False
        original_segment = segments[seg_idx]
        seg_start = original_segment.address
        seg_end = seg_start + original_segment.total_size
        cut_end = cut_addr + cut_size
        if cut_addr < seg_start or cut_end > seg_end:
            allocator_logger.error(f"{_error}: cut range [{cut_addr}, {cut_end}) is outside segment [{seg_start}, {seg_end})")
            return False
        if cut_addr == seg_start and cut_end == seg_end:
            allocator_logger.warning("Split Seg: cut range covers entire segment, nothing to split, just remove it")
            del self.ctx.device_snapshot.segments[seg_idx]
            return True
        left_segment = Segment(
            address=seg_start,
            total_size=cut_addr - seg_start,
            stream=original_segment.stream,
            segment_type=original_segment.segment_type,
            allocated_size=0,
            active_size=0,
            blocks=[],
            device=original_segment.device,
            frames=original_segment.frames,
            is_expandable=original_segment.is_expandable,
            free_or_unmap_event_idx=original_segment.free_or_unmap_event_idx,
            alloc_or_map_event_idx=original_segment.alloc_or_map_event_idx
        )
        right_segment = Segment(
            address=cut_end,
            total_size=seg_end - cut_end,
            stream=original_segment.stream,
            segment_type=original_segment.segment_type,
            allocated_size=0,
            active_size=0,
            blocks=[],
            device=original_segment.device,
            frames=original_segment.frames,
            is_expandable=original_segment.is_expandable,
            free_or_unmap_event_idx=original_segment.free_or_unmap_event_idx,
            alloc_or_map_event_idx=original_segment.alloc_or_map_event_idx
        )
        for block in original_segment.blocks:
            block_start = block.address
            block_end = block_start + block.size
            if block_end <= cut_addr:
                block.segment_ptr = left_segment
                left_segment.blocks.append(block)
                left_segment.active_size += block.size
                if block.state == BlockState.ACTIVE_ALLOCATED:
                    left_segment.allocated_size += block.size
            elif block_start >= cut_end:
                block.segment_ptr = right_segment
                right_segment.blocks.append(block)
                right_segment.active_size += block.size
                if block.state == BlockState.ACTIVE_ALLOCATED:
                    right_segment.allocated_size += block.size
            else:
                allocator_logger.warning(f"{_error}: active block [{block_start}, {block_end}) overlaps with cut range [{cut_addr}, {cut_end}), just drop it.")
        del segments[seg_idx]
        if left_segment.total_size > 0:
            self.insert_segment_sorted(left_segment)
        if right_segment.total_size > 0:
            self.insert_segment_sorted(right_segment)
        return True

    def shrink_segment(self, seg_idx: int, shrink_addr: int, shrink_size: int, direction: str) -> bool:
        _error = "Failed to shrink segment"
        segments = self.ctx.device_snapshot.segments
        if seg_idx < 0 or seg_idx >= len(segments):
            allocator_logger.error(f"{_error}: invalid segment index {seg_idx}")
            return False
        if direction not in ['left', 'right']:
            allocator_logger.error(f"{_error}: invalid direction '{direction}', must be 'left' or 'right'")
            return False
        segment = segments[seg_idx]
        seg_start = segment.address
        seg_end = seg_start + segment.total_size
        shrink_end = shrink_addr + shrink_size
        if direction == 'left':
            if shrink_addr < seg_start or shrink_end > seg_end:
                allocator_logger.error(f"{_error}: shrink range [{shrink_addr}, {shrink_end}) is outside segment [{seg_start}, {seg_end})")
                return False
            new_start = shrink_end
            new_size = seg_end - new_start
            if new_size < 0:
                allocator_logger.error(f"{_error}: shrink results in negative segment size")
                return False
            for block in segment.blocks:
                block_start = block.address
                block_end = block_start + block.size
                if block_end <= shrink_end:
                    allocator_logger.error(f"{_error}: active block [{block_start}, {block_end}) in shrink range [{shrink_addr}, {shrink_end})")
                    return False
            segment.address = new_start
            segment.total_size = new_size
            new_blocks = []
            for block in segment.blocks:
                if block.address >= new_start:
                    new_blocks.append(block)
            segment.blocks = new_blocks
        else:
            if shrink_addr < seg_start or shrink_end > seg_end:
                allocator_logger.error(f"{_error}: shrink range [{shrink_addr}, {shrink_end}) is outside segment [{seg_start}, {seg_end})")
                return False
            new_size = shrink_addr - seg_start
            if new_size < 0:
                allocator_logger.error(f"{_error}: shrink results in negative segment size")
                return False
            for block in segment.blocks:
                block_start = block.address
                block_end = block_start + block.size
                if block_start >= shrink_addr:
                    allocator_logger.error(f"{_error}: active block [{block_start}, {block_end}) in shrink range [{shrink_addr}, {shrink_end})")
                    return False
            new_blocks = []
            for block in segment.blocks:
                if block.address + block.size <= shrink_addr:
                    new_blocks.append(block)
            segment.blocks = new_blocks
            segment.total_size = new_size
        segment.allocated_size = sum(b.size for b in segment.blocks if b.state == BlockState.ACTIVE_ALLOCATED)
        segment.active_size = sum(b.size for b in segment.blocks)
        if segment.total_size == 0:
            del segments[seg_idx]
        return True

    def merge_segments(self, target_idx: int, source_idx: int) -> bool:
        _error = "Failed to merge segments"
        segments = self.ctx.device_snapshot.segments
        if target_idx < 0 or target_idx >= len(segments):
            allocator_logger.error(f"{_error}: invalid target segment index {target_idx}")
            return False
        if source_idx < 0 or source_idx >= len(segments):
            allocator_logger.error(f"{_error}: invalid source segment index {source_idx}")
            return False
        if target_idx == source_idx:
            allocator_logger.error(f"{_error}: target and source are the same segment")
            return False
        target = segments[target_idx]
        source = segments[source_idx]
        if target.stream != source.stream:
            allocator_logger.error(f"{_error}: segments have different streams (target: {target.stream}, source: {source.stream})")
            return False
        are_adjacent = (target.address + target.total_size == source.address or
                        source.address + source.total_size == target.address)
        if not are_adjacent:
            allocator_logger.error(f"{_error}: segments are not adjacent (target: [{target.address}, {target.address + target.total_size}), source: [{source.address}, {source.address + source.total_size}))")
            return False
        if target.address > source.address:
            target, source = source, target
            target_idx, source_idx = source_idx, target_idx
        target.total_size += source.total_size
        target.allocated_size += source.allocated_size
        target.active_size += source.active_size
        for block in source.blocks:
            block.segment_ptr = target
            target.blocks.append(block)
        del segments[source_idx]
        return True