import os
import re
import json
import asyncio
from pathlib import Path
from loguru import logger
from urllib.parse import quote, unquote
from typing import List, Sequence, Dict, Tuple
import aiohttp
import uvloop
from tqdm import tqdm
from stardust.components.attachment.image import ALL_IMAGE
from stardust.components.attachment.pointcloud import ALL_POINTCLOUD
from stardust.rosetta.rosetta_data import RosettaData
from stardust.components.frame import Frame
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger.add("app.log", rotation="500 MB")
[docs]
class Downloader:
def __init__(self, urls: List[Tuple[str]], save_path: str = None):
"""
downloader
Args:
urls:
Required download link
save_path:
Save directory
"""
self.urls = urls
self.save_path = save_path
self.slice_num = 1000
self.attempts = 3
self.bar = tqdm(total=len(self.urls))
self.err_list = list()
[docs]
@staticmethod
def is_empty(filename):
return os.path.getsize(filename) != 0
[docs]
@staticmethod
async def write_local(result, local_path):
os.makedirs(os.path.dirname(local_path), exist_ok=True)
with open(local_path, 'wb') as wp:
wp.write(result)
return True
[docs]
async def fetch(self, session, url, local_path):
for n in range(self.attempts):
try:
async with session.get(url) as response:
if response.status != 200:
print(f"Failed to download {url} status_code:{response.status}")
continue
content = await response.content.read()
result = await self.write_local(content, local_path)
if result:
self.bar.update(1)
if n != 0:
print(f'{n} Retry successfully {url}')
return True
except Exception as e:
print(f'Failed to download {url}: {e}')
await asyncio.sleep(0.3)
self.err_list.append(f"Failed to download {url}")
return False
[docs]
async def async_download(self):
async with aiohttp.ClientSession() as session:
for i in range(0, len(self.urls), self.slice_num):
tasks = []
for url, file_save_path in self.urls[i: i + self.slice_num]:
url = quote(url, safe=':/')
if os.path.exists(file_save_path) and self.is_empty(file_save_path):
self.bar.update(1)
continue
tasks.append(asyncio.create_task(self.fetch(session, url, file_save_path)))
if tasks:
await asyncio.wait(tasks)
if self.err_list:
error_txt = os.path.join(os.getcwd(), 'download_error.txt')
with open(error_txt, 'w') as wp:
for _ in self.err_list:
wp.write(_)
wp.write("\n")
print(f'{len(self.err_list)} urls failed to download')
self.bar.close()
[docs]
def attachment_name_factory(att, save_path, project_id, task_id, frame_num, image_url_lst, pointcloud_url_lst,
name=None):
if att.startswith("oss://stardust-data/"):
att = att.replace("oss://stardust-data/", "https://stardust-data.oss-cn-hangzhou.aliyuncs.com/")
if (suffix := Path(att).suffix) in ALL_IMAGE:
if name:
image_url_lst.append(
(att, os.path.join(save_path, str(project_id), "images", name, f"{task_id}_{frame_num}{suffix}")))
else:
image_url_lst.append(
(att, os.path.join(save_path, str(project_id), "images", f"{task_id}_{frame_num}{suffix}")))
elif suffix in ALL_POINTCLOUD:
pointcloud_url_lst.append(
(att, os.path.join(save_path, str(project_id), "pcds", f"{task_id}_{frame_num}{suffix}")))
else:
raise TypeError("暂不支持的数据集")
[docs]
def load_dataset(frame: Frame = None,
project_id: int = None,
save_path=None,
**kwargs) -> None:
"""
The function of loading data sets, you can enter the Frame object, you can also enter the project pool, if there is already a local rosetta_json will not be downloaded again
save_path is used as the path to save annotation files. If the path has a frame, the data in the frame is obtained. Otherwise, the project_id,, and other values are required.
Args:
frame:
Frame object
save_path:
Save directory
peoject_id:
rosetta project ID
pool_lst:
pool list
Returns:
Outputs:
path
-{project_id}
--jsons
{task_id}_{frame_id}.json
--images
--{camera_id}
{task_id}_{frame_id).jpg
--pcds
{task_id}_{frame_id}.pcd
--poses
{task_id}_{frame_id).txt
--calibrations
{camera_id}.json
"""
assert save_path
if save_path and not os.path.exists(save_path):
os.makedirs(save_path)
# 需要下载的链接
image_url_lst = list()
pointcloud_url_lst = list()
# 如果是Frame对象
if frame:
project_id = frame.task_info.project_id # rosetta项目ID
task_id = frame.task_info.task_id # 任务ID
frame_num = frame.task_info.frame_num # 帧number
media = frame.media # 标注文件
# 先判断是不是点云
if hasattr(media, "point_cloud") and media.point_cloud and media.point_cloud.uri:
attachment_name_factory(media.point_cloud.uri, save_path, project_id, task_id, frame_num, image_url_lst,
pointcloud_url_lst)
# 如果有图片
if hasattr(media, "image") and media.image:
if isinstance(media.image, Sequence):
for img in media.image:
attachment_name_factory(img.uri, save_path, project_id, task_id, frame_num, image_url_lst,
pointcloud_url_lst, name=img.camera_param.name)
else:
attachment_name_factory(media.image.uri, save_path, project_id, task_id, frame_num, image_url_lst,
pointcloud_url_lst)
# 如果是直接加载数据集
elif project_id is not None:
assert isinstance(project_id, int)
pool_lst = kwargs.get("pool_lst", []) # rosetta 的项目的池子list
RosettaData(project_id, save_path).export(pool_lst)
# 设置新的json的保存目录
jsons_save_path = os.path.join(save_path, str(project_id), "jsons")
os.makedirs(jsons_save_path, exist_ok=True)
# 开始处理rosetta json,然后提取下载链接
for file in Path(save_path).joinpath(str(project_id), "json").glob(f"*.json"):
with open(file, 'r', encoding='utf-8') as f:
attachment = (data := json.load(f))['taskParams']['record']['attachment']
task_id = data['taskId']
# 帧number
frame_num = str(file).rsplit(".", 1)[0][-4:]
if not frame_num.isdigit():
frame_num = '0001'
# 提取URL与保存目录
if isinstance(attachment, Dict):
# re_find_lst = re.findall("'url': '(?P<urls>.*?)',", str(attachment))
main_url = attachment['url']
attachment_name_factory(main_url, save_path, project_id, task_id, frame_num, image_url_lst,
pointcloud_url_lst)
for att in attachment.get("imageSources", []):
img_url = att['url']
name = att['name']
attachment_name_factory(img_url,
save_path,
project_id,
task_id,
frame_num,
image_url_lst,
pointcloud_url_lst, name)
elif isinstance(attachment, str):
attachment_name_factory(attachment,
save_path,
project_id,
task_id,
frame_num,
image_url_lst,
pointcloud_url_lst)
with open(f'{jsons_save_path}/{task_id}_{frame_num}.json', 'w', encoding='utf-8') as new_json:
json.dump(data, new_json, ensure_ascii=False)
else:
raise TypeError("必要的参数缺失")
asyncio.run(Downloader(pointcloud_url_lst).async_download())
asyncio.run(Downloader(image_url_lst).async_download())
return save_path
if __name__ == '__main__':
load_dataset(save_path="/Users/mac/Documents/pyproject/sd_sdk/stardust_sdk/Stardust_SDK/data/",
project_id=1529, pool_lst=[])