Commit 9b432819 authored by Lukáš Lalinský's avatar Lukáš Lalinský

All fingerprints in a track must be above certain similarity threshold when compared to each other

parent ae34d121
......@@ -6,7 +6,7 @@ from sqlalchemy import sql
from acoustid import tables as schema
from acoustid.data.fingerprint import lookup_fingerprint, insert_fingerprint, inc_fingerprint_submission_count
from acoustid.data.musicbrainz import find_puid_mbids, resolve_mbid_redirect
from acoustid.data.track import insert_track, insert_mbid, insert_puid, merge_tracks, insert_track_meta
from acoustid.data.track import insert_track, insert_mbid, insert_puid, merge_tracks, insert_track_meta, can_add_fp_to_track
logger = logging.getLogger(__name__)
TRACK_MERGE_TRESHOLD = 0.7
......@@ -71,19 +71,20 @@ def import_submission(conn, submission):
match = matches[0]
logger.debug("Matches %d results, the top result %s with track %d is %d%% similar",
len(matches), match['id'], match['track_id'], match['score'] * 100)
fingerprint['track_id'] = match['track_id']
if match['score'] > FINGERPRINT_MERGE_TRESHOLD:
fingerprint['id'] = match['id']
all_track_ids = set([match['track_id']])
for m in matches:
if m['track_id'] not in all_track_ids:
logger.debug("Fingerprint %d with track %d is %d%% similar",
m['id'], m['track_id'], m['score'] * 100)
all_track_ids.add(m['track_id'])
if len(all_track_ids) > 1:
fingerprint['track_id'] = min(all_track_ids)
all_track_ids.remove(fingerprint['track_id'])
merge_tracks(conn, fingerprint['track_id'], list(all_track_ids))
if can_add_fp_to_track(conn, match['track_id'], submission['fingerprint']):
fingerprint['track_id'] = match['track_id']
all_track_ids = set([match['track_id']])
for m in matches:
if m['track_id'] not in all_track_ids:
logger.debug("Fingerprint %d with track %d is %d%% similar",
m['id'], m['track_id'], m['score'] * 100)
all_track_ids.add(m['track_id'])
if len(all_track_ids) > 1 and fingerprint['track_id'] == match['track_id']:
fingerprint['track_id'] = min(all_track_ids)
all_track_ids.remove(fingerprint['track_id'])
merge_tracks(conn, fingerprint['track_id'], list(all_track_ids))
if not fingerprint['track_id']:
fingerprint['track_id'] = insert_track(conn)
logger.info('Added new track %d', fingerprint['track_id'])
......
......@@ -201,3 +201,39 @@ def calculate_fingerprint_similarity_matrix(conn, track_ids):
result.setdefault(fp2_id, {})[fp2_id] = 1.0
return result
def can_merge_tracks(conn, track_ids):
fp1 = schema.fingerprint.alias('fp1')
fp2 = schema.fingerprint.alias('fp2')
join_cond = sql.and_(fp1.c.id < fp2.c.id, fp1.c.track_id < fp2.c.track_id)
src = fp1.join(fp2, join_cond)
cond = sql.and_(fp1.c.track_id.in_(track_ids), fp2.c.track_id.in_(track_ids))
query = sql.select([
fp1.c.track_id, fp2.c.track_id,
sql.func.min(sql.func.acoustid_compare2(fp1.c.fingerprint, fp2.c.fingerprint)),
], cond, from_obj=src).group_by(fp1.c.track_id, fp2.c.track_id).order_by(fp1.c.track_id, fp2.c.track_id)
rows = conn.execute(query)
merges = {}
for fp1_id, fp2_id, score in rows:
if score < 0.3:
continue
group = fp1_id
if group in merges:
group = merges[group]
merges[fp2_id] = group
result = []
for group in set(merges.values()):
result.append(set([group] + [i for i in merges if merges[i] == group]))
return result
def can_add_fp_to_track(conn, track_id, fingerprint):
cond = schema.fingerprint.c.track_id == track_id
query = sql.select([
sql.func.min(sql.func.acoustid_compare2(schema.fingerprint.c.fingerprint, fingerprint)),
], cond, from_obj=schema.fingerprint)
score = conn.execute(query).scalar()
if score < 0.3:
return False
return True
......@@ -21,6 +21,8 @@ from acoustid.data.track import (
merge_missing_mbids, insert_track, merge_tracks,
merge_mbids,
calculate_fingerprint_similarity_matrix,
can_merge_tracks,
can_add_fp_to_track,
)
......@@ -132,3 +134,28 @@ INSERT INTO fingerprint (fingerprint, length, track_id)
assert_almost_equal(0.94152, matrix[3][1])
assert_almost_equal(0.938414, matrix[3][2])
@with_database
def test_can_merge_tracks(conn):
prepare_database(conn, """
INSERT INTO fingerprint (fingerprint, length, track_id)
VALUES (%(fp1)s, %(len1)s, 1), (%(fp2)s, %(len2)s, 2),
(%(fp3)s, %(len3)s, 3);
""", dict(fp1=TEST_1A_FP_RAW, len1=TEST_1A_LENGTH,
fp2=TEST_1B_FP_RAW, len2=TEST_1B_LENGTH,
fp3=TEST_2_FP_RAW, len3=TEST_2_LENGTH))
groups = can_merge_tracks(conn, [1, 2, 3])
assert_equal([set([1, 2])], groups)
@with_database
def test_can_add_fp_to_track(conn):
prepare_database(conn, """
INSERT INTO fingerprint (fingerprint, length, track_id)
VALUES (%(fp1)s, %(len1)s, 1);
""", dict(fp1=TEST_1A_FP_RAW, len1=TEST_1A_LENGTH))
res = can_add_fp_to_track(conn, 1, TEST_2_FP_RAW)
assert_equal(False, res)
res = can_add_fp_to_track(conn, 1, TEST_1B_FP_RAW)
assert_equal(True, res)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment