299 lines
10 KiB
Python
299 lines
10 KiB
Python
|
#!/usr/bin/env python -O
|
||
|
""" Script to test database capabilities and the DB-API interface
|
||
|
for functionality and memory leaks.
|
||
|
|
||
|
Adapted from a script by M-A Lemburg.
|
||
|
|
||
|
"""
|
||
|
import sys
|
||
|
from time import time
|
||
|
try:
|
||
|
import unittest2 as unittest
|
||
|
except ImportError:
|
||
|
import unittest
|
||
|
|
||
|
PY2 = sys.version_info[0] == 2
|
||
|
|
||
|
class DatabaseTest(unittest.TestCase):
|
||
|
|
||
|
db_module = None
|
||
|
connect_args = ()
|
||
|
connect_kwargs = dict(use_unicode=True, charset="utf8")
|
||
|
create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
|
||
|
rows = 10
|
||
|
debug = False
|
||
|
|
||
|
def setUp(self):
|
||
|
db = self.db_module.connect(*self.connect_args, **self.connect_kwargs)
|
||
|
self.connection = db
|
||
|
self.cursor = db.cursor()
|
||
|
self.BLOBText = ''.join([chr(i) for i in range(256)] * 100);
|
||
|
if PY2:
|
||
|
self.BLOBUText = unicode().join(unichr(i) for i in range(16834))
|
||
|
else:
|
||
|
self.BLOBUText = "".join(chr(i) for i in range(16834))
|
||
|
data = bytearray(range(256)) * 16
|
||
|
self.BLOBBinary = self.db_module.Binary(data)
|
||
|
|
||
|
leak_test = True
|
||
|
|
||
|
def tearDown(self):
|
||
|
if self.leak_test:
|
||
|
import gc
|
||
|
del self.cursor
|
||
|
orphans = gc.collect()
|
||
|
self.assertFalse(orphans, "%d orphaned objects found after deleting cursor" % orphans)
|
||
|
|
||
|
del self.connection
|
||
|
orphans = gc.collect()
|
||
|
self.assertFalse(orphans, "%d orphaned objects found after deleting connection" % orphans)
|
||
|
|
||
|
def table_exists(self, name):
|
||
|
try:
|
||
|
self.cursor.execute('select * from %s where 1=0' % name)
|
||
|
except Exception:
|
||
|
return False
|
||
|
else:
|
||
|
return True
|
||
|
|
||
|
def quote_identifier(self, ident):
|
||
|
return '"%s"' % ident
|
||
|
|
||
|
def new_table_name(self):
|
||
|
i = id(self.cursor)
|
||
|
while True:
|
||
|
name = self.quote_identifier('tb%08x' % i)
|
||
|
if not self.table_exists(name):
|
||
|
return name
|
||
|
i = i + 1
|
||
|
|
||
|
def create_table(self, columndefs):
|
||
|
|
||
|
""" Create a table using a list of column definitions given in
|
||
|
columndefs.
|
||
|
|
||
|
generator must be a function taking arguments (row_number,
|
||
|
col_number) returning a suitable data object for insertion
|
||
|
into the table.
|
||
|
|
||
|
"""
|
||
|
self.table = self.new_table_name()
|
||
|
self.cursor.execute('CREATE TABLE %s (%s) %s' %
|
||
|
(self.table,
|
||
|
',\n'.join(columndefs),
|
||
|
self.create_table_extra))
|
||
|
|
||
|
def check_data_integrity(self, columndefs, generator):
|
||
|
# insert
|
||
|
self.create_table(columndefs)
|
||
|
insert_statement = ('INSERT INTO %s VALUES (%s)' %
|
||
|
(self.table,
|
||
|
','.join(['%s'] * len(columndefs))))
|
||
|
data = [ [ generator(i,j) for j in range(len(columndefs)) ]
|
||
|
for i in range(self.rows) ]
|
||
|
if self.debug:
|
||
|
print(data)
|
||
|
self.cursor.executemany(insert_statement, data)
|
||
|
self.connection.commit()
|
||
|
# verify
|
||
|
self.cursor.execute('select * from %s' % self.table)
|
||
|
l = self.cursor.fetchall()
|
||
|
if self.debug:
|
||
|
print(l)
|
||
|
self.assertEqual(len(l), self.rows)
|
||
|
try:
|
||
|
for i in range(self.rows):
|
||
|
for j in range(len(columndefs)):
|
||
|
self.assertEqual(l[i][j], generator(i,j))
|
||
|
finally:
|
||
|
if not self.debug:
|
||
|
self.cursor.execute('drop table %s' % (self.table))
|
||
|
|
||
|
def test_transactions(self):
|
||
|
columndefs = ( 'col1 INT', 'col2 VARCHAR(255)')
|
||
|
def generator(row, col):
|
||
|
if col == 0: return row
|
||
|
else: return ('%i' % (row%10))*255
|
||
|
self.create_table(columndefs)
|
||
|
insert_statement = ('INSERT INTO %s VALUES (%s)' %
|
||
|
(self.table,
|
||
|
','.join(['%s'] * len(columndefs))))
|
||
|
data = [ [ generator(i,j) for j in range(len(columndefs)) ]
|
||
|
for i in range(self.rows) ]
|
||
|
self.cursor.executemany(insert_statement, data)
|
||
|
# verify
|
||
|
self.connection.commit()
|
||
|
self.cursor.execute('select * from %s' % self.table)
|
||
|
l = self.cursor.fetchall()
|
||
|
self.assertEqual(len(l), self.rows)
|
||
|
for i in range(self.rows):
|
||
|
for j in range(len(columndefs)):
|
||
|
self.assertEqual(l[i][j], generator(i,j))
|
||
|
delete_statement = 'delete from %s where col1=%%s' % self.table
|
||
|
self.cursor.execute(delete_statement, (0,))
|
||
|
self.cursor.execute('select col1 from %s where col1=%s' % \
|
||
|
(self.table, 0))
|
||
|
l = self.cursor.fetchall()
|
||
|
self.assertFalse(l, "DELETE didn't work")
|
||
|
self.connection.rollback()
|
||
|
self.cursor.execute('select col1 from %s where col1=%s' % \
|
||
|
(self.table, 0))
|
||
|
l = self.cursor.fetchall()
|
||
|
self.assertTrue(len(l) == 1, "ROLLBACK didn't work")
|
||
|
self.cursor.execute('drop table %s' % (self.table))
|
||
|
|
||
|
def test_truncation(self):
|
||
|
columndefs = ( 'col1 INT', 'col2 VARCHAR(255)')
|
||
|
def generator(row, col):
|
||
|
if col == 0: return row
|
||
|
else: return ('%i' % (row%10))*((255-self.rows//2)+row)
|
||
|
self.create_table(columndefs)
|
||
|
insert_statement = ('INSERT INTO %s VALUES (%s)' %
|
||
|
(self.table,
|
||
|
','.join(['%s'] * len(columndefs))))
|
||
|
|
||
|
try:
|
||
|
self.cursor.execute(insert_statement, (0, '0'*256))
|
||
|
except Warning:
|
||
|
if self.debug: print(self.cursor.messages)
|
||
|
except self.connection.DataError:
|
||
|
pass
|
||
|
else:
|
||
|
self.fail("Over-long column did not generate warnings/exception with single insert")
|
||
|
|
||
|
self.connection.rollback()
|
||
|
|
||
|
try:
|
||
|
for i in range(self.rows):
|
||
|
data = []
|
||
|
for j in range(len(columndefs)):
|
||
|
data.append(generator(i,j))
|
||
|
self.cursor.execute(insert_statement,tuple(data))
|
||
|
except Warning:
|
||
|
if self.debug: print(self.cursor.messages)
|
||
|
except self.connection.DataError:
|
||
|
pass
|
||
|
else:
|
||
|
self.fail("Over-long columns did not generate warnings/exception with execute()")
|
||
|
|
||
|
self.connection.rollback()
|
||
|
|
||
|
try:
|
||
|
data = [ [ generator(i,j) for j in range(len(columndefs)) ]
|
||
|
for i in range(self.rows) ]
|
||
|
self.cursor.executemany(insert_statement, data)
|
||
|
except Warning:
|
||
|
if self.debug: print(self.cursor.messages)
|
||
|
except self.connection.DataError:
|
||
|
pass
|
||
|
else:
|
||
|
self.fail("Over-long columns did not generate warnings/exception with executemany()")
|
||
|
|
||
|
self.connection.rollback()
|
||
|
self.cursor.execute('drop table %s' % (self.table))
|
||
|
|
||
|
def test_CHAR(self):
|
||
|
# Character data
|
||
|
def generator(row,col):
|
||
|
return ('%i' % ((row+col) % 10)) * 255
|
||
|
self.check_data_integrity(
|
||
|
('col1 char(255)','col2 char(255)'),
|
||
|
generator)
|
||
|
|
||
|
def test_INT(self):
|
||
|
# Number data
|
||
|
def generator(row,col):
|
||
|
return row*row
|
||
|
self.check_data_integrity(
|
||
|
('col1 INT',),
|
||
|
generator)
|
||
|
|
||
|
def test_DECIMAL(self):
|
||
|
# DECIMAL
|
||
|
def generator(row,col):
|
||
|
from decimal import Decimal
|
||
|
return Decimal("%d.%02d" % (row, col))
|
||
|
self.check_data_integrity(
|
||
|
('col1 DECIMAL(5,2)',),
|
||
|
generator)
|
||
|
|
||
|
def test_DATE(self):
|
||
|
ticks = time()
|
||
|
def generator(row,col):
|
||
|
return self.db_module.DateFromTicks(ticks+row*86400-col*1313)
|
||
|
self.check_data_integrity(
|
||
|
('col1 DATE',),
|
||
|
generator)
|
||
|
|
||
|
def test_TIME(self):
|
||
|
ticks = time()
|
||
|
def generator(row,col):
|
||
|
return self.db_module.TimeFromTicks(ticks+row*86400-col*1313)
|
||
|
self.check_data_integrity(
|
||
|
('col1 TIME',),
|
||
|
generator)
|
||
|
|
||
|
def test_DATETIME(self):
|
||
|
ticks = time()
|
||
|
def generator(row,col):
|
||
|
return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313)
|
||
|
self.check_data_integrity(
|
||
|
('col1 DATETIME',),
|
||
|
generator)
|
||
|
|
||
|
def test_TIMESTAMP(self):
|
||
|
ticks = time()
|
||
|
def generator(row,col):
|
||
|
return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313)
|
||
|
self.check_data_integrity(
|
||
|
('col1 TIMESTAMP',),
|
||
|
generator)
|
||
|
|
||
|
def test_fractional_TIMESTAMP(self):
|
||
|
ticks = time()
|
||
|
def generator(row,col):
|
||
|
return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0)
|
||
|
self.check_data_integrity(
|
||
|
('col1 TIMESTAMP',),
|
||
|
generator)
|
||
|
|
||
|
def test_LONG(self):
|
||
|
def generator(row,col):
|
||
|
if col == 0:
|
||
|
return row
|
||
|
else:
|
||
|
return self.BLOBUText # 'BLOB Text ' * 1024
|
||
|
self.check_data_integrity(
|
||
|
('col1 INT', 'col2 LONG'),
|
||
|
generator)
|
||
|
|
||
|
def test_TEXT(self):
|
||
|
def generator(row,col):
|
||
|
if col == 0:
|
||
|
return row
|
||
|
else:
|
||
|
return self.BLOBUText[:5192] # 'BLOB Text ' * 1024
|
||
|
self.check_data_integrity(
|
||
|
('col1 INT', 'col2 TEXT'),
|
||
|
generator)
|
||
|
|
||
|
def test_LONG_BYTE(self):
|
||
|
def generator(row,col):
|
||
|
if col == 0:
|
||
|
return row
|
||
|
else:
|
||
|
return self.BLOBBinary # 'BLOB\000Binary ' * 1024
|
||
|
self.check_data_integrity(
|
||
|
('col1 INT','col2 LONG BYTE'),
|
||
|
generator)
|
||
|
|
||
|
def test_BLOB(self):
|
||
|
def generator(row,col):
|
||
|
if col == 0:
|
||
|
return row
|
||
|
else:
|
||
|
return self.BLOBBinary # 'BLOB\000Binary ' * 1024
|
||
|
self.check_data_integrity(
|
||
|
('col1 INT','col2 BLOB'),
|
||
|
generator)
|