Skip to content

Commit

Permalink
Add LID rerank for MMS (#5545)
Browse files Browse the repository at this point in the history
* init lid rerank

* init lid rerank

* add greedy ctc score
  • Loading branch information
brianyan918 authored Sep 27, 2024
1 parent 920a548 commit c214511
Show file tree
Hide file tree
Showing 20 changed files with 1,376 additions and 4 deletions.
115 changes: 115 additions & 0 deletions examples/mms/lid_rerank/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# N-best Re-ranking for Multilingual LID+ASR
This project provides N-best re-ranking, a simple inference procedure, for improving multilingual speech recognition (ASR) "in the wild" where models are expected to first predict language identity (LID) before transcribing. Our method considers N-best LID predictions for each utterance, runs the corresponding ASR in N different languages, and then uses external features over the candidate transcriptions to determine re-rank.

The workflow is as follows: 1) run LID+ASR inference (MMS and Whisper are supported), 2) compute external re-ranking features, 3) tune feature coefficients on dev set, and 4) apply on test set.

For more information about our method, please refer to the paper: "Improving Multilingual ASR in the Wild Using Simple N-best Re-ranking".

## 1) Commands to Run LID+ASR Inference

### Data Prep
Prepare a text file with one path to a wav file in each line:
```
#/path/to/wav/list
/path/to/audio1.wav
/path/to/audio2.wav
/path/to/audio3.wav
```

The following workflow also assumes that LID and ASR references are available (at least for the dev set). We use [3-letter iso codes](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l4017_langs.html) for both Whisper and MMS.

Next run either Whisper or MMS based LID+ASR.

### Whisper
Refer to the [Whisper documentation](https://github.com/openai/whisper) for installation instructions.

First run LID:
```
python whisper/infer_lid.py --wavs "path/to/wav/list" --dst "path/to/lid/results" --model large-v2 --n 10
```
Note that the size of the N-best list is set as 10 here.

Then run ASR, using the top-N LID predictions:
```
python whisper/infer_asr.py --wavs "path/to/wav/list" --lids "path/to/lid/results"/nbest_lid --dst "path/to/asr/results" --model large-v2
```

### MMS
Refer to the [Fairseq documentation](https://github.com/facebookresearch/fairseq/tree/main) for installation instructions.

Prepare data and models following the [instructions from the MMS repository](https://github.com/facebookresearch/fairseq/tree/main/examples/mms). Note that the MMS backend expects a slightly different wav list format, which can be obtained via:
```
python mms/format_wav_list.py --src "/path/to/wav/list" --dst "/path/to/wav/manifest.tsv"
```
Note that MMS also expects LID references in a file named `"/path/to/wav/manifest.lang"`.

Then run LID:
```
cd "path/to/fairseq/dir"
PYTHONPATH='.' python3 examples/mms/lid/infer.py "path/to/dict/dir" --path "path/to/model" --task audio_classification --infer-manifest "path/to/wav/manifest.tsv" --output-path "path/to/lid/results" --top-k 10
```
Note that the size of the N-best list is set as 10 here.

Then run ASR, using the top-N LID predictions. Since MMS uses language-specific parameters, we've parallelized inference across languages:
```
#Split data by language
python mms/split_by_lang.py --wavs_tsv "/path/to/wav/manifest.tsv" --lid_preds "path/to/lid/results"predictions.txt --dst "path/to/data/split"
#Write language-specific ASR python commands to an executable file
mms/make_parallel_single_runs.py --dump "path/to/data/split" --model "path/to/model" --dst "path/to/asr/results" --fairseq_dir "path/to/fairseq/dir" > run.sh
#Running each language sequentially (you can also parallelize this)
. ./run.sh
#Merge language-specific results back to original order
python mms/merge_by_run.py --dump "path/to/data/split" --exp "path/to/asr/results"
```

## 2) Commands to Compute External Re-ranking Features

### MaLA - Large Language Model
```
python mala/infer.py --txt "path/to/asr/results"/nbest_asr_hyp --dst "path/to/lm/results"
```

### NLLB - Written LID Model
Download the model from the [official source](https://github.com/facebookresearch/fairseq/tree/nllb#lid-model).

```
python nllb/infer.py --txt "path/to/asr/results"/nbest_asr_hyp --dst "path/to/wlid/results" --model "path/to/nllb/model"
```

### MMS-Zeroshot - U-roman Acoustic Model
Download the model from the [official source](https://huggingface.co/spaces/mms-meta/mms-zeroshot/tree/main).

First run u-romanization on the N-best ASR hypotheses:
```
python mms-zs/uromanize.py --txt "path/to/asr/results"/nbest_asr_hyp --lid "path/to/lid/results"/nbest_lid --dst "path/to/uasr/results" --model "path/to/mms-zeroshot"
```

Then compute the forced alignment score using the MMS-Zeroshot model:
```
python mms-zs/falign.py --uroman_txt "path/to/uasr/results"/nbest_asr_hyp_uroman --wav "path/to/wav/list" --dst "path/to/uasr/results" --model "path/to/mms-zeroshot"
```

## 3) Commands to Tune Feature Coefficients
```
python rerank/tune_coefficients.py --slid "path/to/lid/results"/slid_score --asr "path/to/asr/results"/asr_score --wlid "path/to/wlid/results"/wlid_score --lm "path/to/lm/results"/lm_score --uasr "path/to/uasr/results"/uasr_score --dst "path/to/rerank/results" --ref_lid "ground-truth/lid" --nbest_lid "path/to/lid/results"/nbest_lid --ref_asr "ground-truth/asr" --nbest_asr "path/to/asr/results"/nbest_asr_hyp
```

## 4) Commands to Apply on Test Set
```
python rerank/rerank.py --slid "path/to/lid/results"/slid_score --asr "path/to/asr/results"/asr_score --wlid "path/to/wlid/results"/wlid_score --lm "path/to/lm/results"/lm_score --uasr "path/to/uasr/results"/uasr_score --dst "path/to/rerank/results" --ref_lid "ground-truth/lid" --nbest_lid "path/to/lid/results"/nbest_lid --ref_asr "ground-truth/asr" --nbest_asr "path/to/asr/results"/nbest_asr_hyp --w "path/to/rerank/results"/best_coefficients
```

The re-ranked LID and ASR will be in `"path/to/rerank/results"/reranked_1best_lid` and `"path/to/rerank/results"/reranked_1best_asr_hyp` respectively.

# Citation
```
@article{yan2024wild,
title={Improving Multilingual ASR in the Wild Using Simple N-best Re-ranking},
author={Brian Yan, Vineel Pratap, Shinji Watanabe, Michael Auli},
journal={arXiv},
year={2024}
}
```
11 changes: 11 additions & 0 deletions examples/mms/lid_rerank/cer_langs.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
adx
bod
cmn
dzo
jpn
khg
khm
lao
mya
tha
yue
55 changes: 55 additions & 0 deletions examples/mms/lid_rerank/mala/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from tqdm import tqdm
import argparse
import os
import torch

parser = argparse.ArgumentParser()
parser.add_argument("--txt", type=str)
parser.add_argument("--dst", type=str)
parser.add_argument("--gpu", type=int, default=1)
args = parser.parse_args()

if __name__ == "__main__":
if not os.path.exists(args.dst):
os.makedirs(args.dst)

base_model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')
base_model.resize_token_embeddings(260164)
tokenizer = AutoTokenizer.from_pretrained('MaLA-LM/mala-500')
if args.gpu == 1:
model = PeftModel.from_pretrained(base_model, 'MaLA-LM/mala-500').to("cuda")
else:
model = PeftModel.from_pretrained(base_model, 'MaLA-LM/mala-500')
model.eval()

txts = [x.strip() for x in open(args.txt, "r").readlines()]

with open(args.dst + "/lm_score", "w", buffering=1) as f:
for t in tqdm(txts):
input_tokens = tokenizer("", add_special_tokens=True, return_tensors='pt').input_ids
if len(t) > 0:
output_tokens = tokenizer(t, add_special_tokens=False, return_tensors='pt').input_ids
tokens = torch.cat([input_tokens, output_tokens], dim=1)
length = output_tokens.shape[-1]
else:
tokens = input_tokens
length = 0

if args.gpu == 1:
tokens = tokens.to("cuda")

with torch.no_grad():
outputs = model(tokens)
logits = outputs.logits

log_sum = 0
for i in range(tokens.shape[-1] - 1):
past_tok, current_tok = i, i + 1
token_logit = logits[0, past_tok, :]
token_log_probs = torch.nn.functional.log_softmax(token_logit, dim=-1)
log_token_prob = token_log_probs[tokens[0, current_tok]].item()
log_sum += log_token_prob

f.write(str(log_sum) + "\n")
87 changes: 87 additions & 0 deletions examples/mms/lid_rerank/mms-zs/falign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import tempfile
import re
import librosa
import torch
import json
import numpy as np
import argparse
from tqdm import tqdm
import math

from transformers import Wav2Vec2ForCTC, AutoProcessor

from lib import falign_ext

parser = argparse.ArgumentParser()
parser.add_argument("--uroman_txt", type=str)
parser.add_argument("--wav", type=str)
parser.add_argument("--dst", type=str)
parser.add_argument("--model", type=str)
parser.add_argument("--n", type=int, default=10)
args = parser.parse_args()

ASR_SAMPLING_RATE = 16_000

MODEL_ID = "/upload/mms_zs"

processor = AutoProcessor.from_pretrained(args.model+MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(args.model+MODEL_ID)

token_file = args.model+"/upload/mms_zs/tokens.txt"

if __name__ == "__main__":
if not os.path.exists(args.dst):
os.makedirs(args.dst)

tokens = [x.strip() for x in open(token_file, "r").readlines()]

txts = [x.strip() for x in open(args.uroman_txt, "r").readlines()]
wavs = [x.strip() for x in open(args.wav, "r").readlines()]
assert len(txts) == args.n * len(wavs)

if torch.cuda.is_available():
device = torch.device("cuda")
elif (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
):
device = torch.device("mps")
else:
device = torch.device("cpu")

model.to(device)

# clear it
with open(args.dst + "/uasr_score", "w") as f1:
pass

for i, w in tqdm(enumerate(wavs)):
assert isinstance(w, str)
audio_samples = librosa.load(w, sr=ASR_SAMPLING_RATE, mono=True)[0]

inputs = processor(
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
)
inputs = inputs.to(device)

with torch.no_grad():
outputs = model(**inputs).logits

emissions = outputs.log_softmax(dim=-1).squeeze()

for j in range(args.n):
idx = (args.n * i) + j
chars = txts[idx].split()
token_sequence = [tokens.index(x) for x in chars]

try:
_, alphas, _ = falign_ext.falign(emissions, torch.tensor(token_sequence, device=device).int(), False)
aligned_alpha = max(alphas[-1]).item()
except:
aligned_alpha = math.log(0.000000001)

with open(args.dst + "/uasr_score", "a") as f1:
f1.write(str(aligned_alpha) + "\n")
f1.flush()
Loading

0 comments on commit c214511

Please sign in to comment.