"""load tensor and combine tensor"""
import numpy as np
from mindspore.common.tensor import Tensor
from ..communication.management import get_rank, get_group_size
def _get_tensor_strategy(dev_mat, tensor_map):
"""
Get split strategy by device arrangement and tensor map.
Args:
dev_mat (list): The device matrix.
tensor_map (list): The map relation between tensor and devices.
Returns:
List, the split strategy with the same size of np_tensor.
"""
tensor_strategy = []
for dim in tensor_map:
if dim == -1:
tensor_strategy.append(1)
else:
tensor_strategy.append(dev_mat[-dim-1])
return tensor_strategy
def _get_tensor_slice_index(device_arrangement, tensor_strategy, tensor_map, rank_index):
"""
Get the tensor slice index for the local device.
Args:
device_arrangement (list): The device matrix.
tensor_strategy (list): The split strategy with the same size of np_tensor.
tensor_map (list): The map relation between tensor and devices.
rank_index (int): The rank of local device.
Returns:
Integer, the index of the local device for tensor slices.
"""
device_coordinate = _rank_to_coordinate(rank_index, device_arrangement)
device_coordinate_new = _convert_to_new_device_coordinate(device_coordinate, tensor_map)
tensor_slice_index = _coordinate_to_rank(device_coordinate_new, tensor_strategy)
return tensor_slice_index
def _rank_to_coordinate(rank_index, device_arrangement):
"""
Convert rank index to device coordinate.
Args:
rank_index (int): The index of the local device.
device_arrangement (list): The device matrix.
Returns:
List, the coordinate for local device in the device matrix
"""
dim_len = len(device_arrangement)
device_coordinate = np.zeros(dim_len)
for i in range(dim_len):
size = device_arrangement[dim_len - 1 - i]
device_coordinate[dim_len - 1 - i] = rank_index % size
rank_index = int(rank_index / size)
return device_coordinate
def _coordinate_to_rank(device_coordinate, device_arrangement):
"""
Convert device coordinate to rank index.
Args:
device_coordinate (list): The coordinate for local device in the device matrix.
device_arrangement (list): The device matrix.
Returns:
Integer, the index of the local device for tensor slices.
"""
rank_index = 0
size = 1
for i in range(len(device_coordinate)):
rank_index += size * device_coordinate[len(device_coordinate) - 1 - i]
size *= device_arrangement[len(device_coordinate) - 1 - i]
return rank_index
def _convert_to_new_device_coordinate(device_coordinate, tensor_map):
"""
Convert device_coordinate according to the tensor map.
Args:
device_coordinate (list): The coordinate for local device in the device matrix.
tensor_map (list): The map relation between tensor and devices.
Returns:
List, the converted coordinate.
"""
device_coordinate_new = []
for i in range(len(tensor_map)):
if tensor_map[len(tensor_map) - 1 - i] != -1:
device_coordinate_new.insert(0, device_coordinate[len(device_coordinate) - 1 -
tensor_map[len(tensor_map) - 1 - i]])
else:
device_coordinate_new.insert(0, 0)
return device_coordinate_new
def _chunk_tensor(np_tensor, strategy, depth):
"""
Recursive function to chunk tensor.
Args:
np_tensor (NDarray): The matrix to be split.
strategy (list): The split strategy with the same size of np_tensor.
depth (int): Recursion depth.
Returns:
NDarray, the splited matrix.
Raises:
ValueError: If np_tensor can not be split by strategy.
"""
output = []
axis = len(np_tensor.shape) - depth
if np_tensor.shape[axis] % strategy[0] != 0:
raise ValueError("np_tensor can not be split by strategy!")
ret = list(np.split(np_tensor, strategy[0], axis))
if depth == 1:
return ret
for ret_ in ret:
output.extend(
_chunk_tensor(ret_, strategy[len(strategy) - depth + 1:len(strategy)], depth - 1))
return output
def _chunk_tensor_by_strategy(np_tensor, strategy):
"""
Split the input by strategy.
Args:
np_tensor (NDarray): The matrix to be split.
strategy (list): The split strategy with the same size of np_tensor.
Returns:
NDarray, the splited matrix.
Raises:
TypeError: If np_tensor is not ndarray
ValueError: If the length of np_tensor does not match the length of strategy.
"""
if not isinstance(np_tensor, np.ndarray):
raise TypeError("np_tensor should be ndarray!")
if len(strategy) != len(np_tensor.shape):
raise ValueError("The length of np_tensor does not match the length of strategy!")
return _chunk_tensor(np_tensor, strategy, len(strategy))
def _get_slice_index(dev_mat, tensor_map):
"""
Get the slice index for current slice.
Args:
dev_mat (list): The device matrix of devices.
tensor_map (list): The split strategy of tensor.
Returns:
Integer, the slice index for slice on this device.
"""
rank = get_rank()
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
return tensor_slice_index
def _load_tensor(tensor, dev_mat, tensor_map):
"""
Get the tensor slice of the local device by the device matrix and the tensor map
Args:
tensor (Tensor): The tensor to be split.
dev_mat (list): The device matrix of devices.
tensor_map (list): The split strategy of tensor.
Returns:
numpy.array, the sliced array.
Examples:
>>> tensor = Tensor(np.ones([32, 32]))
>>> dev_mat = [2, 4]
>>> tensor_map = [1, -1]
>>> tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
"""
rank = get_rank()
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
np_tensor = tensor.asnumpy()
np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy)
np_tensor_slice = np_tensor_list[int(tensor_slice_index)]
return np_tensor_slice
def _load_tensor_by_layout(tensor, layout):
"""
Load tensor by layout.
Args:
tensor (Tensor): The input tensor.
layout (list): The tensor layout in auto parallel.
Returns:
Tensor, the sliced tensor.
Raises:
TypeError: If layout is not list.
ValueError: If the length of layout is not 3.
"""
if not isinstance(layout, tuple):
raise TypeError("The layout should be tuple! layout is {}".format(layout))
if len(layout) < 6:
raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
dev_mat = layout[0]
tensor_map = layout[1]
uniform_split = layout[4]
group = layout[5]
if uniform_split == 0:
raise RuntimeError("The load tensor only support uniform split now")
if tensor.size == 1:
return tensor
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
if group:
rank = get_rank(group)
size = get_group_size(group)
tensor_slice = np.split(tensor_slice, size)[rank]
return Tensor(tensor_slice)
def _reshape_param_data(param_data, dev_mat, tensor_map):
"""
Combine param slice by the device matrix and the tensor map, used in model parallel scenario.
Args:
param_data (Tensor): The tensor to be reshaped, generated from all the device from AllGatherParamNet.
dev_mat (list): The device matrix of devices.
tensor_map (list): The split strategy of tensor.
Returns:
Tensor, the combined tensor which with the whole data value.
Examples:
>>> param_data = _allgather_param_net(param_data)
>>> dev_mat = [2, 2]
>>> tensor_map = [1, 0]
>>> tensor = _reshape_param_data(tensor_slices, dev_mat, tensor_map)
"""
device_count = 1
for dim in dev_mat:
device_count *= dim
tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0)
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
slice_count = 1
for dim in tensor_strategy:
slice_count *= dim
tensor_slices_new = list(range(slice_count))
for i in range(device_count):
slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i)
tensor_slices_new[int(slice_index)] = np.array(tensor_slices[i])
dim_len = len(tensor_strategy)
for i in range(dim_len):
ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
tensor_slices_new_inner = []
for j in range(ele_count):
new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
for l in range(j * tensor_strategy[dim_len - 1 - i] + 1,
(j + 1) * tensor_strategy[dim_len - 1 - i]):
new_tensor = np.concatenate((new_tensor, tensor_slices_new[l]), axis=dim_len - 1 - i)
tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
tensor_slices_new = tensor_slices_new_inner
return Tensor(tensor_slices_new[0])
def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
"""
Combine param slice by the device matrix, used in model parallel scenario.
Args:
param_data (Tensor): The tensor to be reshaped and rearrangement,
generated from all the device from AllGatherParamNet.
dev_mat (list): The device matrix of devices.
Returns:
Tensor, the combined tensor which with the whole data value.
Examples:
>>> param_data = _allgather_param_net(param_data)
>>> dev_mat = [2, 2]
>>> field_size = [39]
>>> tensor = _reshape_param_data_with_weight(param_data, dev_mat, field_size)
"""
device_count = 1
for dim in dev_mat:
device_count *= dim
tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0)
tensor_slices_col = []
for i in range(len(tensor_slices[0][0])):
tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size, -1)
for j in range(1, device_count):
tensor_slices_new = np.concatenate((tensor_slices_new,\
np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1)
tensor_slices_col.append(tensor_slices_new)
new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1)
for i in range(1, len(tensor_slices_col)):
new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1)
return Tensor(new_tensor)