4be15bca创建于 2021年5月29日历史提交
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module is to write data into mindrecord.
"""
import mindspore._c_mindrecord as ms
from mindspore import log as logger
from .common.exceptions import MRMAddSchemaError, MRMAddIndexError, MRMBuildSchemaError, MRMGetMetaError

__all__ = ['ShardHeader']


class ShardHeader:
    """
    Wrapper class which is represent ShardHeader class in c++ module.

    The class would store meta data of MindRecord File.
    """
    def __init__(self, header=None):
        if header:
            self._header = header
        else:
            self._header = ms.ShardHeader()

    def add_schema(self, schema):
        """
        Add object of ShardSchema.

        Args:
          schema (ShardSchema): Object of ShardSchema.

        Returns:
           int, schema id.

        Raises:
            MRMAddSchemaError: If failed to add schema.
        """
        schema_id = self._header.add_schema(schema)
        if schema_id == -1:
            logger.error("Failed to add schema.")
            raise MRMAddSchemaError
        return schema_id

    def add_index_fields(self, index_fields):
        """
        Add object of ShardSchema.

        Args:
          index_fields (list[str]):

        Returns:
            MSRStatus, SUCCESS or FAILED.

        Raises:
            MRMAddSchemaError: If failed to add index field.
        """
        ret = self._header.add_index_fields(index_fields)
        if ret != ms.MSRStatus.SUCCESS:
            logger.error("Failed to add index field.")
            raise MRMAddIndexError
        return ret

    def build_schema(self, content, desc=None):
        """
        Build raw schema to generate schema object.

        Args:
            content (dict): Dict of user defined schema.
            desc (str,optional): String of schema description.

        Returns:
           Class ShardSchema.

        Raises:
            MRMBuildSchemaError: If failed to build schema.
        """
        desc = desc if desc else ""
        schema = ms.Schema.build(desc, content)
        if not schema:
            logger.error("Failed to add build schema.")
            raise MRMBuildSchemaError
        return schema

    @property
    def header(self):
        """Getter of header"""
        return self._header

    def _get_schema(self):
        """
        Get schema info.

        Returns:
             List of dict.
        """
        return self._get_meta()['schema']

    def _get_blob_fields(self):
        """
        Get blob fields info.

        Returns:
             List of dict.
        """
        return self._get_meta()['blob_fields']

    def _get_meta(self):
        """
       Get metadata including schema, blob fields .etc.

       Returns:
            List of dict.
       """
        ret = self._header.get_meta()
        if ret and len(ret) == 1:
            return ret[0].get_schema_content()

        logger.error("Failed to get meta info.")
        raise MRMGetMetaError

    @property
    def blob_fields(self):
        """Getter of blob fields"""
        return self._get_blob_fields()

    @property
    def schema(self):
        """Getter of schema"""
        return self._get_schema()