A quick tutorial on how to use
ESPnet
Basic directory structure
Data format: dump/raw/<subset-name>/
Task definition: espnet2/tasks/enh.py
Model definition: espnet2/enh/
Model configuration: conf/tuning/xxx.yaml
Data format
***** Data directory used for training/evaluation 👆
***** SLURM config (cmd_backend: 'local' / 'slurm' /...)
***** Envrionment config file (export XXX=YYY)
***** Common script used for speech enhancement
***** Entry script for running the recipe
** Python scripts used for training/inference
** Used for toolkit installation
cd <espnet-root-dir>/egs2/urgent24/enh1./run.sh --stage 1 --stop-stage 1mkdir -p dump/rawcp -r data/* dump/raw/
Model configuration
***** Data preparation scripts (used by run.sh)
***** Directory generated by local/data.sh
***** Data directory used for training/evaluation
***** Exp directory for model training/inference
****** Length stats used for minibatch construction
***** SLURM config file ('local', 'slurm', etc.)
***** Envrionment config file (export XXX=YYY)
***** Common script used for speech enhancement
***** Entry script for running the recipe
** Python scripts used for training/inference
** Used for toolkit installation
cd <espnet-root-dir>/egs2/urgent24/enh1./run.sh --stage 5 --stop-stage 5 --nj 8
Model configuration
***** Data preparation scripts (used by run.sh)
***** Directory generated by local/data.sh
***** Data directory used for training/evaluation
***** Exp directory for model training/inference
****** Trained model directory
***** SLURM config file ('local', 'slurm', etc.)
***** Envrionment config file (export XXX=YYY)
***** Common script used for speech enhancement
***** Entry script for running the recipe
** Python scripts used for training/inference
** Used for toolkit installation
cd <espnet-root-dir>/egs2/urgent24/enh1./run.sh --stage 6 --stop-stage 6 --enh_config conf/tuning/xxx.yaml --num_nodes 1 --ngpu 1
Model configuration
***** Data preparation scripts (used by run.sh)
***** Directory generated by local/data.sh
***** Data directory used for training/evaluation
***** Exp directory for model training/inference
****** Trained model directory
****** Training curves
****** Latest checkpoint
****** Best checkpoint (when finished)
****** Training log
***** SLURM config file ('local', 'slurm', etc.)
***** Envrionment config file (export XXX=YYY)
***** Common script used for speech enhancement
***** Entry script for running the recipe
** Python scripts used for training/inference
** Used for toolkit installation
Model configuration
***** Data preparation scripts (used by run.sh)
***** Directory generated by local/data.sh
***** Data directory used for training/evaluation
***** Exp directory for model training/inference
****** Trained model directory
******* Enhanced audios (spk1.scp)
***** SLURM config file ('local', 'slurm', etc.)
***** Envrionment config file (export XXX=YYY)
***** Common script used for speech enhancement
***** Entry script for running the recipe
** Python scripts used for training/inference
** Used for toolkit installation
cd <espnet-root-dir>/egs2/urgent24/enh1./run.sh --stage 7 --stop-stage 7 --enh_config conf/tuning/xxx.yaml --inference_nj 8 --gpu_inference 8
Model configuration
***** Data preparation scripts (used by run.sh)
***** Directory generated by local/data.sh
***** Data directory used for training/evaluation
***** Exp directory for model training/inference
****** Trained model directory
******* Enhanced audios (spk1.scp)
******** Enhanced audio directory
******** list of all enhanced audios
***** SLURM config file ('local', 'slurm', etc.)
***** Envrionment config file (export XXX=YYY)
***** Common script used for speech enhancement
***** Entry script for running the recipe
** Python scripts used for training/inference
** Used for toolkit installation
Task definition
*** Speech enhancement model definition 👉 Model definition
**** Available encoders (STFT, Conv1d, etc.)
**** Available separators (BSRNN, TF-GridNet, etc.)
**** Available decoders (iSTFT, ConvTranspose1d, etc.)
**** Detailed layer definitions used in separators
*** Loss functions
**** Criterion functions (time / time-frequency domain)
**** Wrappers (fixed permutation, PIT, etc.)
*** Common framework of speech enhancement models
** Used for toolkit installation
In speech enhancement recipes, we need to prepare the following seven files:
In each data file, the data format follows the Kaldi convention, i.e., a table-style format with the first column as the utterance ID (key) and the rest columns as the value.
File name | Example content | Note |
---|---|---|
wav.scp | uid1 /path/to/noisy/uid1.flac | Essential |
spk1.scp | uid1 /path/to/clean/uid1.flac | Needed in stage 8 |
text | uid1 It is also very valuable. | Needed in stage 8+ |
utt2spk | uid1 sid1 | Essential |
spk2utt | sid1 uid1 uid7 uid16 uid99 | Essential |
utt2fs | uid1 32000 | Essential for multi-fs data |
utt2category | uid1 1ch_32000Hz | Essential for multi-fs data |
Note that
utt2fs
is required to run model inference in stage 7 ofrun.sh
so that the enhanced audio can be stored in the correct sampling rate.
For the officially released validation/test subset (downloaded audios), you will need to generate the above data files manually.
To manually generate data files for an audio directory audios
, you could run the following commands (assuming all file names are unique):
mkdir -p dump/raw/<subset-name>find audios/ -iname '*.flac' | awk -F'[/.]' '{print($(NF-1)" "$0)}' | \ sort -u > dump/raw/<subset-name>/wav.scpfind audios/ -iname '*.flac' | awk -F'[/.]' '{print($(NF-1)" "$(NF-1))}' | \ sort -u > dump/raw/<subset-name>/utt2spkfind audios/ -iname '*.flac' | awk -F'[/.]' '{print($(NF-1)" "$(NF-1))}' | \ sort -u > dump/raw/<subset-name>/spk2uttpython -c 'import soundfile as sfwith open("dump/raw/<subset-name>/utt2fs", "w") as f1: with open("dump/raw/<subset-name>/utt2category", "w") as f2: with open("dump/raw/<subset-name>/wav.scp", "r") as f3: for line in f3: uid, path = line.strip().split(maxsplit=1) info = sf.info(path) f1.write(f"{uid} {info.samplerate}\n") f2.write(f"{uid} {info.channels}ch_{info.samplerate}Hz\n")'
The speech enhancement task espnet2.tasks.enh.EnhancementTask
is a high-level class used for model training as well as for loading pre-trained models.
It generally defines the supported
model architectures (encoder_choices, separator_choices, decoder_choices)
loss functions (loss_wrapper_choices, criterion_choices)
The wrapper is used mainly for handling the permutation problem in speech separation tasks.
For single-speaker tasks, please simply use wrapper: fixed_order
.
preprocessing (preprocessor_choices)
used for preprocessing each input sample during both training and inference stages
e.g., for data augmentation / dynamic mixing, and normalization
and the arguments used for training.
To use a pre-trained speech enhancement model for enhancing audios, you can run the following script:
model = SeparateSpeech(
train_config="exp/xxx/config.yaml",
model_file="exp/xx/valid.loss.best.pth",
normalize_output_wav=True,
device="cuda",
)
audio, fs = sf.read("/path/to/noisy/utt1.flac")
enhanced = model(audio[None, :], fs=fs)[0]
All speech enhancement models generally share the same template defined in espnet2.enh.espnet_model.ESPnetEnhancementModel
, which consists of three modules: encoder → separator → decoder.
The main enhancement function is achieved by the separator module, which can be of various architectures, e.g., BSRNN, TF-GridNet, and so on.
To add your own model, you can follow the instructions in egs2/TEMPLATE/enh1/README.md.
The configuration (YAML) file is defined in conf/tuning/
.
A typical config file consists of four parts:
1. Basic hyperparameters (1/2)
max_epoch: 100 # Max training epochsbatch_type: folded # See espnet2/samplers/build_batch_sampler.pybatch_size: 4 # Batch size (number of chunks per minibatch when using the chunk iterator)iterator_type: chunk # Using chunk-based iterator for data loadingchunk_length: 200 # Each training sample will be segmented into overlapped 4-sec (= 200 / 50) chunkschunk_default_fs: 50 # Used for automatically scaling the chunk length based on the input sampling ratenum_iters_per_epoch: 8000 # Number of samples per epoch (each sample can generate >1 chunks when using the chunk iterator)num_workers: 4 # Number of parallel workers for data loadinggrad_clip: 5.0 # Gradient clipping thresholdoptim: adam # Optimizer type (See espnet2/tasks/abs_task.py)optim_conf: # Optimizer configuration lr: 1.0e-03 # See the docstring of torch.optim.Adam eps: 1.0e-08 weight_decay: 1.0e-05patience: 40 # Patience for early stopping. If the validation performance does not improve for 40 consecutive epochs, the training will be ended.
The configuration (YAML) file is defined in conf/tuning/
.
A typical config file consists of four parts:
1. Basic hyperparameters (2/2)
val_scheduler_criterion: # Validation scheduler configuration. Used to collect stats for certain epoch-based optimizers that require the validation performance per epoch.- valid # Name of the phase for stats recording. Can be 'train' or 'valid'.- loss # Name of the stats. Can be any key name in the `stats` dict returned by ESPnetEnhancementModel.forward in espnet2/enh/espnet_model.py.best_model_criterion: # Configuration of recording stats for determining the best-performing checkpoint.- - valid # Name 1 of the phase for stats recording. Can be 'train' or 'valid'. - loss # Name of the stat. Same as above. - min # The best model 1 is determined by the minimum value of the stat. Can be 'min' or 'max'.- - valid # Name 2 of the phase for stats recording. Can be 'train' or 'valid'. - acc # Name of the stat. Same as above. - max # The best model 2 is determined by the maximum value of the stat. Can be 'min' or 'max'.keep_nbest_models: 1 # Number of best models to keep. The best model is determined by `best_model_criterion`.scheduler: steplr # Scheduler type. See scheduler_classes in espnet2/tasks/abs_task.py.scheduler_conf: # Scheduler configuration step_size: 2 # See the docstring of torch.optim.lr_scheduler.StepLR gamma: 0.99allow_multi_rates: true # Whether to allow loading audios of different sampling rates (If true, special treatment is required in the preprocessor to make sure data of different sampling rates are grouped in different categories (thus minibatches))
The configuration (YAML) file is defined in conf/tuning/
.
A typical config file consists of four parts:
2. Preprocessor configuration
preprocessor: enh # Preprocessor type. See preprocessor_choices in espnet2/tasks/enh.py. The 'enh' preprocessor allows automatic processing of different sampling rates based on the utt2fs and utt2category files.force_single_channel: true # Whether to force the input audio to be single-channel in the preprocessor.channel_reordering: true # Whether to reorder the multiple channels of the input audio in the preprocessor.categories: # List of all possible categories used in dump/raw/*/utt2category files. This is used by the preprocessor to group samples according to their category.- 1ch_8000Hz # For this challenge, the category format can be simply '1ch_{fs}Hz'.- 1ch_16000Hz- 1ch_22050Hz- 1ch_24000Hz- 1ch_32000Hz- 1ch_44100Hz- 1ch_48000Hznum_spk: 1 # Number of speakers in the input audio. Should be 1 for single-speaker tasks.
The configuration (YAML) file is defined in conf/tuning/
.
A typical config file consists of four parts:
3. Model configuration (1/2)
model_conf: # Configuration for espnet2/enh/espnet_model.ESPnetEnhancementModel normalize_variance_per_ch: true # Whether to normalize the variance of (each channel of) the input speech always_forward_in_48k: false # Whether to always upsample the input speech to 48 kHz for model processing. The model output will be downsampled back to its original sampling rate. # When processing data of different sampling rates, this must be true for speech enhancement models that only support one sampling rate. # For sampling-frequency-independent (SFI) models such as BSRNN and TF-GridNet-v3, this can be false for more efficient processing. categories: # Must have the same value as in the preprocessor config (see last slide). - 1ch_8000Hz - 1ch_16000Hz - 1ch_22050Hz - 1ch_24000Hz - 1ch_32000Hz - 1ch_44100Hz - 1ch_48000Hz
The configuration (YAML) file is defined in conf/tuning/
.
A typical config file consists of four parts:
3. Model configuration (2/2)
encoder: stft # Encoder type. See encoder_choices in espnet2/tasks/enh.py.encoder_conf: # Encoder configuration. n_fft: 960 # FFT size for each window. hop_length: 480 # Hop size for the sliding window. use_builtin_complex: true # Whether to use PyTorch's builtin complex tensor type. default_fs: 48000 # Used for automatically adjusting the STFT window/hop sizes based on the input sampling rate. # This should be used together with frequency-domain SFI models such as BSRNN and TF-GridNet-v3.decoder: stft # Decoder type. See decoder_choices in espnet2/tasks/enh.py.decoder_conf: # Decoder configuration. n_fft: 960 # Same as above. Usually should have the same config as in encoder_conf. hop_length: 480 default_fs: 48000separator: bsrnn # Separator type. See separator_choices in espnet2/tasks/enh.py.separator_conf: # Separator configuration. num_spk: 1 # Number of speakers to separate. Should be 1 for single-speaker tasks. ... # See espnet2/enh/separator/bsrnn_separator.py.
The configuration (YAML) file is defined in conf/tuning/
.
A typical config file consists of four parts:
4. Training criterion configuration
criterions: # Training criterion configuration. - name: mr_l1_tfd # The first criterion. See criterion_choices in espnet2/tasks/enh.py. conf: # Criterion configuration. window_sz: [256, 512, 768, 1024] # See MultiResL1SpecLoss in espnet2/enh/loss/criterions/time_domain.py. hop_sz: null eps: 1.0e-8 time_domain_weight: 0.5 normalize_variance: true wrapper: fixed_order # Wrapper type for handling the permutation problem. See loss_wrapper_choices in espnet2/tasks/enh.py. wrapper_conf: # Wrapper configuration. weight: 1.0 # Loss weight for the first criterion. Can be used for balancing multiple criteria. - name: si_snr # The second criterion. conf: # Criterion configuration. ... # See SISNRLoss in espnet2/enh/loss/criterions/time_domain.py.
Training:
run.sh → enh.sh (stage 6) → espnet2/bin/enh_train.py
→ espnet2/tasks/enh.py ESPnetEnhancementTask
&
espnet2/tasks/abs_task.py AbsTask.main
→ espnet2/train/trainer.py Trainer.run
ESPnetDataset
EnhPreprocessor
(If used)ChunkIterFactory.build_iter
(If used)→ espnet2/enh/espnet_model.py ESPnetEnhancementModel.forward
STFTEncoder.forward
(If used)BSRNNSeparator.forward
(If used)STFTDecoder.forward
(If used)FixedOrderSolver.forward
(If used)SISNRLoss.forward
(If used)Inference:
run.sh → enh.sh (stage 7) → espnet2/bin/enh_inference.py SeparateSpeech
ESPnetEnhancementTask
&AbsTask.build_model_from_file
ESPnetEnhancementTask.build_streaming_iterator
IterableESPnetDataset
EnhPreprocessor
(If used)→ espnet2/enh/espnet_model.py ESPnetEnhancementModel.forward
STFTEncoder.forward
(If used)BSRNNSeparator.forward
(If used)STFTDecoder.forward
(If used)Basic directory structure
Data format: dump/raw/<subset-name>/
Task definition: espnet2/tasks/enh.py
Model definition: espnet2/enh/
Model configuration: conf/tuning/xxx.yaml
Keyboard shortcuts
↑, ←, Pg Up, k | Go to previous slide |
↓, →, Pg Dn, Space, j | Go to next slide |
Home | Go to first slide |
End | Go to last slide |
Number + Return | Go to specific slide |
b / m / f | Toggle blackout / mirrored / fullscreen mode |
c | Clone slideshow |
p | Toggle presenter mode |
t | Restart the presentation timer |
?, h | Toggle this help |
Esc | Back to slideshow |