Build A PyTorch Style Transfer Web App With Streamlit
In this tutorial we build an interactive deep learning app with Streamlit and PyTorch to apply style transfer.
#more
In this tutorial we build an interactive deep learning app with Streamlit and PyTorch to apply style transfer. This tutorial should demonstrate how easy interactive web applications can be build with Streamlit. Streamlit lets you create beautiful apps for your machine learning or deep learning projects with simple Python scripts. See official Streamlit website for more info.
You can find the code on GitHub: https://github.com/patrickloeber/pytorch-examples.
The style transfer code is based on this fast neural style code from the official PyTorch examples repo: Fast Neural Style.
Installation¶
It is recommended to use a virtual environment before installing the dependencies
pip install streamlit
pip install torch torchvision
Usage¶
Download the pretrained models
python download_saved_models.py
After downloading, move the saved_models folder into the neural_style folder. Then run
streamlit run main.py
The PyTorch functions¶
To utilize [caching]((https://docs.streamlit.io/en/latest/caching.html) we split the original stylize()
function into two different functions for model loading and for applying the style transfer:
# style.py
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@st.cache
def load_model(model_path):
print('load model')
with torch.no_grad():
style_model = TransformerNet()
state_dict = torch.load(model_path)
# remove saved deprecated running_* keys in InstanceNorm from the checkpoint
for k in list(state_dict.keys()):
if re.search(r'in\d+\.running_(mean|var)$', k):
del state_dict[k]
style_model.load_state_dict(state_dict)
style_model.to(device)
style_model.eval()
return style_model
@st.cache
def stylize(style_model, content_image, output_image):
content_image = utils.load_image(content_image)
content_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
content_image = content_transform(content_image)
content_image = content_image.unsqueeze(0).to(device)
with torch.no_grad():
output = style_model(content_image).cpu()
utils.save_image(output_image, output[0])
The Streamlit App¶
Implementing the web app is straightforward and can be achieved in only 30 lines:
import streamlit as st
from PIL import Image
import style
st.title('PyTorch Style Transfer')
img = st.sidebar.selectbox(
'Select Image',
('amber.jpg', 'cat.png')
)
style_name = st.sidebar.selectbox(
'Select Style',
('candy', 'mosaic', 'rain_princess', 'udnie')
)
model= "saved_models/" + style_name + ".pth"
input_image = "images/content-images/" + img
output_image = "images/output-images/" + style_name + "-" + img
st.write('### Source image:')
image = Image.open(input_image)
st.image(image, width=400) # image: numpy array
clicked = st.button('Stylize')
if clicked:
model = style.load_model(model)
style.stylize(model, input_image, output_image)
st.write('### Output image:')
image = Image.open(output_image)
st.image(image, width=400)
FREE VS Code / PyCharm Extensions I Use
✅ Write cleaner code with Sourcery, instant refactoring suggestions: Link*
Python Problem-Solving Bootcamp
🚀 Solve 42 programming puzzles over the course of 21 days: Link*