lane_detection
LaneIoU2D
- class stardust.metric.lane_detection.LaneIoU2D[source]
Bases:
object
- compute(gt_lanes, pd_lanes, IoU_thr, kpt_thr, img_size, lane_type)[source]
Compute metric of lanes
- Args:
- gt_lanes: List[Point]
The ground truth of lanes
- pd_lane: List[Point]
The prediction of lanes
- IoU_thr: float
The IoU threshold of tp lanes
- kpt_thr: float
The dist threshold of tp points
- img_size: tuple
height and width of origin data
- lane_type: str
which type of lane to compute metric, only support ‘2D’ for now
- Returns:
- Tuple:
metric of lane_gt, lane_pd, lane_tp, kpt_pd, kpt_tp
- Examples:
from stardust.metric.lane_detection import LaneIoU2D metric = LaneIoU2D() gt_lanes = [Point(1, 1), Point(2, 2), Point(3, 3)] pd_lanes = [Point(1, 1), Point(2, 2), Point(3, 3)] lane_gt, lane_pd, lane_tp, kpt_pd, kpt_tp = laneIOU.compute(gt_lanes, pd_lanes, 0.5, (256, 512), '2D)
compute_metric_single_frame
- stardust.metric.lane_detection.compute_metric_single_frame(gt_lanes, pd_lanes, IoU_thr, kpt_thr, img_size, lane_type)[source]
Compute metric of all lanes of single frame
- Args:
- gt_lanes: List
The ground truth of lanes
- pr_lane: List
The prediction of lanes
- IoU_thr: float
The IoU threshold of tp lanes
- kpt_thr: float
The dist threshold of tp points
- img_size: tuple
height and width of origin data
- lane_type: str
which type of lane to compute metric, only support ‘2D’ for now
- Returns:
metric: tuple
compute_metric
- stardust.metric.lane_detection.compute_metric(data, IoU_thr=0.5, kpt_thr=3, save_path=None)[source]
Compute metric of lanes
- Args:
- data: Generator
A generator object to get all information from all frames
- IoU_thr: float
The IoU threshold of tp lanes
- kpt_thr: float
The dist threshold of tp points
- save_path: str
Local path to save metric results
- Returns:Tuple
The first one represents the metric of every single frame and the second represents the metric of all frames
- Examples:
from stardust.metric.lane_detection import compute_metric from stardust.rosetta.rosetta_data import RosettaData project_id = 856 json_datas = read_rosetta(project_id=project_id, input_path=input_path, ) metric = compute_metric(json_datas, 0.5, 3, 'local/')