import sys, cv2, time, os
from UI import Ui_TabWidget
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import QFileDialog,QTabWidget
from PyQt5.QtCore import QTimer, QThread, pyqtSignal, Qt
from PyQt5.QtGui import QPixmap, QImage
from PyQt5.QtWidgets import QLabel,QWidget, QProgressBar
from py_util import read_show
from center_main import main
from opts2 import opts
import glob
from py_util import product_show
from opts2 import opts
from detectors.detector_factory import detector_factory
model_path = '/home/yangna/deepblue/2_MOT/CenterNet/exp/ctdet/dla/model_best.pth'
arch = 'dla_34'
task = 'ctdet'
opt = opts().init('--task {} --load_model {} --arch {}'.format(task, model_path, arch).split(' '))
class mywindow(QTabWidget,Ui_TabWidget):
def __init__(self):
super(mywindow,self).__init__()
self.setupUi(self)
self.thread = train_thred()
self.thread.my_signal.connect(self.set_step)
global imgnums
path = r'/home/yangna/deepblue/2_MOT/CenterNet/data/pig/image/*.png'
self.datas = glob.glob(path)
imgnums = len(self.datas)
self.save_nums = 0
def collect_image(self):
'''自动化采集图片
只能采用线程的方式进行摄像头的显示
'''
self.collect_image_thread = collect_image_thread()
self.collect_image_thread.signal.connect(self.set_label)
self.collect_image_thread.start()
def collect_save_image(self):
folder = f'./data/{self.line51.text()}/image'
if not os.path.exists(folder):
os.makedirs(folder)
self.label53.pixmap().save(f'{folder}/{self.save_nums}.jpg')
self.save_nums += 1
self.label52.setText('已采集图片: ' + str(self.save_nums))
def set_label(self, image):
'''显示采集了多少张图片'''
self.label53.setPixmap(QPixmap.fromImage(image))
def choose_train(self):
global train_json
train_json, file_type = QFileDialog.getOpenFileName(self,
'选择训练数据集',
"",
'All Files (*)')
self.label11.setText(train_json)
def choose_val(self):
global val_json
val_json, file_type = QFileDialog.getOpenFileName(self,
'选择验证数据集',
"",
'All Files (*)')
self.label12.setText(val_json)
def count_func(self):
self.thread.start()
def set_step(self, num):
self.bar.setValue(num)
def load_model(self):
opt.debug = min(opt.debug, 0)
self.detector = detector_factory[opt.task](opt)
def load_picture(self):
'''
验证流程中的选择图片
'''
global imgname
if self.pushbutton_22.text() == '选择图片':
imgname, file_type = QFileDialog.getOpenFileName(self,
'选择图片',
"",
'All Files (*)')
read_show(imgname, self.label_21,
choose_id=self.combobox21.currentIndex() + 1)
def test(self):
'''验证流程中的测试过程'''
read_show(imgname, self.label_21, self.detector,
choose_id=self.combobox21.currentIndex() + 1)
def product_start(self):
'''流水线开始'''
if not hasattr(self, 'detector'):
opt.debug = min(opt.debug, 0)
self.detector = detector_factory[opt.task](opt)
if not hasattr(self, 'product_thread'):
video_path = 0
self.product_thread = product_thread(self.detector, video_path, self.combobox41)
self.product_thread.mysignal.connect(self.product_cess)
self.product_thread.start()
def product_stop(self):
'''流水线暂停'''
self.product_thread.stop()
self.product_thread.quit()
self.product_thread.wait()
def exit(self):
sys.exit()
def product_cess(self, image):
self.label41.setPixmap(QPixmap.fromImage(image))
class collect_image_thread(QThread):
'''
数据采集页:
读取视频流;保存到指定文件夹;实时显示保存的图片数量
在线程中读取视频流,在推到UI进程
'''
signal = pyqtSignal(QImage)
def __init__(self):
super(collect_image_thread, self).__init__()
self.cap = cv2.VideoCapture(0)
def run(self):
while self.cap.isOpened():
try:
ret, frame = self.cap.read()
if ret:
img = cv2.resize(frame, (1000,600))
h, w, c = img.shape
byteperlin = c * w
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
image = QImage(img.data, w, h, byteperlin, QImage.Format_RGB888)
self.signal.emit(image)
except:
self.signal.emit('something wrong with the input video source')
class product_thread(QThread):
'''
将这里做成一个API接口的样子, 模型,
模型一直加载在线程中,视频流可以释放、重启
'''
mysignal = pyqtSignal(QImage)
def __init__(self, detector, video_path, combobox):
super(product_thread, self).__init__()
self.flag = 1
self.video_path = video_path
self.cap = cv2.VideoCapture(video_path)
self.detector = detector
self.combobox = combobox
self.index = 0
def run(self):
'''4帧处理一次'''
self.flag = 1
if not self.cap.isOpened():
self.cap = cv2.VideoCapture(self.video_path)
while self.cap.isOpened() and self.flag:
if self.index > 1000000000:
self.index = 0
self.index += 1
try:
ret = self.cap.grab()
if ret and self.index % 4 == 0:
tret, frame = self.cap.retrieve()
image = product_show(frame, self.detector,
choose_id=self.combobox.currentIndex() + 1)
self.mysignal.emit(image)
except:
print('something wrong with the product_thread')
def stop(self):
self.flag = 0
self.cap.release()
class train_thred(QThread):
my_signal = pyqtSignal(int)
def __init__(self):
super(train_thred, self).__init__()
self.max_iter = 50
def run(self):
opt = opts(train_json, val_json).parse()
center_train = main(opt)
for i in range(self.max_iter):
self.my_signal.emit(i)
center_train.train(i)
center_train.logger.close()
if __name__ == '__main__':
app = QtWidgets.QApplication(sys.argv)
window = mywindow()
window.show()
sys.exit(app.exec_())