Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package parser
import (
"bufio"
"math"
"os"
"regexp"
"strconv"
"strings"
"sync"
"github.com/containerd/containerd/oci"
"k8s.io/apimachinery/pkg/util/sets"
"ascend-common/api"
"ascend-common/common-utils/hwlog"
"ascend-common/common-utils/utils"
)
const (
maxEnvLength = 1024
comma = ","
minus = "-"
ascend = "Ascend"
envSliceLen = 2
deviceSliceLen = 2
formatIntBase = 10
)
var (
npuMajorFetchCtrl sync.Once
npuMajorID sets.String
)
func ParseAscendDeviceInfo(env, containerID string) []int {
parts := strings.SplitN(env, "=", envSliceLen)
if len(parts) != envSliceLen {
hwlog.RunLog.Warnf("Invalid %s format in container %s", api.AscendDeviceInfo, containerID)
return nil
}
devicesStr := parts[1]
if len(devicesStr) > maxEnvLength {
hwlog.RunLog.Warnf("%s value too long in container %s", api.AscendDeviceInfo, containerID)
return nil
}
return parseDeviceIDs(devicesStr, containerID)
}
func parseDeviceIDs(devices, containerID string) []int {
if strings.Contains(devices, ascend) {
return parseAscendStyle(devices, containerID)
}
if strings.Contains(devices, comma) && strings.Contains(devices, minus) {
return parseCommaMinusStyle(devices, containerID)
}
if strings.Contains(devices, minus) {
return parseMinusStyle(devices, containerID)
}
return parseCommaStyle(devices, containerID)
}
func parseCommaStyle(devices, containerID string) []int {
devList := strings.Split(devices, comma)
deviceIDs := make([]int, 0, len(devList))
for _, devID := range devList {
id, err := strconv.Atoi(strings.TrimSpace(devID))
if err != nil {
hwlog.RunLog.Warnf("Invalid device ID %s in container %s: %v", devID, containerID, err)
continue
}
deviceIDs = append(deviceIDs, id)
}
return deviceIDs
}
func parseMinusStyle(devices, containerID string) []int {
deviceIDs := make([]int, 0)
rangeParts := strings.Split(devices, minus)
if len(rangeParts) != deviceSliceLen {
hwlog.RunLog.Warnf("Invalid device range %s in container %s", devices, containerID)
return deviceIDs
}
start, err := strconv.Atoi(strings.TrimSpace(rangeParts[0]))
if err != nil {
hwlog.RunLog.Warnf("Invalid start device ID %s in container %s: %v", rangeParts[0], containerID, err)
return deviceIDs
}
end, err := strconv.Atoi(strings.TrimSpace(rangeParts[1]))
if err != nil {
hwlog.RunLog.Warnf("Invalid end device ID %s in container %s: %v", rangeParts[1], containerID, err)
return deviceIDs
}
if start > end {
hwlog.RunLog.Warnf("Invalid device range %d-%d in container %s: start > end", start, end, containerID)
return deviceIDs
}
if end > math.MaxInt16 {
hwlog.RunLog.Warnf("End device ID %d exceeds maximum in container %s", end, containerID)
return deviceIDs
}
for i := start; i <= end; i++ {
deviceIDs = append(deviceIDs, i)
}
return deviceIDs
}
func parseCommaMinusStyle(devices, containerID string) []int {
deviceIDs := make([]int, 0)
parts := strings.Split(devices, comma)
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.Contains(part, minus) {
deviceIDs = append(deviceIDs, parseMinusStyle(part, containerID)...)
} else {
id, err := strconv.Atoi(part)
if err != nil {
hwlog.RunLog.Warnf("Invalid device ID %s in container %s: %v", part, containerID, err)
continue
}
deviceIDs = append(deviceIDs, id)
}
}
return deviceIDs
}
func parseAscendStyle(devices, containerID string) []int {
deviceIDs := make([]int, 0)
parts := strings.Split(devices, comma)
for _, part := range parts {
part = strings.TrimSpace(part)
if !strings.Contains(part, minus) {
hwlog.RunLog.Warnf("Invalid Ascend device format %s in container %s", part, containerID)
continue
}
deviceParts := strings.Split(part, minus)
if len(deviceParts) != deviceSliceLen {
hwlog.RunLog.Warnf("Invalid Ascend device format %s in container %s", part, containerID)
continue
}
deviceID, err := strconv.Atoi(deviceParts[1])
if err != nil {
hwlog.RunLog.Warnf("Invalid device ID %s in container %s: %v", deviceParts[1], containerID, err)
continue
}
deviceIDs = append(deviceIDs, deviceID)
}
return deviceIDs
}
func npuMajor() sets.String {
npuMajorFetchCtrl.Do(func() {
var err error
npuMajorID, err = getNPUMajorID()
if err != nil {
return
}
})
return npuMajorID
}
func getNPUMajorID() (sets.String, error) {
const (
deviceCount = 2
maxSearchLine = 512
)
path, err := utils.CheckPath("/proc/devices")
if err != nil {
return nil, err
}
majorID := sets.NewString()
f, err := os.Open(path)
if err != nil {
return majorID, err
}
defer func() {
err = f.Close()
if err != nil {
hwlog.RunLog.Error(err)
}
}()
s := bufio.NewScanner(f)
count := 0
for s.Scan() {
if count > maxSearchLine {
break
}
count++
text := s.Text()
matched, err := regexp.MatchString("^[0-9]{1,3}\\s[v]?devdrv-cdev$", text)
if err != nil {
return majorID, err
}
if !matched {
continue
}
fields := strings.Fields(text)
majorID.Insert(fields[0])
}
return majorID, nil
}
func FilterNPUDevices(spec *oci.Spec) []int {
if spec == nil || spec.Linux == nil || spec.Linux.Resources == nil {
return nil
}
devIDs := make([]int, 0)
majorIDs := npuMajor()
for _, dev := range spec.Linux.Resources.Devices {
if dev.Minor == nil || dev.Major == nil {
continue
}
if *dev.Minor > math.MaxInt32 {
return nil
}
major := strconv.FormatInt(*dev.Major, formatIntBase)
if dev.Type == "c" && majorIDs.Has(major) {
devIDs = append(devIDs, int(*dev.Minor))
}
}
return devIDs
}