import copy import sys import types from collections import defaultdict from contextlib import contextmanager from .._compat import PY2, with_metaclass, iterkeys, iteritems, hashlib_md5, \ integer_types, basestring from .._globals import IDENTITY from ..connection import ConnectionPool from ..exceptions import NotOnNOSQLError from ..helpers.classes import Reference, ExecutionHandler, SQLCustomType, \ SQLALL, NullDriver from ..helpers.methods import use_common_filters, xorify, merge_tablemaps from ..helpers.regex import REGEX_SELECT_AS_PARSER, REGEX_TABLE_DOT_FIELD from ..migrator import Migrator from ..objects import Table, Field, Expression, Query, Rows, IterRows, \ LazySet, LazyReferenceGetter, VirtualCommand, Select from ..utils import deprecated from . import AdapterMeta, with_connection, with_connection_or_raise CALLABLETYPES = ( types.LambdaType, types.FunctionType, types.BuiltinFunctionType, types.MethodType, types.BuiltinMethodType) class BaseAdapter(with_metaclass(AdapterMeta, ConnectionPool)): dbengine = "None" drivers = () uploads_in_blob = False support_distributed_transaction = False def __init__(self, db, uri, pool_size=0, folder=None, db_codec='UTF-8', credential_decoder=IDENTITY, driver_args={}, adapter_args={}, do_connect=True, after_connection=None, entity_quoting=False): super(BaseAdapter, self).__init__() self._load_dependencies() self.db = db self.uri = uri self.pool_size = pool_size self.folder = folder self.db_codec = db_codec self.credential_decoder = credential_decoder self.driver_args = driver_args self.adapter_args = adapter_args self.expand = self._expand self._after_connection = after_connection self.connection = None self.find_driver() self._initialize_(do_connect) if do_connect: self.reconnect() def _load_dependencies(self): from ..dialects import dialects from ..parsers import parsers from ..representers import representers self.dialect = dialects.get_for(self) self.parser = parsers.get_for(self) self.representer = representers.get_for(self) def _initialize_(self, do_connect): self._find_work_folder() @property def types(self): return self.dialect.types @property def _available_drivers(self): return [ driver for driver in self.drivers if driver in iterkeys(self.db._drivers_available)] def _driver_from_uri(self): rv = None if self.uri: items = self.uri.split('://', 1)[0].split(':') rv = items[1] if len(items) > 1 else None return rv def find_driver(self): if getattr(self, 'driver', None) is not None: return requested_driver = self._driver_from_uri() or \ self.adapter_args.get('driver') if requested_driver: if requested_driver in self._available_drivers: self.driver_name = requested_driver self.driver = self.db._drivers_available[requested_driver] else: raise RuntimeError( 'Driver %s is not available' % requested_driver) elif self._available_drivers: self.driver_name = self._available_drivers[0] self.driver = self.db._drivers_available[self.driver_name] else: raise RuntimeError( "No driver of supported ones %s is available" % str(self.drivers)) def connector(self): return self.driver.connect(self.driver_args) def test_connection(self): pass @with_connection def close_connection(self): rv = self.connection.close() self.connection = None return rv def tables(self, *queries): tables = dict() for query in queries: if isinstance(query, Field): key = query.tablename if tables.get(key, query.table) is not query.table: raise ValueError('Name conflict in table list: %s' % key) tables[key] = query.table elif isinstance(query, (Expression, Query)): tmp = [x for x in (query.first, query.second) if x is not None] tables = merge_tablemaps(tables, self.tables(*tmp)) return tables def get_table(self, *queries): tablemap = self.tables(*queries) if len(tablemap) == 1: return tablemap.popitem()[1] elif len(tablemap) < 1: raise RuntimeError("No table selected") else: raise RuntimeError( "Too many tables selected (%s)" % str(list(tablemap))) def common_filter(self, query, tablist): tenant_fieldname = self.db._request_tenant for table in tablist: if isinstance(table, basestring): table = self.db[table] # deal with user provided filters if table._common_filter is not None: query = query & table._common_filter(query) # deal with multi_tenant filters if tenant_fieldname in table: default = table[tenant_fieldname].default if default is not None: newquery = table[tenant_fieldname] == default if query is None: query = newquery else: query = query & newquery return query def _expand(self, expression, field_type=None, colnames=False, query_env={}): return str(expression) def expand_all(self, fields, tabledict): new_fields = [] append = new_fields.append for item in fields: if isinstance(item, SQLALL): new_fields += item._table elif isinstance(item, str): m = REGEX_TABLE_DOT_FIELD.match(item) if m: tablename, fieldname = m.groups() append(self.db[tablename][fieldname]) else: append(Expression(self.db, lambda item=item: item)) else: append(item) # ## if no fields specified take them all from the requested tables if not new_fields: for table in tabledict.values(): for field in table: append(field) return new_fields def parse_value(self, value, field_itype, field_type, blob_decode=True): # [Note - gi0baro] I think next if block can be (should be?) avoided if field_type != 'blob' and isinstance(value, str): try: value = value.decode(self.db._db_codec) except Exception: pass if PY2 and isinstance(value, unicode): value = value.encode('utf-8') if isinstance(field_type, SQLCustomType): value = field_type.decoder(value) if not isinstance(field_type, str) or value is None: return value elif field_type == 'blob' and not blob_decode: return value else: return self.parser.parse(value, field_itype, field_type) def _add_operators_to_parsed_row(self, rid, table, row): for key, record_operator in iteritems(self.db.record_operators): setattr(row, key, record_operator(row, table, rid)) if table._db._lazy_tables: row['__get_lazy_reference__'] = LazyReferenceGetter(table, rid) def _add_reference_sets_to_parsed_row(self, rid, table, tablename, row): for rfield in table._referenced_by: referee_link = self.db._referee_name and self.db._referee_name % \ dict(table=rfield.tablename, field=rfield.name) if referee_link and referee_link not in row and \ referee_link != tablename: row[referee_link] = LazySet(rfield, rid) def _regex_select_as_parser(self, colname): return REGEX_SELECT_AS_PARSER.search(colname) def _parse(self, row, tmps, fields, colnames, blob_decode, cacheable, fields_virtual, fields_lazy): new_row = defaultdict(self.db.Row) extras = self.db.Row() #: let's loop over columns for (j, colname) in enumerate(colnames): value = row[j] tmp = tmps[j] tablename = None #: do we have a real column? if tmp: (tablename, fieldname, table, field, ft, fit) = tmp colset = new_row[tablename] #: parse value value = self.parse_value(value, fit, ft, blob_decode) if field.filter_out: value = field.filter_out(value) colset[fieldname] = value #! backward compatibility if ft == 'id' and fieldname != 'id' and \ 'id' not in table.fields: colset['id'] = value #: additional parsing for 'id' fields if ft == 'id' and not cacheable: self._add_operators_to_parsed_row(value, table, colset) self._add_reference_sets_to_parsed_row( value, table, tablename, colset) #: otherwise we set the value in extras else: value = self.parse_value( value, fields[j]._itype, fields[j].type, blob_decode) extras[colname] = value new_column_name = self._regex_select_as_parser(colname) if new_column_name is not None: column_name = new_column_name.groups(0) new_row[column_name[0]] = value #: add extras if needed (eg. operations results) if extras: new_row['_extra'] = extras #: add virtuals new_row = self.db.Row(**new_row) for tablename in fields_virtual.keys(): for f, v in fields_virtual[tablename][1]: try: new_row[tablename][f] = v.f(new_row) except (AttributeError, KeyError): pass # not enough fields to define virtual field for f, v in fields_lazy[tablename][1]: try: new_row[tablename][f] = v.handler(v.f, new_row) except (AttributeError, KeyError): pass # not enough fields to define virtual field return new_row def _parse_expand_colnames(self, fieldlist): """ - Expand a list of colnames into a list of (tablename, fieldname, table_obj, field_obj, field_type) - Create a list of table for virtual/lazy fields """ fields_virtual = {} fields_lazy = {} tmps = [] for field in fieldlist: if not isinstance(field, Field): tmps.append(None) continue table = field.table tablename, fieldname = table._tablename, field.name ft = field.type fit = field._itype tmps.append((tablename, fieldname, table, field, ft, fit)) if tablename not in fields_virtual: fields_virtual[tablename] = (table, [ (f.name, f) for f in table._virtual_fields ]) fields_lazy[tablename] = (table, [ (f.name, f) for f in table._virtual_methods ]) return (fields_virtual, fields_lazy, tmps) def parse(self, rows, fields, colnames, blob_decode=True, cacheable=False): (fields_virtual, fields_lazy, tmps) = \ self._parse_expand_colnames(fields) new_rows = [ self._parse( row, tmps, fields, colnames, blob_decode, cacheable, fields_virtual, fields_lazy) for row in rows ] rowsobj = self.db.Rows(self.db, new_rows, colnames, rawrows=rows, fields=fields) # Old style virtual fields for tablename, tmp in fields_virtual.items(): table = tmp[0] # ## old style virtual fields for item in table.virtualfields: try: rowsobj = rowsobj.setvirtualfields(**{tablename: item}) except (KeyError, AttributeError): # to avoid breaking virtualfields when partial select pass return rowsobj def iterparse(self, sql, fields, colnames, blob_decode=True, cacheable=False): """ Iterator to parse one row at a time. It doesn't support the old style virtual fields """ return IterRows(self.db, sql, fields, colnames, blob_decode, cacheable) def adapt(self, value): return value def represent(self, obj, field_type): if isinstance(obj, CALLABLETYPES): obj = obj() return self.representer.represent(obj, field_type) def _drop_table_cleanup(self, table): del self.db[table._tablename] del self.db.tables[self.db.tables.index(table._tablename)] self.db._remove_references_to(table) def drop_table(self, table, mode=''): self._drop_table_cleanup(table) def rowslice(self, rows, minimum=0, maximum=None): return rows def sqlsafe_table(self, tablename, original_tablename=None): return tablename def sqlsafe_field(self, fieldname): return fieldname class DebugHandler(ExecutionHandler): def before_execute(self, command): self.adapter.db.logger.debug('SQL: %s' % command) class SQLAdapter(BaseAdapter): commit_on_alter_table = False # [Note - gi0baro] can_select_for_update should be deprecated and removed can_select_for_update = True execution_handlers = [] migrator_cls = Migrator def __init__(self, *args, **kwargs): super(SQLAdapter, self).__init__(*args, **kwargs) migrator_cls = self.adapter_args.get('migrator', self.migrator_cls) self.migrator = migrator_cls(self) self.execution_handlers = list(self.db.execution_handlers) if self.db._debug: self.execution_handlers.insert(0, DebugHandler) def test_connection(self): self.execute('SELECT 1;') def represent(self, obj, field_type): if isinstance(obj, (Expression, Field)): return str(obj) return super(SQLAdapter, self).represent(obj, field_type) def adapt(self, obj): return "'%s'" % obj.replace("'", "''") def smart_adapt(self, obj): if isinstance(obj, (int, float)): return str(obj) return self.adapt(str(obj)) def fetchall(self): return self.cursor.fetchall() def fetchone(self): return self.cursor.fetchone() def _build_handlers_for_execution(self): rv = [] for handler_class in self.execution_handlers: rv.append(handler_class(self)) return rv def filter_sql_command(self, command): return command @with_connection_or_raise def execute(self, *args, **kwargs): command = self.filter_sql_command(args[0]) handlers = self._build_handlers_for_execution() for handler in handlers: handler.before_execute(command) rv = self.cursor.execute(command, *args[1:], **kwargs) for handler in handlers: handler.after_execute(command) return rv def _expand(self, expression, field_type=None, colnames=False, query_env={}): if isinstance(expression, Field): if not colnames: rv = expression.sqlsafe else: rv = expression.longname if field_type == 'string' and expression.type not in ( 'string', 'text', 'json', 'jsonb', 'password'): rv = self.dialect.cast(rv, self.types['text'], query_env) elif isinstance(expression, (Expression, Query)): first = expression.first second = expression.second op = expression.op optional_args = expression.optional_args or {} optional_args['query_env'] = query_env if second is not None: rv = op(first, second, **optional_args) elif first is not None: rv = op(first, **optional_args) elif isinstance(op, str): if op.endswith(';'): op = op[:-1] rv = '(%s)' % op else: rv = op() elif field_type: rv = self.represent(expression, field_type) elif isinstance(expression, (list, tuple)): rv = ','.join(self.represent(item, field_type) for item in expression) elif isinstance(expression, bool): rv = self.dialect.true_exp if expression else \ self.dialect.false_exp else: rv = expression return str(rv) def _expand_for_index(self, expression, field_type=None, colnames=False, query_env={}): if isinstance(expression, Field): return expression._rname return self._expand(expression, field_type, colnames, query_env) @contextmanager def index_expander(self): self.expand = self._expand_for_index yield self.expand = self._expand def lastrowid(self, table): return self.cursor.lastrowid def _insert(self, table, fields): if fields: return self.dialect.insert( table._rname, ','.join(el[0]._rname for el in fields), ','.join(self.expand(v, f.type) for f, v in fields)) return self.dialect.insert_empty(table._rname) def insert(self, table, fields): query = self._insert(table, fields) try: self.execute(query) except: e = sys.exc_info()[1] if hasattr(table, '_on_insert_error'): return table._on_insert_error(table, fields, e) raise e if hasattr(table, '_primarykey'): pkdict = dict([ (k[0].name, k[1]) for k in fields if k[0].name in table._primarykey]) if pkdict: return pkdict id = self.lastrowid(table) if hasattr(table, '_primarykey') and len(table._primarykey) == 1: id = {table._primarykey[0]: id} if not isinstance(id, integer_types): return id rid = Reference(id) (rid._table, rid._record) = (table, None) return rid def _update(self, table, query, fields): sql_q = '' query_env = dict(current_scope=[table._tablename]) if query: if use_common_filters(query): query = self.common_filter(query, [table]) sql_q = self.expand(query, query_env=query_env) sql_v = ','.join([ '%s=%s' % (field._rname, self.expand(value, field.type, query_env=query_env)) for (field, value) in fields]) return self.dialect.update(table, sql_v, sql_q) def update(self, table, query, fields): sql = self._update(table, query, fields) try: self.execute(sql) except: e = sys.exc_info()[1] if hasattr(table, '_on_update_error'): return table._on_update_error(table, query, fields, e) raise e try: return self.cursor.rowcount except: return None def _delete(self, table, query): sql_q = '' query_env = dict(current_scope=[table._tablename]) if query: if use_common_filters(query): query = self.common_filter(query, [table]) sql_q = self.expand(query, query_env=query_env) return self.dialect.delete(table, sql_q) def delete(self, table, query): sql = self._delete(table, query) self.execute(sql) try: return self.cursor.rowcount except: return None def _colexpand(self, field, query_env): return self.expand(field, colnames=True, query_env=query_env) def _geoexpand(self, field, query_env): if isinstance(field.type, str) and field.type.startswith('geo') and \ isinstance(field, Field): field = field.st_astext() return self.expand(field, query_env=query_env) def _build_joins_for_select(self, tablenames, param): if not isinstance(param, (tuple, list)): param = [param] tablemap = {} for item in param: if isinstance(item, Expression): item = item.first key = item._tablename if tablemap.get(key, item) is not item: raise ValueError('Name conflict in table list: %s' % key) tablemap[key] = item join_tables = [ t._tablename for t in param if not isinstance(t, Expression) ] join_on = [t for t in param if isinstance(t, Expression)] tables_to_merge = {} for t in join_on: tables_to_merge = merge_tablemaps(tables_to_merge, self.tables(t)) join_on_tables = [t.first._tablename for t in join_on] for t in join_on_tables: if t in tables_to_merge: tables_to_merge.pop(t) important_tablenames = join_tables + join_on_tables + \ list(tables_to_merge) excluded = [ t for t in tablenames if t not in important_tablenames ] return ( join_tables, join_on, tables_to_merge, join_on_tables, important_tablenames, excluded, tablemap ) def _select_wcols(self, query, fields, left=False, join=False, distinct=False, orderby=False, groupby=False, having=False, limitby=False, orderby_on_limitby=True, for_update=False, outer_scoped=[], required=None, cache=None, cacheable=None, processor=None): #: parse tablemap tablemap = self.tables(query) #: apply common filters if needed if use_common_filters(query): query = self.common_filter(query, list(tablemap.values())) #: auto-adjust tables tablemap = merge_tablemaps(tablemap, self.tables(*fields)) #: remove outer scoped tables if needed for item in outer_scoped: # FIXME: check for name conflicts tablemap.pop(item, None) if len(tablemap) < 1: raise SyntaxError('Set: no tables selected') query_tables = list(tablemap) #: check for_update argument # [Note - gi0baro] I think this should be removed since useless? # should affect only NoSQL? if self.can_select_for_update is False and for_update is True: raise SyntaxError('invalid select attribute: for_update') #: build joins (inner, left outer) and table names if join: ( # FIXME? ijoin_tables is never used ijoin_tables, ijoin_on, itables_to_merge, ijoin_on_tables, iimportant_tablenames, iexcluded, itablemap ) = self._build_joins_for_select(tablemap, join) tablemap = merge_tablemaps(tablemap, itables_to_merge) tablemap = merge_tablemaps(tablemap, itablemap) if left: ( join_tables, join_on, tables_to_merge, join_on_tables, important_tablenames, excluded, jtablemap ) = self._build_joins_for_select(tablemap, left) tablemap = merge_tablemaps(tablemap, tables_to_merge) tablemap = merge_tablemaps(tablemap, jtablemap) current_scope = outer_scoped + list(tablemap) query_env = dict(current_scope=current_scope, parent_scope=outer_scoped) #: prepare columns and expand fields colnames = [self._colexpand(x, query_env) for x in fields] sql_fields = ', '.join(self._geoexpand(x, query_env) for x in fields) table_alias = lambda name: tablemap[name].query_name(outer_scoped)[0] if join and not left: cross_joins = iexcluded + list(itables_to_merge) tokens = [table_alias(cross_joins[0])] tokens += [self.dialect.cross_join(table_alias(t), query_env) for t in cross_joins[1:]] tokens += [self.dialect.join(t, query_env) for t in ijoin_on] sql_t = ' '.join(tokens) elif not join and left: cross_joins = excluded + list(tables_to_merge) tokens = [table_alias(cross_joins[0])] tokens += [self.dialect.cross_join(table_alias(t), query_env) for t in cross_joins[1:]] # FIXME: WTF? This is not correct syntax at least on PostgreSQL if join_tables: tokens.append(self.dialect.left_join(','.join([table_alias(t) for t in join_tables]), query_env)) tokens += [self.dialect.left_join(t, query_env) for t in join_on] sql_t = ' '.join(tokens) elif join and left: all_tables_in_query = set( important_tablenames + iimportant_tablenames + query_tables) tables_in_joinon = set(join_on_tables + ijoin_on_tables) tables_not_in_joinon = \ list(all_tables_in_query.difference(tables_in_joinon)) tokens = [table_alias(tables_not_in_joinon[0])] tokens += [self.dialect.cross_join(table_alias(t), query_env) for t in tables_not_in_joinon[1:]] tokens += [self.dialect.join(t, query_env) for t in ijoin_on] # FIXME: WTF? This is not correct syntax at least on PostgreSQL if join_tables: tokens.append(self.dialect.left_join(','.join([table_alias(t) for t in join_tables]), query_env)) tokens += [self.dialect.left_join(t, query_env) for t in join_on] sql_t = ' '.join(tokens) else: sql_t = ', '.join(table_alias(t) for t in query_tables) #: expand query if needed if query: query = self.expand(query, query_env=query_env) if having: having = self.expand(having, query_env=query_env) #: groupby sql_grp = groupby if groupby: if isinstance(groupby, (list, tuple)): groupby = xorify(groupby) sql_grp = self.expand(groupby, query_env=query_env) #: orderby sql_ord = False if orderby: if isinstance(orderby, (list, tuple)): orderby = xorify(orderby) if str(orderby) == '': sql_ord = self.dialect.random else: sql_ord = self.expand(orderby, query_env=query_env) #: set default orderby if missing if (limitby and not groupby and query_tables and orderby_on_limitby and not orderby): sql_ord = ', '.join([ tablemap[t][x].sqlsafe for t in query_tables if not isinstance(tablemap[t], Select) for x in (hasattr(tablemap[t], '_primarykey') and tablemap[t]._primarykey or ['_id']) ]) #: build sql using dialect return colnames, self.dialect.select( sql_fields, sql_t, query, sql_grp, having, sql_ord, limitby, distinct, for_update and self.can_select_for_update ) def _select(self, query, fields, attributes): return self._select_wcols(query, fields, **attributes)[1] def nested_select(self, query, fields, attributes): return Select(self.db, query, fields, attributes) def _select_aux_execute(self, sql): self.execute(sql) return self.cursor.fetchall() def _select_aux(self, sql, fields, attributes, colnames): cache = attributes.get('cache', None) if not cache: rows = self._select_aux_execute(sql) else: if isinstance(cache, dict): cache_model = cache['model'] time_expire = cache['expiration'] key = cache.get('key') if not key: key = self.uri + '/' + sql + '/rows' key = hashlib_md5(key).hexdigest() else: (cache_model, time_expire) = cache key = self.uri + '/' + sql + '/rows' key = hashlib_md5(key).hexdigest() rows = cache_model( key, lambda self=self, sql=sql: self._select_aux_execute(sql), time_expire) if isinstance(rows, tuple): rows = list(rows) limitby = attributes.get('limitby', None) or (0,) rows = self.rowslice(rows, limitby[0], None) processor = attributes.get('processor', self.parse) cacheable = attributes.get('cacheable', False) return processor(rows, fields, colnames, cacheable=cacheable) def _cached_select(self, cache, sql, fields, attributes, colnames): del attributes['cache'] (cache_model, time_expire) = cache key = self.uri + '/' + sql key = hashlib_md5(key).hexdigest() args = (sql, fields, attributes, colnames) ret = cache_model( key, lambda self=self, args=args: self._select_aux(*args), time_expire) ret._restore_fields(fields) return ret def select(self, query, fields, attributes): colnames, sql = self._select_wcols(query, fields, **attributes) cache = attributes.get('cache', None) if cache and attributes.get('cacheable', False): return self._cached_select( cache, sql, fields, attributes, colnames) return self._select_aux(sql, fields, attributes, colnames) def iterselect(self, query, fields, attributes): colnames, sql = self._select_wcols(query, fields, **attributes) cacheable = attributes.get('cacheable', False) return self.iterparse(sql, fields, colnames, cacheable=cacheable) def _count(self, query, distinct=None): tablemap = self.tables(query) tablenames = list(tablemap) tables = list(tablemap.values()) query_env = dict(current_scope=tablenames) sql_q = '' if query: if use_common_filters(query): query = self.common_filter(query, tables) sql_q = self.expand(query, query_env=query_env) sql_t = ','.join(self.table_alias(t, []) for t in tables) sql_fields = '*' if distinct: if isinstance(distinct, (list, tuple)): distinct = xorify(distinct) sql_fields = self.expand(distinct, query_env=query_env) return self.dialect.select( self.dialect.count(sql_fields, distinct), sql_t, sql_q ) def count(self, query, distinct=None): self.execute(self._count(query, distinct)) return self.cursor.fetchone()[0] def bulk_insert(self, table, items): return [self.insert(table, item) for item in items] def create_table(self, *args, **kwargs): return self.migrator.create_table(*args, **kwargs) def _drop_table_cleanup(self, table): super(SQLAdapter, self)._drop_table_cleanup(table) if table._dbt: self.migrator.file_delete(table._dbt) self.migrator.log('success!\n', table) def drop_table(self, table, mode=''): queries = self.dialect.drop_table(table, mode) for query in queries: if table._dbt: self.migrator.log(query + '\n', table) self.execute(query) self.commit() self._drop_table_cleanup(table) @deprecated('drop', 'drop_table', 'SQLAdapter') def drop(self, table, mode=''): return self.drop_table(table, mode='') def truncate(self, table, mode=''): # Prepare functions "write_to_logfile" and "close_logfile" try: queries = self.dialect.truncate(table, mode) for query in queries: self.migrator.log(query + '\n', table) self.execute(query) self.migrator.log('success!\n', table) finally: pass def create_index(self, table, index_name, *fields, **kwargs): expressions = [ field._rname if isinstance(field, Field) else field for field in fields] sql = self.dialect.create_index( index_name, table, expressions, **kwargs) try: self.execute(sql) self.commit() except Exception as e: self.rollback() err = 'Error creating index %s\n Driver error: %s\n' + \ ' SQL instruction: %s' raise RuntimeError(err % (index_name, str(e), sql)) return True def drop_index(self, table, index_name): sql = self.dialect.drop_index(index_name, table) try: self.execute(sql) self.commit() except Exception as e: self.rollback() err = 'Error dropping index %s\n Driver error: %s' raise RuntimeError(err % (index_name, str(e))) return True def distributed_transaction_begin(self, key): pass @with_connection def commit(self): return self.connection.commit() @with_connection def rollback(self): return self.connection.rollback() @with_connection def prepare(self, key): self.connection.prepare() @with_connection def commit_prepared(self, key): self.connection.commit() @with_connection def rollback_prepared(self, key): self.connection.rollback() def create_sequence_and_triggers(self, query, table, **args): self.execute(query) def sqlsafe_table(self, tablename, original_tablename=None): if original_tablename is not None: return self.dialect.alias(original_tablename, tablename) return self.dialect.quote(tablename) def sqlsafe_field(self, fieldname): return self.dialect.quote(fieldname) def table_alias(self, tbl, current_scope=[]): if isinstance(tbl, basestring): tbl = self.db[tbl] return tbl.query_name(current_scope)[0] def id_query(self, table): pkeys = getattr(table, '_primarykey', None) if pkeys: return table[pkeys[0]] != None return table._id != None class NoSQLAdapter(BaseAdapter): can_select_for_update = False def commit(self): pass def rollback(self): pass def prepare(self): pass def commit_prepared(self, key): pass def rollback_prepared(self, key): pass def id_query(self, table): return table._id > 0 def create_table(self, table, migrate=True, fake_migrate=False, polymodel=None): table._dbt = None table._notnulls = [] for field_name in table.fields: if table[field_name].notnull: table._notnulls.append(field_name) table._uniques = [] for field_name in table.fields: if table[field_name].unique: # this is unnecessary if the fields are indexed and unique table._uniques.append(field_name) def drop_table(self, table, mode=''): ctable = self.connection[table._tablename] ctable.drop() self._drop_table_cleanup(table) @deprecated('drop', 'drop_table', 'SQLAdapter') def drop(self, table, mode=''): return self.drop_table(table, mode='') def _select(self, *args, **kwargs): raise NotOnNOSQLError( "Nested queries are not supported on NoSQL databases") def nested_select(self, *args, **kwargs): raise NotOnNOSQLError( "Nested queries are not supported on NoSQL databases") class NullAdapter(BaseAdapter): def _load_dependencies(self): from ..dialects.base import CommonDialect self.dialect = CommonDialect(self) def find_driver(self): pass def connector(self): return NullDriver()