📊
🧬
🔍
⚗️
💡
🎯
🚀

DiffPose代码笔记

详解论文与代码对应关系

Posted by CYXYZ on February 26, 2024

DiffPose notebook

预处理X-ray

通过CT可以得到虚拟X射线图像,称为DRRs(数字重建投影)。为了对真实X射线和CT图像模拟产生的X射线关系进行操作,需要对X射线做预处理转变为DRRs图像。

预处理步骤:

  1. 从每个边缘裁剪50像素,以去除由准直器(collimator)引起的影响。 准直器是一种医学成像和放射线治疗领域的重要组件。其主要功能是限制放射线束的大小和形状,并将其定向到一个特定的区域。作用有三点:限定成像的空间范围(限定断层层厚)。降低患者的表面辐射剂量。减少进人探测器的散射线。

  2. 反转成像方程以恢复线积分射线图(line integral radiograph)。 比尔-朗伯定律:描述透过介质光强衰减的物理定律。 \begin{equation} I_f[i,j]=I_0\exp(-L[i,j]), \end{equation} 文章进一步解释了如何通过公式进行反转计算线积分\(L\): \begin{equation} L[i,j]=\log I_0-\log I_f[i,j], \end{equation} \(I_0\)被近似为\(I[i,j]\)的最大值,假设这个代表一个没有与体积相交就达到探测器平面的射线。

  3. 将图像重新归一化到[0, 1]范围内。

注意,文章中提取真实X光数据集的代码是:

specimen = DeepFluoroDataset(
    1,
    filename=filename,
    preprocess=True,  # Set as True to preprocess images
)
processed_xray, _ = specimen[0]

上面如果preprocess=True,那么将得到预处理真实X光之后的数据集,而preprocess=False,将得到真实的X光数据集。

若使用DRR来生成数据,文章中对应的代码部分是:

height = 1536 
dx = 0.194
drr = DRR(
    specimen.volume,
    specimen.spacing,
    sdr=specimen.focal_len / 2,
    height=height,
    delx=dx,
    x0=specimen.x0,
    y0=specimen.y0,
    reverse_x_axis=True,
    patch_size=384,
).to(device)

设置specimen之后,就能使用这个代码生成DRR对象。使用下面的语句可以得到通过DRR产生的不同姿势的x光图像。

pred_xray = drr(pose.to(device))

这样产生的pred_xray和processed_xray仍然有一定差距,因此,可以在生成DRR之前对CT进行预处理,将空气、软组织和骨头进行分割。

使用bone_attenuation_multiplier = 1.0(默认值)会将空气像素的值设为0。 增加bone_attenuation_multiplier会使骨头的密度高于软组织(即增加DRR中的对比度)。

drr = DRR(
    specimen.volume,
    specimen.spacing,
    sdr=specimen.focal_len / 2,
    height=height,
    delx=dx,
    x0=specimen.x0,
    y0=specimen.y0,
    reverse_x_axis=True,
    patch_size=359,
).to(device)

_, pose = specimen[0]
pred_xray = drr(
    pose.to(device),
    bone_attenuation_multiplier=2.5,  # Set the bone attenuation multiplier
)

通过如下代码绘制:

plt.figure(constrained_layout=True)
plt.subplot(121)
plt.title("DRR")
plt.imshow(pred_xray.squeeze().cpu().numpy(), cmap="gray")
plt.subplot(122)
plt.title("Processed X-ray")
plt.imshow(processed_xray.squeeze(), cmap="gray")
plt.show()

loading-ag-491

这样最后的图片大致为这样。

相机姿态

文章使用不同旋转平移的相机姿态SE3来表示相机。在这里,相机的姿态由三个旋转参数和三个平移参数组成。文中提到,这些参数是从独立的正态分布中采样的,这意味着每个参数都是从一个具有足够方差的正态分布中随机选择的,以便捕捉到从等轴中心(isocenter)发生的广泛变化。文章中对应的代码如下:

@beartype
def get_random_offset(batch_size: int, device) -> RigidTransform:
    r1 = torch.distributions.Normal(0, 0.2).sample((batch_size,))
    r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))
    r3 = torch.distributions.Normal(0, 0.25).sample((batch_size,))
    t1 = torch.distributions.Normal(10, 70).sample((batch_size,))
    t2 = torch.distributions.Normal(250, 90).sample((batch_size,))
    t3 = torch.distributions.Normal(5, 50).sample((batch_size,))
    log_R_vee = torch.stack([r1, r2, r3], dim=1).to(device)
    log_t_vee = torch.stack([t1, t2, t3], dim=1).to(device)
    return convert(
        log_R_vee,
        log_t_vee,
        parameterization="se3_log_map",
    )

相机的内外参详解

参考一文带你搞懂相机内参外参(Intrinsics & Extrinsics) - 知乎 (zhihu.com)

讲的蛮清楚的,最终就是:

内参矩阵:

相机坐标系与像素坐标系的相互变换。3D <—> 2D 。当然从像素坐标系到图像坐标系与之前做智能车接收到的map数据相同,就是一个是真实世界中的距离,一个是像素世界中的距离,中间有一个分辨率的影响。

外参矩阵:

相机坐标系与世界坐标系的相互变换。3D <—> 3D

fiducials

搞了好久终于能够可视化fiducials了,主要就是利用数据集的get_2d_groundtruth函数,但是由于训练创建的drr图像是256*256的,但是fiducials仍然保持之前1436的尺度不变,因此绘制出来的fidicials在尺度上有一些问题。这时候对fidicials进行下采样把尺度变为256的即可在原先的图片上绘制fiducials。

train.py的结构

load函数

加载3D数据集

# 加载数据集
specimen = DeepFluoroDataset(id_number)
# 将等中心姿态移至设备
isocenter_pose = specimen.isocenter_pose.to(device)

这里的等中心姿态是指X光的C臂机拍摄时的中心角度

利用3D数据集生成DRR图像

subsample = (1536 - 100) / height
delx = 0.194 * subsample
drr = DRR(
    specimen.volume,
    specimen.spacing,
    specimen.focal_len / 2,
    height,
    delx,
    x0=specimen.x0,
    y0=specimen.y0,
    reverse_x_axis=True,
).to(device)

定义Transform操作

  1. Lambda函数: 它对图像进行归一化处理,将像素值缩放到0到1的范围内。具体来说,它通过计算每个像素值与图像中最小像素值之间的差异,并将结果除以像素值范围(最大像素值减去最小像素值再加上一个小的常数epsilon),以确保最终的范围在0到1之间。
  2. Resize函数: 该函数将图像调整为指定的大小(size x size),同时使用抗锯齿算法(antialias=True)以提高图像质量。
  3. Normalize函数: 该函数对图像进行标准化处理,使用给定的均值和标准差。在这里,图像被标准化为具有均值0.3080和标准差0.1494的正态分布。

main函数

训练参数设置

  1. id_number: 要处理的病例或数据的编号。
  2. height: 图像的高度。默认为256像素。
  3. restart: 用于指定是否从之前的训练中恢复。如果提供了路径,则会加载模型和优化器状态,以便从上次训练中断的地方继续训练。
  4. pretrained_checkpoint: 可选参数,用于指定预训练模型的路径。如果提供了路径,则会加载预训练模型的权重,以供进一步训练或使用。
  5. model_name: 使用的模型的名称,默认为”resnet18”。在这里,可能是指定了使用的姿态估计模型的架构。
  6. parameterization: 参数化方法,默认为”se3_log_map”。可能是指定了姿态参数化的方法。
  7. convention: 坐标系约定,如果有的话。在这里,默认为None。
  8. lr: 学习率,默认为0.001。
  9. batch_size: 批量大小,默认为4。
  10. n_epochs: 训练的总轮数,默认为1000。
  11. n_batches_per_epoch: 每个轮次中的批次数量,默认为100。

模型参数设置

model_params = {
        "model_name": model_name,
        "parameterization": parameterization,
        "convention": convention,
        "norm_layer": "groupnorm",
    }
  1. model_name指定了要使用的模型的名称例如“resnet-18”
  2. "parameterization"键对应的值是"parameterization"制定了姿态参数化的方法,例如”se3_log_map”
  3. "convention" 键对应的值是 convention 参数,指定了坐标系的约定

模型构建

model = PoseRegressor(**model_params)

上述是一个姿态回归模型,用于姿态回归任务

代码为:

class PoseRegressor(torch.nn.Module):
    """
    A PoseRegressor is comprised of a pretrained backbone model that extracts features
    from an input X-ray and two linear layers that decode these features into rotational
    and translational camera pose parameters, respectively.
    """

    def __init__(
        self,
        model_name,
        parameterization,
        convention=None,
        pretrained=False,
        **kwargs,
    ):
        super().__init__()

        self.parameterization = parameterization
        self.convention = convention
        n_angular_components = N_ANGULAR_COMPONENTS[parameterization]

        # Get the size of the output from the backbone
        self.backbone = timm.create_model(
            model_name,
            pretrained,
            num_classes=0,
            in_chans=1,
            **kwargs,
        )
        output = self.backbone(torch.randn(1, 1, 256, 256)).shape[-1]
        self.xyz_regression = torch.nn.Linear(output, 3)
        self.rot_regression = torch.nn.Linear(output, n_angular_components)

    def forward(self, x):
        x = self.backbone(x)
        rot = self.rot_regression(x)
        xyz = self.xyz_regression(x)
        return convert(
            [rot, xyz],
            input_parameterization=self.parameterization,
            output_parameterization="se3_exp_map",
            input_convention=self.convention,
        )

网络组件:

*Backbone 模型*: 这个模型是一个预训练的卷积神经网络,用于从输入的 X 射线图像中提取特征。它是通过 timm.create_model 方法创建的,根据指定的 model_name(例如 “resnet18”)和其他参数(如 pretrained)加载预训练的模型。

self.backbone = timm.create_model(
            model_name,
            pretrained,
            num_classes=0,
            in_chans=1,
            **kwargs,
        )
        output = self.backbone(torch.randn(1, 1, 256, 256)).shape[-1]

线性回归器: 模型的输出经过两个线性层,将提取的特征映射到旋转和平移的姿态参数。

self.xyz_regression = torch.nn.Linear(output, 3)
self.rot_regression = torch.nn.Linear(output, n_angular_components)

其中,rot_regression 线性层将特征映射到旋转参数,而 xyz_regression 线性层将特征映射到平移参数。

Forward 方法:

def forward(self, x):
        x = self.backbone(x)
        rot = self.rot_regression(x)
        xyz = self.xyz_regression(x)
        return convert(
            [rot, xyz],
            input_parameterization=self.parameterization,
            output_parameterization="se3_exp_map",
            input_convention=self.convention,
        )

forward 方法中,输入的图像 x 首先通过 Backbone 模型提取特征,然后这些特征分别传递给 rot_regressionxyz_regression 线性层进行姿态参数的回归。最后,通过 convert 函数将回归得到的姿态参数从指定的参数化形式转换为另一种参数化形式就是从下面转,最后转出来一个se3的。

参数化:

N_ANGULAR_COMPONENTS = {
    "axis_angle": 3,
    "euler_angles": 3,
    "se3_log_map": 3,
    "quaternion": 4,
    "rotation_6d": 6,
    "rotation_10d": 10,
    "quaternion_adjugate": 10,
}

模型支持多种参数化方法,通过 parameterization 参数指定。在此处提供了一个字典 N_ANGULAR_COMPONENTS,用于指定不同参数化方法下旋转参数的维度。

优化器设置:

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

学习率调度器设置:

scheduler = WarmupCosineSchedule(
        optimizer,
        5 * n_batches_per_epoch,
        n_epochs * n_batches_per_epoch - 5 * n_batches_per_epoch,
    )

这行代码创建了一个学习率调度器,采用 Warmup Cosine 学习率调度策略。它会在优化器的基础上进行操作,根据训练的迭代轮次来调整学习率。Warmup Cosine 调度器在训练初期先进行学习率的线性增加(warmup),然后按照余弦函数衰减学习率。

train函数

MultiscaleNormalizedCrossCorrelation2d类

用于在多个尺度上计算两个图像批次之间的标准化交叉相关性(Normalized Cross Correlation,NCC)。

class MultiscaleNormalizedCrossCorrelation2d(torch.nn.Module):
    """Compute Normalized Cross Correlation between two batches of images at multiple scales."""

    def __init__(self, patch_sizes=[None], patch_weights=[1.0], eps=1e-5):
        super().__init__()

        assert len(patch_sizes) == len(patch_weights), "Each scale must have a weight"
        self.nccs = [
            NormalizedCrossCorrelation2d(patch_size) for patch_size in patch_sizes
        ]
        self.patch_weights = patch_weights

    def forward(self, x1, x2):
        scores = []
        for weight, ncc in zip(self.patch_weights, self.nccs):
            scores.append(weight * ncc(x1, x2))
        return torch.stack(scores, dim=0).sum(dim=0)

初始化

  • patch_sizes 是一个列表,包含了多个尺度下要用于计算 NCC 的图像块(patch)的尺寸。如果某个尺度的值为 None,则表示使用完整的输入图像。
  • patch_weights 是一个列表,包含了每个尺度的权重,用于加权不同尺度的 NCC 分数。
  • eps 是一个小的正数,用于避免除以零的情况。

forward

  • 接受两个输入张量 x1x2,分别代表两个图像批次。
  • 在每个尺度下,通过调用 NormalizedCrossCorrelation2d 类的实例计算 NCC,并根据对应的权重加权。
  • 将加权的 NCC 分数在所有尺度上求和,并返回结果。

计算NCC的类

class NormalizedCrossCorrelation2d(torch.nn.Module):
    """Compute Normalized Cross Correlation between two batches of images."""

    def __init__(self, patch_size=None, eps=1e-5):
        super().__init__()
        self.patch_size = patch_size
        self.eps = eps

    def forward(self, x1, x2):
        if self.patch_size is not None:
            x1 = to_patches(x1, self.patch_size)
            x2 = to_patches(x2, self.patch_size)
        assert x1.shape == x2.shape, "Input images must be the same size"
        _, c, h, w = x1.shape
        x1, x2 = self.norm(x1), self.norm(x2)
        score = torch.einsum("b...,b...->b", x1, x2)
        score /= c * h * w
        return score

    def norm(self, x):
        mu = x.mean(dim=[-1, -2], keepdim=True)
        var = x.var(dim=[-1, -2], keepdim=True, correction=0) + self.eps
        std = var.sqrt()
        return (x - mu) / std

初始化

  • patch_size 是一个整数,表示要用于计算 NCC 的图像块(patch)的大小。如果为 None,则表示使用整个输入图像。
  • eps 是一个小的正数,用于避免除以零的情况。

**forward **

  • 接受两个输入张量 x1x2,分别代表两个图像批次。
  • 如果指定了 patch_size,则将输入图像切分成大小为 patch_size 的图像块。
  • 通过norm函数对输入图像进行归一化处理,计算归一化后的均值和标准差。
  • 使用归一化后的图像计算 NCC 分数。这里使用了 PyTorch 的 torch.einsum 函数,按照公式 \(\sum_{h,w,c}(x1 - \bar{x1})(x2 - \bar{x2})\) 计算了 NCC 分数。
  • 最后,将计算得到的 NCC 分数进行标准化,以确保其值范围在 -1 到 1 之间。

上面的步骤计算的是NCC分数。实际是求两个零均值化的图像对应位置像素值的乘积之和,即求两个图像之间的相关性。

具体来说,torch.stack(scores, dim=0)scores 列表中的张量按照第一个维度(即尺度维度)进行堆叠,生成了一个形状为 (num_scales, batch_size) 的张量,其中 num_scales 是尺度的数量,就是 patch_sizes 列表的长度。

batch_size 是批次大小。然后,通过 sum(dim=0) 对所有尺度的 NCC 分数进行求和,得到了最终的结果,即多尺度 NCC 分数。forward 方法的输入参数 x1x2 是两个批次的图像,它们的形状通常为 (batch_size, channels, height, width),其中 batch_size 表示每个批次中图像的数量。在 forward 方法中,对每个尺度的 NCC 操作都会对两个批次的图像进行处理。因此,在循环中计算每个尺度下的 NCC 分数时,实际上是对每个批次的图像进行了 NCC 计算。这就意味着 forward 方法的输出的第一个维度(dim=0)对应于尺度(即 patch_sizes 的长度),而第二个维度(dim=1)对应于批次大小(即 batch_size)。

GeodesicSE3类

该模块用于计算两个刚体变换(Rigid Transform)在 SE(3) 对数空间(log-space)中的距离。

# %% ../notebooks/api/04_metrics.ipynb 11
class GeodesicSE3(torch.nn.Module):
    """Calculate the distance between transforms in the log-space of SE(3)."""

    def __init__(self):
        super().__init__()

    @beartype
    @jaxtyped
    def forward(
        self,
        pose_1: RigidTransform,
        pose_2: RigidTransform,
    ) -> Float[torch.Tensor, "b"]:
        # 计算相对变换,并将其转换为 SE(3) 对数空间中的对应向量,.norm表示计算两个刚体变换之间的范数,即距离
        return pose_2.compose(pose_1.inverse()).get_se3_log().norm(dim=1)

forward 方法中,定义了计算两个刚体变换之间距离的操作。这个方法接受两个参数 pose_1pose_2,它们分别表示两个刚体变换。然后,它计算了 pose_1pose_2 的相对变换,即 $pose_2 \circ pose_1^{-1}$,并将其转换为 SE(3) 对数空间中的对应向量。最后,它计算了这个向量的范数(norm),以表示两个刚体变换之间的距离。

DoubleGeodesic类

该模块用于计算两个 SE(3) 变换矩阵之间的角度和平移的测地线距离。

@beartype
class DoubleGeodesic(torch.nn.Module):
    """Calculate the angular and translational geodesics between two SE(3) transformation matrices."""

    def __init__(
        self,
        sdr: float,  # Source-to-detector radius
        eps: float = 1e-4,  # Avoid overflows in sqrt
    ):
        super().__init__()
        self.sdr = sdr
        self.eps = eps

        self.rotation = GeodesicSO3()
        self.translation = GeodesicTranslation()

    @beartype
    @jaxtyped
    def forward(self, pose_1: RigidTransform, pose_2: RigidTransform):
        angular_geodesic = self.sdr * self.rotation(pose_1, pose_2)
        translation_geodesic = self.translation(pose_1, pose_2)
        double_geodesic = (
            (angular_geodesic).square() + translation_geodesic.square() + self.eps
        ).sqrt()
        return angular_geodesic, translation_geodesic, double_geodesic
  • __init__ 方法中,初始化了这个模块。它接受两个参数 sdreps,分别表示源到探测器的半径和避免 sqrt 过程中的溢出。然后,它创建了两个子模块 rotationtranslation,分别用于计算角度和平移的测地线距离。
  • forward 方法中,定义了计算角度和平移测地线距离的操作。它接受两个参数 pose_1pose_2,分别表示两个 SE(3) 变换矩阵。然后,它使用子模块 rotation 计算角度测地线距离,并使用子模块 translation 计算平移测地线距离。最后,它计算了角度和平移测地线距离的平方和,并对其取平方根,得到两个变换矩阵之间的测地线距离。
#创建1-10均匀分布数
contrast_distribution = torch.distributions.Uniform(1.0, 10.0)
#设置最佳loss
best_loss = torch.inf

训练部分

# 对每个 epoch 进行循环
for epoch in range(n_epochs+1):
    losses = []  # 初始化损失列表,用于记录每个 mini-batch 的损失值
    # 对每个 mini-batch 进行循环
    for _ in (itr := tqdm(range(n_batches_per_epoch), leave=False)):
        # 从对比度分布中随机采样一个对比度值
        contrast = contrast_distribution.sample().item()
        # 获取随机偏移量
        offset = get_random_offset(batch_size, device)
        # 构造姿态变换
        pose = isocenter_pose.compose(offset)
        # 生成数字重建图像并应用转换
        img = drr(None, None, None, pose=pose, bone_attenuation_multiplier=contrast)
        img = transforms(img)

        # 将图像输入模型,得到预测的偏移量
        pred_offset = model(img)
        # 根据预测的偏移量构造预测的姿态变换
        pred_pose = isocenter_pose.compose(pred_offset)
        # 生成预测的数字重建图像并应用转换
        pred_img = drr(None, None, None, pose=pred_pose)
        pred_img = transforms(pred_img)

        # 计算损失函数,其中包括 NCC、测地线距离等指标
        ncc = metric(pred_img, img)
        log_geodesic = geodesic(pred_pose, pose)
        geodesic_rot, geodesic_xyz, double_geodesic = double(pred_pose, pose)
        loss = 1 - ncc + 1e-2 * (log_geodesic + double_geodesic)

        print("loss:", loss)

        # 如果损失为 NaN,则保存模型参数和相关信息,并抛出 RuntimeError
        if loss.isnan().any():
            print("Aaaaaaand we've crashed...")
            print(ncc)
            print(log_geodesic)
            print(geodesic_rot)
            print(geodesic_xyz)
            print(double_geodesic)
            print(pose.get_matrix())
            print(pred_pose.get_matrix())

            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "height": drr.detector.height,
                    "epoch": epoch,
                    "batch_size": batch_size,
                    "n_epochs": n_epochs,
                    "n_batches_per_epoch": n_batches_per_epoch,
                    "pose": pose.get_matrix().cpu(),
                    "pred_pose": pred_pose.get_matrix().cpu(),
                    "img": img.cpu(),
                    "pred_img": pred_img.cpu()
                    **model_params,
                },
                f"checkpoints_new/specimen_{id_number:02d}_crashed.ckpt",
            )
            raise RuntimeError("NaN loss")

        # 执行反向传播、梯度裁剪和参数更新操作
        optimizer.zero_grad()
        loss.mean().backward()
        adaptive_clip_grad_(model.parameters())
        optimizer.step()
        scheduler.step()

        # 记录损失值到列表中
        losses.append(loss.mean().item())

        # 更新进度条,显示当前 epoch 的进度和相关指标
        itr.set_description(f"Epoch [{epoch}/{n_epochs}]")
        itr.set_postfix(
            geodesic_rot=geodesic_rot.mean().item(),
            geodesic_xyz=geodesic_xyz.mean().item(),
            geodesic_dou=double_geodesic.mean().item(),
            geodesic_se3=log_geodesic.mean().item(),
            loss=loss.mean().item(),
            ncc=ncc.mean().item(),
        )

register.py函数

这里只列举出前一部分未涉及的代码

Registration类

Transforms类

from torchvision.transforms import Compose, Lambda, Normalize, Resize


class Transforms:
    def __init__(
        self,
        size: int,  # Dimension to resize image
        eps: float = 1e-6,
    ):
        """Transform X-rays and DRRs before inputting to CNN."""
        self.transforms = Compose(
            [
                Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + eps)),
                Resize((size, size), antialias=True),
                Normalize(mean=0.3080, std=0.1494),
            ]
        )

    def __call__(self, x):
        return self.transforms(x)

这段代码定义了一个名为 Transforms 的类,用于定义图像转换操作,以便在将X射线图像和DRR(Digital Reconstructed Radiograph)输入到卷积神经网络(CNN)之前对它们进行预处理。

  • __init__ 方法中,初始化了 Transforms 类。它接受两个参数:size 表示要调整图像的尺寸,eps 是一个非常小的值,用于避免除以零的情况。在这个方法中,创建了一个 torchvision.transforms.Compose 对象,用于将一系列图像转换操作串联起来。这些转换操作包括:
    • Lambda函数,用于将图像进行最小-最大归一化,即将像素值缩放到 [0, 1] 的范围内。
    • Resize操作,将图像调整为指定的大小 (size, size)。如果设置了 antialias=True,则进行抗锯齿处理。
    • Normalize操作,用于对图像进行标准化,使其具有零均值和单位标准差。这里的 meanstd 分别指定了归一化的均值和标准差。
  • __call__ 方法中,定义了对图像进行转换的操作。这个方法接受一个图像 x 作为输入,并将其传递给之前定义的图像转换操作。最后返回转换后的图像。

总的来说,这段代码的目的是提供一个图像转换类,用于将X射线图像和DRR进行预处理,以便在CNN中进行训练或推断。

initialize_registration方法

SparseRegistration类
class SparseRegistration(torch.nn.Module):
    def __init__(
        self,
        drr: DRR,
        pose: RigidTransform,
        parameterization: str,
        convention: str = None,
        features=None,  # 用于计算mNCC的偏差估计
        n_patches: int = None,  # 如果n_patches为None,则渲染整个图像
        patch_size: int = 13,
    ):
        super().__init__()
        self.drr = drr

        # Parse the input pose
        # 解析输入的姿态
        rotation, translation = convert(
            pose,
            input_parameterization="se3_exp_map",
            output_parameterization=parameterization,
            output_convention=convention,
        )
        self.parameterization = parameterization
        self.convention = convention
        self.rotation = torch.nn.Parameter(rotation)
        self.translation = torch.nn.Parameter(translation)

        # Crop pixels off the edge such that pixels don't fall outside the image
        # 裁剪边缘像素,以防止像素超出图像范围
        self.n_patches = n_patches
        self.patch_size = patch_size
        self.patch_radius = self.patch_size // 2 + 1
        self.height = self.drr.detector.height
        self.width = self.drr.detector.width
        self.f_height = self.height - 2 * self.patch_radius
        self.f_width = self.width - 2 * self.patch_radius

        # Define the distribution over patch centers
        # 定义patch中心的分布

        "如果 features 是 None,即没有提供特征信息,则将 features 初始化为一个形状为 (self.height, self.width) 的张量,其中所有元素的值都设为 1 / (self.height * self.width)。这表示所有像素的概率都是相等的,即所有像素都有相同的概率被选中作为渲染的中心。接下来,使用torch.distributions.categorical.Categorical 创建一个分布对象,该对象表示离散的分类分布。分布的概率由 features 张量中的值确定。为了创建这个分布,首先将 features 张量进行裁剪,以剔除边缘的像素,然后将其展平为一维张量。最终,patch_centers 是一个离散的分类分布,其值表示了渲染中心像素的概率分布。在后续的渲染过程中,可以根据这个分布来选择像素作为渲染的中心位置。"
        if features is None:
            features = torch.ones(
                self.height, self.width, device=self.rotation.device
            ) / (self.height * self.width)
        self.patch_centers = torch.distributions.categorical.Categorical(
            probs=features.squeeze()[
                self.patch_radius : -self.patch_radius,
                self.patch_radius : -self.patch_radius,
            ].flatten()
        )

    def forward(self, n_patches=None, patch_size=None):
        # Parse initial density
        # 解析初始密度
        if not hasattr(self.drr, "density"):
            self.drr.set_bone_attenuation_multiplier(
                self.drr.bone_attenuation_multiplier
            )

        if n_patches is not None or patch_size is not None:
            self.n_patches = n_patches
            self.patch_size = patch_size

        # Make the mask for sparse rendering
        # 创建稀疏渲染的掩码
        if self.n_patches is None:
            mask = torch.ones(
                1,
                self.height,
                self.width,
                dtype=torch.bool,
                device=self.rotation.device,
            )
        else:
            mask = torch.zeros(
                self.n_patches,
                self.height,
                self.width,
                dtype=torch.bool,
                device=self.rotation.device,
            )
            radius = self.patch_size // 2
            idxs = self.patch_centers.sample(sample_shape=torch.Size([self.n_patches]))
            idxs, jdxs = (
                idxs // self.f_height + self.patch_radius,
                idxs % self.f_width + self.patch_radius,
            )

            idx = torch.arange(-radius, radius + 1, device=self.rotation.device)
            patches = torch.cartesian_prod(idx, idx).expand(self.n_patches, -1, -1)
            patches = patches + torch.stack([idxs, jdxs], dim=-1).unsqueeze(1)
            patches = torch.concat(
                [
                    torch.arange(self.n_patches, device=self.rotation.device)
                    .unsqueeze(-1)
                    .expand(-1, self.patch_size**2)
                    .unsqueeze(-1),
                    patches,
                ],
                dim=-1,
            )
            mask[
                patches[..., 0],
                patches[..., 1],
                patches[..., 2],
            ] = True

        # Get the source and target
        # 获取源图像和目标图像
        pose = convert(
            [self.rotation, self.translation],
            input_parameterization=self.parameterization,
            output_parameterization="se3_exp_map",
            input_convention=self.convention,
        )
        source, target = make_xrays(
            pose,
            self.drr.detector.source,
            self.drr.detector.target,
        )

        # Render the sparse image
        # 渲染稀疏图像
        target = target[mask.any(dim=0).view(1, -1)]
        img = siddon_raycast(source, target, self.drr.density, self.drr.spacing)
        if self.n_patches is None:
            img = self.drr.reshape_transform(img, batch_size=len(self.rotation))
        return img, mask

    def get_current_pose(self):
        return convert(
            [self.rotation, self.translation],
            input_parameterization=self.parameterization,
            output_parameterization="se3_exp_map",
            input_convention=self.convention,
        )

这段代码定义了一个名为 SparseRegistration 的模块,用于执行稀疏图像配准(registration)。稀疏图像配准是将两幅图像在局部区域内对齐的过程。

具体来说:

  • SparseRegistration 类的 __init__ 方法用于初始化模块。它接受以下参数:
    • drr:DRR(Digital Reconstructed Radiograph)对象,表示数字重建放射图。
    • pose:RigidTransform 对象,表示相机姿态。
    • parameterization:姿态参数化方式的字符串表示。
    • convention:姿态表示的约定方式的字符串表示。
    • features:用于计算 mNCC(modified Normalized Cross-Correlation)偏差估计的特征。
    • n_patches:稀疏渲染中的 patch 数量。
    • patch_size:patch 的大小。
  • forward 方法用于执行前向传播。它接受以下参数:
    • n_patches:可选参数,用于覆盖初始化时指定的 n_patches
    • patch_size:可选参数,用于覆盖初始化时指定的 patch_size

forward 方法中,主要执行以下操作:

  • 解析初始密度和设置稀疏渲染的掩码。
  • 获取源图像和目标图像。
  • 渲染稀疏图像,并返回渲染后的图像和掩码。

此外,get_current_pose 方法用于获取当前的姿态。