PyTorch Model Config & Freeze

PyTorch로 모델 일부를 Config하는 예시

for ct, child in enumerate(model.children()):
    if type(child) == nn.modules.container.ModuleList:
        for gct, gchild in enumerate(child.children()):
            if gct==4:
                gchild._depthwise_conv=Conv2dSame(in_channels=192, out_channels=192, kernel_size=1, dilation=2, bias=False)
            elif gct==11:
                gchild._depthwise_conv=Conv2dSame(in_channels=288, out_channels=288, kernel_size=1, dilation=2, bias=False)
model._conv_head=Conv2dSame(in_channels=640, out_channels=1024, kernel_size=1, bias=False)
model._bn1=nn.BatchNorm2d(1024, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)


PyTorch로 모델 일부를 Freezing하는 예시

def freeze_model(model):
    for ct, child in enumerate(model.children()):
        if isinstance(child,nn.modules.batchnorm.BatchNorm2d):
            for param in child.parameters():
                param.requires_grad = False
    return model

Leave a Reply

Your email address will not be published. Required fields are marked *