A full factorial design in Python from Beginning to End
Introduction
Commercial DoE and statistical software is incredibly powerful, but they also come with a hefty price tag. In addition, such software can be pretty daunting because of their many features of which you 80% probably don’t need.
Therefore, I was experimenting with python instead and here is the first result of how you can use it to perform a 2-level full factorial design. Feel free to use and adapt the functions here to your needs.
Packages
from pyDOE2 import fullfact import numpy as np import pandas as pd import matplotlib.pyplot as plt import itertools import math from statsmodels.formula.api import ols import statsmodels.api as sm from mpl_toolkits.mplot3d import Axes3D
We start by loading all the required packages.
from pyDOE2 import fullfact
:pyDOE2
is a library for designing experiments, andfullfact
is a function used to generate full factorial designs. This function creates a matrix where each row represents an experimental run, and each column represents a level of the factor, covering all possible combinations of factor levels.
import numpy as np
:numpy
is a fundamental package for scientific computing in Python. It's commonly used for numerical computations.
import pandas as pd
:pandas
provides data structures and data analysis tools in Python. It's particularly well-suited for handling structured data, like experimental results stored in tables (DataFrames), and is used for data manipulation and analysis.
import matplotlib.pyplot as plt
:matplotlib
is a widely used plotting library for Python. It's capable of creating static, interactive, and animated visualizations in Python. Theplt
submodule is used here for plotting graphs, such as bar charts and contour plots.
import itertools
:The
itertools
module provides a collection of tools for handling iterators. Within the context of experimental design, it can be used to generate combinations of factor levels.
import math
:The
math
module provides access to mathematical functions.
from statsmodels.formula.api import ols
:statsmodels
is a Python module that provides classes and functions for the estimation of many different statistical models, as well as for conducting statistical tests and statistical data exploration. Theols
function from theformula.api
submodule is used for fitting Ordinary Least Squares regression models.
import statsmodels.api as sm
:This import statement also brings in
statsmodels
but is typically used to access a broader range of statistical models, tests, and data exploration tools beyond just OLS, such as ANOVA.
from mpl_toolkits.mplot3d import Axes3D
:mpl_toolkits.mplot3d
is a submodule ofmatplotlib
that provides tools for basic 3D plotting capabilities, like 3D scatter plots and surface plots.Axes3D
is used to create a 3D axes object for 3D plotting.
These packages collectively provide a robust toolkit for conducting factorial design experiments, from the initial design phase (using pyDOE2
), through data manipulation and analysis (with pandas
, numpy
, and statsmodels
), to the final visualization of results and diagnostic plots (using matplotlib
and mpl_toolkits.mplot3d
).
Experimental plan
def create_full_factorial_design(factors, randomize=False): # Create a 2-level full factorial design design = fullfact([2]*len(factors)) # Convert levels from 0/1 to -1/+1 design = 2*design - 1 # Convert design matrix to a DataFrame df = pd.DataFrame(design, columns=factors) # Randomize the design if needed if randomize: df = df.sample(frac=1).reset_index(drop=True) return df
The next step is the creation of our experimental design plan. The above python function create_full_factorial_design
, is designed to create a 2-level full factorial design plan. It is based on the pyDoE2 package but designed to be more user-friendly. Here is how it works:
Function Definition:
The function
create_full_factorial_design
is defined with two parameters:factors
andrandomize
.factors
is a list of strings representing the names of the factors in the experiment.randomize
is a boolean value indicating whether the order of runs in the factorial design should be randomized. The default value isFalse
, meaning no randomization unless specified.
Creating the Full Factorial Design:
The function starts by creating a 2-level full factorial design matrix using the
fullfact
function. Thefullfact
function requires a list where each element represents the number of levels for a factor. Since this is a 2-level design,[2]*len(factors)
creates a list with an element2
for each factor, indicating two levels for each factor.The resulting design matrix,
design
, contains rows representing the different runs or experiments, and columns representing the factors. Initially, the levels are represented as0
and1
.
Converting Levels to -1 and +1:
The code
design = 2*design - 1
transforms the level coding from0/1
to-1/+1
, which is a common practice in factorial designs to facilitate the analysis. This step changes the lower level (0
) to-1
and the upper level (1
) to+1
.
Converting to DataFrame:
The design matrix is then converted into a pandas DataFrame,
df
, with columns named according to thefactors
list. This step facilitates further data manipulation and analysis within the Python ecosystem, especially with pandas.
Randomizing the Design:
If the
randomize
parameter is set toTrue
, the function randomizes the order of the runs in the design matrix. This is done using thesample
method withfrac=1
, which shuffles the DataFrame rows. Thereset_index(drop=True)
part resets the DataFrame index without adding the old index as a column, maintaining the original structure but in a randomized order.
Returning the Design Matrix:
Finally, the function returns the DataFrame
df
, which represents the 2-level full factorial design matrix with the specified factors as columns and the runs as rows, optionally randomized based on therandomize
parameter.
factors = ['T', 'P', 'C', 'RPM'] df = create_full_factorial_design(factors, randomize=False) # Save the design to an Excel file df.to_excel('full_factorial_design_filtration_rate.xlsx', index=False)
The code above demonstrates how the function create_full_factorial_design
is used in practice to generate a full factorial design matrix for an experiment. This specific example involves the four factors Temperature (T), Pressure (P), Concentration (C) and Revolutions per minute (RPM). The experimental design is related to the filtration rate example that were discussed in earlier articles. Here's a step-by-step explanation of the code:
Defining Factors:
The first line
factors = ['T', 'P', 'C', 'RPM']
defines a list of factor names. Each element in the list is a string that represents a different factor to be included in the factorial design. The choice of factors depends on what the experimenter believes could influence the response variable, which is the filtration rate in this example.
Generating the Design Matrix:
The second line
df = create_full_factorial_design(factors, randomize=False)
calls the previously discussed functioncreate_full_factorial_design
with the list of factors as its first argument. The second argument,randomize=False
, specifies that the order of the runs (or experimental conditions) in the design matrix should not be randomized. This means the experiments will be listed in a standard order based on the levels of the factors.The function returns a pandas DataFrame
df
that contains the full factorial design matrix. Each row in this DataFrame represents an experimental run with specific levels (either -1 or +1) for each of the factors 'T', 'P', 'C', and 'RPM'. The design covers all possible combinations of these levels across all factors.
Exporting the Design Matrix to Excel:
The final line
df.to_excel('full_factorial_design_filtration_rate.xlsx', index=False)
uses theto_excel
method of the pandas DataFrame to save the design matrix to an Excel file named 'full_factorial_design_filtration_rate.xlsx'. Theindex=False
parameter is included to prevent pandas from writing the DataFrame index as a separate column in the Excel file. This results in a cleaner design matrix where only the columns for the factors are included.
The so created experimental plan can now be executed in the lab. Each experimental run at a time, where the filtration rate is measured and noted in a separate column of the excel.
Visualization
After performing the experiments and adding the results (filtration rate is our only result variable in this example) we start to analyze the data with some simple visualization.
Main effects
def main_effects_plot(excel_file, result_column): # Load data from excel df = pd.read_excel(excel_file) # Identify factors by excluding the result column factors = [col for col in df.columns if col != result_column] # Calculate main effects main_effects = {} for factor in factors: mean_plus = df[df[factor] == 1][result_column].mean() mean_minus = df[df[factor] == -1][result_column].mean() main_effects[factor] = mean_plus - mean_minus # Plot main effects fig, ax = plt.subplots(figsize=(5, 5)) ax.bar(main_effects.keys(), main_effects.values()) # Annotate the bars with their values for factor, value in main_effects.items(): ax.text(factor, value + 0.01 * abs(value), '{:.2f}'.format(value), ha='center', va='bottom') ax.set_ylabel('Main Effect') ax.set_title('Main Effects Plot') plt.xticks(rotation=45, ha="right") plt.show()
The function main_effects_plot
above is designed to load experimental data from an Excel file, calculate the main effects of each factor on a specified result, and then plot these main effects in a bar chart. Here's a detailed explanation of how the function operates:
Function Definition:
main_effects_plot
is defined with two parameters:excel_file
, which is expected to be the path to an Excel file containing the experimental data, andresult_column
, which is the name of the column in the Excel file that contains the result or response variable of the experiment.
Loading Data:
The line
df = pd.read_excel(excel_file)
uses pandasread_excel
function to load the experimental data from the specified Excel file into a DataFramedf
. This DataFrame includes both the factors and the result column.
Identifying Factors:
Factors are identified by excluding the result column from the list of DataFrame columns. This is done using a list comprehension:
[col for col in df.columns if col != result_column]
. The resulting list,factors
, contains the names of all columns except the result column, assuming these to be the factors in the experiment.
Calculating Main Effects:
The function then iterates over each factor to calculate its main effect. The main effect of a factor is determined by the difference in the mean of the result variable when the factor is at its high level (coded as
+1
) and its low level (coded as-1
).For each factor,
mean_plus
is the mean of the result column where the factor level is+1
, andmean_minus
is the mean of the result column where the factor level is-1
.The difference
mean_plus - mean_minus
represents the main effect of the factor, which is stored in a dictionarymain_effects
with the factor names as keys.
Plotting Main Effects:
A bar chart is created to visualize the main effects using
matplotlib
. Thefig, ax = plt.subplots(figsize=(5, 5))
line initializes a figure and a single subplot with a specified size.The main effects are plotted as bars with
ax.bar(main_effects.keys(), main_effects.values())
, where the keys of themain_effects
dictionary (factor names) are used as the bar labels and the values (main effects) determine the height of the bars.
Annotating the Bars:
Each bar is annotated with its value using a loop that goes through each
factor, value
pair in themain_effects
dictionary. Theax.text
method places a text label (formatted to two decimal places) just above each bar to indicate the main effect's magnitude.
Customizing the Plot:
The y-axis is labeled as 'Main Effect', and a title 'Main Effects Plot' is set for the plot.
The x-ticks, which represent the factor names, are rotated 45 degrees to the right (
plt.xticks(rotation=45, ha="right")
) for better readability.
Displaying the Plot:
Finally,
plt.show()
displays the plot with the main effects of each factor. This visual representation helps in understanding which factors have the most significant impact on the result variable, indicated by the height and direction of the bars (positive or negative main effects).
excel_file = 'full_factorial_design_filtration_rate_results.xlsx' result_column = 'Filtration_rate' main_effects_plot(excel_file, result_column)
The code above shows how the main_effects_plot
function is used.
Specifying the Excel File:
excel_file = 'full_factorial_design_filtration_rate_results.xlsx'
assigns the string containing the file name of the Excel file to the variableexcel_file
. This Excel file is expected to contain the results of a full factorial design experiment, including the levels of each factor for every experimental run and the corresponding results for the 'Filtration_rate'.
Identifying the Result Column:
result_column = 'Filtration_rate'
sets the name of the column in the Excel file that contains the response variable data. In this case there is a column named 'Filtration_rate' in the Excel file that contains the result values that we are interested in.
Executing the Main Effects Plot Function:
main_effects_plot(excel_file, result_column)
calls themain_effects_plot
function, passing the path to the Excel file and the name of the result column as arguments. This function will then execute the steps that are described above and generate and display a bar chart that visualizes the main effects. Each bar represents a factor, and the height of the bar indicates the magnitude of the factor's main effect on the 'Filtration_rate'. Positive values indicate an increase in the 'Filtration_rate' when the factor is at its high level, while negative values indicate a decrease.
Interactions
def interaction_point_plot(excel_file, result_column): df = pd.read_excel(excel_file) factors = [col for col in df.columns if col != result_column] interactions = list(itertools.combinations(factors, 2)) # Calculate number of rows and columns for subplots cols = 3 rows = math.ceil(len(interactions) / cols) fig, axs = plt.subplots(rows, cols, figsize=(10, 10*rows/3)) for idx, interaction in enumerate(interactions): row = idx // cols col = idx % cols ax = axs[row, col] if rows > 1 else axs[col] for level in [-1, 1]: subset = df[df[interaction[0]] == level] ax.plot(subset[interaction[1]].unique(), subset.groupby(interaction[1])[result_column].mean(), 'o-', label=f'{interaction[0]} = {level}') ax.set_title(f'{interaction[0]} x {interaction[1]}') ax.legend() ax.grid(True) # Handle any remaining axes (if there's no data to plot in them) for idx in range(len(interactions), rows*cols): row = idx // cols col = idx % cols axs[row, col].axis('off') plt.tight_layout() plt.show()
After plotting the main effects, we use interaction_point_plot
to explore potential two-way interactions in our factorial design. This function requires the same two parameters as the function for the main effects plot: excel_file
and result_column
.
Here is how the function works.
Loading Data and Identifying Factors:
pd.read_excel(excel_file)
: This line reads the experimental data from the specified Excel file into a DataFrame.Factor Identification: A list comprehension creates a list of factor names, which includes all columns in the DataFrame except the one specified in
result_column
.
Generating Factor Combinations for Interactions:
interactions = list(itertools.combinations(factors, 2))
: This line utilizes thecombinations
function from Python'sitertools
module to generate all possible unique pairs of factors. These pairs represent the 2-way interactions we aim to analyze.
Setting Up the Plot Layout:
The code calculates the number of rows needed for subplotting, ensuring all interaction pairs have a dedicated plot.
plt.subplots(rows, cols, figsize=(10, 10*rows/3))
: Initializes a grid of subplots with appropriate dimensions. Thefigsize
is set to maintain a consistent and readable plot size regardless of the number of rows.
Plotting Each Interaction Pair:
The function iterates over each interaction pair, plotting them in their respective subplot locations.
Within each plot, it further iterates over the two levels (low and high, represented by [-1, 1]) of the first factor in the pair.
The mean of the result column for each level of the second factor is calculated using
subset.groupby(interaction[1])[result_column].mean()
. This mean represents the average response at each combination of factor levels.
Enhancing Plot Readability:
Titles, legends, and grid lines are added to each subplot for clarity.
Extra subplots (if any) are disabled using
axs[row, col].axis('off')
to maintain a clean layout.
Final Adjustments and Display:
plt.tight_layout()
adjusts the spacing between plots to avoid overlapping elements.plt.show()
displays the complete set of interaction plots, offering a visual representation of how each pair of factors interacts and affects the response variable.
interaction_point_plot(excel_file, result_column)
The function is used similar to the main_effects_plot function above. It will create a point plot, where the relationship between factors is indicated by the behavior of the lines. Parallel lines imply no interactions between factors, while diverging, converging, or crossing lines signal significant interactions.
Model building / ANOVA
df = pd.read_excel(excel_file) # Fit the model formula = 'Filtration_rate ~ T + CoF + P + RPM + T:CoF + T:RPM' model = ols(formula, data=df).fit() # Perform ANOVA and print the results anova_table = sm.stats.anova_lm(model, typ=1) anova_table
After getting an overview of the data, we can conduct the model building process. This is what we can do with the code above. The process involves fitting a linear model to the data and then conducting ANOVA to test the significance of each factor and interaction. We do this to test the assumptions we built from the visualization step. Here's a breakdown:
Loading Data:
df = pd.read_excel(excel_file)
loads the experimental data from the Excel file (specified by the variableexcel_file
) into a pandas DataFramedf
.
Fitting the Linear Model:
The
formula = 'Filtration_rate ~ T + CoF + P + RPM + T:CoF + T:RPM'
line defines the model formula. This formula specifies that the 'Filtration_rate' is modeled as a function of the factors 'T', 'CoF', 'P' and 'RPM', along with the interactions 'T:CoF' and 'T:RPM'. The '+' symbol adds main effects, and ':' specifies interactions between factors.model = ols(formula, data=df).fit()
fits an Ordinary Least Squares (OLS) regression model using the specified formula. Theols
function from the statsmodels library is used here, withdata=df
indicating that the data for the model comes from the DataFramedf
. The.fit()
method fits the model to the data and returns the fitted model objectmodel
.
Performing ANOVA:
anova_table = sm.stats.anova_lm(model, typ=1)
performs ANOVA on the fitted modelmodel
using Type I sum of squares. This function is part of the statsmodels library (abbreviated assm
). Theanova_lm
function computes the ANOVA table for the model, which includes statistics such as the sum of squares, degrees of freedom, mean square, F-value, and the P-value for each factor and interaction in the model.The ANOVA table, stored in
anova_table
, helps determine the statistical significance of each factor and interaction's effect on the response variable. A low P-value (typically <0.05) indicates that a factor or an interaction has a statistically significant effect on the response variable.
Iteration:
The here described process is an iterative process. The formula in step 2 is optimized until it contains only significant parameters (p-value < 0.05). The previous performed visualization gives a good indication about which parameters might be significant and which not.
Model control
The model control / diagnostics step follows the model building step. The function diagnostic_plots
creates three diagnostic plots to assess the fit of our model. Here's a breakdown of each section of the code:
def diagnostic_plots(model): # Extract residuals and predicted values from the model residuals = model.resid predicted = model.fittedvalues # Create subplots fig, axs = plt.subplots(1, 3, figsize=(10, 5)) # Residuals vs Predicted axs[0].scatter(predicted, residuals, edgecolors='k', facecolors='none') axs[0].axhline(y=0, color='k', linestyle='dashed', linewidth=1) axs[0].set_title('Residuals vs. Predicted') axs[0].set_xlabel('Predicted values') axs[0].set_ylabel('Residuals') # Residuals vs. Runs (Order of Data Collection) axs[1].scatter(range(len(residuals)), residuals, edgecolors='k', facecolors='none') axs[1].axhline(y=0, color='k', linestyle='dashed', linewidth=1) axs[1].set_title('Residuals vs. Run') axs[1].set_xlabel('Run') axs[1].set_ylabel('Residuals') # Q-Q plot sm.qqplot(residuals, line='45', fit=True, ax=axs[2]) axs[2].set_title('Q-Q Plot') plt.tight_layout() plt.show()
Function Definition:
diagnostic_plots(model)
defines the function with a single parametermodel
, which is the previously - during the model building step - fitted linear model. It contains all the necessary information that are needed to create the diagnostic plots.
Extracting Residuals and Predicted Values:
residuals = model.resid
extracts the residuals from the model, which are the differences between the observed and predicted values.predicted = model.fittedvalues
extracts the predicted values from the model, which are the values predicted by the regression line.
Setting Up Subplots:
fig, axs = plt.subplots(1, 3, figsize=(10, 5))
initializes a figure with a 1x3 grid of subplots (i.e., three plots in a row) and sets the figure size.
Plotting Residuals vs. Predicted:
The first subplot (
axs[0]
) is a scatter plot of predicted values against residuals. Ideally, the points should be randomly dispersed around the horizontal liney=0
, without any discernible pattern.axs[0].axhline(y=0, ...)
adds a horizontal dashed line aty=0
to help visualize the zero residual line.
Plotting Residuals vs. Runs:
The second subplot (
axs[1]
) is a scatter plot of the order of data collection (or "runs") against residuals. This plot checks for any patterns in the residuals that may suggest dependence on the order of data collection, which could indicate time-related trends or serial correlation.The residuals are plotted against
range(len(residuals))
, which generates a sequence of integers representing the order of data collection.
Q-Q Plot:
The third subplot (
axs[2]
) is a Quantile-Quantile (Q-Q) plot generated bysm.qqplot(residuals, ...)
, used to assess the normality of the residuals. Points following a straight line (the 45-degree line indicated byline='45'
) suggest that the residuals are normally distributed.
Final Adjustments and Display:
plt.tight_layout()
adjusts the spacing between the subplots to prevent overlap.plt.show()
displays the figure with the three diagnostic plots.
diagnostic_plots(model)
The code diagnostic_plots(model) calls the previously defined function and passes the model object to it. As already mentioned, this model object should be a fitted regression model from which the function will extract residuals and predicted values to generate the three diagnostic plots. If we are not satisfied with the outcome of the diagnostic plots we may have to transform the data, add additional interaction terms or fit a higher order model (i.e. add a quadratic term). However, the latter might require to run additional experiments.
Concluding the design
If the model control was successful and we are happy with the result, we can draw a conclusion and prepare our results for a presentation or a report. The following steps might help us with that.
Model summary
print(model.summary())
The code print(model.summary()) displays a comprehensive summary containing statistical and diagnostic information for the regression model stored in model. This summary includes key performance metrics of the model and detailed insights into the statistical significance of each model coefficient, reflecting the influence of individual factors and interactions of factors on the filtration rate.
3D plots
def plot_3D_surface(model, data, x_name, y_name, z_name, held_factor, held_value, title): x_range = np.linspace(-1, 1, 100) y_range = np.linspace(-1, 1, 100) x_grid, y_grid = np.meshgrid(x_range, y_range) df_pred = pd.DataFrame({ x_name: x_grid.ravel(), y_name: y_grid.ravel(), held_factor: [held_value] * x_grid.size }) df_pred = sm.add_constant(df_pred) # Add a constant to match the model's expectations Z = model.predict(df_pred).values.reshape(x_grid.shape) fig = plt.figure(figsize=(10, 7)) ax = fig.add_subplot(111, projection='3d') # Plotting the surface ax.plot_surface(x_grid, y_grid, Z, cmap='viridis', alpha=0.6) # Plotting the measured points mask = data[held_factor] == held_value ax.scatter(data[mask][x_name], data[mask][y_name], data[mask][z_name], color='r', marker='o') ax.set_xlabel(x_name) ax.set_ylabel(y_name) ax.set_zlabel(z_name) ax.set_title(title) plt.show()
The plot_3D_surface
function is designed to create a 3D surface plot that visualizes the relationship between two factors and one response variable, holding other factors constant. Here's how it works:
Function Parameters:
model
: The fitted regression model object used for predictions.data
: The DataFrame containing the observed data.x_name
,y_name
,z_name
: The names of the columns indata
representing the two predictors (x and y) and the response variable (z), respectively.held_factor
: The name of another factor in the model that is held constant for this visualization.held_value
: The specific value at whichheld_factor
is held constant.title
: The title for the plot.
Generating Prediction Grids:
x_range
andy_range
create arrays of 100 points each, ranging from -1 to 1, representing the standardized range of values for the two predictors.np.meshgrid
generates two 2D grid arrays fromx_range
andy_range
, which are used to create a grid of (x, y) coordinates.
Preparing Prediction Data:
A new DataFrame
df_pred
is created with columns forx_name
,y_name
, and theheld_factor
, populated with the grid values and the constantheld_value
, respectively.sm.add_constant(df_pred)
adds a constant term to the DataFrame if the model includes an intercept.
Making Predictions:
model.predict(df_pred)
uses the fitted model to predict the response variableZ
over the grid of (x, y) values, holdingheld_factor
constant atheld_value
. The predictions are reshaped to match the grid's shape for plotting.
Creating the 3D Plot:
A 3D subplot is initialized with
projection='3d'
.ax.plot_surface
plots the predicted response surface over the (x, y) grid, using a color map (cmap='viridis'
) and setting transparency withalpha
.
Overlaying Observed Data Points:
The
mask
is used to filter rows indata
whereheld_factor
equalsheld_value
.ax.scatter
plots these observed data points on the 3D plot, making it easier to compare the predicted surface with actual observations.
Customizing the Plot:
Axis labels are set to the names of the x, y, and z variables, and a title is added to the plot.
Displaying the Plot:
plt.show()
renders the figure, displaying the 3D surface plot with the observed data points overlaid.
# Replace 'response' with the name of your dependent variable plot_3D_surface(model, df, 'T', 'RPM', 'Filtration_rate', 'CoF', +1, "3D Surface of Interaction between T and RPM with CoF=+1") plot_3D_surface(model, df, 'T', 'RPM', 'Filtration_rate', 'CoF', -1, "3D Surface of Interaction between T and RPM with CoF=-1")
This code block demonstrates the use of the plot_3D_surface
function to visualize the interaction effects between two factors, 'T' and 'RPM', on the response variable 'Filtration_rate', while holding another factor, 'CoF', at two different levels: +1 and -1. Here's a breakdown:
First Plot:
The function
plot_3D_surface
is called with the fitted modelmodel
, the datadf
, the names of the two interacting factors 'T' and 'RPM', and the response variable 'Filtration_rate'.'CoF' is specified as the held factor, with its value set to +1.
The title "3D Surface of Interaction between T and RPM with CoF=+1" is set.
Second Plot:
A similar call to
plot_3D_surface
is made, but this time 'CoF' is held at -1.
By creating these two plots, the code block aims to visually compare how the interaction between 'T' and 'RPM' influences the 'Filtration_rate' under two different conditions of 'CoF'.
Contour plots
# Function to create contour plot def plot_contour(x_name, y_name, held_factor, held_value, title): x_range = np.linspace(-1, 1, 100) y_range = np.linspace(-1, 1, 100) x_grid, y_grid = np.meshgrid(x_range, y_range) predictions = model.predict(pd.DataFrame({ x_name: x_grid.ravel(), y_name: y_grid.ravel(), held_factor: held_value, })) Z = predictions.values.reshape(x_grid.shape) plt.figure(figsize=(7, 5)) contour = plt.contourf(x_grid, y_grid, Z, 20, cmap='viridis') plt.colorbar(contour) plt.title(title) plt.xlabel(x_name) plt.ylabel(y_name) plt.show()
The function plot_contour
creates a so called contour plot to visualize the interaction effects between two factors on a response variable, with one additional factor held constant at a specified value. A contour plot achieves the same than a 3 dimensional plot but with a slightly different vizual representation. Here's a breakdown of how the function operates:
Function Definition:
plot_contour
is defined with five parameters:x_name
andy_name
for the names of the two factors to be plotted on the x and y axes,held_factor
for the name of the factor that is held constant,held_value
for the value at whichheld_factor
is held constant, andtitle
for the plot's title.
Setting Up the Grid:
x_range
andy_range
create arrays of 100 points each, spanning from -1 to 1, which represent the standardized range of values for the two factors.np.meshgrid
is used to generate two 2D grid arrays fromx_range
andy_range
, which will be used for plotting.
Generating Predictions:
A new DataFrame is created on the fly within the
model.predict
call, containing columns forx_name
,y_name
, and theheld_factor
, populated with the grid values and the constantheld_value
, respectively.model.predict
is then used to predict the response variable values over the grid, based on the specified model. These predictions are stored inpredictions
.
Reshaping Predictions:
The predicted values in
predictions
are reshaped to match the shape of thex_grid
andy_grid
, resulting in a 2D arrayZ
that contains the predicted response variable values across the grid.
Creating the Contour Plot:
plt.figure
initializes a new figure with specified dimensions.plt.contourf
creates a filled contour plot on the grid, withZ
providing the heights (values) at each point. The20
argument specifies the number of contour levels to draw, andcmap='viridis'
sets the color map.plt.colorbar
adds a color bar to the side of the plot to indicate the scale of the response variable.
Customizing the Plot:
The plot is titled using
plt.title
, and the x and y axes are labeled withplt.xlabel
andplt.ylabel
, corresponding to the names of the two factors.
Displaying the Plot:
plt.show()
renders and displays the contour plot.
plot_contour('T', 'RPM', 'CoF', +1, "Contour Plot of Interaction between T and RPM with C=+1") plot_contour('T', 'RPM', 'CoF', -1, "Contour Plot of Interaction between T and RPM with C=-1")
These two lines of code use the plot_contour
function to generate contour plots:
First Contour Plot:
plot_contour('T', 'RPM', 'CoF', +1, "Contour Plot of Interaction between T and RPM with C=+1")
generates a contour plot where 'T' is plotted on the x-axis, 'RPM' on the y-axis, and 'CoF' is held at a value of +1.The title of the plot is "Contour Plot of Interaction between T and RPM with C=+1".
Second Contour Plot:
plot_contour('T', 'RPM', 'CoF', -1, "Contour Plot of Interaction between T and RPM with C=-1")
creates a similar plot but with 'CoF' held at a value of -1.The title of the plot is "Contour Plot of Interaction between T and RPM with C=-1".
Conclusion
In summary, this blog post showcased how Python and its libraries offer a straightforward and cost-effective solution for conducting Design of Experiments (DoE) and statistical analysis. The illustrated functions enable you to easily apply these techniques to your projects.
A great benefit of python is its versatility that extends far beyond what we've explored, encompassing a broad range of statistical, machine learning, and data visualization capabilities. This makes Python an invaluable tool for researchers and engineers, allowing for sophisticated analyses and insightful visualizations that can drive decision-making and innovation.
I will be exploring this further in the future. So stay tuned!