给生成锚框的代码简单写了个注释,希望指正
# 此函数中,ratio为锚框的真实宽高比
# data: (batch_size, num_channels, height, width)
def multibox_prior(data, sizes, ratios):
in_height, in_width = data.shape[-2:]
device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)
boxes_per_pixel = (num_sizes + num_ratios - 1)
size_tensor = torch.tensor(sizes, device=device)
size_ratio = torch.tensor(ratio, device=device)
offset_h, offset_w = 0.5, 0.5
step_h = 1.0 / in_height
step_w = 1.0 / in_width
# (center_h, center_w): 每个锚框缩放后的中心点的坐标
center_h = (torch.arange(in_height, device=device) + offset_h) * step_h
center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w
# center_h = tensor([1, 2, 3, 4])
# center_w = tensor([1, 2, 3])
#
# shift_y = tensor([[1, 1, 1], 即在第0维叠加
# [2, 2, 2],
# [3, 3, 3],
# [4, 4, 4]])
# shift_x = tensor([[1, 2, 3], 即在第1维叠加
# [1, 2, 3],
# [1, 2, 3],
# [1, 2, 3]])
# shift_x = [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3] ^ T
# shift_y = [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4] ^ T
# (shift_x, shift_y) 即是全部点的坐标
shift_y, shift_x = torch.meshgrid(center_h, center_w)
shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)
# 由于ratio为真实比例,因此当 ratio=1 时,应满足 w=h,故要乘以 in_height / in_width
# 又由于上面的 shift_x 和 shift_y 已经缩放到 [0,1] 范围,
# 因此这里不用乘以 in_width 和 in_height
w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),
sizes[0] * torch.sqrt(ratio_tensor[1:])))\
* in_height / in_width
h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),
sizes[0] / torch.sqrt(ratio_tensor[1:])))
# 每一行对应一个锚框的左上角偏差和右下角偏差
anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(
in_height * in_width, 1) / 2
# 每一行对应锚框的中心点重复两次
out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],
dim=1).repeat_interleave(boxes_per_pixel, dim=0)
# 加起来之后,就是锚框左上和右下的真实坐标
output = out_grid + anchor_manipulations
# 第一维是批量大小,要生成这一维
return output.unsqueeze(0)