# -*- coding: utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
# MindIE is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#         http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.

import base64
import io
import math
from PIL import Image


# Set of JPEG Start Of Frame (SOF) markers.
# These markers indicate the beginning of the frame header, 
# which contains image dimensions (width and height).
SOF_MARKERS = {
    0xC0,  # SOF0 (Baseline DCT)
    0xC1,  # SOF1 (Extended Sequential DCT)
    0xC2,  # SOF2 (Progressive DCT)
    0xC3,  # SOF3 (Lossless Sequential)
    0xC5,  # SOF5 (Differential Sequential)
    0xC6,  # SOF6 (Differential Progressive)
    0xC7,  # SOF7 (Differential Lossless)
    0xC9,  # SOF9 (Extended Sequential DCT, Arithmetic Coding)
    0xCA,  # SOF10 (Progressive DCT, Arithmetic Coding)
    0xCB,  # SOF11 (Lossless Sequential, Arithmetic Coding)
    0xCD,  # SOF13 (Differential Sequential, Arithmetic Coding)
    0xCE,  # SOF14 (Differential Progressive, Arithmetic Coding)
    0xCF,  # SOF15 (Differential Lossless, Arithmetic Coding)
}


def parse_jpeg_size(data: bytes):
    """
    Parses JPEG binary data to extract width and height by reading the SOF marker.
    This is faster than decoding the entire image as it only reads the header.
    
    Args:
        data: Raw JPEG byte data.
        
    Returns:
        A tuple (width, height).
        
    Raises:
        ValueError: If the data is not a valid JPEG or SOF marker is not found.
    """
    idx = 0
    length = len(data)

    # Check for JPEG magic number (SOI - Start Of Image)
    if length < 2 or data[0:2] != b"\xff\xd8":
        raise ValueError("Not a JPEG")

    # Start parsing after the SOI marker
    idx = 2
    while idx + 9 < length:
        # Look for the next marker prefix (0xFF)
        if data[idx] != 0xFF:
            idx += 1
            continue

        marker = data[idx + 1]

        # Handle padding bytes (0xFF followed by 0xFF)
        if marker == 0xFF:
            idx += 1
            continue

        # Check if this is a Start Of Frame marker containing dimensions
        if marker in SOF_MARKERS:
            # Structure of SOF segment:
            # [Marker(2)] [Length(2)] [Precision(1)] [Height(2)] [Width(2)] ...
            # Height is at offset +5, +6; Width is at offset +7, +8
            h = (data[idx + 5] << 8) | data[idx + 6]
            w = (data[idx + 7] << 8) | data[idx + 8]
            return w, h

        # Stop if we reach End Of Image (EOI) or Start Of Scan (SOS)
        # SOS indicates the start of compressed image data, so headers are done.
        if marker in (0xD9, 0xDA):
            break

        # Ensure there are enough bytes to read the segment length
        if idx + 3 >= length:
            break

        # Read the length of the current segment (includes the 2 bytes for length itself)
        seg_len = (data[idx + 2] << 8) | data[idx + 3]
        if seg_len < 2:
            break

        # Skip to the next segment
        idx += 2 + seg_len

    raise ValueError("JPEG SOF marker not found")


def parse_png_size(data: bytes):
    """
    Parses PNG binary data to extract width and height from the IHDR chunk.
    
    Args:
        data: Raw PNG byte data (must be at least 24 bytes).
        
    Returns:
        A tuple (width, height).
    """
    # PNG IHDR chunk structure:
    # Bytes 16-19: Width (Big Endian)
    # Bytes 20-23: Height (Big Endian)
    w = int.from_bytes(data[16:20], "big")
    h = int.from_bytes(data[20:24], "big")
    return w, h


def fast_get_hw(b64_str: str):
    """
    Quickly extracts image dimensions from a Base64 encoded data URI using PIL.
    Note: While 'fast' compared to full processing, it still decodes the header via PIL.
    
    Args:
        b64_str: A data URI string (e.g., "data:image/jpeg;base64,...").
        
    Returns:
        A tuple (width, height).
    """
    # Split the data URI to get only the base64 encoded part
    # Assumes format: "data:<mime>;base64,<encoded_data>"
    img_bytes = base64.b64decode(b64_str.split(",")[1])
    
    # Open image using PIL to get dimensions
    with io.BytesIO(img_bytes) as f:
        with Image.open(f) as img:
            return img.width, img.height


def get_hw_from_local(path: str):
    """
    Reads the first 64KB of a local image file to determine dimensions.
    Supports PNG and JPEG formats.
    
    Args:
        path: File path or file URI (starting with "file://").
        
    Returns:
        A tuple (width, height).
    """
    # Remove file:// protocol prefix if present
    if path.startswith("file://"):
        path = path[7:]

    # Read only the first 64KB, which is sufficient for headers of most images
    with open(path, "rb") as f:
        data = f.read(65536)

    # Check PNG signature
    if data.startswith(b"\x89PNG"):
        return parse_png_size(data)
    
    # Assume JPEG if not PNG
    return parse_jpeg_size(data)


def get_mul_token(img_url: str) -> float:
    """
    Calculates a token multiplier based on image dimensions.
    The formula divides the image into 32x32 patches and counts them.
    
    Args:
        img_url: A local file path, file URI, or base64 data URI.
        
    Returns:
        The calculated multiplier (float).
    """
    if img_url.startswith("data:image"):
        # Handle base64 encoded images
        h, w = fast_get_hw(img_url)
    else:
        # Handle local file paths
        h, w = get_hw_from_local(img_url)
    
    # Calculate number of 32x32 patches needed to cover the image
    # Note: The original code had a bug: 'mul_token' was used before assignment.
    # It should likely be 'mul_token = ...' or returned directly.
    mul_token = math.ceil(h / 32) * math.ceil(w / 32)
    
    return mul_token