我一直在阅读一篇论文,其中包含了一个机器学习模型,我想在PyTorch中复现该模型:
[由于此处为文本描述,无法显示图片,请参考问题中的模型结构图]
简而言之,输入数据被等分为n个大小相同的向量,并且每个向量被分别传递给一个独立的局部模型。所有局部模型的输出随后被连接在一起,并通过下一个层进行处理。(此处x与我的问题无关,我们可以暂时忽略它)
目前我提出了以下代码:
class GlobalModel(torch.nn.Module):
def __init__(self, n_local_models):
super(GlobalModel, self).__init__()
self.local_models = [LocalModel() for _ in range(n_local_models)]
self.linear = torch.nn.Linear(100, 100)
self.activation = torch.nn.ReLU()
这里的LocalModel
是一个预先定义好的torch.nn.Module
类实例。线性层的尺寸只是一个暂定值,后续我会根据局部模型动态调整。
我的问题是,如何编写一个最优化的forward()
函数,使得它可以并行地运行所有局部模型,然后将它们的结果拼接起来,再传入线性层和激活函数进行处理?目前我能想到的唯一实现方式是遍历局部模型列表,逐个执行这些模型。但这似乎效率较低,我觉得应该有一个更为优雅高效的解决方案。