Commit 9b8e46e3 authored by Josh Elser's avatar Josh Elser

Extract certain kwargs from Connection into OpenConnectionRequest.info

Also adds a unit test to show that the extra args are extracted

Fixes #10
parent a5f24c38
......@@ -88,6 +88,12 @@ SQLSTATE_ERROR_CLASSES = [
('INT', errors.InternalError), # Phoenix internal error
]
# Relevant properties as defined by https://calcite.apache.org/avatica/docs/client_reference.html
OPEN_CONNECTION_PROPERTIES = (
'user', # User for the database connection
'password', # Password for the user
)
def raise_sql_error(code, sqlstate, message):
for prefix, error_class in SQLSTATE_ERROR_CLASSES:
......@@ -316,7 +322,9 @@ class AvaticaClient(object):
request = requests_pb2.OpenConnectionRequest()
request.connection_id = connectionId
if info is not None:
request.info = info
# Info is a list of repeated pairs, setting a dict directly fails
for k in info:
request.info[k] = info[k]
response_data = self._apply(request)
response = responses_pb2.OpenConnectionResponse()
......
......@@ -16,6 +16,7 @@ import logging
import uuid
import weakref
from phoenixdb import errors
from phoenixdb.avatica import OPEN_CONNECTION_PROPERTIES
from phoenixdb.cursor import Cursor
from phoenixdb.errors import OperationalError, NotSupportedError, ProgrammingError
......@@ -34,8 +35,17 @@ class Connection(object):
self._client = client
self._closed = False
self._cursors = []
# Extract properties to pass to OpenConnectionRequest
self._connection_args = {}
# The rest of the kwargs
self._filtered_args = {}
for k in kwargs:
if k in OPEN_CONNECTION_PROPERTIES:
self._connection_args[k] = kwargs[k]
else:
self._filtered_args[k] = kwargs[k]
self.open()
self.set_session(**kwargs)
self.set_session(**self._filtered_args)
def __del__(self):
if not self._closed:
......@@ -51,7 +61,7 @@ class Connection(object):
def open(self):
"""Opens the connection."""
self._id = str(uuid.uuid4())
self._client.openConnection(self._id)
self._client.openConnection(self._id, info=self._connection_args)
def close(self):
"""Closes the connection.
......
......@@ -8,6 +8,7 @@ from phoenixdb.tests import TEST_DB_URL
class PhoenixDatabaseAPI20Test(dbapi20.DatabaseAPI20Test):
driver = phoenixdb
connect_args = (TEST_DB_URL, )
connect_kw_args = {}
ddl1 = 'create table %sbooze (name varchar(20) primary key)' % dbapi20.DatabaseAPI20Test.table_prefix
ddl2 = 'create table %sbarflys (name varchar(20) primary key, drink varchar(30))' % dbapi20.DatabaseAPI20Test.table_prefix
......@@ -36,6 +37,7 @@ class PhoenixDatabaseAPI20Test(dbapi20.DatabaseAPI20Test):
con.close()
def test_autocommit(self):
self.connect_kw_args = {}
con = dbapi20.DatabaseAPI20Test._connect(self)
self.assertFalse(con.autocommit)
con.autocommit = True
......@@ -45,6 +47,7 @@ class PhoenixDatabaseAPI20Test(dbapi20.DatabaseAPI20Test):
con.close()
def test_readonly(self):
self.connect_kw_args = {}
con = dbapi20.DatabaseAPI20Test._connect(self)
self.assertFalse(con.readonly)
con.readonly = True
......@@ -106,3 +109,18 @@ class PhoenixDatabaseAPI20Test(dbapi20.DatabaseAPI20Test):
self.failUnless(cur.rowcount in (-1,1))
finally:
con.close()
def test_credentials(self):
self.connect_kw_args['user'] = 'SCOTT'
self.connect_kw_args['password'] = 'TIGER'
self.connect_kw_args['readonly'] = 'True'
con = self._connect()
try:
self.assertEqual(con._connection_args, {'user':'SCOTT', 'password':'TIGER'},
'Should have extract user and password'
)
self.assertEqual(con._filtered_args, {'readonly':'True'},
'Should have not extracted foo'
)
finally:
con.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment