Commit 6fd27d10 authored by Lukáš Lalinský's avatar Lukáš Lalinský

Merged in elserj/python-phoenixdb/user_and_password (pull request #2)

Extract certain kwargs from Connection into OpenConnectionRequest.info
parents 36041d50 0bdcc50d
...@@ -88,6 +88,12 @@ SQLSTATE_ERROR_CLASSES = [ ...@@ -88,6 +88,12 @@ SQLSTATE_ERROR_CLASSES = [
('INT', errors.InternalError), # Phoenix internal error ('INT', errors.InternalError), # Phoenix internal error
] ]
# Relevant properties as defined by https://calcite.apache.org/avatica/docs/client_reference.html
OPEN_CONNECTION_PROPERTIES = (
'user', # User for the database connection
'password', # Password for the user
)
def raise_sql_error(code, sqlstate, message): def raise_sql_error(code, sqlstate, message):
for prefix, error_class in SQLSTATE_ERROR_CLASSES: for prefix, error_class in SQLSTATE_ERROR_CLASSES:
...@@ -316,7 +322,9 @@ class AvaticaClient(object): ...@@ -316,7 +322,9 @@ class AvaticaClient(object):
request = requests_pb2.OpenConnectionRequest() request = requests_pb2.OpenConnectionRequest()
request.connection_id = connectionId request.connection_id = connectionId
if info is not None: if info is not None:
request.info = info # Info is a list of repeated pairs, setting a dict directly fails
for k, v in info.items():
request.info[k] = v
response_data = self._apply(request) response_data = self._apply(request)
response = responses_pb2.OpenConnectionResponse() response = responses_pb2.OpenConnectionResponse()
......
...@@ -16,6 +16,7 @@ import logging ...@@ -16,6 +16,7 @@ import logging
import uuid import uuid
import weakref import weakref
from phoenixdb import errors from phoenixdb import errors
from phoenixdb.avatica import OPEN_CONNECTION_PROPERTIES
from phoenixdb.cursor import Cursor from phoenixdb.cursor import Cursor
from phoenixdb.errors import OperationalError, NotSupportedError, ProgrammingError from phoenixdb.errors import OperationalError, NotSupportedError, ProgrammingError
...@@ -34,8 +35,17 @@ class Connection(object): ...@@ -34,8 +35,17 @@ class Connection(object):
self._client = client self._client = client
self._closed = False self._closed = False
self._cursors = [] self._cursors = []
# Extract properties to pass to OpenConnectionRequest
self._connection_args = {}
# The rest of the kwargs
self._filtered_args = {}
for k in kwargs:
if k in OPEN_CONNECTION_PROPERTIES:
self._connection_args[k] = kwargs[k]
else:
self._filtered_args[k] = kwargs[k]
self.open() self.open()
self.set_session(**kwargs) self.set_session(**self._filtered_args)
def __del__(self): def __del__(self):
if not self._closed: if not self._closed:
...@@ -51,7 +61,7 @@ class Connection(object): ...@@ -51,7 +61,7 @@ class Connection(object):
def open(self): def open(self):
"""Opens the connection.""" """Opens the connection."""
self._id = str(uuid.uuid4()) self._id = str(uuid.uuid4())
self._client.openConnection(self._id) self._client.openConnection(self._id, info=self._connection_args)
def close(self): def close(self):
"""Closes the connection. """Closes the connection.
......
import unittest
import phoenixdb
from . import dbapi20
from phoenixdb.tests import TEST_DB_URL
@unittest.skipIf(TEST_DB_URL is None, "these tests require the PHOENIXDB_TEST_DB_URL environment variable set to a clean database")
class PhoenixConnectionTest(unittest.TestCase):
def _connect(self, connect_kw_args):
try:
r = phoenixdb.connect(
*(TEST_DB_URL, ), **connect_kw_args
)
except AttributeError:
self.fail("Failed to connect")
return r
def test_connection_credentials(self):
connect_kw_args = {'user':'SCOTT', 'password':'TIGER', 'readonly':'True'}
con = self._connect(connect_kw_args)
try:
self.assertEqual(con._connection_args, {'user':'SCOTT', 'password':'TIGER'},
'Should have extract user and password'
)
self.assertEqual(con._filtered_args, {'readonly':'True'},
'Should have not extracted foo'
)
finally:
con.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment