05360171创建于 2022年3月18日历史提交
# Copyright 2021 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



# Plotting utils



import glob

import math

import os

import random

from copy import copy

from pathlib import Path



import cv2

import matplotlib

import matplotlib.pyplot as plt

import numpy as np

import torch

import yaml

from PIL import Image

from scipy.signal import butter, filtfilt



from utils.general import xywh2xyxy, xyxy2xywh

from utils.metrics import fitness



# Settings

matplotlib.use('Agg')  # for writing to files only





def color_list():

    # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb

    def hex2rgb(h):

        return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))



    return [hex2rgb(h) for h in plt.rcParams['axes.prop_cycle'].by_key()['color']]





def hist2d(x, y, n=100):

    # 2d histogram used in labels.png and evolve.png

    xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)

    hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))

    xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)

    yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)

    return np.log(hist[xidx, yidx])





def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):

    # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy

    def butter_lowpass(cutoff, fs, order):

        nyq = 0.5 * fs

        normal_cutoff = cutoff / nyq

        return butter(order, normal_cutoff, btype='low', analog=False)



    b, a = butter_lowpass(cutoff, fs, order=order)

    return filtfilt(b, a, data)  # forward-backward filter





def plot_one_box(x, img, color=None, label=None, line_thickness=None):

    # Plots one bounding box on image img

    tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line/font thickness

    color = color or [random.randint(0, 255) for _ in range(3)]

    c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))

    cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)

    if label:

        tf = max(tl - 1, 1)  # font thickness

        t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]

        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3

        cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filled

        cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)





def plot_wh_methods():  # from utils.general import *; plot_wh_methods()

    # Compares the two methods for width-height anchor multiplication

    # https://github.com/ultralytics/yolov3/issues/168

    x = np.arange(-4.0, 4.0, .1)

    ya = np.exp(x)

    yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2



    fig = plt.figure(figsize=(6, 3), dpi=150)

    plt.plot(x, ya, '.-', label='YOLO')

    plt.plot(x, yb ** 2, '.-', label='YOLO ^2')

    plt.plot(x, yb ** 1.6, '.-', label='YOLO ^1.6')

    plt.xlim(left=-4, right=4)

    plt.ylim(bottom=0, top=6)

    plt.xlabel('input')

    plt.ylabel('output')

    plt.grid()

    plt.legend()

    fig.tight_layout()

    fig.savefig('comparison.png', dpi=200)





def output_to_target(output, width, height):

    # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]

    if isinstance(output, torch.Tensor):

        output = output.cpu().numpy()



    targets = []

    for i, o in enumerate(output):

        if o is not None:

            for pred in o:

                box = pred[:4]

                w = (box[2] - box[0]) / width

                h = (box[3] - box[1]) / height

                x = box[0] / width + w / 2

                y = box[1] / height + h / 2

                conf = pred[4]

                cls = int(pred[5])

                                      

                targets.append([item.cpu() if isinstance(item, torch.Tensor) else item for item in [i, cls, x, y, w, h, conf]])



    return np.array(targets)





def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):

    # Plot image grid with labels



    if isinstance(images, torch.Tensor):

        images = images.cpu().float().numpy()

    if isinstance(targets, torch.Tensor):

        targets = targets.cpu().numpy()



    # un-normalise

    if np.max(images[0]) <= 1:

        images *= 255



    tl = 3  # line thickness

    tf = max(tl - 1, 1)  # font thickness

    bs, _, h, w = images.shape  # batch size, _, height, width

    bs = min(bs, max_subplots)  # limit plot images

    ns = np.ceil(bs ** 0.5)  # number of subplots (square)



    # Check if we should resize

    scale_factor = max_size / max(h, w)

    if scale_factor < 1:

        h = math.ceil(scale_factor * h)

        w = math.ceil(scale_factor * w)



    colors = color_list()  # list of colors

    mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)  # init

    for i, img in enumerate(images):

        if i == max_subplots:  # if last batch has fewer images than we expect

            break



        block_x = int(w * (i // ns))

        block_y = int(h * (i % ns))



        img = img.transpose(1, 2, 0)

        if scale_factor < 1:

            img = cv2.resize(img, (w, h))



        mosaic[block_y:block_y + h, block_x:block_x + w, :] = img

        if len(targets) > 0:

            image_targets = targets[targets[:, 0] == i]

            boxes = xywh2xyxy(image_targets[:, 2:6]).T

            classes = image_targets[:, 1].astype('int')

            labels = image_targets.shape[1] == 6  # labels if no conf column

            conf = None if labels else image_targets[:, 6]  # check for confidence presence (label vs pred)



            boxes[[0, 2]] *= w

            boxes[[0, 2]] += block_x

            boxes[[1, 3]] *= h

            boxes[[1, 3]] += block_y

            for j, box in enumerate(boxes.T):

                cls = int(classes[j])

                color = colors[cls % len(colors)]

                cls = names[cls] if names else cls

                if labels or conf[j] > 0.25:  # 0.25 conf thresh

                    label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])

                    plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)



        # Draw image filename labels

        if paths:

            label = Path(paths[i]).name[:40]  # trim to 40 char

            t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]

            cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,

                        lineType=cv2.LINE_AA)



        # Image border

        cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)



    if fname:

        r = min(1280. / max(h, w) / ns, 1.0)  # ratio to limit image size

        mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)

        # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB))  # cv2 save

        Image.fromarray(mosaic).save(fname)  # PIL save

    return mosaic





def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):

    # Plot LR simulating training for full epochs

    optimizer, scheduler = copy(optimizer), copy(scheduler)  # do not modify originals

    y = []

    for _ in range(epochs):

        scheduler.step()

        y.append(optimizer.param_groups[0]['lr'])

    plt.plot(y, '.-', label='LR')

    plt.xlabel('epoch')

    plt.ylabel('LR')

    plt.grid()

    plt.xlim(0, epochs)

    plt.ylim(0)

    plt.tight_layout()

    plt.savefig(Path(save_dir) / 'LR.png', dpi=200)





def plot_test_txt():  # from utils.general import *; plot_test()

    # Plot test.txt histograms

    x = np.loadtxt('test.txt', dtype=np.float32)

    box = xyxy2xywh(x[:, :4])

    cx, cy = box[:, 0], box[:, 1]



    fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)

    ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)

    ax.set_aspect('equal')

    plt.savefig('hist2d.png', dpi=300)



    fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)

    ax[0].hist(cx, bins=600)

    ax[1].hist(cy, bins=600)

    plt.savefig('hist1d.png', dpi=200)





def plot_targets_txt():  # from utils.general import *; plot_targets_txt()

    # Plot targets.txt histograms

    x = np.loadtxt('targets.txt', dtype=np.float32).T

    s = ['x targets', 'y targets', 'width targets', 'height targets']

    fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)

    ax = ax.ravel()

    for i in range(4):

        ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))

        ax[i].legend()

        ax[i].set_title(s[i])

    plt.savefig('targets.jpg', dpi=200)





def plot_study_txt(f='study.txt', x=None):  # from utils.general import *; plot_study_txt()

    # Plot study.txt generated by test.py

    fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)

    ax = ax.ravel()



    fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)

    for f in ['study/study_coco_yolo%s.txt' % x for x in ['s', 'm', 'l', 'x']]:

        y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T

        x = np.arange(y.shape[1]) if x is None else np.array(x)

        s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']

        for i in range(7):

            ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)

            ax[i].set_title(s[i])



        j = y[3].argmax() + 1

        ax2.plot(y[6, :j], y[3, :j] * 1E2, '.-', linewidth=2, markersize=8,

                 label=Path(f).stem.replace('study_coco_', '').replace('yolo', 'YOLO'))



    ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],

             'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')



    ax2.grid()

    ax2.set_xlim(0, 30)

    ax2.set_ylim(28, 50)

    ax2.set_yticks(np.arange(30, 55, 5))

    ax2.set_xlabel('GPU Speed (ms/img)')

    ax2.set_ylabel('COCO AP val')

    ax2.legend(loc='lower right')

    plt.savefig('study_mAP_latency.png', dpi=300)

    plt.savefig(f.replace('.txt', '.png'), dpi=300)





def plot_labels(labels, save_dir=''):

    # plot dataset labels

    c, b = labels[:, 0], labels[:, 1:].transpose()  # classes, boxes

    nc = int(c.max() + 1)  # number of classes



    fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)

    ax = ax.ravel()

    ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)

    ax[0].set_xlabel('classes')

    ax[1].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')

    ax[1].set_xlabel('x')

    ax[1].set_ylabel('y')

    ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')

    ax[2].set_xlabel('width')

    ax[2].set_ylabel('height')

    plt.savefig(Path(save_dir) / 'labels.png', dpi=200)

    plt.close()



    # seaborn correlogram

    try:

        import seaborn as sns

        import pandas as pd

        x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])

        sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',

                     plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),

                     diag_kws=dict(bins=50))

        plt.savefig(Path(save_dir) / 'labels_correlogram.png', dpi=200)

        plt.close()

    except Exception as e:

        pass





def plot_evolution(yaml_file='data/hyp.finetune.yaml'):  # from utils.general import *; plot_evolution()

    # Plot hyperparameter evolution results in evolve.txt

    with open(yaml_file) as f:

        hyp = yaml.load(f, Loader=yaml.FullLoader)

    x = np.loadtxt('evolve.txt', ndmin=2)

    f = fitness(x)

    # weights = (f - f.min()) ** 2  # for weighted results

    plt.figure(figsize=(10, 12), tight_layout=True)

    matplotlib.rc('font', **{'size': 8})

    for i, (k, v) in enumerate(hyp.items()):

        y = x[:, i + 7]

        # mu = (y * weights).sum() / weights.sum()  # best weighted result

        mu = y[f.argmax()]  # best single result

        plt.subplot(6, 5, i + 1)

        plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')

        plt.plot(mu, f.max(), 'k+', markersize=15)

        plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9})  # limit to 40 characters

        if i % 5 != 0:

            plt.yticks([])

        print('%15s: %.3g' % (k, mu))

    plt.savefig('evolve.png', dpi=200)

    print('\nPlot saved as evolve.png')





def plot_results_overlay(start=0, stop=0):  # from utils.general import *; plot_results_overlay()

    # Plot training 'results*.txt', overlaying train and val losses

    s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95']  # legends

    t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1']  # titles

    for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):

        results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T

        n = results.shape[1]  # number of rows

        x = range(start, min(stop, n) if stop else n)

        fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)

        ax = ax.ravel()

        for i in range(5):

            for j in [i, i + 5]:

                y = results[j, x]

                ax[i].plot(x, y, marker='.', label=s[j])

                # y_smooth = butter_lowpass_filtfilt(y)

                # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])



            ax[i].set_title(t[i])

            ax[i].legend()

            ax[i].set_ylabel(f) if i == 0 else None  # add filename

        fig.savefig(f.replace('.txt', '.png'), dpi=200)





def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):

    # from utils.general import *; plot_results(save_dir='runs/train/exp0')

    # Plot training 'results*.txt'

    fig, ax = plt.subplots(2, 5, figsize=(12, 6))

    ax = ax.ravel()

    s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',

         'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']

    if bucket:

        # os.system('rm -rf storage.googleapis.com')

        # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]

        files = ['%g.txt' % x for x in id]

        c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/%g.txt' % (bucket, x) for x in id)

        os.system(c)

    else:

        files = glob.glob(str(Path(save_dir) / '*.txt')) + glob.glob('../../Downloads/results*.txt')

    assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)

    for fi, f in enumerate(files):

        try:

            results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T

            n = results.shape[1]  # number of rows

            x = range(start, min(stop, n) if stop else n)

            for i in range(10):

                y = results[i, x]

                if i in [0, 1, 2, 5, 6, 7]:

                    y[y == 0] = np.nan  # don't show zero loss values

                    # y /= y[0]  # normalize

                label = labels[fi] if len(labels) else Path(f).stem

                ax[i].plot(x, y, marker='.', label=label, linewidth=1, markersize=6)

                ax[i].set_title(s[i])

                # if i in [5, 6, 7]:  # share train and val loss y axes

                #     ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])

        except Exception as e:

            print('Warning: Plotting error for %s; %s' % (f, e))



    fig.tight_layout()

    ax[1].legend()

    fig.savefig(Path(save_dir) / 'results.png', dpi=200)