Note
Click here to download the full example code
Datasets can often contain components of that require different feature extraction and processing pipelines. This scenario might occur when:
This example demonstrates how to use sklearn.compose.ColumnTransformer
on a dataset containing different types of features. We use the 20-newsgroups dataset and compute standard bag-of-words features for the subject line and body in separate pipelines as well as ad hoc features on the body. We combine them (with weights) using a ColumnTransformer and finally train a classifier on the combined set of features.
The choice of features is not particularly helpful, but serves to illustrate the technique.
Out:
precision recall f1-score support 0 0.96 0.62 0.76 494 1 0.25 0.84 0.39 76 micro avg 0.65 0.65 0.65 570 macro avg 0.61 0.73 0.57 570 weighted avg 0.87 0.65 0.71 570
# Author: Matt Terry <[email protected]> # # License: BSD 3 clause from __future__ import print_function import numpy as np from sklearn.base import BaseEstimator, TransformerMixin from sklearn.datasets import fetch_20newsgroups from sklearn.datasets.twenty_newsgroups import strip_newsgroup_footer from sklearn.datasets.twenty_newsgroups import strip_newsgroup_quoting from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction import DictVectorizer from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics import classification_report from sklearn.pipeline import Pipeline from sklearn.compose import ColumnTransformer from sklearn.svm import LinearSVC class TextStats(BaseEstimator, TransformerMixin): """Extract features from each document for DictVectorizer""" def fit(self, x, y=None): return self def transform(self, posts): return [{'length': len(text), 'num_sentences': text.count('.')} for text in posts] class SubjectBodyExtractor(BaseEstimator, TransformerMixin): """Extract the subject & body from a usenet post in a single pass. Takes a sequence of strings and produces a dict of sequences. Keys are `subject` and `body`. """ def fit(self, x, y=None): return self def transform(self, posts): # construct object dtype array with two columns # first column = 'subject' and second column = 'body' features = np.empty(shape=(len(posts), 2), dtype=object) for i, text in enumerate(posts): headers, _, bod = text.partition('\n\n') bod = strip_newsgroup_footer(bod) bod = strip_newsgroup_quoting(bod) features[i, 1] = bod prefix = 'Subject:' sub = '' for line in headers.split('\n'): if line.startswith(prefix): sub = line[len(prefix):] break features[i, 0] = sub return features pipeline = Pipeline([ # Extract the subject & body ('subjectbody', SubjectBodyExtractor()), # Use ColumnTransformer to combine the features from subject and body ('union', ColumnTransformer( [ # Pulling features from the post's subject line (first column) ('subject', TfidfVectorizer(min_df=50), 0), # Pipeline for standard bag-of-words model for body (second column) ('body_bow', Pipeline([ ('tfidf', TfidfVectorizer()), ('best', TruncatedSVD(n_components=50)), ]), 1), # Pipeline for pulling ad hoc features from post's body ('body_stats', Pipeline([ ('stats', TextStats()), # returns a list of dicts ('vect', DictVectorizer()), # list of dicts -> feature matrix ]), 1), ], # weight components in ColumnTransformer transformer_weights={ 'subject': 0.8, 'body_bow': 0.5, 'body_stats': 1.0, } )), # Use a SVC classifier on the combined features ('svc', LinearSVC()), ]) # limit the list of categories to make running this example faster. categories = ['alt.atheism', 'talk.religion.misc'] train = fetch_20newsgroups(random_state=1, subset='train', categories=categories, ) test = fetch_20newsgroups(random_state=1, subset='test', categories=categories, ) pipeline.fit(train.data, train.target) y = pipeline.predict(test.data) print(classification_report(y, test.target))
Total running time of the script: ( 0 minutes 1.374 seconds)
Gallery generated by Sphinx-Gallery
© 2007–2018 The scikit-learn developers
Licensed under the 3-clause BSD License.
http://scikit-learn.org/stable/auto_examples/compose/plot_column_transformer.html