Commit 3196fe90 authored by Alessio Brutti's avatar Alessio Brutti
Browse files

second commit, minor changes

parent c7757109
import numpy as np
import pickle as pkl
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import roc_curve
import os
import sys
import argparse
import glob
def load_features(input_folder):
# A lot of hardcoded stuff for librispeech format, ID probably should have been saved in extract features
feat=[]
ids=[]
for filename in sorted(glob.glob(args.input_folder+'/*')):
#print(filename)
ids.append(filename.split('/')[-1].split('-')[0])
ifile=open(filename,'rb')
feat.append(pkl.load(ifile))
return ids, feat
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Compute similarity and EER on a set of embeddings")
parser.add_argument("-i", "--input_folder", type=str, help="Embeddings folder")
args = parser.parse_args()
ids, feats = load_features(args.input_folder)
d = cosine_similarity(np.stack(feats))
dup=np.triu(d+1,1) ##I take only the upper part
y_score=d[dup>0]
#del d
#del dup
labels=np.array([i==j for i in ids for j in ids], dtype=int ).reshape(-1,len(ids))
dup=np.triu(labels+1,1) ##I take only the upper part
y_true=labels[dup>0]
fpr, tpr, threshold = roc_curve(y_true, y_score)
fnr = 1 - tpr
EER = fpr[np.nanargmin(np.abs((fnr - fpr)))]
print(EER)
EER = fnr[np.nanargmin(np.absolute((fnr - fpr)))]
print(EER)
......@@ -47,13 +47,12 @@ for foldername in sorted(glob.glob(args.input_dir+'/*/')):
print(audio_file)
signal = utils.get_fft_spectrum(audio_file, buckets,SAMPLE_RATE,NUM_FFT,FRAME_LEN,FRAME_STEP,PREEMPHASIS_ALPHA)
feats.append(np.squeeze(model.predict(signal.reshape(1,*signal.shape,1))))
if not os.path.isdir(args.output_dir+'/'+sessionname):
os.mkdir(args.output_dir+'/'+sessionname)
fdump=open(audio_file.replace(args.input_dir,args.output_dir).replace('wav','pkl'),'wb')
pkl.dump(feats[-1], fdump, protocol=pkl.HIGHEST_PROTOCOL)
fdump.close()
#fdump=open(args.output_dir+'/'+audio_file.split('/')[-1].split('.')[0]+'.pkl','wb')
#pkl.dump(embs, fdump, protocol=pkl.HIGHEST_PROTOCOL)
#fdump.close()
d = cosine_similarity(np.stack(feats))
print(d.shape)
print(d)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment