Skip to content

Commit 2162c0f

Browse files
author
yixu.cui
committed
fix indices device(GPU) bug
1 parent a2be414 commit 2162c0f

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

utils/loss.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ def build_targets(self, p, targets, imgs):
642642
#indices, anch = self.find_4_positive(p, targets)
643643
#indices, anch = self.find_5_positive(p, targets)
644644
#indices, anch = self.find_9_positive(p, targets)
645-
645+
device = torch.device(targets.device)
646646
matching_bs = [[] for pp in p]
647647
matching_as = [[] for pp in p]
648648
matching_gjs = [[] for pp in p]
@@ -682,7 +682,7 @@ def build_targets(self, p, targets, imgs):
682682
all_gj.append(gj)
683683
all_gi.append(gi)
684684
all_anch.append(anch[i][idx])
685-
from_which_layer.append(torch.ones(size=(len(b),)) * i)
685+
from_which_layer.append((torch.ones(size=(len(b),)) * i).to(device))
686686

687687
fg_pred = pi[b, a, gj, gi]
688688
p_obj.append(fg_pred[:, 4:5])
@@ -739,7 +739,7 @@ def build_targets(self, p, targets, imgs):
739739
+ 3.0 * pair_wise_iou_loss
740740
)
741741

742-
matching_matrix = torch.zeros_like(cost)
742+
matching_matrix = torch.zeros_like(cost, device=device)
743743

744744
for gt_idx in range(num_gt):
745745
_, pos_idx = torch.topk(
@@ -753,7 +753,7 @@ def build_targets(self, p, targets, imgs):
753753
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
754754
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
755755
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
756-
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
756+
fg_mask_inboxes = (matching_matrix.sum(0) > 0.0).to(device)
757757
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
758758

759759
from_which_layer = from_which_layer[fg_mask_inboxes]
@@ -1288,6 +1288,7 @@ def build_targets(self, p, targets, imgs):
12881288

12891289
indices, anch = self.find_3_positive(p, targets)
12901290

1291+
device = torch.device(targets.device)
12911292
matching_bs = [[] for pp in p]
12921293
matching_as = [[] for pp in p]
12931294
matching_gjs = [[] for pp in p]
@@ -1327,7 +1328,8 @@ def build_targets(self, p, targets, imgs):
13271328
all_gj.append(gj)
13281329
all_gi.append(gi)
13291330
all_anch.append(anch[i][idx])
1330-
from_which_layer.append(torch.ones(size=(len(b),)) * i)
1331+
# from_which_layer.append(torch.ones(size=(len(b),)) * i)
1332+
from_which_layer.append((torch.ones(size=(len(b),)) * i).to(device))
13311333

13321334
fg_pred = pi[b, a, gj, gi]
13331335
p_obj.append(fg_pred[:, 4:5])
@@ -1384,7 +1386,7 @@ def build_targets(self, p, targets, imgs):
13841386
+ 3.0 * pair_wise_iou_loss
13851387
)
13861388

1387-
matching_matrix = torch.zeros_like(cost)
1389+
matching_matrix = torch.zeros_like(cost, device=device)
13881390

13891391
for gt_idx in range(num_gt):
13901392
_, pos_idx = torch.topk(
@@ -1398,9 +1400,11 @@ def build_targets(self, p, targets, imgs):
13981400
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
13991401
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
14001402
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
1401-
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
1403+
fg_mask_inboxes = (matching_matrix.sum(0) > 0.0).to(device)
14021404
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
14031405

1406+
# [cui] Ref: https://github.com/WongKinYiu/yolov7/issues/1045
1407+
# from_which_layer = from_which_layer.to(fg_mask_inboxes.device)[fg_mask_inboxes]
14041408
from_which_layer = from_which_layer[fg_mask_inboxes]
14051409
all_b = all_b[fg_mask_inboxes]
14061410
all_a = all_a[fg_mask_inboxes]
@@ -1441,6 +1445,8 @@ def build_targets2(self, p, targets, imgs):
14411445

14421446
indices, anch = self.find_5_positive(p, targets)
14431447

1448+
device = torch.device(targets.device)
1449+
# print(f"cuda index: {device.index}")
14441450
matching_bs = [[] for pp in p]
14451451
matching_as = [[] for pp in p]
14461452
matching_gjs = [[] for pp in p]
@@ -1480,8 +1486,7 @@ def build_targets2(self, p, targets, imgs):
14801486
all_gj.append(gj)
14811487
all_gi.append(gi)
14821488
all_anch.append(anch[i][idx])
1483-
from_which_layer.append(torch.ones(size=(len(b),)) * i)
1484-
1489+
from_which_layer.append((torch.ones(size=(len(b),)) * i).to(device))
14851490
fg_pred = pi[b, a, gj, gi]
14861491
p_obj.append(fg_pred[:, 4:5])
14871492
p_cls.append(fg_pred[:, 5:])
@@ -1537,7 +1542,7 @@ def build_targets2(self, p, targets, imgs):
15371542
+ 3.0 * pair_wise_iou_loss
15381543
)
15391544

1540-
matching_matrix = torch.zeros_like(cost)
1545+
matching_matrix = torch.zeros_like(cost, device=device)
15411546

15421547
for gt_idx in range(num_gt):
15431548
_, pos_idx = torch.topk(
@@ -1551,7 +1556,7 @@ def build_targets2(self, p, targets, imgs):
15511556
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
15521557
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
15531558
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
1554-
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
1559+
fg_mask_inboxes = (matching_matrix.sum(0) > 0.0).to(device)
15551560
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
15561561

15571562
from_which_layer = from_which_layer[fg_mask_inboxes]

0 commit comments

Comments
 (0)