feat: vision

This commit is contained in:
raiots 2024-06-16 20:16:27 +08:00
parent cb8f65aacf
commit 384e2f9851
3 changed files with 174 additions and 8 deletions

View File

@ -6,16 +6,11 @@ import sys
import time
import logging
import cv2
from enum import Enum
from control import ManualControl
from control import ManualControl, AutoControl
from control import Status
class Status(Enum):
INIT = 0
DETECT = 1
TRACK = 2
LANDING = 3
class UAV:

View File

@ -1,10 +1,21 @@
from pynput import keyboard
import time
from vision import Tracker
import cv2
from enum import Enum
TICKS_PER_SEC = 30
class Status(Enum):
INIT = 0
DETECT = 1
TRACK = 2
LANDING = 3
class PIDController:
def __init__(self, kp, ki, kd, current_time):
self.kp = kp
@ -81,7 +92,21 @@ class AutoControl:
self.pid = PIDController(1, 0, 0, time.time())
def detect(self):
pass
front_img = self.uav.frame_queue.get()
cv2.imshow("ROI select", front_img[:, :, 0:3])
self.gROI = cv2.selectROI("ROI select", front_img[:, :, 0:3], False)
if (not self.gROI):
print("空框选,退出")
quit()
self.gTracker = Tracker(tracker_type="KCF")
self.gTracker.initWorking(front_img[:, :, 0:3], self.gROI)
print("start tracking")
self.uav.status = Status.TRACKING
cv2.destroyWindow("ROI select")
def track(self):
front_img = self.uav.frame_queue.get()
class FakeUAV:
def __init__(self) -> None:

146
vision.py Normal file
View File

@ -0,0 +1,146 @@
import cv2
class MessageItem(object):
# 用于封装信息的类,包含图片和其他信息
def __init__(self,frame,message):
self._frame = frame
self._message = message
def getFrame(self):
# 图片信息
return self._frame
def getMessage(self):
#文字信息,json格式
return self._message
class Tracker(object):
'''
追踪者模块,用于追踪指定目标
'''
def __init__(self, tracker_type="BOOSTING", draw_coord=True):
'''
初始化追踪器种类
'''
# 获得opencv版本
(major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
self.tracker_types = ['BOOSTING', 'MIL', 'KCF', 'TLD', 'MEDIANFLOW', 'GOTURN']
self.tracker_type = tracker_type
self.isWorking = False
self.draw_coord = draw_coord
# 构造追踪器
if int(major_ver) < 3:
self.tracker = cv2.Tracker_create(tracker_type)
else:
if tracker_type == 'BOOSTING':
self.tracker = cv2.TrackerBoosting_create()
if tracker_type == 'MIL':
self.tracker = cv2.TrackerMIL_create()
if tracker_type == 'KCF':
self.tracker = cv2.TrackerKCF_create()
if tracker_type == 'TLD':
self.tracker = cv2.TrackerTLD_create()
if tracker_type == 'MEDIANFLOW':
self.tracker = cv2.TrackerMedianFlow_create()
if tracker_type == 'GOTURN':
self.tracker = cv2.TrackerGOTURN_create()
def initWorking(self, frame, box):
'''
追踪器工作初始化
frame:初始化追踪画面
box:追踪的区域
'''
if not self.tracker:
raise Exception("追踪器未初始化")
status = self.tracker.init(frame, box)
# if not status:
# raise Exception("追踪器工作初始化失败")
self.coord = box
self.isWorking = True
def track(self, frame):
'''
开启追踪
'''
message = None
if self.isWorking:
status, self.coord = self.tracker.update(frame)
if status:
message = {"coord": [((int(self.coord[0]), int(self.coord[1])),
(int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3])))]}
if self.draw_coord:
p1 = (int(self.coord[0]), int(self.coord[1]))
p2 = (int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3]))
center = (int(self.coord[0] + self.coord[2] / 2), int(self.coord[1] + self.coord[3] / 2))
cv2.circle(frame, center, 2, (0, 0, 255), 2)
cv2.rectangle(frame, p1, p2, (255, 0, 0), 2, 1)
message['msg'] = "is tracking"
message['target'] = center
return MessageItem(frame, message)
if __name__ == '__main__':
# 初始化视频捕获设备
gVideoDevice = cv2.VideoCapture(0)
gCapStatus, gFrame = gVideoDevice.read()
# 选择 框选帧
print("按 n 选择下一帧,按 y 选取当前帧")
while True:
if (gCapStatus == False):
print("捕获帧失败")
quit()
_key = cv2.waitKey(0) & 0xFF
if(_key == ord('n')):
gCapStatus,gFrame = gVideoDevice.read()
if(_key == ord('y')):
break
cv2.imshow("pick frame",gFrame)
# 框选感兴趣区域region of interest
cv2.destroyWindow("pick frame")
gROI = cv2.selectROI("ROI frame",gFrame,False)
if (not gROI):
print("空框选,退出")
quit()
# 初始化追踪器
gTracker = Tracker(tracker_type="KCF")
gTracker.initWorking(gFrame,gROI)
# 循环帧读取,开始跟踪
while True:
gCapStatus, gFrame = gVideoDevice.read()
if(gCapStatus):
# 展示跟踪图片
print(gFrame)
_item = gTracker.track(gFrame)
cv2.imshow("track result",_item.getFrame())
if _item.getMessage():
# 打印跟踪数据
print(_item.getMessage())
else:
# 丢失重新用初始ROI初始
print("丢失重新使用初始ROI开始")
gTracker = Tracker(tracker_type="KCF")
gTracker.initWorking(gFrame, gROI)
_key = cv2.waitKey(1) & 0xFF
if (_key == ord('q')) | (_key == 27):
break
if (_key == ord('r')) :
# 用户请求用初始ROI
print("用户请求用初始ROI")
gTracker = Tracker(tracker_type="KCF")
gTracker.initWorking(gFrame, gROI)
else:
print("捕获帧失败")
quit()