testing.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. from __future__ import division
  2. from pydub import AudioSegment
  3. from dejavu.decoder import path_to_songname
  4. from dejavu import Dejavu
  5. from dejavu.fingerprint import *
  6. import traceback
  7. import fnmatch
  8. import os, re, ast
  9. import subprocess
  10. import random
  11. import logging
  12. def set_seed(seed=None):
  13. """
  14. `seed` as None means that the sampling will be random.
  15. Setting your own seed means that you can produce the
  16. same experiment over and over.
  17. """
  18. if seed != None:
  19. random.seed(seed)
  20. def get_files_recursive(src, fmt):
  21. """
  22. `src` is the source directory.
  23. `fmt` is the extension, ie ".mp3" or "mp3", etc.
  24. """
  25. for root, dirnames, filenames in os.walk(src):
  26. for filename in fnmatch.filter(filenames, '*' + fmt):
  27. yield os.path.join(root, filename)
  28. def get_length_audio(audiopath, extension):
  29. """
  30. Returns length of audio in seconds.
  31. Returns None if format isn't supported or in case of error.
  32. """
  33. try:
  34. audio = AudioSegment.from_file(audiopath, extension.replace(".", ""))
  35. except:
  36. print "Error in get_length_audio(): %s" % traceback.format_exc()
  37. return None
  38. return int(len(audio) / 1000.0)
  39. def get_starttime(length, nseconds, padding):
  40. """
  41. `length` is total audio length in seconds
  42. `nseconds` is amount of time to sample in seconds
  43. `padding` is off-limits seconds at beginning and ending
  44. """
  45. maximum = length - padding - nseconds
  46. if padding > maximum:
  47. return 0
  48. return random.randint(padding, maximum)
  49. def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10):
  50. """
  51. Generates a test file for each file recursively in `src` directory
  52. of given format using `nseconds` sampled from the audio file.
  53. Results are written to `dest` directory.
  54. `padding` is the number of off-limit seconds and the beginning and
  55. end of a track that won't be sampled in testing. Often you want to
  56. avoid silence, etc.
  57. """
  58. # create directories if necessary
  59. for directory in [src, dest]:
  60. try:
  61. os.stat(directory)
  62. except:
  63. os.mkdir(directory)
  64. # find files recursively of a given file format
  65. for fmt in fmts:
  66. testsources = get_files_recursive(src, fmt)
  67. for audiosource in testsources:
  68. print "audiosource:", audiosource
  69. filename, extension = os.path.splitext(os.path.basename(audiosource))
  70. length = get_length_audio(audiosource, extension)
  71. starttime = get_starttime(length, nseconds, padding)
  72. test_file_name = "%s_%s_%ssec.%s" % (
  73. os.path.join(dest, filename), starttime,
  74. nseconds, extension.replace(".", ""))
  75. subprocess.check_output([
  76. "ffmpeg", "-y",
  77. "-ss", "%d" % starttime,
  78. '-t' , "%d" % nseconds,
  79. "-i", audiosource,
  80. test_file_name])
  81. def log_msg(msg, log=True, silent=False):
  82. if log:
  83. logging.debug(msg)
  84. if not silent:
  85. print msg
  86. def autolabel(rects, ax):
  87. # attach some text labels
  88. for rect in rects:
  89. height = rect.get_height()
  90. ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height,
  91. '%d' % int(height), ha='center', va='bottom')
  92. def autolabeldoubles(rects, ax):
  93. # attach some text labels
  94. for rect in rects:
  95. height = rect.get_height()
  96. ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height,
  97. '%s' % round(float(height), 3), ha='center', va='bottom')
  98. class DejavuTest(object):
  99. def __init__(self, folder, seconds):
  100. super(DejavuTest, self).__init__()
  101. self.test_folder = folder
  102. self.test_seconds = seconds
  103. self.test_songs = []
  104. print "test_seconds", self.test_seconds
  105. self.test_files = [
  106. f for f in os.listdir(self.test_folder)
  107. if os.path.isfile(os.path.join(self.test_folder, f))
  108. and re.findall("[0-9]*sec", f)[0] in self.test_seconds]
  109. print "test_files", self.test_files
  110. self.n_columns = len(self.test_seconds)
  111. self.n_lines = int(len(self.test_files) / self.n_columns)
  112. print "columns:", self.n_columns
  113. print "length of test files:", len(self.test_files)
  114. print "lines:", self.n_lines
  115. # variable match results (yes, no, invalid)
  116. self.result_match = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
  117. print "result_match matrix:", self.result_match
  118. # variable match precision (if matched in the corrected time)
  119. self.result_matching_times = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
  120. # variable mahing time (query time)
  121. self.result_query_duration = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
  122. # variable confidence
  123. self.result_match_confidence = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
  124. self.begin()
  125. def get_column_id (self, secs):
  126. for i, sec in enumerate(self.test_seconds):
  127. if secs == sec:
  128. return i
  129. def get_line_id (self, song):
  130. for i, s in enumerate(self.test_songs):
  131. if song == s:
  132. return i
  133. self.test_songs.append(song)
  134. return len(self.test_songs) - 1
  135. def create_plots(self, name, results, results_folder):
  136. for sec in range(0, len(self.test_seconds)):
  137. ind = np.arange(self.n_lines) #
  138. width = 0.25 # the width of the bars
  139. fig = plt.figure()
  140. ax = fig.add_subplot(111)
  141. ax.set_xlim([-1 * width, 2 * width])
  142. means_dvj = [x[0] for x in results[sec]]
  143. rects1 = ax.bar(ind, means_dvj, width, color='r')
  144. # add some
  145. ax.set_ylabel(name)
  146. ax.set_title("%s %s Results" % (self.test_seconds[sec], name))
  147. ax.set_xticks(ind + width)
  148. labels = [0 for x in range(0, self.n_lines)]
  149. for x in range(0, self.n_lines):
  150. labels[x] = "song %s" % (x+1)
  151. ax.set_xticklabels(labels)
  152. box = ax.get_position()
  153. ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
  154. #ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
  155. if name == 'Confidence':
  156. autolabel(rects1, ax)
  157. else:
  158. autolabeldoubles(rects1, ax)
  159. plt.grid()
  160. fig_name = os.path.join(results_folder, "%s_%s.png" % (name, self.test_seconds[sec]))
  161. fig.savefig(fig_name)
  162. def begin(self):
  163. for f in self.test_files:
  164. log_msg('--------------------------------------------------')
  165. log_msg('file: %s' % f)
  166. # get column
  167. col = self.get_column_id(re.findall("[0-9]*sec", f)[0])
  168. # format: XXXX_offset_length.mp3
  169. song = path_to_songname(f).split("_")[0]
  170. line = self.get_line_id(song)
  171. result = subprocess.check_output([
  172. "python",
  173. "dejavu.py",
  174. '-r',
  175. 'file',
  176. self.test_folder + "/" + f])
  177. if result.strip() == "None":
  178. log_msg('No match')
  179. self.result_match[line][col] = 'no'
  180. self.result_matching_times[line][col] = 0
  181. self.result_query_duration[line][col] = 0
  182. self.result_match_confidence[line][col] = 0
  183. else:
  184. result = result.strip()
  185. result = result.replace(" \'", ' "')
  186. result = result.replace("{\'", '{"')
  187. result = result.replace("\':", '":')
  188. result = result.replace("\',", '",')
  189. # which song did we predict?
  190. result = ast.literal_eval(result)
  191. song_result = result["song_name"]
  192. log_msg('song: %s' % song)
  193. log_msg('song_result: %s' % song_result)
  194. if song_result != song:
  195. log_msg('invalid match')
  196. self.result_match[line][col] = 'invalid'
  197. self.result_matching_times[line][col] = 0
  198. self.result_query_duration[line][col] = 0
  199. self.result_match_confidence[line][col] = 0
  200. else:
  201. log_msg('correct match')
  202. print self.result_match
  203. self.result_match[line][col] = 'yes'
  204. self.result_query_duration[line][col] = round(result[Dejavu.MATCH_TIME],3)
  205. self.result_match_confidence[line][col] = result[Dejavu.CONFIDENCE]
  206. song_start_time = re.findall("\_[^\_]+",f)
  207. song_start_time = song_start_time[0].lstrip("_ ")
  208. result_start_time = round((result[Dejavu.OFFSET] * DEFAULT_WINDOW_SIZE *
  209. DEFAULT_OVERLAP_RATIO) / (DEFAULT_FS), 0)
  210. self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time)
  211. if (abs(self.result_matching_times[line][col]) == 1):
  212. self.result_matching_times[line][col] = 0
  213. log_msg('query duration: %s' % round(result[Dejavu.MATCH_TIME],3))
  214. log_msg('confidence: %s' % result[Dejavu.CONFIDENCE])
  215. log_msg('song start_time: %s' % song_start_time)
  216. log_msg('result start time: %s' % result_start_time)
  217. if (self.result_matching_times[line][col] == 0):
  218. log_msg('accurate match')
  219. else:
  220. log_msg('inaccurate match')
  221. log_msg('--------------------------------------------------\n')