From f0170d62604ce06b0a3610b85ad2b3ba3f82e88f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Jul 2022 17:42:49 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ludwig/constants.py | 14 ++++++---- ludwig/features/audio_feature.py | 48 +++++++++++++++++++------------- ludwig/schema/preprocessing.py | 13 ++++----- 3 files changed, 43 insertions(+), 32 deletions(-) diff --git a/ludwig/constants.py b/ludwig/constants.py index ba604772368..c562bc2ca94 100644 --- a/ludwig/constants.py +++ b/ludwig/constants.py @@ -183,12 +183,14 @@ CONTINUE_PROMPT = "Do you want to continue? " DEFAULT_AUDIO_TENSOR_LENGTH = 70000 -AUDIO_FEATURE_KEYS = ["type", - "window_length_in_s", - "window_shift_in_s", - "num_fft_points", - "window_type", - "num_filter_bands"] +AUDIO_FEATURE_KEYS = [ + "type", + "window_length_in_s", + "window_shift_in_s", + "num_fft_points", + "window_type", + "num_filter_bands", +] MODEL_TYPE = "model_type" MODEL_ECD = "ecd" diff --git a/ludwig/features/audio_feature.py b/ludwig/features/audio_feature.py index b6de486ed01..c0557fe7e1b 100644 --- a/ludwig/features/audio_feature.py +++ b/ludwig/features/audio_feature.py @@ -15,22 +15,24 @@ # ============================================================================== import logging import os -from typing import Any, Dict, List, Union, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch import torchaudio -from ludwig.constants import (AUDIO, - BACKFILL, - COLUMN, - NAME, - PREPROCESSING, - PROC_COLUMN, - SRC, - TIED, - TYPE, - AUDIO_FEATURE_KEYS) +from ludwig.constants import ( + AUDIO, + AUDIO_FEATURE_KEYS, + BACKFILL, + COLUMN, + NAME, + PREPROCESSING, + PROC_COLUMN, + SRC, + TIED, + TYPE, +) from ludwig.features.base_feature import BaseFeatureMixin from ludwig.features.sequence_feature import SequenceInputFeature from ludwig.schema.features.audio_feature import AudioInputFeatureConfig @@ -60,8 +62,11 @@ class _AudioPreprocessing(torch.nn.Module): def __init__(self, metadata: Dict[str, Any]): super().__init__() - self.audio_feature_dict = {key: value for key, value in metadata["preprocessing"].items() - if key in AUDIO_FEATURE_KEYS and value is not None} + self.audio_feature_dict = { + key: value + for key, value in metadata["preprocessing"].items() + if key in AUDIO_FEATURE_KEYS and value is not None + } self.feature_dim = metadata["feature_dim"] self.max_length = metadata["max_length"] self.padding_value = metadata["preprocessing"]["padding_value"] @@ -135,12 +140,14 @@ def _get_feature_dim(preprocessing_parameters, sampling_rate_in_hz): if feature_type == "raw": feature_dim = 1 elif feature_type == "stft_phase": - feature_dim_symmetric = get_length_in_samp(preprocessing_parameters["window_length_in_s"], - sampling_rate_in_hz) + feature_dim_symmetric = get_length_in_samp( + preprocessing_parameters["window_length_in_s"], sampling_rate_in_hz + ) feature_dim = 2 * get_non_symmetric_length(feature_dim_symmetric) elif feature_type in ["stft", "group_delay"]: - feature_dim_symmetric = get_length_in_samp(preprocessing_parameters["window_length_in_s"], - sampling_rate_in_hz) + feature_dim_symmetric = get_length_in_samp( + preprocessing_parameters["window_length_in_s"], sampling_rate_in_hz + ) feature_dim = get_non_symmetric_length(feature_dim_symmetric) elif feature_type == "fbank": feature_dim = preprocessing_parameters["num_filter_bands"] @@ -391,8 +398,11 @@ def add_feature_data( feature_dim = metadata[name]["feature_dim"] max_length = metadata[name]["max_length"] - audio_feature_dict = {key: value for key, value in preprocessing_parameters.items() - if key in AUDIO_FEATURE_KEYS and value is not None} + audio_feature_dict = { + key: value + for key, value in preprocessing_parameters.items() + if key in AUDIO_FEATURE_KEYS and value is not None + } audio_file_length_limit_in_s = preprocessing_parameters["audio_file_length_limit_in_s"] if num_audio_utterances == 0: diff --git a/ludwig/schema/preprocessing.py b/ludwig/schema/preprocessing.py index f085d451082..3551778483e 100644 --- a/ludwig/schema/preprocessing.py +++ b/ludwig/schema/preprocessing.py @@ -560,36 +560,35 @@ class AudioPreprocessingConfig(schema_utils.BaseMarshmallowConfig): type: str = schema_utils.StringOptions( ["fbank", "group_delay", "raw", "stft", "stft_phase"], default="fbank", - description="Defines the type of audio feature to be used." + description="Defines the type of audio feature to be used.", ) window_length_in_s: float = schema_utils.NonNegativeFloat( default=0.04, description="Defines the window length used for the short time Fourier transformation. This is only needed if " - "the audio_feature_type is 'raw'.", + "the audio_feature_type is 'raw'.", ) window_shift_in_s: float = schema_utils.NonNegativeFloat( default=0.02, description="Defines the window shift used for the short time Fourier transformation (also called " - "hop_length). This is only needed if the audio_feature_type is 'raw'. " + "hop_length). This is only needed if the audio_feature_type is 'raw'. ", ) num_fft_points: float = schema_utils.NonNegativeFloat( - default=None, - description="Defines the number of fft points used for the short time Fourier transformation" + default=None, description="Defines the number of fft points used for the short time Fourier transformation" ) window_type: str = schema_utils.StringOptions( ["bartlett", "blackman", "hamming", "hann"], default="hamming", - description="Defines the type window the signal is weighted before the short time Fourier transformation." + description="Defines the type window the signal is weighted before the short time Fourier transformation.", ) num_filter_bands: int = schema_utils.PositiveInteger( default=80, description="Defines the number of filters used in the filterbank. Only needed if audio_feature_type " - "is 'fbank'" + "is 'fbank'", )