feat: vision
This commit is contained in:
parent
cb8f65aacf
commit
384e2f9851
9
agent.py
9
agent.py
|
@ -6,16 +6,11 @@ import sys
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import cv2
|
import cv2
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from control import ManualControl
|
from control import ManualControl, AutoControl
|
||||||
|
from control import Status
|
||||||
|
|
||||||
|
|
||||||
class Status(Enum):
|
|
||||||
INIT = 0
|
|
||||||
DETECT = 1
|
|
||||||
TRACK = 2
|
|
||||||
LANDING = 3
|
|
||||||
|
|
||||||
|
|
||||||
class UAV:
|
class UAV:
|
||||||
|
|
27
control.py
27
control.py
|
@ -1,10 +1,21 @@
|
||||||
from pynput import keyboard
|
from pynput import keyboard
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from vision import Tracker
|
||||||
|
import cv2
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
TICKS_PER_SEC = 30
|
TICKS_PER_SEC = 30
|
||||||
|
|
||||||
|
|
||||||
|
class Status(Enum):
|
||||||
|
INIT = 0
|
||||||
|
DETECT = 1
|
||||||
|
TRACK = 2
|
||||||
|
LANDING = 3
|
||||||
|
|
||||||
class PIDController:
|
class PIDController:
|
||||||
def __init__(self, kp, ki, kd, current_time):
|
def __init__(self, kp, ki, kd, current_time):
|
||||||
self.kp = kp
|
self.kp = kp
|
||||||
|
@ -81,7 +92,21 @@ class AutoControl:
|
||||||
self.pid = PIDController(1, 0, 0, time.time())
|
self.pid = PIDController(1, 0, 0, time.time())
|
||||||
|
|
||||||
def detect(self):
|
def detect(self):
|
||||||
pass
|
front_img = self.uav.frame_queue.get()
|
||||||
|
cv2.imshow("ROI select", front_img[:, :, 0:3])
|
||||||
|
self.gROI = cv2.selectROI("ROI select", front_img[:, :, 0:3], False)
|
||||||
|
if (not self.gROI):
|
||||||
|
print("空框选,退出")
|
||||||
|
quit()
|
||||||
|
self.gTracker = Tracker(tracker_type="KCF")
|
||||||
|
self.gTracker.initWorking(front_img[:, :, 0:3], self.gROI)
|
||||||
|
print("start tracking")
|
||||||
|
self.uav.status = Status.TRACKING
|
||||||
|
cv2.destroyWindow("ROI select")
|
||||||
|
|
||||||
|
def track(self):
|
||||||
|
front_img = self.uav.frame_queue.get()
|
||||||
|
|
||||||
|
|
||||||
class FakeUAV:
|
class FakeUAV:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
|
@ -0,0 +1,146 @@
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
class MessageItem(object):
|
||||||
|
# 用于封装信息的类,包含图片和其他信息
|
||||||
|
def __init__(self,frame,message):
|
||||||
|
self._frame = frame
|
||||||
|
self._message = message
|
||||||
|
|
||||||
|
def getFrame(self):
|
||||||
|
# 图片信息
|
||||||
|
return self._frame
|
||||||
|
|
||||||
|
def getMessage(self):
|
||||||
|
#文字信息,json格式
|
||||||
|
return self._message
|
||||||
|
|
||||||
|
|
||||||
|
class Tracker(object):
|
||||||
|
'''
|
||||||
|
追踪者模块,用于追踪指定目标
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, tracker_type="BOOSTING", draw_coord=True):
|
||||||
|
'''
|
||||||
|
初始化追踪器种类
|
||||||
|
'''
|
||||||
|
# 获得opencv版本
|
||||||
|
(major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
|
||||||
|
self.tracker_types = ['BOOSTING', 'MIL', 'KCF', 'TLD', 'MEDIANFLOW', 'GOTURN']
|
||||||
|
self.tracker_type = tracker_type
|
||||||
|
self.isWorking = False
|
||||||
|
self.draw_coord = draw_coord
|
||||||
|
# 构造追踪器
|
||||||
|
if int(major_ver) < 3:
|
||||||
|
self.tracker = cv2.Tracker_create(tracker_type)
|
||||||
|
else:
|
||||||
|
if tracker_type == 'BOOSTING':
|
||||||
|
self.tracker = cv2.TrackerBoosting_create()
|
||||||
|
if tracker_type == 'MIL':
|
||||||
|
self.tracker = cv2.TrackerMIL_create()
|
||||||
|
if tracker_type == 'KCF':
|
||||||
|
self.tracker = cv2.TrackerKCF_create()
|
||||||
|
if tracker_type == 'TLD':
|
||||||
|
self.tracker = cv2.TrackerTLD_create()
|
||||||
|
if tracker_type == 'MEDIANFLOW':
|
||||||
|
self.tracker = cv2.TrackerMedianFlow_create()
|
||||||
|
if tracker_type == 'GOTURN':
|
||||||
|
self.tracker = cv2.TrackerGOTURN_create()
|
||||||
|
|
||||||
|
def initWorking(self, frame, box):
|
||||||
|
'''
|
||||||
|
追踪器工作初始化
|
||||||
|
frame:初始化追踪画面
|
||||||
|
box:追踪的区域
|
||||||
|
'''
|
||||||
|
if not self.tracker:
|
||||||
|
raise Exception("追踪器未初始化")
|
||||||
|
status = self.tracker.init(frame, box)
|
||||||
|
# if not status:
|
||||||
|
# raise Exception("追踪器工作初始化失败")
|
||||||
|
self.coord = box
|
||||||
|
self.isWorking = True
|
||||||
|
|
||||||
|
def track(self, frame):
|
||||||
|
'''
|
||||||
|
开启追踪
|
||||||
|
'''
|
||||||
|
message = None
|
||||||
|
if self.isWorking:
|
||||||
|
status, self.coord = self.tracker.update(frame)
|
||||||
|
if status:
|
||||||
|
message = {"coord": [((int(self.coord[0]), int(self.coord[1])),
|
||||||
|
(int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3])))]}
|
||||||
|
if self.draw_coord:
|
||||||
|
p1 = (int(self.coord[0]), int(self.coord[1]))
|
||||||
|
p2 = (int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3]))
|
||||||
|
center = (int(self.coord[0] + self.coord[2] / 2), int(self.coord[1] + self.coord[3] / 2))
|
||||||
|
cv2.circle(frame, center, 2, (0, 0, 255), 2)
|
||||||
|
cv2.rectangle(frame, p1, p2, (255, 0, 0), 2, 1)
|
||||||
|
message['msg'] = "is tracking"
|
||||||
|
message['target'] = center
|
||||||
|
return MessageItem(frame, message)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# 初始化视频捕获设备
|
||||||
|
gVideoDevice = cv2.VideoCapture(0)
|
||||||
|
gCapStatus, gFrame = gVideoDevice.read()
|
||||||
|
|
||||||
|
# 选择 框选帧
|
||||||
|
print("按 n 选择下一帧,按 y 选取当前帧")
|
||||||
|
while True:
|
||||||
|
if (gCapStatus == False):
|
||||||
|
print("捕获帧失败")
|
||||||
|
quit()
|
||||||
|
|
||||||
|
_key = cv2.waitKey(0) & 0xFF
|
||||||
|
if(_key == ord('n')):
|
||||||
|
gCapStatus,gFrame = gVideoDevice.read()
|
||||||
|
if(_key == ord('y')):
|
||||||
|
break
|
||||||
|
|
||||||
|
cv2.imshow("pick frame",gFrame)
|
||||||
|
|
||||||
|
# 框选感兴趣区域region of interest
|
||||||
|
cv2.destroyWindow("pick frame")
|
||||||
|
gROI = cv2.selectROI("ROI frame",gFrame,False)
|
||||||
|
if (not gROI):
|
||||||
|
print("空框选,退出")
|
||||||
|
quit()
|
||||||
|
|
||||||
|
# 初始化追踪器
|
||||||
|
gTracker = Tracker(tracker_type="KCF")
|
||||||
|
gTracker.initWorking(gFrame,gROI)
|
||||||
|
|
||||||
|
# 循环帧读取,开始跟踪
|
||||||
|
while True:
|
||||||
|
gCapStatus, gFrame = gVideoDevice.read()
|
||||||
|
if(gCapStatus):
|
||||||
|
# 展示跟踪图片
|
||||||
|
print(gFrame)
|
||||||
|
_item = gTracker.track(gFrame)
|
||||||
|
cv2.imshow("track result",_item.getFrame())
|
||||||
|
|
||||||
|
if _item.getMessage():
|
||||||
|
# 打印跟踪数据
|
||||||
|
print(_item.getMessage())
|
||||||
|
else:
|
||||||
|
# 丢失,重新用初始ROI初始
|
||||||
|
print("丢失,重新使用初始ROI开始")
|
||||||
|
gTracker = Tracker(tracker_type="KCF")
|
||||||
|
gTracker.initWorking(gFrame, gROI)
|
||||||
|
|
||||||
|
_key = cv2.waitKey(1) & 0xFF
|
||||||
|
if (_key == ord('q')) | (_key == 27):
|
||||||
|
break
|
||||||
|
if (_key == ord('r')) :
|
||||||
|
# 用户请求用初始ROI
|
||||||
|
print("用户请求用初始ROI")
|
||||||
|
gTracker = Tracker(tracker_type="KCF")
|
||||||
|
gTracker.initWorking(gFrame, gROI)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("捕获帧失败")
|
||||||
|
quit()
|
Loading…
Reference in New Issue