Commit 75b4a67c authored by Lukáš Lalinský's avatar Lukáš Lalinský

Multi-DB setup

parent bb009801
[database]
two_phase_commit=yes
[database:default]
host=127.0.0.1
port=5432
port=15432
user=acoustid
name=acoustid_test
password=acoustid
[database:slow]
host=127.0.0.1
port=15432
user=acoustid
name=acoustid_slow_test
password=acoustid
[logging]
level=WARNING
level.sqlalchemy=WARNING
syslog=yes
syslog_facility=local1
[index]
host=127.0.0.1
port=16080
[redis]
host=127.0.0.1
port=6379
port=16379
[website]
base_url=http://acoustid.org/
......
[database]
name=acoustid
user=acoustid
password=XXX
superuser=postgres
host=localhost
port=5432
[database_slow]
name=acoustid
user=acoustid
password=XXX
host=localhost
port=5432
......
......@@ -8,18 +8,15 @@ import time
import operator
from typing import Type
from acoustid import const
from acoustid.db import DatabaseContext
from acoustid.const import MAX_REQUESTS_PER_SECOND
from acoustid.handler import Handler
from acoustid.models import Application, Account, Submission, SubmissionResult, Track
from acoustid.data.track import lookup_mbids, resolve_track_gid, lookup_meta_ids
from acoustid.data.musicbrainz import lookup_metadata
from acoustid.data.submission import insert_submission, lookup_submission_status
from acoustid.data.fingerprint import decode_fingerprint, FingerprintSearcher
from acoustid.data.format import find_or_insert_format
from acoustid.data.application import lookup_application_id_by_apikey
from acoustid.data.account import lookup_account_id_by_apikey
from acoustid.data.source import find_or_insert_source
from acoustid.data.meta import insert_meta, lookup_meta
from acoustid.data.foreignid import find_or_insert_foreignid
from acoustid.data.meta import lookup_meta
from acoustid.data.stats import update_lookup_counter, update_user_agent_counter, update_lookup_avg_time
from acoustid.ratelimiter import RateLimiter
from werkzeug.utils import cached_property
......@@ -103,6 +100,7 @@ class APIHandler(Handler):
else:
connect = server.engine.connect
handler = cls(connect=connect)
handler.server = server
handler.index = server.index
handler.redis = server.redis
handler.config = server.config
......@@ -564,6 +562,7 @@ class LookupHandler(APIHandler):
matches = [(0, track_id, p['track_gid'], 1.0)]
else:
matches = searcher.search(p['fingerprint'], p['duration'])
print(repr(matches))
all_matches.append(matches)
response = {}
if params.batch:
......@@ -589,37 +588,121 @@ class LookupHandler(APIHandler):
return response
class APIHandlerWithORM(APIHandler):
params_class = None # type: Type[APIHandlerParams]
def __init__(self, server):
self.server = server
@property
def index(self):
return self.server.index
@property
def redis(self):
return self.server.redis
@property
def config(self):
return self.server.config
@classmethod
def create_from_server(cls, server):
return cls(server=server)
def _rate_limit(self, user_ip, application_id):
ip_rate_limit = self.config.rate_limiter.ips.get(user_ip, MAX_REQUESTS_PER_SECOND)
if self.rate_limiter.limit('ip', user_ip, ip_rate_limit):
if application_id == DEMO_APPLICATION_ID:
raise errors.TooManyRequests(ip_rate_limit)
if application_id is not None:
application_rate_limit = self.config.rate_limiter.applications.get(application_id)
if application_rate_limit is not None:
if self.rate_limiter.limit('app', application_id, application_rate_limit):
if application_id == DEMO_APPLICATION_ID:
raise errors.TooManyRequests(application_rate_limit)
def handle(self, req):
params = self.params_class(self.config)
if req.access_route:
self.user_ip = req.access_route[0]
else:
self.user_ip = req.remote_addr
self.is_secure = req.is_secure
self.user_agent = req.user_agent
self.rate_limiter = RateLimiter(self.redis, 'rl')
try:
with DatabaseContext(self.server) as db:
self.db = db
try:
try:
params.parse(req.values, db)
self._rate_limit(self.user_ip, getattr(params, 'application_id', None))
return self._ok(self._handle_internal(params), params.format)
except errors.WebServiceError:
raise
except Exception:
logger.exception('Error while handling API request')
raise errors.InternalError()
finally:
self.db = None
except errors.WebServiceError as e:
logger.warning("WS error: %s", e.message)
return self._error(e.code, e.message, params.format, status=e.status)
def get_submission_status(db, submission_ids):
submissions = (
db.session.query(SubmissionResult.submission_id, SubmissionResult.track_id)
.filter(SubmissionResult.submission_id.in_(submission_ids))
)
track_ids = {submission_id: track_id for (submission_id, track_id) in submissions}
tracks = (
db.session.query(Track.id, Track.gid)
.filter(Track.id.in_(track_ids.values()))
)
track_gids = {track_id: track_gid for (track_id, track_gid) in tracks}
response = {'submissions': []}
for submission_id in submission_ids:
submission_response = {'id': submission_id, 'status': 'pending'}
track_id = track_ids.get(submission_id)
if track_id is not None:
track_gid = track_gids.get(track_id)
if track_gid is not None:
submission_response.update({
'status': 'imported',
'response': {'id': track_gid},
})
response['submissions'].append(submission_response)
return response
class SubmissionStatusHandlerParams(APIHandlerParams):
def parse(self, values, conn):
super(SubmissionStatusHandlerParams, self).parse(values, conn)
self._parse_client(values, conn)
def parse(self, values, db):
super(SubmissionStatusHandlerParams, self).parse(values, db)
self._parse_client(values, db.session.connection(mapper=Application))
self.ids = values.getlist('id', type=int)
class SubmissionStatusHandler(APIHandler):
class SubmissionStatusHandler(APIHandlerWithORM):
params_class = SubmissionStatusHandlerParams
def _handle_internal(self, params):
response = {'submissions': [{'id': id, 'status': 'pending'} for id in params.ids]}
tracks = lookup_submission_status(self.conn, params.ids)
for submission in response['submissions']:
id = submission['id']
track_gid = tracks.get(id)
if track_gid is not None:
submission['status'] = 'imported'
submission['result'] = {'id': track_gid}
return response
return get_submission_status(self.db, params.ids)
class SubmitHandlerParams(APIHandlerParams):
def _parse_user(self, values, conn):
def _parse_user(self, values, db):
account_apikey = values.get('user')
if not account_apikey:
raise errors.MissingParameterError('user')
self.account_id = lookup_account_id_by_apikey(conn, account_apikey)
self.account_id = db.session.query(Account.id).filter(Account.apikey == account_apikey).scalar()
if not self.account_id:
raise errors.InvalidUserAPIKeyError()
......@@ -640,8 +723,8 @@ class SubmitHandlerParams(APIHandlerParams):
p['foreignid'] = values.get('foreignid' + suffix)
if p['foreignid'] and not is_foreignid(p['foreignid']):
raise errors.InvalidForeignIDError('foreignid' + suffix)
p['mbids'] = values.getlist('mbid' + suffix)
if p['mbids'] and not all(map(is_uuid, p['mbids'])):
p['mbid'] = values.get('mbid' + suffix)
if p['mbid'] and not is_uuid(p['mbid']):
raise errors.InvalidUUIDError('mbid' + suffix)
self._parse_duration_and_format(p, values, suffix)
fingerprint_string = values.get('fingerprint' + suffix)
......@@ -662,10 +745,10 @@ class SubmitHandlerParams(APIHandlerParams):
p['year'] = values.get('year' + suffix, type=int)
self.submissions.append(p)
def parse(self, values, conn):
super(SubmitHandlerParams, self).parse(values, conn)
self._parse_client(values, conn)
self._parse_user(values, conn)
def parse(self, values, db):
super(SubmitHandlerParams, self).parse(values, db)
self._parse_client(values, db.session.connection(mapper=Application))
self._parse_user(values, db)
self.wait = values.get('wait', type=int, default=0)
self.submissions = []
suffixes = list(iter_args_suffixes(values, 'fingerprint'))
......@@ -679,73 +762,45 @@ class SubmitHandlerParams(APIHandlerParams):
raise
class SubmitHandler(APIHandler):
class SubmitHandler(APIHandlerWithORM):
params_class = SubmitHandlerParams
meta_fields = ('track', 'artist', 'album', 'album_artist', 'track_no',
'disc_no', 'year')
def _handle_internal(self, params):
response = {'submissions': []}
ids = set()
with self.conn.begin():
source_id = find_or_insert_source(self.conn, params.application_id, params.account_id, params.application_version)
format_ids = {}
for p in params.submissions:
if p['format']:
if p['format'] not in format_ids:
format_ids[p['format']] = find_or_insert_format(self.conn, p['format'])
p['format_id'] = format_ids[p['format']]
for p in params.submissions:
mbids = p['mbids'] or [None]
for mbid in mbids:
values = {
'mbid': mbid or None,
'puid': p['puid'] or None,
'bitrate': p['bitrate'] or None,
'fingerprint': p['fingerprint'],
'length': p['duration'],
'format_id': p.get('format_id'),
'source_id': source_id
}
meta_values = dict((n, p[n] or None) for n in self.meta_fields)
if any(meta_values.itervalues()):
values['meta_id'] = insert_meta(self.conn, meta_values)
if p['foreignid']:
values['foreignid_id'] = find_or_insert_foreignid(self.conn, p['foreignid'])
id = insert_submission(self.conn, values)
ids.add(id)
submission = {'id': id, 'status': 'pending'}
if p['index']:
submission['index'] = p['index']
response['submissions'].append(submission)
if self.redis is not None:
self.redis.publish('channel.submissions', json.dumps(list(ids)))
clients_waiting_key = 'submission.waiting'
clients_waiting = self.redis.incr(clients_waiting_key) - 1
try:
max_wait = 10
self.redis.expire(clients_waiting_key, max_wait)
tracks = {}
remaining = min(max(0, max_wait - 2 ** clients_waiting), params.wait)
logger.debug('starting to wait at %f %d', remaining, clients_waiting)
while remaining > 0 and ids:
logger.debug('waiting %f seconds', remaining)
time.sleep(0.5) # XXX replace with LISTEN/NOTIFY
remaining -= 0.5
tracks = lookup_submission_status(self.conn, ids)
if not tracks:
continue
for submission in response['submissions']:
id = submission['id']
track_gid = tracks.get(id)
if track_gid is not None:
submission['status'] = 'imported'
submission['result'] = {'id': track_gid}
ids.remove(id)
finally:
self.redis.decr(clients_waiting_key)
ids = {}
for p in params.submissions:
submission = Submission()
submission.account_id = params.account_id
submission.application_id = params.application_id
submission.application_version = params.application_version
submission.fingerprint = p['fingerprint']
submission.duration = p['duration']
submission.mbid = p['mbid'] or None
submission.puid = p['puid'] or None
submission.foreignid = p['foreignid'] or None
submission.bitrate = p['bitrate'] or None
submission.format = p['format'] or None
submission.track = p['track'] or None
submission.artist = p['artist'] or None
submission.album = p['album'] or None
submission.album_artist = p['album_artist'] or None
submission.track_no = p['track_no'] or None
submission.disc_no = p['disc_no'] or None
submission.year = p['year'] or None
self.db.session.add(submission)
self.db.session.flush()
ids[submission.id] = p['index']
self.db.session.commit()
self.redis.publish('channel.submissions', json.dumps(list(ids.keys())))
response = get_submission_status(self.db, list(ids.keys()))
for submission_response in response['submissions']:
submission_id = submission_response['id']
index = ids[submission_id]
if index:
submission_response['index'] = index
return response
......@@ -18,8 +18,10 @@ def read_env_item(obj, key, name, convert=None):
value = None
if name in os.environ:
value = os.environ[name]
if name + '_FILE' in os.environ:
logger.info('Reading config value from environment variable %s', name)
elif name + '_FILE' in os.environ:
value = open(os.environ[name + '_FILE']).read().strip()
logger.info('Reading config value from environment variable %s', name + '_FILE')
if value is not None:
if convert is not None:
value = convert(value)
......@@ -39,6 +41,35 @@ class BaseConfig(object):
pass
class DatabasesConfig(BaseConfig):
def __init__(self):
self.databases = {
'default': DatabaseConfig(),
'slow': DatabaseConfig(),
}
self.use_two_phase_commit = False
def create_engines(self, **kwargs):
engines = {}
for name, db_config in self.databases.items():
engines[name] = db_config.create_engine(**kwargs)
return engines
def read_section(self, parser, section):
if parser.has_option(section, 'two_phase_commit'):
self.use_two_phase_commit = parser.getboolean(section, 'two_phase_commit')
for name, sub_config in self.databases.items():
sub_section = '{}:{}'.format(section, name)
sub_config.read_section(parser, sub_section)
def read_env(self, prefix):
read_env_item(self, 'use_two_phase_commit', prefix + 'TWO_PHASE_COMMIT', convert=str_to_bool)
for name, sub_config in self.databases.items():
sub_prefix = prefix + name.upper() + '_'
sub_config.read_env(sub_prefix)
class DatabaseConfig(BaseConfig):
def __init__(self):
......@@ -103,21 +134,21 @@ class DatabaseConfig(BaseConfig):
if parser.has_option(section, 'password'):
self.password = parser.get(section, 'password')
if parser.has_option(section, 'pool_size'):
self.password = parser.getint(section, 'pool_size')
self.pool_size = parser.getint(section, 'pool_size')
if parser.has_option(section, 'pool_recycle'):
self.password = parser.getint(section, 'pool_recycle')
self.pool_recycle = parser.getint(section, 'pool_recycle')
if parser.has_option(section, 'pool_pre_ping'):
self.password = parser.getboolean(section, 'pool_pre_ping')
self.pool_pre_ping = parser.getboolean(section, 'pool_pre_ping')
def read_env(self, prefix):
read_env_item(self, 'name', prefix + 'POSTGRES_DB')
read_env_item(self, 'host', prefix + 'POSTGRES_HOST')
read_env_item(self, 'port', prefix + 'POSTGRES_PORT', convert=int)
read_env_item(self, 'user', prefix + 'POSTGRES_USER')
read_env_item(self, 'password', prefix + 'POSTGRES_PASSWORD')
read_env_item(self, 'pool_size', prefix + 'POSTGRES_POOL_SIZE', convert=int)
read_env_item(self, 'pool_recycle', prefix + 'POSTGRES_POOL_RECYCLE', convert=int)
read_env_item(self, 'pool_pre_ping', prefix + 'POSTGRES_POOL_PRE_PING', convert=str_to_bool)
read_env_item(self, 'name', prefix + 'NAME')
read_env_item(self, 'host', prefix + 'HOST')
read_env_item(self, 'port', prefix + 'PORT', convert=int)
read_env_item(self, 'user', prefix + 'USER')
read_env_item(self, 'password', prefix + 'PASSWORD')
read_env_item(self, 'pool_size', prefix + 'POOL_SIZE', convert=int)
read_env_item(self, 'pool_recycle', prefix + 'POOL_RECYCLE', convert=int)
read_env_item(self, 'pool_pre_ping', prefix + 'POOL_PRE_PING', convert=str_to_bool)
class IndexConfig(BaseConfig):
......@@ -342,7 +373,7 @@ class RateLimiterConfig(BaseConfig):
class Config(object):
def __init__(self):
self.database = DatabaseConfig()
self.databases = DatabasesConfig()
self.logging = LoggingConfig()
self.website = WebSiteConfig()
self.index = IndexConfig()
......@@ -357,7 +388,7 @@ class Config(object):
logger.info("Loading configuration file %s", path)
parser = ConfigParser.RawConfigParser()
parser.read(path)
self.database.read(parser, 'database')
self.databases.read(parser, 'database')
self.logging.read(parser, 'logging')
self.website.read(parser, 'website')
self.index.read(parser, 'index')
......@@ -373,7 +404,7 @@ class Config(object):
prefix = 'ACOUSTID_TEST_'
else:
prefix = 'ACOUSTID_'
self.database.read_env(prefix)
self.databases.read_env(prefix)
self.logging.read_env(prefix)
self.website.read_env(prefix)
self.index.read_env(prefix)
......
......@@ -21,3 +21,6 @@ FINGERPRINT_MAX_LENGTH_DIFF = 7
FINGERPRINT_MAX_ALLOWED_LENGTH_DIFF = 30
MAX_REQUESTS_PER_SECOND = 3
MAX_FOREIGNID_NAMESPACE_LENGTH = 10
MAX_FOREIGNID_VALUE_LENGTH = 64
......@@ -28,6 +28,7 @@ def insert_submission(conn, data):
'source_id': data.get('source_id'),
'format_id': data.get('format_id'),
'meta_id': data.get('meta_id'),
'foreignid': data.get('foreignid'),
'foreignid_id': data.get('foreignid_id'),
})
id = conn.execute(insert_stmt).inserted_primary_key[0]
......
from sqlalchemy.orm import sessionmaker
from acoustid.tables import metadata
Session = sessionmaker()
def get_bind_args(engines):
binds = {}
for table in metadata.sorted_tables:
bind_key = table.info.get('bind_key', 'default')
if bind_key != 'default':
binds[table] = engines[bind_key]
return {'bind': engines['default'], 'binds': binds}
def get_session_args(script):
kwargs = {'twophase': script.config.databases.use_two_phase_commit}
kwargs.update(get_bind_args(script.db_engines))
return kwargs
class DatabaseContext(object):
def __init__(self, bind):
self.session = Session(bind=bind)
def __init__(self, script):
self.session = Session(**get_session_args(script))
def __enter__(self):
return self
......
......@@ -79,3 +79,11 @@ class StatsLookups(Base):
__table__ = tables.stats_lookups
application = relationship('Application')
class Submission(Base):
__table__ = tables.submission
class SubmissionResult(Base):
__table__ = tables.submission_result
......@@ -11,11 +11,26 @@ from optparse import OptionParser
from acoustid.config import Config
from acoustid.indexclient import IndexClientPool
from acoustid.utils import LocalSysLogHandler
from acoustid.db import DatabaseContext
from acoustid._release import GIT_RELEASE
logger = logging.getLogger(__name__)
class ScriptContext(object):
def __init__(self, db, redis, index):
self.db = db
self.redis = redis
self.index = index
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.db.close()
class Script(object):
def __init__(self, config_path, tests=False):
......@@ -23,25 +38,30 @@ class Script(object):
if config_path:
self.config.read(config_path)
self.config.read_env(tests=tests)
if tests:
self.engine = sqlalchemy.create_engine(self.config.database.create_url(),
poolclass=sqlalchemy.pool.AssertionPool)
else:
self.engine = sqlalchemy.create_engine(self.config.database.create_url())
create_engine_kwargs = {'poolclass': sqlalchemy.pool.AssertionPool} if tests else {}
self.db_engines = self.config.databases.create_engines(**create_engine_kwargs)
if not self.config.index.host:
self.index = None
else:
self.index = IndexClientPool(host=self.config.index.host,
port=self.config.index.port,
recycle=60)
if not self.config.redis.host:
self.redis = None
else:
self.redis = Redis(host=self.config.redis.host,
port=self.config.redis.port)
self._console_logging_configured = False
self.setup_logging()
@property
def engine(self):
return self.db_engines['default']
def setup_logging(self):
for logger_name, level in sorted(self.config.logging.levels.items()):
logging.getLogger(logger_name).setLevel(level)
......@@ -66,6 +86,10 @@ class Script(object):
def setup_sentry(self):
sentry_sdk.init(self.config.sentry.script_dsn, release=GIT_RELEASE)
def context(self):
db = DatabaseContext(self).session
return ScriptContext(db=db, redis=self.redis, index=self.index)
def run_script(func, option_cb=None, master_only=False):
parser = OptionParser()
......
......@@ -9,7 +9,7 @@ from cStringIO import StringIO
from werkzeug.exceptions import HTTPException
from werkzeug.routing import Map, Rule, Submount
from werkzeug.wrappers import Request
from werkzeug.contrib.fixers import ProxyFix
from werkzeug.middleware.proxy_fix import ProxyFix
from acoustid.script import Script
from acoustid._release import GIT_RELEASE
import acoustid.api.v1
......@@ -53,12 +53,15 @@ admin_url_rules = [
class Server(Script):
def __init__(self, config_path):
super(Server, self).__init__(config_path)
def __init__(self, config_path, **kwargs):
super(Server, self).__init__(config_path, **kwargs)
url_rules = api_url_rules + admin_url_rules
self.url_map = Map(url_rules, strict_slashes=False)
def __call__(self, environ, start_response):
return self.wsgi_app(environ, start_response)
def wsgi_app(self, environ, start_response):
urls = self.url_map.bind_to_environ(environ)
handler = None
try:
......@@ -112,16 +115,16 @@ def add_cors_headers(app):
return wrapped_app
def make_application(config_path):
def make_application(config_path, **kwargs):
"""Construct a WSGI application for the AcoustID server
:param config_path: path to the server configuration file
"""
server = Server(config_path)
server = Server(config_path, **kwargs)
server.setup_sentry()
app = GzipRequestMiddleware(server)
app = ProxyFix(app)
app = SentryWsgiMiddleware(app)
app = replace_double_slashes(app)
app = add_cors_headers(app)
return server, app
server.wsgi_app = GzipRequestMiddleware(server.wsgi_app)
server.wsgi_app = ProxyFix(server.wsgi_app)
server.wsgi_app = SentryWsgiMiddleware(server.wsgi_app)
server.wsgi_app = replace_double_slashes(server.wsgi_app)
server.wsgi_app = add_cors_headers(server.wsgi_app)
return server
import sqlalchemy
import sqlalchemy.event
from sqlalchemy import (
MetaData, Table, Column, Index,
MetaData, Table, Column, Index, Sequence,
ForeignKey, CheckConstraint,
Integer, String, DateTime, Boolean, Date, Text, SmallInteger, BigInteger, CHAR,
DDL, sql,
......@@ -103,7 +104,7 @@ source = Table('source', metadata,
Index('source_idx_uniq', 'application_id', 'account_id', 'version', unique=True),
)
submission = Table('submission', metadata,
submission_old = Table('submission_old', metadata,
Column('id', Integer, primary_key=True),
Column('fingerprint', ARRAY(Integer), nullable=False),
Column('length', SmallInteger, CheckConstraint('length>0'), nullable=False),
......@@ -118,7 +119,55 @@ submission = Table('submission', metadata,
Column('foreignid_id', Integer, ForeignKey('foreignid.id')),
)
Index('submission_idx_handled', submission.c.id, postgresql_where=submission.c.handled == False) # noqa: E712
Index('submission_idx_handled', submission_old.c.id, postgresql_where=submission_old.c.handled == False) # noqa: E712
submission_id_seq = Sequence('submission_id_seq', metadata=metadata)
submission = Table('submission', metadata,
Column('id', Integer, submission_id_seq, server_default=submission_id_seq.next_value(), primary_key=True),
Column('created', DateTime(timezone=True), server_default=sql.func.current_timestamp(), nullable=False),
Column('handled', Boolean, default=False, server_default=sql.false()),
Column('account_id', Integer, nullable=False), # ForeignKey('account.id')
Column('application_id', Integer, nullable=False), # ForeignKey('application.id')
Column('application_version', String),
Column('fingerprint', ARRAY(Integer), nullable=False),