Multivariate statistics with MNE

MEG skills - Multivariate statistics with MNE


MNE : Demo of multivariate statistics (decoding / MVPA)


Author : Alexandre Gramfort alexandre.gramfort@telecom-paristech.fr


In [1]: # add plot inline in the page

%matplotlib inline


First, load the mne package:


In [2]: import mne


We set the log-level to 'WARNING' so the output is less verbose


In [3]: mne.set_log_level('WARNING')



Access raw data


Now we import the sample dataset. If you don't already have it, it will be downloaded automatically (but be patient approx. 2GB)


In [4]: data_path = '/Users/alex/Sync/karolinska_teaching/'

raw_fname = data_path + '/meg/workshop_visual_sss.fif'


print raw_fname

/Users/alex/Sync/karolinska_teaching//meg/workshop_visual_sss.fif


Read data from file:


In [5]: raw = mne.fiff.Raw(raw_fname)

print raw

<Raw  |  n_channels x n_times : 322 x 988000>


!!! we need to fix now tiny acquisition problems


In [6]: def fix_info(raw):

    raw.info['chs'][raw.ch_names.index('BIO001')]['kind'] = mne.fiff.constants.FIFF.FIFFV_EOG_CH

    raw.info['chs'][raw.ch_names.index('BIO002')]['kind'] = mne.fiff.constants.FIFF.FIFFV_EOG_CH

    raw.info['chs'][raw.ch_names.index('BIO003')]['kind'] = mne.fiff.constants.FIFF.FIFFV_ECG_CH


fix_info(raw)


In [7]: print raw.info['bads']

[]



Define and read epochs


First extract events:

In [8]: events = mne.find_events(raw, stim_channel='STI101', verbose=True)

Reading 0 ... 987999  =      0.000 ...   987.999 secs...

[done]

1473 events found

Events id: [10 11 12 13 14 16 17 20 21 22 23 24 26 27]


Look at the design in a graphical way:


In [9]: plt.plot((events[:, 0] - raw.first_samp) / raw.info['sfreq'], events[:, 2], '.');

plt.xlabel('time (s)')

plt.axis([0.1168, 41.4003, 9.7295, 17.3887])


Out[9]: [0.1168, 41.4003, 9.7295, 17.3887]















From raw to epochs


Define epochs parameters:


In [10]: event_id = dict(left=14, right=24)  # event trigger and conditions

tmin = -0.3  # start of each epoch

tmax = 0.5  # end of each epoch

baseline = (-0.3, -0.15)


reject = dict(grad=4000e-13, eog=150e-6)


picks = mne.fiff.pick_types(raw.info, meg='grad', eeg=False, eog=True, stim=False, exclude='bads')


epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks, baseline=baseline, reject=reject,preload=True)  # with preload

print epochs

<Epochs  |  n_events : 111 (all good), tmin : -0.3 (s), tmax : 0.5 (s), baseline : (-0.3, -0.15),'left': 62, 'right': 49>


In [11]: epochs.plot_drop_log()















Out[11]: 92.464358452138498


Look at the ERF and contrast between left and rigth response


In [12]: evoked_left = epochs['left'].average()

evoked_right = epochs['right'].average()

evoked_contrast = evoked_left - evoked_right


In [13]: fig = evoked_left.plot()

fig = evoked_right.plot()

fig = evoked_contrast.plot()







































Plot some topographies


In [14]: import numpy as np

times = np.linspace(-0.1, 0.5, 10)

fig = evoked_left.plot_topomap(times=times, ch_type='grad')

fig = evoked_right.plot_topomap(times=times, ch_type='grad')

fig =evoked_contrast.plot_topomap(times=times, ch_type='grad')












Now let's see if we can classify single trials with an SVM


To have a chance at 50% accuracy equalize epoch count in each condition


In [15]: epochs_list = [epochs[k] for k in event_id]

mne.epochs.equalize_epoch_counts(epochs_list)


A classifier takes as input an x and return y (-1 or 1). Here x will be the data at one time point on all gradiometers (hence the term multivariate). We work with all sensors jointly and try to find a discriminative pattern between 2 conditions to predict the class.


In [16]: n_times = len(epochs.times)


# Take only the data channels (here the gradiometers)

data_picks = mne.fiff.pick_types(epochs.info, meg=True, exclude='bads')


# Make arrays X and y such that :

# X is 3d with X.shape[0] is the total number of epochs to classify

# y is filled with integers coding for the class to predict

# We must have X.shape[0] equal to y.shape[0]


X = [e.get_data()[:, data_picks, :] for e in epochs_list]

y = [k * np.ones(len(this_X)) for k, this_X in enumerate(X)]

X = np.concatenate(X)

y = np.concatenate(y)


In [17]: print X.shape, y.shape

print y

(98, 204, 801) (98,)

[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.

  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.

  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.  1.  1.  1.  1.

  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.

  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.

  1.  1.  1.  1.  1.  1.  1.  1.]


For classification we will use the scikit-learn package (http://scikit-learn.org/)


In [18]: from sklearn.svm import SVC

from sklearn.cross_validation import cross_val_score, ShuffleSplit


# Define an SVM classifier (SVC) with a linear kernel

clf = SVC(C=1, kernel='linear')


Define a monte-carlo cross-validation generator (to reduce variance):


In [19]: cv = ShuffleSplit(len(X), 10, test_size=0.2, random_state=42)


The goal is going to be to learn on 80% of the epochs and evaluate on the remaining 20% of trials if we can predict accuratly.


In [20]: X_2d = X.reshape(len(X), -1)

X_2d = X_2d / np.std(X_2d)

scores_full = cross_val_score(clf, X_2d, y, cv=cv, n_jobs=1)

print "Classification score: %s (std. %s)"

% (np.mean(scores_full), np.std(scores_full))

Classification score: 0.92 (std. 0.0458257569496)


It's also possible to run the same decoder and each time point to know when in time the conditions can be better classified:


In [21]: scores = np.empty(n_times)

std_scores = np.empty(n_times)


for t in range(n_times):

    Xt = X[:, :, t]

    # Standardize features

    Xt -= Xt.mean(axis=0)

    Xt /= Xt.std(axis=0)

    # Run cross-validation

    scores_t = cross_val_score(clf, Xt, y, cv=cv, n_jobs=1)

    scores[t] = scores_t.mean()

    std_scores[t] = scores_t.std()


A bit of rescaling


In [22]: times = 1e3 * epochs.times # to have times in ms

scores *= 100  # make it percentage accuracy

std_scores *= 100


Now a bit of plotting


In [23]: plt.plot(times, scores, label="Classif. score")

plt.axhline(50, color='k', linestyle='--', label="Chance level")

plt.axvline(0, color='r', label='stim onset')

plt.axhline(100 * np.mean(scores_full), color='g', label='Accuracy full epoch')

plt.legend()

hyp_limits = (scores - std_scores, scores + std_scores)

plt.fill_between(times, hyp_limits[0], y2=hyp_limits[1], color='b', alpha=0.5)

plt.xlabel('Times (ms)')

plt.ylabel('CV classification score (% correct)')

plt.ylim([30, 100])

plt.title('Sensor space decoding')


Out[23]: <matplotlib.text.Text at 0x1114c3810>















Look at generalization over time


We can test how much the "decodability" stays over time.

Have a look at : http://martinos.org/mne/dev/auto_examples/decoding/plot_decoding_time_generalization.html

to get an idea of what to expect.


In [24]: from mne.decoding import time_generalization


# Compute Area Under the Curver (AUC) Receiver Operator Curve (ROC) score

# of time generalization. A perfect decoding would lead to AUCs of 1.

# Chance level is at 0.5.

# The default classifier is a linear SVM (C=1) after feature scaling.

scores = time_generalization(epochs_list, clf=None, cv=5, scoring="roc_auc", shuffle=True, n_jobs=4)

/Users/alex/local/lib/python/site-packages/sklearn/cross_validation.py:1212: DeprecationWarning: check_cv will return indices instead of boolean masks from 0.17

  'masks from 0.17', DeprecationWarning)

/Users/alex/local/lib/python/site-packages/sklearn/cross_validation.py:62: DeprecationWarning: The indices parameter is deprecated and will be removed (assumed True) in 0.17

  stacklevel=1)


Now visualize


In [25]: times = 1e3 * epochs.times # convert times to ms


plt.imshow(scores, interpolation='nearest', origin='lower',

           extent=[times[0], times[-1], times[0], times[-1]],

           vmin=0., vmax=1.)

plt.xlabel('Times Test (ms)')

plt.ylabel('Times Train (ms)')

plt.title('Time generalization (%s vs. %s)' % tuple(event_id.keys()))

plt.axvline(0, color='k')

plt.axhline(0, color='k')

plt.colorbar()


Out[25]: <matplotlib.colorbar.Colorbar instance at 0x113b335a8>















In [26]: plt.plot(times, np.diag(scores), label="Classif. score")

plt.axhline(0.5, color='k', linestyle='--', label="Chance level")

plt.axvline(0, color='r', label='stim onset')

plt.legend()

plt.xlabel('Time (ms)')

plt.ylabel('ROC classification score')

plt.title('Decoding (%s vs. %s)' % tuple(event_id.keys()))


Out[26]: <matplotlib.text.Text at 0x113b39890>















Exercise


Can you improve the performance using full epochs and a common spatial pattern (CSP) used by most BCI systems?

Have a look at the example

http://martinos.org/mne/dev/auto_examples/decoding/plot_decoding_csp_space.html


In []: