Embedding machine learning models into a GUI(Gradio & Streamlit)
Overview
In the ever-evolving landscape of data science and machine learning, the ability to effectively showcase and deploy predictive models is becoming increasingly essential. Two powerful tools, Gradio and Streamlit, have emerged as go-to platforms for building interactive and user-friendly applications that allow users to interact with machine learning models effortlessly. In this article, we’ll explore the journey of embedding machine learning models in both Gradio and Streamlit, demonstrating how these frameworks can transform complex algorithms into accessible, real-world applications.
Prerequisites
For you to build an effective machine learning web app, the following are the requirements:
- Familiarity with Python: A foundational understanding of the Python programming language is essential, as both Gradio and Streamlit are Python-based frameworks.
- Basic Machine Learning Knowledge: While not mandatory, a basic grasp of machine learning concepts will help you understand the models you’re working with and how to present their predictions effectively.
- Python Environment: Set up a Python environment using tools like virtual environments or conda to manage dependencies and avoid conflicts.
- Installed Libraries: Make sure you have the required libraries, such as Gradio, Streamlit, pandas, and scikit-learn, installed in your Python environment.
- Trained Machine Learning Model: Have a trained machine learning model ready for deployment.
Introduction
This project will utilize already pre-trained models that have been saved using pickle. The models, which have been trained using Telco Churn customer data from Vodafone, a telecommunications company and store sales data from Corporation Favorita, a grocery store in Ecuador.
We’ll use customer data saved as churn_prediction_model.pkl to build a web app using Streamlit while store sales data whose model is saved as best_model.pkl will be used to build a web app using Gradio.
So let’s dive in and begin!
Python Environment
Creating a Virtual Environment on Windows:
- Open Command Prompt or PowerShell.
- Navigate to the directory where you want to create the virtual environment using the
cd
command. - Run the following command to create a virtual environment named “myenv”:
python -m venv myenv
4. To activate the virtual environment, run:
myenv\Scripts\activate
Creating a Virtual Environment on macOS:
- Open Terminal.
- Navigate to the directory where you want to create the virtual environment using the
cd
command. - Run the following command to create a virtual environment named “myenv”:
python3 -m venv myenv
4. To activate the virtual environment, run:
source myenv/bin/activate
Install dependencies
We need to install requirements.txt in our virtual environment, which is the file that contains the libraries we’ll use.
gradio
streamlit
scikit-learn==1.0.2
joblib==1.2.0
numpy==1.23.1
pandas==1.4.3
Building a web app to predict Customer Churn using Streamlit
This app requires a user to input customer attributes such as whether they have a partner, internet service or the total charges they incur for the services offered by the telecom company inorder to predict if a customer will churn or not.
Let’s import the libraries for our Streamlit web app then load the trained model.
import streamlit as st
import joblib
import pandas as pd
# Load the trained model
pipeline = joblib.load('churn_prediction_model.pkl')
Now to the main part, creating the app.
Design the User Interface
- Add and design sidebar to define the purpose of the app.
- Add company logo as .png just to make the app welcoming.
- Add title to the app and modify the font style.
- Create user input widgets, a slider for tenure and create a dataframe for the inputs.
- Create a predict button such that when clicked, it displays the user inputs in a dataframe and makes a prediction for churn using the pipeline we loaded from our presaved model.
# Create the Streamlit app
def main():
# Set the background color of the sidebar
st.sidebar.markdown(
"""
<style>
.sidebar {
background-color: darkblue;
color: white;
padding: 20px;
}
</style>
""",
unsafe_allow_html=True
)
# Add image of the company logo
logo_image = "images/vodafone.png"
st.image(logo_image, use_column_width=True)
# Create a side widget for app purpose
st.sidebar.title("Churn Prediction App")
st.sidebar.write("Welcome to the Churn Prediction App! This app is designed to help you predict customer churn using a machine learning model. Here's how to use the app:")
st.sidebar.subheader("Step 1: Input Customer Information")
st.sidebar.write("Use the input widgets on the left to provide customer information, including tenure, monthly charges, total charges, and more. Select options for relevant features using checkboxes and dropdowns.")
st.sidebar.subheader("Step 2: Predict Customer Churn")
st.sidebar.write("Once you've input all the necessary information, click the 'Predict' button below the input widgets. The app will use its machine learning model to predict whether the customer is likely to churn or stay.")
st.sidebar.subheader("Step 3: Interpret Results")
st.sidebar.write("After clicking the 'Predict' button, the app will display the prediction result. If the prediction indicates that the customer is likely to stay, the app will show 'Customer is likely to stay.' If the prediction indicates churn, the app will show 'Customer is likely to churn.'")
st.sidebar.subheader("Step 4: Explore Further")
st.sidebar.write("Feel free to adjust the input values and options to see how different factors impact the prediction outcome. The app is designed to assist you in understanding and analyzing customer churn.")
st.sidebar.subheader("About the App")
st.sidebar.write("This app is powered by a machine learning model that has been trained on customer data to predict churn. It's intended for educational and demonstration purposes. For any questions or assistance, please reach out to our team.")
# Create input widgets for user input
st.markdown("<h1 style='font-family: Times New Roman, Times, serif; font-size:48px; font-style:italic;'>Customer Churn Prediction</h1>", unsafe_allow_html=True)
# Create input widgets for user input
tenure = st.slider("Tenure", 1, 72, 36)
monthly_charges = st.number_input("Monthly Charges", min_value=0.0, max_value=1000.0, value=50.0)
total_charges = st.number_input("Total Charges", min_value=0.0, max_value=10000.0, value=1500.0)
# Create a DataFrame with user input
user_data = pd.DataFrame({
'SeniorCitizen_1': [1 if st.checkbox("Senior Citizen") else 0],
'Partner_Yes': [1 if st.checkbox("Partner") else 0],
'Dependents_Yes': [1 if st.checkbox("Dependents") else 0],
'PhoneService_Yes': [1 if st.checkbox("Phone Service") else 0],
'MultipleLines_No phone service': [0],
'MultipleLines_Yes': [1 if st.checkbox("Multiple Lines") else 0],
'InternetService_Fiber optic': [1 if st.checkbox("Fiber Optic Internet") else 0],
'InternetService_No': [0],
'OnlineSecurity_Yes': [1 if st.checkbox("Online Security") else 0],
'OnlineBackup_Yes': [1 if st.checkbox("Online Backup") else 0],
'DeviceProtection_Yes': [1 if st.checkbox("Device Protection") else 0],
'TechSupport_Yes': [1 if st.checkbox("Tech Support") else 0],
'StreamingTV_Yes': [1 if st.checkbox("Streaming TV") else 0],
'StreamingMovies_Yes': [1 if st.checkbox("Streaming Movies") else 0],
'Contract_One year': [1 if st.checkbox("One Year Contract") else 0],
'Contract_Two year': [1 if st.checkbox("Two Year Contract") else 0],
'PaperlessBilling_Yes': [1 if st.checkbox("Paperless Billing") else 0],
'PaymentMethod_Credit card (automatic)': [0],
'PaymentMethod_Electronic check': [1 if st.checkbox("Electronic Check") else 0],
'PaymentMethod_Mailed check': [1 if st.checkbox("Mailed Check") else 0],
'tenure': [tenure],
'MonthlyCharges': [monthly_charges],
'TotalCharges': [total_charges]
})
# Create "Predict" button
predict_button = st.button("Predict")
# Perform prediction when the button is clicked
if predict_button:
# Display user input
st.write("User Input:")
st.dataframe(user_data)
# Make prediction
prediction = pipeline.predict(user_data)
# Display prediction result
if prediction[0] == 0:
st.write("Prediction: Customer is likely to stay.")
else:
st.write("Prediction: Customer is likely to churn.")
if __name__ == "__main__":
main()
The web app looks like this
Building a web app to forecast store sales using Gradio
This app requires a user to input various variables such as date, store number, promotion status, year, cluster, and daily oil price inorder to forecast sales.
First, we import necessary libraries.
# Import libraries
import gradio as gr
import pandas as pd
import joblib
from datetime import datetime
Thereafter we load our trained model
# Load the ARIMA model
arima_model = joblib.load('best_model.pkl')
Then we define the app features for our app together with the inputs and how they will be passed in our model as shown in the code snippet below.
def sales_forecast(date, store_nbr, onpromotion, year, cluster, dcoilwtico):
# Parse the date input into a datetime object
date = datetime.strptime(date, "%Y-%m-%d")
# Create a DataFrame with the user inputs
user_input = pd.DataFrame({
'Date': [date],
'Store Number': [store_nbr],
'On promotion': [onpromotion],
'Year': [year],
'Cluster': [cluster],
'Daily Oil Price': [dcoilwtico]
})
# Make the sales forecast prediction using the ARIMA model
forecast = arima_model.forecast(steps=1)
return forecast[0] # Return the forecasted value
# Create the interface
iface = gr.Interface(
fn=sales_forecast,
inputs=[
gr.inputs.Textbox(label="Date"),
gr.inputs.Number(label="Store Number"),
gr.inputs.Number(label="On promotion"),
gr.inputs.Number(label="Year"),
gr.inputs.Number(label="Cluster"),
gr.inputs.Number(label="Daily Oil Price")
],
outputs=gr.outputs.Textbox(),
title="<div style='text-align: center;'>"
"<span style='font-family: \"Times New Roman\", sans-serif; font-style: italic; font-weight: bold;'>CORPORATION FAVORITA</span>"
"<p style='font-size: 16px;'>Predict sales across our different stores</p>"
"</div>",
theme="light" # Set the theme to "light"
)
if __name__ == "__main__":
iface.launch()
The Gradio web app looks like this
You can save the inputs together with the output(forecast) by clicking the flag button in the app, which are saved as a .csv file.
In conclusion, the process of embedding machine learning models into web applications using Gradio and Streamlit offers a powerful way to interact with and showcase your models to a broader audience. Whether you choose Gradio’s intuitive and rapid interface-building approach or Streamlit’s versatile and customizable platform, both tools empower data scientists to create user-friendly applications that bridge the gap between complex models and end-users.
By following the steps I have outlined in this article and exploring other examples, you can unlock the potential of your machine learning projects and make them accessible to a wider audience. As we have seen above, building a machine learning web app doesn’t have to be a difficult task thanks to Gradio and Streamlit, it can be a rewarding journey that brings your models to life in an interactive and engaging manner.
Appreciation
Sincere appreciation to Azubi Africa through their program Azubi Data Science for the valuable skills I have acquired. Spare a few minutes and check out Azubi Africa and as well learn about their programs.