In this blog, we have discussed many examples of pre-processing techniques which are specifically suited for spectral data. Specific spectroscopy pre-processing may involve developing functions that are not common to other machine learning communities, and therefore not included in standard libraries, for instance the (excellent!) scikit-learn. Integrating custom functions in a way that makes them compatible with the scikit-learn API is nevertheless possible. In this post, we’ll discuss a specific example of that: we’ll include Multiplicative Scatter Correction (or MSC) in a custom class that can be dropped into a scikit-learn custom pipeline.
There are (at least) two important reasons for wanting to create custom classes:
- If the custom pre-processing function you want to add depends on the training data, it needs to be fitted to the training data only, and then used to transform the test data (or any other subsequent data) before inference.
- If you are training models for production, it is way better to include custom functions into a single scikit-learn pipeline, which keep things simple when exporting and loading models for further inference.
Let me explain.
A short digression
Consider the first reason in the list above. MSC is a function that, in general, depends on the training data (for a refresher on the MSC, take a look at our introductory post on the subject). MSC, as the name implies, aims at removing unwanted scatter effects from the spectra. For that purpose, it requires a reference spectrum which is, ideally, a spectrum that is free of scatter effects.
Now, generally such an ideal spectrum is not available, but we can assume that the average spectrum of all the samples in our dataset is a good approximation to the ideal spectrum. The expectation is that, if scatter effects are randomly different from sample to sample, they will average out if we take the mean spectrum.
This is excellent news for the applicability of the method, but it makes MSC dependent on the training spectra through their average. Hence, if we were to first apply MSC and then to split the data into a train and test set, we would create a data leakage problem. Some features of the test set would be indirectly included in the training set via the prior average.
Including the MSC into a custom pipeline, and then fitting the pipeline on the training set only (or, cyclically, in cross-validation) will avoid the issue.
A second example is that of Savitky-Golay smoothing. SG smoothing is done on each spectrum independently. For this reason we could (correctly) apply it before splitting the data. No leakage problem would be created. However, including SG smoothing in a pipeline, would make things easier in production, as per ‘Reason #2’ mentioned above.
Developing custom classes
OK, the digression is over, let’s get into the main topic.
The great folks at scikit-learn provided the APIs of scikit-learn objects, where we can find information on how to build a custom function, compatible with the rest of their API. That document is the basis for our discussion here. But first, if you need to refresh some of the topic before starting:
- More info on MSC is covered in this Nirpy post. Here we are going to recast the same code from that post into a form that can be inserted into a scikit-learn pipeline.
- Of course, you may also want to glance at scikit-learn Pipelines before starting.
In the scikit-learn terminology, MSC is a transformer, i.e. a function that modifies the data ‘in a supervised or unsupervised way’ before an estimator (for instance a PLS regressor) can be fitted.
In order to work with the scikit-learn API, a transformer needs to have, at the very least, a fit function and a transform function. These functions will be wrapped into a Python class. In addition to the documentation linked above, excellent examples on how to do that are provided in the scikit-learn Github page
Using the examples as template, here’s the MSC custom class which is compatible with the scikit-learn API. I’m going to write the class first, then explain its main components, and finally wrap it up (including listing the necessary imports) into a Python script.
Here’s the class.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
class msc(TransformerMixin, BaseEstimator): def __init__(self, reference=None): self.reference = reference def fit(self, X, y=None): X = self._validate_data(X, accept_sparse=True) # Mean centre correction X -= X.mean(axis=1)[:,np.newaxis] # If the reference doesn't exist, set it to the average spectrum if self.reference is None: self.reference = X.mean(axis=0) self.is_fitted_ = True # Fit function must always return self return self def transform(self, X, y=None): check_is_fitted(self) # Input validation X = check_array(X) # Mean centre correction X -= X.mean(axis=1)[:,np.newaxis] # MSC Correction Xmsc = np.zeros_like(X) for i in range(X.shape[0]): # Run regression fit = np.polyfit(self.reference, X[i,:], 1, full=True) # Apply correction Xmsc[i,:] = (X[i,:] - fit[0][1]) / fit[0][0] return Xmsc |
A Python class needs to have an __init__ function. The only argument in our case is the (optional) reference array. As discussed above, in practice the reference array is estimated by the average of the spectra that are fitted, but it can be passed as an external array. We’ll pass this dependency specifically in the __init__ function.
Then comes the fit function. Since this is not an estimator, there is nothing to fit, but we use the fit function to produce a mean-centred version of our array and (if required) calculate the reference array as the average of the input arrays.
Finally the transform function is where all the action is. The transform function can only be called after the estimator has been fitted, otherwise an error will occur. The transform function will also calculate a mean-centred version of the spectra, then apply the MSC correction proper.
MSC pipeline in action
We are now ready to look at a usage example. Here we 1) import the relevant libraries, 2) import some data, 3) define a pipeline containing MSC correction and PLS regression, and 4) estimate the optimal number of latent variables with a GridSearchCV.
Let’s begin with the imports (note that some of the imports are actually needed for the custom class written above)
1 2 3 4 5 6 7 8 |
import pandas as pd import numpy as np from sklearn.pipeline import Pipeline from sklearn.cross_decomposition import PLSRegression from sklearn.model_selection import GridSearchCV from sklearn.utils.validation import check_X_y, check_array, check_is_fitted from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin |
Now we load some data from our GitHub repo and define the arrays
1 2 3 4 5 6 7 |
# Read data url = 'https://raw.githubusercontent.com/nevernervous78/nirpyresearch/master/data/peach_spectra_brix.csv' data = pd.read_csv(url) # Define arrays X = data.values[:,1:].astype("float32") y = data["Brix"].values.astype("float32") |
Finally, we define and fit a pipeline
1 2 3 4 5 6 7 8 9 |
# GridSearchCV to estimate optimal number of latent variables in cross-validation pipe = Pipeline([('msc', msc()), ('pls', PLSRegression() )]) parameters = {'pls__n_components': np.arange(1,11,1)} plscv = GridSearchCV(pipe, parameters, scoring = 'neg_mean_squared_error') # Fit the instance plscv.fit(X, y) # Print the result print(plscv.best_estimator_) |
The GridSearchCV function is designed to look for the number of latent variables that minimised the mean squared error. Running the code above, you should find that n=5 is the optimal number of latent variables.
Note that, should you wish to use this model to make inferences on a new (hypothetical) test set X_test , you will simply do plscv.predict(X_test) and the MSC correction will be applied in the background.
In addition, should you wish to export this model, the fitted MSC function (including the reference array) will be saved together with the custom pipeline.
That’s all we have time for today. As always, thanks for reading and until next time.
Daniel
**
Feature image by Carlos / Saigon – Vietnam from Pixabay.