-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init lid rerank * init lid rerank * add greedy ctc score
- Loading branch information
1 parent
920a548
commit c214511
Showing
20 changed files
with
1,376 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# N-best Re-ranking for Multilingual LID+ASR | ||
This project provides N-best re-ranking, a simple inference procedure, for improving multilingual speech recognition (ASR) "in the wild" where models are expected to first predict language identity (LID) before transcribing. Our method considers N-best LID predictions for each utterance, runs the corresponding ASR in N different languages, and then uses external features over the candidate transcriptions to determine re-rank. | ||
|
||
The workflow is as follows: 1) run LID+ASR inference (MMS and Whisper are supported), 2) compute external re-ranking features, 3) tune feature coefficients on dev set, and 4) apply on test set. | ||
|
||
For more information about our method, please refer to the paper: "Improving Multilingual ASR in the Wild Using Simple N-best Re-ranking". | ||
|
||
## 1) Commands to Run LID+ASR Inference | ||
|
||
### Data Prep | ||
Prepare a text file with one path to a wav file in each line: | ||
``` | ||
#/path/to/wav/list | ||
/path/to/audio1.wav | ||
/path/to/audio2.wav | ||
/path/to/audio3.wav | ||
``` | ||
|
||
The following workflow also assumes that LID and ASR references are available (at least for the dev set). We use [3-letter iso codes](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l4017_langs.html) for both Whisper and MMS. | ||
|
||
Next run either Whisper or MMS based LID+ASR. | ||
|
||
### Whisper | ||
Refer to the [Whisper documentation](https://github.com/openai/whisper) for installation instructions. | ||
|
||
First run LID: | ||
``` | ||
python whisper/infer_lid.py --wavs "path/to/wav/list" --dst "path/to/lid/results" --model large-v2 --n 10 | ||
``` | ||
Note that the size of the N-best list is set as 10 here. | ||
|
||
Then run ASR, using the top-N LID predictions: | ||
``` | ||
python whisper/infer_asr.py --wavs "path/to/wav/list" --lids "path/to/lid/results"/nbest_lid --dst "path/to/asr/results" --model large-v2 | ||
``` | ||
|
||
### MMS | ||
Refer to the [Fairseq documentation](https://github.com/facebookresearch/fairseq/tree/main) for installation instructions. | ||
|
||
Prepare data and models following the [instructions from the MMS repository](https://github.com/facebookresearch/fairseq/tree/main/examples/mms). Note that the MMS backend expects a slightly different wav list format, which can be obtained via: | ||
``` | ||
python mms/format_wav_list.py --src "/path/to/wav/list" --dst "/path/to/wav/manifest.tsv" | ||
``` | ||
Note that MMS also expects LID references in a file named `"/path/to/wav/manifest.lang"`. | ||
|
||
Then run LID: | ||
``` | ||
cd "path/to/fairseq/dir" | ||
PYTHONPATH='.' python3 examples/mms/lid/infer.py "path/to/dict/dir" --path "path/to/model" --task audio_classification --infer-manifest "path/to/wav/manifest.tsv" --output-path "path/to/lid/results" --top-k 10 | ||
``` | ||
Note that the size of the N-best list is set as 10 here. | ||
|
||
Then run ASR, using the top-N LID predictions. Since MMS uses language-specific parameters, we've parallelized inference across languages: | ||
``` | ||
#Split data by language | ||
python mms/split_by_lang.py --wavs_tsv "/path/to/wav/manifest.tsv" --lid_preds "path/to/lid/results"predictions.txt --dst "path/to/data/split" | ||
#Write language-specific ASR python commands to an executable file | ||
mms/make_parallel_single_runs.py --dump "path/to/data/split" --model "path/to/model" --dst "path/to/asr/results" --fairseq_dir "path/to/fairseq/dir" > run.sh | ||
#Running each language sequentially (you can also parallelize this) | ||
. ./run.sh | ||
#Merge language-specific results back to original order | ||
python mms/merge_by_run.py --dump "path/to/data/split" --exp "path/to/asr/results" | ||
``` | ||
|
||
## 2) Commands to Compute External Re-ranking Features | ||
|
||
### MaLA - Large Language Model | ||
``` | ||
python mala/infer.py --txt "path/to/asr/results"/nbest_asr_hyp --dst "path/to/lm/results" | ||
``` | ||
|
||
### NLLB - Written LID Model | ||
Download the model from the [official source](https://github.com/facebookresearch/fairseq/tree/nllb#lid-model). | ||
|
||
``` | ||
python nllb/infer.py --txt "path/to/asr/results"/nbest_asr_hyp --dst "path/to/wlid/results" --model "path/to/nllb/model" | ||
``` | ||
|
||
### MMS-Zeroshot - U-roman Acoustic Model | ||
Download the model from the [official source](https://huggingface.co/spaces/mms-meta/mms-zeroshot/tree/main). | ||
|
||
First run u-romanization on the N-best ASR hypotheses: | ||
``` | ||
python mms-zs/uromanize.py --txt "path/to/asr/results"/nbest_asr_hyp --lid "path/to/lid/results"/nbest_lid --dst "path/to/uasr/results" --model "path/to/mms-zeroshot" | ||
``` | ||
|
||
Then compute the forced alignment score using the MMS-Zeroshot model: | ||
``` | ||
python mms-zs/falign.py --uroman_txt "path/to/uasr/results"/nbest_asr_hyp_uroman --wav "path/to/wav/list" --dst "path/to/uasr/results" --model "path/to/mms-zeroshot" | ||
``` | ||
|
||
## 3) Commands to Tune Feature Coefficients | ||
``` | ||
python rerank/tune_coefficients.py --slid "path/to/lid/results"/slid_score --asr "path/to/asr/results"/asr_score --wlid "path/to/wlid/results"/wlid_score --lm "path/to/lm/results"/lm_score --uasr "path/to/uasr/results"/uasr_score --dst "path/to/rerank/results" --ref_lid "ground-truth/lid" --nbest_lid "path/to/lid/results"/nbest_lid --ref_asr "ground-truth/asr" --nbest_asr "path/to/asr/results"/nbest_asr_hyp | ||
``` | ||
|
||
## 4) Commands to Apply on Test Set | ||
``` | ||
python rerank/rerank.py --slid "path/to/lid/results"/slid_score --asr "path/to/asr/results"/asr_score --wlid "path/to/wlid/results"/wlid_score --lm "path/to/lm/results"/lm_score --uasr "path/to/uasr/results"/uasr_score --dst "path/to/rerank/results" --ref_lid "ground-truth/lid" --nbest_lid "path/to/lid/results"/nbest_lid --ref_asr "ground-truth/asr" --nbest_asr "path/to/asr/results"/nbest_asr_hyp --w "path/to/rerank/results"/best_coefficients | ||
``` | ||
|
||
The re-ranked LID and ASR will be in `"path/to/rerank/results"/reranked_1best_lid` and `"path/to/rerank/results"/reranked_1best_asr_hyp` respectively. | ||
|
||
# Citation | ||
``` | ||
@article{yan2024wild, | ||
title={Improving Multilingual ASR in the Wild Using Simple N-best Re-ranking}, | ||
author={Brian Yan, Vineel Pratap, Shinji Watanabe, Michael Auli}, | ||
journal={arXiv}, | ||
year={2024} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
adx | ||
bod | ||
cmn | ||
dzo | ||
jpn | ||
khg | ||
khm | ||
lao | ||
mya | ||
tha | ||
yue |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
from peft import PeftModel | ||
from tqdm import tqdm | ||
import argparse | ||
import os | ||
import torch | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--txt", type=str) | ||
parser.add_argument("--dst", type=str) | ||
parser.add_argument("--gpu", type=int, default=1) | ||
args = parser.parse_args() | ||
|
||
if __name__ == "__main__": | ||
if not os.path.exists(args.dst): | ||
os.makedirs(args.dst) | ||
|
||
base_model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf') | ||
base_model.resize_token_embeddings(260164) | ||
tokenizer = AutoTokenizer.from_pretrained('MaLA-LM/mala-500') | ||
if args.gpu == 1: | ||
model = PeftModel.from_pretrained(base_model, 'MaLA-LM/mala-500').to("cuda") | ||
else: | ||
model = PeftModel.from_pretrained(base_model, 'MaLA-LM/mala-500') | ||
model.eval() | ||
|
||
txts = [x.strip() for x in open(args.txt, "r").readlines()] | ||
|
||
with open(args.dst + "/lm_score", "w", buffering=1) as f: | ||
for t in tqdm(txts): | ||
input_tokens = tokenizer("", add_special_tokens=True, return_tensors='pt').input_ids | ||
if len(t) > 0: | ||
output_tokens = tokenizer(t, add_special_tokens=False, return_tensors='pt').input_ids | ||
tokens = torch.cat([input_tokens, output_tokens], dim=1) | ||
length = output_tokens.shape[-1] | ||
else: | ||
tokens = input_tokens | ||
length = 0 | ||
|
||
if args.gpu == 1: | ||
tokens = tokens.to("cuda") | ||
|
||
with torch.no_grad(): | ||
outputs = model(tokens) | ||
logits = outputs.logits | ||
|
||
log_sum = 0 | ||
for i in range(tokens.shape[-1] - 1): | ||
past_tok, current_tok = i, i + 1 | ||
token_logit = logits[0, past_tok, :] | ||
token_log_probs = torch.nn.functional.log_softmax(token_logit, dim=-1) | ||
log_token_prob = token_log_probs[tokens[0, current_tok]].item() | ||
log_sum += log_token_prob | ||
|
||
f.write(str(log_sum) + "\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import os | ||
import tempfile | ||
import re | ||
import librosa | ||
import torch | ||
import json | ||
import numpy as np | ||
import argparse | ||
from tqdm import tqdm | ||
import math | ||
|
||
from transformers import Wav2Vec2ForCTC, AutoProcessor | ||
|
||
from lib import falign_ext | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--uroman_txt", type=str) | ||
parser.add_argument("--wav", type=str) | ||
parser.add_argument("--dst", type=str) | ||
parser.add_argument("--model", type=str) | ||
parser.add_argument("--n", type=int, default=10) | ||
args = parser.parse_args() | ||
|
||
ASR_SAMPLING_RATE = 16_000 | ||
|
||
MODEL_ID = "/upload/mms_zs" | ||
|
||
processor = AutoProcessor.from_pretrained(args.model+MODEL_ID) | ||
model = Wav2Vec2ForCTC.from_pretrained(args.model+MODEL_ID) | ||
|
||
token_file = args.model+"/upload/mms_zs/tokens.txt" | ||
|
||
if __name__ == "__main__": | ||
if not os.path.exists(args.dst): | ||
os.makedirs(args.dst) | ||
|
||
tokens = [x.strip() for x in open(token_file, "r").readlines()] | ||
|
||
txts = [x.strip() for x in open(args.uroman_txt, "r").readlines()] | ||
wavs = [x.strip() for x in open(args.wav, "r").readlines()] | ||
assert len(txts) == args.n * len(wavs) | ||
|
||
if torch.cuda.is_available(): | ||
device = torch.device("cuda") | ||
elif ( | ||
hasattr(torch.backends, "mps") | ||
and torch.backends.mps.is_available() | ||
and torch.backends.mps.is_built() | ||
): | ||
device = torch.device("mps") | ||
else: | ||
device = torch.device("cpu") | ||
|
||
model.to(device) | ||
|
||
# clear it | ||
with open(args.dst + "/uasr_score", "w") as f1: | ||
pass | ||
|
||
for i, w in tqdm(enumerate(wavs)): | ||
assert isinstance(w, str) | ||
audio_samples = librosa.load(w, sr=ASR_SAMPLING_RATE, mono=True)[0] | ||
|
||
inputs = processor( | ||
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt" | ||
) | ||
inputs = inputs.to(device) | ||
|
||
with torch.no_grad(): | ||
outputs = model(**inputs).logits | ||
|
||
emissions = outputs.log_softmax(dim=-1).squeeze() | ||
|
||
for j in range(args.n): | ||
idx = (args.n * i) + j | ||
chars = txts[idx].split() | ||
token_sequence = [tokens.index(x) for x in chars] | ||
|
||
try: | ||
_, alphas, _ = falign_ext.falign(emissions, torch.tensor(token_sequence, device=device).int(), False) | ||
aligned_alpha = max(alphas[-1]).item() | ||
except: | ||
aligned_alpha = math.log(0.000000001) | ||
|
||
with open(args.dst + "/uasr_score", "a") as f1: | ||
f1.write(str(aligned_alpha) + "\n") | ||
f1.flush() |
Oops, something went wrong.