Source code for stardust.utils.convert

from typing import Iterable
from numbers import Number
import re
from pathlib import Path
import json
import glob
import shutil

from stardust.components.frame import *
from stardust.file.download_files import load_dataset
from stardust.components.annotations import *
from stardust.components.attachment.image import *
from stardust.components.attachment.pointcloud import *
from stardust.components.camera.camera import *
from stardust.rosetta.rosetta_data import RosettaData as Export
from stardust.ms.ms import *

__all__ = [
    "read_rosetta", "read_ms"
]


def _serialize(obj):
    dic = dict()
    if hasattr(obj, "__dict__"):
        for key, val in obj.__dict__.items():
            if str(key).startswith("__") and str(key).endswith("__"):
                continue
            if str(key) in ("id", "shapely_polygon"):
                continue
            dic[key] = _serialize(val)
    elif isinstance(obj, (str, Number)):
        return obj
    elif isinstance(obj, Dict):
        for key, val in obj.items():
            dic[key] = _serialize(val)
    elif isinstance(obj, Iterable):
        return [_serialize(val) for val in obj]
    return dic


class Convertor:
    def __init__(self) -> None:
        self.project_id = None
        self.input_path = None
        self.pattern = re.compile(r'_(\d+)\.json')

    def yield_rosetta_json(self):
        """
        yield json file from local path which is download from rosetta

        Returns:

        """

        if not (rosetta_json_lst := glob.glob(os.path.join(self.input_path, f"{self.project_id}/json/**/*.json"))):
            rosetta_json_lst = glob.glob(os.path.join(self.input_path, f"{self.project_id}/json/*.json"))

        for file in rosetta_json_lst:
            try:
                with open(file, 'r', encoding='utf-8') as jf:
                    _ = json.load(jf)
                    self.file_name = str(file)
                    __ = Path(file).stem.rsplit("_", 1)[1]
                    if self.pattern.search(str(file)):
                        self.frame_num = int(Path(file).stem.rsplit("_", 1)[1])
                    else:
                        self.frame_num = 0
                    yield _
            except json.decoder.JSONDecodeError as e:
                print(e)
                continue
            except Exception as e:
                raise e

    def factory_children(self, children: "rosetta children",
                         parent_id: str = None):
        """
        Start processing operation subentries, recursion

        Args:
            children:
                rosetta children
            slot_id:
                rosetta slot id

        Returns:
            Dictionary, where key is the id of children and value is the type
        """
        children_lst = dict()
        label_kind = None
        for child in children:
            if child['type'] == "input":
                input_obj = Input.gen_input(child, parent_id)
                children_lst[input_obj.id] = "input"
                if input_obj.name in ("type", "类型", "类别"):
                    label_kind = input_obj.value
                self.label_result.input_lst.setdefault(input_obj.id, input_obj)
            elif child['type'] == "slotChildren":
                self.factory_slot_children(child['slotsChildren'], parent_id, children_lst)
            elif child['type'] == "slot":
                self.factory_slots(child['slots'], parent_id, children_lst)
            else:
                raise ValueError
        return children_lst, label_kind

    def factory_slot(self, slot: "rosetta slot",
                     children_lst: Dict = None,
                     parent_id: str = None,
                     label_kind: str = None,
                     team_id: int = None) -> Optional[str]:
        """
        Processing label box

        Args:
            slot:
                rosetta slot
            children_lst:
                rosetta children
            parent_id:
                Upper floor ID
        Returns:
            str, id
        """
        obj_type = slot['type']
        if obj_type == "box3d":  # 3D框
            box3d = Box3D.gen_box3d(slot, children_lst, parent_id, label_kind, team_id=team_id)
            if not box3d:
                return
            self.label_result.box3d_lst.setdefault(box3d.id, box3d)
            return box3d.id
        elif obj_type == "box2d":  # 2D框
            box2d = Box2D.gen_box(slot, children_lst, parent_id, label_kind, team_id=team_id)
            if not box2d:
                return
            self.label_result.box2d_lst.setdefault(box2d.id, box2d)
            return box2d.id
        elif obj_type == "cuboid":  # 映射立体框
            pass
        elif obj_type == "line":  # 2D线
            line = Line.gen_line(slot, children_lst, parent_id, label_kind, team_id=team_id)
            if not line:
                return
            self.label_result.line_lst.setdefault(line.id, line)
            return line.id
        elif obj_type == "polygon":  # 多边形
            polygon = Polygon.gen_polygon(slot, children_lst, parent_id, label_kind, team_id=team_id)
            if not polygon:
                return
            self.label_result.polygon_lst.setdefault(polygon.id, polygon)
            return polygon.id
        elif obj_type == "point":  # 2D关键点
            point = Point.gen_point(slot, children_lst, parent_id, label_kind, team_id=team_id)
            if not point:
                return
            self.label_result.key_point_lst.setdefault(point.id, point)
        elif obj_type == "text":  # 文本标注
            text = Text.gen_text(slot, children_lst, parent_id, label_kind, team_id=team_id)
            if not text:
                return
            self.label_result.text_lst.setdefault(text.id, text)
        else:
            print(f'不支持的结构{obj_type}')
            pass
            # raise ValueError('不支持的结构')

    def factory_slots(self, slot_lst: "rosetta的slots json",
                      parent_id: str = None,
                      children_lst: Dict = {}) -> None:
        """
        Start processing the tag box list

        Args:
            slot_lst:
                rosetta slot
            parent_id:
                Upper floor ID
            children_lst:
                rosetta children

        Returns:

        """
        for slot in slot_lst:
            slot_id = self.factory_slot(slot, list(), parent_id=parent_id, team_id=slot.get("teamId", None))
            if slot_id:
                children_lst[slot_id]: slot['type']

    def factory_slot_children(self, sc_lst: "slotsChildren or slots",
                              parent_id: str = None,
                              children_lst: Dict = {}) -> None:
        """
        Start processing the annotation instance

        Args:
            children_lst:
                rosetta children
            sc_lst:
                rosetta slot
            parent_id:
                Upper floor ID

        Returns:

        """
        for sc in sc_lst:
            try:
                slot_id = sc['slot']['id']
                if sc.get("children"):
                    _children_lst, label_kind = self.factory_children(sc['children'], slot_id)
                    self.factory_slot(sc['slot'], _children_lst, parent_id, label_kind,
                                      team_id=sc['slot'].get("teamId", None))
                    if slot_id:
                        children_lst[slot_id] = sc['slot']['type']
                else:
                    self.factory_slot(sc['slot'], list(), parent_id, team_id=sc['slot'].get("teamId", None))
            except Exception as e:
                raise e
                continue

    def factory_anns(self, ann_lst: List) -> None:
        """
        Annotated configuration

        Args:
            ann_lst:
                Annotated result list
        Returns:

        """
        for ann in ann_lst:
            ann_type = ann['type']
            if ann_type == "input":
                input = Input.gen_input(ann)
                self.label_result.input_lst.setdefault(input.id, input)
            elif ann_type == "slotChildren":
                self.factory_slot_children(ann['slotsChildren'])
            elif ann_type == "slot":
                self.factory_slots(ann['slots'])
            elif ann_type == "childrenOnly":
                pass

    def factory_label_file(self, label_type: str,
                           label_file,
                           **kwargs) -> Media:
        """
        Start processing annotation file, single frame

        Args:
            label_type:
                Annotated type, image, point cloud
            label_file:
                Mark file
        Returns:
                a Media instance
        """
        if "POINTCLOUD" in label_type:
            pcd_url = label_file['url'] if isinstance(label_file, Dict) else label_file
            meta_point = PointCloud(
                uri=pcd_url,
                name=None,
                file_path=Path(pcd_url)
            )
            meta_images = list()
            if "imageSources" in label_file:
                for image in label_file['imageSources']:
                    meta_images.append(Image.gen_image(image))
            media = Media(point_cloud=meta_point, image=meta_images)
        elif "IMAGE" in label_type:
            image_url = label_file['url'] if isinstance(label_file, Dict) else label_file
            media = Media(
                point_cloud=None,
                image=Image(uri=image_url,
                            width=kwargs['size'].get("width") if kwargs.get("size") else None,
                            height=kwargs['size'].get("height") if kwargs.get("size") else None
                            )
            )
        else:
            media = {}

        return media

    def factory_rosetta(self, data: Dict) -> Frame:
        """
        Start processing, all single frame data

        Args:
            data: rosetta data

        Returns:
            a Frame instance

        """
        label_file = data['taskParams']['record']['attachment']
        label_type = data['taskParams']['record']['attachmentType']

        # rosetta 相关信息
        self.task_info = TaskInfo(
            task_id=data['taskId'],
            project_id=data['projectId'],
            pool_id=data['poolId'],
            frame_num=self.frame_num
        )

        # 处理标注文件
        self.media = self.factory_label_file(label_type=label_type,
                                             label_file=label_file,
                                             size=data['taskParams']['record']['metadata'].get("size") \
                                                 if data['taskParams']['record'].get('metadata') else None
                                             )

        # 标注结果
        self.label_result = Annotation()
        # 处理标注结果
        self.factory_anns(data['result']['annotations'])
        annotation = self.label_result

        # 算法标注结果
        self.label_result = Prediction()

        # 处理预标注结果
        if not data['taskParams']['record']['metadata']:
            # return None # 为啥要return None 来着???
            prediction = None
        else:
            if pres_data := data['taskParams']['record']['metadata'].get("preprocessedData"):
                if pres_data.get("annotations"):
                    self.factory_anns(pres_data['annotations'])
            prediction = self.label_result

        return Frame(media=self.media,
                     task_info=self.task_info,
                     annotation=annotation,
                     prediction=prediction
                     )

    def convert_rosetta(self) -> Dict:
        """
        Start to do the conversion work, yield a single frame of data

        Returns:

        """
        for json_data in self.yield_rosetta_json():
            yield self.factory_rosetta(json_data)

    def export_rosetta(self, pool_lst=[], split_name='old', env_name='top') -> None:
        """
        export rosetta project data

        Args:
            pool_lst:
                A list of pools to export
        Returns:

        """
        Export(self.project_id, self.input_path, env_name).export(pool_lst, split_name=split_name)

    def read_rosetta(self, project_id,
                     input_path,
                     pool_lst,
                     export_type=None,
                     **kwargs):
        """
        start to convert data, your choice can be: SDK、 json,default to be SDK,
        if you want to export other formats, please refer to stardust.conversion

        Args:
            project_id:
                rosetta project
            input_path:
                path to store json data exported from Rosetta
            pool_lst:
                pool list from your rosetta project
            export_type:
                expected export data type,SDK or json

        Returns:

        """

        assert project_id and isinstance(project_id, int)

        self.project_id = project_id
        self.input_path = input_path

        # 从rosetta导出数据
        if os.path.exists(os.path.join(self.input_path, str(self.project_id), "json")):
            print("开始获取标注结果")
            shutil.rmtree(os.path.join(self.input_path, str(self.project_id)))

        self.export_rosetta(pool_lst, split_name=kwargs.get("split"), env_name=kwargs.get("env_name"))

        # 适应生成器
        for frame_obj in self.convert_rosetta():
            if not frame_obj:
                continue
            if export_type == "json":
                yield _serialize(frame_obj)
            else:
                yield frame_obj

    def read_ms(self, export_type=None, **kwargs):
        """
        从marning star 读取数据

        Args:
            export_type:
                导出的数据类型
            ms_data:
                ms 的数据

        Returns:
                frame对象

        问题:
            taskinfo 没有
            标注文件没有
            帧 与 标注文件之间的关系
        """
        # 从morning star 取数据集
        ms_data = MS().export_dataset(**kwargs)
        assert ms_data

        # 取到的数据都是一个个实例,循环进行处理
        for instance_lst in ms_data:
            for index, instance in enumerate(instance_lst):
                self.frame_num = index
                res = self.factory_rosetta(instance)
                if export_type == "json":
                    yield _serialize(res)
                else:
                    yield res


[docs] def read_rosetta( project_id: int = None, input_path: str = None, pool_lst: Optional[List[int]] = None, export_type=None, **kwargs): """ start to convert data, your choice can be: SDK、 json,default to be SDK, if you want to export other formats, please refer to stardust.conversion Args: project_id: rosetta project input_path: path to store json data exported from Rosetta pool_lst: pool list from your rosetta project export_type: expected export data type,SDK or json Returns: """ assert project_id and input_path return Convertor().read_rosetta(project_id, input_path, pool_lst, export_type, **kwargs)
[docs] def read_ms(*args, **kwargs): """ Export MorningStar data Args: dataset_id: Data set ID version_num: Version number slice_id: Slice ID page_no: Slice paging page_size: Amount of data per page Returns: generator object Examples: .. code-block:: python from stardust.utils.convert import read_ms from stardust.convertion.to_pandaset import to_pandaset # Derived data gen_data = read_ms( dataset_id=351787480925605888, version_num=18 ) # scale format to_pandaset(gen_data, export_path:"Input save directory") """ return Convertor().read_ms(*args, **kwargs)
if __name__ == '__main__': current_dir = os.getcwd() base_dir = os.path.dirname(os.path.dirname(current_dir)) project_id = 1354 pool_lst = [33750, ] # 输入路径是rosetta的保存路径/json/*.json input_path = f"{base_dir}/data/" os.makedirs(input_path, exist_ok=True) gen_data = read_rosetta(project_id=project_id, input_path=input_path, pool_lst=pool_lst, export_type="json" ) os.makedirs(os.path.join(input_path, "output"), exist_ok=True) # with open(f"{os.path.join(input_path, f'{project_id}', 'output')}/{project_id}.json", 'w', encoding='utf-8') as f: from pprint import pprint for data in gen_data: # pprint(len(data['annotation']['box3d_lst'])) pass # json.dump(data, f, ensure_ascii=False, indent=2) # break # load_dataset(data, "/Users/mac/Desktop/1835") # break