A new minor TorchEEG release brings new trainers, models, utils, datasets and transforms to TorchEEG. The TorchEEG team has experimented extensively with generative models and cross-subject algorithms based on domain adaptation, which are considered as cutting-edge research for deep learning-oriented EEG analysis. The related API is now open source! More feedback was taken to improve TorchEEG usability.
Trainer
We offer a variety of trainers for generating models:
- DDPMTrainer and CDDPMTrainer
python
unet = BUNet(in_channels=4)
trainer = DDPMTrainer(unet)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)
unet = BCUNet(in_channels=4, num_classes=2)
trainer = CDDPMTrainer(unet)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)
- GANTrainer and CGANTrainer
python
g_model = BGenerator(in_channels=128)
d_model = BDiscriminator(in_channels=4)
trainer = GANTrainer(g_model, d_model)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)
g_model = BCGenerator(in_channels=128, num_classes=2)
d_model = BCDiscriminator(in_channels=4, num_classes=2)
trainer = CGANTrainer(g_model, d_model)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)
- VAETrainer and CVAETrainer
python
trainer = VAETrainer(encoder, decoder)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)
- GlowTrainer
python
model = BGlow(in_channels=4)
trainer = GlowTrainer(model)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)
Some domain adaptation algorithms were supplemented to show good experimental performance:
- DANTrainer
python
trainer = DANTrainer(extractor, classifier)
trainer.fit(source_loader, target_loader, val_loader)
trainer.test(test_loader)
- ADATrainer
python
trainer = ADATrainer(extractor, classifier)
trainer.fit(source_loader, target_loader, val_loader)
trainer.test(test_loader)
Model
Complemented with a CNN-based hybrid model:
* SSTEmotionNet
python
eeg = torch.randn(2, 32 + 4, 16, 16)
model = SSTEmotionNet(temporal_in_channels=32,
spectral_in_channels=4,
grid_size=(16, 16),
num_classes=2)
pred = model(eeg)
A large number of transformer-based models have been added:
- ViT
python
eeg = torch.randn(1, 32, 128)
model = ArjunViT(chunk_size=128,
t_patch_size=32,
num_classes=2)
pred = model(eeg)
- VanillaTransformer
python
eeg = torch.randn(1, 32, 128)
model = VanillaTransformer(chunk_size=128,
t_patch_size=32,
hid_channels=32,
depth=3,
heads=4,
head_channels=64,
mlp_channels=64,
num_classes=2)
pred = model(eeg)
- ArjunViT
python
eeg = torch.randn(1, 32, 128)
model = ArjunViT(chunk_size=128,
t_patch_size=32,
num_classes=2)
pred = model(eeg)
And GNN-based model:
* LGGNet
python
eeg = torch.rand(2, 1, 32, 128)
model = LGGNet(DEAP_GENERAL_REGION_LIST,
num_electrodes=32,
chunk_size=128)
pred = model(eeg)
Generate model baselines to be adapted to EEG generation:
* BUNet and BCUNet for DDPM
python
unet = BUNet()
eeg = torch.randn(2, 4, 9, 9)
t = torch.randint(low=1, high=1000, size=(2, ))
fake_X = unet(eeg, t)
unet = BCUNet(num_classes=2)
eeg = torch.randn(2, 4, 9, 9)
t = torch.randint(low=1, high=1000, size=(2, ))
y = torch.randint(low=0, high=2, size=(1, ))
fake_X = unet(eeg, t, y)
* BCGenerator , BGenerator, BCDiscriminator, and BDiscriminator for GAN and CGAN
python
g_model = BGenerator(in_channels=128)
d_model = BDiscriminator(in_channels=4)
z = torch.normal(mean=0, std=1, size=(1, 128))
fake_X = g_model(z)
disc_X = d_model(fake_X)
g_model = BCGenerator(in_channels=128, num_classes=3)
d_model = BCDiscriminator(in_channels=4, num_classes=3)
z = torch.normal(mean=0, std=1, size=(1, 128))
y = torch.randint(low=0, high=3, size=(1, ))
fake_X = g_model(z, y)
disc_X = d_model(fake_X, y)
* BCEncoder, BCDecoder, BDecoder, and BDecoder for VAE and CVAE
python
encoder = BEncoder(in_channels=4)
decoder = BDecoder(in_channels=64, out_channels=4)
eeg = torch.randn(1, 4, 9, 9)
mu, logvar = encoder(eeg)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = eps * std + mu
fake_X = decoder(z)
encoder = BCEncoder(in_channels=4, num_classes=3)
decoder = BCDecoder(in_channels=64, out_channels=4, num_classes=3)
y = torch.randint(low=0, high=3, size=(1, ))
eeg = torch.randn(1, 4, 9, 9)
mu, logvar = encoder(eeg, y)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = eps * std + mu
fake_X = decoder(z, y)
* Glow
python
model = BGlow()
mock_eeg = torch.randn(2, 4, 32, 32)
z, nll_loss, y_logits = model(mock_eeg)
loss = nll_loss.mean()
fake_X = model(temperature=1.0, reverse=True)
Utils
Added a new method for visualizing the adjacency matrix:
* plot_adj_connectivity
python
adj = torch.randn(62, 62)
plot_adj_connectivity(adj,
SEED_CHANNEL_LIST,
SEED_GENERAL_REGION_LIST,
num_connectivity=60,
linewidth=1.5)
Datasets
Additional feature datasets are supported, where the publisher has extracted the suggested features:
- SEEDIVDataset
python
dataset = SEEDIVDataset()
- SEEDIVFeatureDataset
python
dataset = SEEDIVFeatureDataset()
- SEEDFeatureDataset
python
dataset = SEEDFeatureDataset()
- MPEDFeatureDataset
python
dataset = MPEDFeatureDataset()
Some hook functions for trial have been added:
- before_trial_normalize
python
dataset = DEAPDataset(before_trial=before_trial_normalize)
- after_trial_normalize
python
dataset = DEAPDataset(after_trial=after_trial_normalize)
- after_trial_moving_avg
python
dataset = DEAPDataset(after_trial=after_trial_moving_avg)
Transforms
New transformation functions have been added:
- Downsample
python
eeg = np.random.randn(32, 128)
transformed_eeg = Downsample(num_points=32, axis=-1)(eeg=eeg)
Concatenate-related transformation functions support cross-use with Compose:
- Concatenate
python
eeg = np.random.randn(32, 128)
transformed_eeg = Concatenate([BandSkewness(), BandBinPower()])(eeg=eeg)
- MapChunk
python
eeg = np.random.randn(64, 1000)
transformed_eeg = MapChunk(BandDifferentialEntropy(),
chunk_size=250,
overlap=0)(eeg=eeg)
Breaking Changes
- The transforms `ConcatenateChunk` changed to `MapChunk `
- Parameter name `frequency` changed to `sampling_rate` in `RandomFrequencyShift`
- `MSRN ` was removed from `torcheeg.models`