Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package app
import (
"context"
"errors"
"fmt"
"math"
"strconv"
"strings"
"syscall"
"github.com/containerd/containerd"
"github.com/containerd/containerd/cio"
"github.com/containerd/containerd/namespaces"
"github.com/containerd/containerd/oci"
errors2 "k8s.io/apimachinery/pkg/api/errors"
"ascend-common/api"
"ascend-common/common-utils/hwlog"
"ascend-common/common-utils/utils"
"container-manager/pkg/common"
)
const (
charDevice = "c"
maxDevicesNum = 100000
maxEnvNum = 10000
base = 10
)
type ContainerdClient struct {
client *containerd.Client
stoppedContainer map[string]containerd.Container
}
func NewContainerdClient() *ContainerdClient {
return &ContainerdClient{
stoppedContainer: make(map[string]containerd.Container),
}
}
func (c *ContainerdClient) init() error {
cClient, err := containerd.New(common.ParamOption.SockPath)
if err != nil {
hwlog.RunLog.Errorf("connect to container runtime failed, error: %v", err)
return errors.New("connect to container runtime failed")
}
c.client = cClient
return nil
}
func (c *ContainerdClient) close() error {
return c.client.Close()
}
func (c *ContainerdClient) getAllContainers() (interface{}, error) {
nss, err := c.client.NamespaceService().List(context.Background())
if err != nil {
hwlog.RunLog.Errorf("failed to get namespace list, error: %v", err)
return nil, errors.New("failed to get namespace list")
}
var ctrs = make(map[string][]containerd.Container)
for _, ns := range nss {
if ns == "moby" {
continue
}
ctx := namespaces.WithNamespace(context.Background(), ns)
containers, err := c.client.ContainerService().List(ctx)
if err != nil {
hwlog.RunLog.Errorf("failed to get container list for ns %s, skip. error: %v", ns, err)
continue
}
for _, container := range containers {
containerObj, err := c.client.LoadContainer(ctx, container.ID)
if err != nil {
hwlog.RunLog.Errorf("failed to load container %s, error: %v", container.ID, err)
continue
}
ctrs[ns] = append(ctrs[ns], containerObj)
}
}
return ctrs, nil
}
func (c *ContainerdClient) doStart(containerID, ns string) error {
ctx := namespaces.WithNamespace(context.Background(), ns)
container, ok := c.stoppedContainer[containerID]
if !ok {
return fmt.Errorf("container %s have not stopped", containerID)
}
task, err := container.NewTask(ctx, cio.NewCreator(cio.WithStdio))
if err != nil {
return fmt.Errorf("failed to new task for container: %v", err)
}
if err = task.Start(ctx); err != nil {
return fmt.Errorf("failed to start task for container: %v", err)
}
delete(c.stoppedContainer, containerID)
return nil
}
func (c *ContainerdClient) doStop(containerID, ns string) error {
ctx := namespaces.WithNamespace(context.Background(), ns)
container, err := c.client.LoadContainer(ctx, containerID)
if err != nil {
return fmt.Errorf("failed to load container: %v", err)
}
task, err := container.Task(ctx, nil)
if err != nil {
return fmt.Errorf("failed to get container %s , error: %v", containerID, err)
}
if err = task.Kill(ctx, syscall.SIGKILL); err != nil {
return fmt.Errorf("failed to kill container %s, error: %v", containerID, err)
}
if err = task.Kill(ctx, syscall.SIGKILL, containerd.WithKillAll); err != nil && errors2.IsNotFound(err) {
return fmt.Errorf("failed to kill container %s, error: %v", containerID, err)
}
if _, err = task.Delete(ctx, containerd.WithProcessKill); err != nil {
return fmt.Errorf("failed to delete task for container %s, error: %v", containerID, err)
}
c.stoppedContainer[containerID] = container
return nil
}
func (c *ContainerdClient) getUsedDevs(containerObj interface{}, ctx context.Context) ([]int32, error) {
switch cs := containerObj.(type) {
case containerd.Container:
return c.doGetUsedDevs(cs, ctx)
default:
return nil, nil
}
}
func (c *ContainerdClient) doGetUsedDevs(cs containerd.Container, ctx context.Context) ([]int32, error) {
spec, err := getCtrValidSpec(cs, ctx)
if err != nil {
return nil, fmt.Errorf("get container %s valid spec failed, error: %v", cs.ID(), err)
}
envs := spec.Process.Env
for i := len(envs) - 1; i >= 0; i-- {
env := envs[i]
if strings.Contains(env, api.AscendDeviceInfo) {
usedDevs, err := getUsedDevsWithAscendRuntime(env)
if err != nil {
return nil, fmt.Errorf("parse env %s failed, error: %v", api.AscendDeviceInfo, err)
}
return usedDevs, nil
}
}
hwlog.RunLog.Debugf("get container %s used devs by env %s failed, not used ascend docker runtime",
cs.ID(), api.AscendDeviceInfo)
usedDevs, err := getUsedDevsWithoutAscendRuntime(spec)
if err != nil {
return nil, fmt.Errorf("get container %s device ids failed, error: %v", cs.ID(), err)
}
return usedDevs, nil
}
func getCtrValidSpec(containerObj containerd.Container, ctx context.Context) (*oci.Spec, error) {
spec, err := containerObj.Spec(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get container spec:%v", err)
}
if spec.Linux == nil || spec.Linux.Resources == nil || len(spec.Linux.Resources.Devices) > maxDevicesNum {
return nil, fmt.Errorf("devices in container is too much (%v) or empty", maxDevicesNum)
}
if spec.Process == nil || len(spec.Process.Env) > maxEnvNum {
return nil, fmt.Errorf("env in container is too much (%v) or empty", maxEnvNum)
}
return spec, nil
}
func getUsedDevsWithoutAscendRuntime(spec *oci.Spec) ([]int32, error) {
if spec.Linux == nil || spec.Linux.Resources == nil {
return nil, errors.New("empty spec info")
}
phyIds := make([]int32, 0, sliceLen16)
majorIDs := npuMajor()
for _, dev := range spec.Linux.Resources.Devices {
if dev.Minor == nil || dev.Major == nil {
continue
}
if *dev.Minor > math.MaxInt32 {
return nil, fmt.Errorf("get wrong device ID (%v)", dev.Minor)
}
major := strconv.FormatInt(*dev.Major, base)
if dev.Type == charDevice && utils.Contains(majorIDs, major) {
phyIds = append(phyIds, int32(*dev.Minor))
}
}
return phyIds, nil
}