* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
import { observer } from 'mobx-react-lite';
import React, { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import type { TFunction } from 'i18next';
import { COLOR, getDecimalCount, getCompareName, getBaselineName } from '../Common';
import type { ConditionDataType } from './Filter';
import type { CompareData, VoidFunction } from '../../utils/interface';
import { queryCommunicationMatrix } from '../../utils/RequestUtils';
import { cloneDeep } from 'lodash';
import { type Session } from '../../entity/session';
import { CollapsiblePanel } from '@insight/lib/components';
import { safeStr, disposeAdaptiveEchart, getAdaptiveEchart, getDefaultChartOptions } from '@insight/lib/utils';
import type { TooltipComponentOption, VisualMapComponentOption } from 'echarts/components';
import type { PiecewiseVisualMapOption } from 'echarts/types/dist/shared';
import type { Condition, MatrixTypeValues, Range } from './CommunicationMatrix/Filter';
import Filter, { MatrixType } from './CommunicationMatrix/Filter';
import {
getTransportTypeSerie,
getTransportTypeVisualMap,
} from './CommunicationMatrix/transportType';
interface DataSource {
data: MatrixItem[];
rankIds: number[];
}
interface MatrixItem {
srcRank: number;
dstRank: number;
matrixData: CompareData<MatrixData>;
}
interface MatrixData {
opName: string;
bandwidth: number;
transitSize: number;
transitTime: number;
transportType: string;
}
interface MergedMatrixData {
opName: string;
bandwidth: number[];
transitSize: number[];
transitTime: number[];
transportType: string[];
}
interface ChartData {
data: HeatmapData[];
rankIds: number[];
type: MatrixType;
min: number;
max: number;
isCompare: boolean;
}
export type HeatmapData = [string, string, React.Key, CompareData<MergedMatrixData>];
export enum HeatmapDataIndex {
SRC_RANK = 0,
DST_RANK = 1,
VALUE = 2,
DATA = 3,
}
const matrixDataTypeUnits = {
bandwidth: '(GB/s)',
transitSize: '(MB)',
transitTime: '(ms)',
};
function InitChart(data: ChartData, t: TFunction): void {
const chartDom = document.getElementById('matrixchart');
if (chartDom !== null) {
disposeAdaptiveEchart(chartDom);
const myChart = getAdaptiveEchart(chartDom);
myChart.setOption(wrapData(data, t), { replaceMerge: ['series', 'xAxis', 'yAxis'] });
myChart.on('dataZoom', (): void => {
const option = myChart.getOption();
const xAxisData = (option as any).xAxis[0].data;
const dataZoom = (option as any).dataZoom[0];
const start = dataZoom.start ?? 0;
const end = dataZoom.end ?? 100;
const total = xAxisData.length;
const visibleCount = Math.round((end - start) / 100 * total);
const showLabel = visibleCount <= 16;
myChart.setOption({
series: [{
label: {
show: showLabel,
},
}],
}, false);
});
}
}
function wrapData(dataSource: ChartData, t: TFunction): any {
const { data, rankIds, type, min, max, isCompare } = dataSource;
const option: any = cloneDeep(baseOption);
option.xAxis.data = rankIds;
option.yAxis.data = rankIds;
option.series = [getSeries({ data, rankIds, t, type, isCompare })];
option.visualMap = getVisualMap({ type, min, max, dataLength: data.length, isCompare, t });
return option;
}
const baseOption: any = {
xAxis: {
type: 'category',
name: 'Src Rank Id',
splitArea: {
show: true,
},
},
yAxis: {
type: 'category',
name: 'Dst Rank Id',
},
tooltip: { show: true },
textStyle: getDefaultChartOptions().textStyle,
dataZoom: [
{
type: 'inside',
xAxisIndex: [0],
start: 0,
end: 100,
},
{
type: 'inside',
yAxisIndex: [0],
start: 0,
end: 100,
},
{
type: 'inside',
start: 0,
end: 100,
},
],
grid: {
left: '100',
right: '100',
height: '80%',
top: '10%',
},
};
export const baseSeries = {
type: 'heatmap',
emphasis: {
itemStyle: {
shadowBlur: 10,
shadowColor: COLOR.GREY_50,
},
},
};
function getSeries({ data, rankIds, type, isCompare, t }: {
data: HeatmapData[];
rankIds: number[];
type: MatrixType;
isCompare: boolean;
t: TFunction;
}): any {
if (type === MatrixType.TRANSPORT_TYPE) {
return getTransportTypeSerie({ data, rankIds, type, isCompare, t });
}
return {
...baseSeries,
data,
label: {
show: rankIds.length <= 16,
formatter: function (params: any): string {
const dataList = isCompare ? params?.data[3].diff : params?.data[3].compare;
if (!dataList) {
return '';
}
return dataList[type].length > 1 ? `[${dataList[type].join(',')}]` : dataList[type][0];
},
},
tooltip: getTooltip({ t, type, isCompare }),
};
}
interface Label {
label: string | number;
content: string | number;
contentClass?: string;
key?: MatrixType;
}
export function getTooltip({ t, type, isCompare }: {t: TFunction;type: MatrixTypeValues;isCompare: boolean}):
TooltipComponentOption {
return {
show: true,
formatter: function (params: any): string {
const list = getDisplayList({ t, type, isCompare, data: params.data });
return list.map(labelItem => {
const unit = matrixDataTypeUnits[labelItem.key as keyof typeof matrixDataTypeUnits];
return `<span>
${safeStr(labelItem.label)}:
</span>
<span class="tooltip-value ${labelItem.contentClass ?? ''}">
${safeStr(labelItem.content)}
<span style="font-weight: normal;color: var(--mi-text-color-tertiary)">${unit ?? ''}</span>
</span>
<br/>`;
}).join('');
},
};
}
function getDisplayList({ t, type, isCompare, data }:
{data: HeatmapData;t: TFunction;type: MatrixTypeValues;isCompare: boolean}): Label[] {
const [srcRank, dstRank, value, { compare, baseline }] = data;
const list: Label[] = [
{ label: 'Src Rank -> Dst Rank', content: `${srcRank} -> ${dstRank}` },
];
if (isCompare) {
if (compare.opName !== '') {
list.push({ label: t(getCompareName('operatorName')), content: compare.opName });
}
if (baseline.opName !== '') {
list.push({ label: t(getBaselineName('operatorName')), content: baseline.opName });
}
if (type !== MatrixType.TRANSPORT_TYPE) {
list.push({ label: t('Difference'), content: value, key: type, contentClass: typeof value === 'number' && value >= 0 ? 'positive-number' : 'negative-number' });
}
list.push(
{ label: t(getCompareName(type)), content: compare[type][0], key: type },
{ label: t(getBaselineName(type)), content: baseline[type][0], key: type },
);
} else {
if (compare.opName !== '') {
list.push({ label: t('operatorName'), content: compare.opName });
}
const allTypeList = Object.values(MatrixType).map((itemType) => {
let content;
if (compare[itemType].length > 1) {
content = `[${compare[itemType].join(', ')}]`;
} else {
content = compare[itemType][0];
}
return { label: t(itemType), content, key: itemType };
});
list.push(...allTypeList);
}
return list;
}
const baseVisualMap: PiecewiseVisualMapOption = {
orient: 'horizontal',
left: 'center',
bottom: '0',
textStyle: { color: COLOR.GREY_40 },
dimension: 2,
};
function getVisualMap({ dataLength, min, max, type, isCompare = false, t }: {
dataLength: number;min: number;max: number;isCompare?: boolean;type: MatrixType; t: TFunction;
}): VisualMapComponentOption {
if (type === MatrixType.TRANSPORT_TYPE) {
return getTransportTypeVisualMap(isCompare, t);
}
if (dataLength > 0 || isFinite(max)) {
let inRange = { color: [COLOR.BAND_0, COLOR.BAND_1, COLOR.BAND_2, COLOR.BAND_3] };
if (type === MatrixType.TRANSIT_TIME) {
inRange.color.reverse();
}
if (min === max) {
inRange = { color: [COLOR.BAND_1] };
}
return {
...baseVisualMap,
calculable: true,
itemHeight: 300,
inRange,
min,
max,
precision: Math.max(getDecimalCount(min), getDecimalCount(max)),
};
}
return baseVisualMap;
}
function mapValuesToEnum(data: Record<string, any>): Omit<MergedMatrixData, 'opName'> {
return Object.fromEntries(
Object.values(MatrixType).map((key) => [key, [data[key]]]),
) as Omit<MergedMatrixData, 'opName'>;
}
interface UpdateChartParams {
dataSource: DataSource;
switchCondition: Condition;
range?: Range;
shouldUpdateRange: boolean;
setRange: (val: Range) => void;
t: TFunction;
isCompare: boolean;
}
const updateChart = ({ dataSource, switchCondition, range, shouldUpdateRange, setRange, t, isCompare }: UpdateChartParams): void => {
const { data, rankIds } = dataSource;
const result = data.reduce((acc: Record<string, HeatmapData>, cur) => {
const { srcRank, dstRank, matrixData: { compare, baseline, diff } } = cur;
const key = `${srcRank}-${dstRank}`;
const compareValue = compare[switchCondition.type];
const diffValue = diff[switchCondition.type];
const value = isCompare ? diffValue : compareValue;
let match = rankIds.includes(srcRank) && rankIds.includes(dstRank);
if (!switchCondition.showInner) {
match = match && srcRank !== dstRank;
}
if (range) {
match = match && typeof value === 'number' && value >= range.min && value <= range.max;
}
if (match) {
if (!acc[key]) {
acc[key] = [String(srcRank), String(dstRank), value,
{
compare: {
opName: compare.opName,
...mapValuesToEnum(compare),
},
baseline: {
opName: baseline.opName,
...mapValuesToEnum(baseline),
},
diff: {
opName: diff.opName,
...mapValuesToEnum(diff),
},
},
];
} else {
for (const item of Object.values(MatrixType)) {
(acc[key][3].compare[item] as Array<string | number>).unshift(compare[item]);
}
}
}
return acc;
}, {});
const dataList = Object.values(result);
const values: number[] = dataList.map((item: HeatmapData) => typeof item[HeatmapDataIndex.VALUE] === 'number' ? item[HeatmapDataIndex.VALUE] as number : 0);
const min = dataList.length > 0 ? Math.min(...values) : 0;
const max = dataList.length > 0 ? Math.max(...values) : 0;
if (shouldUpdateRange) {
setRange({ min, max });
}
InitChart({ ...dataSource, data: dataList, type: switchCondition.type, min: range?.min ?? min, max: range?.max ?? max, isCompare }, t);
};
const updateData = async(condition: ConditionDataType, setDataSource: VoidFunction, isCompare: boolean): Promise<void> => {
const { iterationId, stage, operatorName, baselineIterationId, pgName, groupIdHash, baselineGroupIdHash } = condition;
if (stage === '' || operatorName === '') {
setDataSource({ data: [], rankIds: [] });
return;
}
const param = { iterationId, pgName, stage, operatorName, isCompare, baselineIterationId, groupIdHash, baselineGroupIdHash };
const res = await queryCommunicationMatrix(param);
const data = res?.matrixList ?? [];
let rankIds: number[];
if (stage === 'p2p') {
rankIds = Array.from(
new Set(
data.flatMap(({ srcRank, dstRank }: {srcRank: number; dstRank: number}) => [srcRank, dstRank]),
),
).sort((a, b) => (a as number) - (b as number)) as number[];
} else {
rankIds = stage.replace(/[(),]/, '')
.split(',').map(value => Number.parseInt(value))
.filter(value => !Number.isNaN(value))
.sort((a, b) => a - b);
}
setDataSource({ data, rankIds });
};
const CommunicationMatrix = observer(({ isShow, conditions, session }: { isShow: boolean;conditions: ConditionDataType;session: Session}) => {
const { t } = useTranslation('communication');
const [switchCondition, setSwitchCondition] = useState<Condition>({ type: MatrixType.BANDWIDTH, showInner: false });
const [range, setRange] = useState<Range>({ min: 0, max: 1 });
const [dataSource, setDataSource] = useState<DataSource>({ data: [], rankIds: [] });
const handleFilterChange = (filed: string, val: string | boolean): void => {
setSwitchCondition({ ...switchCondition, [filed]: val });
};
const handleRangeChange = (rangeVal: Range): void => {
updateChart({ shouldUpdateRange: false, range: rangeVal, setRange, switchCondition, dataSource, t, isCompare: session.isCompare });
};
useEffect(() => {
if (isShow) {
if (session.clusterCompleted) {
updateData(conditions, setDataSource, session.isCompare);
} else {
setDataSource({ data: [], rankIds: [] });
}
}
}, [isShow, conditions, session.isCompare]);
useEffect(() => {
updateChart({ shouldUpdateRange: true, setRange, switchCondition, dataSource, t, isCompare: session.isCompare });
}, [dataSource, switchCondition, t, session.isCompare]);
return <CollapsiblePanel style={{ display: isShow ? 'block' : 'none' }} title={t('sessionTitle.MatrixModel')} padding={'16px 24px'}>
<Filter condition={switchCondition} handleChange={handleFilterChange} range={range} onRangeChange={handleRangeChange}/>
<div id={'matrixchart'} style={{ width: 'calc(100vw - 80px)', height: '800px' }}></div>
</CollapsiblePanel>;
});
export default CommunicationMatrix;