#!/usr/bin/env python3

# coding: utf-8

# Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.

# This program is free software, you can redistribute it and/or modify it under the terms and conditions of

# CANN Open Software License Agreement Version 2.0 (the "License").

# Please refer to the License for details. You may not use this file except in compliance with the License.

# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,

# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.

# See LICENSE in the root of the software repository for the full text of the License.

# -----------------------------------------------------------------------------------------------------------

"""

"""

import pypto





def test_unsqueeze_validshape():

    """Test whether unsqueeze correctly propagates validShape"""

    dtype = pypto.DT_FP32

    shape = [32, 32]

    x = pypto.tensor(shape, dtype, "x")



    with pypto.function("UNSQUEEZE_VALIDSHAPE", x):

        pypto.set_vec_tile_shapes(32, 32)



        # Create a view with validShape different from shape to test validShape propagation

        # View shape is [32, 32], but validShape is [16, 16]

        x_view = pypto.view(x, [32, 32], [0, 0], valid_shape=[16, 16])



        # Test unsqueeze at dimension 0

        res = pypto.unsqueeze(x_view, 0)



        # Verify shape: [32, 32] -> [1, 32, 32]

        assert res.shape == [1, 32, 32]



        # Verify validShape: [16, 16] -> [1, 16, 16]

        assert len(res.valid_shape) == 3

        assert res.valid_shape[0].concrete() == 1

        assert res.valid_shape[1].concrete() == 16

        assert res.valid_shape[2].concrete() == 16



        assert pypto.reshape(res, [-1]).shape == [1024]