环境配置

conda create --name FlashOcc python=3.8.5
conda activate FlashOcc
pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install mmcv-full==1.5.3
pip install mmdet==2.25.1
pip install mmsegmentation==0.25.0

sudo apt-get install python3-dev 
sudo apt-get install libevent-dev
sudo apt-get groupinstall 'development tools'
export PATH=/usr/local/cuda/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
export CUDA_ROOT=/usr/local/cuda
pip install pycuda

pip install lyft_dataset_sdk
pip install networkx==2.2
pip install numba==0.53.0
pip install numpy==1.23.5
pip install nuscenes-devkit
pip install plyfile
pip install scikit-image
pip install tensorboard
pip install trimesh==2.35.39
pip install setuptools==59.5.0
pip install yapf==0.40.1

cd Path_to_FlashOcc
git clone git@github.com:Yzichen/FlashOCC.git

cd Path_to_FlashOcc/FlashOcc
git clone https://github.com/open-mmlab/mmdetection3d.git

cd Path_to_FlashOcc/FlashOcc/mmdetection3d
git checkout v1.0.0rc4
pip install -v -e . 

cd Path_to_FlashOcc/FlashOcc/projects
pip install -v -e . 

先按照 FlashOCC/doc/install.md at master · Yzichen/FlashOCC 把环境什么的给配了。

不过我的Ubuntu22的系统装的是CU11.7,并不支持11.1,所以Torch是117的版本。

数据集下载

按照 FlashOCC/doc/nuscenes_det.md at master · Yzichen/FlashOCC 去搞nuscenes数据集。

如果只是验证的话,下载Full dataset(v1.0)中的mini版本即可。下载完成之后解压,并放到下面的目录中:

└── Path_to_FlashOcc/
    └── data
        └── nuscenes
            ├── v1.0-mini
            ├── sweeps  
            └── samples 

然后修改代码 tools/create_data_bevdet.py

def add_ann_adj_info(extra_tag):
    # from trainval to mini
    nuscenes_version = 'v1.0-mini'
    dataroot = './data/nuscenes/'
    nuscenes = NuScenes(nuscenes_version, dataroot)
    for set in ['train', 'val']:
        dataset = pickle.load(
            open('%s/%s_infos_%s.pkl' % (dataroot, extra_tag, set), 'rb'))
        for id in range(len(dataset['infos'])):
            if id % 10 == 0:
                print('%d/%d' % (id, len(dataset['infos'])))
            info = dataset['infos'][id]
            # get sweep adjacent frame info
            sample = nuscenes.get('sample', info['token'])
            ann_infos = list()
            for ann in sample['anns']:
                ann_info = nuscenes.get('sample_annotation', ann)
                velocity = nuscenes.box_velocity(ann_info['token'])
                if np.any(np.isnan(velocity)):
                    velocity = np.zeros(3)
                ann_info['velocity'] = velocity
                ann_infos.append(ann_info)
            dataset['infos'][id]['ann_infos'] = ann_infos
            dataset['infos'][id]['ann_infos'] = get_gt(dataset['infos'][id])
            dataset['infos'][id]['scene_token'] = sample['scene_token']

            scene = nuscenes.get('scene', sample['scene_token'])
            dataset['infos'][id]['scene_name'] = scene['name']
            dataset['infos'][id]['occ_path'] = \
                './data/nuscenes/gts/%s/%s'%(scene['name'], info['token'])
        with open('%s/%s_infos_%s.pkl' % (dataroot, extra_tag, set),
                  'wb') as fid:
            pickle.dump(dataset, fid)


if __name__ == '__main__':
    dataset = 'nuscenes'
    version = 'v1.0'
    # from trainval to mini
    train_version = f'{version}-mini'
    root_path = 'data/nuscenes'
    extra_tag = 'bevdetv2-nuscenes'
    nuscenes_data_prep(
        root_path=root_path,
        info_prefix=extra_tag,
        version=train_version,
        max_sweeps=0)

    print('add_ann_infos')
    add_ann_adj_info(extra_tag)

然后运行该代码,会新生成2个文件:

└── Path_to_FlashOcc/
    └── data
        └── nuscenes
            ├── v1.0-mini(existing)
            ├── sweeps  (existing)
            ├── samples (existing)
            ├── bevdetv2-nuscenes_infos_train.pkl (new)
            └── bevdetv2-nuscenes_infos_val.pkl (new)

然后去下载Prediction对应的GT:

Occupancy3D-nuScenes-mini - Google 云端硬盘

然后去下载Occupancy对应的GT:

gts.tar.gz - Google 云端硬盘

然后去下载模型:

flashocc-r50-256x704.pth - Google 云端硬盘

运行Demo

bash tools/dist_test.sh projects/configs/flashocc/flashocc-r50.py  ckpts/flashocc-r50-256x704.pth 4 --eval map

上面的4是显卡数,改成实际的显卡数量。

occ3d -> occ3d panoptic

因为使用的是mini数据集,所以需要修改代码,测试可用的代码在本文最后。

踩坑

  1. CUSOLVER_STATUS_INTERNAL_ERROR

RuntimeError: cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR, when calling `cusolverDnCreate(handle)`
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1704876) of binary: /home/zhan/anaconda3/envs/FlashOcc/bin/python

CUDA/cuSOLVER 在执行矩阵求逆时初始化失败,运行下面命令执行自检:

python -c "import torch; print(torch.version, torch.version.cuda); a=torch.eye(4,device='cuda').double().unsqueeze(0); print(torch.inverse(a))"

报错:

(FlashOcc) ➜  FlashOCC git:(master) ✗ python -c "import torch; print(torch.version, torch.version.cuda); a=torch.eye(4,device='cuda').double().unsqueeze(0); print(torch.inverse(a))"
<module 'torch.version' from '/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/version.py'> 11.1
Traceback (most recent call last):
  File "<string>", line 1, in <module>
RuntimeError: cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR, when calling `cusolverDnCreate(handle)`

如果报错和我一样,你可能是Ada架构的显卡,比如40系RTX显卡。

根据网上的一些信息,出现这个错误的都是40系显卡。RuntimeError: cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR, when calling cusolverDnCreate(handle) · Issue #19 · Yzichen/FlashOCC 比如论文代码仓库中的这个Issue。

不过似乎也不是完全没法解决:RuntimeError: cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR, when calling `cusolverDnCreate(handle) · Issue #164 · nerdyrodent/VQGAN-CLIP

二楼提供了一种解决方法:

pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 -f https://download.pytorch.org/whl/torch_stable.html
pip install kornia==0.5.10

安装完试试:

(FlashOcc) ➜  FlashOCC git:(master) ✗ python -c "import torch; a=torch.eye(4,device='cuda').double().unsqueeze(0); print(torch.inverse(a))"

tensor([[[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]]], device='cuda:0', dtype=torch.float64)

他妈的神医啊!

  1. 继续运行Demo

ImportError: /home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops7resize_4callERKNS_6TensorEN3c108ArrayRefIlEENS5_8optionalINS5_12MemoryFormatEEE
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1717570) of binary: /home/zhan/anaconda3/envs/FlashOcc/bin/python
Traceback (most recent call last):
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launch.py", line 195, in <module>
    main()
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launch.py", line 191, in main
    launch(args)
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launch.py", line 176, in launch
    run(args)
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/run.py", line 753, in run
    elastic_launch(
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 246, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 

似乎是 torch 与 mmcv 的二进制不匹配,我重新装了 torch 的但是没重新装 mmcv,重新源码编译一下 mmcv看看:

# 清理库和缓存
pip uninstall -y mmcv mmcv-full mmcv-lite
pip cache purge

export MMCV_WITH_OPS=1
export FORCE_CUDA=1
export TORCH_CUDA_ARCH_LIST=8.9
pip install -U pip setuptools wheel ninja
pip install -v --no-binary mmcv-full mmcv-full==1.5.3

pip install mmdet==2.25.1 mmsegmentation==0.25.0
cd YourFlashOCCPath/mmdetection3d && pip install -v -e .
cd YourFlashOCCPath/FlashOCC/projects && pip install -v -e .

报错:

ValueError: Unknown CUDA arch (8.9) or GPU not supported
error: subprocess-exited-with-error
...
...
...
Failed to build mmcv-full
ERROR: Failed to build installable wheels for some pyproject.toml based projects (mmcv-full)

TORCH_CUDA_ARCH_LIST=8.9,而 torch 1.13.1 的 cpp_extension 不认识 8.9 这个ARCH,wdnmd。

请出NVIDIA的 PTX JIT,PTX 全称 parallel thread execution,JIT 全称 Just-in-Time。相比事前编译的轮子,换个架构就用不了了,JIT即时编译动态地编译成当前硬件架构可用的代码。

# 清理当前环境变量和卸载库
unset TORCH_CUDA_ARCH_LIST
pip uninstall -y mmcv mmcv-full mmcv-lite

# 重新设置编译架构
export MMCV_WITH_OPS=1
export FORCE_CUDA=1
export TORCH_CUDA_ARCH_LIST="8.6+PTX"

# 安装 mmcv-full
pip install -U pip setuptools wheel ninja
pip install -v --no-cache-dir --no-build-isolation mmcv-full==1.5.3

# 验证 mmcv
python -c "import torch,mmcv; print(torch.__version__, torch.version.cuda); print(mmcv.version, mmcv.__file__); from mmcv.ops import nms_match; print('mmcv ops ok')"

如果输出:

1.13.1+cu117 11.7
<module 'mmcv.version' from '/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/mmcv/version.py'> /home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/mmcv/__init__.py
mmcv ops ok

那 mmcv 就好了。

  1. 继续运行demo+1

报错+1

ImportError: /home/zhan/Projects/FlashOCC/projects/mmdet3d_plugin/ops/bev_pool/bev_pool_ext.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefIlEENS2_8optionalINS2_10ScalarTypeEEENS5_INS2_6LayoutEEENS5_INS2_6DeviceEEENS5_IbEE
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1725473) of binary: /home/zhan/anaconda3/envs/FlashOcc/bin/python
Traceback (most recent call last):
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launch.py", line 195, in <module>
    main()
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launch.py", line 191, in main
    launch(args)
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launch.py", line 176, in launch
    run(args)
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/run.py", line 753, in run
    elastic_launch(
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 246, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 

新装的 torch 和 FlashOcc 自带的库又冲突了。

# 删库
rm -f /home/zhan/Projects/FlashOCC/projects/mmdet3d_plugin/ops/bev_pool/bev_pool_ext*.so
rm -f /home/zhan/Projects/FlashOCC/projects/mmdet3d_plugin/ops/bev_pool_v2/bev_pool_v2_ext*.so
rm -f /home/zhan/Projects/FlashOCC/projects/mmdet3d_plugin/ops/nearest_assign/nearest_assign_ext*.so
rm -rf /home/zhan/Projects/FlashOCC/projects/build
rm -rf flashocc_plugin.egg-info

# 重新编译安装插件
cd /home/zhan/Projects/FlashOCC/projects
export FORCE_CUDA=1
export TORCH_CUDA_ARCH_LIST="8.6+PTX"
pip install -v -e .

然后验证

(FlashOcc) ➜  FlashOCC git:(master) ✗ python -c "from mmdet3d_plugin.ops import bev_pool, bev_pool_v2, nearest_assign; print('plugin ops ok')"

Using /home/zhan/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/zhan/.cache/torch_extensions/py38_cu117/dvr/build.ninja...
Building extension module dvr...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module dvr...
/home/zhan/anaconda3/envs/FlashOcc/lib/python3.8/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
plugin ops ok

过了,运行Demo:

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 81/81, 16.0 task/s, elapsed: 5s, ETA:     0smetric =  map

Starting Evaluation...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 81/81 [00:00<00:00, 283.86it/s]
===> per class IoU of 81 samples:
===> others - IoU = 27.98
===> barrier - IoU = nan
===> bicycle - IoU = 0.92
===> bus - IoU = 29.8
===> car - IoU = 38.42
===> construction_vehicle - IoU = 0.0
===> motorcycle - IoU = 7.0
===> pedestrian - IoU = 12.56
===> traffic_cone - IoU = 2.75
===> trailer - IoU = nan
===> truck - IoU = 19.56
===> driveable_surface - IoU = 72.53
===> other_flat - IoU = 0.0
===> sidewalk - IoU = 42.74
===> terrain - IoU = 34.21
===> manmade - IoU = 43.75
===> vegetation - IoU = 29.1
===> mIoU of 81 samples: 24.09
{'mIoU': array([0.28 ,   nan, 0.009, 0.298, 0.384, 0.   , 0.07 , 0.126, 0.028,
         nan, 0.196, 0.725, 0.   , 0.427, 0.342, 0.437, 0.291, 0.869])}

此贴完结。

get_instance_info.py

import os
import tqdm
import glob
import pickle
import argparse
import numpy as np
import torch
import multiprocessing
from pyquaternion import Quaternion
from nuscenes.utils.data_classes import Box
from nuscenes.utils.geometry_utils import points_in_box


parser = argparse.ArgumentParser()
parser.add_argument('--nusc-root', default='data/nuscenes')
parser.add_argument('--occ3d-root', default='data/nuscenes/occ3d')
parser.add_argument('--output-dir', default='data/nuscenes/occ3d_panoptic')
parser.add_argument('--version', default='v1.0-trainval')
parser.add_argument('--train-info', default=None)
parser.add_argument('--val-info', default=None)
parser.add_argument('--test-info', default=None)
args = parser.parse_args()

token2path = {}
for gt_path in glob.glob(os.path.join(args.occ3d_root, '**', 'labels.npz'), recursive=True):
    token = os.path.basename(os.path.dirname(gt_path))
    token2path[token] = gt_path

if len(token2path) == 0:
    raise FileNotFoundError(
        'No labels.npz found under occ3d-root: {}. '
        'Please check --occ3d-root.'.format(args.occ3d_root))

occ_class_names = [
    'others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',
    'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',
    'driveable_surface', 'other_flat', 'sidewalk',
    'terrain', 'manmade', 'vegetation', 'free'
]

det_class_names = [
    'car', 'truck', 'trailer', 'bus', 'construction_vehicle',
    'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'
]


def convert_to_nusc_box(bboxes, lift_center=False, wlh_margin=0.0):
    results = []
    for q in range(bboxes.shape[0]):

        bbox = bboxes[q].copy()
        if lift_center:
            bbox[2] += bbox[5] * 0.5

        bbox_yaw = -bbox[6] - np.pi / 2
        orientation = Quaternion(axis=[0, 0, 1], radians=bbox_yaw).inverse

        box = Box(
            center=[bbox[0], bbox[1], bbox[2]],
            # 0.8 in pc range is roungly 2 voxels in occ grid
            # enlarge bbox to include voxels on the edge
            size=[bbox[3]+wlh_margin, bbox[4]+wlh_margin, bbox[5]+wlh_margin],
            orientation=orientation,
        )

        results.append(box)

    return results


def meshgrid3d(occ_size, pc_range):  # points in ego coord
    W, H, D = occ_size
    
    xs = torch.linspace(0.5, W - 0.5, W).view(W, 1, 1).expand(W, H, D) / W
    ys = torch.linspace(0.5, H - 0.5, H).view(1, H, 1).expand(W, H, D) / H
    zs = torch.linspace(0.5, D - 0.5, D).view(1, 1, D).expand(W, H, D) / D
    xs = xs * (pc_range[3] - pc_range[0]) + pc_range[0]
    ys = ys * (pc_range[4] - pc_range[1]) + pc_range[1]
    zs = zs * (pc_range[5] - pc_range[2]) + pc_range[2]
    xyz = torch.stack((xs, ys, zs), -1)

    return xyz


def process_add_instance_info(sample):
    point_cloud_range = [-40, -40, -1.0, 40, 40, 5.4]
    occ_size = [200, 200, 16]
    num_classes = 18
    
    occ_gt_path = token2path[sample['token']]
    occ_labels = np.load(occ_gt_path)
    
    occ_gt = occ_labels['semantics']
    gt_boxes = sample['gt_boxes']
    gt_names = sample['gt_names']
    
    bboxes = convert_to_nusc_box(gt_boxes)
    
    instance_gt = np.zeros(occ_gt.shape).astype(np.uint8)
    instance_id = 1
    
    pts = meshgrid3d(occ_size, point_cloud_range).numpy()
    
    # filter out free voxels to accelerate
    valid_idx = np.where(occ_gt < num_classes - 1)
    flatten_occ_gt = occ_gt[valid_idx]
    flatten_inst_gt = instance_gt[valid_idx]
    flatten_pts = pts[valid_idx]
    
    instance_boxes = []
    instance_class_ids = []
    
    for i in range(len(gt_names)):
        if gt_names[i] not in occ_class_names:
            continue
        occ_tag_id = occ_class_names.index(gt_names[i])
            
        # Move box to ego vehicle coord system
        bbox = bboxes[i]
        bbox.rotate(Quaternion(sample['lidar2ego_rotation']))
        bbox.translate(np.array(sample['lidar2ego_translation']))
        
        mask = points_in_box(bbox, flatten_pts.transpose(1, 0))
        
        # ignore voxels not belonging to this class
        mask[mask] = (flatten_occ_gt[mask] == occ_tag_id)
        # ignore voxels already occupied
        mask[mask] = (flatten_inst_gt[mask] == 0)
        
        # only instance with at least 1 voxel will be recorded
        if mask.sum() > 0:
            flatten_inst_gt[mask] = instance_id
            instance_id += 1
            
            # enlarge boxes to include voxels on the edge
            new_box = bbox.copy()
            new_box.wlh = new_box.wlh + 0.8
            
            instance_boxes.append(new_box)
            instance_class_ids.append(occ_tag_id)
    
    # classes that should be viewed as one instance
    all_class_ids_unique = np.unique(occ_gt)
    for i, class_name in enumerate(occ_class_names):
        if class_name in det_class_names or class_name == 'free' or i not in all_class_ids_unique:
            continue
        flatten_inst_gt[flatten_occ_gt == i] = instance_id
        instance_id += 1
    
    # post process unconvered non-occupied voxels
    uncover_idx = np.where(flatten_inst_gt == 0)
    uncover_pts = flatten_pts[uncover_idx]
    uncover_inst_gt = np.zeros_like(uncover_pts[..., 0]).astype(np.uint8)
    unconver_occ_gt = flatten_occ_gt[uncover_idx]
    
    # uncover_inst_dist records the dist between each voxel and its current nearest bbox's center
    uncover_inst_dist = np.ones_like(uncover_pts[..., 0]) * 1e8
    for i, box in enumerate(instance_boxes):
        # important, non-background inst id starts from 1
        inst_id = i + 1
        class_id = instance_class_ids[i]
        mask = points_in_box(box, uncover_pts.transpose(1, 0))
        # mask voxels not belonging to this class
        mask[unconver_occ_gt != class_id] = False
        dist = np.sum((box.center - uncover_pts) ** 2, axis=-1)
        # voxels that have already been assigned to a closer box's instance should be ignored
        # voxels that not inside the box should be ignored
        # `mask[(dist >= uncover_inst_dist)]=False` is right, as it only transforms True masks into False without converting False into True
        # to give readers a more clear understanding, the most standard writing is `mask[mask & (dist >= uncover_inst_dist)]=False`
        mask[dist >= uncover_inst_dist] = False
        # mask[mask & (dist >= uncover_inst_dist)]=False
        
        # important: only voxels inside the box (mask = True) and having no closer identical-class box need to update dist
        uncover_inst_dist[mask] = dist[mask]
        uncover_inst_gt[mask] = inst_id
        
    flatten_inst_gt[uncover_idx] = uncover_inst_gt
    
    instance_gt[valid_idx] = flatten_inst_gt
    # not using this checking function yet
    # assert (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum() < 100, "too many non-free voxels are not assigned to any instance in %s"%(occ_gt_path)
    # global max_margin
    # if max_margin < (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum():
    #     print("###### new max margin: ", max(max_margin, (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum()))
    # max_margin = max(max_margin, (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum())
    
    # save to original path
    data_split = occ_gt_path.split(os.path.sep)[-3:]
    data_path = os.path.sep.join(data_split)
    
    ##### Warning: Using args.xxx (global variable) here is strongly unrecommended
    save_path = os.path.join(args.output_dir, data_path)
    
    save_dir = os.path.split(save_path)[0]
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    if np.unique(instance_gt).shape[0] != instance_gt.max()+1:
        print('warning: some instance masks are covered by following ones %s'%(save_dir))
    
    # only semantic and mask information is needed to be reserved
    retain_keys = ['semantics', 'mask_lidar', 'mask_camera']   
    new_occ_labels = {k: occ_labels[k] for k in retain_keys}
    new_occ_labels['instances'] = instance_gt
    np.savez_compressed(save_path, **new_occ_labels)


def add_instance_info(sample_infos):
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    
    # all cpus participate in multi processing
    pool = multiprocessing.Pool(multiprocessing.cpu_count())
    with tqdm.tqdm(total=len(sample_infos['infos'])) as pbar:
        for _ in pool.imap(process_add_instance_info, sample_infos['infos']):
            pbar.update(1)
    
    pool.close()
    pool.join()


def resolve_info_file(user_path, candidates):
    if user_path is not None:
        if not os.path.exists(user_path):
            raise FileNotFoundError('Cannot find info file: {}'.format(user_path))
        return user_path

    for filename in candidates:
        path = os.path.join(args.nusc_root, filename)
        if os.path.exists(path):
            return path

    raise FileNotFoundError(
        'Cannot find info file under {}. Tried: {}'.format(
            args.nusc_root, ', '.join(candidates)))


if __name__ == '__main__':
    if args.version in ['v1.0-trainval', 'v1.0-mini']:
        train_info = resolve_info_file(
            args.train_info,
            ['bevdetv2-nuscenes_infos_train.pkl', 'nuscenes_infos_train_sweep.pkl'])
        val_info = resolve_info_file(
            args.val_info,
            ['bevdetv2-nuscenes_infos_val.pkl', 'nuscenes_infos_val_sweep.pkl'])

        sample_infos = pickle.load(open(train_info, 'rb'))
        add_instance_info(sample_infos)

        sample_infos = pickle.load(open(val_info, 'rb'))
        add_instance_info(sample_infos)

    elif args.version == 'v1.0-test':
        test_info = resolve_info_file(
            args.test_info,
            ['nuscenes_infos_test_sweep.pkl', 'bevdetv2-nuscenes_infos_test.pkl'])
        sample_infos = pickle.load(open(test_info, 'rb'))
        add_instance_info(sample_infos)

    else:
        raise ValueError