최대 1 분 소요


wav2vec(Huggingface)

In [1]:
!pip install transformers datasets jiwer
In [2]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datasets import load_dataset
import soundfile as sf
import torch
from jiwer import wer

wer: word error rate
infer중에 몇번 틀렸는지, 낮을 수록 좋다.
acc = True/Total

In [3]:
processor = Wav2Vec2Processor.from_pretrained("kresnik/wav2vec2-large-xlsr-korean")
model = Wav2Vec2ForCTC.from_pretrained("kresnik/wav2vec2-large-xlsr-korean").to("cuda")
ds = load_dataset("kresnik/zeroth_korean", "clean")
In [4]:
test_ds = ds['test']
In [5]:
type(test_ds)
Out [5]:
datasets.arrow_dataset.Dataset
In [6]:
def map_to_array(batch):
    speech, sr = sf.read(batch['file'])
    batch['speech'] = speech
    return batch
In [7]:
test_ds = test_ds.map(map_to_array)
Out [7]:
  0%|          | 0/457 [00:00<?, ?ex/s]
In [8]:
def map_to_pred(batch):
    inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding="longest")
    input_values = inputs.input_values.to("cuda") # melspectrogram
    # gradient(경사하강법)
    with torch.no_grad():
        logits = model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    batch["transcription"] = transcription
    return batch

result = test_ds.map(map_to_pred, batched=True, batch_size=16)
Out [8]:
  0%|          | 0/29 [00:00<?, ?ba/s]
In [9]:
print("WER", wer(result["text"], result["transcription"])) # 원래 정답과 예상치 비교(에러 확률; 낮을수록 좋음)
Out [9]:
WER 0.04773377503388044

댓글남기기