lane_detection

LaneIoU2D

class stardust.metric.lane_detection.LaneIoU2D[source]

Bases: object

compute(gt_lanes, pd_lanes, IoU_thr, kpt_thr, img_size, lane_type)[source]

Compute metric of lanes

Args:
gt_lanes: List[Point]

The ground truth of lanes

pd_lane: List[Point]

The prediction of lanes

IoU_thr: float

The IoU threshold of tp lanes

kpt_thr: float

The dist threshold of tp points

img_size: tuple

height and width of origin data

lane_type: str

which type of lane to compute metric, only support ‘2D’ for now

Returns:
Tuple:

metric of lane_gt, lane_pd, lane_tp, kpt_pd, kpt_tp

Examples:
from stardust.metric.lane_detection import LaneIoU2D
metric = LaneIoU2D()
gt_lanes = [Point(1, 1), Point(2, 2), Point(3, 3)]
pd_lanes = [Point(1, 1), Point(2, 2), Point(3, 3)]
lane_gt, lane_pd, lane_tp, kpt_pd, kpt_tp = laneIOU.compute(gt_lanes, pd_lanes, 0.5, (256, 512), '2D)

compute_metric_single_frame

stardust.metric.lane_detection.compute_metric_single_frame(gt_lanes, pd_lanes, IoU_thr, kpt_thr, img_size, lane_type)[source]

Compute metric of all lanes of single frame

Args:
gt_lanes: List

The ground truth of lanes

pr_lane: List

The prediction of lanes

IoU_thr: float

The IoU threshold of tp lanes

kpt_thr: float

The dist threshold of tp points

img_size: tuple

height and width of origin data

lane_type: str

which type of lane to compute metric, only support ‘2D’ for now

Returns:

metric: tuple

compute_metric

stardust.metric.lane_detection.compute_metric(data, IoU_thr=0.5, kpt_thr=3, save_path=None)[source]

Compute metric of lanes

Args:
data: Generator

A generator object to get all information from all frames

IoU_thr: float

The IoU threshold of tp lanes

kpt_thr: float

The dist threshold of tp points

save_path: str

Local path to save metric results

Returns:Tuple

The first one represents the metric of every single frame and the second represents the metric of all frames

Examples:
from stardust.metric.lane_detection import compute_metric
from stardust.rosetta.rosetta_data import RosettaData

project_id = 856
json_datas = read_rosetta(project_id=project_id,
                        input_path=input_path,
                        )
metric = compute_metric(json_datas, 0.5, 3, 'local/')