3267bf88创建于 2021年10月15日历史提交
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "ps/util.h"
#include <unordered_map>
#include <vector>
#include <memory>
#include "ps/constants.h"
#include "ps/ps_context.h"
#include "utils/ms_utils.h"

namespace mindspore {
namespace ps {
std::unordered_map<std::string, int64_t> Util::optimizer_to_ids{
  {kApplyMomentum, 0},
  {kSparseAdam, 1},
  {kSparseLazyAdam, 2},
  {kSparseFtrl, 3},
};

std::unordered_map<int64_t, std::string> Util::id_to_optimizers{
  {0, kApplyMomentum},
  {1, kSparseAdam},
  {2, kSparseLazyAdam},
  {3, kSparseFtrl},
};

std::unordered_map<int64_t, std::string> Util::id_to_optimizer_nodes{
  {0, kApplyMomentumOp},
  {1, kSparseAdamOp},
  {2, kSparseLazyAdamOp},
  {3, kSparseFtrlOp},
};

bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); }

bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); }

int64_t Util::optimizer_id(const std::string &name) {
  if (optimizer_to_ids.count(name) > 0) {
    return optimizer_to_ids[name];
  }
  return -1;
}

std::string Util::optimizer_name(int64_t id) {
  if (id_to_optimizers.count(id) > 0) {
    return id_to_optimizers[id];
  }
  return "";
}

std::string Util::optimizer_node_name(int64_t id) {
  if (id_to_optimizer_nodes.count(id) > 0) {
    return id_to_optimizer_nodes[id];
  }
  return "";
}

bool Util::is_optimizer(const std::string &name) { return optimizer_to_ids.count(name) > 0; }

int64_t Util::LocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
  std::map<int64_t, int64_t> shard_dims = AllRankLocalShard(first_dim, rank_id, server_num);
  if (shard_dims.count(rank_id) == 0) {
    MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id;
  }
  return shard_dims[rank_id];
}

std::map<int64_t, int64_t> Util::AllRankLocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
  if (first_dim <= 0 || server_num <= 0 || rank_id < 0) {
    MS_LOG(EXCEPTION) << "Input values are invalid.";
  }
  if (rank_id >= server_num) {
    MS_LOG(EXCEPTION) << "The rank ID " << rank_id << " should be less than the number of servers " << server_num;
  }
  std::map<int64_t, int64_t> shard_dims;
  for (int64_t i = 0; i < server_num; i++) {
    shard_dims[i] = 0;
  }
  if (server_num != static_cast<int64_t>(shard_dims.size())) {
    MS_LOG(EXCEPTION) << "Inconsistent server num " << server_num << " shard dims counter size " << shard_dims.size();
  }
  int64_t server_index = -1;
  for (int64_t i = 0; i < first_dim; i++) {
    server_index = (server_index + 1) % server_num;
    shard_dims[server_index] = shard_dims[server_index] + 1;
  }
  if (shard_dims.count(rank_id) == 0) {
    MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id << ", total server num " << server_num;
  }
  return shard_dims;
}

void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
                                const size_t first_dim_size, const size_t outer_dim_size,
                                mindspore::kernel::SparseGradient<int> *unique_sparse_grad) {
  size_t slice_segment_size = indices_size * segment_size;
  std::vector<float> workspace_grad(slice_segment_size);
  std::vector<int> workspace_indices(indices_size);

  MS_EXCEPTION_IF_NULL(gradients);
  MS_EXCEPTION_IF_NULL(indices);

  mindspore::kernel::SparseGradient<int> workspace_sparse_grad(
    {workspace_grad.data(), workspace_indices.data(), indices_size});
  mindspore::kernel::SparseGradient<int> input_sparse_grad({gradients, indices, indices_size});
  mindspore::kernel::ReduceSparseGradientParam<int> param;
  param.input_grad_ = &input_sparse_grad;
  param.workspace_grad_ = &workspace_sparse_grad;
  param.output_grad_ = unique_sparse_grad;
  param.max_index_ = first_dim_size;
  param.value_stride_ = outer_dim_size;

  mindspore::kernel::SparseOptimizerCPUKernel::BucketReduceSparseGradient(param);
}

bool Util::FuseServerCommOps(const pipeline::ResourcePtr &res) {
  FuncGraphPtr func_graph = res->func_graph();
  MS_EXCEPTION_IF_NULL(func_graph);
  DoFusion(func_graph, kPullWeightOpName, kFusedPullWeightOpName);
  DoFusion(func_graph, kPushWeightOpName, kFusedPushWeightOpName);
  return true;
}

void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name,
                    const std::string &fused_cnode_name) {
  MS_EXCEPTION_IF_NULL(func_graph);
  std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());

  std::vector<AnfNodePtr> single_nodes;
  std::vector<std::string> weight_names;
  std::vector<int64_t> indices;
  for (const AnfNodePtr &node : node_list) {
    if (node != nullptr && node->isa<CNode>()) {
      if (AnfAlgo::GetCNodeName(node) == cnode_name) {
        single_nodes.push_back(node);

        auto weight_name_value_node =
          AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightNameOffset)->cast<ValueNodePtr>();
        const std::string &weight_name = GetValue<std::string>(weight_name_value_node->value());
        weight_names.push_back(weight_name);

        auto weight_index_value_node =
          AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightIndexOffset)->cast<ValueNodePtr>();
        int64_t weight_index = GetValue<int64_t>(weight_index_value_node->value());
        indices.push_back(weight_index);
      }
    }
  }

  auto prim = std::make_shared<Primitive>(fused_cnode_name);
  MS_EXCEPTION_IF_NULL(prim);
  std::vector<AnfNodePtr> fused_node_inputs = {};
  fused_node_inputs.push_back(NewValueNode(prim));
  (void)std::for_each(single_nodes.begin(), single_nodes.end(), [&](const AnfNodePtr &node) {
    fused_node_inputs.push_back(AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0));
  });

  auto fused_cnode = func_graph->NewCNode(fused_node_inputs);
  MS_EXCEPTION_IF_NULL(fused_cnode);
  AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(weight_names), fused_cnode);
  AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(indices), fused_cnode);
  AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), fused_cnode);

  auto kernel_info = std::make_shared<device::KernelInfo>();
  MS_EXCEPTION_IF_NULL(kernel_info);
  fused_cnode->set_kernel_info(kernel_info);
  auto kernel_build_info = GenerateKernelBuildInfo(single_nodes);
  AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_cnode.get());

  AbstractBasePtrList abstract_list;
  for (const auto &node : single_nodes) {
    auto cnode = node->cast<CNodePtr>();
    MS_EXCEPTION_IF_NULL(cnode);
    abstract_list.push_back(cnode->abstract());
  }
  auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
  MS_EXCEPTION_IF_NULL(abstract_tuple);
  fused_cnode->set_abstract(abstract_tuple);

  auto manager = func_graph->manager();
  MS_EXCEPTION_IF_NULL(manager);
  for (const auto &node : single_nodes) {
    if (!manager->Replace(node, fused_cnode)) {
      MS_LOG(EXCEPTION) << "manager replace node failed";
    }
  }
  return;
}

kernel::KernelBuildInfoPtr Util::GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
  std::vector<std::string> inputs_device_format;
  std::vector<std::string> outputs_device_format;
  std::vector<TypeId> inputs_device_type;
  std::vector<TypeId> outputs_device_type;
  std::vector<std::vector<size_t>> outputs_shape;
  kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  for (size_t idx = 0; idx < node_list.size(); ++idx) {
    auto cnode = utils::cast<CNodePtr>(node_list[idx]);
    MS_EXCEPTION_IF_NULL(cnode);
    size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
    for (size_t input_index = 0; input_index < input_num; ++input_index) {
      (void)inputs_device_format.emplace_back(kOpFormat_DEFAULT);
      inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
    }
    size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
    for (size_t output_index = 0; output_index < output_num; ++output_index) {
      (void)outputs_device_format.emplace_back(kOpFormat_DEFAULT);
      outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
      outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
    }
  }
  builder.SetInputsFormat(inputs_device_format);
  builder.SetOutputsFormat(outputs_device_format);
  builder.SetInputsDeviceType(inputs_device_type);
  builder.SetOutputsDeviceType(outputs_device_type);
  return builder.Build();
}
}  // namespace ps
}  // namespace mindspore