05360171创建于 2022年3月18日历史提交
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import warnings



from mmcv.cnn import MODELS as MMCV_MODELS

from mmcv.utils import Registry



from mmaction.utils import import_module_error_func



MODELS = Registry('models', parent=MMCV_MODELS)

BACKBONES = MODELS

NECKS = MODELS

HEADS = MODELS

RECOGNIZERS = MODELS

LOSSES = MODELS

LOCALIZERS = MODELS





# Define an empty registry and building func, so that can import

DETECTORS = MODELS



@import_module_error_func('mmdet')

def build_detector(cfg, train_cfg, test_cfg):

    pass





def build_backbone(cfg):

    """Build backbone."""

    return BACKBONES.build(cfg)





def build_head(cfg):

    """Build head."""

    return HEADS.build(cfg)





def build_recognizer(cfg, train_cfg=None, test_cfg=None):

    """Build recognizer."""

    if train_cfg is not None or test_cfg is not None:

        warnings.warn(

            'train_cfg and test_cfg is deprecated, '

            'please specify them in model. Details see this '

            'PR: https://github.com/open-mmlab/mmaction2/pull/629',

            UserWarning)

    assert cfg.get(

        'train_cfg'

    ) is None or train_cfg is None, 'train_cfg specified in both outer field and model field'  # noqa: E501

    assert cfg.get(

        'test_cfg'

    ) is None or test_cfg is None, 'test_cfg specified in both outer field and model field '  # noqa: E501

    return RECOGNIZERS.build(

        cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))





def build_loss(cfg):

    """Build loss."""

    return LOSSES.build(cfg)





def build_localizer(cfg):

    """Build localizer."""

    return LOCALIZERS.build(cfg)





def build_model(cfg, train_cfg=None, test_cfg=None):

    """Build model."""

    args = cfg.copy()

    obj_type = args.pop('type')

    if obj_type in LOCALIZERS:

        return build_localizer(cfg)

    if obj_type in RECOGNIZERS:

        return build_recognizer(cfg, train_cfg, test_cfg)

    if obj_type in DETECTORS:

        if train_cfg is not None or test_cfg is not None:

            warnings.warn(

                'train_cfg and test_cfg is deprecated, '

                'please specify them in model. Details see this '

                'PR: https://github.com/open-mmlab/mmaction2/pull/629',

                UserWarning)

        return build_detector(cfg, train_cfg, test_cfg)

    model_in_mmdet = ['FastRCNN']

    if obj_type in model_in_mmdet:

        raise ImportError(

            'Please install mmdet for spatial temporal detection tasks.')

    raise ValueError(f'{obj_type} is not registered in '

                     'LOCALIZERS, RECOGNIZERS or DETECTORS')





def build_neck(cfg):

    """Build neck."""

    return NECKS.build(cfg)