577 lines
24 KiB
Python
577 lines
24 KiB
Python
|
import datetime
|
||
|
import sys
|
||
|
import time
|
||
|
import unittest2
|
||
|
import pymysql
|
||
|
from pymysql.tests import base
|
||
|
from pymysql._compat import text_type
|
||
|
|
||
|
|
||
|
class TempUser:
|
||
|
def __init__(self, c, user, db, auth=None, authdata=None, password=None):
|
||
|
self._c = c
|
||
|
self._user = user
|
||
|
self._db = db
|
||
|
create = "CREATE USER " + user
|
||
|
if password is not None:
|
||
|
create += " IDENTIFIED BY '%s'" % password
|
||
|
elif auth is not None:
|
||
|
create += " IDENTIFIED WITH %s" % auth
|
||
|
if authdata is not None:
|
||
|
create += " AS '%s'" % authdata
|
||
|
try:
|
||
|
c.execute(create)
|
||
|
self._created = True
|
||
|
except pymysql.err.InternalError:
|
||
|
# already exists - TODO need to check the same plugin applies
|
||
|
self._created = False
|
||
|
try:
|
||
|
c.execute("GRANT SELECT ON %s.* TO %s" % (db, user))
|
||
|
self._grant = True
|
||
|
except pymysql.err.InternalError:
|
||
|
self._grant = False
|
||
|
|
||
|
def __enter__(self):
|
||
|
return self
|
||
|
|
||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||
|
if self._grant:
|
||
|
self._c.execute("REVOKE SELECT ON %s.* FROM %s" % (self._db, self._user))
|
||
|
if self._created:
|
||
|
self._c.execute("DROP USER %s" % self._user)
|
||
|
|
||
|
|
||
|
class TestAuthentication(base.PyMySQLTestCase):
|
||
|
|
||
|
socket_auth = False
|
||
|
socket_found = False
|
||
|
two_questions_found = False
|
||
|
three_attempts_found = False
|
||
|
pam_found = False
|
||
|
mysql_old_password_found = False
|
||
|
sha256_password_found = False
|
||
|
|
||
|
import os
|
||
|
osuser = os.environ.get('USER')
|
||
|
|
||
|
# socket auth requires the current user and for the connection to be a socket
|
||
|
# rest do grants @localhost due to incomplete logic - TODO change to @% then
|
||
|
db = base.PyMySQLTestCase.databases[0].copy()
|
||
|
|
||
|
socket_auth = db.get('unix_socket') is not None \
|
||
|
and db.get('host') in ('localhost', '127.0.0.1')
|
||
|
|
||
|
cur = pymysql.connect(**db).cursor()
|
||
|
del db['user']
|
||
|
cur.execute("SHOW PLUGINS")
|
||
|
for r in cur:
|
||
|
if (r[1], r[2]) != (u'ACTIVE', u'AUTHENTICATION'):
|
||
|
continue
|
||
|
if r[3] == u'auth_socket.so':
|
||
|
socket_plugin_name = r[0]
|
||
|
socket_found = True
|
||
|
elif r[3] == u'dialog_examples.so':
|
||
|
if r[0] == 'two_questions':
|
||
|
two_questions_found = True
|
||
|
elif r[0] == 'three_attempts':
|
||
|
three_attempts_found = True
|
||
|
elif r[0] == u'pam':
|
||
|
pam_found = True
|
||
|
pam_plugin_name = r[3].split('.')[0]
|
||
|
if pam_plugin_name == 'auth_pam':
|
||
|
pam_plugin_name = 'pam'
|
||
|
# MySQL: authentication_pam
|
||
|
# https://dev.mysql.com/doc/refman/5.5/en/pam-authentication-plugin.html
|
||
|
|
||
|
# MariaDB: pam
|
||
|
# https://mariadb.com/kb/en/mariadb/pam-authentication-plugin/
|
||
|
|
||
|
# Names differ but functionality is close
|
||
|
elif r[0] == u'mysql_old_password':
|
||
|
mysql_old_password_found = True
|
||
|
elif r[0] == u'sha256_password':
|
||
|
sha256_password_found = True
|
||
|
#else:
|
||
|
# print("plugin: %r" % r[0])
|
||
|
|
||
|
def test_plugin(self):
|
||
|
# Bit of an assumption that the current user is a native password
|
||
|
self.assertEqual('mysql_native_password', self.connections[0]._auth_plugin_name)
|
||
|
|
||
|
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||
|
@unittest2.skipIf(socket_found, "socket plugin already installed")
|
||
|
def testSocketAuthInstallPlugin(self):
|
||
|
# needs plugin. lets install it.
|
||
|
cur = self.connections[0].cursor()
|
||
|
try:
|
||
|
cur.execute("install plugin auth_socket soname 'auth_socket.so'")
|
||
|
TestAuthentication.socket_found = True
|
||
|
self.socket_plugin_name = 'auth_socket'
|
||
|
self.realtestSocketAuth()
|
||
|
except pymysql.err.InternalError:
|
||
|
try:
|
||
|
cur.execute("install soname 'auth_socket'")
|
||
|
TestAuthentication.socket_found = True
|
||
|
self.socket_plugin_name = 'unix_socket'
|
||
|
self.realtestSocketAuth()
|
||
|
except pymysql.err.InternalError:
|
||
|
TestAuthentication.socket_found = False
|
||
|
raise unittest2.SkipTest('we couldn\'t install the socket plugin')
|
||
|
finally:
|
||
|
if TestAuthentication.socket_found:
|
||
|
cur.execute("uninstall plugin %s" % self.socket_plugin_name)
|
||
|
|
||
|
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||
|
@unittest2.skipUnless(socket_found, "no socket plugin")
|
||
|
def testSocketAuth(self):
|
||
|
self.realtestSocketAuth()
|
||
|
|
||
|
def realtestSocketAuth(self):
|
||
|
with TempUser(self.connections[0].cursor(), TestAuthentication.osuser + '@localhost',
|
||
|
self.databases[0]['db'], self.socket_plugin_name) as u:
|
||
|
c = pymysql.connect(user=TestAuthentication.osuser, **self.db)
|
||
|
|
||
|
class Dialog(object):
|
||
|
fail=False
|
||
|
|
||
|
def __init__(self, con):
|
||
|
self.fail=TestAuthentication.Dialog.fail
|
||
|
pass
|
||
|
|
||
|
def prompt(self, echo, prompt):
|
||
|
if self.fail:
|
||
|
self.fail=False
|
||
|
return b'bad guess at a password'
|
||
|
return self.m.get(prompt)
|
||
|
|
||
|
class DialogHandler(object):
|
||
|
|
||
|
def __init__(self, con):
|
||
|
self.con=con
|
||
|
|
||
|
def authenticate(self, pkt):
|
||
|
while True:
|
||
|
flag = pkt.read_uint8()
|
||
|
echo = (flag & 0x06) == 0x02
|
||
|
last = (flag & 0x01) == 0x01
|
||
|
prompt = pkt.read_all()
|
||
|
|
||
|
if prompt == b'Password, please:':
|
||
|
self.con.write_packet(b'stillnotverysecret\0')
|
||
|
else:
|
||
|
self.con.write_packet(b'no idea what to do with this prompt\0')
|
||
|
pkt = self.con._read_packet()
|
||
|
pkt.check_error()
|
||
|
if pkt.is_ok_packet() or last:
|
||
|
break
|
||
|
return pkt
|
||
|
|
||
|
class DefectiveHandler(object):
|
||
|
def __init__(self, con):
|
||
|
self.con=con
|
||
|
|
||
|
|
||
|
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||
|
@unittest2.skipIf(two_questions_found, "two_questions plugin already installed")
|
||
|
def testDialogAuthTwoQuestionsInstallPlugin(self):
|
||
|
# needs plugin. lets install it.
|
||
|
cur = self.connections[0].cursor()
|
||
|
try:
|
||
|
cur.execute("install plugin two_questions soname 'dialog_examples.so'")
|
||
|
TestAuthentication.two_questions_found = True
|
||
|
self.realTestDialogAuthTwoQuestions()
|
||
|
except pymysql.err.InternalError:
|
||
|
raise unittest2.SkipTest('we couldn\'t install the two_questions plugin')
|
||
|
finally:
|
||
|
if TestAuthentication.two_questions_found:
|
||
|
cur.execute("uninstall plugin two_questions")
|
||
|
|
||
|
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||
|
@unittest2.skipUnless(two_questions_found, "no two questions auth plugin")
|
||
|
def testDialogAuthTwoQuestions(self):
|
||
|
self.realTestDialogAuthTwoQuestions()
|
||
|
|
||
|
def realTestDialogAuthTwoQuestions(self):
|
||
|
TestAuthentication.Dialog.fail=False
|
||
|
TestAuthentication.Dialog.m = {b'Password, please:': b'notverysecret',
|
||
|
b'Are you sure ?': b'yes, of course'}
|
||
|
with TempUser(self.connections[0].cursor(), 'pymysql_2q@localhost',
|
||
|
self.databases[0]['db'], 'two_questions', 'notverysecret') as u:
|
||
|
with self.assertRaises(pymysql.err.OperationalError):
|
||
|
pymysql.connect(user='pymysql_2q', **self.db)
|
||
|
pymysql.connect(user='pymysql_2q', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||
|
|
||
|
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||
|
@unittest2.skipIf(three_attempts_found, "three_attempts plugin already installed")
|
||
|
def testDialogAuthThreeAttemptsQuestionsInstallPlugin(self):
|
||
|
# needs plugin. lets install it.
|
||
|
cur = self.connections[0].cursor()
|
||
|
try:
|
||
|
cur.execute("install plugin three_attempts soname 'dialog_examples.so'")
|
||
|
TestAuthentication.three_attempts_found = True
|
||
|
self.realTestDialogAuthThreeAttempts()
|
||
|
except pymysql.err.InternalError:
|
||
|
raise unittest2.SkipTest('we couldn\'t install the three_attempts plugin')
|
||
|
finally:
|
||
|
if TestAuthentication.three_attempts_found:
|
||
|
cur.execute("uninstall plugin three_attempts")
|
||
|
|
||
|
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||
|
@unittest2.skipUnless(three_attempts_found, "no three attempts plugin")
|
||
|
def testDialogAuthThreeAttempts(self):
|
||
|
self.realTestDialogAuthThreeAttempts()
|
||
|
|
||
|
def realTestDialogAuthThreeAttempts(self):
|
||
|
TestAuthentication.Dialog.m = {b'Password, please:': b'stillnotverysecret'}
|
||
|
TestAuthentication.Dialog.fail=True # fail just once. We've got three attempts after all
|
||
|
with TempUser(self.connections[0].cursor(), 'pymysql_3a@localhost',
|
||
|
self.databases[0]['db'], 'three_attempts', 'stillnotverysecret') as u:
|
||
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DialogHandler}, **self.db)
|
||
|
with self.assertRaises(pymysql.err.OperationalError):
|
||
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': object}, **self.db)
|
||
|
|
||
|
with self.assertRaises(pymysql.err.OperationalError):
|
||
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DefectiveHandler}, **self.db)
|
||
|
with self.assertRaises(pymysql.err.OperationalError):
|
||
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'notdialogplugin': TestAuthentication.Dialog}, **self.db)
|
||
|
TestAuthentication.Dialog.m = {b'Password, please:': b'I do not know'}
|
||
|
with self.assertRaises(pymysql.err.OperationalError):
|
||
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||
|
TestAuthentication.Dialog.m = {b'Password, please:': None}
|
||
|
with self.assertRaises(pymysql.err.OperationalError):
|
||
|
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
|
||
|
|
||
|
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||
|
@unittest2.skipIf(pam_found, "pam plugin already installed")
|
||
|
@unittest2.skipIf(os.environ.get('PASSWORD') is None, "PASSWORD env var required")
|
||
|
@unittest2.skipIf(os.environ.get('PAMSERVICE') is None, "PAMSERVICE env var required")
|
||
|
def testPamAuthInstallPlugin(self):
|
||
|
# needs plugin. lets install it.
|
||
|
cur = self.connections[0].cursor()
|
||
|
try:
|
||
|
cur.execute("install plugin pam soname 'auth_pam.so'")
|
||
|
TestAuthentication.pam_found = True
|
||
|
self.realTestPamAuth()
|
||
|
except pymysql.err.InternalError:
|
||
|
raise unittest2.SkipTest('we couldn\'t install the auth_pam plugin')
|
||
|
finally:
|
||
|
if TestAuthentication.pam_found:
|
||
|
cur.execute("uninstall plugin pam")
|
||
|
|
||
|
|
||
|
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||
|
@unittest2.skipUnless(pam_found, "no pam plugin")
|
||
|
@unittest2.skipIf(os.environ.get('PASSWORD') is None, "PASSWORD env var required")
|
||
|
@unittest2.skipIf(os.environ.get('PAMSERVICE') is None, "PAMSERVICE env var required")
|
||
|
def testPamAuth(self):
|
||
|
self.realTestPamAuth()
|
||
|
|
||
|
def realTestPamAuth(self):
|
||
|
db = self.db.copy()
|
||
|
import os
|
||
|
db['password'] = os.environ.get('PASSWORD')
|
||
|
cur = self.connections[0].cursor()
|
||
|
try:
|
||
|
cur.execute('show grants for ' + TestAuthentication.osuser + '@localhost')
|
||
|
grants = cur.fetchone()[0]
|
||
|
cur.execute('drop user ' + TestAuthentication.osuser + '@localhost')
|
||
|
except pymysql.OperationalError as e:
|
||
|
# assuming the user doesn't exist which is ok too
|
||
|
self.assertEqual(1045, e.args[0])
|
||
|
grants = None
|
||
|
with TempUser(cur, TestAuthentication.osuser + '@localhost',
|
||
|
self.databases[0]['db'], 'pam', os.environ.get('PAMSERVICE')) as u:
|
||
|
try:
|
||
|
c = pymysql.connect(user=TestAuthentication.osuser, **db)
|
||
|
db['password'] = 'very bad guess at password'
|
||
|
with self.assertRaises(pymysql.err.OperationalError):
|
||
|
pymysql.connect(user=TestAuthentication.osuser,
|
||
|
auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
|
||
|
**self.db)
|
||
|
except pymysql.OperationalError as e:
|
||
|
self.assertEqual(1045, e.args[0])
|
||
|
# we had 'bad guess at password' work with pam. Well at least we get a permission denied here
|
||
|
with self.assertRaises(pymysql.err.OperationalError):
|
||
|
pymysql.connect(user=TestAuthentication.osuser,
|
||
|
auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
|
||
|
**self.db)
|
||
|
if grants:
|
||
|
# recreate the user
|
||
|
cur.execute(grants)
|
||
|
|
||
|
# select old_password("crummy p\tassword");
|
||
|
#| old_password("crummy p\tassword") |
|
||
|
#| 2a01785203b08770 |
|
||
|
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||
|
@unittest2.skipUnless(mysql_old_password_found, "no mysql_old_password plugin")
|
||
|
def testMySQLOldPasswordAuth(self):
|
||
|
if self.mysql_server_is(self.connections[0], (5, 7, 0)):
|
||
|
raise unittest2.SkipTest('Old passwords aren\'t supported in 5.7')
|
||
|
# pymysql.err.OperationalError: (1045, "Access denied for user 'old_pass_user'@'localhost' (using password: YES)")
|
||
|
# from login in MySQL-5.6
|
||
|
if self.mysql_server_is(self.connections[0], (5, 6, 0)):
|
||
|
raise unittest2.SkipTest('Old passwords don\'t authenticate in 5.6')
|
||
|
db = self.db.copy()
|
||
|
db['password'] = "crummy p\tassword"
|
||
|
with self.connections[0] as c:
|
||
|
# deprecated in 5.6
|
||
|
if sys.version_info[0:2] >= (3,2) and self.mysql_server_is(self.connections[0], (5, 6, 0)):
|
||
|
with self.assertWarns(pymysql.err.Warning) as cm:
|
||
|
c.execute("SELECT OLD_PASSWORD('%s')" % db['password'])
|
||
|
else:
|
||
|
c.execute("SELECT OLD_PASSWORD('%s')" % db['password'])
|
||
|
v = c.fetchone()[0]
|
||
|
self.assertEqual(v, '2a01785203b08770')
|
||
|
# only works in MariaDB and MySQL-5.6 - can't separate out by version
|
||
|
#if self.mysql_server_is(self.connections[0], (5, 5, 0)):
|
||
|
# with TempUser(c, 'old_pass_user@localhost',
|
||
|
# self.databases[0]['db'], 'mysql_old_password', '2a01785203b08770') as u:
|
||
|
# cur = pymysql.connect(user='old_pass_user', **db).cursor()
|
||
|
# cur.execute("SELECT VERSION()")
|
||
|
c.execute("SELECT @@secure_auth")
|
||
|
secure_auth_setting = c.fetchone()[0]
|
||
|
c.execute('set old_passwords=1')
|
||
|
# pymysql.err.Warning: 'pre-4.1 password hash' is deprecated and will be removed in a future release. Please use post-4.1 password hash instead
|
||
|
if sys.version_info[0:2] >= (3,2) and self.mysql_server_is(self.connections[0], (5, 6, 0)):
|
||
|
with self.assertWarns(pymysql.err.Warning) as cm:
|
||
|
c.execute('set global secure_auth=0')
|
||
|
else:
|
||
|
c.execute('set global secure_auth=0')
|
||
|
with TempUser(c, 'old_pass_user@localhost',
|
||
|
self.databases[0]['db'], password=db['password']) as u:
|
||
|
cur = pymysql.connect(user='old_pass_user', **db).cursor()
|
||
|
cur.execute("SELECT VERSION()")
|
||
|
c.execute('set global secure_auth=%r' % secure_auth_setting)
|
||
|
|
||
|
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
|
||
|
@unittest2.skipUnless(sha256_password_found, "no sha256 password authentication plugin found")
|
||
|
def testAuthSHA256(self):
|
||
|
c = self.connections[0].cursor()
|
||
|
with TempUser(c, 'pymysql_sha256@localhost',
|
||
|
self.databases[0]['db'], 'sha256_password') as u:
|
||
|
if self.mysql_server_is(self.connections[0], (5, 7, 0)):
|
||
|
c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' ='Sh@256Pa33'")
|
||
|
else:
|
||
|
c.execute('SET old_passwords = 2')
|
||
|
c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' = PASSWORD('Sh@256Pa33')")
|
||
|
db = self.db.copy()
|
||
|
db['password'] = "Sh@256Pa33"
|
||
|
# not implemented yet so thows error
|
||
|
with self.assertRaises(pymysql.err.OperationalError):
|
||
|
pymysql.connect(user='pymysql_256', **db)
|
||
|
|
||
|
class TestConnection(base.PyMySQLTestCase):
|
||
|
|
||
|
def test_utf8mb4(self):
|
||
|
"""This test requires MySQL >= 5.5"""
|
||
|
arg = self.databases[0].copy()
|
||
|
arg['charset'] = 'utf8mb4'
|
||
|
conn = pymysql.connect(**arg)
|
||
|
|
||
|
def test_largedata(self):
|
||
|
"""Large query and response (>=16MB)"""
|
||
|
cur = self.connections[0].cursor()
|
||
|
cur.execute("SELECT @@max_allowed_packet")
|
||
|
if cur.fetchone()[0] < 16*1024*1024 + 10:
|
||
|
print("Set max_allowed_packet to bigger than 17MB")
|
||
|
return
|
||
|
t = 'a' * (16*1024*1024)
|
||
|
cur.execute("SELECT '" + t + "'")
|
||
|
assert cur.fetchone()[0] == t
|
||
|
|
||
|
def test_autocommit(self):
|
||
|
con = self.connections[0]
|
||
|
self.assertFalse(con.get_autocommit())
|
||
|
|
||
|
cur = con.cursor()
|
||
|
cur.execute("SET AUTOCOMMIT=1")
|
||
|
self.assertTrue(con.get_autocommit())
|
||
|
|
||
|
con.autocommit(False)
|
||
|
self.assertFalse(con.get_autocommit())
|
||
|
cur.execute("SELECT @@AUTOCOMMIT")
|
||
|
self.assertEqual(cur.fetchone()[0], 0)
|
||
|
|
||
|
def test_select_db(self):
|
||
|
con = self.connections[0]
|
||
|
current_db = self.databases[0]['db']
|
||
|
other_db = self.databases[1]['db']
|
||
|
|
||
|
cur = con.cursor()
|
||
|
cur.execute('SELECT database()')
|
||
|
self.assertEqual(cur.fetchone()[0], current_db)
|
||
|
|
||
|
con.select_db(other_db)
|
||
|
cur.execute('SELECT database()')
|
||
|
self.assertEqual(cur.fetchone()[0], other_db)
|
||
|
|
||
|
def test_connection_gone_away(self):
|
||
|
"""
|
||
|
http://dev.mysql.com/doc/refman/5.0/en/gone-away.html
|
||
|
http://dev.mysql.com/doc/refman/5.0/en/error-messages-client.html#error_cr_server_gone_error
|
||
|
"""
|
||
|
con = self.connections[0]
|
||
|
cur = con.cursor()
|
||
|
cur.execute("SET wait_timeout=1")
|
||
|
time.sleep(2)
|
||
|
with self.assertRaises(pymysql.OperationalError) as cm:
|
||
|
cur.execute("SELECT 1+1")
|
||
|
# error occures while reading, not writing because of socket buffer.
|
||
|
#self.assertEqual(cm.exception.args[0], 2006)
|
||
|
self.assertIn(cm.exception.args[0], (2006, 2013))
|
||
|
|
||
|
def test_init_command(self):
|
||
|
conn = pymysql.connect(
|
||
|
init_command='SELECT "bar"; SELECT "baz"',
|
||
|
**self.databases[0]
|
||
|
)
|
||
|
c = conn.cursor()
|
||
|
c.execute('select "foobar";')
|
||
|
self.assertEqual(('foobar',), c.fetchone())
|
||
|
conn.close()
|
||
|
with self.assertRaises(pymysql.err.Error):
|
||
|
conn.ping(reconnect=False)
|
||
|
|
||
|
def test_read_default_group(self):
|
||
|
conn = pymysql.connect(
|
||
|
read_default_group='client',
|
||
|
**self.databases[0]
|
||
|
)
|
||
|
self.assertTrue(conn.open)
|
||
|
|
||
|
def test_context(self):
|
||
|
with self.assertRaises(ValueError):
|
||
|
c = pymysql.connect(**self.databases[0])
|
||
|
with c as cur:
|
||
|
cur.execute('create table test ( a int )')
|
||
|
c.begin()
|
||
|
cur.execute('insert into test values ((1))')
|
||
|
raise ValueError('pseudo abort')
|
||
|
c.commit()
|
||
|
c = pymysql.connect(**self.databases[0])
|
||
|
with c as cur:
|
||
|
cur.execute('select count(*) from test')
|
||
|
self.assertEqual(0, cur.fetchone()[0])
|
||
|
cur.execute('insert into test values ((1))')
|
||
|
with c as cur:
|
||
|
cur.execute('select count(*) from test')
|
||
|
self.assertEqual(1,cur.fetchone()[0])
|
||
|
cur.execute('drop table test')
|
||
|
|
||
|
def test_set_charset(self):
|
||
|
c = pymysql.connect(**self.databases[0])
|
||
|
c.set_charset('utf8')
|
||
|
# TODO validate setting here
|
||
|
|
||
|
def test_defer_connect(self):
|
||
|
import socket
|
||
|
for db in self.databases:
|
||
|
d = db.copy()
|
||
|
try:
|
||
|
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||
|
sock.connect(d['unix_socket'])
|
||
|
except KeyError:
|
||
|
sock = socket.create_connection(
|
||
|
(d.get('host', 'localhost'), d.get('port', 3306)))
|
||
|
for k in ['unix_socket', 'host', 'port']:
|
||
|
try:
|
||
|
del d[k]
|
||
|
except KeyError:
|
||
|
pass
|
||
|
|
||
|
c = pymysql.connect(defer_connect=True, **d)
|
||
|
self.assertFalse(c.open)
|
||
|
c.connect(sock)
|
||
|
c.close()
|
||
|
sock.close()
|
||
|
|
||
|
@unittest2.skipUnless(sys.version_info[0:2] >= (3,2), "required py-3.2")
|
||
|
def test_no_delay_warning(self):
|
||
|
current_db = self.databases[0].copy()
|
||
|
current_db['no_delay'] = True
|
||
|
with self.assertWarns(DeprecationWarning) as cm:
|
||
|
conn = pymysql.connect(**current_db)
|
||
|
|
||
|
|
||
|
# A custom type and function to escape it
|
||
|
class Foo(object):
|
||
|
value = "bar"
|
||
|
|
||
|
|
||
|
def escape_foo(x, d):
|
||
|
return x.value
|
||
|
|
||
|
|
||
|
class TestEscape(base.PyMySQLTestCase):
|
||
|
def test_escape_string(self):
|
||
|
con = self.connections[0]
|
||
|
cur = con.cursor()
|
||
|
|
||
|
self.assertEqual(con.escape("foo'bar"), "'foo\\'bar'")
|
||
|
# added NO_AUTO_CREATE_USER as not including it in 5.7 generates warnings
|
||
|
cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'")
|
||
|
self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
|
||
|
|
||
|
def test_escape_builtin_encoders(self):
|
||
|
con = self.connections[0]
|
||
|
cur = con.cursor()
|
||
|
|
||
|
val = datetime.datetime(2012, 3, 4, 5, 6)
|
||
|
self.assertEqual(con.escape(val, con.encoders), "'2012-03-04 05:06:00'")
|
||
|
|
||
|
def test_escape_custom_object(self):
|
||
|
con = self.connections[0]
|
||
|
cur = con.cursor()
|
||
|
|
||
|
mapping = {Foo: escape_foo}
|
||
|
self.assertEqual(con.escape(Foo(), mapping), "bar")
|
||
|
|
||
|
def test_escape_fallback_encoder(self):
|
||
|
con = self.connections[0]
|
||
|
cur = con.cursor()
|
||
|
|
||
|
class Custom(str):
|
||
|
pass
|
||
|
|
||
|
mapping = {text_type: pymysql.escape_string}
|
||
|
self.assertEqual(con.escape(Custom('foobar'), mapping), "'foobar'")
|
||
|
|
||
|
def test_escape_no_default(self):
|
||
|
con = self.connections[0]
|
||
|
cur = con.cursor()
|
||
|
|
||
|
self.assertRaises(TypeError, con.escape, 42, {})
|
||
|
|
||
|
def test_escape_dict_value(self):
|
||
|
con = self.connections[0]
|
||
|
cur = con.cursor()
|
||
|
|
||
|
mapping = con.encoders.copy()
|
||
|
mapping[Foo] = escape_foo
|
||
|
self.assertEqual(con.escape({'foo': Foo()}, mapping), {'foo': "bar"})
|
||
|
|
||
|
def test_escape_list_item(self):
|
||
|
con = self.connections[0]
|
||
|
cur = con.cursor()
|
||
|
|
||
|
mapping = con.encoders.copy()
|
||
|
mapping[Foo] = escape_foo
|
||
|
self.assertEqual(con.escape([Foo()], mapping), "(bar)")
|
||
|
|
||
|
def test_previous_cursor_not_closed(self):
|
||
|
con = self.connections[0]
|
||
|
cur1 = con.cursor()
|
||
|
cur1.execute("SELECT 1; SELECT 2")
|
||
|
cur2 = con.cursor()
|
||
|
cur2.execute("SELECT 3")
|
||
|
self.assertEqual(cur2.fetchone()[0], 3)
|
||
|
|
||
|
def test_commit_during_multi_result(self):
|
||
|
con = self.connections[0]
|
||
|
cur = con.cursor()
|
||
|
cur.execute("SELECT 1; SELECT 2")
|
||
|
con.commit()
|
||
|
cur.execute("SELECT 3")
|
||
|
self.assertEqual(cur.fetchone()[0], 3)
|