Welcome to AlphaRotate’s documentation!¶
AlphaRotate is an open-source Tensorflow benchmark for performing scalable rotation detection on various datasets, which is maintained by Xue Yang with Shanghai Jiao Tong University supervised by Prof. Junchi Yan.
This repository is developed for the following purposes:
Providing modules for developing rotation detection algorithms to facilitate future research.
Providing implementation of state-of-the-art rotation detection methods.
Benchmarking existing rotation detection algorithms under different dataset & experiment settings, for the purpose of fair comparison.
Introduction to Rotation Detection¶
Arbitrary-oriented objects are ubiquitous for detection across visual datasets, such as aerial images, scene text, face and 3D objects, retail scenes, etc. Compared with the large literature on horizontal object detection, research in oriented object detection is relatively in its earlier stage, with many open problems to solve.
Rotation detection techniques have been applied to the following applications:
Aerial images
Scene text
Face
3D object detection
Retail scenes
and more…
In this repository, we mainly focus on aerial images due to its challenging.
Readers are referred to the following survey for more technical details about aerial image rotation detection: DOTA-DOAI
Installation¶
Docker¶
We recommend using docker images if docker or other container runtimes e.g. singularity is available on your devices.
We maintain a prebuilt image at dockerhub (cuda version < 11):
yangxue2docker/yx-tf-det:tensorflow1.13.1-cuda10-gpu-py3
Note
For 30xx series graphics cards (cuda version >= 11), please download image from tensorflow-release-notes according to your development environment, e.g. nvcr.io/nvidia/tensorflow:20.11-tf1-py3
Manual configuration¶
If docker is not available, we provide detailed steps to install the requirements by pip
(cuda version < 11):
pip install -r requirements.txt
pip install -v -e . # or "python setup.py develop"
Or, you can simply install AlphaRotate with the following commands (cuda version < 11):
pip install alpharotate # Not suitable for dev.
Note
For 30xx series graphics cards (cuda version >= 11), we recommend this blog to install tf1.xx
For 30xx series graphics cards (cuda version >= 11):
cd alpharotate/libs/utils/cython_utils
rm *.so
rm *.c
rm *.cpp
python setup.py build_ext --inplace (or make)
cd alpharotate/libs/utils/
rm *.so
rm *.c
rm *.cpp
python setup.py build_ext --inplace
Run the Experiment¶
Download Model¶
Pretrain weights¶
Download a pretrain weight you need from the following three options, and then put it to pretrained_weights.
Tensorflow pretrain weights: resnet50_v1, resnet101_v1, resnet152_v1, efficientnet, mobilenet_v2, darknet53 (Baidu Drive (1jg2), Google Drive).
- Pytorch pretrain weights, refer to pretrain_zoo.py and Others.
Trained weights¶
Please download trained models by this project, then put them to trained_weights.
Compile¶
cd $PATH_ROOT/libs/utils/cython_utils
rm *.so
rm *.c
rm *.cpp
python setup.py build_ext --inplace (or make)
cd $PATH_ROOT/libs/utils/
rm *.so
rm *.c
rm *.cpp
python setup.py build_ext --inplace
Train¶
- If you want to train your own dataset, please note:
Select the detector and dataset you want to use, and mark them as
#DETECTOR
and#DATASET
(such as#DETECTOR=retinanet
and#DATASET=DOTA
)Modify parameters (such as
CLASS_NUM
,DATASET_NAME
,VERSION
, etc.) in$PATH_ROOT/libs/configs/#DATASET/#DETECTOR/cfgs_xxx.py
Copy
$PATH_ROOT/libs/configs/#DATASET/#DETECTOR/cfgs_xxx.py
to$PATH_ROOT/libs/configs/cfgs.py
Add category information in
$PATH_ROOT/libs/label_name_dict/label_dict.py
Add data_name to
$PATH_ROOT/dataloader/dataset/read_tfrecord.py
Make tfrecord
If image is very large (such as DOTA dataset), the image needs to be cropped. Take DOTA dataset as a example:
cd $PATH_ROOT/dataloader/dataset/DOTA
python data_crop.py
If image does not need to be cropped, just convert the annotation file into xml format, refer to example.xml.
cd $PATH_ROOT/dataloader/dataset/
python convert_data_to_tfrecord.py --root_dir='/PATH/TO/DOTA/'
--xml_dir='labeltxt'
--image_dir='images'
--save_name='train'
--img_format='.png'
--dataset='DOTA'
Start training
cd $PATH_ROOT/tools/#DETECTOR
python train.py
Train and Evaluation¶
For large-scale image, take DOTA dataset as a example (the output file or visualization is in
$PATH_ROOT/tools/#DETECTOR/test_dota/VERSION
):
cd $PATH_ROOT/tools/#DETECTOR
python test_dota.py --test_dir='/PATH/TO/IMAGES/'
--gpus=0,1,2,3,4,5,6,7
-ms (multi-scale testing, optional)
-s (visualization, optional)
-cn (use cpu nms, slightly better <1% than gpu nms but slower, optional)
or (recommend in this repo, better than multi-scale testing)
python test_dota_sota.py --test_dir='/PATH/TO/IMAGES/'
--gpus=0,1,2,3,4,5,6,7
-s (visualization, optional)
-cn (use cpu nms, slightly better <1% than gpu nms but slower, optional)
Note
In order to set the breakpoint conveniently, the read and write mode of the file is’ a+’. If the model of the same #VERSION
needs to be tested again, the original test results need to be deleted.
For small-scale image, take HRSC2016 dataset as a example:
cd $PATH_ROOT/tools/#DETECTOR
python test_hrsc2016.py --test_dir='/PATH/TO/IMAGES/'
--gpu=0
--image_ext='bmp'
--test_annotation_path='/PATH/TO/ANNOTATIONS'
-s (visualization, optional)
Tensorboard
cd $PATH_ROOT/output/summary
tensorboard --logdir=.


API Reference¶
alpharotate.utils¶
densely_coded_label¶
- alpharotate.utils.densely_coded_label.angle_label_decode(angle_encode_label, angle_range, omega=1.0, mode=0)[source]¶
Decode binary/gray label back to angle label
- Parameters
angle_encode_label – binary/gray label
angle_label – angle label, range in [-90,0) or [-180, 0)
angle_range – 90 or 180
mode – 0: binary label, 1: gray label
- Returns
angle label
- alpharotate.utils.densely_coded_label.angle_label_encode(angle_label, angle_range, omega=1.0, mode=0)[source]¶
Encode angle label as binary/gray label
- Parameters
angle_label – angle label, range in [-90,0) or [-180, 0)
angle_range – 90 or 180
omega – angle discretization granularity
mode – 0: binary label, 1: gray label
- Returns
binary/gray label
Dense Coded Label: Proposed by “Xue Yang et al. Dense Label Encoding for Boundary Discontinuity Free Rotation Detection. CVPR 2021.”
- alpharotate.utils.densely_coded_label.binary_label_decode(binary_label, angle_range, omega=1.0)[source]¶
Decode binary label back to angle label
- Parameters
binary_label – binary label
angle_range – 90 or 180
omega – angle discretization granularity
- Returns
angle label
- alpharotate.utils.densely_coded_label.binary_label_encode(angle_label, angle_range, omega=1.0)[source]¶
Encode angle label as binary label
- Parameters
angle_label – angle label, range in [-90,0) or [-180, 0)
angle_range – 90 or 180
omega – angle discretization granularity
- Returns
binary label
- alpharotate.utils.densely_coded_label.get_all_binary_label(num_label, class_range)[source]¶
Get all binary label according to num_label
- Parameters
num_label – angle_range/omega, 90/omega or 180/omega
class_range – angle_range/omega, 90/omega or 180/omega
- Returns
all binary label
- alpharotate.utils.densely_coded_label.get_all_gray_label(angle_range)[source]¶
Get all gray label
- Parameters
angle_range – 90/omega or 180/omega
- Returns
all gray label
- alpharotate.utils.densely_coded_label.get_code_len(class_range, mode=0)[source]¶
Get encode length
- Parameters
class_range – angle_range/omega
mode – 0: binary label, 1: gray label
- Returns
encode length
smooth_label¶
- alpharotate.utils.smooth_label.angle_smooth_label(angle_label, angle_range=90, label_type=0, radius=4, omega=1)[source]¶
- Parameters
angle_label – angle label, range in [-90,0) or [-180, 0)
angle_range – 90 or 180
label_type – 0: gaussian label, 1: rectangular label, 2: pulse label, 3: triangle label
radius – window radius
omega – angle discretization granularity
- Returns
Circular Smooth Label: Proposed by “Xue Yang et al. Arbitrary-Oriented Object Detection with Circular Smooth Label. ECCV 2020.”
- alpharotate.utils.smooth_label.gaussian_label(label, num_class, u=0, sig=4.0)[source]¶
Get gaussian label
- Parameters
label – angle_label/omega
num_class – angle_range/omega
u – mean
sig – window radius
- Returns
gaussian label
- alpharotate.utils.smooth_label.pulse_label(label, num_class)[source]¶
Get pulse label
- Parameters
label – angle_label/omega
num_class – angle_range/omega
- Returns
pulse label
gaussian_metric¶
- alpharotate.utils.gaussian_metric.box2gaussian(boxes1, boxes2)[source]¶
Convert box \((x,y,w,h, heta)\) to Gaussian distribution \((\mathbf \mu, \mathbf \Sigma)\)
- Parameters
boxes1 – \((x_{1},y_{1},w_{1},h_{1}, heta_{1})\), [-1, 5]
boxes2 – \((x_{2},y_{2},w_{2},h_{2}, heta_{2})\), [-1, 5]
- Returns
\((\mathbf \mu, \mathbf \Sigma)\)
- alpharotate.utils.gaussian_metric.gaussian_kullback_leibler_divergence(boxes1, boxes2)[source]¶
Calculate the kullback-leibler divergence between boxes1 and boxes2
- Parameters
boxes1 – \((x_{1},y_{1},w_{1},h_{1}, heta_{1})\), shape: [-1, 5]
boxes2 – \((x_{2},y_{2},w_{2},h_{2}, heta_{2})\), shape: [-1, 5]
- Returns
kullback-leibler divergence, \(\mathbf D_{kl}\)
- alpharotate.utils.gaussian_metric.gaussian_wasserstein_distance(boxes1, boxes2)[source]¶
Calculate the wasserstein distance between boxes1 and boxes2: \(\mathbf D_{w} = ||\mathbf \mu_{1} - \mathbf \mu_{2}||^{2}_{2} + \mathbf Tr(\mathbf \Sigma_{1} + \mathbf \Sigma_{2} - 2(\mathbf \Sigma_{1}^{1/2}\mathbf \Sigma_{2}\mathbf \Sigma_{1}^{1/2})^{1/2})\)
- Parameters
boxes1 – \((x_{1},y_{1},w_{1},h_{1}, heta_{1})\), shape: [-1, 5]
boxes2 – \((x_{2},y_{2},w_{2},h_{2}, heta_{2})\), shape: [-1, 5]
- Returns
wasserstein distance, \(\mathbf D_{w}\)
- alpharotate.utils.gaussian_metric.kullback_leibler_divergence(mu1, mu2, mu1_T, mu2_T, sigma1, sigma2)[source]¶
Calculate the kullback-leibler divergence between two Gaussian distributions : \(\mathbf D_{kl} = 0.5*((\mathbf \mu_{1}-\mathbf \mu_{2})^T \mathbf \Sigma_{2}^{1/2}(\mathbf \mu_{1}-\mathbf \mu_{2})+0.5*\mathbf Tr(\mathbf \Sigma_{2}^{-1} \mathbf \Sigma_{1})+0.5*\ln |\mathbf \Sigma_{2}|/|\mathbf \Sigma_{1}| -1\)
- Parameters
mu1 – mean \((\mu_{1})\) of the Gaussian distribution, shape: [-1, 1, 2]
mu2 – mean \((\mu_{2})\) of the Gaussian distribution, shape: [-1, 1, 2]
mu1_T – transposition of \((\mu_{1})\), shape: [-1, 2, 1]
mu2_T – transposition of \((\mu_{2})\), shape: [-1, 2, 1]
sigma1 – covariance \((\Sigma_{1})\) of the Gaussian distribution, shape: [-1, 2, 2]
sigma2 – covariance \((\Sigma_{1})\) of the Gaussian distribution, shape: [-1, 2, 2]
- Returns
kullback-leibler divergence, \(\mathbf D_{kl}\)
- alpharotate.utils.gaussian_metric.wasserstein_distance_item2(sigma1, sigma2)[source]¶
Calculate the second term of wasserstein distance: \(\mathbf Tr(\mathbf \Sigma_{1} + \mathbf \Sigma_{2} - 2(\mathbf \Sigma_{1}^{1/2}\mathbf \Sigma_{2}\mathbf \Sigma_{1}^{1/2})^{1/2})\)
- Parameters
sigma1 – covariance \((\Sigma_{1})\) of the Gaussian distribution, shape: [-1, 2, 2]
sigma2 – covariance \((\Sigma_{1})\) of the Gaussian distribution, shape: [-1, 2, 2]
- Returns
the second term of wasserstein distance