SP/web2py/gluon/contrib/pymysql/tests/test_connection.py
Saturneic 064f602b1a Add.
2018-10-25 23:33:13 +08:00

577 lines
24 KiB
Python

import datetime
import sys
import time
import unittest2
import pymysql
from pymysql.tests import base
from pymysql._compat import text_type
class TempUser:
def __init__(self, c, user, db, auth=None, authdata=None, password=None):
self._c = c
self._user = user
self._db = db
create = "CREATE USER " + user
if password is not None:
create += " IDENTIFIED BY '%s'" % password
elif auth is not None:
create += " IDENTIFIED WITH %s" % auth
if authdata is not None:
create += " AS '%s'" % authdata
try:
c.execute(create)
self._created = True
except pymysql.err.InternalError:
# already exists - TODO need to check the same plugin applies
self._created = False
try:
c.execute("GRANT SELECT ON %s.* TO %s" % (db, user))
self._grant = True
except pymysql.err.InternalError:
self._grant = False
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if self._grant:
self._c.execute("REVOKE SELECT ON %s.* FROM %s" % (self._db, self._user))
if self._created:
self._c.execute("DROP USER %s" % self._user)
class TestAuthentication(base.PyMySQLTestCase):
socket_auth = False
socket_found = False
two_questions_found = False
three_attempts_found = False
pam_found = False
mysql_old_password_found = False
sha256_password_found = False
import os
osuser = os.environ.get('USER')
# socket auth requires the current user and for the connection to be a socket
# rest do grants @localhost due to incomplete logic - TODO change to @% then
db = base.PyMySQLTestCase.databases[0].copy()
socket_auth = db.get('unix_socket') is not None \
and db.get('host') in ('localhost', '127.0.0.1')
cur = pymysql.connect(**db).cursor()
del db['user']
cur.execute("SHOW PLUGINS")
for r in cur:
if (r[1], r[2]) != (u'ACTIVE', u'AUTHENTICATION'):
continue
if r[3] == u'auth_socket.so':
socket_plugin_name = r[0]
socket_found = True
elif r[3] == u'dialog_examples.so':
if r[0] == 'two_questions':
two_questions_found = True
elif r[0] == 'three_attempts':
three_attempts_found = True
elif r[0] == u'pam':
pam_found = True
pam_plugin_name = r[3].split('.')[0]
if pam_plugin_name == 'auth_pam':
pam_plugin_name = 'pam'
# MySQL: authentication_pam
# https://dev.mysql.com/doc/refman/5.5/en/pam-authentication-plugin.html
# MariaDB: pam
# https://mariadb.com/kb/en/mariadb/pam-authentication-plugin/
# Names differ but functionality is close
elif r[0] == u'mysql_old_password':
mysql_old_password_found = True
elif r[0] == u'sha256_password':
sha256_password_found = True
#else:
# print("plugin: %r" % r[0])
def test_plugin(self):
# Bit of an assumption that the current user is a native password
self.assertEqual('mysql_native_password', self.connections[0]._auth_plugin_name)
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipIf(socket_found, "socket plugin already installed")
def testSocketAuthInstallPlugin(self):
# needs plugin. lets install it.
cur = self.connections[0].cursor()
try:
cur.execute("install plugin auth_socket soname 'auth_socket.so'")
TestAuthentication.socket_found = True
self.socket_plugin_name = 'auth_socket'
self.realtestSocketAuth()
except pymysql.err.InternalError:
try:
cur.execute("install soname 'auth_socket'")
TestAuthentication.socket_found = True
self.socket_plugin_name = 'unix_socket'
self.realtestSocketAuth()
except pymysql.err.InternalError:
TestAuthentication.socket_found = False
raise unittest2.SkipTest('we couldn\'t install the socket plugin')
finally:
if TestAuthentication.socket_found:
cur.execute("uninstall plugin %s" % self.socket_plugin_name)
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipUnless(socket_found, "no socket plugin")
def testSocketAuth(self):
self.realtestSocketAuth()
def realtestSocketAuth(self):
with TempUser(self.connections[0].cursor(), TestAuthentication.osuser + '@localhost',
self.databases[0]['db'], self.socket_plugin_name) as u:
c = pymysql.connect(user=TestAuthentication.osuser, **self.db)
class Dialog(object):
fail=False
def __init__(self, con):
self.fail=TestAuthentication.Dialog.fail
pass
def prompt(self, echo, prompt):
if self.fail:
self.fail=False
return b'bad guess at a password'
return self.m.get(prompt)
class DialogHandler(object):
def __init__(self, con):
self.con=con
def authenticate(self, pkt):
while True:
flag = pkt.read_uint8()
echo = (flag & 0x06) == 0x02
last = (flag & 0x01) == 0x01
prompt = pkt.read_all()
if prompt == b'Password, please:':
self.con.write_packet(b'stillnotverysecret\0')
else:
self.con.write_packet(b'no idea what to do with this prompt\0')
pkt = self.con._read_packet()
pkt.check_error()
if pkt.is_ok_packet() or last:
break
return pkt
class DefectiveHandler(object):
def __init__(self, con):
self.con=con
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipIf(two_questions_found, "two_questions plugin already installed")
def testDialogAuthTwoQuestionsInstallPlugin(self):
# needs plugin. lets install it.
cur = self.connections[0].cursor()
try:
cur.execute("install plugin two_questions soname 'dialog_examples.so'")
TestAuthentication.two_questions_found = True
self.realTestDialogAuthTwoQuestions()
except pymysql.err.InternalError:
raise unittest2.SkipTest('we couldn\'t install the two_questions plugin')
finally:
if TestAuthentication.two_questions_found:
cur.execute("uninstall plugin two_questions")
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipUnless(two_questions_found, "no two questions auth plugin")
def testDialogAuthTwoQuestions(self):
self.realTestDialogAuthTwoQuestions()
def realTestDialogAuthTwoQuestions(self):
TestAuthentication.Dialog.fail=False
TestAuthentication.Dialog.m = {b'Password, please:': b'notverysecret',
b'Are you sure ?': b'yes, of course'}
with TempUser(self.connections[0].cursor(), 'pymysql_2q@localhost',
self.databases[0]['db'], 'two_questions', 'notverysecret') as u:
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_2q', **self.db)
pymysql.connect(user='pymysql_2q', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipIf(three_attempts_found, "three_attempts plugin already installed")
def testDialogAuthThreeAttemptsQuestionsInstallPlugin(self):
# needs plugin. lets install it.
cur = self.connections[0].cursor()
try:
cur.execute("install plugin three_attempts soname 'dialog_examples.so'")
TestAuthentication.three_attempts_found = True
self.realTestDialogAuthThreeAttempts()
except pymysql.err.InternalError:
raise unittest2.SkipTest('we couldn\'t install the three_attempts plugin')
finally:
if TestAuthentication.three_attempts_found:
cur.execute("uninstall plugin three_attempts")
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipUnless(three_attempts_found, "no three attempts plugin")
def testDialogAuthThreeAttempts(self):
self.realTestDialogAuthThreeAttempts()
def realTestDialogAuthThreeAttempts(self):
TestAuthentication.Dialog.m = {b'Password, please:': b'stillnotverysecret'}
TestAuthentication.Dialog.fail=True # fail just once. We've got three attempts after all
with TempUser(self.connections[0].cursor(), 'pymysql_3a@localhost',
self.databases[0]['db'], 'three_attempts', 'stillnotverysecret') as u:
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DialogHandler}, **self.db)
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': object}, **self.db)
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DefectiveHandler}, **self.db)
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'notdialogplugin': TestAuthentication.Dialog}, **self.db)
TestAuthentication.Dialog.m = {b'Password, please:': b'I do not know'}
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
TestAuthentication.Dialog.m = {b'Password, please:': None}
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipIf(pam_found, "pam plugin already installed")
@unittest2.skipIf(os.environ.get('PASSWORD') is None, "PASSWORD env var required")
@unittest2.skipIf(os.environ.get('PAMSERVICE') is None, "PAMSERVICE env var required")
def testPamAuthInstallPlugin(self):
# needs plugin. lets install it.
cur = self.connections[0].cursor()
try:
cur.execute("install plugin pam soname 'auth_pam.so'")
TestAuthentication.pam_found = True
self.realTestPamAuth()
except pymysql.err.InternalError:
raise unittest2.SkipTest('we couldn\'t install the auth_pam plugin')
finally:
if TestAuthentication.pam_found:
cur.execute("uninstall plugin pam")
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipUnless(pam_found, "no pam plugin")
@unittest2.skipIf(os.environ.get('PASSWORD') is None, "PASSWORD env var required")
@unittest2.skipIf(os.environ.get('PAMSERVICE') is None, "PAMSERVICE env var required")
def testPamAuth(self):
self.realTestPamAuth()
def realTestPamAuth(self):
db = self.db.copy()
import os
db['password'] = os.environ.get('PASSWORD')
cur = self.connections[0].cursor()
try:
cur.execute('show grants for ' + TestAuthentication.osuser + '@localhost')
grants = cur.fetchone()[0]
cur.execute('drop user ' + TestAuthentication.osuser + '@localhost')
except pymysql.OperationalError as e:
# assuming the user doesn't exist which is ok too
self.assertEqual(1045, e.args[0])
grants = None
with TempUser(cur, TestAuthentication.osuser + '@localhost',
self.databases[0]['db'], 'pam', os.environ.get('PAMSERVICE')) as u:
try:
c = pymysql.connect(user=TestAuthentication.osuser, **db)
db['password'] = 'very bad guess at password'
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user=TestAuthentication.osuser,
auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
**self.db)
except pymysql.OperationalError as e:
self.assertEqual(1045, e.args[0])
# we had 'bad guess at password' work with pam. Well at least we get a permission denied here
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user=TestAuthentication.osuser,
auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
**self.db)
if grants:
# recreate the user
cur.execute(grants)
# select old_password("crummy p\tassword");
#| old_password("crummy p\tassword") |
#| 2a01785203b08770 |
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipUnless(mysql_old_password_found, "no mysql_old_password plugin")
def testMySQLOldPasswordAuth(self):
if self.mysql_server_is(self.connections[0], (5, 7, 0)):
raise unittest2.SkipTest('Old passwords aren\'t supported in 5.7')
# pymysql.err.OperationalError: (1045, "Access denied for user 'old_pass_user'@'localhost' (using password: YES)")
# from login in MySQL-5.6
if self.mysql_server_is(self.connections[0], (5, 6, 0)):
raise unittest2.SkipTest('Old passwords don\'t authenticate in 5.6')
db = self.db.copy()
db['password'] = "crummy p\tassword"
with self.connections[0] as c:
# deprecated in 5.6
if sys.version_info[0:2] >= (3,2) and self.mysql_server_is(self.connections[0], (5, 6, 0)):
with self.assertWarns(pymysql.err.Warning) as cm:
c.execute("SELECT OLD_PASSWORD('%s')" % db['password'])
else:
c.execute("SELECT OLD_PASSWORD('%s')" % db['password'])
v = c.fetchone()[0]
self.assertEqual(v, '2a01785203b08770')
# only works in MariaDB and MySQL-5.6 - can't separate out by version
#if self.mysql_server_is(self.connections[0], (5, 5, 0)):
# with TempUser(c, 'old_pass_user@localhost',
# self.databases[0]['db'], 'mysql_old_password', '2a01785203b08770') as u:
# cur = pymysql.connect(user='old_pass_user', **db).cursor()
# cur.execute("SELECT VERSION()")
c.execute("SELECT @@secure_auth")
secure_auth_setting = c.fetchone()[0]
c.execute('set old_passwords=1')
# pymysql.err.Warning: 'pre-4.1 password hash' is deprecated and will be removed in a future release. Please use post-4.1 password hash instead
if sys.version_info[0:2] >= (3,2) and self.mysql_server_is(self.connections[0], (5, 6, 0)):
with self.assertWarns(pymysql.err.Warning) as cm:
c.execute('set global secure_auth=0')
else:
c.execute('set global secure_auth=0')
with TempUser(c, 'old_pass_user@localhost',
self.databases[0]['db'], password=db['password']) as u:
cur = pymysql.connect(user='old_pass_user', **db).cursor()
cur.execute("SELECT VERSION()")
c.execute('set global secure_auth=%r' % secure_auth_setting)
@unittest2.skipUnless(socket_auth, "connection to unix_socket required")
@unittest2.skipUnless(sha256_password_found, "no sha256 password authentication plugin found")
def testAuthSHA256(self):
c = self.connections[0].cursor()
with TempUser(c, 'pymysql_sha256@localhost',
self.databases[0]['db'], 'sha256_password') as u:
if self.mysql_server_is(self.connections[0], (5, 7, 0)):
c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' ='Sh@256Pa33'")
else:
c.execute('SET old_passwords = 2')
c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' = PASSWORD('Sh@256Pa33')")
db = self.db.copy()
db['password'] = "Sh@256Pa33"
# not implemented yet so thows error
with self.assertRaises(pymysql.err.OperationalError):
pymysql.connect(user='pymysql_256', **db)
class TestConnection(base.PyMySQLTestCase):
def test_utf8mb4(self):
"""This test requires MySQL >= 5.5"""
arg = self.databases[0].copy()
arg['charset'] = 'utf8mb4'
conn = pymysql.connect(**arg)
def test_largedata(self):
"""Large query and response (>=16MB)"""
cur = self.connections[0].cursor()
cur.execute("SELECT @@max_allowed_packet")
if cur.fetchone()[0] < 16*1024*1024 + 10:
print("Set max_allowed_packet to bigger than 17MB")
return
t = 'a' * (16*1024*1024)
cur.execute("SELECT '" + t + "'")
assert cur.fetchone()[0] == t
def test_autocommit(self):
con = self.connections[0]
self.assertFalse(con.get_autocommit())
cur = con.cursor()
cur.execute("SET AUTOCOMMIT=1")
self.assertTrue(con.get_autocommit())
con.autocommit(False)
self.assertFalse(con.get_autocommit())
cur.execute("SELECT @@AUTOCOMMIT")
self.assertEqual(cur.fetchone()[0], 0)
def test_select_db(self):
con = self.connections[0]
current_db = self.databases[0]['db']
other_db = self.databases[1]['db']
cur = con.cursor()
cur.execute('SELECT database()')
self.assertEqual(cur.fetchone()[0], current_db)
con.select_db(other_db)
cur.execute('SELECT database()')
self.assertEqual(cur.fetchone()[0], other_db)
def test_connection_gone_away(self):
"""
http://dev.mysql.com/doc/refman/5.0/en/gone-away.html
http://dev.mysql.com/doc/refman/5.0/en/error-messages-client.html#error_cr_server_gone_error
"""
con = self.connections[0]
cur = con.cursor()
cur.execute("SET wait_timeout=1")
time.sleep(2)
with self.assertRaises(pymysql.OperationalError) as cm:
cur.execute("SELECT 1+1")
# error occures while reading, not writing because of socket buffer.
#self.assertEqual(cm.exception.args[0], 2006)
self.assertIn(cm.exception.args[0], (2006, 2013))
def test_init_command(self):
conn = pymysql.connect(
init_command='SELECT "bar"; SELECT "baz"',
**self.databases[0]
)
c = conn.cursor()
c.execute('select "foobar";')
self.assertEqual(('foobar',), c.fetchone())
conn.close()
with self.assertRaises(pymysql.err.Error):
conn.ping(reconnect=False)
def test_read_default_group(self):
conn = pymysql.connect(
read_default_group='client',
**self.databases[0]
)
self.assertTrue(conn.open)
def test_context(self):
with self.assertRaises(ValueError):
c = pymysql.connect(**self.databases[0])
with c as cur:
cur.execute('create table test ( a int )')
c.begin()
cur.execute('insert into test values ((1))')
raise ValueError('pseudo abort')
c.commit()
c = pymysql.connect(**self.databases[0])
with c as cur:
cur.execute('select count(*) from test')
self.assertEqual(0, cur.fetchone()[0])
cur.execute('insert into test values ((1))')
with c as cur:
cur.execute('select count(*) from test')
self.assertEqual(1,cur.fetchone()[0])
cur.execute('drop table test')
def test_set_charset(self):
c = pymysql.connect(**self.databases[0])
c.set_charset('utf8')
# TODO validate setting here
def test_defer_connect(self):
import socket
for db in self.databases:
d = db.copy()
try:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(d['unix_socket'])
except KeyError:
sock = socket.create_connection(
(d.get('host', 'localhost'), d.get('port', 3306)))
for k in ['unix_socket', 'host', 'port']:
try:
del d[k]
except KeyError:
pass
c = pymysql.connect(defer_connect=True, **d)
self.assertFalse(c.open)
c.connect(sock)
c.close()
sock.close()
@unittest2.skipUnless(sys.version_info[0:2] >= (3,2), "required py-3.2")
def test_no_delay_warning(self):
current_db = self.databases[0].copy()
current_db['no_delay'] = True
with self.assertWarns(DeprecationWarning) as cm:
conn = pymysql.connect(**current_db)
# A custom type and function to escape it
class Foo(object):
value = "bar"
def escape_foo(x, d):
return x.value
class TestEscape(base.PyMySQLTestCase):
def test_escape_string(self):
con = self.connections[0]
cur = con.cursor()
self.assertEqual(con.escape("foo'bar"), "'foo\\'bar'")
# added NO_AUTO_CREATE_USER as not including it in 5.7 generates warnings
cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'")
self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
def test_escape_builtin_encoders(self):
con = self.connections[0]
cur = con.cursor()
val = datetime.datetime(2012, 3, 4, 5, 6)
self.assertEqual(con.escape(val, con.encoders), "'2012-03-04 05:06:00'")
def test_escape_custom_object(self):
con = self.connections[0]
cur = con.cursor()
mapping = {Foo: escape_foo}
self.assertEqual(con.escape(Foo(), mapping), "bar")
def test_escape_fallback_encoder(self):
con = self.connections[0]
cur = con.cursor()
class Custom(str):
pass
mapping = {text_type: pymysql.escape_string}
self.assertEqual(con.escape(Custom('foobar'), mapping), "'foobar'")
def test_escape_no_default(self):
con = self.connections[0]
cur = con.cursor()
self.assertRaises(TypeError, con.escape, 42, {})
def test_escape_dict_value(self):
con = self.connections[0]
cur = con.cursor()
mapping = con.encoders.copy()
mapping[Foo] = escape_foo
self.assertEqual(con.escape({'foo': Foo()}, mapping), {'foo': "bar"})
def test_escape_list_item(self):
con = self.connections[0]
cur = con.cursor()
mapping = con.encoders.copy()
mapping[Foo] = escape_foo
self.assertEqual(con.escape([Foo()], mapping), "(bar)")
def test_previous_cursor_not_closed(self):
con = self.connections[0]
cur1 = con.cursor()
cur1.execute("SELECT 1; SELECT 2")
cur2 = con.cursor()
cur2.execute("SELECT 3")
self.assertEqual(cur2.fetchone()[0], 3)
def test_commit_during_multi_result(self):
con = self.connections[0]
cur = con.cursor()
cur.execute("SELECT 1; SELECT 2")
con.commit()
cur.execute("SELECT 3")
self.assertEqual(cur.fetchone()[0], 3)