SP/web2py/gluon/packages/dal/pydal/adapters/postgres.py

285 lines
9.7 KiB
Python
Raw Permalink Normal View History

2018-10-25 15:33:07 +00:00
import re
from .._compat import PY2, with_metaclass, iterkeys, to_unicode, long
from .._globals import IDENTITY, THREAD_LOCAL
from ..drivers import psycopg2_adapt
from ..helpers.classes import ConnectionConfigurationMixin
from .base import SQLAdapter
from . import AdapterMeta, adapters, with_connection, with_connection_or_raise
class PostgreMeta(AdapterMeta):
def __call__(cls, *args, **kwargs):
if cls not in [Postgre, PostgreNew, PostgreBoolean]:
return AdapterMeta.__call__(cls, *args, **kwargs)
available_drivers = [
driver for driver in cls.drivers
if driver in iterkeys(kwargs['db']._drivers_available)]
uri_items = kwargs['uri'].split('://', 1)[0].split(':')
uri_driver = uri_items[1] if len(uri_items) > 1 else None
if uri_driver and uri_driver in available_drivers:
driver = uri_driver
else:
driver = available_drivers[0] if available_drivers else \
cls.drivers[0]
cls = adapters._registry_[uri_items[0] + ":" + driver]
return AdapterMeta.__call__(cls, *args, **kwargs)
@adapters.register_for('postgres')
class Postgre(
with_metaclass(PostgreMeta, ConnectionConfigurationMixin, SQLAdapter)
):
dbengine = 'postgres'
drivers = ('psycopg2', 'pg8000')
support_distributed_transaction = True
REGEX_URI = re.compile(
'^(?P<user>[^:@]+)(\:(?P<password>[^@]*))?@(?P<host>\[[^/]+\]|' +
'[^\:@]*)(\:(?P<port>[0-9]+))?/(?P<db>[^\?]+)' +
'(\?sslmode=(?P<sslmode>.+))?(\?unix_socket=(?P<socket>.+))?$')
def __init__(self, db, uri, pool_size=0, folder=None, db_codec='UTF-8',
credential_decoder=IDENTITY, driver_args={},
adapter_args={}, do_connect=True, srid=4326,
after_connection=None):
self.srid = srid
super(Postgre, self).__init__(
db, uri, pool_size, folder, db_codec, credential_decoder,
driver_args, adapter_args, do_connect, after_connection)
def _initialize_(self, do_connect):
super(Postgre, self)._initialize_(do_connect)
ruri = self.uri.split('://', 1)[1]
m = self.REGEX_URI.match(ruri)
if not m:
raise SyntaxError("Invalid URI string in DAL")
user = self.credential_decoder(m.group('user'))
if not user:
raise SyntaxError('User required')
password = self.credential_decoder(m.group('password'))
if not password:
password = ''
host = m.group('host')
socket = m.group('socket')
if not host and not socket:
raise SyntaxError('Host name required')
db = m.group('db')
if not db and not socket:
raise SyntaxError('Database name required')
port = int(m.group('port') or '5432')
sslmode = m.group('sslmode')
if socket:
self.driver_args.update(user=user, host=socket, port=port, password=password)
if db:
self.driver_args['database'] = db
else:
self.driver_args.update(database=db, user=user, host=host, port=port, password=password)
if sslmode:
self.driver_args['sslmode'] = sslmode
# choose diver according uri
if self.driver:
self.__version__ = "%s %s" % (self.driver.__name__,
self.driver.__version__)
else:
self.__version__ = None
THREAD_LOCAL._pydal_last_insert_ = None
self._mock_reconnect()
def _get_json_dialect(self):
from ..dialects.postgre import PostgreDialectJSON
return PostgreDialectJSON
def _get_json_parser(self):
from ..parsers.postgre import PostgreAutoJSONParser
return PostgreAutoJSONParser
@property
def _last_insert(self):
return THREAD_LOCAL._pydal_last_insert_
@_last_insert.setter
def _last_insert(self, value):
THREAD_LOCAL._pydal_last_insert_ = value
def connector(self):
return self.driver.connect(**self.driver_args)
def after_connection(self):
self.execute("SET CLIENT_ENCODING TO 'UTF8'")
self.execute("SET standard_conforming_strings=on;")
def _configure_on_first_reconnect(self):
self._config_json()
def lastrowid(self, table):
if self._last_insert:
return long(self.cursor.fetchone()[0])
sequence_name = table._sequence_name
self.execute("SELECT currval(%s);" % self.adapt(sequence_name))
return long(self.cursor.fetchone()[0])
def _insert(self, table, fields):
self._last_insert = None
if fields:
retval = None
if hasattr(table, '_id'):
self._last_insert = (table._id, 1)
retval = table._id._rname
return self.dialect.insert(
table._rname,
','.join(el[0]._rname for el in fields),
','.join(self.expand(v, f.type) for f, v in fields),
retval)
return self.dialect.insert_empty(table._rname)
@with_connection
def prepare(self, key):
self.execute("PREPARE TRANSACTION '%s';" % key)
@with_connection
def commit_prepared(self, key):
self.execute("COMMIT PREPARED '%s';" % key)
@with_connection
def rollback_prepared(self, key):
self.execute("ROLLBACK PREPARED '%s';" % key)
@adapters.register_for('postgres:psycopg2')
class PostgrePsyco(Postgre):
drivers = ('psycopg2',)
def _config_json(self):
use_json = self.driver.__version__ >= "2.0.12" and \
self.connection.server_version >= 90200
if use_json:
self.dialect = self._get_json_dialect()(self)
if self.driver.__version__ >= '2.5.0':
self.parser = self._get_json_parser()(self)
def adapt(self, obj):
adapted = psycopg2_adapt(obj)
# deal with new relic Connection Wrapper (newrelic>=2.10.0.8)
cxn = getattr(self.connection, '__wrapped__', self.connection)
adapted.prepare(cxn)
rv = adapted.getquoted()
if not PY2:
if isinstance(rv, bytes):
return rv.decode('utf-8')
return rv
@adapters.register_for('postgres:pg8000')
class PostgrePG8000(Postgre):
drivers = ('pg8000',)
def _config_json(self):
if self.connection._server_version >= "9.2.0":
self.dialect = self._get_json_dialect()(self)
if self.driver.__version__ >= '1.10.2':
self.parser = self._get_json_parser()(self)
def adapt(self, obj):
return "'%s'" % obj.replace("%", "%%").replace("'", "''")
@with_connection_or_raise
def execute(self, *args, **kwargs):
if PY2:
args = list(args)
args[0] = to_unicode(args[0])
return super(PostgrePG8000, self).execute(*args, **kwargs)
@adapters.register_for('postgres2')
class PostgreNew(Postgre):
def _get_json_dialect(self):
from ..dialects.postgre import PostgreDialectArraysJSON
return PostgreDialectArraysJSON
def _get_json_parser(self):
from ..parsers.postgre import PostgreNewAutoJSONParser
return PostgreNewAutoJSONParser
@adapters.register_for('postgres2:psycopg2')
class PostgrePsycoNew(PostgrePsyco, PostgreNew):
pass
@adapters.register_for('postgres2:pg8000')
class PostgrePG8000New(PostgrePG8000, PostgreNew):
pass
@adapters.register_for('postgres3')
class PostgreBoolean(PostgreNew):
def _get_json_dialect(self):
from ..dialects.postgre import PostgreDialectBooleanJSON
return PostgreDialectBooleanJSON
def _get_json_parser(self):
from ..parsers.postgre import PostgreBooleanAutoJSONParser
return PostgreBooleanAutoJSONParser
@adapters.register_for('postgres3:psycopg2')
class PostgrePsycoBoolean(PostgrePsycoNew, PostgreBoolean):
pass
@adapters.register_for('postgres3:pg8000')
class PostgrePG8000Boolean(PostgrePG8000New, PostgreBoolean):
pass
@adapters.register_for('jdbc:postgres')
class JDBCPostgre(Postgre):
drivers = ('zxJDBC',)
REGEX_URI = re.compile(
'^(?P<user>[^:@]+)(\:(?P<password>[^@]*))?@(?P<host>\[[^/]+\]|' +
'[^\:/]+)(\:(?P<port>[0-9]+))?/(?P<db>.+)$')
def _initialize_(self, do_connect):
super(Postgre, self)._initialize_(do_connect)
ruri = self.uri.split('://', 1)[1]
m = self.REGEX_URI.match(ruri)
if not m:
raise SyntaxError("Invalid URI string in DAL")
user = self.credential_decoder(m.group('user'))
if not user:
raise SyntaxError('User required')
password = self.credential_decoder(m.group('password'))
if not password:
password = ''
host = m.group('host')
if not host:
raise SyntaxError('Host name required')
db = m.group('db')
if not db:
raise SyntaxError('Database name required')
port = m.group('port') or '5432'
self.dsn = (
'jdbc:postgresql://%s:%s/%s' % (host, port, db), user, password)
# choose diver according uri
if self.driver:
self.__version__ = "%s %s" % (self.driver.__name__,
self.driver.__version__)
else:
self.__version__ = None
THREAD_LOCAL._pydal_last_insert_ = None
self._mock_reconnect()
def connector(self):
return self.driver.connect(*self.dsn, **self.driver_args)
def after_connection(self):
self.connection.set_client_encoding('UTF8')
self.execute('BEGIN;')
self.execute("SET CLIENT_ENCODING TO 'UNICODE';")
def _config_json(self):
use_json = self.connection.dbversion >= "9.2.0"
if use_json:
self.dialect = self._get_json_dialect()(self)