Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torchvision | |
| import time | |
| import math | |
| import os | |
| import copy | |
| import pdb | |
| import argparse | |
| import sys | |
| import cv2 | |
| import skimage.io | |
| import skimage.transform | |
| import skimage.color | |
| import skimage | |
| import torch | |
| import model | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import datasets, models, transforms | |
| from dataloader import CSVDataset, collater, Resizer, AspectRatioBasedSampler, Augmenter, UnNormalizer, Normalizer, RGB_MEAN, RGB_STD | |
| from scipy.optimize import linear_sum_assignment | |
| # assert torch.__version__.split('.')[1] == '4' | |
| print('CUDA available: {}'.format(torch.cuda.is_available())) | |
| color_list = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 0, 255), (0, 255, 255), (255, 255, 0), (128, 0, 255), | |
| (0, 128, 255), (128, 255, 0), (0, 255, 128), (255, 128, 0), (255, 0, 128), (128, 128, 255), (128, 255, 128), (255, 128, 128), (128, 128, 0), (128, 0, 128)] | |
| class detect_rect: | |
| def __init__(self): | |
| self.curr_frame = 0 | |
| self.curr_rect = np.array([0, 0, 1, 1]) | |
| self.next_rect = np.array([0, 0, 1, 1]) | |
| self.conf = 0 | |
| self.id = 0 | |
| def position(self): | |
| x = (self.curr_rect[0] + self.curr_rect[2])/2 | |
| y = (self.curr_rect[1] + self.curr_rect[3])/2 | |
| return np.array([x, y]) | |
| def size(self): | |
| w = self.curr_rect[2] - self.curr_rect[0] | |
| h = self.curr_rect[3] - self.curr_rect[1] | |
| return np.array([w, h]) | |
| class tracklet: | |
| def __init__(self, det_rect): | |
| self.id = det_rect.id | |
| self.rect_list = [det_rect] | |
| self.rect_num = 1 | |
| self.last_rect = det_rect | |
| self.last_frame = det_rect.curr_frame | |
| self.no_match_frame = 0 | |
| def add_rect(self, det_rect): | |
| self.rect_list.append(det_rect) | |
| self.rect_num = self.rect_num + 1 | |
| self.last_rect = det_rect | |
| self.last_frame = det_rect.curr_frame | |
| def velocity(self): | |
| if(self.rect_num < 2): | |
| return (0, 0) | |
| elif(self.rect_num < 6): | |
| return (self.rect_list[self.rect_num - 1].position - self.rect_list[self.rect_num - 2].position) / (self.rect_list[self.rect_num - 1].curr_frame - self.rect_list[self.rect_num - 2].curr_frame) | |
| else: | |
| v1 = (self.rect_list[self.rect_num - 1].position - self.rect_list[self.rect_num - 4].position) / (self.rect_list[self.rect_num - 1].curr_frame - self.rect_list[self.rect_num - 4].curr_frame) | |
| v2 = (self.rect_list[self.rect_num - 2].position - self.rect_list[self.rect_num - 5].position) / (self.rect_list[self.rect_num - 2].curr_frame - self.rect_list[self.rect_num - 5].curr_frame) | |
| v3 = (self.rect_list[self.rect_num - 3].position - self.rect_list[self.rect_num - 6].position) / (self.rect_list[self.rect_num - 3].curr_frame - self.rect_list[self.rect_num - 6].curr_frame) | |
| return (v1 + v2 + v3) / 3 | |
| def cal_iou(rect1, rect2): | |
| x1, y1, x2, y2 = rect1 | |
| x3, y3, x4, y4 = rect2 | |
| i_w = min(x2, x4) - max(x1, x3) | |
| i_h = min(y2, y4) - max(y1, y3) | |
| if(i_w <= 0 or i_h <= 0): | |
| return 0 | |
| i_s = i_w * i_h | |
| s_1 = (x2 - x1) * (y2 - y1) | |
| s_2 = (x4 - x3) * (y4 - y3) | |
| return float(i_s) / (s_1 + s_2 - i_s) | |
| def cal_simi(det_rect1, det_rect2): | |
| return cal_iou(det_rect1.next_rect, det_rect2.curr_rect) | |
| def cal_simi_track_det(track, det_rect): | |
| if(det_rect.curr_frame <= track.last_frame): | |
| print("cal_simi_track_det error") | |
| return 0 | |
| elif(det_rect.curr_frame - track.last_frame == 1): | |
| return cal_iou(track.last_rect.next_rect, det_rect.curr_rect) | |
| else: | |
| pred_rect = track.last_rect.curr_rect + np.append(track.velocity, track.velocity) * (det_rect.curr_frame - track.last_frame) | |
| return cal_iou(pred_rect, det_rect.curr_rect) | |
| def track_det_match(tracklet_list, det_rect_list, min_iou = 0.5): | |
| num1 = len(tracklet_list) | |
| num2 = len(det_rect_list) | |
| cost_mat = np.zeros((num1, num2)) | |
| for i in range(num1): | |
| for j in range(num2): | |
| cost_mat[i, j] = -cal_simi_track_det(tracklet_list[i], det_rect_list[j]) | |
| match_result = linear_sum_assignment(cost_mat) | |
| match_result = np.asarray(match_result) | |
| match_result = np.transpose(match_result) | |
| matches, unmatched1, unmatched2 = [], [], [] | |
| for i in range(num1): | |
| if i not in match_result[:, 0]: | |
| unmatched1.append(i) | |
| for j in range(num2): | |
| if j not in match_result[:, 1]: | |
| unmatched2.append(j) | |
| for i, j in match_result: | |
| if cost_mat[i, j] > -min_iou: | |
| unmatched1.append(i) | |
| unmatched2.append(j) | |
| else: | |
| matches.append((i, j)) | |
| return matches, unmatched1, unmatched2 | |
| def draw_caption(image, box, caption, color): | |
| b = np.array(box).astype(int) | |
| cv2.putText(image, caption, (b[0], b[1] - 8), cv2.FONT_HERSHEY_PLAIN, 2, color, 2) | |
| def run_each_dataset(model_dir, retinanet, dataset_path, subset, cur_dataset): | |
| print(cur_dataset) | |
| img_list = os.listdir(os.path.join(dataset_path, subset, cur_dataset, 'img1')) | |
| img_list = [os.path.join(dataset_path, subset, cur_dataset, 'img1', _) for _ in img_list if ('jpg' in _) or ('png' in _)] | |
| img_list = sorted(img_list) | |
| img_len = len(img_list) | |
| last_feat = None | |
| confidence_threshold = 0.4 | |
| IOU_threshold = 0.5 | |
| retention_threshold = 10 | |
| det_list_all = [] | |
| tracklet_all = [] | |
| max_id = 0 | |
| max_draw_len = 100 | |
| draw_interval = 5 | |
| img_width = 1920 | |
| img_height = 1080 | |
| fps = 30 | |
| for i in range(img_len): | |
| det_list_all.append([]) | |
| for idx in range((int(img_len / 2)), img_len + 1): | |
| i = idx - 1 | |
| print('tracking: ', i) | |
| with torch.no_grad(): | |
| data_path1 = img_list[min(idx, img_len - 1)] | |
| img_origin1 = skimage.io.imread(data_path1) | |
| img_h, img_w, _ = img_origin1.shape | |
| img_height, img_width = img_h, img_w | |
| resize_h, resize_w = math.ceil(img_h / 32) * 32, math.ceil(img_w / 32) * 32 | |
| img1 = np.zeros((resize_h, resize_w, 3), dtype=img_origin1.dtype) | |
| img1[:img_h, :img_w, :] = img_origin1 | |
| img1 = (img1.astype(np.float32) / 255.0 - np.array([[RGB_MEAN]])) / np.array([[RGB_STD]]) | |
| img1 = torch.from_numpy(img1).permute(2, 0, 1).view(1, 3, resize_h, resize_w) | |
| scores, transformed_anchors, last_feat = retinanet(img1.cuda().float(), last_feat=last_feat) | |
| # if idx > 0: | |
| if idx > (int(img_len / 2)): | |
| idxs = np.where(scores>0.1) | |
| for j in range(idxs[0].shape[0]): | |
| bbox = transformed_anchors[idxs[0][j], :] | |
| x1 = int(bbox[0]) | |
| y1 = int(bbox[1]) | |
| x2 = int(bbox[2]) | |
| y2 = int(bbox[3]) | |
| x3 = int(bbox[4]) | |
| y3 = int(bbox[5]) | |
| x4 = int(bbox[6]) | |
| y4 = int(bbox[7]) | |
| det_conf = float(scores[idxs[0][j]]) | |
| det_rect = detect_rect() | |
| det_rect.curr_frame = idx | |
| det_rect.curr_rect = np.array([x1, y1, x2, y2]) | |
| det_rect.next_rect = np.array([x3, y3, x4, y4]) | |
| det_rect.conf = det_conf | |
| if det_rect.conf > confidence_threshold: | |
| det_list_all[det_rect.curr_frame - 1].append(det_rect) | |
| # if i == 0: | |
| if i == int(img_len / 2): | |
| for j in range(len(det_list_all[i])): | |
| det_list_all[i][j].id = j + 1 | |
| max_id = max(max_id, j + 1) | |
| track = tracklet(det_list_all[i][j]) | |
| tracklet_all.append(track) | |
| continue | |
| matches, unmatched1, unmatched2 = track_det_match(tracklet_all, det_list_all[i], IOU_threshold) | |
| for j in range(len(matches)): | |
| det_list_all[i][matches[j][1]].id = tracklet_all[matches[j][0]].id | |
| det_list_all[i][matches[j][1]].id = tracklet_all[matches[j][0]].id | |
| tracklet_all[matches[j][0]].add_rect(det_list_all[i][matches[j][1]]) | |
| delete_track_list = [] | |
| for j in range(len(unmatched1)): | |
| tracklet_all[unmatched1[j]].no_match_frame = tracklet_all[unmatched1[j]].no_match_frame + 1 | |
| if(tracklet_all[unmatched1[j]].no_match_frame >= retention_threshold): | |
| delete_track_list.append(unmatched1[j]) | |
| origin_index = set([k for k in range(len(tracklet_all))]) | |
| delete_index = set(delete_track_list) | |
| left_index = list(origin_index - delete_index) | |
| tracklet_all = [tracklet_all[k] for k in left_index] | |
| for j in range(len(unmatched2)): | |
| det_list_all[i][unmatched2[j]].id = max_id + 1 | |
| max_id = max_id + 1 | |
| track = tracklet(det_list_all[i][unmatched2[j]]) | |
| tracklet_all.append(track) | |
| #**************visualize tracking result and save evaluate file**************** | |
| fout_tracking = open(os.path.join(model_dir, 'results', cur_dataset + '.txt'), 'w') | |
| save_img_dir = os.path.join(model_dir, 'results', cur_dataset) | |
| if not os.path.exists(save_img_dir): | |
| os.makedirs(save_img_dir) | |
| out_video = os.path.join(model_dir, 'results', cur_dataset + '.mp4') | |
| videoWriter = cv2.VideoWriter(out_video, cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), fps, (img_width, img_height)) | |
| id_dict = {} | |
| for i in range((int(img_len / 2)), img_len): | |
| print('saving: ', i) | |
| img = cv2.imread(img_list[i]) | |
| for j in range(len(det_list_all[i])): | |
| x1, y1, x2, y2 = det_list_all[i][j].curr_rect.astype(int) | |
| trace_id = det_list_all[i][j].id | |
| id_dict.setdefault(str(trace_id),[]).append((int((x1+x2)/2), y2)) | |
| draw_trace_id = str(trace_id) | |
| draw_caption(img, (x1, y1, x2, y2), draw_trace_id, color=color_list[trace_id % len(color_list)]) | |
| cv2.rectangle(img, (x1, y1), (x2, y2), color=color_list[trace_id % len(color_list)], thickness=2) | |
| trace_len = len(id_dict[str(trace_id)]) | |
| trace_len_draw = min(max_draw_len, trace_len) | |
| for k in range(trace_len_draw - draw_interval): | |
| if(k % draw_interval == 0): | |
| draw_point1 = id_dict[str(trace_id)][trace_len - k - 1] | |
| draw_point2 = id_dict[str(trace_id)][trace_len - k - 1 - draw_interval] | |
| cv2.line(img, draw_point1, draw_point2, color=color_list[trace_id % len(color_list)], thickness=2) | |
| fout_tracking.write(str(i+1) + ',' + str(trace_id) + ',' + str(x1) + ',' + str(y1) + ',' + str(x2 - x1) + ',' + str(y2 - y1) + ',-1,-1,-1,-1\n') | |
| cv2.imwrite(os.path.join(save_img_dir, str(i + 1).zfill(6) + '.jpg'), img) | |
| videoWriter.write(img) | |
| # cv2.waitKey(0) | |
| fout_tracking.close() | |
| videoWriter.release() | |
| def run_from_train(model_dir, root_path): | |
| if not os.path.exists(os.path.join(model_dir, 'results')): | |
| os.makedirs(os.path.join(model_dir, 'results')) | |
| retinanet = torch.load(os.path.join(model_dir, 'model_final.pt')) | |
| use_gpu = True | |
| if use_gpu: retinanet = retinanet.cuda() | |
| retinanet.eval() | |
| for seq_num in [2, 4, 5, 9, 10, 11, 13]: | |
| run_each_dataset(model_dir, retinanet, root_path, 'train', 'MOT17-{:02d}'.format(seq_num)) | |
| for seq_num in [1, 3, 6, 7, 8, 12, 14]: | |
| run_each_dataset(model_dir, retinanet, root_path, 'test', 'MOT17-{:02d}'.format(seq_num)) | |
| def main(args=None): | |
| parser = argparse.ArgumentParser(description='Simple script for testing a CTracker network.') | |
| parser.add_argument('--dataset_path', default='/dockerdata/home/jeromepeng/data/MOT/MOT17/', type=str, help='Dataset path, location of the images sequence.') | |
| parser.add_argument('--model_dir', default='./trained_model/', help='Path to model (.pt) file.') | |
| parser.add_argument('--model_path', default='./trained_model/model_final.pth', help='Path to model (.pt) file.') | |
| parser = parser.parse_args(args) | |
| if not os.path.exists(os.path.join(parser.model_dir, 'results')): | |
| os.makedirs(os.path.join(parser.model_dir, 'results')) | |
| retinanet = model.resnet50(num_classes=1, pretrained=True) | |
| # retinanet_save = torch.load(os.path.join(parser.model_dir, 'model_final.pth')) | |
| retinanet_save = torch.load(os.path.join(parser.model_path)) | |
| # rename moco pre-trained keys | |
| state_dict = retinanet_save.state_dict() | |
| for k in list(state_dict.keys()): | |
| # retain only encoder up to before the embedding layer | |
| if k.startswith('module.'): | |
| # remove prefix | |
| state_dict[k[len("module."):]] = state_dict[k] | |
| # delete renamed or unused k | |
| del state_dict[k] | |
| retinanet.load_state_dict(state_dict) | |
| use_gpu = True | |
| if use_gpu: retinanet = retinanet.cuda() | |
| retinanet.eval() | |
| for seq_num in [2, 4, 5, 9, 10, 11, 13]: | |
| run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'train', 'MOT17-{:02d}'.format(seq_num)) | |
| # for seq_num in [1, 3, 6, 7, 8, 12, 14]: | |
| # run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'test', 'MOT17-{:02d}'.format(seq_num)) | |
| if __name__ == '__main__': | |
| main() | |