Compare commits

...

10 Commits

Author SHA1 Message Date
raiot 318f710ef8 Merge branch 'main' of https://github.com/raiots/USVLander 2024-06-29 19:22:46 +08:00
raiot 68ff57d137 doc 2024-06-29 19:22:40 +08:00
raiot 9770c5528c doc 2024-06-29 19:22:11 +08:00
raiot 1d62a017af doc 2024-06-29 19:21:55 +08:00
raiot 4e67b0c9d2 feat: requirements 2024-06-29 18:59:35 +08:00
raiots ba7f14f654 fix: bug on landing 2024-06-24 02:04:55 +08:00
raiot 6b13f0337c fix 2024-06-24 00:23:34 +08:00
raiots f14adf0b65 fix: mod 2024-06-23 17:11:22 +08:00
raiots 01b4131b95 fix: data record bug 2024-06-23 15:28:14 +08:00
raiot 023793651d feat: first demo 2024-06-23 14:57:02 +08:00
12 changed files with 500 additions and 35 deletions

197
.gitignore vendored Normal file
View File

@ -0,0 +1,197 @@
# Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode
# Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
# Local History for Visual Studio Code
.history/
# Built Visual Studio Code Extensions
*.vsix
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
data/
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode

29
README.md Normal file
View File

@ -0,0 +1,29 @@
# USVLander
## 项目介绍
USVLander 是一个完全基于计算机视觉的无人机控制算法,目前针对 DJI Tello 无人机设计,能够实现视觉自主降落至无人船。该项目基于 Python 编程语言,并集成了 YOLO 实时目标检测系统和 OpenCV 计算机视觉库,以提供精确的导航和控制。
## 安装说明
要安装 USVLander您需要确保您的系统中已安装 Python 3.10 或更高版本(开发过程基于 Python3.10.8 )。此外,您还需要安装 YOLO 和 OpenCV 库。可以通过以下命令安装所需的依赖项:
```bash
pip install -r requirements.txt
```
## 如何使用
在安装所有必要的依赖项后,您可以通过以下步骤使用 USVLander 控制算法:
1. 将无人机与您的计算机连接。
2. 打开终端并导航到 USVLander 项目目录。
3. 运行主控制脚本:
```bash
python main.py
```
## License
USVLander 根据 MIT 许可证发布。这意味着您可以自由地使用、修改和分发该软件,但您必须包含原始作者的版权声明和许可声明。
请注意,此 README 是一个基本模板,您可能需要根据项目的具体情况进行调整。确保在发布之前,您已经完全测试了所有的功能,并且代码的文档是最新的。祝您的项目成功!
![alt text](docs\Draft_process.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

View File

@ -6,9 +6,9 @@ import sys
import time import time
import logging import logging
import cv2 import cv2
import numpy as np
from control import ManualControl, AutoControl from control import ManualControl, AutoControl
from control import Status from utils import Status, DataRecorder
@ -20,6 +20,11 @@ class UAV:
self.tello.connect() self.tello.connect()
print('Current_battery: ' + str(self.tello.get_battery())) print('Current_battery: ' + str(self.tello.get_battery()))
self.tello.streamon() self.tello.streamon()
self.tello.set_video_resolution(Tello.RESOLUTION_480P)
self.tello.set_video_fps(Tello.FPS_30)
self.tello.set_video_bitrate(Tello.BITRATE_1MBPS)
self.tello.set_video_direction(Tello.CAMERA_FORWARD)
self.frame_read = self.tello.get_frame_read() self.frame_read = self.tello.get_frame_read()
self.left_right_velocity = 0 self.left_right_velocity = 0
@ -33,51 +38,77 @@ class UAV:
self.status = Status.DETECT self.status = Status.DETECT
self.frame_queue = Queue(maxsize=1) # 用于存放视频帧, 1用于避免帧堆积 self.frame_queue = Queue(maxsize=1) # 用于存放视频帧, 1用于避免帧堆积
record_name = input("请输入数据备注(英文):")
if record_name:
self.recorder = DataRecorder(exp_note=record_name, need_record=True)
else:
self.recorder = DataRecorder()
self.front_cam_shape = (648, 478) # x, y
def update(self): def update(self):
""" Update routine. Send velocities to Tello. """ Update routine. Send velocities to Tello.
向Tello发送各方向速度信息 向Tello发送各方向速度信息
""" """
if self.yaw_velocity > 50:
self.yaw_velocity = 50
elif self.yaw_velocity < -50:
self.yaw_velocity = -50
if self.send_rc_control: if self.send_rc_control:
self.tello.send_rc_control(self.left_right_velocity, self.for_back_velocity, self.tello.send_rc_control(self.left_right_velocity, self.for_back_velocity,
self.up_down_velocity, self.yaw_velocity) self.up_down_velocity, self.yaw_velocity)
print(self.left_right_velocity, self.for_back_velocity,
self.up_down_velocity, self.yaw_velocity)
def run(self): def run(self):
""" 主程序 """ """ 主程序 """
# self.init_flight() self.init_flight()
# self.tello.turn_motor_on()
state_thread = Thread(target=self.video_stream) state_thread = Thread(target=self.video_stream)
state_thread.start() state_thread.start()
# self.send_rc_control = True self.manual_control.start() # 启动手动控制键盘监听
# self.manual_control.start() # 启动手动控制键盘监听
while True: while True:
# self.up_down_velocity(10) # self.up_down_velocity(10)
# self.send_rc_control = True self.auto_track_land()
self.auto_control()
self.update() self.update()
self.video_display() # self.video_display()
time.sleep(0.05) time.sleep(0.05)
def signal_handler(self, signal, frame): def signal_handler(self, signal, frame):
print ('\nSignal Catched! 执行急停!') print ('\nSignal Catched! 执行急停!')
self.tello.emergency() self.tello.emergency()
self.tello.streamoff()
sys.exit(0) sys.exit(0)
def init_flight(self): def init_flight(self):
self.tello.turn_motor_on() self.tello.turn_motor_on()
time.sleep(10) time.sleep(5)
logging.warn('自检完毕,可以起飞') logging.warn('自检完毕,可以起飞')
self.tello.takeoff() self.tello.takeoff()
self.send_rc_control = True # 开启发送控制信号
self.update()
self.tello.go_xyz_speed(0, 0, 50, 20) # 初始升高高度
def video_stream(self): def video_stream(self):
# 独立线程,用于获取视频流并记录状态
height, width, _ = self.frame_read.frame.shape height, width, _ = self.frame_read.frame.shape
while True: while True:
frame = self.frame_read.frame frame = self.frame_read.frame
self.frame_queue.put(frame) self.frame_queue.put(frame)
states = self.tello.get_current_state()
states['timestamp'] = time.time()
states['status'] = self.status
states['cmd_vel'] = [self.left_right_velocity, self.for_back_velocity, self.up_down_velocity, self.yaw_velocity]
self.recorder.state_record(states)
self.recorder.frame_record(frame, 'origin')
time.sleep(1 / 30) time.sleep(1 / 30)
def video_display(self): def video_display(self):
@ -88,20 +119,31 @@ class UAV:
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
self.keep_recording = False self.keep_recording = False
cv2.destroyAllWindows() cv2.destroyAllWindows()
if cv2.waitKey(1) & 0xFF == ord('s'):
cv2.imwrite(str(time.time()) + '.jpg', frame)
def auto_track_land(self): def auto_track_land(self):
# print(self.status)
if self.status == Status.INIT or self.status == Status.TRACKING: if self.status == Status.TRACK or self.status == Status.DETECT:
front_img = self.frame_queue.get() front_img = self.frame_queue.get()
front_img = cv2.cvtColor(front_img, cv2.COLOR_RGB2BGR)
if self.status == Status.DETECT: if self.status == Status.DETECT:
self.auto_control.detect(front_img=front_img) self.auto_control.detect(front_img=front_img)
elif self.status == Status.TRACK: elif self.status == Status.TRACK:
self.auto_control.track(front_img=front_img) self.auto_control.track(front_img=front_img)
elif self.status == Status.LANDING:
bottom_img = self.frame_queue.get()
self.auto_control.landing(bottom_img=bottom_img)
elif self.status == Status.FALLING:
# 自由落体阶段
pass
time.sleep(1 / 30)
if __name__ == '__main__': if __name__ == '__main__':
uav = UAV() uav = UAV()
# while True:
# print(uav.tello.get_current_state())
uav.run() uav.run()

View File

@ -1,20 +1,18 @@
from pynput import keyboard from pynput import keyboard
import time import time
import logging
from vision import Tracker from vision import Tracker
import cv2 import cv2
from enum import Enum from utils import Status
import pupil_apriltags as apriltag
from djitellopy import Tello
TICKS_PER_SEC = 30 TICKS_PER_SEC = 30
class Status(Enum):
INIT = 0
DETECT = 1
TRACK = 2
LANDING = 3
class PIDController: class PIDController:
def __init__(self, kp, ki, kd, current_time): def __init__(self, kp, ki, kd, current_time):
@ -25,7 +23,7 @@ class PIDController:
self.integral = 0 self.integral = 0
self.last_time = current_time self.last_time = current_time
def control(self, error, current_time): def control(self, error, current_time) -> int:
current_time = current_time current_time = current_time
delta_time = current_time - self.last_time delta_time = current_time - self.last_time
if delta_time == 0: if delta_time == 0:
@ -38,7 +36,7 @@ class PIDController:
self.previous_error = error self.previous_error = error
self.last_time = current_time self.last_time = current_time
return output return int(output)
class ManualControl: class ManualControl:
@ -63,10 +61,12 @@ class ManualControl:
self.uav.up_down_velocity = 50 self.uav.up_down_velocity = 50
elif key.char == 's': elif key.char == 's':
self.uav.up_down_velocity = -50 self.uav.up_down_velocity = -50
elif key.char == 'q': elif key.char == 'z':
self.uav.yaw_velocity = -50 self.uav.yaw_velocity = -50
elif key.char == 'e': elif key.char == 'c':
self.uav.yaw_velocity = 50 self.uav.yaw_velocity = 50
elif key.char == 'l':
self.uav.tello.land()
# print(f'left_right_velocity: {self.uav.left_right_velocity}, for_back_velocity: {self.uav.for_back_velocity}, up_down_velocity: {self.uav.up_down_velocity}, yaw_velocity: {self.uav.yaw_velocity}') # print(f'left_right_velocity: {self.uav.left_right_velocity}, for_back_velocity: {self.uav.for_back_velocity}, up_down_velocity: {self.uav.up_down_velocity}, yaw_velocity: {self.uav.yaw_velocity}')
except AttributeError: except AttributeError:
pass pass
@ -79,7 +79,7 @@ class ManualControl:
self.uav.left_right_velocity = 0 self.uav.left_right_velocity = 0
elif key.char == 'w' or key.char == 's': elif key.char == 'w' or key.char == 's':
self.uav.up_down_velocity = 0 self.uav.up_down_velocity = 0
elif key.char == 'q' or key.char == 'e': elif key.char == 'z' or key.char == 'c':
self.uav.yaw_velocity = 0 self.uav.yaw_velocity = 0
except AttributeError: except AttributeError:
pass pass
@ -89,27 +89,132 @@ class AutoControl:
def __init__(self, uav): def __init__(self, uav):
self.uav = uav self.uav = uav
self.target = [0, 0, 0, 0] self.target = [0, 0, 0, 0]
self.pid = PIDController(1, 0, 0, time.time()) self.tag_detector = apriltag.Detector(families='tag36h11')
self.yaw_controller = PIDController(0.3, 0, 0.1, time.time())
self.for_back_controler = PIDController(100, 0, 0.1, time.time())
self.left_right_controler = PIDController(50, 0, 0.1, time.time()) # 摄像头左右视野较小
self.loss_target_times = 0
self.is_get_tag = False
self.get_tag_times = 0
def detect(self, front_img): def detect(self, front_img):
cv2.imshow("ROI select", front_img[:, :, 0:3]) cv2.imshow("ROI select", front_img)
self.gROI = cv2.selectROI("ROI select", front_img[:, :, 0:3], False)
# 在 selectROI 前等待按键(可选)
key = cv2.waitKey(10) # 等待按键
if key == ord('q'): # 如果按下 'q' 键
print("Quit selected before ROI.")
self.uav.tello.land()
cv2.destroyAllWindows()
exit()
self.gROI = cv2.selectROI("ROI select", front_img, False)
if (not self.gROI): if (not self.gROI):
print("空框选,退出") print("空框选,退出")
quit() self.uav.tello.land()
self.gTracker = Tracker(tracker_type="KCF") self.gTracker = Tracker(tracker_type="KCF")
self.gTracker.initWorking(front_img[:, :, 0:3], self.gROI) self.gTracker.initWorking(front_img, self.gROI)
print("start tracking") logging.info("start tracking")
self.uav.status = Status.TRACK self.uav.status = Status.TRACK
cv2.destroyWindow("ROI select") cv2.destroyWindow("ROI select")
def track(self, front_img): def track(self, front_img):
_item = self.gTracker.track(front_img) _item = self.gTracker.track(front_img)
if _item.message: if _item.getMessage():
self.loss_target_times = 0
logging.info('tracking')
self.target = _item.getMessage()['target'] self.target = _item.getMessage()['target']
# print(self.target)
target_dict = {'timestamp': time.time(), 'x': self.target[0], 'y':self.target[1]}
self.uav.recorder.anything_record(target_dict, 'track_result', ['timestamp', 'x', 'y'])
frame = _item.getFrame()
cv2.imshow("track result", frame)
self.uav.recorder.frame_record(frame, 'KCF_tracking')
# store the frame to video
_key = cv2.waitKey(1) & 0xFF
if (_key == ord('q')) | (_key == 27):
self.uav.status = Status.LANDING
if (_key == ord('r')) :
self.uav.status = Status.DETECT
# print(_item.getFrame())
# print(self.target)
yaw_vel = self.yaw_controller.control(self.target[0] - self.uav.front_cam_shape[0]/2, time.time())
self.uav.yaw_velocity = int(yaw_vel)
self.uav.for_back_velocity = 20
else: else:
return None self.loss_target_times += 1
if self.loss_target_times == 6:
# 等待6帧0.2秒),仍丢失则重新检测并重置计数器
self.loss_target_times = 0 # 在函数头部还有一个
self.uav.status = Status.DETECT
self.uav.yaw_velocity = 0
self.uav.for_back_velocity = 0
if self.target[1] > self.uav.front_cam_shape[1] * 0.8:
print(self.target[1], self.uav.front_cam_shape[0]/10)
logging.warn("ready to land")
self.uav.yaw_velocity = 0
self.uav.for_back_velocity = 10
self.uav.status = Status.LANDING
self.uav.tello.set_video_direction(Tello.CAMERA_DOWNWARD)
cv2.destroyAllWindows()
time.sleep(1) # 等待1秒确保摄像头回传到位
def landing(self, bottom_img):
tag_large = self.tag_detector.detect(bottom_img[:, :, 0], estimate_tag_pose=True, camera_params=([353.836278849487, 353.077028029136, 163.745870517989, 115.130883974855]), tag_size=0.169)
tag_small = self.tag_detector.detect(bottom_img[:, :, 0], estimate_tag_pose=True, camera_params=([353.836278849487, 353.077028029136, 163.745870517989, 115.130883974855]), tag_size=0.018)
tag_list = []
for tag in tag_large:
if tag.tag_id == 1:
tag_list.append(tag)
for tag in tag_small:
if tag.tag_id == 0:
tag_list.append(tag)
if tag_list:
# 如果检测到标签
self.get_tag_times += 1
if self.get_tag_times >= 5:
# 检测到5次后才认为大部分tag已经进入视野
self.is_get_tag = True
for_vel = -1 * self.for_back_controler.control(tag_list[0].pose_t[0], time.time())
self.uav.for_back_velocity = int(for_vel)
left_vel = -1 * self.left_right_controler.control(tag_list[0].pose_t[1], time.time())
self.uav.left_right_velocity = int(left_vel)
# print(self.uav.for_back_velocity, self.uav.left_right_velocity)
if tag_list[0].pose_t[0] < 0.2 and tag_list[0].pose_t[1] < 0.2:
self.uav.up_down_velocity = -20
# print(tag_list[0].pose_t) # 后正左正高度
if tag_list[0].pose_t[2] < 0.3:
self.uav.tello.land()
self.uav.status = Status.FALLING
else:
self.uav.up_down_velocity = 0
self.uav.left_right_velocity = 0
if self.is_get_tag:
self.uav.for_back_velocity = 0
logging.warn("no tag detected")
cv2.imshow('aa', bottom_img[:, :, 0])
cv2.waitKey(1)
# if tags:
# self.uav.status = Status.LANDED
# self.uav.tello.land()
# logging.warn('landed')
# else:
# self.uav.up_down_velocity = -20
class FakeUAV: class FakeUAV:

BIN
docs/Draft_process.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 206 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 941 B

BIN
requirements.txt Normal file

Binary file not shown.

14
test.py Normal file
View File

@ -0,0 +1,14 @@
import pupil_apriltags as apriltag
import cv2
tag_detector = apriltag.Detector(families='tag36h11')
img = cv2.imread('Snipaste_2024-06-22_21-30-43.png')
# cv2.imshow('aaa', img)
# print(img.shape)
tag = tag_detector.detect(img[:,:,0], estimate_tag_pose=True, camera_params=([353.836278849487, 353.077028029136, 163.745870517989, 115.130883974855]), tag_size=0.018)
for t in tag:
if t.tag_id == 0:
print(t.pose_t)
print(t.tag_id)

76
utils.py Normal file
View File

@ -0,0 +1,76 @@
from enum import Enum
from pathlib import Path
from datetime import datetime
import time
import cv2
import logging
import csv
class Status(Enum):
# INIT = 0
DETECT = 1
TRACK = 2
LANDING = 3
FALLING = 4
class DataRecorder:
def __init__(self, exp_note='default', need_record=False):
# use path lib to manage file path
self.need_record = need_record
if need_record:
self.root_folder = Path('data') / str(datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + exp_note)
self.root_folder.mkdir(parents=True, exist_ok=False)
self.video_folder = self.root_folder / 'video'
self.video_folder.mkdir(parents=True, exist_ok=False)
with open(str(self.root_folder) + '/sensor_data.csv', mode='a', newline='') as file:
writer = csv.DictWriter(file, fieldnames=['timestamp', 'mid', 'x', 'y', 'z', 'mpry', 'pitch', 'roll', 'yaw', 'vgx', 'vgy', 'vgz', 'templ', 'temph', 'tof', 'h', 'bat', 'baro', 'time', 'agx', 'agy', 'agz', 'status', 'cmd_vel'])
writer.writeheader()
else:
logging.warn("实验数据将不被记录!")
def frame_record(self, frame, cam_name:str):
if self.need_record:
cam_folder = self.video_folder / cam_name # 分离不同来源的视频
if not cam_folder.exists():
# 如果文件夹不存在则创建
print('create folder:', str(cam_folder))
cam_folder.mkdir(parents=True, exist_ok=False)
frame_name = cam_folder / str(str(time.time()) + '.jpg')
if cv2.imwrite(str(frame_name), frame):
pass
else:
print('write frame failed')
def state_record(self, states):
if self.need_record:
with open(str(self.root_folder) + '/sensor_data.csv', mode='a', newline='') as file:
writer = csv.DictWriter(file, fieldnames=['timestamp', 'mid', 'x', 'y', 'z', 'mpry', 'pitch', 'roll', 'yaw', 'vgx', 'vgy', 'vgz', 'templ', 'temph', 'tof', 'h', 'bat', 'baro', 'time', 'agx', 'agy', 'agz', 'status', 'cmd_vel'])
# writer.writeheader()
writer.writerow(states)
def anything_record(self, data, data_name:str, header:list):
if self.need_record:
with open(str(self.root_folder) + '/' + data_name + '.csv', mode='a', newline='') as file:
writer = csv.DictWriter(file, fieldnames=header)
# writer.writeheader()
writer.writerow(data)
if __name__ == '__main__':
recorder = DataRecorder('aa')
# print(recorder.root_folder)
camera = cv2.VideoCapture(0)
while True:
ret, frame = camera.read()
if ret:
print(frame.shape)
cv2.imshow('frame', frame)
recorder.frame_record(frame, 'front')
cv2.waitKey(1)
time.sleep(0.1)

BIN
video.avi

Binary file not shown.

View File

@ -26,7 +26,7 @@ class Tracker(object):
''' '''
# 获得opencv版本 # 获得opencv版本
(major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.') (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
self.tracker_types = ['BOOSTING', 'MIL', 'KCF', 'TLD', 'MEDIANFLOW', 'GOTURN'] self.tracker_types = ['BOOSTING', 'MIL', 'KCF', 'TLD', 'MEDIANFLOW', 'GOTURN', 'CSRT']
self.tracker_type = tracker_type self.tracker_type = tracker_type
self.isWorking = False self.isWorking = False
self.draw_coord = draw_coord self.draw_coord = draw_coord
@ -46,6 +46,8 @@ class Tracker(object):
self.tracker = cv2.TrackerMedianFlow_create() self.tracker = cv2.TrackerMedianFlow_create()
if tracker_type == 'GOTURN': if tracker_type == 'GOTURN':
self.tracker = cv2.TrackerGOTURN_create() self.tracker = cv2.TrackerGOTURN_create()
if tracker_type == 'CSRT':
self.tracker = cv2.TrackerCSRT_create()
def initWorking(self, frame, box): def initWorking(self, frame, box):
''' '''