Commit 3e65fa85 authored by Alessio Brutti's avatar Alessio Brutti
Browse files

parameterization of folders and model file

parent 48f42b43
......@@ -7,8 +7,15 @@ Models are trained on VoxCeleb2 and are available in model/weights.h5
It finds all the wav files in a given folder (hard-coded) are producess a pickle dataset with a list of arrays containing the VGGvox embeddings for each file.
To Do:
- parameterize: input folder, output folder, file with models
Usage:
python src/extract_features.py -h
--input_dir: folder containing the wav files to process. It assumes that files are grouped by session or speaker.
--output_dir: folder where features are stored in a pickle file (one for each session/speaker)
--model: file with model weights.
This code is derived from the python version of VGGvox available here:
https://github.com/linhdvu14/vggvox-speaker-identification
import os
import numpy as np
from model import vggvox_model
#from model import vggvox_model
import utils
from scipy.spatial.distance import cdist, euclidean, cosine
import glob
import pickle as pkl
import argparse
###PARAMETERS
# Signal processing
SAMPLE_RATE = 16000
PREEMPHASIS_ALPHA = 0.97
#FRAME_LEN = 0.05
FRAME_LEN = 0.025
FRAME_STEP = 0.01
FRAME_STEP = 0.005
NUM_FFT = 512
BUCKET_STEP = 1
MAX_SEC = 10
# Model
WEIGHTS_FILE = "model/weights.h5" ###ARGUMENT!!!!
parser = argparse.ArgumentParser(description='Parameters for vggvox feature extraction. It assumes that files are grouped by sessions/speaker')
parser.add_argument('--input_dir', type=str, required=True,help='Folder containing the wav file to process')
parser.add_argument('--output_dir', type= str, required=True, help='Folder where features are stored')
parser.add_argument('--model', type=str, default='model/weights.h5',help='file where model weights are stored')
args = parser.parse_args()
WEIGHTS_FILE = args.model
print("Loading model weights from [{}]....".format(WEIGHTS_FILE))
model = vggvox_model(NUM_FFT)
model.load_weights(WEIGHTS_FILE)
......@@ -26,10 +34,11 @@ model.summary()
buckets = utils.build_buckets(MAX_SEC, BUCKET_STEP, FRAME_STEP)
for foldername in sorted(glob.glob('Tosca_segments/*')):
for foldername in sorted(glob.glob(args.input_dir+'/*')):
print(foldername)
session_name=foldername.split('/')[1]
embs=[]
for wav_file in sorted(glob.glob(foldername+'/*.wav')):
print(wav_file)
print('Extracting features from file %s'%wav_file)
......@@ -37,19 +46,8 @@ for foldername in sorted(glob.glob('Tosca_segments/*')):
embs.append(np.squeeze(model.predict(signal.reshape(1,*signal.shape,1))))
print(len(embs))
print('Dumping vggvox embeddings')
fdump=open('../output/VggVox/Tosca/'+session_name+'.pkl','wb')
fdump=open(args.output_dir+'/'+session_name+'.pkl','wb')
pkl.dump(embs, fdump, protocol=pkl.HIGHEST_PROTOCOL)
fdump.close()
#wav_file='Tosca_segments/ANTS_1193130232_59/054.wav'
#print('Extracting features from file %s'%wav_file)
#signal = utils.get_fft_spectrum(wav_file, buckets,SAMPLE_RATE,NUM_FFT,FRAME_LEN,FRAME_STEP,PREEMPHASIS_ALPHA)
#print(signal.shape)
#print('Embeddings...')
#embs = np.squeeze(model.predict(signal.reshape(1,*signal.shape,1)))
#print(embs.shape)
#print(embs.reshape(-1,1).shape)
#print(cdist(embs.reshape(-1,1).T, embs.reshape(-1,1).T, metric=COST_METRIC))
#print(cosine(embs1,embs2))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment