# -*- coding:utf-8 -*-
# Author: Xue Yang <yangxue-2019-sjtu@sjtu.edu.cn>, <yangxue0827@126.com>
# License: Apache-2.0 license
# Copyright (c) SJTU. ALL rights reserved.
from __future__ import absolute_import, division, print_function
import numpy as np
import math
[docs]def gaussian_label(label, num_class, u=0, sig=4.0):
"""
Get gaussian label
:param label: angle_label/omega
:param num_class: angle_range/omega
:param u: mean
:param sig: window radius
:return: gaussian label
"""
x = np.array(range(math.floor(-num_class/2), math.ceil(num_class/2), 1))
if num_class % 2 != 0:
x = x[:-1]
y_sig = np.exp(-(x - u) ** 2 / (2 * sig ** 2))
return np.concatenate([y_sig[math.ceil(num_class/2)-label:],
y_sig[:math.ceil(num_class/2)-label]], axis=0)
[docs]def rectangular_label(label, num_class, radius=4):
"""
Get rectangular label
:param label: angle_label/omega
:param num_class: angle_range/omega
:param radius: window radius
:return: rectangular label
"""
x = np.zeros([num_class])
x[:radius+1] = 1
x[-radius:] = 1
y_sig = np.concatenate([x[-label:], x[:-label]], axis=0)
return y_sig
[docs]def pulse_label(label, num_class):
"""
Get pulse label
:param label: angle_label/omega
:param num_class: angle_range/omega
:return: pulse label
"""
x = np.zeros([num_class])
x[label] = 1
return x
[docs]def triangle_label(label, num_class, radius=4):
"""
Get triangle label
:param label: angle_label/omega
:param num_class: angle_range/omega
:param radius: window radius
:return: triangle label
"""
y_sig = np.zeros([num_class])
x = np.array(range(radius+1))
y = -1/(radius+1) * x + 1
y_sig[:radius+1] = y
y_sig[-radius:] = y[-1:0:-1]
return np.concatenate([y_sig[-label:], y_sig[:-label]], axis=0)
def get_all_smooth_label(num_label, label_type=0, radius=4):
all_smooth_label = []
if label_type == 0:
for i in range(num_label):
all_smooth_label.append(gaussian_label(i, num_label, sig=radius))
elif label_type == 1:
for i in range(num_label):
all_smooth_label.append(rectangular_label(i, num_label, radius=radius))
elif label_type == 2:
for i in range(num_label):
all_smooth_label.append(pulse_label(i, num_label))
elif label_type == 3:
for i in range(num_label):
all_smooth_label.append(triangle_label(i, num_label, radius=radius))
else:
raise Exception('Only support gaussian, rectangular, triangle and pulse label')
return np.array(all_smooth_label)
[docs]def angle_smooth_label(angle_label, angle_range=90, label_type=0, radius=4, omega=1):
"""
:param angle_label: angle label, range in [-90,0) or [-180, 0)
:param angle_range: 90 or 180
:param label_type: 0: gaussian label, 1: rectangular label, 2: pulse label, 3: triangle label
:param radius: window radius
:param omega: angle discretization granularity
:return:
**Circular Smooth Label:**
Proposed by `"Xue Yang et al. Arbitrary-Oriented Object Detection with Circular Smooth Label. ECCV 2020."
<https://link.springer.com/chapter/10.1007/978-3-030-58598-3_40>`_
.. image:: ../../images/csl.jpg
"""
assert angle_range % omega == 0, 'wrong omega'
angle_range /= omega
angle_label /= omega
angle_label = np.array(-np.round(angle_label), np.int32)
all_smooth_label = get_all_smooth_label(int(angle_range), label_type, radius)
inx = angle_label == angle_range
angle_label[inx] = angle_range - 1
smooth_label = all_smooth_label[angle_label]
return np.array(smooth_label, np.float32)
if __name__ == '__main__':
import matplotlib.pyplot as plt
# angle_label = np.array([-89.9, -45.2, -0.3, -1.9])
# smooth_label = angle_smooth_label(angle_label)
# y_sig = triangle_label(30, 180, radius=8)
y_sig = gaussian_label(180, 180, sig=6)
# y_sig = pulse_label(40, 180)
# y_sig = triangle_label(3, 180, radius=1)
x = np.array(range(0, 180, 1))
plt.plot(x, y_sig, "r-", linewidth=2)
plt.grid(True)
plt.show()
print(y_sig)
print(y_sig.shape)