from typing import Iterable
from numbers import Number
import re
from pathlib import Path
import json
import glob
import shutil
from stardust.components.frame import *
from stardust.file.download_files import load_dataset
from stardust.components.annotations import *
from stardust.components.attachment.image import *
from stardust.components.attachment.pointcloud import *
from stardust.components.camera.camera import *
from stardust.rosetta.rosetta_data import RosettaData as Export
from stardust.ms.ms import *
__all__ = [
    "read_rosetta", "read_ms"
]
def _serialize(obj):
    dic = dict()
    if hasattr(obj, "__dict__"):
        for key, val in obj.__dict__.items():
            if str(key).startswith("__") and str(key).endswith("__"):
                continue
            if str(key) in ("id", "shapely_polygon"):
                continue
            dic[key] = _serialize(val)
    elif isinstance(obj, (str, Number)):
        return obj
    elif isinstance(obj, Dict):
        for key, val in obj.items():
            dic[key] = _serialize(val)
    elif isinstance(obj, Iterable):
        return [_serialize(val) for val in obj]
    return dic
class Convertor:
    def __init__(self) -> None:
        self.project_id = None
        self.input_path = None
        self.pattern = re.compile(r'_(\d+)\.json')
    def yield_rosetta_json(self):
        """
        yield json file from local path which is download from rosetta
        Returns:
        """
        if not (rosetta_json_lst := glob.glob(os.path.join(self.input_path, f"{self.project_id}/json/**/*.json"))):
            rosetta_json_lst = glob.glob(os.path.join(self.input_path, f"{self.project_id}/json/*.json"))
        for file in rosetta_json_lst:
            try:
                with open(file, 'r', encoding='utf-8') as jf:
                    _ = json.load(jf)
                    self.file_name = str(file)
                    __ = Path(file).stem.rsplit("_", 1)[1]
                    if self.pattern.search(str(file)):
                        self.frame_num = int(Path(file).stem.rsplit("_", 1)[1])
                    else:
                        self.frame_num = 0
                    yield _
            except json.decoder.JSONDecodeError as e:
                print(e)
                continue
            except Exception as e:
                raise e
    def factory_children(self, children: "rosetta children",
                         parent_id: str = None):
        """
        Start processing operation subentries, recursion
        Args:
            children:
                rosetta children
            slot_id:
                rosetta slot id
        Returns:
            Dictionary, where key is the id of children and value is the type
        """
        children_lst = dict()
        label_kind = None
        for child in children:
            if child['type'] == "input":
                input_obj = Input.gen_input(child, parent_id)
                children_lst[input_obj.id] = "input"
                if input_obj.name in ("type", "类型", "类别"):
                    label_kind = input_obj.value
                self.label_result.input_lst.setdefault(input_obj.id, input_obj)
            elif child['type'] == "slotChildren":
                self.factory_slot_children(child['slotsChildren'], parent_id, children_lst)
            elif child['type'] == "slot":
                self.factory_slots(child['slots'], parent_id, children_lst)
            else:
                raise ValueError
        return children_lst, label_kind
    def factory_slot(self, slot: "rosetta slot",
                     children_lst: Dict = None,
                     parent_id: str = None,
                     label_kind: str = None,
                     team_id: int = None) -> Optional[str]:
        """
        Processing label box
        Args:
            slot:
                rosetta slot
            children_lst:
                rosetta children
            parent_id:
                Upper floor ID
        Returns:
            str, id
        """
        obj_type = slot['type']
        if obj_type == "box3d":  # 3D框
            box3d = Box3D.gen_box3d(slot, children_lst, parent_id, label_kind, team_id=team_id)
            if not box3d:
                return
            self.label_result.box3d_lst.setdefault(box3d.id, box3d)
            return box3d.id
        elif obj_type == "box2d":  # 2D框
            box2d = Box2D.gen_box(slot, children_lst, parent_id, label_kind, team_id=team_id)
            if not box2d:
                return
            self.label_result.box2d_lst.setdefault(box2d.id, box2d)
            return box2d.id
        elif obj_type == "cuboid":  # 映射立体框
            pass
        elif obj_type == "line":  # 2D线
            line = Line.gen_line(slot, children_lst, parent_id, label_kind, team_id=team_id)
            if not line:
                return
            self.label_result.line_lst.setdefault(line.id, line)
            return line.id
        elif obj_type == "polygon":  # 多边形
            polygon = Polygon.gen_polygon(slot, children_lst, parent_id, label_kind, team_id=team_id)
            if not polygon:
                return
            self.label_result.polygon_lst.setdefault(polygon.id, polygon)
            return polygon.id
        elif obj_type == "point":  # 2D关键点
            point = Point.gen_point(slot, children_lst, parent_id, label_kind, team_id=team_id)
            if not point:
                return
            self.label_result.key_point_lst.setdefault(point.id, point)
        elif obj_type == "text":  # 文本标注
            text = Text.gen_text(slot, children_lst, parent_id, label_kind, team_id=team_id)
            if not text:
                return
            self.label_result.text_lst.setdefault(text.id, text)
        else:
            print(f'不支持的结构{obj_type}')
            pass
            # raise ValueError('不支持的结构')
    def factory_slots(self, slot_lst: "rosetta的slots json",
                      parent_id: str = None,
                      children_lst: Dict = {}) -> None:
        """
        Start processing the tag box list
        Args:
            slot_lst:
                rosetta slot
            parent_id:
                Upper floor ID
            children_lst:
                rosetta children
        Returns:
        """
        for slot in slot_lst:
            slot_id = self.factory_slot(slot, list(), parent_id=parent_id, team_id=slot.get("teamId", None))
            if slot_id:
                children_lst[slot_id]: slot['type']
    def factory_slot_children(self, sc_lst: "slotsChildren or slots",
                              parent_id: str = None,
                              children_lst: Dict = {}) -> None:
        """
        Start processing the annotation instance
        Args:
            children_lst:
                rosetta children
            sc_lst:
                rosetta slot
            parent_id:
                Upper floor ID
        Returns:
        """
        for sc in sc_lst:
            try:
                slot_id = sc['slot']['id']
                if sc.get("children"):
                    _children_lst, label_kind = self.factory_children(sc['children'], slot_id)
                    self.factory_slot(sc['slot'], _children_lst, parent_id, label_kind,
                                      team_id=sc['slot'].get("teamId", None))
                    if slot_id:
                        children_lst[slot_id] = sc['slot']['type']
                else:
                    self.factory_slot(sc['slot'], list(), parent_id, team_id=sc['slot'].get("teamId", None))
            except Exception as e:
                raise e
                continue
    def factory_anns(self, ann_lst: List) -> None:
        """
        Annotated configuration
        Args:
            ann_lst:
                Annotated result list
        Returns:
        """
        for ann in ann_lst:
            ann_type = ann['type']
            if ann_type == "input":
                input = Input.gen_input(ann)
                self.label_result.input_lst.setdefault(input.id, input)
            elif ann_type == "slotChildren":
                self.factory_slot_children(ann['slotsChildren'])
            elif ann_type == "slot":
                self.factory_slots(ann['slots'])
            elif ann_type == "childrenOnly":
                pass
    def factory_label_file(self, label_type: str,
                           label_file,
                           **kwargs) -> Media:
        """
        Start processing annotation file, single frame
        Args:
            label_type:
                Annotated type, image, point cloud
            label_file:
                Mark file
        Returns:
                a Media instance
        """
        if "POINTCLOUD" in label_type:
            pcd_url = label_file['url'] if isinstance(label_file, Dict) else label_file
            meta_point = PointCloud(
                uri=pcd_url,
                name=None,
                file_path=Path(pcd_url)
            )
            meta_images = list()
            if "imageSources" in label_file:
                for image in label_file['imageSources']:
                    meta_images.append(Image.gen_image(image))
            media = Media(point_cloud=meta_point, image=meta_images)
        elif "IMAGE" in label_type:
            image_url = label_file['url'] if isinstance(label_file, Dict) else label_file
            media = Media(
                point_cloud=None,
                image=Image(uri=image_url,
                            width=kwargs['size'].get("width") if kwargs.get("size") else None,
                            height=kwargs['size'].get("height") if kwargs.get("size") else None
                            )
            )
        else:
            media = {}
        return media
    def factory_rosetta(self, data: Dict) -> Frame:
        """
        Start processing, all single frame data
        Args:
            data: rosetta data
        Returns:
            a Frame instance
        """
        label_file = data['taskParams']['record']['attachment']
        label_type = data['taskParams']['record']['attachmentType']
        # rosetta 相关信息
        self.task_info = TaskInfo(
            task_id=data['taskId'],
            project_id=data['projectId'],
            pool_id=data['poolId'],
            frame_num=self.frame_num
        )
        # 处理标注文件
        self.media = self.factory_label_file(label_type=label_type,
                                             label_file=label_file,
                                             size=data['taskParams']['record']['metadata'].get("size") \
                                                 if data['taskParams']['record'].get('metadata') else None
                                             )
        # 标注结果
        self.label_result = Annotation()
        # 处理标注结果
        self.factory_anns(data['result']['annotations'])
        annotation = self.label_result
        # 算法标注结果
        self.label_result = Prediction()
        # 处理预标注结果
        if not data['taskParams']['record']['metadata']:
            # return None # 为啥要return None 来着???
            prediction = None
        else:
            if pres_data := data['taskParams']['record']['metadata'].get("preprocessedData"):
                if pres_data.get("annotations"):
                    self.factory_anns(pres_data['annotations'])
            prediction = self.label_result
        return Frame(media=self.media,
                     task_info=self.task_info,
                     annotation=annotation,
                     prediction=prediction
                     )
    def convert_rosetta(self) -> Dict:
        """
        Start to do the conversion work, yield a single frame of data
        Returns:
        """
        for json_data in self.yield_rosetta_json():
            yield self.factory_rosetta(json_data)
    def export_rosetta(self, pool_lst=[], split_name='old', env_name='top') -> None:
        """
        export rosetta project data
        Args:
            pool_lst:
                A list of pools to export
        Returns:
        """
        Export(self.project_id, self.input_path, env_name).export(pool_lst, split_name=split_name)
    def read_rosetta(self, project_id,
                     input_path,
                     pool_lst,
                     export_type=None,
                     **kwargs):
        """
        start to convert data, your choice can be: SDK、 json,default to be SDK,
        if you want to export other formats, please refer to stardust.conversion
        Args:
            project_id:
                rosetta project
            input_path:
                path to store json data exported from Rosetta
            pool_lst:
                pool list from your rosetta project
            export_type:
                expected export data type,SDK or json
        Returns:
        """
        assert project_id and isinstance(project_id, int)
        self.project_id = project_id
        self.input_path = input_path
        # 从rosetta导出数据
        if os.path.exists(os.path.join(self.input_path, str(self.project_id), "json")):
            print("开始获取标注结果")
            shutil.rmtree(os.path.join(self.input_path, str(self.project_id)))
        self.export_rosetta(pool_lst, split_name=kwargs.get("split"), env_name=kwargs.get("env_name"))
        # 适应生成器
        for frame_obj in self.convert_rosetta():
            if not frame_obj:
                continue
            if export_type == "json":
                yield _serialize(frame_obj)
            else:
                yield frame_obj
    def read_ms(self, export_type=None, **kwargs):
        """
        从marning star 读取数据
        Args:
            export_type:
                导出的数据类型
            ms_data:
                ms 的数据
        Returns:
                frame对象
        问题:
            taskinfo 没有
            标注文件没有
            帧 与 标注文件之间的关系
        """
        # 从morning star 取数据集
        ms_data = MS().export_dataset(**kwargs)
        assert ms_data
        # 取到的数据都是一个个实例,循环进行处理
        for instance_lst in ms_data:
            for index, instance in enumerate(instance_lst):
                self.frame_num = index
                res = self.factory_rosetta(instance)
                if export_type == "json":
                    yield _serialize(res)
                else:
                    yield res
[docs]
def read_rosetta(
        project_id: int = None,
        input_path: str = None,
        pool_lst: Optional[List[int]] = None,
        export_type=None,
        **kwargs):
    """
    start to convert data, your choice can be: SDK、 json,default to be SDK,
    if you want to export other formats, please refer to stardust.conversion
    Args:
        project_id:
            rosetta project
        input_path:
            path to store json data exported from Rosetta
        pool_lst:
            pool list from your rosetta project
        export_type:
            expected export data type,SDK or json
    Returns:
    """
    assert project_id and input_path
    return Convertor().read_rosetta(project_id, input_path, pool_lst, export_type, **kwargs) 
[docs]
def read_ms(*args, **kwargs):
    """
    Export MorningStar data
    Args:
        dataset_id:
            Data set ID
        version_num:
            Version number
        slice_id:
            Slice ID
        page_no:
            Slice paging
        page_size:
            Amount of data per page
    Returns:
        generator object
    Examples:
        .. code-block:: python
            from stardust.utils.convert import read_ms
            from stardust.convertion.to_pandaset import to_pandaset
            # Derived data
            gen_data = read_ms(
                dataset_id=351787480925605888,
                version_num=18
            )
            # scale format
            to_pandaset(gen_data, export_path:"Input save directory")
    """
    return Convertor().read_ms(*args, **kwargs) 
if __name__ == '__main__':
    current_dir = os.getcwd()
    base_dir = os.path.dirname(os.path.dirname(current_dir))
    project_id = 1354
    pool_lst = [33750, ]
    # 输入路径是rosetta的保存路径/json/*.json
    input_path = f"{base_dir}/data/"
    os.makedirs(input_path, exist_ok=True)
    gen_data = read_rosetta(project_id=project_id,
                            input_path=input_path,
                            pool_lst=pool_lst,
                            export_type="json"
                            )
    os.makedirs(os.path.join(input_path, "output"), exist_ok=True)
    # with open(f"{os.path.join(input_path, f'{project_id}', 'output')}/{project_id}.json", 'w', encoding='utf-8') as f:
    from pprint import pprint
    for data in gen_data:
        # pprint(len(data['annotation']['box3d_lst']))
        pass
        # json.dump(data, f, ensure_ascii=False, indent=2)
        # break
        # load_dataset(data, "/Users/mac/Desktop/1835")
        # break