음성 처리(STT) - wav2vec(Huggingface)
wav2vec(Huggingface)
!pip install transformers datasets jiwer
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datasets import load_dataset
import soundfile as sf
import torch
from jiwer import wer
wer: word error rate
infer중에 몇번 틀렸는지, 낮을 수록 좋다.
acc = True/Total
processor = Wav2Vec2Processor.from_pretrained("kresnik/wav2vec2-large-xlsr-korean")
model = Wav2Vec2ForCTC.from_pretrained("kresnik/wav2vec2-large-xlsr-korean").to("cuda")
ds = load_dataset("kresnik/zeroth_korean", "clean")
test_ds = ds['test']
type(test_ds)
datasets.arrow_dataset.Dataset
def map_to_array(batch):
speech, sr = sf.read(batch['file'])
batch['speech'] = speech
return batch
test_ds = test_ds.map(map_to_array)
0%| | 0/457 [00:00<?, ?ex/s]
def map_to_pred(batch):
inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding="longest")
input_values = inputs.input_values.to("cuda") # melspectrogram
# gradient(경사하강법)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = test_ds.map(map_to_pred, batched=True, batch_size=16)
0%| | 0/29 [00:00<?, ?ba/s]
print("WER", wer(result["text"], result["transcription"])) # 원래 정답과 예상치 비교(에러 확률; 낮을수록 좋음)
WER 0.04773377503388044
댓글남기기