mirror of https://github.com/rhasspy/piper
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
558 lines
25 KiB
Plaintext
558 lines
25 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": [],
|
|
"gpuType": "T4",
|
|
"authorship_tag": "ABX9TyMAPvo6Syxu5wDRkSmySUxq",
|
|
"include_colab_link": true
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
},
|
|
"accelerator": "GPU"
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "view-in-github",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"<a href=\"https://colab.research.google.com/github/rmcpantoja/piper/blob/master/notebooks/piper_inference_(ONNX).ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# <font color=\"pink\"> **[Piper](https://github.com/rhasspy/piper) inferencing notebook.**\n",
|
|
"## ![Piper logo](https://contribute.rhasspy.org/img/logo.png)\n",
|
|
"\n",
|
|
"---\n",
|
|
"\n",
|
|
"- Notebook made by [rmcpantoja](http://github.com/rmcpantoja)\n",
|
|
"- Collaborator: [Xx_Nessu_xX](https://fakeyou.com/profile/Xx_Nessu_xX)"
|
|
],
|
|
"metadata": {
|
|
"id": "eK3nmYDB6C1a"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# First steps"
|
|
],
|
|
"metadata": {
|
|
"id": "9wIvcSmOby84"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"#@title Install software and settings\n",
|
|
"#@markdown The speech synthesizer and other important dependencies will be installed in this cell. But first, some settings:\n",
|
|
"\n",
|
|
"#@markdown #### Enhable Enhanced Accessibility?\n",
|
|
"#@markdown This Enhanced Accessibility functionality is designed for the visually impaired, in which most of the interface can be used by voice guides.\n",
|
|
"enhanced_accessibility = True #@param {type:\"boolean\"}\n",
|
|
"#@markdown ---\n",
|
|
"\n",
|
|
"#@markdown #### Please select your language:\n",
|
|
"lang_select = \"English\" #@param [\"English\", \"Spanish\"]\n",
|
|
"if lang_select == \"English\":\n",
|
|
" lang = \"en\"\n",
|
|
"elif lang_select == \"Spanish\":\n",
|
|
" lang = \"es\"\n",
|
|
"else:\n",
|
|
" raise Exception(\"Language not supported.\")\n",
|
|
"#@markdown ---\n",
|
|
"#@markdown #### Do you want to use the GPU for inference?\n",
|
|
"\n",
|
|
"#@markdown The GPU can be enabled in the edit/notebook settings menu, and this step must be done before connecting to a runtime. The GPU can lead to a higher response speed in inference, but you can use the CPU, for example, if your colab runtime to use GPU's has been ended.\n",
|
|
"use_gpu = True #@param {type:\"boolean\"}\n",
|
|
"\n",
|
|
"if enhanced_accessibility:\n",
|
|
" from google.colab import output\n",
|
|
" guideurl = f\"https://github.com/rmcpantoja/piper/blob/master/notebooks/wav/{lang}\"\n",
|
|
" def playaudio(filename, extension = \"wav\"):\n",
|
|
" return output.eval_js(f'new Audio(\"{guideurl}/{filename}.{extension}?raw=true\").play()')\n",
|
|
"\n",
|
|
"%cd /content\n",
|
|
"print(\"Installing...\")\n",
|
|
"if enhanced_accessibility:\n",
|
|
" playaudio(\"installing\")\n",
|
|
"!git clone -q https://github.com/rmcpantoja/piper\n",
|
|
"%cd /content/piper/src/python\n",
|
|
"#!pip install -q -r requirements.txt\n",
|
|
"!pip install -q cython>=0.29.0 piper-phonemize==1.1.0 librosa>=0.9.2 numpy>=1.19.0 onnxruntime>=1.11.0 pytorch-lightning==1.7.0 torch==1.11.0\n",
|
|
"!pip install -q onnxruntime-gpu\n",
|
|
"!bash build_monotonic_align.sh\n",
|
|
"import os\n",
|
|
"if not os.path.exists(\"/content/piper/src/python/lng\"):\n",
|
|
" !cp -r \"/content/piper/notebooks/lng\" /content/piper/src/python/lng\n",
|
|
"import sys\n",
|
|
"sys.path.append('/content/piper/notebooks')\n",
|
|
"from translator import *\n",
|
|
"lan = Translator()\n",
|
|
"print(\"Checking GPU...\")\n",
|
|
"gpu_info = !nvidia-smi\n",
|
|
"if use_gpu and any('not found' in info for info in gpu_info[0].split(':')):\n",
|
|
" if enhanced_accessibility:\n",
|
|
" playaudio(\"nogpu\")\n",
|
|
" raise Exception(lan.translate(lang, \"The Use GPU checkbox is checked, but you don't have a GPU runtime.\"))\n",
|
|
"elif not use_gpu and not any('not found' in info for info in gpu_info[0].split(':')):\n",
|
|
" if enhanced_accessibility:\n",
|
|
" playaudio(\"gpuavailable\")\n",
|
|
" raise Exception(lan.translate(lang, \"The Use GPU checkbox is unchecked, however you are using a GPU runtime environment. We recommend you check the checkbox to use GPU to take advantage of it.\"))\n",
|
|
"\n",
|
|
"if enhanced_accessibility:\n",
|
|
" playaudio(\"installed\")\n",
|
|
"print(\"Success!\")"
|
|
],
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "v8b_PEtXb8co"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"#@title Download your exported model\n",
|
|
"%cd /content/piper/src/python\n",
|
|
"import os\n",
|
|
"#@markdown #### ID or link of the voice package (tar.gz format):\n",
|
|
"package_url_or_id = \"\" #@param {type:\"string\"}\n",
|
|
"#@markdown ---\n",
|
|
"if package_url_or_id == \"\" or package_url_or_id == \"http\" or package_url_or_id == \"1\":\n",
|
|
" if enhanced_accessibility:\n",
|
|
" playaudio(\"noid\")\n",
|
|
" raise Exception(lan.translate(lang, \"Invalid link or ID!\"))\n",
|
|
"print(\"Downloading voice package...\")\n",
|
|
"if enhanced_accessibility:\n",
|
|
" playaudio(\"downloading\")\n",
|
|
"if package_url_or_id.startswith(\"1\"):\n",
|
|
" !gdown -q \"{package_url_or_id}\" -O \"voice_package.tar.gz\"\n",
|
|
"elif package_url_or_id.startswith(\"https://drive.google.com/file/d/\"):\n",
|
|
" !gdown -q \"{package_url_or_id}\" -O \"voice_package.tar.gz\" --fuzzy\n",
|
|
"else:\n",
|
|
" !wget -q \"{package_url_or_id}\" -O \"voice_package.tar.gz\"\n",
|
|
"if os.path.exists(\"/content/piper/src/python/voice_package.tar.gz\"):\n",
|
|
" !tar -xf voice_package.tar.gz\n",
|
|
" print(\"Voice package downloaded!\")\n",
|
|
" if enhanced_accessibility:\n",
|
|
" playaudio(\"downloaded\")\n",
|
|
"else:\n",
|
|
" if enhanced_accessibility:\n",
|
|
" playaudio(\"dwnerror\")\n",
|
|
" raise Exception(lan.translate(lang, \"Model failed to download!\"))"
|
|
],
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "ykIYmVXccg6s"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# inferencing"
|
|
],
|
|
"metadata": {
|
|
"id": "MRvkYJF6g5FT"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"#@title run inference\n",
|
|
"#@markdown #### before you enjoy... Some notes!\n",
|
|
"#@markdown * You can run the cell to download voice packs and download voices you want at any time, even if you run this cell!\n",
|
|
"#@markdown * When you download a new voice, run this cell again and you will now be able to toggle between all the ones you download. Incredible, right?\n",
|
|
"\n",
|
|
"#@markdown Enjoy!!\n",
|
|
"\n",
|
|
"%cd /content/piper/src/python\n",
|
|
"# original: infer_onnx.py\n",
|
|
"import json\n",
|
|
"import logging\n",
|
|
"import math\n",
|
|
"import sys\n",
|
|
"from pathlib import Path\n",
|
|
"from enum import Enum\n",
|
|
"from typing import Iterable, List, Optional, Union\n",
|
|
"import numpy as np\n",
|
|
"import onnxruntime\n",
|
|
"from piper_train.vits.utils import audio_float_to_int16\n",
|
|
"import glob\n",
|
|
"import ipywidgets as widgets\n",
|
|
"from IPython.display import display, Audio, Markdown, clear_output\n",
|
|
"from piper_phonemize import phonemize_codepoints, phonemize_espeak, tashkeel_run\n",
|
|
"\n",
|
|
"_LOGGER = logging.getLogger(\"piper_train.infer_onnx\")\n",
|
|
"\n",
|
|
"def detect_onnx_models(path):\n",
|
|
" onnx_models = glob.glob(path + '/*.onnx')\n",
|
|
" if len(onnx_models) > 1:\n",
|
|
" return onnx_models\n",
|
|
" elif len(onnx_models) == 1:\n",
|
|
" return onnx_models[0]\n",
|
|
" else:\n",
|
|
" return None\n",
|
|
"\n",
|
|
"\n",
|
|
"def main():\n",
|
|
" \"\"\"Main entry point\"\"\"\n",
|
|
" models_path = \"/content/piper/src/python\"\n",
|
|
" logging.basicConfig(level=logging.DEBUG)\n",
|
|
" providers = [\n",
|
|
" \"CPUExecutionProvider\"\n",
|
|
" if use_gpu is False\n",
|
|
" else (\"CUDAExecutionProvider\", {\"cudnn_conv_algo_search\": \"DEFAULT\"})\n",
|
|
" ]\n",
|
|
" sess_options = onnxruntime.SessionOptions()\n",
|
|
" model = None\n",
|
|
" onnx_models = detect_onnx_models(models_path)\n",
|
|
" speaker_selection = widgets.Dropdown(\n",
|
|
" options=[],\n",
|
|
" description=f'{lan.translate(lang, \"Select speaker\")}:',\n",
|
|
" layout={'visibility': 'hidden'}\n",
|
|
" )\n",
|
|
" if onnx_models is None:\n",
|
|
" if enhanced_accessibility:\n",
|
|
" playaudio(\"novoices\")\n",
|
|
" raise Exception(lan.translate(lang, \"No downloaded voice packages!\"))\n",
|
|
" elif isinstance(onnx_models, str):\n",
|
|
" onnx_model = onnx_models\n",
|
|
" model, config = load_onnx(onnx_model, sess_options, providers)\n",
|
|
" if config[\"num_speakers\"] > 1:\n",
|
|
" speaker_selection.options = config[\"speaker_id_map\"].values()\n",
|
|
" speaker_selection.layout.visibility = 'visible'\n",
|
|
" preview_sid = 0\n",
|
|
" if enhanced_accessibility:\n",
|
|
" playaudio(\"multispeaker\")\n",
|
|
" else:\n",
|
|
" speaker_selection.layout.visibility = 'hidden'\n",
|
|
" preview_sid = None\n",
|
|
"\n",
|
|
" if enhanced_accessibility:\n",
|
|
" inferencing(\n",
|
|
" model,\n",
|
|
" config,\n",
|
|
" preview_sid,\n",
|
|
" lan.translate(\n",
|
|
" config[\"espeak\"][\"voice\"][:2],\n",
|
|
" \"Interface openned. Write your texts, configure the different synthesis options or download all the voices you want. Enjoy!\"\n",
|
|
" )\n",
|
|
" )\n",
|
|
" else:\n",
|
|
" voice_model_names = []\n",
|
|
" for current in onnx_models:\n",
|
|
" voice_struct = current.split(\"/\")[5]\n",
|
|
" voice_model_names.append(voice_struct)\n",
|
|
" if enhanced_accessibility:\n",
|
|
" playaudio(\"selectmodel\")\n",
|
|
" selection = widgets.Dropdown(\n",
|
|
" options=voice_model_names,\n",
|
|
" description=f'{lan.translate(lang, \"Select voice package\")}:',\n",
|
|
" )\n",
|
|
" load_btn = widgets.Button(\n",
|
|
" description=lan.translate(lang, \"Load it!\")\n",
|
|
" )\n",
|
|
" config = None\n",
|
|
" def load_model(button):\n",
|
|
" nonlocal config\n",
|
|
" global onnx_model\n",
|
|
" nonlocal model\n",
|
|
" nonlocal models_path\n",
|
|
" selected_voice = selection.value\n",
|
|
" onnx_model = f\"{models_path}/{selected_voice}\"\n",
|
|
" model, config = load_onnx(onnx_model, sess_options, providers)\n",
|
|
" if enhanced_accessibility:\n",
|
|
" playaudio(\"loaded\")\n",
|
|
" if config[\"num_speakers\"] > 1:\n",
|
|
" speaker_selection.options = config[\"speaker_id_map\"].values()\n",
|
|
" speaker_selection.layout.visibility = 'visible'\n",
|
|
" if enhanced_accessibility:\n",
|
|
" playaudio(\"multispeaker\")\n",
|
|
" else:\n",
|
|
" speaker_selection.layout.visibility = 'hidden'\n",
|
|
"\n",
|
|
" load_btn.on_click(load_model)\n",
|
|
" display(selection, load_btn)\n",
|
|
" display(speaker_selection)\n",
|
|
" speed_slider = widgets.FloatSlider(\n",
|
|
" value=1,\n",
|
|
" min=0.25,\n",
|
|
" max=4,\n",
|
|
" step=0.1,\n",
|
|
" description=lan.translate(lang, \"Rate scale\"),\n",
|
|
" orientation='horizontal',\n",
|
|
" )\n",
|
|
" noise_scale_slider = widgets.FloatSlider(\n",
|
|
" value=0.667,\n",
|
|
" min=0.25,\n",
|
|
" max=4,\n",
|
|
" step=0.1,\n",
|
|
" description=lan.translate(lang, \"Phoneme noise scale\"),\n",
|
|
" orientation='horizontal',\n",
|
|
" )\n",
|
|
" noise_scale_w_slider = widgets.FloatSlider(\n",
|
|
" value=1,\n",
|
|
" min=0.25,\n",
|
|
" max=4,\n",
|
|
" step=0.1,\n",
|
|
" description=lan.translate(lang, \"Phoneme stressing scale\"),\n",
|
|
" orientation='horizontal',\n",
|
|
" )\n",
|
|
" play = widgets.Checkbox(\n",
|
|
" value=True,\n",
|
|
" description=lan.translate(lang, \"Auto-play\"),\n",
|
|
" disabled=False\n",
|
|
" )\n",
|
|
" text_input = widgets.Text(\n",
|
|
" value='',\n",
|
|
" placeholder=f'{lan.translate(lang, \"Enter your text here\")}:',\n",
|
|
" description=lan.translate(lang, \"Text to synthesize\"),\n",
|
|
" layout=widgets.Layout(width='80%')\n",
|
|
" )\n",
|
|
" synthesize_button = widgets.Button(\n",
|
|
" description=lan.translate(lang, \"Synthesize\"),\n",
|
|
" button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n",
|
|
" tooltip=lan.translate(lang, \"Click here to synthesize the text.\"),\n",
|
|
" icon='check'\n",
|
|
" )\n",
|
|
" close_button = widgets.Button(\n",
|
|
" description=lan.translate(lang, \"Exit\"),\n",
|
|
" tooltip=lan.translate(lang, \"Closes this GUI.\"),\n",
|
|
" icon='check'\n",
|
|
" )\n",
|
|
"\n",
|
|
" def on_synthesize_button_clicked(b):\n",
|
|
" if model is None:\n",
|
|
" if enhanced_accessibility:\n",
|
|
" playaudio(\"nomodel\")\n",
|
|
" raise Exception(lan.translate(lang, \"You have not loaded any model from the list!\"))\n",
|
|
" text = text_input.value\n",
|
|
" if config[\"num_speakers\"] > 1:\n",
|
|
" sid = speaker_selection.value\n",
|
|
" else:\n",
|
|
" sid = None\n",
|
|
" rate = speed_slider.value\n",
|
|
" noise_scale = noise_scale_slider.value\n",
|
|
" noise_scale_w = noise_scale_w_slider.value\n",
|
|
" auto_play = play.value\n",
|
|
" inferencing(model, config, sid, text, rate, noise_scale, noise_scale_w, auto_play)\n",
|
|
"\n",
|
|
" def on_close_button_clicked(b):\n",
|
|
" clear_output()\n",
|
|
" if enhanced_accessibility:\n",
|
|
" playaudio(\"exit\")\n",
|
|
"\n",
|
|
" synthesize_button.on_click(on_synthesize_button_clicked)\n",
|
|
" close_button.on_click(on_close_button_clicked)\n",
|
|
" display(text_input)\n",
|
|
" display(speed_slider)\n",
|
|
" display(noise_scale_slider)\n",
|
|
" display(noise_scale_w_slider)\n",
|
|
" display(play)\n",
|
|
" display(synthesize_button)\n",
|
|
" display(close_button)\n",
|
|
"\n",
|
|
"def load_onnx(model, sess_options, providers = [\"CPUExecutionProvider\"]):\n",
|
|
" _LOGGER.debug(\"Loading model from %s\", model)\n",
|
|
" config = load_config(model)\n",
|
|
" model = onnxruntime.InferenceSession(\n",
|
|
" str(model),\n",
|
|
" sess_options=sess_options,\n",
|
|
" providers= providers\n",
|
|
" )\n",
|
|
" _LOGGER.info(\"Loaded model from %s\", model)\n",
|
|
" return model, config\n",
|
|
"\n",
|
|
"def load_config(model):\n",
|
|
" with open(f\"{model}.json\", \"r\") as file:\n",
|
|
" config = json.load(file)\n",
|
|
" return config\n",
|
|
"PAD = \"_\" # padding (0)\n",
|
|
"BOS = \"^\" # beginning of sentence\n",
|
|
"EOS = \"$\" # end of sentence\n",
|
|
"\n",
|
|
"class PhonemeType(str, Enum):\n",
|
|
" ESPEAK = \"espeak\"\n",
|
|
" TEXT = \"text\"\n",
|
|
"\n",
|
|
"def phonemize(config, text: str) -> List[List[str]]:\n",
|
|
" \"\"\"Text to phonemes grouped by sentence.\"\"\"\n",
|
|
" if config[\"phoneme_type\"] == PhonemeType.ESPEAK:\n",
|
|
" if config[\"espeak\"][\"voice\"] == \"ar\":\n",
|
|
" # Arabic diacritization\n",
|
|
" # https://github.com/mush42/libtashkeel/\n",
|
|
" text = tashkeel_run(text)\n",
|
|
" return phonemize_espeak(text, config[\"espeak\"][\"voice\"])\n",
|
|
" if config[\"phoneme_type\"] == PhonemeType.TEXT:\n",
|
|
" return phonemize_codepoints(text)\n",
|
|
" raise ValueError(f'Unexpected phoneme type: {config[\"phoneme_type\"]}')\n",
|
|
"\n",
|
|
"def phonemes_to_ids(config, phonemes: List[str]) -> List[int]:\n",
|
|
" \"\"\"Phonemes to ids.\"\"\"\n",
|
|
" id_map = config[\"phoneme_id_map\"]\n",
|
|
" ids: List[int] = list(id_map[BOS])\n",
|
|
" for phoneme in phonemes:\n",
|
|
" if phoneme not in id_map:\n",
|
|
" print(\"Missing phoneme from id map: %s\", phoneme)\n",
|
|
" continue\n",
|
|
" ids.extend(id_map[phoneme])\n",
|
|
" ids.extend(id_map[PAD])\n",
|
|
" ids.extend(id_map[EOS])\n",
|
|
" return ids\n",
|
|
"\n",
|
|
"def inferencing(model, config, sid, line, length_scale = 1, noise_scale = 0.667, noise_scale_w = 0.8, auto_play=True):\n",
|
|
" audios = []\n",
|
|
" if config[\"phoneme_type\"] == \"PhonemeType.ESPEAK\":\n",
|
|
" config[\"phoneme_type\"] = \"espeak\"\n",
|
|
" text = phonemize(config, line)\n",
|
|
" for phonemes in text:\n",
|
|
" phoneme_ids = phonemes_to_ids(config, phonemes)\n",
|
|
" num_speakers = config[\"num_speakers\"]\n",
|
|
" if num_speakers == 1:\n",
|
|
" speaker_id = None # for now\n",
|
|
" else:\n",
|
|
" speaker_id = sid\n",
|
|
" text = np.expand_dims(np.array(phoneme_ids, dtype=np.int64), 0)\n",
|
|
" text_lengths = np.array([text.shape[1]], dtype=np.int64)\n",
|
|
" scales = np.array(\n",
|
|
" [noise_scale, length_scale, noise_scale_w],\n",
|
|
" dtype=np.float32,\n",
|
|
" )\n",
|
|
" sid = None\n",
|
|
" if speaker_id is not None:\n",
|
|
" sid = np.array([speaker_id], dtype=np.int64)\n",
|
|
" audio = model.run(\n",
|
|
" None,\n",
|
|
" {\n",
|
|
" \"input\": text,\n",
|
|
" \"input_lengths\": text_lengths,\n",
|
|
" \"scales\": scales,\n",
|
|
" \"sid\": sid,\n",
|
|
" },\n",
|
|
" )[0].squeeze((0, 1))\n",
|
|
" audio = audio_float_to_int16(audio.squeeze())\n",
|
|
" audios.append(audio)\n",
|
|
" merged_audio = np.concatenate(audios)\n",
|
|
" sample_rate = config[\"audio\"][\"sample_rate\"]\n",
|
|
" display(Markdown(f\"{line}\"))\n",
|
|
" display(Audio(merged_audio, rate=sample_rate, autoplay=auto_play))\n",
|
|
"\n",
|
|
"def denoise(\n",
|
|
" audio: np.ndarray, bias_spec: np.ndarray, denoiser_strength: float\n",
|
|
") -> np.ndarray:\n",
|
|
" audio_spec, audio_angles = transform(audio)\n",
|
|
"\n",
|
|
" a = bias_spec.shape[-1]\n",
|
|
" b = audio_spec.shape[-1]\n",
|
|
" repeats = max(1, math.ceil(b / a))\n",
|
|
" bias_spec_repeat = np.repeat(bias_spec, repeats, axis=-1)[..., :b]\n",
|
|
"\n",
|
|
" audio_spec_denoised = audio_spec - (bias_spec_repeat * denoiser_strength)\n",
|
|
" audio_spec_denoised = np.clip(audio_spec_denoised, a_min=0.0, a_max=None)\n",
|
|
" audio_denoised = inverse(audio_spec_denoised, audio_angles)\n",
|
|
"\n",
|
|
" return audio_denoised\n",
|
|
"\n",
|
|
"\n",
|
|
"def stft(x, fft_size, hopsamp):\n",
|
|
" \"\"\"Compute and return the STFT of the supplied time domain signal x.\n",
|
|
" Args:\n",
|
|
" x (1-dim Numpy array): A time domain signal.\n",
|
|
" fft_size (int): FFT size. Should be a power of 2, otherwise DFT will be used.\n",
|
|
" hopsamp (int):\n",
|
|
" Returns:\n",
|
|
" The STFT. The rows are the time slices and columns are the frequency bins.\n",
|
|
" \"\"\"\n",
|
|
" window = np.hanning(fft_size)\n",
|
|
" fft_size = int(fft_size)\n",
|
|
" hopsamp = int(hopsamp)\n",
|
|
" return np.array(\n",
|
|
" [\n",
|
|
" np.fft.rfft(window * x[i : i + fft_size])\n",
|
|
" for i in range(0, len(x) - fft_size, hopsamp)\n",
|
|
" ]\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"def istft(X, fft_size, hopsamp):\n",
|
|
" \"\"\"Invert a STFT into a time domain signal.\n",
|
|
" Args:\n",
|
|
" X (2-dim Numpy array): Input spectrogram. The rows are the time slices and columns are the frequency bins.\n",
|
|
" fft_size (int):\n",
|
|
" hopsamp (int): The hop size, in samples.\n",
|
|
" Returns:\n",
|
|
" The inverse STFT.\n",
|
|
" \"\"\"\n",
|
|
" fft_size = int(fft_size)\n",
|
|
" hopsamp = int(hopsamp)\n",
|
|
" window = np.hanning(fft_size)\n",
|
|
" time_slices = X.shape[0]\n",
|
|
" len_samples = int(time_slices * hopsamp + fft_size)\n",
|
|
" x = np.zeros(len_samples)\n",
|
|
" for n, i in enumerate(range(0, len(x) - fft_size, hopsamp)):\n",
|
|
" x[i : i + fft_size] += window * np.real(np.fft.irfft(X[n]))\n",
|
|
" return x\n",
|
|
"\n",
|
|
"\n",
|
|
"def inverse(magnitude, phase):\n",
|
|
" recombine_magnitude_phase = np.concatenate(\n",
|
|
" [magnitude * np.cos(phase), magnitude * np.sin(phase)], axis=1\n",
|
|
" )\n",
|
|
"\n",
|
|
" x_org = recombine_magnitude_phase\n",
|
|
" n_b, n_f, n_t = x_org.shape # pylint: disable=unpacking-non-sequence\n",
|
|
" x = np.empty([n_b, n_f // 2, n_t], dtype=np.complex64)\n",
|
|
" x.real = x_org[:, : n_f // 2]\n",
|
|
" x.imag = x_org[:, n_f // 2 :]\n",
|
|
" inverse_transform = []\n",
|
|
" for y in x:\n",
|
|
" y_ = istft(y.T, fft_size=1024, hopsamp=256)\n",
|
|
" inverse_transform.append(y_[None, :])\n",
|
|
"\n",
|
|
" inverse_transform = np.concatenate(inverse_transform, 0)\n",
|
|
"\n",
|
|
" return inverse_transform\n",
|
|
"\n",
|
|
"\n",
|
|
"def transform(input_data):\n",
|
|
" x = input_data\n",
|
|
" real_part = []\n",
|
|
" imag_part = []\n",
|
|
" for y in x:\n",
|
|
" y_ = stft(y, fft_size=1024, hopsamp=256).T\n",
|
|
" real_part.append(y_.real[None, :, :]) # pylint: disable=unsubscriptable-object\n",
|
|
" imag_part.append(y_.imag[None, :, :]) # pylint: disable=unsubscriptable-object\n",
|
|
" real_part = np.concatenate(real_part, 0)\n",
|
|
" imag_part = np.concatenate(imag_part, 0)\n",
|
|
"\n",
|
|
" magnitude = np.sqrt(real_part**2 + imag_part**2)\n",
|
|
" phase = np.arctan2(imag_part.data, real_part.data)\n",
|
|
"\n",
|
|
" return magnitude, phase\n",
|
|
"\n",
|
|
"main()"
|
|
],
|
|
"metadata": {
|
|
"id": "hcKk8M2ug8kM",
|
|
"cellView": "form"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
}
|
|
]
|
|
} |