class: center, middle, theYellowBackground background: blue .huge.bold[A quick tutorial on how to use
ESPnet] --- # Table of Contents 1. Basic directory structure * [Recipe](#espnet_recipe) * [Framework](#espnet_framework) 2. [Data format](#data_format): `dump/raw/
/` 3. [Task definition](#task_definition): `espnet2/tasks/enh.py` 4. [Model definition](#model_definition): `espnet2/enh/` 5. [Model configuration](#model_configuration): `conf/tuning/xxx.yaml` 6. [Script flow of training and inference](#script_flow) --- name: espnet_recipe # Basic directory structure - Recipe .left-column-38[
📁 [espnet-root-dir]
├─
📁 egs2
│ └─
📁 urgent24
│ └─
📁 enh1
│ ├─
📁 conf
│ ├─
📁 local
│ ├─ .red[📁 data] │ ├─ .red[📁 dump] │ ├─
📄 cmd.sh
│ ├─
📄 path.sh
│ ├─
📄 enh.sh
│ └─
📄 run.sh
├─
📁 espnet2
└─
📁 tools
] .right-column-61[
* Root directory ** Directory containing all recipes *** Recipe for a specific dataset **** Speech enhancement recipe ***** Model configurations (YAML files) ***** Data preparation scripts (used by run.sh) ***** .red[Directory generated by local/data.sh]
👉
Data format
***** .red[Data directory used for training/evaluation]
👆
***** SLURM config (cmd_backend: 'local' / 'slurm' /...) ***** Envrionment config file (export XXX=YYY) ***** Common script used for speech enhancement ***** .red[Entry script for running the recipe] ** Python scripts used for training/inference ** Used for toolkit installation
] .full-column[ ```bash cd
/egs2/urgent24/enh1 ./run.sh --stage 1 --stop-stage 1 mkdir -p dump/raw cp -r data/* dump/raw/ ``` ] --- # Basic directory structure - Recipe (Cont'd) .left-column-48[
📁 [espnet-root-dir]
├─
📁 egs2
│ └─
📁 urgent24
│ └─
📁 enh1
│ ├─
📁 conf
│ ├─
📁 local
│ ├─
📁 data
│ ├─
📁 dump
│ ├─ .red[📁 exp] │ │ └─ .red[📁 enh_stats_16k] │ ├─
📄 cmd.sh
│ ├─
📄 path.sh
│ ├─
📄 enh.sh
│ └─
📄 run.sh
├─
📁 espnet2
└─
📁 tools
] .right-column-51[
* Root directory ** Directory containing all recipes *** Recipe for a specific dataset **** Speech enhancement recipe ***** .red[Model configurations]
👉
Model configuration
***** Data preparation scripts (used by run.sh) ***** Directory generated by local/data.sh ***** Data directory used for training/evaluation ***** .red[ Exp directory for model training/inference] ****** .red[Length stats used for minibatch construction] ***** SLURM config file ('local', 'slurm', etc.) ***** Envrionment config file (export XXX=YYY) ***** Common script used for speech enhancement ***** .red[Entry script for running the recipe] ** Python scripts used for training/inference ** Used for toolkit installation
] .full-column[ ```bash cd
/egs2/urgent24/enh1 ./run.sh --stage 5 --stop-stage 5 --nj 8 ``` ] --- # Basic directory structure - Recipe (Cont'd) .left-column-48[
📁 [espnet-root-dir]
├─
📁 egs2
│ └─
📁 urgent24
│ └─
📁 enh1
│ ├─
📁 conf
│ ├─
📁 local
│ ├─
📁 data
│ ├─
📁 dump
│ ├─ .red[📁 exp] │ │ └─ .red[📁 enh_train_enh_xxx_raw] │ ├─
📄 cmd.sh
│ ├─
📄 path.sh
│ ├─
📄 enh.sh
│ └─
📄 run.sh
├─
📁 espnet2
└─
📁 tools
] .right-column-51[
* Root directory ** Directory containing all recipes *** Recipe for a specific dataset **** Speech enhancement recipe ***** .red[Model configurations]
👉
Model configuration
***** Data preparation scripts (used by run.sh) ***** Directory generated by local/data.sh ***** Data directory used for training/evaluation ***** .red[Exp directory for model training/inference]
******
Trained model directory
***** SLURM config file ('local', 'slurm', etc.) ***** Envrionment config file (export XXX=YYY) ***** Common script used for speech enhancement ***** .red[Entry script for running the recipe] ** Python scripts used for training/inference ** Used for toolkit installation
] .full-column[ ```bash cd
/egs2/urgent24/enh1 ./run.sh --stage 6 --stop-stage 6 --enh_config conf/tuning/xxx.yaml --num_nodes 1 --ngpu 1 ``` ] --- # Basic directory structure - Recipe (Cont'd) .left-column-48[
📁 [espnet-root-dir]
├─
📁 egs2
│ └─
📁 urgent24
│ └─
📁 enh1
│ ├─
📁 conf
│ ├─
📁 local
│ ├─
📁 data
│ ├─
📁 dump
│ ├─ .red[📁 exp] │ │ └─ .red[📁 enh_train_enh_xxx_raw] │ │ ├─
📁 images
│ │ ├─
📄 1epoch.pth
│ │ ├─
📄 valid.loss.best.pth
│ │ └─
📄 train.log
│ ├─
📄 cmd.sh
│ ├─
📄 path.sh
│ ├─
📄 enh.sh
│ └─
📄 run.sh
├─
📁 espnet2
└─
📁 tools
] .right-column-51[
* Root directory ** Directory containing all recipes *** Recipe for a specific dataset **** Speech enhancement recipe ***** .red[Model configurations]
👉
Model configuration
***** Data preparation scripts (used by run.sh) ***** Directory generated by local/data.sh ***** Data directory used for training/evaluation ***** .red[Exp directory for model training/inference]
******
Trained model directory
******
Training curves
******
Latest checkpoint
******
Best checkpoint (when finished)
******
Training log
***** SLURM config file ('local', 'slurm', etc.) ***** Envrionment config file (export XXX=YYY) ***** Common script used for speech enhancement ***** .red[Entry script for running the recipe] ** Python scripts used for training/inference ** Used for toolkit installation
] --- # Basic directory structure - Recipe (Cont'd) .left-column-48[
📁 [espnet-root-dir]
├─
📁 egs2
│ └─
📁 urgent24
│ └─
📁 enh1
│ ├─
📁 conf
│ ├─
📁 local
│ ├─
📁 data
│ ├─
📁 dump
│ ├─ .red[📁 exp] │ │ └─ .red[📁 enh_train_enh_xxx_raw] │ │ └─ .red[📁 enhanced_validation] │ ├─
📄 cmd.sh
│ ├─
📄 path.sh
│ ├─
📄 enh.sh
│ └─
📄 run.sh
├─
📁 espnet2
└─
📁 tools
] .right-column-51[
* Root directory ** Directory containing all recipes *** Recipe for a specific dataset **** Speech enhancement recipe ***** .red[Model configurations]
👉
Model configuration
***** Data preparation scripts (used by run.sh) ***** Directory generated by local/data.sh ***** Data directory used for training/evaluation ***** .red[Exp directory for model training/inference]
******
Trained model directory
*******
Enhanced audios (spk1.scp)
***** SLURM config file ('local', 'slurm', etc.) ***** Envrionment config file (export XXX=YYY) ***** Common script used for speech enhancement ***** .red[Entry script for running the recipe] ** Python scripts used for training/inference ** Used for toolkit installation
] .full-column[ ```bash cd
/egs2/urgent24/enh1 ./run.sh --stage 7 --stop-stage 7 --enh_config conf/tuning/xxx.yaml --inference_nj 8 --gpu_inference 8 ``` ] --- # Basic directory structure - Recipe (Cont'd) .left-column-48[
📁 [espnet-root-dir]
├─
📁 egs2
│ └─
📁 urgent24
│ └─
📁 enh1
│ ├─
📁 conf
│ ├─
📁 local
│ ├─
📁 data
│ ├─
📁 dump
│ ├─ .red[📁 exp] │ │ └─ .red[📁 enh_train_enh_xxx_raw] │ │ └─ .red[📁 enhanced_validation] │ │ ├─
📁 logdir
│ │ └─
📄 spk1.scp
│ ├─
📄 cmd.sh
│ ├─
📄 path.sh
│ ├─
📄 enh.sh
│ └─
📄 run.sh
├─
📁 espnet2
└─
📁 tools
] .right-column-51[
* Root directory ** Directory containing all recipes *** Recipe for a specific dataset **** Speech enhancement recipe ***** .red[Model configurations]
👉
Model configuration
***** Data preparation scripts (used by run.sh) ***** Directory generated by local/data.sh ***** Data directory used for training/evaluation ***** .red[Exp directory for model training/inference]
******
Trained model directory
*******
Enhanced audios (spk1.scp)
********
Enhanced audio directory
********
list of all enhanced audios
***** SLURM config file ('local', 'slurm', etc.) ***** Envrionment config file (export XXX=YYY) ***** Common script used for speech enhancement ***** .red[Entry script for running the recipe] ** Python scripts used for training/inference ** Used for toolkit installation
] --- name: espnet_framework # Basic directory structure - Framework .left-column-38[
📁 [espnet-root-dir]
├─
📁 egs2
├─
📁 espnet2
│ ├─
📁 bin
│ │ ├─
📄 enh_train.py
│ │ ├─
📄 enh_inference.py
│ │ └─
📄 enh_scoring.py
│ ├─
📁 tasks
│ │ ├─
📄 abs_task.py
│ │ └─
📄 enh.py
│ └─
📁 enh
│ ├─
📁 encoder
│ ├─
📁 separator
│ ├─
📁 decoder
│ ├─
📁 layers
│ ├─
📁 loss
│ │ ├─
📁 criterions
│ │ └─
📁 wrappers
│ └─
📄 espnet_model.py
└─
📁 tools
] .right-column-61[
* Root directory ** Directory containing all recipes ** Python scripts used for training/inference *** Entry scripts (used by
enh.sh
) **** Entry script for training **** .red[Entry script for inference (enhancing a subset)] **** Entry script for scoring (evaluating enhanced audios) *** Task definition **** Abstract task definition (shared by all tasks) **** .red[Speech enhancement task definition]
👉
Task definition
*** Speech enhancement model definition
👉
Model definition
**** Available encoders (
STFT
,
Conv1d
, etc.) **** Available separators (
BSRNN
,
TF-GridNet
, etc.) **** Available decoders (
iSTFT
,
ConvTranspose1d
, etc.) **** Detailed layer definitions used in separators *** Loss functions **** Criterion functions (
time
/
time-frequency
domain) **** Wrappers (
fixed permutation
,
PIT
, etc.) *** .red[Common framework of speech enhancement models] ** Used for toolkit installation
] --- name: data_format # Data format In speech enhancement recipes, we need to prepare the following seven files: .left-column-40[
📁 [espnet-root-dir]
└─
📁 egs2/urgent24/enh1
├─
📁 data
└─
📁 dump/raw
├─
📁 train
│ ├─
📄 wav.scp
│ ├─
📄 spk1.scp
│ ├─
📄 text
│ ├─
📄 utt2spk
│ ├─
📄 spk2utt
│ ├─
📄 utt2fs
│ └─
📄 utt2category
├─
📁 validation
└─
📁 test
] .right-column-59[
* Root directory ** Recipe for the URGENT 2024 Challenge *** Directory generated by local/data.sh *** Data directory used for training/evaluation **** Training set ***** .red[List of degraded speech (model input)] ***** .red[Corresponding clean speech (label)] ***** .red[Corresponding transcripts (for ASR evaluation)] ***** .red[Corresponding speaker IDs] ***** .red[Generated by utils/utt2spk_to_spk2utt.pl] ***** .red[Corresponding sampling rate in Hz] ***** .red[Corresponding category (for minibatch construction)] **** Validation set **** Test set
] .full-column[ In each data file, the data format follows the
Kaldi
convention, i.e., a
table-style
format with the first column as the utterance ID (key) and the rest columns as the value. ] --- # Data format (Cont'd) |File name|Example content|Note| |---|---|---| |wav.scp|
uid1
/path/to/noisy/uid1.flac|Essential| |spk1.scp|
uid1
/path/to/clean/uid1.flac|Needed in stage 8| |text|
uid1
It is also very valuable.|Needed in stage 8+| |utt2spk|
uid1
sid1|Essential| |spk2utt|
sid1
uid1 uid7 uid16 uid99|Essential| |utt2fs|
uid1
32000|Essential for multi-fs data| |utt2category|
uid1
1ch_32000Hz|Essential for multi-fs data| > Note that `utt2fs` is required to run model inference in stage 7 of `run.sh` so that the enhanced audio can be stored in the correct sampling rate. >
> For the officially released validation/test subset (downloaded audios), you will need to generate the above data files manually. --- # Data format (Cont'd) To manually generate data files for an audio directory `audios`, you could run the following commands (assuming all file names are unique): ```bash mkdir -p dump/raw/
find audios/ -iname '*.flac' | awk -F'[/.]' '{print($(NF-1)" "$0)}' | \ sort -u > dump/raw/
/wav.scp find audios/ -iname '*.flac' | awk -F'[/.]' '{print($(NF-1)" "$(NF-1))}' | \ sort -u > dump/raw/
/utt2spk find audios/ -iname '*.flac' | awk -F'[/.]' '{print($(NF-1)" "$(NF-1))}' | \ sort -u > dump/raw/
/spk2utt python -c ' import soundfile as sf with open("dump/raw/
/utt2fs", "w") as f1: with open("dump/raw/
/utt2category", "w") as f2: with open("dump/raw/
/wav.scp", "r") as f3: for line in f3: uid, path = line.strip().split(maxsplit=1) info = sf.info(path) f1.write(f"{uid} {info.samplerate}\n") f2.write(f"{uid} {info.channels}ch_{info.samplerate}Hz\n") ' ``` --- name: task_definition # Task definition The speech enhancement task
`espnet2.tasks.enh.EnhancementTask`
is a high-level class used for model training as well as for loading pre-trained models. It generally defines the supported * .darkred[model architectures] (encoder_choices, separator_choices, decoder_choices) * .darkred[loss functions] (loss_wrapper_choices, criterion_choices) * The wrapper is used mainly for handling the permutation problem in speech separation tasks. * For single-speaker tasks, please simply use `wrapper: fixed_order`. * .darkred[preprocessing] (preprocessor_choices) * used for preprocessing each input sample during both training and inference stages * e.g., for data augmentation / dynamic mixing, and normalization * and the .darkred[arguments] used for training. --- # Task definition (Cont'd) To use a pre-trained speech enhancement model for enhancing audios, you can run the following script:
import soundfile as sf from espnet2.bin.enh_inference import SeparateSpeech
model = SeparateSpeech( train_config="exp/xxx/config.yaml", model_file="exp/xx/valid.loss.best.pth", normalize_output_wav=True, device="cuda", )
audio, fs = sf.read("/path/to/noisy/utt1.flac") enhanced = model(audio[None, :], fs=fs)[0]
--- name: model_definition # Model definition All speech enhancement models generally share the same template defined in
`espnet2.enh.espnet_model.ESPnetEnhancementModel`
, which consists of three modules:
encoder
→
separator
→
decoder
. The main enhancement function is achieved by the separator module, which can be of various architectures, e.g.,
BSRNN
,
TF-GridNet
, and so on. To add your own model, you can follow the instructions in
egs2/TEMPLATE/enh1/README.md
. --- name: model_configuration # Model configuration The configuration (YAML) file is defined in
`conf/tuning/`
. A typical config file consists of four parts: **1.** Basic hyperparameters (1/2) ```yaml max_epoch: 100 # Max training epochs batch_type: folded # See espnet2/samplers/build_batch_sampler.py batch_size: 4 # Batch size (number of chunks per minibatch when using the chunk iterator) iterator_type: chunk # Using chunk-based iterator for data loading chunk_length: 200 # Each training sample will be segmented into overlapped 4-sec (= 200 / 50) chunks chunk_default_fs: 50 # Used for automatically scaling the chunk length based on the input sampling rate num_iters_per_epoch: 8000 # Number of samples per epoch (each sample can generate >1 chunks when using the chunk iterator) num_workers: 4 # Number of parallel workers for data loading grad_clip: 5.0 # Gradient clipping threshold optim: adam # Optimizer type (See espnet2/tasks/abs_task.py) optim_conf: # Optimizer configuration lr: 1.0e-03 # See the docstring of torch.optim.Adam eps: 1.0e-08 weight_decay: 1.0e-05 patience: 40 # Patience for early stopping. If the validation performance does not improve for 40 consecutive epochs, the training will be ended. ``` --- # Model configuration The configuration (YAML) file is defined in
`conf/tuning/`
. A typical config file consists of four parts: **1.** Basic hyperparameters (2/2) ```yaml val_scheduler_criterion: # Validation scheduler configuration. Used to collect stats for certain epoch-based optimizers that require the validation performance per epoch. - valid # Name of the phase for stats recording. Can be 'train' or 'valid'. - loss # Name of the stats. Can be any key name in the `stats` dict returned by ESPnetEnhancementModel.forward in espnet2/enh/espnet_model.py. best_model_criterion: # Configuration of recording stats for determining the best-performing checkpoint. - - valid # Name 1 of the phase for stats recording. Can be 'train' or 'valid'. - loss # Name of the stat. Same as above. - min # The best model 1 is determined by the minimum value of the stat. Can be 'min' or 'max'. - - valid # Name 2 of the phase for stats recording. Can be 'train' or 'valid'. - acc # Name of the stat. Same as above. - max # The best model 2 is determined by the maximum value of the stat. Can be 'min' or 'max'. keep_nbest_models: 1 # Number of best models to keep. The best model is determined by `best_model_criterion`. scheduler: steplr # Scheduler type. See scheduler_classes in espnet2/tasks/abs_task.py. scheduler_conf: # Scheduler configuration step_size: 2 # See the docstring of torch.optim.lr_scheduler.StepLR gamma: 0.99 allow_multi_rates: true # Whether to allow loading audios of different sampling rates (If true, special treatment is required in the preprocessor to make sure data of different sampling rates are grouped in different categories (thus minibatches)) ``` --- # Model configuration The configuration (YAML) file is defined in
`conf/tuning/`
. A typical config file consists of four parts: **2.** Preprocessor configuration ```yaml preprocessor: enh # Preprocessor type. See preprocessor_choices in espnet2/tasks/enh.py. The 'enh' preprocessor allows automatic processing of different sampling rates based on the utt2fs and utt2category files. force_single_channel: true # Whether to force the input audio to be single-channel in the preprocessor. channel_reordering: true # Whether to reorder the multiple channels of the input audio in the preprocessor. categories: # List of all possible categories used in dump/raw/*/utt2category files. This is used by the preprocessor to group samples according to their category. - 1ch_8000Hz # For this challenge, the category format can be simply '1ch_{fs}Hz'. - 1ch_16000Hz - 1ch_22050Hz - 1ch_24000Hz - 1ch_32000Hz - 1ch_44100Hz - 1ch_48000Hz num_spk: 1 # Number of speakers in the input audio. Should be 1 for single-speaker tasks. ``` --- # Model configuration The configuration (YAML) file is defined in
`conf/tuning/`
. A typical config file consists of four parts: **3.** Model configuration (1/2) ```yaml model_conf: # Configuration for espnet2/enh/espnet_model.ESPnetEnhancementModel normalize_variance_per_ch: true # Whether to normalize the variance of (each channel of) the input speech always_forward_in_48k: false # Whether to always upsample the input speech to 48 kHz for model processing. The model output will be downsampled back to its original sampling rate. # When processing data of different sampling rates, this must be true for speech enhancement models that only support one sampling rate. # For sampling-frequency-independent (SFI) models such as BSRNN and TF-GridNet-v3, this can be false for more efficient processing. categories: # Must have the same value as in the preprocessor config (see last slide). - 1ch_8000Hz - 1ch_16000Hz - 1ch_22050Hz - 1ch_24000Hz - 1ch_32000Hz - 1ch_44100Hz - 1ch_48000Hz ``` --- # Model configuration The configuration (YAML) file is defined in
`conf/tuning/`
. A typical config file consists of four parts: **3.** Model configuration (2/2) ```yaml encoder: stft # Encoder type. See encoder_choices in espnet2/tasks/enh.py. encoder_conf: # Encoder configuration. n_fft: 960 # FFT size for each window. hop_length: 480 # Hop size for the sliding window. use_builtin_complex: true # Whether to use PyTorch's builtin complex tensor type. default_fs: 48000 # Used for automatically adjusting the STFT window/hop sizes based on the input sampling rate. # This should be used together with frequency-domain SFI models such as BSRNN and TF-GridNet-v3. decoder: stft # Decoder type. See decoder_choices in espnet2/tasks/enh.py. decoder_conf: # Decoder configuration. n_fft: 960 # Same as above. Usually should have the same config as in encoder_conf. hop_length: 480 default_fs: 48000 separator: bsrnn # Separator type. See separator_choices in espnet2/tasks/enh.py. separator_conf: # Separator configuration. num_spk: 1 # Number of speakers to separate. Should be 1 for single-speaker tasks. ... # See espnet2/enh/separator/bsrnn_separator.py. ``` --- # Model configuration The configuration (YAML) file is defined in
`conf/tuning/`
. A typical config file consists of four parts: **4.** Training criterion configuration ```yaml criterions: # Training criterion configuration. - name: mr_l1_tfd # The first criterion. See criterion_choices in espnet2/tasks/enh.py. conf: # Criterion configuration. window_sz: [256, 512, 768, 1024] # See MultiResL1SpecLoss in espnet2/enh/loss/criterions/time_domain.py. hop_sz: null eps: 1.0e-8 time_domain_weight: 0.5 normalize_variance: true wrapper: fixed_order # Wrapper type for handling the permutation problem. See loss_wrapper_choices in espnet2/tasks/enh.py. wrapper_conf: # Wrapper configuration. weight: 1.0 # Loss weight for the first criterion. Can be used for balancing multiple criteria. - name: si_snr # The second criterion. conf: # Criterion configuration. ... # See SISNRLoss in espnet2/enh/loss/criterions/time_domain.py. ``` --- name: script_flow # Script flow of training and inference **Training:**
run.sh
→
enh.sh
(stage 6) →
espnet2/bin/enh_train.py
→
espnet2/tasks/enh.py
`ESPnetEnhancementTask` &
espnet2/tasks/abs_task.py
`AbsTask.main`
→
espnet2/train/trainer.py
`Trainer.run`
* →
espnet2/train/dataset.py
`ESPnetDataset`
* →
espnet2/train/preprocessor.py
`EnhPreprocessor` (If used)
* →
espnet2/iterators/chunk_iter_factory.py
`ChunkIterFactory.build_iter` (If used)
→
espnet2/enh/espnet_model.py
`ESPnetEnhancementModel.forward`
* →
espnet2/enh/encoder/stft_encoder.py
`STFTEncoder.forward` (If used)
* →
espnet2/enh/separator/bsrnn_separator.py
`BSRNNSeparator.forward` (If used)
* →
espnet2/enh/decoder/stft_decoder.py
`STFTDecoder.forward` (If used)
* →
espnet2/enh/loss/wrappers/fixed_order.py
`FixedOrderSolver.forward` (If used)
* →
espnet2/enh/loss/criterions/time_domain.py
`SISNRLoss.forward` (If used)
--- # Script flow of training and inference (Cont'd) **Inference:**
run.sh
→
enh.sh
(stage 7) →
espnet2/bin/enh_inference.py
`SeparateSpeech`
* →
espnet2/tasks/enh.py
`ESPnetEnhancementTask` &
espnet2/tasks/abs_task.py
`AbsTask.build_model_from_file`
* →
espnet2/tasks/enh.py
`ESPnetEnhancementTask.build_streaming_iterator`
* →
espnet2/train/iterable_dataset.py
`IterableESPnetDataset`
* →
espnet2/train/preprocessor.py
`EnhPreprocessor` (If used)
→
espnet2/enh/espnet_model.py
`ESPnetEnhancementModel.forward`
* →
espnet2/enh/encoder/stft_encoder.py
`STFTEncoder.forward` (If used)
* →
espnet2/enh/separator/bsrnn_separator.py
`BSRNNSeparator.forward` (If used)
* →
espnet2/enh/decoder/stft_decoder.py
`STFTDecoder.forward` (If used)