Shortcuts

Source code for mmseg.models.decode_heads.vpd_depth_head

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
from mmengine.model import BaseModule
from torch import Tensor

from mmseg.registry import MODELS
from mmseg.utils import SampleList
from ..utils import resize
from .decode_head import BaseDecodeHead


class VPDDepthDecoder(BaseModule):
    """VPD Depth Decoder class.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        num_deconv_layers (int): Number of deconvolution layers.
        num_deconv_filters (List[int]): List of output channels for
            deconvolution layers.
        init_cfg (Optional[Union[Dict, List[Dict]]], optional): Configuration
            for weight initialization. Defaults to Normal for Conv2d and
            ConvTranspose2d layers.
    """

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 num_deconv_layers: int,
                 num_deconv_filters: List[int],
                 init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
                     type='Normal',
                     std=0.001,
                     layer=['Conv2d', 'ConvTranspose2d'])):
        super().__init__(init_cfg=init_cfg)
        self.in_channels = in_channels

        self.deconv_layers = self._make_deconv_layer(
            num_deconv_layers,
            num_deconv_filters,
        )

        conv_layers = []
        conv_layers.append(
            build_conv_layer(
                dict(type='Conv2d'),
                in_channels=num_deconv_filters[-1],
                out_channels=out_channels,
                kernel_size=3,
                stride=1,
                padding=1))
        conv_layers.append(build_norm_layer(dict(type='BN'), out_channels)[1])
        conv_layers.append(nn.ReLU(inplace=True))
        self.conv_layers = nn.Sequential(*conv_layers)

        self.up_sample = nn.Upsample(
            scale_factor=2, mode='bilinear', align_corners=False)

    def forward(self, x):
        """Forward pass through the decoder network."""
        out = self.deconv_layers(x)
        out = self.conv_layers(out)

        out = self.up_sample(out)
        out = self.up_sample(out)

        return out

    def _make_deconv_layer(self, num_layers, num_deconv_filters):
        """Make deconv layers."""

        layers = []
        in_channels = self.in_channels
        for i in range(num_layers):

            num_channels = num_deconv_filters[i]
            layers.append(
                build_upsample_layer(
                    dict(type='deconv'),
                    in_channels=in_channels,
                    out_channels=num_channels,
                    kernel_size=2,
                    stride=2,
                    padding=0,
                    output_padding=0,
                    bias=False))
            layers.append(nn.BatchNorm2d(num_channels))
            layers.append(nn.ReLU(inplace=True))
            in_channels = num_channels

        return nn.Sequential(*layers)


[docs] @MODELS.register_module() class VPDDepthHead(BaseDecodeHead): """Depth Prediction Head for VPD. .. _`VPD`: https://arxiv.org/abs/2303.02153 Args: max_depth (float): Maximum depth value. Defaults to 10.0. in_channels (Sequence[int]): Number of input channels for each convolutional layer. embed_dim (int): Dimension of embedding. Defaults to 192. feature_dim (int): Dimension of aggregated feature. Defaults to 1536. num_deconv_layers (int): Number of deconvolution layers in the decoder. Defaults to 3. num_deconv_filters (Sequence[int]): Number of filters for each deconv layer. Defaults to (32, 32, 32). fmap_border (Union[int, Sequence[int]]): Feature map border for cropping. Defaults to 0. align_corners (bool): Flag for align_corners in interpolation. Defaults to False. loss_decode (dict): Configurations for the loss function. Defaults to dict(type='SiLogLoss'). init_cfg (dict): Initialization configurations. Defaults to dict(type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']). """ num_classes = 1 out_channels = 1 input_transform = None def __init__( self, max_depth: float = 10.0, in_channels: Sequence[int] = [320, 640, 1280, 1280], embed_dim: int = 192, feature_dim: int = 1536, num_deconv_layers: int = 3, num_deconv_filters: Sequence[int] = (32, 32, 32), fmap_border: Union[int, Sequence[int]] = 0, align_corners: bool = False, loss_decode: dict = dict(type='SiLogLoss'), init_cfg=dict( type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']), ): super(BaseDecodeHead, self).__init__(init_cfg=init_cfg) # initialize parameters self.in_channels = in_channels self.max_depth = max_depth self.align_corners = align_corners # feature map border if isinstance(fmap_border, int): fmap_border = (fmap_border, fmap_border) self.fmap_border = fmap_border # define network layers self.conv1 = nn.Sequential( nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1), nn.GroupNorm(16, in_channels[0]), nn.ReLU(), nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1), ) self.conv2 = nn.Conv2d( in_channels[1], in_channels[1], 3, stride=2, padding=1) self.conv_aggregation = nn.Sequential( nn.Conv2d(sum(in_channels), feature_dim, 1), nn.GroupNorm(16, feature_dim), nn.ReLU(), ) self.decoder = VPDDepthDecoder( in_channels=embed_dim * 8, out_channels=embed_dim, num_deconv_layers=num_deconv_layers, num_deconv_filters=num_deconv_filters) self.depth_pred_layer = nn.Sequential( nn.Conv2d( embed_dim, embed_dim, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False), nn.Conv2d(embed_dim, 1, kernel_size=3, stride=1, padding=1)) # build loss if isinstance(loss_decode, dict): self.loss_decode = MODELS.build(loss_decode) elif isinstance(loss_decode, (list, tuple)): self.loss_decode = nn.ModuleList() for loss in loss_decode: self.loss_decode.append(MODELS.build(loss)) else: raise TypeError(f'loss_decode must be a dict or sequence of dict,\ but got {type(loss_decode)}') def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor: gt_depth_maps = [ data_sample.gt_depth_map.data for data_sample in batch_data_samples ] return torch.stack(gt_depth_maps, dim=0)
[docs] def forward(self, x): x = [ x[0], x[1], torch.cat([x[2], F.interpolate(x[3], scale_factor=2)], dim=1) ] x = torch.cat([self.conv1(x[0]), self.conv2(x[1]), x[2]], dim=1) x = self.conv_aggregation(x) x = x[:, :, :x.size(2) - self.fmap_border[0], :x.size(3) - self.fmap_border[1]].contiguous() x = self.decoder(x) out = self.depth_pred_layer(x) depth = torch.sigmoid(out) * self.max_depth return depth
[docs] def loss_by_feat(self, pred_depth_map: Tensor, batch_data_samples: SampleList) -> dict: """Compute depth estimation loss. Args: pred_depth_map (Tensor): The output from decode head forward function. batch_data_samples (List[:obj:`SegDataSample`]): The seg data samples. It usually includes information such as `metainfo` and `gt_dpeth_map`. Returns: dict[str, Tensor]: a dictionary of loss components """ gt_depth_map = self._stack_batch_gt(batch_data_samples) loss = dict() pred_depth_map = resize( input=pred_depth_map, size=gt_depth_map.shape[2:], mode='bilinear', align_corners=self.align_corners) if not isinstance(self.loss_decode, nn.ModuleList): losses_decode = [self.loss_decode] else: losses_decode = self.loss_decode for loss_decode in losses_decode: if loss_decode.loss_name not in loss: loss[loss_decode.loss_name] = loss_decode( pred_depth_map, gt_depth_map) else: loss[loss_decode.loss_name] += loss_decode( pred_depth_map, gt_depth_map) return loss