Commit e3a8e9e5 authored by Lukáš Lalinský's avatar Lukáš Lalinský

Add decimal support

parent 624ce428
......@@ -21,6 +21,7 @@ import pprint
import json
import logging
import urlparse
from decimal import Decimal
from HTMLParser import HTMLParser
from phoenixdb import errors
......@@ -138,12 +139,24 @@ class AvaticaClient(object):
def _apply(self, request_data, expected_response_type=None):
logger.debug("Sending request\n%s", pprint.pformat(request_data))
class FakeFloat(float):
# XXX there has to be a better way to do this
def __init__(self, value):
self.value = value
def __repr__(self):
return str(self.value)
def default(obj):
if isinstance(obj, Decimal):
return FakeFloat(obj)
raise TypeError
if self.version >= AVATICA_1_4_0_INCUBATING:
body = json.dumps(request_data)
body = json.dumps(request_data, default=default)
headers = {'content-type': 'application/json'}
else:
body = None
headers = {'request': json.dumps(request_data)}
headers = {'request': json.dumps(request_data, default=default)}
try:
self.connection.request('POST', self.url.path, body=body, headers=headers)
......@@ -159,8 +172,9 @@ class AvaticaClient(object):
parse_error_page(response_body)
raise errors.InterfaceError('RPC request returned invalid status code', response.status)
noop = lambda x: x
try:
response_data = json.loads(response_body)
response_data = json.loads(response_body, parse_float=noop)
except ValueError as e:
logger.debug("Received response\n%s", response_body)
raise errors.InterfaceError('valid JSON document', cause=e)
......
......@@ -14,6 +14,7 @@
import logging
import collections
from decimal import Decimal
from phoenixdb.errors import OperationalError, NotSupportedError, ProgrammingError
__all__ = ['Cursor', 'ColumnDescription']
......@@ -49,6 +50,7 @@ class Cursor(object):
self._connection = connection
self._id = id
self._signature = None
self._data_types = []
self._frame = None
self._pos = None
self._closed = False
......@@ -89,6 +91,7 @@ class Cursor(object):
self._connection._client.closeStatement(self._connection._id, self._id)
self._id = None
self._signature = None
self._data_types = []
self._frame = None
self._pos = None
self._closed = True
......@@ -120,6 +123,18 @@ class Cursor(object):
self._connection._client.closeStatement(self._connection._id, self._id)
self._id = id
def _set_signature(self, signature):
self._signature = signature
self._data_types = []
if signature is None:
return
identity = lambda value: value
for i, column in enumerate(signature['columns']):
if column['columnClassName'] == 'java.math.BigDecimal':
self._data_types.append((i, Decimal))
elif column['columnClassName'] == 'java.lang.Float' or column['columnClassName'] == 'java.lang.Double':
self._data_types.append((i, float))
def _set_frame(self, frame):
self._frame = frame
self._pos = None
......@@ -146,14 +161,14 @@ class Cursor(object):
result = results[0]
if result['ownStatement']:
self._set_id(result['statementId'])
self._set_signature(result['signature'])
self._set_frame(result['firstFrame'])
self._signature = result['signature']
self._updatecount = result['updateCount']
else:
statement = self._connection._client.prepare(self._connection._id, self._id,
operation, maxRowCount=self.itersize)
self._set_id(statement['id'])
self._signature = statement['signature']
self._set_signature(statement['signature'])
frame = self._connection._client.fetch(self._connection._id, self._id,
parameters, fetchMaxRowCount=self.itersize)
self._set_frame(frame)
......@@ -183,6 +198,10 @@ class Cursor(object):
self._pos = None
if not self._frame['done']:
self._fetch_next_frame()
for i, data_type in self._data_types:
value = row[i]
if value is not None:
row[i] = data_type(value)
return row
def fetchmany(self, size=None):
......
import unittest
import phoenixdb
from decimal import Decimal
from phoenixdb.tests import DatabaseTestCase
......@@ -77,9 +78,25 @@ class TypesTest(DatabaseTestCase):
def test_unsigned_double(self):
self.checkFloatType("unsigned_double", 0, 1.7976931348623158E+308)
@unittest.skip("not implemented")
def test_decimal(self):
assert False
self.createTable("phoenixdb_test_tbl1", "id integer primary key, val decimal(8,3)")
with self.conn.cursor() as cursor:
cursor.execute("UPSERT INTO phoenixdb_test_tbl1 VALUES (1, 33333.333)")
cursor.execute("UPSERT INTO phoenixdb_test_tbl1 VALUES (2, NULL)")
cursor.execute("UPSERT INTO phoenixdb_test_tbl1 VALUES (3, ?)", [33333.333])
cursor.execute("UPSERT INTO phoenixdb_test_tbl1 VALUES (4, ?)", [Decimal('33333.333')])
cursor.execute("UPSERT INTO phoenixdb_test_tbl1 VALUES (5, ?)", [None])
cursor.execute("SELECT id, val FROM phoenixdb_test_tbl1 ORDER BY id")
self.assertEqual(cursor.description[1].type_code, phoenixdb.NUMBER)
rows = cursor.fetchall()
self.assertEqual([r[0] for r in rows], [1, 2, 3, 4, 5])
self.assertEqual(rows[0][1], Decimal('33333.333'))
self.assertEqual(rows[1][1], None)
self.assertEqual(rows[2][1], Decimal('33333.333'))
self.assertEqual(rows[3][1], Decimal('33333.333'))
self.assertEqual(rows[4][1], None)
self.assertRaises(self.conn.DatabaseError, cursor.execute, "UPSERT INTO phoenixdb_test_tbl1 VALUES (100, ?)", [Decimal('1234567890')])
self.assertRaises(self.conn.DatabaseError, cursor.execute, "UPSERT INTO phoenixdb_test_tbl1 VALUES (101, ?)", [Decimal('123456.789')])
def test_boolean(self):
self.createTable("phoenixdb_test_tbl1", "id integer primary key, val boolean")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment