yolov3和yolov4损失函数

yolov3和yolov4损失函数

  • yolov3损失函数
    • loss公式
    • loss代码
  • yolov4损失函数
    • loss公式
    • loss代码

yolov3损失函数

loss公式

yolov3和yolov4损失函数_第1张图片
其中:
网格共有KxK个,每个网格产生M个候选框anchor,每个anchor经过网络会得到相应的bounding box,最终形成KxKxM个bounding box,如果box内noobj,则只计算该box的置信loss。

1.回归loss会乘以一个(2-wxh)的比例系数,用来加大对小box的损失。
2.置信度loss损失函数采用交叉熵,分为两部分:obj和noobj,其中noobj的loss还增加了权重系数lambda,这是为了减少noobj计算部分的贡献权重。
3.分类loss损失函数采用交叉熵,当第i个网格的第j个anchor box负责某一个真实目标时,那么这个anchor box所产生的bounding box才会去计算分类损失函数。

loss代码

box_loss_scale = 2 - y_true[l][...,2:3]*y_true[l][...,3:4]
xy_loss = object_mask * box_loss_scale * K.binary_crossentropy(raw_true_xy, raw_pred[...,0:2], from_logits=True)
wh_loss = object_mask * box_loss_scale * 0.5 * K.square(raw_true_wh-raw_pred[...,2:4])
confidence_loss = object_mask * K.binary_crossentropy(object_mask, raw_pred[...,4:5], from_logits=True)+ (1-object_mask) * K.binary_crossentropy(object_mask, raw_pred[...,4:5], from_logits=True) * ignore_mask
class_loss = object_mask * K.binary_crossentropy(true_class_probs, raw_pred[...,5:], from_logits=True)
xy_loss = K.sum(xy_loss) / mf
wh_loss = K.sum(wh_loss) / mf
confidence_loss = K.sum(confidence_loss) / mf
class_loss = K.sum(class_loss) / mf
loss += xy_loss + wh_loss + confidence_loss + class_loss

yolov4损失函数

loss公式

与yolov3不同的是,location_loss使用了ciou。ciou在iou的基础上考虑了边框的重合度、中心距离和宽高比的尺度信息。
ciou——loss函数如下:
yolov3和yolov4损失函数_第2张图片

loss代码

box_loss_scale = 2 - y_true[l][...,2:3]*y_true[l][...,3:4]
raw_true_box = y_true[l][...,0:4]
ciou = box_ciou(pred_box, raw_true_box)
ciou_loss = object_mask * box_loss_scale * (1 - ciou)
ciou_loss = K.sum(ciou_loss) / mf
location_loss = ciou_loss
confidence_loss = object_mask * K.binary_crossentropy(object_mask, raw_pred[...,4:5], from_logits=True)+(1-object_mask) * K.binary_crossentropy(object_mask, raw_pred[...,4:5], from_logits=True) * ignore_mask
class_loss = object_mask * K.binary_crossentropy(true_class_probs, raw_pred[...,5:], from_logits=True)
confidence_loss = K.sum(confidence_loss) / mf
class_loss = K.sum(class_loss) / mf
loss += location_loss + confidence_loss + class_loss

你可能感兴趣的