我正在阅读的一篇论文中包含了一个机器学习模型,我想在PyTorch中复现这个模型。
基本上,输入被等分为n个大小相同的向量,每个向量都会被分别传递到一个独立的局部模型中。所有局部模型的输出随后会被连接起来,并通过下一层处理。(此处x与我的问题无关,因此我们忽略它)
目前我已构建如下代码:
import torch.nn as nn
class GlobalModel(nn.Module):
def __init__(self, n_local_models):
super(GlobalModel, self).__init__()
self.local_models = [LocalModel() for _ in range(n_local_models)]
self.linear = nn.Linear(100, 100)
self.activation = nn.ReLU()
其中LocalModel
是另一个torch.nn.Module类。线性层的大小只是一个临时设定,稍后我会根据局部模型动态调整。
我的问题是,如何最好地编写一个forward()函数,使得它能并行运行所有局部模型,在将它们的输出连接起来之后,再传递给线性层和激活函数进行处理。目前我能想到的唯一实现方式是遍历局部模型列表,并逐个顺序执行它们。但这似乎效率较低,我感觉应该存在一个更为优雅的解决方案。