神奇画笔-计算机视觉风格迁移的魔法
2024-05-11
- 迁移后的图像
from torchvision.models import vgg19
from torchvision import transforms
from PIL import Image
from matplotlib import pyplot as plt
import torch
import torch.optim as optim
import numpy as np
model=vgg19(pretrained=True)
for i in model.parameters():
i.requires_grad_(False)
def load_img(path, max_size=400,shape=None):
img = Image.open(path).convert('RGB')
if(max(img.size)) > max_size:size = max_size
else:size=max(img.size)
if shape is not None:
size = shape
transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
])
img = transform(img)[:3,:,:].unsqueeze(0)
return img
content=load_img('QQ图片20220207183552.jpg')
style = load_img('201609281450067826.jpg', shape=content.shape[-2:])
model.parameters
<bound method Module.parameters of VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU(inplace=True)
(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): ReLU(inplace=True)
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): ReLU(inplace=True)
(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): ReLU(inplace=True)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU(inplace=True)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): ReLU(inplace=True)
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)>
def get_features(img):
features = {}
layers = {'0':'conv1_1',
'5':'conv2_1',
'10':'conv3_1',
'19':'conv4_1',
'21':'conv4_2', #content层
'28':'conv5_1'}
num=0
for name,layer in model._modules.items():
n=0
if name=='features':
for i in layer.modules():
if n==0:n+=1;continue
img=i(img)
if n==1 or n==6 or n==11 or n==20 or n==22 or n==29:
features[str(n-1)]=img
n+=1
break
return features
content_features=get_features(content)
style_features=get_features(style)
def gram_matrix(tensor):
_, d, h, w = tensor.size()
tensor = tensor.view(d, h*w)
gram = torch.mm(tensor, tensor.t())
return gram
style_grams = {layer:gram_matrix(style_features[layer]) for layer in style_features}
target = content.clone().requires_grad_(True)
def im_convert(tensor):
img = tensor.clone().detach()
img = img.numpy().squeeze()
img = img.transpose(1,2,0)
img = img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
img = img.clip(0,1)
return img
'''定义不同层的权重'''
style_weights = {
'0': 1,
'5': 0.8,
'10': 0.5,
'19': 0.3,
'28': 0.1,
}
'''定义2种损失对应的权重'''
content_weight = 1
style_weight = 1e6
show_every = 100
optimizer = optim.Adam([target], lr=0.003)
steps =5000
for ii in range(steps):
target_features = get_features(target)
content_loss = torch.mean((content_features['21'] - target_features['21'])**2)
style_loss = 0
'''加上每一层的gram_matrix矩阵的损失'''
for layer in style_weights:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
_, d, h, w = target_feature.shape
style_gram = style_grams[layer]
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
style_loss += layer_style_loss/(d*h*w) #加到总的style_loss里,除以大小
total_loss = content_weight * content_loss + style_weight * style_loss
print('Total Loss:',total_loss.item())
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if ii % show_every == 0 :
print('Total Loss:',total_loss.item())
plt.imshow(im_convert(target))
plt.show()
plt.imshow(im_convert(target))
plt.show()