1500 lines
54 KiB
Python
1500 lines
54 KiB
Python
|
# Python implementation of the MySQL client-server protocol
|
||
|
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
|
||
|
# Error codes:
|
||
|
# http://dev.mysql.com/doc/refman/5.5/en/error-messages-client.html
|
||
|
from __future__ import print_function
|
||
|
from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON
|
||
|
|
||
|
import errno
|
||
|
from functools import partial
|
||
|
import hashlib
|
||
|
import io
|
||
|
import os
|
||
|
import socket
|
||
|
import struct
|
||
|
import sys
|
||
|
import traceback
|
||
|
import warnings
|
||
|
|
||
|
from .charset import MBLENGTH, charset_by_name, charset_by_id
|
||
|
from .constants import CLIENT, COMMAND, FIELD_TYPE, SERVER_STATUS
|
||
|
from .converters import escape_item, escape_string, through, conversions as _conv
|
||
|
from .cursors import Cursor
|
||
|
from .optionfile import Parser
|
||
|
from .util import byte2int, int2byte
|
||
|
from . import err
|
||
|
|
||
|
try:
|
||
|
import ssl
|
||
|
SSL_ENABLED = True
|
||
|
except ImportError:
|
||
|
ssl = None
|
||
|
SSL_ENABLED = False
|
||
|
|
||
|
try:
|
||
|
import getpass
|
||
|
DEFAULT_USER = getpass.getuser()
|
||
|
del getpass
|
||
|
except (ImportError, KeyError):
|
||
|
# KeyError occurs when there's no entry in OS database for a current user.
|
||
|
DEFAULT_USER = None
|
||
|
|
||
|
|
||
|
DEBUG = False
|
||
|
|
||
|
_py_version = sys.version_info[:2]
|
||
|
|
||
|
|
||
|
# socket.makefile() in Python 2 is not usable because very inefficient and
|
||
|
# bad behavior about timeout.
|
||
|
# XXX: ._socketio doesn't work under IronPython.
|
||
|
if _py_version == (2, 7) and not IRONPYTHON:
|
||
|
# read method of file-like returned by sock.makefile() is very slow.
|
||
|
# So we copy io-based one from Python 3.
|
||
|
from ._socketio import SocketIO
|
||
|
|
||
|
def _makefile(sock, mode):
|
||
|
return io.BufferedReader(SocketIO(sock, mode))
|
||
|
elif _py_version == (2, 6):
|
||
|
# Python 2.6 doesn't have fast io module.
|
||
|
# So we make original one.
|
||
|
class SockFile(object):
|
||
|
def __init__(self, sock):
|
||
|
self._sock = sock
|
||
|
|
||
|
def read(self, n):
|
||
|
read = self._sock.recv(n)
|
||
|
if len(read) == n:
|
||
|
return read
|
||
|
while True:
|
||
|
data = self._sock.recv(n-len(read))
|
||
|
if not data:
|
||
|
return read
|
||
|
read += data
|
||
|
if len(read) == n:
|
||
|
return read
|
||
|
|
||
|
def _makefile(sock, mode):
|
||
|
assert mode == 'rb'
|
||
|
return SockFile(sock)
|
||
|
else:
|
||
|
# socket.makefile in Python 3 is nice.
|
||
|
def _makefile(sock, mode):
|
||
|
return sock.makefile(mode)
|
||
|
|
||
|
|
||
|
TEXT_TYPES = set([
|
||
|
FIELD_TYPE.BIT,
|
||
|
FIELD_TYPE.BLOB,
|
||
|
FIELD_TYPE.LONG_BLOB,
|
||
|
FIELD_TYPE.MEDIUM_BLOB,
|
||
|
FIELD_TYPE.STRING,
|
||
|
FIELD_TYPE.TINY_BLOB,
|
||
|
FIELD_TYPE.VAR_STRING,
|
||
|
FIELD_TYPE.VARCHAR,
|
||
|
FIELD_TYPE.GEOMETRY])
|
||
|
|
||
|
sha_new = partial(hashlib.new, 'sha1')
|
||
|
|
||
|
NULL_COLUMN = 251
|
||
|
UNSIGNED_CHAR_COLUMN = 251
|
||
|
UNSIGNED_SHORT_COLUMN = 252
|
||
|
UNSIGNED_INT24_COLUMN = 253
|
||
|
UNSIGNED_INT64_COLUMN = 254
|
||
|
|
||
|
DEFAULT_CHARSET = 'latin1'
|
||
|
|
||
|
MAX_PACKET_LEN = 2**24-1
|
||
|
|
||
|
|
||
|
def dump_packet(data): # pragma: no cover
|
||
|
def is_ascii(data):
|
||
|
if 65 <= byte2int(data) <= 122:
|
||
|
if isinstance(data, int):
|
||
|
return chr(data)
|
||
|
return data
|
||
|
return '.'
|
||
|
|
||
|
try:
|
||
|
print("packet length:", len(data))
|
||
|
for i in range(1, 6):
|
||
|
f = sys._getframe(i)
|
||
|
print("call[%d]: %s (line %d)" % (i, f.f_code.co_name, f.f_lineno))
|
||
|
print("-" * 66)
|
||
|
except ValueError:
|
||
|
pass
|
||
|
dump_data = [data[i:i+16] for i in range_type(0, min(len(data), 256), 16)]
|
||
|
for d in dump_data:
|
||
|
print(' '.join(map(lambda x: "{:02X}".format(byte2int(x)), d)) +
|
||
|
' ' * (16 - len(d)) + ' ' * 2 +
|
||
|
''.join(map(lambda x: "{}".format(is_ascii(x)), d)))
|
||
|
print("-" * 66)
|
||
|
print()
|
||
|
|
||
|
|
||
|
def _scramble(password, message):
|
||
|
if not password:
|
||
|
return b''
|
||
|
if DEBUG: print('password=' + str(password))
|
||
|
stage1 = sha_new(password).digest()
|
||
|
stage2 = sha_new(stage1).digest()
|
||
|
s = sha_new()
|
||
|
s.update(message)
|
||
|
s.update(stage2)
|
||
|
result = s.digest()
|
||
|
return _my_crypt(result, stage1)
|
||
|
|
||
|
|
||
|
def _my_crypt(message1, message2):
|
||
|
length = len(message1)
|
||
|
result = b''
|
||
|
for i in range_type(length):
|
||
|
x = (struct.unpack('B', message1[i:i+1])[0] ^
|
||
|
struct.unpack('B', message2[i:i+1])[0])
|
||
|
result += struct.pack('B', x)
|
||
|
return result
|
||
|
|
||
|
# old_passwords support ported from libmysql/password.c
|
||
|
SCRAMBLE_LENGTH_323 = 8
|
||
|
|
||
|
|
||
|
class RandStruct_323(object):
|
||
|
def __init__(self, seed1, seed2):
|
||
|
self.max_value = 0x3FFFFFFF
|
||
|
self.seed1 = seed1 % self.max_value
|
||
|
self.seed2 = seed2 % self.max_value
|
||
|
|
||
|
def my_rnd(self):
|
||
|
self.seed1 = (self.seed1 * 3 + self.seed2) % self.max_value
|
||
|
self.seed2 = (self.seed1 + self.seed2 + 33) % self.max_value
|
||
|
return float(self.seed1) / float(self.max_value)
|
||
|
|
||
|
|
||
|
def _scramble_323(password, message):
|
||
|
hash_pass = _hash_password_323(password)
|
||
|
hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323])
|
||
|
hash_pass_n = struct.unpack(">LL", hash_pass)
|
||
|
hash_message_n = struct.unpack(">LL", hash_message)
|
||
|
|
||
|
rand_st = RandStruct_323(hash_pass_n[0] ^ hash_message_n[0],
|
||
|
hash_pass_n[1] ^ hash_message_n[1])
|
||
|
outbuf = io.BytesIO()
|
||
|
for _ in range_type(min(SCRAMBLE_LENGTH_323, len(message))):
|
||
|
outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64))
|
||
|
extra = int2byte(int(rand_st.my_rnd() * 31))
|
||
|
out = outbuf.getvalue()
|
||
|
outbuf = io.BytesIO()
|
||
|
for c in out:
|
||
|
outbuf.write(int2byte(byte2int(c) ^ byte2int(extra)))
|
||
|
return outbuf.getvalue()
|
||
|
|
||
|
|
||
|
def _hash_password_323(password):
|
||
|
nr = 1345345333
|
||
|
add = 7
|
||
|
nr2 = 0x12345671
|
||
|
|
||
|
# x in py3 is numbers, p27 is chars
|
||
|
for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]:
|
||
|
nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF
|
||
|
nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
|
||
|
add = (add + c) & 0xFFFFFFFF
|
||
|
|
||
|
r1 = nr & ((1 << 31) - 1) # kill sign bits
|
||
|
r2 = nr2 & ((1 << 31) - 1)
|
||
|
return struct.pack(">LL", r1, r2)
|
||
|
|
||
|
|
||
|
def pack_int24(n):
|
||
|
return struct.pack('<I', n)[:3]
|
||
|
|
||
|
# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
|
||
|
def lenenc_int(i):
|
||
|
if (i < 0):
|
||
|
raise ValueError("Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i)
|
||
|
elif (i < 0xfb):
|
||
|
return int2byte(i)
|
||
|
elif (i < (1 << 16)):
|
||
|
return b'\xfc' + struct.pack('<H', i)
|
||
|
elif (i < (1 << 24)):
|
||
|
return b'\xfd' + struct.pack('<I', i)[:3]
|
||
|
elif (i < (1 << 64)):
|
||
|
return b'\xfe' + struct.pack('<Q', i)
|
||
|
else:
|
||
|
raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger" % (i, (1 << 64)))
|
||
|
|
||
|
class MysqlPacket(object):
|
||
|
"""Representation of a MySQL response packet.
|
||
|
|
||
|
Provides an interface for reading/parsing the packet results.
|
||
|
"""
|
||
|
__slots__ = ('_position', '_data')
|
||
|
|
||
|
def __init__(self, data, encoding):
|
||
|
self._position = 0
|
||
|
self._data = data
|
||
|
|
||
|
def get_all_data(self):
|
||
|
return self._data
|
||
|
|
||
|
def read(self, size):
|
||
|
"""Read the first 'size' bytes in packet and advance cursor past them."""
|
||
|
result = self._data[self._position:(self._position+size)]
|
||
|
if len(result) != size:
|
||
|
error = ('Result length not requested length:\n'
|
||
|
'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
|
||
|
% (size, len(result), self._position, len(self._data)))
|
||
|
if DEBUG:
|
||
|
print(error)
|
||
|
self.dump()
|
||
|
raise AssertionError(error)
|
||
|
self._position += size
|
||
|
return result
|
||
|
|
||
|
def read_all(self):
|
||
|
"""Read all remaining data in the packet.
|
||
|
|
||
|
(Subsequent read() will return errors.)
|
||
|
"""
|
||
|
result = self._data[self._position:]
|
||
|
self._position = None # ensure no subsequent read()
|
||
|
return result
|
||
|
|
||
|
def advance(self, length):
|
||
|
"""Advance the cursor in data buffer 'length' bytes."""
|
||
|
new_position = self._position + length
|
||
|
if new_position < 0 or new_position > len(self._data):
|
||
|
raise Exception('Invalid advance amount (%s) for cursor. '
|
||
|
'Position=%s' % (length, new_position))
|
||
|
self._position = new_position
|
||
|
|
||
|
def rewind(self, position=0):
|
||
|
"""Set the position of the data buffer cursor to 'position'."""
|
||
|
if position < 0 or position > len(self._data):
|
||
|
raise Exception("Invalid position to rewind cursor to: %s." % position)
|
||
|
self._position = position
|
||
|
|
||
|
def get_bytes(self, position, length=1):
|
||
|
"""Get 'length' bytes starting at 'position'.
|
||
|
|
||
|
Position is start of payload (first four packet header bytes are not
|
||
|
included) starting at index '0'.
|
||
|
|
||
|
No error checking is done. If requesting outside end of buffer
|
||
|
an empty string (or string shorter than 'length') may be returned!
|
||
|
"""
|
||
|
return self._data[position:(position+length)]
|
||
|
|
||
|
if PY2:
|
||
|
def read_uint8(self):
|
||
|
result = ord(self._data[self._position])
|
||
|
self._position += 1
|
||
|
return result
|
||
|
else:
|
||
|
def read_uint8(self):
|
||
|
result = self._data[self._position]
|
||
|
self._position += 1
|
||
|
return result
|
||
|
|
||
|
def read_uint16(self):
|
||
|
result = struct.unpack_from('<H', self._data, self._position)[0]
|
||
|
self._position += 2
|
||
|
return result
|
||
|
|
||
|
def read_uint24(self):
|
||
|
low, high = struct.unpack_from('<HB', self._data, self._position)
|
||
|
self._position += 3
|
||
|
return low + (high << 16)
|
||
|
|
||
|
def read_uint32(self):
|
||
|
result = struct.unpack_from('<I', self._data, self._position)[0]
|
||
|
self._position += 4
|
||
|
return result
|
||
|
|
||
|
def read_uint64(self):
|
||
|
result = struct.unpack_from('<Q', self._data, self._position)[0]
|
||
|
self._position += 8
|
||
|
return result
|
||
|
|
||
|
def read_string(self):
|
||
|
end_pos = self._data.find(b'\0', self._position)
|
||
|
if end_pos < 0:
|
||
|
return None
|
||
|
result = self._data[self._position:end_pos]
|
||
|
self._position = end_pos + 1
|
||
|
return result
|
||
|
|
||
|
def read_length_encoded_integer(self):
|
||
|
"""Read a 'Length Coded Binary' number from the data buffer.
|
||
|
|
||
|
Length coded numbers can be anywhere from 1 to 9 bytes depending
|
||
|
on the value of the first byte.
|
||
|
"""
|
||
|
c = self.read_uint8()
|
||
|
if c == NULL_COLUMN:
|
||
|
return None
|
||
|
if c < UNSIGNED_CHAR_COLUMN:
|
||
|
return c
|
||
|
elif c == UNSIGNED_SHORT_COLUMN:
|
||
|
return self.read_uint16()
|
||
|
elif c == UNSIGNED_INT24_COLUMN:
|
||
|
return self.read_uint24()
|
||
|
elif c == UNSIGNED_INT64_COLUMN:
|
||
|
return self.read_uint64()
|
||
|
|
||
|
def read_length_coded_string(self):
|
||
|
"""Read a 'Length Coded String' from the data buffer.
|
||
|
|
||
|
A 'Length Coded String' consists first of a length coded
|
||
|
(unsigned, positive) integer represented in 1-9 bytes followed by
|
||
|
that many bytes of binary data. (For example "cat" would be "3cat".)
|
||
|
"""
|
||
|
length = self.read_length_encoded_integer()
|
||
|
if length is None:
|
||
|
return None
|
||
|
return self.read(length)
|
||
|
|
||
|
def read_struct(self, fmt):
|
||
|
s = struct.Struct(fmt)
|
||
|
result = s.unpack_from(self._data, self._position)
|
||
|
self._position += s.size
|
||
|
return result
|
||
|
|
||
|
def is_ok_packet(self):
|
||
|
# https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
|
||
|
return self._data[0:1] == b'\0' and len(self._data) >= 7
|
||
|
|
||
|
def is_eof_packet(self):
|
||
|
# http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
|
||
|
# Caution: \xFE may be LengthEncodedInteger.
|
||
|
# If \xFE is LengthEncodedInteger header, 8bytes followed.
|
||
|
return self._data[0:1] == b'\xfe' and len(self._data) < 9
|
||
|
|
||
|
def is_auth_switch_request(self):
|
||
|
# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
|
||
|
return self._data[0:1] == b'\xfe'
|
||
|
|
||
|
def is_resultset_packet(self):
|
||
|
field_count = ord(self._data[0:1])
|
||
|
return 1 <= field_count <= 250
|
||
|
|
||
|
def is_load_local_packet(self):
|
||
|
return self._data[0:1] == b'\xfb'
|
||
|
|
||
|
def is_error_packet(self):
|
||
|
return self._data[0:1] == b'\xff'
|
||
|
|
||
|
def check_error(self):
|
||
|
if self.is_error_packet():
|
||
|
self.rewind()
|
||
|
self.advance(1) # field_count == error (we already know that)
|
||
|
errno = self.read_uint16()
|
||
|
if DEBUG: print("errno =", errno)
|
||
|
err.raise_mysql_exception(self._data)
|
||
|
|
||
|
def dump(self):
|
||
|
dump_packet(self._data)
|
||
|
|
||
|
|
||
|
class FieldDescriptorPacket(MysqlPacket):
|
||
|
"""A MysqlPacket that represents a specific column's metadata in the result.
|
||
|
|
||
|
Parsing is automatically done and the results are exported via public
|
||
|
attributes on the class such as: db, table_name, name, length, type_code.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, data, encoding):
|
||
|
MysqlPacket.__init__(self, data, encoding)
|
||
|
self._parse_field_descriptor(encoding)
|
||
|
|
||
|
def _parse_field_descriptor(self, encoding):
|
||
|
"""Parse the 'Field Descriptor' (Metadata) packet.
|
||
|
|
||
|
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
|
||
|
"""
|
||
|
self.catalog = self.read_length_coded_string()
|
||
|
self.db = self.read_length_coded_string()
|
||
|
self.table_name = self.read_length_coded_string().decode(encoding)
|
||
|
self.org_table = self.read_length_coded_string().decode(encoding)
|
||
|
self.name = self.read_length_coded_string().decode(encoding)
|
||
|
self.org_name = self.read_length_coded_string().decode(encoding)
|
||
|
self.charsetnr, self.length, self.type_code, self.flags, self.scale = (
|
||
|
self.read_struct('<xHIBHBxx'))
|
||
|
# 'default' is a length coded binary and is still in the buffer?
|
||
|
# not used for normal result sets...
|
||
|
|
||
|
def description(self):
|
||
|
"""Provides a 7-item tuple compatible with the Python PEP249 DB Spec."""
|
||
|
return (
|
||
|
self.name,
|
||
|
self.type_code,
|
||
|
None, # TODO: display_length; should this be self.length?
|
||
|
self.get_column_length(), # 'internal_size'
|
||
|
self.get_column_length(), # 'precision' # TODO: why!?!?
|
||
|
self.scale,
|
||
|
self.flags % 2 == 0)
|
||
|
|
||
|
def get_column_length(self):
|
||
|
if self.type_code == FIELD_TYPE.VAR_STRING:
|
||
|
mblen = MBLENGTH.get(self.charsetnr, 1)
|
||
|
return self.length // mblen
|
||
|
return self.length
|
||
|
|
||
|
def __str__(self):
|
||
|
return ('%s %r.%r.%r, type=%s, flags=%x'
|
||
|
% (self.__class__, self.db, self.table_name, self.name,
|
||
|
self.type_code, self.flags))
|
||
|
|
||
|
|
||
|
class OKPacketWrapper(object):
|
||
|
"""
|
||
|
OK Packet Wrapper. It uses an existing packet object, and wraps
|
||
|
around it, exposing useful variables while still providing access
|
||
|
to the original packet objects variables and methods.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, from_packet):
|
||
|
if not from_packet.is_ok_packet():
|
||
|
raise ValueError('Cannot create ' + str(self.__class__.__name__) +
|
||
|
' object from invalid packet type')
|
||
|
|
||
|
self.packet = from_packet
|
||
|
self.packet.advance(1)
|
||
|
|
||
|
self.affected_rows = self.packet.read_length_encoded_integer()
|
||
|
self.insert_id = self.packet.read_length_encoded_integer()
|
||
|
self.server_status, self.warning_count = self.read_struct('<HH')
|
||
|
self.message = self.packet.read_all()
|
||
|
self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
|
||
|
|
||
|
def __getattr__(self, key):
|
||
|
return getattr(self.packet, key)
|
||
|
|
||
|
|
||
|
class EOFPacketWrapper(object):
|
||
|
"""
|
||
|
EOF Packet Wrapper. It uses an existing packet object, and wraps
|
||
|
around it, exposing useful variables while still providing access
|
||
|
to the original packet objects variables and methods.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, from_packet):
|
||
|
if not from_packet.is_eof_packet():
|
||
|
raise ValueError(
|
||
|
"Cannot create '{0}' object from invalid packet type".format(
|
||
|
self.__class__))
|
||
|
|
||
|
self.packet = from_packet
|
||
|
self.warning_count, self.server_status = self.packet.read_struct('<xhh')
|
||
|
if DEBUG: print("server_status=", self.server_status)
|
||
|
self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
|
||
|
|
||
|
def __getattr__(self, key):
|
||
|
return getattr(self.packet, key)
|
||
|
|
||
|
|
||
|
class LoadLocalPacketWrapper(object):
|
||
|
"""
|
||
|
Load Local Packet Wrapper. It uses an existing packet object, and wraps
|
||
|
around it, exposing useful variables while still providing access
|
||
|
to the original packet objects variables and methods.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, from_packet):
|
||
|
if not from_packet.is_load_local_packet():
|
||
|
raise ValueError(
|
||
|
"Cannot create '{0}' object from invalid packet type".format(
|
||
|
self.__class__))
|
||
|
|
||
|
self.packet = from_packet
|
||
|
self.filename = self.packet.get_all_data()[1:]
|
||
|
if DEBUG: print("filename=", self.filename)
|
||
|
|
||
|
def __getattr__(self, key):
|
||
|
return getattr(self.packet, key)
|
||
|
|
||
|
|
||
|
class Connection(object):
|
||
|
"""
|
||
|
Representation of a socket with a mysql server.
|
||
|
|
||
|
The proper way to get an instance of this class is to call
|
||
|
connect().
|
||
|
"""
|
||
|
|
||
|
_sock = None
|
||
|
_auth_plugin_name = ''
|
||
|
|
||
|
def __init__(self, host=None, user=None, password="",
|
||
|
database=None, port=0, unix_socket=None,
|
||
|
charset='', sql_mode=None,
|
||
|
read_default_file=None, conv=None, use_unicode=None,
|
||
|
client_flag=0, cursorclass=Cursor, init_command=None,
|
||
|
connect_timeout=None, ssl=None, read_default_group=None,
|
||
|
compress=None, named_pipe=None, no_delay=None,
|
||
|
autocommit=False, db=None, passwd=None, local_infile=False,
|
||
|
max_allowed_packet=16*1024*1024, defer_connect=False,
|
||
|
auth_plugin_map={}, read_timeout=None, write_timeout=None):
|
||
|
"""
|
||
|
Establish a connection to the MySQL database. Accepts several
|
||
|
arguments:
|
||
|
|
||
|
host: Host where the database server is located
|
||
|
user: Username to log in as
|
||
|
password: Password to use.
|
||
|
database: Database to use, None to not use a particular one.
|
||
|
port: MySQL port to use, default is usually OK. (default: 3306)
|
||
|
unix_socket: Optionally, you can use a unix socket rather than TCP/IP.
|
||
|
charset: Charset you want to use.
|
||
|
sql_mode: Default SQL_MODE to use.
|
||
|
read_default_file:
|
||
|
Specifies my.cnf file to read these parameters from under the [client] section.
|
||
|
conv:
|
||
|
Conversion dictionary to use instead of the default one.
|
||
|
This is used to provide custom marshalling and unmarshaling of types.
|
||
|
See converters.
|
||
|
use_unicode:
|
||
|
Whether or not to default to unicode strings.
|
||
|
This option defaults to true for Py3k.
|
||
|
client_flag: Custom flags to send to MySQL. Find potential values in constants.CLIENT.
|
||
|
cursorclass: Custom cursor class to use.
|
||
|
init_command: Initial SQL statement to run when connection is established.
|
||
|
connect_timeout: Timeout before throwing an exception when connecting.
|
||
|
ssl:
|
||
|
A dict of arguments similar to mysql_ssl_set()'s parameters.
|
||
|
For now the capath and cipher arguments are not supported.
|
||
|
read_default_group: Group to read from in the configuration file.
|
||
|
compress; Not supported
|
||
|
named_pipe: Not supported
|
||
|
autocommit: Autocommit mode. None means use server default. (default: False)
|
||
|
local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
|
||
|
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
|
||
|
Only used to limit size of "LOAD LOCAL INFILE" data packet smaller than default (16KB).
|
||
|
defer_connect: Don't explicitly connect on contruction - wait for connect call.
|
||
|
(default: False)
|
||
|
auth_plugin_map: A dict of plugin names to a class that processes that plugin.
|
||
|
The class will take the Connection object as the argument to the constructor.
|
||
|
The class needs an authenticate method taking an authentication packet as
|
||
|
an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
|
||
|
(if no authenticate method) for returning a string from the user. (experimental)
|
||
|
db: Alias for database. (for compatibility to MySQLdb)
|
||
|
passwd: Alias for password. (for compatibility to MySQLdb)
|
||
|
"""
|
||
|
if no_delay is not None:
|
||
|
warnings.warn("no_delay option is deprecated", DeprecationWarning)
|
||
|
|
||
|
if use_unicode is None and sys.version_info[0] > 2:
|
||
|
use_unicode = True
|
||
|
|
||
|
if db is not None and database is None:
|
||
|
database = db
|
||
|
if passwd is not None and not password:
|
||
|
password = passwd
|
||
|
|
||
|
if compress or named_pipe:
|
||
|
raise NotImplementedError("compress and named_pipe arguments are not supported")
|
||
|
|
||
|
if local_infile:
|
||
|
client_flag |= CLIENT.LOCAL_FILES
|
||
|
|
||
|
self.ssl = False
|
||
|
if ssl:
|
||
|
if not SSL_ENABLED:
|
||
|
raise NotImplementedError("ssl module not found")
|
||
|
self.ssl = True
|
||
|
client_flag |= CLIENT.SSL
|
||
|
self.ctx = self._create_ssl_ctx(ssl)
|
||
|
|
||
|
if read_default_group and not read_default_file:
|
||
|
if sys.platform.startswith("win"):
|
||
|
read_default_file = "c:\\my.ini"
|
||
|
else:
|
||
|
read_default_file = "/etc/my.cnf"
|
||
|
|
||
|
if read_default_file:
|
||
|
if not read_default_group:
|
||
|
read_default_group = "client"
|
||
|
|
||
|
cfg = Parser()
|
||
|
cfg.read(os.path.expanduser(read_default_file))
|
||
|
|
||
|
def _config(key, arg):
|
||
|
if arg:
|
||
|
return arg
|
||
|
try:
|
||
|
return cfg.get(read_default_group, key)
|
||
|
except Exception:
|
||
|
return arg
|
||
|
|
||
|
user = _config("user", user)
|
||
|
password = _config("password", password)
|
||
|
host = _config("host", host)
|
||
|
database = _config("database", database)
|
||
|
unix_socket = _config("socket", unix_socket)
|
||
|
port = int(_config("port", port))
|
||
|
charset = _config("default-character-set", charset)
|
||
|
|
||
|
self.host = host or "localhost"
|
||
|
self.port = port or 3306
|
||
|
self.user = user or DEFAULT_USER
|
||
|
self.password = password or ""
|
||
|
self.db = database
|
||
|
self.unix_socket = unix_socket
|
||
|
if read_timeout is not None and read_timeout <= 0:
|
||
|
raise ValueError("read_timeout should be >= 0")
|
||
|
self._read_timeout = read_timeout
|
||
|
if write_timeout is not None and write_timeout <= 0:
|
||
|
raise ValueError("write_timeout should be >= 0")
|
||
|
self._write_timeout = write_timeout
|
||
|
if charset:
|
||
|
self.charset = charset
|
||
|
self.use_unicode = True
|
||
|
else:
|
||
|
self.charset = DEFAULT_CHARSET
|
||
|
self.use_unicode = False
|
||
|
|
||
|
if use_unicode is not None:
|
||
|
self.use_unicode = use_unicode
|
||
|
|
||
|
self.encoding = charset_by_name(self.charset).encoding
|
||
|
|
||
|
client_flag |= CLIENT.CAPABILITIES
|
||
|
if self.db:
|
||
|
client_flag |= CLIENT.CONNECT_WITH_DB
|
||
|
self.client_flag = client_flag
|
||
|
|
||
|
self.cursorclass = cursorclass
|
||
|
self.connect_timeout = connect_timeout
|
||
|
|
||
|
self._result = None
|
||
|
self._affected_rows = 0
|
||
|
self.host_info = "Not connected"
|
||
|
|
||
|
#: specified autocommit mode. None means use server default.
|
||
|
self.autocommit_mode = autocommit
|
||
|
|
||
|
if conv is None:
|
||
|
conv = _conv
|
||
|
# Need for MySQLdb compatibility.
|
||
|
self.encoders = dict([(k, v) for (k, v) in conv.items() if type(k) is not int])
|
||
|
self.decoders = dict([(k, v) for (k, v) in conv.items() if type(k) is int])
|
||
|
self.sql_mode = sql_mode
|
||
|
self.init_command = init_command
|
||
|
self.max_allowed_packet = max_allowed_packet
|
||
|
self._auth_plugin_map = auth_plugin_map
|
||
|
if defer_connect:
|
||
|
self._sock = None
|
||
|
else:
|
||
|
self.connect()
|
||
|
|
||
|
def _create_ssl_ctx(self, sslp):
|
||
|
if isinstance(sslp, ssl.SSLContext):
|
||
|
return sslp
|
||
|
ca = sslp.get('ca')
|
||
|
capath = sslp.get('capath')
|
||
|
hasnoca = ca is None and capath is None
|
||
|
ctx = ssl.create_default_context(cafile=ca, capath=capath)
|
||
|
ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)
|
||
|
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
|
||
|
if 'cert' in sslp:
|
||
|
ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))
|
||
|
if 'cipher' in sslp:
|
||
|
ctx.set_ciphers(sslp['cipher'])
|
||
|
ctx.options |= ssl.OP_NO_SSLv2
|
||
|
ctx.options |= ssl.OP_NO_SSLv3
|
||
|
return ctx
|
||
|
|
||
|
def close(self):
|
||
|
"""Send the quit message and close the socket"""
|
||
|
if self._sock is None:
|
||
|
raise err.Error("Already closed")
|
||
|
send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)
|
||
|
try:
|
||
|
self._write_bytes(send_data)
|
||
|
except Exception:
|
||
|
pass
|
||
|
finally:
|
||
|
sock = self._sock
|
||
|
self._sock = None
|
||
|
self._rfile = None
|
||
|
sock.close()
|
||
|
|
||
|
@property
|
||
|
def open(self):
|
||
|
return self._sock is not None
|
||
|
|
||
|
def __del__(self):
|
||
|
if self._sock:
|
||
|
try:
|
||
|
self._sock.close()
|
||
|
except:
|
||
|
pass
|
||
|
self._sock = None
|
||
|
self._rfile = None
|
||
|
|
||
|
def autocommit(self, value):
|
||
|
self.autocommit_mode = bool(value)
|
||
|
current = self.get_autocommit()
|
||
|
if value != current:
|
||
|
self._send_autocommit_mode()
|
||
|
|
||
|
def get_autocommit(self):
|
||
|
return bool(self.server_status &
|
||
|
SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT)
|
||
|
|
||
|
def _read_ok_packet(self):
|
||
|
pkt = self._read_packet()
|
||
|
if not pkt.is_ok_packet():
|
||
|
raise err.OperationalError(2014, "Command Out of Sync")
|
||
|
ok = OKPacketWrapper(pkt)
|
||
|
self.server_status = ok.server_status
|
||
|
return ok
|
||
|
|
||
|
def _send_autocommit_mode(self):
|
||
|
"""Set whether or not to commit after every execute()"""
|
||
|
self._execute_command(COMMAND.COM_QUERY, "SET AUTOCOMMIT = %s" %
|
||
|
self.escape(self.autocommit_mode))
|
||
|
self._read_ok_packet()
|
||
|
|
||
|
def begin(self):
|
||
|
"""Begin transaction."""
|
||
|
self._execute_command(COMMAND.COM_QUERY, "BEGIN")
|
||
|
self._read_ok_packet()
|
||
|
|
||
|
def commit(self):
|
||
|
"""Commit changes to stable storage"""
|
||
|
self._execute_command(COMMAND.COM_QUERY, "COMMIT")
|
||
|
self._read_ok_packet()
|
||
|
|
||
|
def rollback(self):
|
||
|
"""Roll back the current transaction"""
|
||
|
self._execute_command(COMMAND.COM_QUERY, "ROLLBACK")
|
||
|
self._read_ok_packet()
|
||
|
|
||
|
def show_warnings(self):
|
||
|
"""SHOW WARNINGS"""
|
||
|
self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS")
|
||
|
result = MySQLResult(self)
|
||
|
result.read()
|
||
|
return result.rows
|
||
|
|
||
|
def select_db(self, db):
|
||
|
"""Set current db"""
|
||
|
self._execute_command(COMMAND.COM_INIT_DB, db)
|
||
|
self._read_ok_packet()
|
||
|
|
||
|
def escape(self, obj, mapping=None):
|
||
|
"""Escape whatever value you pass to it.
|
||
|
|
||
|
Non-standard, for internal use; do not use this in your applications.
|
||
|
"""
|
||
|
if isinstance(obj, str_type):
|
||
|
return "'" + self.escape_string(obj) + "'"
|
||
|
return escape_item(obj, self.charset, mapping=mapping)
|
||
|
|
||
|
def literal(self, obj):
|
||
|
"""Alias for escape()
|
||
|
|
||
|
Non-standard, for internal use; do not use this in your applications.
|
||
|
"""
|
||
|
return self.escape(obj, self.encoders)
|
||
|
|
||
|
def escape_string(self, s):
|
||
|
if (self.server_status &
|
||
|
SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES):
|
||
|
return s.replace("'", "''")
|
||
|
return escape_string(s)
|
||
|
|
||
|
def cursor(self, cursor=None):
|
||
|
"""Create a new cursor to execute queries with"""
|
||
|
if cursor:
|
||
|
return cursor(self)
|
||
|
return self.cursorclass(self)
|
||
|
|
||
|
def __enter__(self):
|
||
|
"""Context manager that returns a Cursor"""
|
||
|
return self.cursor()
|
||
|
|
||
|
def __exit__(self, exc, value, traceback):
|
||
|
"""On successful exit, commit. On exception, rollback"""
|
||
|
if exc:
|
||
|
self.rollback()
|
||
|
else:
|
||
|
self.commit()
|
||
|
|
||
|
# The following methods are INTERNAL USE ONLY (called from Cursor)
|
||
|
def query(self, sql, unbuffered=False):
|
||
|
# if DEBUG:
|
||
|
# print("DEBUG: sending query:", sql)
|
||
|
if isinstance(sql, text_type) and not (JYTHON or IRONPYTHON):
|
||
|
if PY2:
|
||
|
sql = sql.encode(self.encoding)
|
||
|
else:
|
||
|
sql = sql.encode(self.encoding, 'surrogateescape')
|
||
|
self._execute_command(COMMAND.COM_QUERY, sql)
|
||
|
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
|
||
|
return self._affected_rows
|
||
|
|
||
|
def next_result(self, unbuffered=False):
|
||
|
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
|
||
|
return self._affected_rows
|
||
|
|
||
|
def affected_rows(self):
|
||
|
return self._affected_rows
|
||
|
|
||
|
def kill(self, thread_id):
|
||
|
arg = struct.pack('<I', thread_id)
|
||
|
self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
|
||
|
return self._read_ok_packet()
|
||
|
|
||
|
def ping(self, reconnect=True):
|
||
|
"""Check if the server is alive"""
|
||
|
if self._sock is None:
|
||
|
if reconnect:
|
||
|
self.connect()
|
||
|
reconnect = False
|
||
|
else:
|
||
|
raise err.Error("Already closed")
|
||
|
try:
|
||
|
self._execute_command(COMMAND.COM_PING, "")
|
||
|
return self._read_ok_packet()
|
||
|
except Exception:
|
||
|
if reconnect:
|
||
|
self.connect()
|
||
|
return self.ping(False)
|
||
|
else:
|
||
|
raise
|
||
|
|
||
|
def set_charset(self, charset):
|
||
|
# Make sure charset is supported.
|
||
|
encoding = charset_by_name(charset).encoding
|
||
|
|
||
|
self._execute_command(COMMAND.COM_QUERY, "SET NAMES %s" % self.escape(charset))
|
||
|
self._read_packet()
|
||
|
self.charset = charset
|
||
|
self.encoding = encoding
|
||
|
|
||
|
def connect(self, sock=None):
|
||
|
try:
|
||
|
if sock is None:
|
||
|
if self.unix_socket and self.host in ('localhost', '127.0.0.1'):
|
||
|
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||
|
sock.settimeout(self.connect_timeout)
|
||
|
sock.connect(self.unix_socket)
|
||
|
self.host_info = "Localhost via UNIX socket"
|
||
|
if DEBUG: print('connected using unix_socket')
|
||
|
else:
|
||
|
while True:
|
||
|
try:
|
||
|
sock = socket.create_connection(
|
||
|
(self.host, self.port), self.connect_timeout)
|
||
|
break
|
||
|
except (OSError, IOError) as e:
|
||
|
if e.errno == errno.EINTR:
|
||
|
continue
|
||
|
raise
|
||
|
self.host_info = "socket %s:%d" % (self.host, self.port)
|
||
|
if DEBUG: print('connected using socket')
|
||
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||
|
sock.settimeout(None)
|
||
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||
|
self._sock = sock
|
||
|
self._rfile = _makefile(sock, 'rb')
|
||
|
self._next_seq_id = 0
|
||
|
|
||
|
self._get_server_information()
|
||
|
self._request_authentication()
|
||
|
|
||
|
if self.sql_mode is not None:
|
||
|
c = self.cursor()
|
||
|
c.execute("SET sql_mode=%s", (self.sql_mode,))
|
||
|
|
||
|
if self.init_command is not None:
|
||
|
c = self.cursor()
|
||
|
c.execute(self.init_command)
|
||
|
c.close()
|
||
|
self.commit()
|
||
|
|
||
|
if self.autocommit_mode is not None:
|
||
|
self.autocommit(self.autocommit_mode)
|
||
|
except BaseException as e:
|
||
|
self._rfile = None
|
||
|
if sock is not None:
|
||
|
try:
|
||
|
sock.close()
|
||
|
except:
|
||
|
pass
|
||
|
|
||
|
if isinstance(e, (OSError, IOError, socket.error)):
|
||
|
exc = err.OperationalError(
|
||
|
2003,
|
||
|
"Can't connect to MySQL server on %r (%s)" % (
|
||
|
self.host, e))
|
||
|
# Keep original exception and traceback to investigate error.
|
||
|
exc.original_exception = e
|
||
|
exc.traceback = traceback.format_exc()
|
||
|
if DEBUG: print(exc.traceback)
|
||
|
raise exc
|
||
|
|
||
|
# If e is neither DatabaseError or IOError, It's a bug.
|
||
|
# But raising AssertionError hides original error.
|
||
|
# So just reraise it.
|
||
|
raise
|
||
|
|
||
|
def write_packet(self, payload):
|
||
|
"""Writes an entire "mysql packet" in its entirety to the network
|
||
|
addings its length and sequence number.
|
||
|
"""
|
||
|
# Internal note: when you build packet manualy and calls _write_bytes()
|
||
|
# directly, you should set self._next_seq_id properly.
|
||
|
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
|
||
|
if DEBUG: dump_packet(data)
|
||
|
self._write_bytes(data)
|
||
|
self._next_seq_id = (self._next_seq_id + 1) % 256
|
||
|
|
||
|
def _read_packet(self, packet_type=MysqlPacket):
|
||
|
"""Read an entire "mysql packet" in its entirety from the network
|
||
|
and return a MysqlPacket type that represents the results.
|
||
|
"""
|
||
|
buff = b''
|
||
|
while True:
|
||
|
packet_header = self._read_bytes(4)
|
||
|
if DEBUG: dump_packet(packet_header)
|
||
|
|
||
|
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
|
||
|
bytes_to_read = btrl + (btrh << 16)
|
||
|
if packet_number != self._next_seq_id:
|
||
|
raise err.InternalError("Packet sequence number wrong - got %d expected %d" %
|
||
|
(packet_number, self._next_seq_id))
|
||
|
self._next_seq_id = (self._next_seq_id + 1) % 256
|
||
|
|
||
|
recv_data = self._read_bytes(bytes_to_read)
|
||
|
if DEBUG: dump_packet(recv_data)
|
||
|
buff += recv_data
|
||
|
# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
|
||
|
if bytes_to_read == 0xffffff:
|
||
|
continue
|
||
|
if bytes_to_read < MAX_PACKET_LEN:
|
||
|
break
|
||
|
|
||
|
packet = packet_type(buff, self.encoding)
|
||
|
packet.check_error()
|
||
|
return packet
|
||
|
|
||
|
def _read_bytes(self, num_bytes):
|
||
|
self._sock.settimeout(self._read_timeout)
|
||
|
while True:
|
||
|
try:
|
||
|
data = self._rfile.read(num_bytes)
|
||
|
break
|
||
|
except (IOError, OSError) as e:
|
||
|
if e.errno == errno.EINTR:
|
||
|
continue
|
||
|
raise err.OperationalError(
|
||
|
2013,
|
||
|
"Lost connection to MySQL server during query (%s)" % (e,))
|
||
|
if len(data) < num_bytes:
|
||
|
raise err.OperationalError(
|
||
|
2013, "Lost connection to MySQL server during query")
|
||
|
return data
|
||
|
|
||
|
def _write_bytes(self, data):
|
||
|
self._sock.settimeout(self._write_timeout)
|
||
|
try:
|
||
|
self._sock.sendall(data)
|
||
|
except IOError as e:
|
||
|
raise err.OperationalError(2006, "MySQL server has gone away (%r)" % (e,))
|
||
|
|
||
|
def _read_query_result(self, unbuffered=False):
|
||
|
if unbuffered:
|
||
|
try:
|
||
|
result = MySQLResult(self)
|
||
|
result.init_unbuffered_query()
|
||
|
except:
|
||
|
result.unbuffered_active = False
|
||
|
result.connection = None
|
||
|
raise
|
||
|
else:
|
||
|
result = MySQLResult(self)
|
||
|
result.read()
|
||
|
self._result = result
|
||
|
if result.server_status is not None:
|
||
|
self.server_status = result.server_status
|
||
|
return result.affected_rows
|
||
|
|
||
|
def insert_id(self):
|
||
|
if self._result:
|
||
|
return self._result.insert_id
|
||
|
else:
|
||
|
return 0
|
||
|
|
||
|
def _execute_command(self, command, sql):
|
||
|
if not self._sock:
|
||
|
raise err.InterfaceError("(0, '')")
|
||
|
|
||
|
# If the last query was unbuffered, make sure it finishes before
|
||
|
# sending new commands
|
||
|
if self._result is not None:
|
||
|
if self._result.unbuffered_active:
|
||
|
warnings.warn("Previous unbuffered result was left incomplete")
|
||
|
self._result._finish_unbuffered_query()
|
||
|
while self._result.has_next:
|
||
|
self.next_result()
|
||
|
self._result = None
|
||
|
|
||
|
if isinstance(sql, text_type):
|
||
|
sql = sql.encode(self.encoding)
|
||
|
|
||
|
packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
|
||
|
|
||
|
# tiny optimization: build first packet manually instead of
|
||
|
# calling self..write_packet()
|
||
|
prelude = struct.pack('<iB', packet_size, command)
|
||
|
packet = prelude + sql[:packet_size-1]
|
||
|
self._write_bytes(packet)
|
||
|
if DEBUG: dump_packet(packet)
|
||
|
self._next_seq_id = 1
|
||
|
|
||
|
if packet_size < MAX_PACKET_LEN:
|
||
|
return
|
||
|
|
||
|
sql = sql[packet_size-1:]
|
||
|
while True:
|
||
|
packet_size = min(MAX_PACKET_LEN, len(sql))
|
||
|
self.write_packet(sql[:packet_size])
|
||
|
sql = sql[packet_size:]
|
||
|
if not sql and packet_size < MAX_PACKET_LEN:
|
||
|
break
|
||
|
|
||
|
def _request_authentication(self):
|
||
|
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
||
|
if int(self.server_version.split('.', 1)[0]) >= 5:
|
||
|
self.client_flag |= CLIENT.MULTI_RESULTS
|
||
|
|
||
|
if self.user is None:
|
||
|
raise ValueError("Did not specify a username")
|
||
|
|
||
|
charset_id = charset_by_name(self.charset).id
|
||
|
if isinstance(self.user, text_type):
|
||
|
self.user = self.user.encode(self.encoding)
|
||
|
|
||
|
data_init = struct.pack('<iIB23s', self.client_flag, 1, charset_id, b'')
|
||
|
|
||
|
if self.ssl and self.server_capabilities & CLIENT.SSL:
|
||
|
self.write_packet(data_init)
|
||
|
|
||
|
self._sock = self.ctx.wrap_socket(self._sock, server_hostname=self.host)
|
||
|
self._rfile = _makefile(self._sock, 'rb')
|
||
|
|
||
|
data = data_init + self.user + b'\0'
|
||
|
|
||
|
authresp = b''
|
||
|
if self._auth_plugin_name in ('', 'mysql_native_password'):
|
||
|
authresp = _scramble(self.password.encode('latin1'), self.salt)
|
||
|
|
||
|
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
|
||
|
data += lenenc_int(len(authresp)) + authresp
|
||
|
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
|
||
|
data += struct.pack('B', len(authresp)) + authresp
|
||
|
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
|
||
|
data += authresp + b'\0'
|
||
|
|
||
|
if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
|
||
|
if isinstance(self.db, text_type):
|
||
|
self.db = self.db.encode(self.encoding)
|
||
|
data += self.db + b'\0'
|
||
|
|
||
|
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
|
||
|
name = self._auth_plugin_name
|
||
|
if isinstance(name, text_type):
|
||
|
name = name.encode('ascii')
|
||
|
data += name + b'\0'
|
||
|
|
||
|
self.write_packet(data)
|
||
|
auth_packet = self._read_packet()
|
||
|
|
||
|
# if authentication method isn't accepted the first byte
|
||
|
# will have the octet 254
|
||
|
if auth_packet.is_auth_switch_request():
|
||
|
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
|
||
|
auth_packet.read_uint8() # 0xfe packet identifier
|
||
|
plugin_name = auth_packet.read_string()
|
||
|
if self.server_capabilities & CLIENT.PLUGIN_AUTH and plugin_name is not None:
|
||
|
auth_packet = self._process_auth(plugin_name, auth_packet)
|
||
|
else:
|
||
|
# send legacy handshake
|
||
|
data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0'
|
||
|
self.write_packet(data)
|
||
|
auth_packet = self._read_packet()
|
||
|
|
||
|
def _process_auth(self, plugin_name, auth_packet):
|
||
|
plugin_class = self._auth_plugin_map.get(plugin_name)
|
||
|
if not plugin_class:
|
||
|
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
|
||
|
if plugin_class:
|
||
|
try:
|
||
|
handler = plugin_class(self)
|
||
|
return handler.authenticate(auth_packet)
|
||
|
except AttributeError:
|
||
|
if plugin_name != b'dialog':
|
||
|
raise err.OperationalError(2059, "Authentication plugin '%s'" \
|
||
|
" not loaded: - %r missing authenticate method" % (plugin_name, plugin_class))
|
||
|
except TypeError:
|
||
|
raise err.OperationalError(2059, "Authentication plugin '%s'" \
|
||
|
" not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class))
|
||
|
else:
|
||
|
handler = None
|
||
|
if plugin_name == b"mysql_native_password":
|
||
|
# https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41
|
||
|
data = _scramble(self.password.encode('latin1'), auth_packet.read_all()) + b'\0'
|
||
|
elif plugin_name == b"mysql_old_password":
|
||
|
# https://dev.mysql.com/doc/internals/en/old-password-authentication.html
|
||
|
data = _scramble_323(self.password.encode('latin1'), auth_packet.read_all()) + b'\0'
|
||
|
elif plugin_name == b"mysql_clear_password":
|
||
|
# https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
|
||
|
data = self.password.encode('latin1') + b'\0'
|
||
|
elif plugin_name == b"dialog":
|
||
|
pkt = auth_packet
|
||
|
while True:
|
||
|
flag = pkt.read_uint8()
|
||
|
echo = (flag & 0x06) == 0x02
|
||
|
last = (flag & 0x01) == 0x01
|
||
|
prompt = pkt.read_all()
|
||
|
|
||
|
if prompt == b"Password: ":
|
||
|
self.write_packet(self.password.encode('latin1') + b'\0')
|
||
|
elif handler:
|
||
|
resp = 'no response - TypeError within plugin.prompt method'
|
||
|
try:
|
||
|
resp = handler.prompt(echo, prompt)
|
||
|
self.write_packet(resp + b'\0')
|
||
|
except AttributeError:
|
||
|
raise err.OperationalError(2059, "Authentication plugin '%s'" \
|
||
|
" not loaded: - %r missing prompt method" % (plugin_name, handler))
|
||
|
except TypeError:
|
||
|
raise err.OperationalError(2061, "Authentication plugin '%s'" \
|
||
|
" %r didn't respond with string. Returned '%r' to prompt %r" % (plugin_name, handler, resp, prompt))
|
||
|
else:
|
||
|
raise err.OperationalError(2059, "Authentication plugin '%s' (%r) not configured" % (plugin_name, handler))
|
||
|
pkt = self._read_packet()
|
||
|
pkt.check_error()
|
||
|
if pkt.is_ok_packet() or last:
|
||
|
break
|
||
|
return pkt
|
||
|
else:
|
||
|
raise err.OperationalError(2059, "Authentication plugin '%s' not configured" % plugin_name)
|
||
|
|
||
|
self.write_packet(data)
|
||
|
pkt = self._read_packet()
|
||
|
pkt.check_error()
|
||
|
return pkt
|
||
|
|
||
|
# _mysql support
|
||
|
def thread_id(self):
|
||
|
return self.server_thread_id[0]
|
||
|
|
||
|
def character_set_name(self):
|
||
|
return self.charset
|
||
|
|
||
|
def get_host_info(self):
|
||
|
return self.host_info
|
||
|
|
||
|
def get_proto_info(self):
|
||
|
return self.protocol_version
|
||
|
|
||
|
def _get_server_information(self):
|
||
|
i = 0
|
||
|
packet = self._read_packet()
|
||
|
data = packet.get_all_data()
|
||
|
|
||
|
if DEBUG: dump_packet(data)
|
||
|
self.protocol_version = byte2int(data[i:i+1])
|
||
|
i += 1
|
||
|
|
||
|
server_end = data.find(b'\0', i)
|
||
|
self.server_version = data[i:server_end].decode('latin1')
|
||
|
i = server_end + 1
|
||
|
|
||
|
self.server_thread_id = struct.unpack('<I', data[i:i+4])
|
||
|
i += 4
|
||
|
|
||
|
self.salt = data[i:i+8]
|
||
|
i += 9 # 8 + 1(filler)
|
||
|
|
||
|
self.server_capabilities = struct.unpack('<H', data[i:i+2])[0]
|
||
|
i += 2
|
||
|
|
||
|
if len(data) >= i + 6:
|
||
|
lang, stat, cap_h, salt_len = struct.unpack('<BHHB', data[i:i+6])
|
||
|
i += 6
|
||
|
self.server_language = lang
|
||
|
self.server_charset = charset_by_id(lang).name
|
||
|
|
||
|
self.server_status = stat
|
||
|
if DEBUG: print("server_status: %x" % stat)
|
||
|
|
||
|
self.server_capabilities |= cap_h << 16
|
||
|
if DEBUG: print("salt_len:", salt_len)
|
||
|
salt_len = max(12, salt_len - 9)
|
||
|
|
||
|
# reserved
|
||
|
i += 10
|
||
|
|
||
|
if len(data) >= i + salt_len:
|
||
|
# salt_len includes auth_plugin_data_part_1 and filler
|
||
|
self.salt += data[i:i+salt_len]
|
||
|
i += salt_len
|
||
|
|
||
|
i+=1
|
||
|
# AUTH PLUGIN NAME may appear here.
|
||
|
if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
|
||
|
# Due to Bug#59453 the auth-plugin-name is missing the terminating
|
||
|
# NUL-char in versions prior to 5.5.10 and 5.6.2.
|
||
|
# ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
|
||
|
# didn't use version checks as mariadb is corrected and reports
|
||
|
# earlier than those two.
|
||
|
server_end = data.find(b'\0', i)
|
||
|
if server_end < 0: # pragma: no cover - very specific upstream bug
|
||
|
# not found \0 and last field so take it all
|
||
|
self._auth_plugin_name = data[i:].decode('latin1')
|
||
|
else:
|
||
|
self._auth_plugin_name = data[i:server_end].decode('latin1')
|
||
|
|
||
|
def get_server_info(self):
|
||
|
return self.server_version
|
||
|
|
||
|
Warning = err.Warning
|
||
|
Error = err.Error
|
||
|
InterfaceError = err.InterfaceError
|
||
|
DatabaseError = err.DatabaseError
|
||
|
DataError = err.DataError
|
||
|
OperationalError = err.OperationalError
|
||
|
IntegrityError = err.IntegrityError
|
||
|
InternalError = err.InternalError
|
||
|
ProgrammingError = err.ProgrammingError
|
||
|
NotSupportedError = err.NotSupportedError
|
||
|
|
||
|
|
||
|
class MySQLResult(object):
|
||
|
|
||
|
def __init__(self, connection):
|
||
|
"""
|
||
|
:type connection: Connection
|
||
|
"""
|
||
|
self.connection = connection
|
||
|
self.affected_rows = None
|
||
|
self.insert_id = None
|
||
|
self.server_status = None
|
||
|
self.warning_count = 0
|
||
|
self.message = None
|
||
|
self.field_count = 0
|
||
|
self.description = None
|
||
|
self.rows = None
|
||
|
self.has_next = None
|
||
|
self.unbuffered_active = False
|
||
|
|
||
|
def __del__(self):
|
||
|
if self.unbuffered_active:
|
||
|
self._finish_unbuffered_query()
|
||
|
|
||
|
def read(self):
|
||
|
try:
|
||
|
first_packet = self.connection._read_packet()
|
||
|
|
||
|
if first_packet.is_ok_packet():
|
||
|
self._read_ok_packet(first_packet)
|
||
|
elif first_packet.is_load_local_packet():
|
||
|
self._read_load_local_packet(first_packet)
|
||
|
else:
|
||
|
self._read_result_packet(first_packet)
|
||
|
finally:
|
||
|
self.connection = None
|
||
|
|
||
|
def init_unbuffered_query(self):
|
||
|
self.unbuffered_active = True
|
||
|
first_packet = self.connection._read_packet()
|
||
|
|
||
|
if first_packet.is_ok_packet():
|
||
|
self._read_ok_packet(first_packet)
|
||
|
self.unbuffered_active = False
|
||
|
self.connection = None
|
||
|
elif first_packet.is_load_local_packet():
|
||
|
self._read_load_local_packet(first_packet)
|
||
|
self.unbuffered_active = False
|
||
|
self.connection = None
|
||
|
else:
|
||
|
self.field_count = first_packet.read_length_encoded_integer()
|
||
|
self._get_descriptions()
|
||
|
|
||
|
# Apparently, MySQLdb picks this number because it's the maximum
|
||
|
# value of a 64bit unsigned integer. Since we're emulating MySQLdb,
|
||
|
# we set it to this instead of None, which would be preferred.
|
||
|
self.affected_rows = 18446744073709551615
|
||
|
|
||
|
def _read_ok_packet(self, first_packet):
|
||
|
ok_packet = OKPacketWrapper(first_packet)
|
||
|
self.affected_rows = ok_packet.affected_rows
|
||
|
self.insert_id = ok_packet.insert_id
|
||
|
self.server_status = ok_packet.server_status
|
||
|
self.warning_count = ok_packet.warning_count
|
||
|
self.message = ok_packet.message
|
||
|
self.has_next = ok_packet.has_next
|
||
|
|
||
|
def _read_load_local_packet(self, first_packet):
|
||
|
load_packet = LoadLocalPacketWrapper(first_packet)
|
||
|
sender = LoadLocalFile(load_packet.filename, self.connection)
|
||
|
try:
|
||
|
sender.send_data()
|
||
|
except:
|
||
|
self.connection._read_packet() # skip ok packet
|
||
|
raise
|
||
|
|
||
|
ok_packet = self.connection._read_packet()
|
||
|
if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
|
||
|
raise err.OperationalError(2014, "Commands Out of Sync")
|
||
|
self._read_ok_packet(ok_packet)
|
||
|
|
||
|
def _check_packet_is_eof(self, packet):
|
||
|
if not packet.is_eof_packet():
|
||
|
return False
|
||
|
#TODO: Support CLIENT.DEPRECATE_EOF
|
||
|
# 1) Add DEPRECATE_EOF to CAPABILITIES
|
||
|
# 2) Mask CAPABILITIES with server_capabilities
|
||
|
# 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper instead of EOFPacketWrapper
|
||
|
wp = EOFPacketWrapper(packet)
|
||
|
self.warning_count = wp.warning_count
|
||
|
self.has_next = wp.has_next
|
||
|
return True
|
||
|
|
||
|
def _read_result_packet(self, first_packet):
|
||
|
self.field_count = first_packet.read_length_encoded_integer()
|
||
|
self._get_descriptions()
|
||
|
self._read_rowdata_packet()
|
||
|
|
||
|
def _read_rowdata_packet_unbuffered(self):
|
||
|
# Check if in an active query
|
||
|
if not self.unbuffered_active:
|
||
|
return
|
||
|
|
||
|
# EOF
|
||
|
packet = self.connection._read_packet()
|
||
|
if self._check_packet_is_eof(packet):
|
||
|
self.unbuffered_active = False
|
||
|
self.connection = None
|
||
|
self.rows = None
|
||
|
return
|
||
|
|
||
|
row = self._read_row_from_packet(packet)
|
||
|
self.affected_rows = 1
|
||
|
self.rows = (row,) # rows should tuple of row for MySQL-python compatibility.
|
||
|
return row
|
||
|
|
||
|
def _finish_unbuffered_query(self):
|
||
|
# After much reading on the MySQL protocol, it appears that there is,
|
||
|
# in fact, no way to stop MySQL from sending all the data after
|
||
|
# executing a query, so we just spin, and wait for an EOF packet.
|
||
|
while self.unbuffered_active:
|
||
|
packet = self.connection._read_packet()
|
||
|
if self._check_packet_is_eof(packet):
|
||
|
self.unbuffered_active = False
|
||
|
self.connection = None # release reference to kill cyclic reference.
|
||
|
|
||
|
def _read_rowdata_packet(self):
|
||
|
"""Read a rowdata packet for each data row in the result set."""
|
||
|
rows = []
|
||
|
while True:
|
||
|
packet = self.connection._read_packet()
|
||
|
if self._check_packet_is_eof(packet):
|
||
|
self.connection = None # release reference to kill cyclic reference.
|
||
|
break
|
||
|
rows.append(self._read_row_from_packet(packet))
|
||
|
|
||
|
self.affected_rows = len(rows)
|
||
|
self.rows = tuple(rows)
|
||
|
|
||
|
def _read_row_from_packet(self, packet):
|
||
|
row = []
|
||
|
for encoding, converter in self.converters:
|
||
|
try:
|
||
|
data = packet.read_length_coded_string()
|
||
|
except IndexError:
|
||
|
# No more columns in this row
|
||
|
# See https://github.com/PyMySQL/PyMySQL/pull/434
|
||
|
break
|
||
|
if data is not None:
|
||
|
if encoding is not None:
|
||
|
data = data.decode(encoding)
|
||
|
if DEBUG: print("DEBUG: DATA = ", data)
|
||
|
if converter is not None:
|
||
|
data = converter(data)
|
||
|
row.append(data)
|
||
|
return tuple(row)
|
||
|
|
||
|
def _get_descriptions(self):
|
||
|
"""Read a column descriptor packet for each column in the result."""
|
||
|
self.fields = []
|
||
|
self.converters = []
|
||
|
use_unicode = self.connection.use_unicode
|
||
|
conn_encoding = self.connection.encoding
|
||
|
description = []
|
||
|
|
||
|
for i in range_type(self.field_count):
|
||
|
field = self.connection._read_packet(FieldDescriptorPacket)
|
||
|
self.fields.append(field)
|
||
|
description.append(field.description())
|
||
|
field_type = field.type_code
|
||
|
if use_unicode:
|
||
|
if field_type == FIELD_TYPE.JSON:
|
||
|
# When SELECT from JSON column: charset = binary
|
||
|
# When SELECT CAST(... AS JSON): charset = connection encoding
|
||
|
# This behavior is different from TEXT / BLOB.
|
||
|
# We should decode result by connection encoding regardless charsetnr.
|
||
|
# See https://github.com/PyMySQL/PyMySQL/issues/488
|
||
|
encoding = conn_encoding # SELECT CAST(... AS JSON)
|
||
|
elif field_type in TEXT_TYPES:
|
||
|
if field.charsetnr == 63: # binary
|
||
|
# TEXTs with charset=binary means BINARY types.
|
||
|
encoding = None
|
||
|
else:
|
||
|
encoding = conn_encoding
|
||
|
else:
|
||
|
# Integers, Dates and Times, and other basic data is encoded in ascii
|
||
|
encoding = 'ascii'
|
||
|
else:
|
||
|
encoding = None
|
||
|
converter = self.connection.decoders.get(field_type)
|
||
|
if converter is through:
|
||
|
converter = None
|
||
|
if DEBUG: print("DEBUG: field={}, converter={}".format(field, converter))
|
||
|
self.converters.append((encoding, converter))
|
||
|
|
||
|
eof_packet = self.connection._read_packet()
|
||
|
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
|
||
|
self.description = tuple(description)
|
||
|
|
||
|
|
||
|
class LoadLocalFile(object):
|
||
|
def __init__(self, filename, connection):
|
||
|
self.filename = filename
|
||
|
self.connection = connection
|
||
|
|
||
|
def send_data(self):
|
||
|
"""Send data packets from the local file to the server"""
|
||
|
if not self.connection._sock:
|
||
|
raise err.InterfaceError("(0, '')")
|
||
|
conn = self.connection
|
||
|
|
||
|
try:
|
||
|
with open(self.filename, 'rb') as open_file:
|
||
|
packet_size = min(conn.max_allowed_packet, 16*1024) # 16KB is efficient enough
|
||
|
while True:
|
||
|
chunk = open_file.read(packet_size)
|
||
|
if not chunk:
|
||
|
break
|
||
|
conn.write_packet(chunk)
|
||
|
except IOError:
|
||
|
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
|
||
|
finally:
|
||
|
# send the empty packet to signify we are done sending data
|
||
|
conn.write_packet(b'')
|