@@ -642,7 +642,7 @@ def build_targets(self, p, targets, imgs):
642642 #indices, anch = self.find_4_positive(p, targets)
643643 #indices, anch = self.find_5_positive(p, targets)
644644 #indices, anch = self.find_9_positive(p, targets)
645-
645+ device = torch . device ( targets . device )
646646 matching_bs = [[] for pp in p ]
647647 matching_as = [[] for pp in p ]
648648 matching_gjs = [[] for pp in p ]
@@ -682,7 +682,7 @@ def build_targets(self, p, targets, imgs):
682682 all_gj .append (gj )
683683 all_gi .append (gi )
684684 all_anch .append (anch [i ][idx ])
685- from_which_layer .append (torch .ones (size = (len (b ),)) * i )
685+ from_which_layer .append (( torch .ones (size = (len (b ),)) * i ). to ( device ) )
686686
687687 fg_pred = pi [b , a , gj , gi ]
688688 p_obj .append (fg_pred [:, 4 :5 ])
@@ -739,7 +739,7 @@ def build_targets(self, p, targets, imgs):
739739 + 3.0 * pair_wise_iou_loss
740740 )
741741
742- matching_matrix = torch .zeros_like (cost )
742+ matching_matrix = torch .zeros_like (cost , device = device )
743743
744744 for gt_idx in range (num_gt ):
745745 _ , pos_idx = torch .topk (
@@ -753,7 +753,7 @@ def build_targets(self, p, targets, imgs):
753753 _ , cost_argmin = torch .min (cost [:, anchor_matching_gt > 1 ], dim = 0 )
754754 matching_matrix [:, anchor_matching_gt > 1 ] *= 0.0
755755 matching_matrix [cost_argmin , anchor_matching_gt > 1 ] = 1.0
756- fg_mask_inboxes = matching_matrix .sum (0 ) > 0.0
756+ fg_mask_inboxes = ( matching_matrix .sum (0 ) > 0.0 ). to ( device )
757757 matched_gt_inds = matching_matrix [:, fg_mask_inboxes ].argmax (0 )
758758
759759 from_which_layer = from_which_layer [fg_mask_inboxes ]
@@ -1288,6 +1288,7 @@ def build_targets(self, p, targets, imgs):
12881288
12891289 indices , anch = self .find_3_positive (p , targets )
12901290
1291+ device = torch .device (targets .device )
12911292 matching_bs = [[] for pp in p ]
12921293 matching_as = [[] for pp in p ]
12931294 matching_gjs = [[] for pp in p ]
@@ -1327,7 +1328,8 @@ def build_targets(self, p, targets, imgs):
13271328 all_gj .append (gj )
13281329 all_gi .append (gi )
13291330 all_anch .append (anch [i ][idx ])
1330- from_which_layer .append (torch .ones (size = (len (b ),)) * i )
1331+ # from_which_layer.append(torch.ones(size=(len(b),)) * i)
1332+ from_which_layer .append ((torch .ones (size = (len (b ),)) * i ).to (device ))
13311333
13321334 fg_pred = pi [b , a , gj , gi ]
13331335 p_obj .append (fg_pred [:, 4 :5 ])
@@ -1384,7 +1386,7 @@ def build_targets(self, p, targets, imgs):
13841386 + 3.0 * pair_wise_iou_loss
13851387 )
13861388
1387- matching_matrix = torch .zeros_like (cost )
1389+ matching_matrix = torch .zeros_like (cost , device = device )
13881390
13891391 for gt_idx in range (num_gt ):
13901392 _ , pos_idx = torch .topk (
@@ -1398,9 +1400,11 @@ def build_targets(self, p, targets, imgs):
13981400 _ , cost_argmin = torch .min (cost [:, anchor_matching_gt > 1 ], dim = 0 )
13991401 matching_matrix [:, anchor_matching_gt > 1 ] *= 0.0
14001402 matching_matrix [cost_argmin , anchor_matching_gt > 1 ] = 1.0
1401- fg_mask_inboxes = matching_matrix .sum (0 ) > 0.0
1403+ fg_mask_inboxes = ( matching_matrix .sum (0 ) > 0.0 ). to ( device )
14021404 matched_gt_inds = matching_matrix [:, fg_mask_inboxes ].argmax (0 )
14031405
1406+ # [cui] Ref: https://github.com/WongKinYiu/yolov7/issues/1045
1407+ # from_which_layer = from_which_layer.to(fg_mask_inboxes.device)[fg_mask_inboxes]
14041408 from_which_layer = from_which_layer [fg_mask_inboxes ]
14051409 all_b = all_b [fg_mask_inboxes ]
14061410 all_a = all_a [fg_mask_inboxes ]
@@ -1441,6 +1445,8 @@ def build_targets2(self, p, targets, imgs):
14411445
14421446 indices , anch = self .find_5_positive (p , targets )
14431447
1448+ device = torch .device (targets .device )
1449+ # print(f"cuda index: {device.index}")
14441450 matching_bs = [[] for pp in p ]
14451451 matching_as = [[] for pp in p ]
14461452 matching_gjs = [[] for pp in p ]
@@ -1480,8 +1486,7 @@ def build_targets2(self, p, targets, imgs):
14801486 all_gj .append (gj )
14811487 all_gi .append (gi )
14821488 all_anch .append (anch [i ][idx ])
1483- from_which_layer .append (torch .ones (size = (len (b ),)) * i )
1484-
1489+ from_which_layer .append ((torch .ones (size = (len (b ),)) * i ).to (device ))
14851490 fg_pred = pi [b , a , gj , gi ]
14861491 p_obj .append (fg_pred [:, 4 :5 ])
14871492 p_cls .append (fg_pred [:, 5 :])
@@ -1537,7 +1542,7 @@ def build_targets2(self, p, targets, imgs):
15371542 + 3.0 * pair_wise_iou_loss
15381543 )
15391544
1540- matching_matrix = torch .zeros_like (cost )
1545+ matching_matrix = torch .zeros_like (cost , device = device )
15411546
15421547 for gt_idx in range (num_gt ):
15431548 _ , pos_idx = torch .topk (
@@ -1551,7 +1556,7 @@ def build_targets2(self, p, targets, imgs):
15511556 _ , cost_argmin = torch .min (cost [:, anchor_matching_gt > 1 ], dim = 0 )
15521557 matching_matrix [:, anchor_matching_gt > 1 ] *= 0.0
15531558 matching_matrix [cost_argmin , anchor_matching_gt > 1 ] = 1.0
1554- fg_mask_inboxes = matching_matrix .sum (0 ) > 0.0
1559+ fg_mask_inboxes = ( matching_matrix .sum (0 ) > 0.0 ). to ( device )
15551560 matched_gt_inds = matching_matrix [:, fg_mask_inboxes ].argmax (0 )
15561561
15571562 from_which_layer = from_which_layer [fg_mask_inboxes ]
0 commit comments