Author : Alexandre Gramfort
# add plot inline in the page
%matplotlib inline
First, load the mne package:
import mne
We set the log-level to 'WARNING' so the output is less verbose
mne.set_log_level('WARNING')
Now we import the sample dataset. If you don't already have it, it will be downloaded automatically (but be patient approx. 2GB)
data_path = '/Users/alex/Sync/karolinska_teaching/'
raw_fname = data_path + '/meg/workshop_visual_sss.fif'
print raw_fname
Read data from file:
raw = mne.fiff.Raw(raw_fname)
print raw
!!! we need to fix now tiny acquisition problems
def fix_info(raw):
raw.info['chs'][raw.ch_names.index('BIO001')]['kind'] = mne.fiff.constants.FIFF.FIFFV_EOG_CH
raw.info['chs'][raw.ch_names.index('BIO002')]['kind'] = mne.fiff.constants.FIFF.FIFFV_EOG_CH
raw.info['chs'][raw.ch_names.index('BIO003')]['kind'] = mne.fiff.constants.FIFF.FIFFV_ECG_CH
fix_info(raw)
print raw.info['bads']
First extract events:
events = mne.find_events(raw, stim_channel='STI101', verbose=True)
Look at the design in a graphical way:
plt.plot((events[:, 0] - raw.first_samp) / raw.info['sfreq'], events[:, 2], '.');
plt.xlabel('time (s)')
plt.axis([0.1168, 41.4003, 9.7295, 17.3887])
Define epochs parameters:
event_id = dict(left=14, right=24) # event trigger and conditions
tmin = -0.3 # start of each epoch
tmax = 0.5 # end of each epoch
baseline = (-0.3, -0.15)
reject = dict(grad=4000e-13, eog=150e-6)
picks = mne.fiff.pick_types(raw.info, meg='grad', eeg=False, eog=True,
stim=False, exclude='bads')
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
picks=picks, baseline=baseline, reject=reject,
preload=True) # with preload
print epochs
epochs.plot_drop_log()
Look at the ERF and contrast between left and rigth response
evoked_left = epochs['left'].average()
evoked_right = epochs['right'].average()
evoked_contrast = evoked_left - evoked_right
fig = evoked_left.plot()
fig = evoked_right.plot()
fig = evoked_contrast.plot()
Plot some topographies
import numpy as np
times = np.linspace(-0.1, 0.5, 10)
fig = evoked_left.plot_topomap(times=times, ch_type='grad')
fig = evoked_right.plot_topomap(times=times, ch_type='grad')
fig = evoked_contrast.plot_topomap(times=times, ch_type='grad')
To have a chance at 50% accuracy equalize epoch count in each condition
epochs_list = [epochs[k] for k in event_id]
mne.epochs.equalize_epoch_counts(epochs_list)
A classifier takes as input an x
and return y
(-1 or 1). Here x will be the data at one time point on all gradiometers (hence the term multivariate). We work with all sensors jointly and try to find a discriminative pattern between 2 conditions to predict the class.
n_times = len(epochs.times)
# Take only the data channels (here the gradiometers)
data_picks = mne.fiff.pick_types(epochs.info, meg=True, exclude='bads')
# Make arrays X and y such that :
# X is 3d with X.shape[0] is the total number of epochs to classify
# y is filled with integers coding for the class to predict
# We must have X.shape[0] equal to y.shape[0]
X = [e.get_data()[:, data_picks, :] for e in epochs_list]
y = [k * np.ones(len(this_X)) for k, this_X in enumerate(X)]
X = np.concatenate(X)
y = np.concatenate(y)
print X.shape, y.shape
print y
For classification we will use the scikit-learn package (http://scikit-learn.org/)
from sklearn.svm import SVC
from sklearn.cross_validation import cross_val_score, ShuffleSplit
# Define an SVM classifier (SVC) with a linear kernel
clf = SVC(C=1, kernel='linear')
Define a monte-carlo cross-validation generator (to reduce variance):
cv = ShuffleSplit(len(X), 10, test_size=0.2, random_state=42)
The goal is going to be to learn on 80% of the epochs and evaluate on the remaining 20% of trials if we can predict accuratly.
X_2d = X.reshape(len(X), -1)
X_2d = X_2d / np.std(X_2d)
scores_full = cross_val_score(clf, X_2d, y, cv=cv, n_jobs=1)
print "Classification score: %s (std. %s)" % (np.mean(scores_full), np.std(scores_full))
It's also possible to run the same decoder and each time point to know when in time the conditions can be better classified:
scores = np.empty(n_times)
std_scores = np.empty(n_times)
for t in range(n_times):
Xt = X[:, :, t]
# Standardize features
Xt -= Xt.mean(axis=0)
Xt /= Xt.std(axis=0)
# Run cross-validation
scores_t = cross_val_score(clf, Xt, y, cv=cv, n_jobs=1)
scores[t] = scores_t.mean()
std_scores[t] = scores_t.std()
A bit of rescaling
times = 1e3 * epochs.times # to have times in ms
scores *= 100 # make it percentage accuracy
std_scores *= 100
Now a bit of plotting
plt.plot(times, scores, label="Classif. score")
plt.axhline(50, color='k', linestyle='--', label="Chance level")
plt.axvline(0, color='r', label='stim onset')
plt.axhline(100 * np.mean(scores_full), color='g', label='Accuracy full epoch')
plt.legend()
hyp_limits = (scores - std_scores, scores + std_scores)
plt.fill_between(times, hyp_limits[0], y2=hyp_limits[1], color='b', alpha=0.5)
plt.xlabel('Times (ms)')
plt.ylabel('CV classification score (% correct)')
plt.ylim([30, 100])
plt.title('Sensor space decoding')
We can test how much the "decodability" stays over time.
Have a look at : http://martinos.org/mne/dev/auto_examples/decoding/plot_decoding_time_generalization.html
to get an idea of what to expect.
from mne.decoding import time_generalization
# Compute Area Under the Curver (AUC) Receiver Operator Curve (ROC) score
# of time generalization. A perfect decoding would lead to AUCs of 1.
# Chance level is at 0.5.
# The default classifier is a linear SVM (C=1) after feature scaling.
scores = time_generalization(epochs_list, clf=None, cv=5, scoring="roc_auc",
shuffle=True, n_jobs=4)
Now visualize
times = 1e3 * epochs.times # convert times to ms
plt.imshow(scores, interpolation='nearest', origin='lower',
extent=[times[0], times[-1], times[0], times[-1]],
vmin=0., vmax=1.)
plt.xlabel('Times Test (ms)')
plt.ylabel('Times Train (ms)')
plt.title('Time generalization (%s vs. %s)' % tuple(event_id.keys()))
plt.axvline(0, color='k')
plt.axhline(0, color='k')
plt.colorbar()
plt.plot(times, np.diag(scores), label="Classif. score")
plt.axhline(0.5, color='k', linestyle='--', label="Chance level")
plt.axvline(0, color='r', label='stim onset')
plt.legend()
plt.xlabel('Time (ms)')
plt.ylabel('ROC classification score')
plt.title('Decoding (%s vs. %s)' % tuple(event_id.keys()))
Have a look at the example
http://martinos.org/mne/dev/auto_examples/decoding/plot_decoding_csp_space.html