class sklearn.model_selection.StratifiedShuffleSplit(n_splits=10, test_size=’default’, train_size=None, random_state=None)
[source]
Stratified ShuffleSplit cross-validator
Provides train/test indices to split data in train/test sets.
This cross-validation object is a merge of StratifiedKFold and ShuffleSplit, which returns stratified randomized folds. The folds are made by preserving the percentage of samples for each class.
Note: like the ShuffleSplit strategy, stratified random splits do not guarantee that all folds will be different, although this is still very likely for sizeable datasets.
Read more in the User Guide.
Parameters: |
|
---|
>>> from sklearn.model_selection import StratifiedShuffleSplit >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]]) >>> y = np.array([0, 0, 0, 1, 1, 1]) >>> sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0) >>> sss.get_n_splits(X, y) 5 >>> print(sss) StratifiedShuffleSplit(n_splits=5, random_state=0, ...) >>> for train_index, test_index in sss.split(X, y): ... print("TRAIN:", train_index, "TEST:", test_index) ... X_train, X_test = X[train_index], X[test_index] ... y_train, y_test = y[train_index], y[test_index] TRAIN: [5 2 3] TEST: [4 1 0] TRAIN: [5 1 4] TEST: [0 2 3] TRAIN: [5 0 2] TEST: [4 3 1] TRAIN: [4 1 0] TEST: [2 3 5] TRAIN: [0 5 1] TEST: [3 4 2]
get_n_splits ([X, y, groups]) | Returns the number of splitting iterations in the cross-validator |
split (X, y[, groups]) | Generate indices to split data into training and test set. |
__init__(n_splits=10, test_size=’default’, train_size=None, random_state=None)
[source]
get_n_splits(X=None, y=None, groups=None)
[source]
Returns the number of splitting iterations in the cross-validator
Parameters: |
|
---|---|
Returns: |
|
split(X, y, groups=None)
[source]
Generate indices to split data into training and test set.
Parameters: |
|
---|---|
Yields: |
|
Randomized CV splitters may return different results for each call of split. You can make the results identical by setting random_state
to an integer.
sklearn.model_selection.StratifiedShuffleSplit
© 2007–2018 The scikit-learn developers
Licensed under the 3-clause BSD License.
http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html