966 lines
36 KiB
Python
966 lines
36 KiB
Python
import copy
|
|
import sys
|
|
import types
|
|
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
from .._compat import PY2, with_metaclass, iterkeys, iteritems, hashlib_md5, \
|
|
integer_types, basestring
|
|
from .._globals import IDENTITY
|
|
from ..connection import ConnectionPool
|
|
from ..exceptions import NotOnNOSQLError
|
|
from ..helpers.classes import Reference, ExecutionHandler, SQLCustomType, \
|
|
SQLALL, NullDriver
|
|
from ..helpers.methods import use_common_filters, xorify, merge_tablemaps
|
|
from ..helpers.regex import REGEX_SELECT_AS_PARSER, REGEX_TABLE_DOT_FIELD
|
|
from ..migrator import Migrator
|
|
from ..objects import Table, Field, Expression, Query, Rows, IterRows, \
|
|
LazySet, LazyReferenceGetter, VirtualCommand, Select
|
|
from ..utils import deprecated
|
|
from . import AdapterMeta, with_connection, with_connection_or_raise
|
|
|
|
|
|
CALLABLETYPES = (
|
|
types.LambdaType, types.FunctionType, types.BuiltinFunctionType,
|
|
types.MethodType, types.BuiltinMethodType)
|
|
|
|
|
|
class BaseAdapter(with_metaclass(AdapterMeta, ConnectionPool)):
|
|
dbengine = "None"
|
|
drivers = ()
|
|
uploads_in_blob = False
|
|
support_distributed_transaction = False
|
|
|
|
def __init__(self, db, uri, pool_size=0, folder=None, db_codec='UTF-8',
|
|
credential_decoder=IDENTITY, driver_args={},
|
|
adapter_args={}, do_connect=True, after_connection=None,
|
|
entity_quoting=False):
|
|
super(BaseAdapter, self).__init__()
|
|
self._load_dependencies()
|
|
self.db = db
|
|
self.uri = uri
|
|
self.pool_size = pool_size
|
|
self.folder = folder
|
|
self.db_codec = db_codec
|
|
self.credential_decoder = credential_decoder
|
|
self.driver_args = driver_args
|
|
self.adapter_args = adapter_args
|
|
self.expand = self._expand
|
|
self._after_connection = after_connection
|
|
self.connection = None
|
|
self.find_driver()
|
|
self._initialize_(do_connect)
|
|
if do_connect:
|
|
self.reconnect()
|
|
|
|
def _load_dependencies(self):
|
|
from ..dialects import dialects
|
|
from ..parsers import parsers
|
|
from ..representers import representers
|
|
self.dialect = dialects.get_for(self)
|
|
self.parser = parsers.get_for(self)
|
|
self.representer = representers.get_for(self)
|
|
|
|
def _initialize_(self, do_connect):
|
|
self._find_work_folder()
|
|
|
|
@property
|
|
def types(self):
|
|
return self.dialect.types
|
|
|
|
@property
|
|
def _available_drivers(self):
|
|
return [
|
|
driver for driver in self.drivers
|
|
if driver in iterkeys(self.db._drivers_available)]
|
|
|
|
def _driver_from_uri(self):
|
|
rv = None
|
|
if self.uri:
|
|
items = self.uri.split('://', 1)[0].split(':')
|
|
rv = items[1] if len(items) > 1 else None
|
|
return rv
|
|
|
|
def find_driver(self):
|
|
if getattr(self, 'driver', None) is not None:
|
|
return
|
|
requested_driver = self._driver_from_uri() or \
|
|
self.adapter_args.get('driver')
|
|
if requested_driver:
|
|
if requested_driver in self._available_drivers:
|
|
self.driver_name = requested_driver
|
|
self.driver = self.db._drivers_available[requested_driver]
|
|
else:
|
|
raise RuntimeError(
|
|
'Driver %s is not available' % requested_driver)
|
|
elif self._available_drivers:
|
|
self.driver_name = self._available_drivers[0]
|
|
self.driver = self.db._drivers_available[self.driver_name]
|
|
else:
|
|
raise RuntimeError(
|
|
"No driver of supported ones %s is available" %
|
|
str(self.drivers))
|
|
|
|
def connector(self):
|
|
return self.driver.connect(self.driver_args)
|
|
|
|
def test_connection(self):
|
|
pass
|
|
|
|
@with_connection
|
|
def close_connection(self):
|
|
rv = self.connection.close()
|
|
self.connection = None
|
|
return rv
|
|
|
|
def tables(self, *queries):
|
|
tables = dict()
|
|
for query in queries:
|
|
if isinstance(query, Field):
|
|
key = query.tablename
|
|
if tables.get(key, query.table) is not query.table:
|
|
raise ValueError('Name conflict in table list: %s' % key)
|
|
tables[key] = query.table
|
|
elif isinstance(query, (Expression, Query)):
|
|
tmp = [x for x in (query.first, query.second) if x is not None]
|
|
tables = merge_tablemaps(tables, self.tables(*tmp))
|
|
return tables
|
|
|
|
def get_table(self, *queries):
|
|
tablemap = self.tables(*queries)
|
|
if len(tablemap) == 1:
|
|
return tablemap.popitem()[1]
|
|
elif len(tablemap) < 1:
|
|
raise RuntimeError("No table selected")
|
|
else:
|
|
raise RuntimeError(
|
|
"Too many tables selected (%s)" % str(list(tablemap)))
|
|
|
|
def common_filter(self, query, tablist):
|
|
tenant_fieldname = self.db._request_tenant
|
|
for table in tablist:
|
|
if isinstance(table, basestring):
|
|
table = self.db[table]
|
|
# deal with user provided filters
|
|
if table._common_filter is not None:
|
|
query = query & table._common_filter(query)
|
|
# deal with multi_tenant filters
|
|
if tenant_fieldname in table:
|
|
default = table[tenant_fieldname].default
|
|
if default is not None:
|
|
newquery = table[tenant_fieldname] == default
|
|
if query is None:
|
|
query = newquery
|
|
else:
|
|
query = query & newquery
|
|
return query
|
|
|
|
def _expand(self, expression, field_type=None, colnames=False,
|
|
query_env={}):
|
|
return str(expression)
|
|
|
|
def expand_all(self, fields, tabledict):
|
|
new_fields = []
|
|
append = new_fields.append
|
|
for item in fields:
|
|
if isinstance(item, SQLALL):
|
|
new_fields += item._table
|
|
elif isinstance(item, str):
|
|
m = REGEX_TABLE_DOT_FIELD.match(item)
|
|
if m:
|
|
tablename, fieldname = m.groups()
|
|
append(self.db[tablename][fieldname])
|
|
else:
|
|
append(Expression(self.db, lambda item=item: item))
|
|
else:
|
|
append(item)
|
|
# ## if no fields specified take them all from the requested tables
|
|
if not new_fields:
|
|
for table in tabledict.values():
|
|
for field in table:
|
|
append(field)
|
|
return new_fields
|
|
|
|
def parse_value(self, value, field_itype, field_type, blob_decode=True):
|
|
# [Note - gi0baro] I think next if block can be (should be?) avoided
|
|
if field_type != 'blob' and isinstance(value, str):
|
|
try:
|
|
value = value.decode(self.db._db_codec)
|
|
except Exception:
|
|
pass
|
|
if PY2 and isinstance(value, unicode):
|
|
value = value.encode('utf-8')
|
|
if isinstance(field_type, SQLCustomType):
|
|
value = field_type.decoder(value)
|
|
if not isinstance(field_type, str) or value is None:
|
|
return value
|
|
elif field_type == 'blob' and not blob_decode:
|
|
return value
|
|
else:
|
|
return self.parser.parse(value, field_itype, field_type)
|
|
|
|
def _add_operators_to_parsed_row(self, rid, table, row):
|
|
for key, record_operator in iteritems(self.db.record_operators):
|
|
setattr(row, key, record_operator(row, table, rid))
|
|
if table._db._lazy_tables:
|
|
row['__get_lazy_reference__'] = LazyReferenceGetter(table, rid)
|
|
|
|
def _add_reference_sets_to_parsed_row(self, rid, table, tablename, row):
|
|
for rfield in table._referenced_by:
|
|
referee_link = self.db._referee_name and self.db._referee_name % \
|
|
dict(table=rfield.tablename, field=rfield.name)
|
|
if referee_link and referee_link not in row and \
|
|
referee_link != tablename:
|
|
row[referee_link] = LazySet(rfield, rid)
|
|
|
|
def _regex_select_as_parser(self, colname):
|
|
return REGEX_SELECT_AS_PARSER.search(colname)
|
|
|
|
def _parse(self, row, tmps, fields, colnames, blob_decode,
|
|
cacheable, fields_virtual, fields_lazy):
|
|
new_row = defaultdict(self.db.Row)
|
|
extras = self.db.Row()
|
|
#: let's loop over columns
|
|
for (j, colname) in enumerate(colnames):
|
|
value = row[j]
|
|
tmp = tmps[j]
|
|
tablename = None
|
|
#: do we have a real column?
|
|
if tmp:
|
|
(tablename, fieldname, table, field, ft, fit) = tmp
|
|
colset = new_row[tablename]
|
|
#: parse value
|
|
value = self.parse_value(value, fit, ft, blob_decode)
|
|
if field.filter_out:
|
|
value = field.filter_out(value)
|
|
colset[fieldname] = value
|
|
#! backward compatibility
|
|
if ft == 'id' and fieldname != 'id' and \
|
|
'id' not in table.fields:
|
|
colset['id'] = value
|
|
#: additional parsing for 'id' fields
|
|
if ft == 'id' and not cacheable:
|
|
self._add_operators_to_parsed_row(value, table, colset)
|
|
self._add_reference_sets_to_parsed_row(
|
|
value, table, tablename, colset)
|
|
#: otherwise we set the value in extras
|
|
else:
|
|
value = self.parse_value(
|
|
value, fields[j]._itype, fields[j].type, blob_decode)
|
|
extras[colname] = value
|
|
new_column_name = self._regex_select_as_parser(colname)
|
|
if new_column_name is not None:
|
|
column_name = new_column_name.groups(0)
|
|
new_row[column_name[0]] = value
|
|
#: add extras if needed (eg. operations results)
|
|
if extras:
|
|
new_row['_extra'] = extras
|
|
#: add virtuals
|
|
new_row = self.db.Row(**new_row)
|
|
for tablename in fields_virtual.keys():
|
|
for f, v in fields_virtual[tablename][1]:
|
|
try:
|
|
new_row[tablename][f] = v.f(new_row)
|
|
except (AttributeError, KeyError):
|
|
pass # not enough fields to define virtual field
|
|
for f, v in fields_lazy[tablename][1]:
|
|
try:
|
|
new_row[tablename][f] = v.handler(v.f, new_row)
|
|
except (AttributeError, KeyError):
|
|
pass # not enough fields to define virtual field
|
|
return new_row
|
|
|
|
def _parse_expand_colnames(self, fieldlist):
|
|
"""
|
|
- Expand a list of colnames into a list of
|
|
(tablename, fieldname, table_obj, field_obj, field_type)
|
|
- Create a list of table for virtual/lazy fields
|
|
"""
|
|
fields_virtual = {}
|
|
fields_lazy = {}
|
|
tmps = []
|
|
for field in fieldlist:
|
|
if not isinstance(field, Field):
|
|
tmps.append(None)
|
|
continue
|
|
table = field.table
|
|
tablename, fieldname = table._tablename, field.name
|
|
ft = field.type
|
|
fit = field._itype
|
|
tmps.append((tablename, fieldname, table, field, ft, fit))
|
|
if tablename not in fields_virtual:
|
|
fields_virtual[tablename] = (table, [
|
|
(f.name, f) for f in table._virtual_fields
|
|
])
|
|
fields_lazy[tablename] = (table, [
|
|
(f.name, f) for f in table._virtual_methods
|
|
])
|
|
return (fields_virtual, fields_lazy, tmps)
|
|
|
|
def parse(self, rows, fields, colnames, blob_decode=True, cacheable=False):
|
|
(fields_virtual, fields_lazy, tmps) = \
|
|
self._parse_expand_colnames(fields)
|
|
new_rows = [
|
|
self._parse(
|
|
row, tmps, fields, colnames, blob_decode, cacheable,
|
|
fields_virtual, fields_lazy)
|
|
for row in rows
|
|
]
|
|
rowsobj = self.db.Rows(self.db, new_rows, colnames, rawrows=rows,
|
|
fields=fields)
|
|
# Old style virtual fields
|
|
for tablename, tmp in fields_virtual.items():
|
|
table = tmp[0]
|
|
# ## old style virtual fields
|
|
for item in table.virtualfields:
|
|
try:
|
|
rowsobj = rowsobj.setvirtualfields(**{tablename: item})
|
|
except (KeyError, AttributeError):
|
|
# to avoid breaking virtualfields when partial select
|
|
pass
|
|
return rowsobj
|
|
|
|
def iterparse(self, sql, fields, colnames, blob_decode=True,
|
|
cacheable=False):
|
|
"""
|
|
Iterator to parse one row at a time.
|
|
It doesn't support the old style virtual fields
|
|
"""
|
|
return IterRows(self.db, sql, fields, colnames, blob_decode, cacheable)
|
|
|
|
def adapt(self, value):
|
|
return value
|
|
|
|
def represent(self, obj, field_type):
|
|
if isinstance(obj, CALLABLETYPES):
|
|
obj = obj()
|
|
return self.representer.represent(obj, field_type)
|
|
|
|
def _drop_table_cleanup(self, table):
|
|
del self.db[table._tablename]
|
|
del self.db.tables[self.db.tables.index(table._tablename)]
|
|
self.db._remove_references_to(table)
|
|
|
|
def drop_table(self, table, mode=''):
|
|
self._drop_table_cleanup(table)
|
|
|
|
def rowslice(self, rows, minimum=0, maximum=None):
|
|
return rows
|
|
|
|
def sqlsafe_table(self, tablename, original_tablename=None):
|
|
return tablename
|
|
|
|
def sqlsafe_field(self, fieldname):
|
|
return fieldname
|
|
|
|
|
|
class DebugHandler(ExecutionHandler):
|
|
def before_execute(self, command):
|
|
self.adapter.db.logger.debug('SQL: %s' % command)
|
|
|
|
|
|
class SQLAdapter(BaseAdapter):
|
|
commit_on_alter_table = False
|
|
# [Note - gi0baro] can_select_for_update should be deprecated and removed
|
|
can_select_for_update = True
|
|
execution_handlers = []
|
|
migrator_cls = Migrator
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(SQLAdapter, self).__init__(*args, **kwargs)
|
|
migrator_cls = self.adapter_args.get('migrator', self.migrator_cls)
|
|
self.migrator = migrator_cls(self)
|
|
self.execution_handlers = list(self.db.execution_handlers)
|
|
if self.db._debug:
|
|
self.execution_handlers.insert(0, DebugHandler)
|
|
|
|
def test_connection(self):
|
|
self.execute('SELECT 1;')
|
|
|
|
def represent(self, obj, field_type):
|
|
if isinstance(obj, (Expression, Field)):
|
|
return str(obj)
|
|
return super(SQLAdapter, self).represent(obj, field_type)
|
|
|
|
def adapt(self, obj):
|
|
return "'%s'" % obj.replace("'", "''")
|
|
|
|
def smart_adapt(self, obj):
|
|
if isinstance(obj, (int, float)):
|
|
return str(obj)
|
|
return self.adapt(str(obj))
|
|
|
|
def fetchall(self):
|
|
return self.cursor.fetchall()
|
|
|
|
def fetchone(self):
|
|
return self.cursor.fetchone()
|
|
|
|
def _build_handlers_for_execution(self):
|
|
rv = []
|
|
for handler_class in self.execution_handlers:
|
|
rv.append(handler_class(self))
|
|
return rv
|
|
|
|
def filter_sql_command(self, command):
|
|
return command
|
|
|
|
@with_connection_or_raise
|
|
def execute(self, *args, **kwargs):
|
|
command = self.filter_sql_command(args[0])
|
|
handlers = self._build_handlers_for_execution()
|
|
for handler in handlers:
|
|
handler.before_execute(command)
|
|
rv = self.cursor.execute(command, *args[1:], **kwargs)
|
|
for handler in handlers:
|
|
handler.after_execute(command)
|
|
return rv
|
|
|
|
def _expand(self, expression, field_type=None, colnames=False,
|
|
query_env={}):
|
|
if isinstance(expression, Field):
|
|
if not colnames:
|
|
rv = expression.sqlsafe
|
|
else:
|
|
rv = expression.longname
|
|
if field_type == 'string' and expression.type not in (
|
|
'string', 'text', 'json', 'jsonb', 'password'):
|
|
rv = self.dialect.cast(rv, self.types['text'], query_env)
|
|
elif isinstance(expression, (Expression, Query)):
|
|
first = expression.first
|
|
second = expression.second
|
|
op = expression.op
|
|
optional_args = expression.optional_args or {}
|
|
optional_args['query_env'] = query_env
|
|
if second is not None:
|
|
rv = op(first, second, **optional_args)
|
|
elif first is not None:
|
|
rv = op(first, **optional_args)
|
|
elif isinstance(op, str):
|
|
if op.endswith(';'):
|
|
op = op[:-1]
|
|
rv = '(%s)' % op
|
|
else:
|
|
rv = op()
|
|
elif field_type:
|
|
rv = self.represent(expression, field_type)
|
|
elif isinstance(expression, (list, tuple)):
|
|
rv = ','.join(self.represent(item, field_type)
|
|
for item in expression)
|
|
elif isinstance(expression, bool):
|
|
rv = self.dialect.true_exp if expression else \
|
|
self.dialect.false_exp
|
|
else:
|
|
rv = expression
|
|
return str(rv)
|
|
|
|
def _expand_for_index(self, expression, field_type=None, colnames=False,
|
|
query_env={}):
|
|
if isinstance(expression, Field):
|
|
return expression._rname
|
|
return self._expand(expression, field_type, colnames, query_env)
|
|
|
|
@contextmanager
|
|
def index_expander(self):
|
|
self.expand = self._expand_for_index
|
|
yield
|
|
self.expand = self._expand
|
|
|
|
def lastrowid(self, table):
|
|
return self.cursor.lastrowid
|
|
|
|
def _insert(self, table, fields):
|
|
if fields:
|
|
return self.dialect.insert(
|
|
table._rname,
|
|
','.join(el[0]._rname for el in fields),
|
|
','.join(self.expand(v, f.type) for f, v in fields))
|
|
return self.dialect.insert_empty(table._rname)
|
|
|
|
def insert(self, table, fields):
|
|
query = self._insert(table, fields)
|
|
try:
|
|
self.execute(query)
|
|
except:
|
|
e = sys.exc_info()[1]
|
|
if hasattr(table, '_on_insert_error'):
|
|
return table._on_insert_error(table, fields, e)
|
|
raise e
|
|
if hasattr(table, '_primarykey'):
|
|
pkdict = dict([
|
|
(k[0].name, k[1]) for k in fields
|
|
if k[0].name in table._primarykey])
|
|
if pkdict:
|
|
return pkdict
|
|
id = self.lastrowid(table)
|
|
if hasattr(table, '_primarykey') and len(table._primarykey) == 1:
|
|
id = {table._primarykey[0]: id}
|
|
if not isinstance(id, integer_types):
|
|
return id
|
|
rid = Reference(id)
|
|
(rid._table, rid._record) = (table, None)
|
|
return rid
|
|
|
|
def _update(self, table, query, fields):
|
|
sql_q = ''
|
|
query_env = dict(current_scope=[table._tablename])
|
|
if query:
|
|
if use_common_filters(query):
|
|
query = self.common_filter(query, [table])
|
|
sql_q = self.expand(query, query_env=query_env)
|
|
sql_v = ','.join([
|
|
'%s=%s' % (field._rname,
|
|
self.expand(value, field.type, query_env=query_env))
|
|
for (field, value) in fields])
|
|
return self.dialect.update(table, sql_v, sql_q)
|
|
|
|
def update(self, table, query, fields):
|
|
sql = self._update(table, query, fields)
|
|
try:
|
|
self.execute(sql)
|
|
except:
|
|
e = sys.exc_info()[1]
|
|
if hasattr(table, '_on_update_error'):
|
|
return table._on_update_error(table, query, fields, e)
|
|
raise e
|
|
try:
|
|
return self.cursor.rowcount
|
|
except:
|
|
return None
|
|
|
|
def _delete(self, table, query):
|
|
sql_q = ''
|
|
query_env = dict(current_scope=[table._tablename])
|
|
if query:
|
|
if use_common_filters(query):
|
|
query = self.common_filter(query, [table])
|
|
sql_q = self.expand(query, query_env=query_env)
|
|
return self.dialect.delete(table, sql_q)
|
|
|
|
def delete(self, table, query):
|
|
sql = self._delete(table, query)
|
|
self.execute(sql)
|
|
try:
|
|
return self.cursor.rowcount
|
|
except:
|
|
return None
|
|
|
|
def _colexpand(self, field, query_env):
|
|
return self.expand(field, colnames=True, query_env=query_env)
|
|
|
|
def _geoexpand(self, field, query_env):
|
|
if isinstance(field.type, str) and field.type.startswith('geo') and \
|
|
isinstance(field, Field):
|
|
field = field.st_astext()
|
|
return self.expand(field, query_env=query_env)
|
|
|
|
def _build_joins_for_select(self, tablenames, param):
|
|
if not isinstance(param, (tuple, list)):
|
|
param = [param]
|
|
tablemap = {}
|
|
for item in param:
|
|
if isinstance(item, Expression):
|
|
item = item.first
|
|
key = item._tablename
|
|
if tablemap.get(key, item) is not item:
|
|
raise ValueError('Name conflict in table list: %s' % key)
|
|
tablemap[key] = item
|
|
join_tables = [
|
|
t._tablename for t in param if not isinstance(t, Expression)
|
|
]
|
|
join_on = [t for t in param if isinstance(t, Expression)]
|
|
tables_to_merge = {}
|
|
for t in join_on:
|
|
tables_to_merge = merge_tablemaps(tables_to_merge, self.tables(t))
|
|
join_on_tables = [t.first._tablename for t in join_on]
|
|
for t in join_on_tables:
|
|
if t in tables_to_merge:
|
|
tables_to_merge.pop(t)
|
|
important_tablenames = join_tables + join_on_tables + \
|
|
list(tables_to_merge)
|
|
excluded = [
|
|
t for t in tablenames if t not in important_tablenames
|
|
]
|
|
return (
|
|
join_tables, join_on, tables_to_merge, join_on_tables,
|
|
important_tablenames, excluded, tablemap
|
|
)
|
|
|
|
def _select_wcols(self, query, fields, left=False, join=False,
|
|
distinct=False, orderby=False, groupby=False,
|
|
having=False, limitby=False, orderby_on_limitby=True,
|
|
for_update=False, outer_scoped=[], required=None,
|
|
cache=None, cacheable=None, processor=None):
|
|
#: parse tablemap
|
|
tablemap = self.tables(query)
|
|
#: apply common filters if needed
|
|
if use_common_filters(query):
|
|
query = self.common_filter(query, list(tablemap.values()))
|
|
#: auto-adjust tables
|
|
tablemap = merge_tablemaps(tablemap, self.tables(*fields))
|
|
#: remove outer scoped tables if needed
|
|
for item in outer_scoped:
|
|
# FIXME: check for name conflicts
|
|
tablemap.pop(item, None)
|
|
if len(tablemap) < 1:
|
|
raise SyntaxError('Set: no tables selected')
|
|
query_tables = list(tablemap)
|
|
#: check for_update argument
|
|
# [Note - gi0baro] I think this should be removed since useless?
|
|
# should affect only NoSQL?
|
|
if self.can_select_for_update is False and for_update is True:
|
|
raise SyntaxError('invalid select attribute: for_update')
|
|
#: build joins (inner, left outer) and table names
|
|
if join:
|
|
(
|
|
# FIXME? ijoin_tables is never used
|
|
ijoin_tables, ijoin_on, itables_to_merge, ijoin_on_tables,
|
|
iimportant_tablenames, iexcluded, itablemap
|
|
) = self._build_joins_for_select(tablemap, join)
|
|
tablemap = merge_tablemaps(tablemap, itables_to_merge)
|
|
tablemap = merge_tablemaps(tablemap, itablemap)
|
|
if left:
|
|
(
|
|
join_tables, join_on, tables_to_merge, join_on_tables,
|
|
important_tablenames, excluded, jtablemap
|
|
) = self._build_joins_for_select(tablemap, left)
|
|
tablemap = merge_tablemaps(tablemap, tables_to_merge)
|
|
tablemap = merge_tablemaps(tablemap, jtablemap)
|
|
current_scope = outer_scoped + list(tablemap)
|
|
query_env = dict(current_scope=current_scope,
|
|
parent_scope=outer_scoped)
|
|
#: prepare columns and expand fields
|
|
colnames = [self._colexpand(x, query_env) for x in fields]
|
|
sql_fields = ', '.join(self._geoexpand(x, query_env) for x in fields)
|
|
table_alias = lambda name: tablemap[name].query_name(outer_scoped)[0]
|
|
if join and not left:
|
|
cross_joins = iexcluded + list(itables_to_merge)
|
|
tokens = [table_alias(cross_joins[0])]
|
|
tokens += [self.dialect.cross_join(table_alias(t), query_env)
|
|
for t in cross_joins[1:]]
|
|
tokens += [self.dialect.join(t, query_env) for t in ijoin_on]
|
|
sql_t = ' '.join(tokens)
|
|
elif not join and left:
|
|
cross_joins = excluded + list(tables_to_merge)
|
|
tokens = [table_alias(cross_joins[0])]
|
|
tokens += [self.dialect.cross_join(table_alias(t), query_env)
|
|
for t in cross_joins[1:]]
|
|
# FIXME: WTF? This is not correct syntax at least on PostgreSQL
|
|
if join_tables:
|
|
tokens.append(self.dialect.left_join(','.join([table_alias(t)
|
|
for t in join_tables]), query_env))
|
|
tokens += [self.dialect.left_join(t, query_env) for t in join_on]
|
|
sql_t = ' '.join(tokens)
|
|
elif join and left:
|
|
all_tables_in_query = set(
|
|
important_tablenames + iimportant_tablenames + query_tables)
|
|
tables_in_joinon = set(join_on_tables + ijoin_on_tables)
|
|
tables_not_in_joinon = \
|
|
list(all_tables_in_query.difference(tables_in_joinon))
|
|
tokens = [table_alias(tables_not_in_joinon[0])]
|
|
tokens += [self.dialect.cross_join(table_alias(t), query_env)
|
|
for t in tables_not_in_joinon[1:]]
|
|
tokens += [self.dialect.join(t, query_env) for t in ijoin_on]
|
|
# FIXME: WTF? This is not correct syntax at least on PostgreSQL
|
|
if join_tables:
|
|
tokens.append(self.dialect.left_join(','.join([table_alias(t)
|
|
for t in join_tables]), query_env))
|
|
tokens += [self.dialect.left_join(t, query_env) for t in join_on]
|
|
sql_t = ' '.join(tokens)
|
|
else:
|
|
sql_t = ', '.join(table_alias(t) for t in query_tables)
|
|
#: expand query if needed
|
|
if query:
|
|
query = self.expand(query, query_env=query_env)
|
|
if having:
|
|
having = self.expand(having, query_env=query_env)
|
|
#: groupby
|
|
sql_grp = groupby
|
|
if groupby:
|
|
if isinstance(groupby, (list, tuple)):
|
|
groupby = xorify(groupby)
|
|
sql_grp = self.expand(groupby, query_env=query_env)
|
|
#: orderby
|
|
sql_ord = False
|
|
if orderby:
|
|
if isinstance(orderby, (list, tuple)):
|
|
orderby = xorify(orderby)
|
|
if str(orderby) == '<random>':
|
|
sql_ord = self.dialect.random
|
|
else:
|
|
sql_ord = self.expand(orderby, query_env=query_env)
|
|
#: set default orderby if missing
|
|
if (limitby and not groupby and query_tables and orderby_on_limitby and
|
|
not orderby):
|
|
sql_ord = ', '.join([
|
|
tablemap[t][x].sqlsafe
|
|
for t in query_tables if not isinstance(tablemap[t], Select)
|
|
for x in (hasattr(tablemap[t], '_primarykey') and
|
|
tablemap[t]._primarykey or ['_id'])
|
|
])
|
|
#: build sql using dialect
|
|
return colnames, self.dialect.select(
|
|
sql_fields, sql_t, query, sql_grp, having, sql_ord, limitby,
|
|
distinct, for_update and self.can_select_for_update
|
|
)
|
|
|
|
def _select(self, query, fields, attributes):
|
|
return self._select_wcols(query, fields, **attributes)[1]
|
|
|
|
def nested_select(self, query, fields, attributes):
|
|
return Select(self.db, query, fields, attributes)
|
|
|
|
def _select_aux_execute(self, sql):
|
|
self.execute(sql)
|
|
return self.cursor.fetchall()
|
|
|
|
def _select_aux(self, sql, fields, attributes, colnames):
|
|
cache = attributes.get('cache', None)
|
|
if not cache:
|
|
rows = self._select_aux_execute(sql)
|
|
else:
|
|
if isinstance(cache, dict):
|
|
cache_model = cache['model']
|
|
time_expire = cache['expiration']
|
|
key = cache.get('key')
|
|
if not key:
|
|
key = self.uri + '/' + sql + '/rows'
|
|
key = hashlib_md5(key).hexdigest()
|
|
else:
|
|
(cache_model, time_expire) = cache
|
|
key = self.uri + '/' + sql + '/rows'
|
|
key = hashlib_md5(key).hexdigest()
|
|
rows = cache_model(
|
|
key,
|
|
lambda self=self, sql=sql: self._select_aux_execute(sql),
|
|
time_expire)
|
|
if isinstance(rows, tuple):
|
|
rows = list(rows)
|
|
limitby = attributes.get('limitby', None) or (0,)
|
|
rows = self.rowslice(rows, limitby[0], None)
|
|
processor = attributes.get('processor', self.parse)
|
|
cacheable = attributes.get('cacheable', False)
|
|
return processor(rows, fields, colnames, cacheable=cacheable)
|
|
|
|
def _cached_select(self, cache, sql, fields, attributes, colnames):
|
|
del attributes['cache']
|
|
(cache_model, time_expire) = cache
|
|
key = self.uri + '/' + sql
|
|
key = hashlib_md5(key).hexdigest()
|
|
args = (sql, fields, attributes, colnames)
|
|
ret = cache_model(
|
|
key,
|
|
lambda self=self, args=args: self._select_aux(*args),
|
|
time_expire)
|
|
ret._restore_fields(fields)
|
|
return ret
|
|
|
|
def select(self, query, fields, attributes):
|
|
colnames, sql = self._select_wcols(query, fields, **attributes)
|
|
cache = attributes.get('cache', None)
|
|
if cache and attributes.get('cacheable', False):
|
|
return self._cached_select(
|
|
cache, sql, fields, attributes, colnames)
|
|
return self._select_aux(sql, fields, attributes, colnames)
|
|
|
|
def iterselect(self, query, fields, attributes):
|
|
colnames, sql = self._select_wcols(query, fields, **attributes)
|
|
cacheable = attributes.get('cacheable', False)
|
|
return self.iterparse(sql, fields, colnames, cacheable=cacheable)
|
|
|
|
def _count(self, query, distinct=None):
|
|
tablemap = self.tables(query)
|
|
tablenames = list(tablemap)
|
|
tables = list(tablemap.values())
|
|
query_env = dict(current_scope=tablenames)
|
|
sql_q = ''
|
|
if query:
|
|
if use_common_filters(query):
|
|
query = self.common_filter(query, tables)
|
|
sql_q = self.expand(query, query_env=query_env)
|
|
sql_t = ','.join(self.table_alias(t, []) for t in tables)
|
|
sql_fields = '*'
|
|
if distinct:
|
|
if isinstance(distinct, (list, tuple)):
|
|
distinct = xorify(distinct)
|
|
sql_fields = self.expand(distinct, query_env=query_env)
|
|
return self.dialect.select(
|
|
self.dialect.count(sql_fields, distinct), sql_t, sql_q
|
|
)
|
|
|
|
def count(self, query, distinct=None):
|
|
self.execute(self._count(query, distinct))
|
|
return self.cursor.fetchone()[0]
|
|
|
|
def bulk_insert(self, table, items):
|
|
return [self.insert(table, item) for item in items]
|
|
|
|
def create_table(self, *args, **kwargs):
|
|
return self.migrator.create_table(*args, **kwargs)
|
|
|
|
def _drop_table_cleanup(self, table):
|
|
super(SQLAdapter, self)._drop_table_cleanup(table)
|
|
if table._dbt:
|
|
self.migrator.file_delete(table._dbt)
|
|
self.migrator.log('success!\n', table)
|
|
|
|
def drop_table(self, table, mode=''):
|
|
queries = self.dialect.drop_table(table, mode)
|
|
for query in queries:
|
|
if table._dbt:
|
|
self.migrator.log(query + '\n', table)
|
|
self.execute(query)
|
|
self.commit()
|
|
self._drop_table_cleanup(table)
|
|
|
|
@deprecated('drop', 'drop_table', 'SQLAdapter')
|
|
def drop(self, table, mode=''):
|
|
return self.drop_table(table, mode='')
|
|
|
|
def truncate(self, table, mode=''):
|
|
# Prepare functions "write_to_logfile" and "close_logfile"
|
|
try:
|
|
queries = self.dialect.truncate(table, mode)
|
|
for query in queries:
|
|
self.migrator.log(query + '\n', table)
|
|
self.execute(query)
|
|
self.migrator.log('success!\n', table)
|
|
finally:
|
|
pass
|
|
|
|
def create_index(self, table, index_name, *fields, **kwargs):
|
|
expressions = [
|
|
field._rname if isinstance(field, Field) else field
|
|
for field in fields]
|
|
sql = self.dialect.create_index(
|
|
index_name, table, expressions, **kwargs)
|
|
try:
|
|
self.execute(sql)
|
|
self.commit()
|
|
except Exception as e:
|
|
self.rollback()
|
|
err = 'Error creating index %s\n Driver error: %s\n' + \
|
|
' SQL instruction: %s'
|
|
raise RuntimeError(err % (index_name, str(e), sql))
|
|
return True
|
|
|
|
def drop_index(self, table, index_name):
|
|
sql = self.dialect.drop_index(index_name, table)
|
|
try:
|
|
self.execute(sql)
|
|
self.commit()
|
|
except Exception as e:
|
|
self.rollback()
|
|
err = 'Error dropping index %s\n Driver error: %s'
|
|
raise RuntimeError(err % (index_name, str(e)))
|
|
return True
|
|
|
|
def distributed_transaction_begin(self, key):
|
|
pass
|
|
|
|
@with_connection
|
|
def commit(self):
|
|
return self.connection.commit()
|
|
|
|
@with_connection
|
|
def rollback(self):
|
|
return self.connection.rollback()
|
|
|
|
@with_connection
|
|
def prepare(self, key):
|
|
self.connection.prepare()
|
|
|
|
@with_connection
|
|
def commit_prepared(self, key):
|
|
self.connection.commit()
|
|
|
|
@with_connection
|
|
def rollback_prepared(self, key):
|
|
self.connection.rollback()
|
|
|
|
def create_sequence_and_triggers(self, query, table, **args):
|
|
self.execute(query)
|
|
|
|
def sqlsafe_table(self, tablename, original_tablename=None):
|
|
if original_tablename is not None:
|
|
return self.dialect.alias(original_tablename, tablename)
|
|
return self.dialect.quote(tablename)
|
|
|
|
def sqlsafe_field(self, fieldname):
|
|
return self.dialect.quote(fieldname)
|
|
|
|
def table_alias(self, tbl, current_scope=[]):
|
|
if isinstance(tbl, basestring):
|
|
tbl = self.db[tbl]
|
|
return tbl.query_name(current_scope)[0]
|
|
|
|
def id_query(self, table):
|
|
pkeys = getattr(table, '_primarykey', None)
|
|
if pkeys:
|
|
return table[pkeys[0]] != None
|
|
return table._id != None
|
|
|
|
|
|
class NoSQLAdapter(BaseAdapter):
|
|
can_select_for_update = False
|
|
|
|
def commit(self):
|
|
pass
|
|
|
|
def rollback(self):
|
|
pass
|
|
|
|
def prepare(self):
|
|
pass
|
|
|
|
def commit_prepared(self, key):
|
|
pass
|
|
|
|
def rollback_prepared(self, key):
|
|
pass
|
|
|
|
def id_query(self, table):
|
|
return table._id > 0
|
|
|
|
def create_table(self, table, migrate=True, fake_migrate=False,
|
|
polymodel=None):
|
|
table._dbt = None
|
|
table._notnulls = []
|
|
for field_name in table.fields:
|
|
if table[field_name].notnull:
|
|
table._notnulls.append(field_name)
|
|
table._uniques = []
|
|
for field_name in table.fields:
|
|
if table[field_name].unique:
|
|
# this is unnecessary if the fields are indexed and unique
|
|
table._uniques.append(field_name)
|
|
|
|
def drop_table(self, table, mode=''):
|
|
ctable = self.connection[table._tablename]
|
|
ctable.drop()
|
|
self._drop_table_cleanup(table)
|
|
|
|
@deprecated('drop', 'drop_table', 'SQLAdapter')
|
|
def drop(self, table, mode=''):
|
|
return self.drop_table(table, mode='')
|
|
|
|
def _select(self, *args, **kwargs):
|
|
raise NotOnNOSQLError(
|
|
"Nested queries are not supported on NoSQL databases")
|
|
|
|
def nested_select(self, *args, **kwargs):
|
|
raise NotOnNOSQLError(
|
|
"Nested queries are not supported on NoSQL databases")
|
|
|
|
|
|
class NullAdapter(BaseAdapter):
|
|
|
|
def _load_dependencies(self):
|
|
from ..dialects.base import CommonDialect
|
|
self.dialect = CommonDialect(self)
|
|
|
|
def find_driver(self):
|
|
pass
|
|
|
|
def connector(self):
|
|
return NullDriver()
|