环境配置
conda create --name FlashOcc python=3.8.5
conda activate FlashOcc
pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install mmcv-full==1.5.3
pip install mmdet==2.25.1
pip install mmsegmentation==0.25.0
sudo apt-get install python3-dev
sudo apt-get install libevent-dev
sudo apt-get groupinstall 'development tools'
export PATH=/usr/local/cuda/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
export CUDA_ROOT=/usr/local/cuda
pip install pycuda
pip install lyft_dataset_sdk
pip install networkx==2.2
pip install numba==0.53.0
pip install numpy==1.23.5
pip install nuscenes-devkit
pip install plyfile
pip install scikit-image
pip install tensorboard
pip install trimesh==2.35.39
pip install setuptools==59.5.0
pip install yapf==0.40.1
cd Path_to_FlashOcc
git clone git@github.com:Yzichen/FlashOCC.git
cd Path_to_FlashOcc/FlashOcc
git clone https://github.com/open-mmlab/mmdetection3d.git
cd Path_to_FlashOcc/FlashOcc/mmdetection3d
git checkout v1.0.0rc4
pip install -v -e .
cd Path_to_FlashOcc/FlashOcc/projects
pip install -v -e . 先按照 FlashOCC/doc/install.md at master · Yzichen/FlashOCC 把环境什么的给配了。
不过我的Ubuntu22的系统装的是CU11.7,并不支持11.1,所以Torch是117的版本。
数据集下载
按照 FlashOCC/doc/nuscenes_det.md at master · Yzichen/FlashOCC 去搞nuscenes数据集。
如果只是验证的话,下载Full dataset(v1.0)中的mini版本即可。下载完成之后解压,并放到下面的目录中:
└── Path_to_FlashOcc/
└── data
└── nuscenes
├── v1.0-mini
├── sweeps
└── samples 然后修改代码 tools/create_data_bevdet.py
def add_ann_adj_info(extra_tag):
# from trainval to mini
nuscenes_version = 'v1.0-mini'
dataroot = './data/nuscenes/'
nuscenes = NuScenes(nuscenes_version, dataroot)
for set in ['train', 'val']:
dataset = pickle.load(
open('%s/%s_infos_%s.pkl' % (dataroot, extra_tag, set), 'rb'))
for id in range(len(dataset['infos'])):
if id % 10 == 0:
print('%d/%d' % (id, len(dataset['infos'])))
info = dataset['infos'][id]
# get sweep adjacent frame info
sample = nuscenes.get('sample', info['token'])
ann_infos = list()
for ann in sample['anns']:
ann_info = nuscenes.get('sample_annotation', ann)
velocity = nuscenes.box_velocity(ann_info['token'])
if np.any(np.isnan(velocity)):
velocity = np.zeros(3)
ann_info['velocity'] = velocity
ann_infos.append(ann_info)
dataset['infos'][id]['ann_infos'] = ann_infos
dataset['infos'][id]['ann_infos'] = get_gt(dataset['infos'][id])
dataset['infos'][id]['scene_token'] = sample['scene_token']
scene = nuscenes.get('scene', sample['scene_token'])
dataset['infos'][id]['scene_name'] = scene['name']
dataset['infos'][id]['occ_path'] = \
'./data/nuscenes/gts/%s/%s'%(scene['name'], info['token'])
with open('%s/%s_infos_%s.pkl' % (dataroot, extra_tag, set),
'wb') as fid:
pickle.dump(dataset, fid)
if __name__ == '__main__':
dataset = 'nuscenes'
version = 'v1.0'
# from trainval to mini
train_version = f'{version}-mini'
root_path = 'data/nuscenes'
extra_tag = 'bevdetv2-nuscenes'
nuscenes_data_prep(
root_path=root_path,
info_prefix=extra_tag,
version=train_version,
max_sweeps=0)
print('add_ann_infos')
add_ann_adj_info(extra_tag)然后运行该代码,会新生成2个文件:
└── Path_to_FlashOcc/
└── data
└── nuscenes
├── v1.0-mini(existing)
├── sweeps (existing)
├── samples (existing)
├── bevdetv2-nuscenes_infos_train.pkl (new)
└── bevdetv2-nuscenes_infos_val.pkl (new)然后去下载Prediction对应的GT:
Occupancy3D-nuScenes-mini - Google 云端硬盘
然后去下载Occupancy对应的GT:
然后去下载模型:
flashocc-r50-256x704.pth - Google 云端硬盘
运行Demo
bash tools/dist_test.sh projects/configs/flashocc/flashocc-r50.py ckpts/flashocc-r50-256x704.pth 4 --eval map上面的4是显卡数,改成实际的显卡数量。
occ3d -> occ3d panoptic
因为使用的是mini数据集,所以需要修改代码,测试可用的代码在本文最后。
踩坑
CUSOLVER_STATUS_INTERNAL_ERROR
RuntimeError: cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR, when calling `cusolverDnCreate(handle)`
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1704876) of binary: /home/zhan/anaconda3/envs/FlashOcc/bin/pythonCUDA/cuSOLVER 在执行矩阵求逆时初始化失败,运行下面命令执行自检:
python -c "import torch; print(torch.version, torch.version.cuda); a=torch.eye(4,device='cuda').double().unsqueeze(0); print(torch.inverse(a))"报错:
(FlashOcc) ➜ FlashOCC git:(master) ✗ python -c "import torch; print(torch.version, torch.version.cuda); a=torch.eye(4,device='cuda').double().unsqueeze(0); print(torch.inverse(a))"
<module 'torch.version' from '/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/version.py'> 11.1
Traceback (most recent call last):
File "<string>", line 1, in <module>
RuntimeError: cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR, when calling `cusolverDnCreate(handle)`如果报错和我一样,你可能是Ada架构的显卡,比如40系RTX显卡。
根据网上的一些信息,出现这个错误的都是40系显卡。RuntimeError: cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR, when calling cusolverDnCreate(handle) · Issue #19 · Yzichen/FlashOCC 比如论文代码仓库中的这个Issue。
二楼提供了一种解决方法:
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 -f https://download.pytorch.org/whl/torch_stable.html
pip install kornia==0.5.10安装完试试:
(FlashOcc) ➜ FlashOCC git:(master) ✗ python -c "import torch; a=torch.eye(4,device='cuda').double().unsqueeze(0); print(torch.inverse(a))"
tensor([[[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]]], device='cuda:0', dtype=torch.float64)他妈的神医啊!
继续运行Demo
ImportError: /home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops7resize_4callERKNS_6TensorEN3c108ArrayRefIlEENS5_8optionalINS5_12MemoryFormatEEE
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1717570) of binary: /home/zhan/anaconda3/envs/FlashOcc/bin/python
Traceback (most recent call last):
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launch.py", line 195, in <module>
main()
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launch.py", line 191, in main
launch(args)
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launch.py", line 176, in launch
run(args)
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/run.py", line 753, in run
elastic_launch(
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 246, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 似乎是 torch 与 mmcv 的二进制不匹配,我重新装了 torch 的但是没重新装 mmcv,重新源码编译一下 mmcv看看:
# 清理库和缓存
pip uninstall -y mmcv mmcv-full mmcv-lite
pip cache purge
export MMCV_WITH_OPS=1
export FORCE_CUDA=1
export TORCH_CUDA_ARCH_LIST=8.9
pip install -U pip setuptools wheel ninja
pip install -v --no-binary mmcv-full mmcv-full==1.5.3
pip install mmdet==2.25.1 mmsegmentation==0.25.0
cd YourFlashOCCPath/mmdetection3d && pip install -v -e .
cd YourFlashOCCPath/FlashOCC/projects && pip install -v -e .报错:
ValueError: Unknown CUDA arch (8.9) or GPU not supported
error: subprocess-exited-with-error
...
...
...
Failed to build mmcv-full
ERROR: Failed to build installable wheels for some pyproject.toml based projects (mmcv-full)TORCH_CUDA_ARCH_LIST=8.9,而 torch 1.13.1 的 cpp_extension 不认识 8.9 这个ARCH,wdnmd。
请出NVIDIA的 PTX JIT,PTX 全称 parallel thread execution,JIT 全称 Just-in-Time。相比事前编译的轮子,换个架构就用不了了,JIT即时编译动态地编译成当前硬件架构可用的代码。
# 清理当前环境变量和卸载库
unset TORCH_CUDA_ARCH_LIST
pip uninstall -y mmcv mmcv-full mmcv-lite
# 重新设置编译架构
export MMCV_WITH_OPS=1
export FORCE_CUDA=1
export TORCH_CUDA_ARCH_LIST="8.6+PTX"
# 安装 mmcv-full
pip install -U pip setuptools wheel ninja
pip install -v --no-cache-dir --no-build-isolation mmcv-full==1.5.3
# 验证 mmcv
python -c "import torch,mmcv; print(torch.__version__, torch.version.cuda); print(mmcv.version, mmcv.__file__); from mmcv.ops import nms_match; print('mmcv ops ok')"如果输出:
1.13.1+cu117 11.7
<module 'mmcv.version' from '/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/mmcv/version.py'> /home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/mmcv/__init__.py
mmcv ops ok那 mmcv 就好了。
继续运行demo+1
报错+1
ImportError: /home/zhan/Projects/FlashOCC/projects/mmdet3d_plugin/ops/bev_pool/bev_pool_ext.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefIlEENS2_8optionalINS2_10ScalarTypeEEENS5_INS2_6LayoutEEENS5_INS2_6DeviceEEENS5_IbEE
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1725473) of binary: /home/zhan/anaconda3/envs/FlashOcc/bin/python
Traceback (most recent call last):
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launch.py", line 195, in <module>
main()
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launch.py", line 191, in main
launch(args)
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launch.py", line 176, in launch
run(args)
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/run.py", line 753, in run
elastic_launch(
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 246, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 新装的 torch 和 FlashOcc 自带的库又冲突了。
# 删库
rm -f /home/zhan/Projects/FlashOCC/projects/mmdet3d_plugin/ops/bev_pool/bev_pool_ext*.so
rm -f /home/zhan/Projects/FlashOCC/projects/mmdet3d_plugin/ops/bev_pool_v2/bev_pool_v2_ext*.so
rm -f /home/zhan/Projects/FlashOCC/projects/mmdet3d_plugin/ops/nearest_assign/nearest_assign_ext*.so
rm -rf /home/zhan/Projects/FlashOCC/projects/build
rm -rf flashocc_plugin.egg-info
# 重新编译安装插件
cd /home/zhan/Projects/FlashOCC/projects
export FORCE_CUDA=1
export TORCH_CUDA_ARCH_LIST="8.6+PTX"
pip install -v -e .然后验证
(FlashOcc) ➜ FlashOCC git:(master) ✗ python -c "from mmdet3d_plugin.ops import bev_pool, bev_pool_v2, nearest_assign; print('plugin ops ok')"
Using /home/zhan/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/zhan/.cache/torch_extensions/py38_cu117/dvr/build.ninja...
Building extension module dvr...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module dvr...
/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
plugin ops ok过了,运行Demo:
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 81/81, 16.0 task/s, elapsed: 5s, ETA: 0smetric = map
Starting Evaluation...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 81/81 [00:00<00:00, 283.86it/s]
===> per class IoU of 81 samples:
===> others - IoU = 27.98
===> barrier - IoU = nan
===> bicycle - IoU = 0.92
===> bus - IoU = 29.8
===> car - IoU = 38.42
===> construction_vehicle - IoU = 0.0
===> motorcycle - IoU = 7.0
===> pedestrian - IoU = 12.56
===> traffic_cone - IoU = 2.75
===> trailer - IoU = nan
===> truck - IoU = 19.56
===> driveable_surface - IoU = 72.53
===> other_flat - IoU = 0.0
===> sidewalk - IoU = 42.74
===> terrain - IoU = 34.21
===> manmade - IoU = 43.75
===> vegetation - IoU = 29.1
===> mIoU of 81 samples: 24.09
{'mIoU': array([0.28 , nan, 0.009, 0.298, 0.384, 0. , 0.07 , 0.126, 0.028,
nan, 0.196, 0.725, 0. , 0.427, 0.342, 0.437, 0.291, 0.869])}此贴完结。
get_instance_info.py
import os
import tqdm
import glob
import pickle
import argparse
import numpy as np
import torch
import multiprocessing
from pyquaternion import Quaternion
from nuscenes.utils.data_classes import Box
from nuscenes.utils.geometry_utils import points_in_box
parser = argparse.ArgumentParser()
parser.add_argument('--nusc-root', default='data/nuscenes')
parser.add_argument('--occ3d-root', default='data/nuscenes/occ3d')
parser.add_argument('--output-dir', default='data/nuscenes/occ3d_panoptic')
parser.add_argument('--version', default='v1.0-trainval')
parser.add_argument('--train-info', default=None)
parser.add_argument('--val-info', default=None)
parser.add_argument('--test-info', default=None)
args = parser.parse_args()
token2path = {}
for gt_path in glob.glob(os.path.join(args.occ3d_root, '**', 'labels.npz'), recursive=True):
token = os.path.basename(os.path.dirname(gt_path))
token2path[token] = gt_path
if len(token2path) == 0:
raise FileNotFoundError(
'No labels.npz found under occ3d-root: {}. '
'Please check --occ3d-root.'.format(args.occ3d_root))
occ_class_names = [
'others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',
'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',
'driveable_surface', 'other_flat', 'sidewalk',
'terrain', 'manmade', 'vegetation', 'free'
]
det_class_names = [
'car', 'truck', 'trailer', 'bus', 'construction_vehicle',
'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'
]
def convert_to_nusc_box(bboxes, lift_center=False, wlh_margin=0.0):
results = []
for q in range(bboxes.shape[0]):
bbox = bboxes[q].copy()
if lift_center:
bbox[2] += bbox[5] * 0.5
bbox_yaw = -bbox[6] - np.pi / 2
orientation = Quaternion(axis=[0, 0, 1], radians=bbox_yaw).inverse
box = Box(
center=[bbox[0], bbox[1], bbox[2]],
# 0.8 in pc range is roungly 2 voxels in occ grid
# enlarge bbox to include voxels on the edge
size=[bbox[3]+wlh_margin, bbox[4]+wlh_margin, bbox[5]+wlh_margin],
orientation=orientation,
)
results.append(box)
return results
def meshgrid3d(occ_size, pc_range): # points in ego coord
W, H, D = occ_size
xs = torch.linspace(0.5, W - 0.5, W).view(W, 1, 1).expand(W, H, D) / W
ys = torch.linspace(0.5, H - 0.5, H).view(1, H, 1).expand(W, H, D) / H
zs = torch.linspace(0.5, D - 0.5, D).view(1, 1, D).expand(W, H, D) / D
xs = xs * (pc_range[3] - pc_range[0]) + pc_range[0]
ys = ys * (pc_range[4] - pc_range[1]) + pc_range[1]
zs = zs * (pc_range[5] - pc_range[2]) + pc_range[2]
xyz = torch.stack((xs, ys, zs), -1)
return xyz
def process_add_instance_info(sample):
point_cloud_range = [-40, -40, -1.0, 40, 40, 5.4]
occ_size = [200, 200, 16]
num_classes = 18
occ_gt_path = token2path[sample['token']]
occ_labels = np.load(occ_gt_path)
occ_gt = occ_labels['semantics']
gt_boxes = sample['gt_boxes']
gt_names = sample['gt_names']
bboxes = convert_to_nusc_box(gt_boxes)
instance_gt = np.zeros(occ_gt.shape).astype(np.uint8)
instance_id = 1
pts = meshgrid3d(occ_size, point_cloud_range).numpy()
# filter out free voxels to accelerate
valid_idx = np.where(occ_gt < num_classes - 1)
flatten_occ_gt = occ_gt[valid_idx]
flatten_inst_gt = instance_gt[valid_idx]
flatten_pts = pts[valid_idx]
instance_boxes = []
instance_class_ids = []
for i in range(len(gt_names)):
if gt_names[i] not in occ_class_names:
continue
occ_tag_id = occ_class_names.index(gt_names[i])
# Move box to ego vehicle coord system
bbox = bboxes[i]
bbox.rotate(Quaternion(sample['lidar2ego_rotation']))
bbox.translate(np.array(sample['lidar2ego_translation']))
mask = points_in_box(bbox, flatten_pts.transpose(1, 0))
# ignore voxels not belonging to this class
mask[mask] = (flatten_occ_gt[mask] == occ_tag_id)
# ignore voxels already occupied
mask[mask] = (flatten_inst_gt[mask] == 0)
# only instance with at least 1 voxel will be recorded
if mask.sum() > 0:
flatten_inst_gt[mask] = instance_id
instance_id += 1
# enlarge boxes to include voxels on the edge
new_box = bbox.copy()
new_box.wlh = new_box.wlh + 0.8
instance_boxes.append(new_box)
instance_class_ids.append(occ_tag_id)
# classes that should be viewed as one instance
all_class_ids_unique = np.unique(occ_gt)
for i, class_name in enumerate(occ_class_names):
if class_name in det_class_names or class_name == 'free' or i not in all_class_ids_unique:
continue
flatten_inst_gt[flatten_occ_gt == i] = instance_id
instance_id += 1
# post process unconvered non-occupied voxels
uncover_idx = np.where(flatten_inst_gt == 0)
uncover_pts = flatten_pts[uncover_idx]
uncover_inst_gt = np.zeros_like(uncover_pts[..., 0]).astype(np.uint8)
unconver_occ_gt = flatten_occ_gt[uncover_idx]
# uncover_inst_dist records the dist between each voxel and its current nearest bbox's center
uncover_inst_dist = np.ones_like(uncover_pts[..., 0]) * 1e8
for i, box in enumerate(instance_boxes):
# important, non-background inst id starts from 1
inst_id = i + 1
class_id = instance_class_ids[i]
mask = points_in_box(box, uncover_pts.transpose(1, 0))
# mask voxels not belonging to this class
mask[unconver_occ_gt != class_id] = False
dist = np.sum((box.center - uncover_pts) ** 2, axis=-1)
# voxels that have already been assigned to a closer box's instance should be ignored
# voxels that not inside the box should be ignored
# `mask[(dist >= uncover_inst_dist)]=False` is right, as it only transforms True masks into False without converting False into True
# to give readers a more clear understanding, the most standard writing is `mask[mask & (dist >= uncover_inst_dist)]=False`
mask[dist >= uncover_inst_dist] = False
# mask[mask & (dist >= uncover_inst_dist)]=False
# important: only voxels inside the box (mask = True) and having no closer identical-class box need to update dist
uncover_inst_dist[mask] = dist[mask]
uncover_inst_gt[mask] = inst_id
flatten_inst_gt[uncover_idx] = uncover_inst_gt
instance_gt[valid_idx] = flatten_inst_gt
# not using this checking function yet
# assert (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum() < 100, "too many non-free voxels are not assigned to any instance in %s"%(occ_gt_path)
# global max_margin
# if max_margin < (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum():
# print("###### new max margin: ", max(max_margin, (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum()))
# max_margin = max(max_margin, (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum())
# save to original path
data_split = occ_gt_path.split(os.path.sep)[-3:]
data_path = os.path.sep.join(data_split)
##### Warning: Using args.xxx (global variable) here is strongly unrecommended
save_path = os.path.join(args.output_dir, data_path)
save_dir = os.path.split(save_path)[0]
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if np.unique(instance_gt).shape[0] != instance_gt.max()+1:
print('warning: some instance masks are covered by following ones %s'%(save_dir))
# only semantic and mask information is needed to be reserved
retain_keys = ['semantics', 'mask_lidar', 'mask_camera']
new_occ_labels = {k: occ_labels[k] for k in retain_keys}
new_occ_labels['instances'] = instance_gt
np.savez_compressed(save_path, **new_occ_labels)
def add_instance_info(sample_infos):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
# all cpus participate in multi processing
pool = multiprocessing.Pool(multiprocessing.cpu_count())
with tqdm.tqdm(total=len(sample_infos['infos'])) as pbar:
for _ in pool.imap(process_add_instance_info, sample_infos['infos']):
pbar.update(1)
pool.close()
pool.join()
def resolve_info_file(user_path, candidates):
if user_path is not None:
if not os.path.exists(user_path):
raise FileNotFoundError('Cannot find info file: {}'.format(user_path))
return user_path
for filename in candidates:
path = os.path.join(args.nusc_root, filename)
if os.path.exists(path):
return path
raise FileNotFoundError(
'Cannot find info file under {}. Tried: {}'.format(
args.nusc_root, ', '.join(candidates)))
if __name__ == '__main__':
if args.version in ['v1.0-trainval', 'v1.0-mini']:
train_info = resolve_info_file(
args.train_info,
['bevdetv2-nuscenes_infos_train.pkl', 'nuscenes_infos_train_sweep.pkl'])
val_info = resolve_info_file(
args.val_info,
['bevdetv2-nuscenes_infos_val.pkl', 'nuscenes_infos_val_sweep.pkl'])
sample_infos = pickle.load(open(train_info, 'rb'))
add_instance_info(sample_infos)
sample_infos = pickle.load(open(val_info, 'rb'))
add_instance_info(sample_infos)
elif args.version == 'v1.0-test':
test_info = resolve_info_file(
args.test_info,
['nuscenes_infos_test_sweep.pkl', 'bevdetv2-nuscenes_infos_test.pkl'])
sample_infos = pickle.load(open(test_info, 'rb'))
add_instance_info(sample_infos)
else:
raise ValueError
评论