From: Mart Lubbers Date: Sun, 30 Apr 2017 19:36:28 +0000 (+0200) Subject: add prediction X-Git-Url: https://git.martlubbers.net/?a=commitdiff_plain;h=8fdbbd451426bebb71e686f19bcf146769d02968;p=asr1617data.git add prediction --- diff --git a/experiments.py b/experiments.py index 699c68c..911de6b 100644 --- a/experiments.py +++ b/experiments.py @@ -126,7 +126,7 @@ if __name__ == '__main__': for winlen, winstep in ((0.025, 0.01), (0.1, 0.04), (0.2, 0.08)): for name, model in (('simple', simplemodel), ('bottle', bottlemodel)): m = run('mfcc', winlen, winstep, model, name) - fproot = 'model_{}_{}_{}.json'.format(winlen, winstep, name) + fproot = 'model_{}_{}_{}'.format(winlen, winstep, name) with open('{}.json'.format(fproot), 'w') as f: f.write(m.to_json()) m.save_weights('{}.hdf5'.format(fproot)) diff --git a/predict.py b/predict.py index 962b9fd..b561c65 100644 --- a/predict.py +++ b/predict.py @@ -1,26 +1,45 @@ import numpy as np import sys +import pympi import scipy.io.wavfile as wav import numpy as np -from python_speech_featuresimport mfcc +from python_speech_features import mfcc from keras.models import model_from_json modelfile = sys.argv[1] -hdf5file = '{}.hdf5'.format(modelfile[-4:]) +hdf5file = '{}.hdf5'.format(modelfile[:-5]) -with open(modelfile, 'r') as f: +with open(modelfile, 'r', encoding='utf-8') as f: json = f.read() model = model_from_json(json) model.load_weights(hdf5file) +(_, winlen, winstep, _) = modelfile.split('_') +winlen = float(winlen) +winstep = float(winstep) + model.compile( loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy']) -(rate, sig) = wav.read(sys.stdin.buffer) -data = mfcc(sig, rate, winlen -for i in model.predict(dat, batch_size=32, verbose=0): - print(i[0]) +(rate, sig) = wav.read(sys.argv[2], mmap=True) +data = mfcc(sig, rate, winlen, winstep, numcep=13, appendEnergy=True) +tgob = pympi.TextGrid(xmax=winstep*len(data)) +tier = tgob.add_tier('lyrics') + +time = 0.0 +lastlabel = False +lasttime = 0.0 +for i in model.predict(data, batch_size=32, verbose=0): +# print('{}\t{}'.format(time, i)) + label = i > 0.5 + if label != lastlabel and time-lasttime > 0.5: + tier.add_interval(lasttime, time, '*' if lastlabel else '') + lastlabel = label + lasttime = time + + time += winstep +tgob.to_file('/dev/stdout')