Friday, February 4, 2022

[SOLVED] AWS EC2 causing issue with Streamlit ML App

Issue

It's strange because on my local machine, this issue doesn't happen, and my app works fine.

However, when I run the app on an AWS EC2 instance, it gives me an error regarding a matplotlib import. Below the matplotlib import, I have matplotlib.use('TkAgg'). When the code is like this, the Streamlit app gives me this error (only on the EC2 instance):

ImportError: Cannot load backend 'TkAgg' which requires the 'tk' interactive framework, as 'headless' is currently running

Traceback:
File "/home/ubuntu/anaconda3/envs/streamlit/lib/python3.6/site-packages/streamlit/script_runner.py", line 332, in _run_script
    exec(code, module.__dict__)
File "/home/ubuntu/extremely_unnecessary/app.py", line 16, in <module>
    matplotlib.use('TkAgg')
File "/home/ubuntu/anaconda3/envs/streamlit/lib/python3.6/site-packages/matplotlib/__init__.py", line 1171, in use
    plt.switch_backend(name)
File "/home/ubuntu/anaconda3/envs/streamlit/lib/python3.6/site-packages/matplotlib/pyplot.py", line 287, in switch_backend
    newbackend, required_framework, current_framework))

After doing some research, I tried changing the offending line to matplotlib.use('agg'). When I do this, the app works properly, however none of the models except one work when selected.

The app is hosted here: http://54.193.229.139:8501/ The way it works is you upload an image, then select a pretrained model from the drop down menu to apply "style transfer" to the image you uploaded.

For some bizarre reason, the 12th model in the list (chicken-strawberries-market-069_10000.pth) works, but none of the other models do. Again, this only happens on the EC2 instance - even when I use matplotlib.use('agg'), all the models work when running the streamlit app locally.

I also tried using some other variations including matplotlib.use('GTK3Agg') and matplotlib.use('WebAgg'), which give me various other error messages.

Does anyone know how to fix this so I can get all the models working on the EC2 instance?

Edit: I've started receiving a new error message, I'm working on trying to change the code now. I use CUDA through my GPU, apparently I have to make some cpu-bound changes so it'll work on the ubuntu server. Not sure why the chicken strawberries model works though...

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
Traceback:
File "/home/ubuntu/anaconda3/envs/streamlit/lib/python3.6/site-packages/streamlit/script_runner.py", line 332, in _run_script
    exec(code, module.__dict__)
File "/home/ubuntu/extremely_unnecessary/app.py", line 91, in <module>
    main()
File "/home/ubuntu/extremely_unnecessary/app.py", line 54, in main
    transformer.load_state_dict(torch.load(checkpoint))
File "/home/ubuntu/anaconda3/envs/streamlit/lib/python3.6/site-packages/torch/serialization.py", line 595, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/home/ubuntu/anaconda3/envs/streamlit/lib/python3.6/site-packages/torch/serialization.py", line 774, in _legacy_load
    result = unpickler.load()
File "/home/ubuntu/anaconda3/envs/streamlit/lib/python3.6/site-packages/torch/serialization.py", line 730, in persistent_load
    deserialized_objects[root_key] = restore_location(obj, location)
File "/home/ubuntu/anaconda3/envs/streamlit/lib/python3.6/site-packages/torch/serialization.py", line 175, in default_restore_location
    result = fn(storage, location)
File "/home/ubuntu/anaconda3/envs/streamlit/lib/python3.6/site-packages/torch/serialization.py", line 151, in _cuda_deserialize
    device = validate_cuda_device(location)
File "/home/ubuntu/anaconda3/envs/streamlit/lib/python3.6/site-packages/torch/serialization.py", line 135, in validate_cuda_device
    raise RuntimeError('Attempting to deserialize object on a CUDA '

The code for the app:

import matplotlib.pyplot as plt
from PIL import Image
from torchvision.utils import save_image
import tqdm
import streamlit as st
from models import TransformerNet
from utils import *
import torch
import numpy as np
from torch.autograd import Variable
import argparse
import tkinter as tk
import os
import cv2
import matplotlib
matplotlib.use('agg')


def main():

    uploaded_file = st.file_uploader(
        "Choose an image", type=['jpg', 'png', 'webm', 'mp4', 'gif', 'jpeg'])
    if uploaded_file is not None:
        st.image(uploaded_file, width=200)

    folder = os.path.abspath(os.getcwd())
    folder = folder + '/models'

    fnames = []

    for basename in os.listdir(folder):
        print(basename)
        fname = os.path.join(folder, basename)

        if fname.endswith('.pth'):
            fnames.append(fname)

    checkpoint = st.selectbox('Select a pretrained model', fnames)

    os.makedirs("images/outputs", exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")
    transform = style_transform()

    # Define model and load model checkpoint
    transformer = TransformerNet().to(device)
    transformer.load_state_dict(torch.load(checkpoint))
    transformer.eval()

    # Prepare input
    image_tensor = Variable(transform(Image.open(
        uploaded_file).convert('RGB'))).to(device)
    image_tensor = image_tensor.unsqueeze(0)

    # Stylize image
    with torch.no_grad():
        stylized_image = denormalize(transformer(image_tensor)).cpu()

    fn = str(np.random.randint(0, 100)) + 'image.jpg'
    save_image(stylized_image, f"images/outputs/stylized-{fn}")

    st.image(f"images/outputs/stylized-{fn}")


if __name__ == "__main__":
    main()



Solution

Turns out all I needed to do was implement the line in the error message - in line 53, I just had to change it from this:

transformer.load_state_dict(torch.load(checkpoint))

to this

transformer.load_state_dict(torch.load(
    checkpoint, map_location=torch.load('cpu')))

And it works!



Answered By - Nick
Answer Checked By - Mary Flores (WPSolving Volunteer)