SP/web2py/gluon/packages/dal/pydal/adapters/mongo.py

951 lines
37 KiB
Python
Raw Permalink Normal View History

2018-10-25 15:33:07 +00:00
import copy
import random
from datetime import datetime
from .._compat import basestring, long
from ..exceptions import NotOnNOSQLError
from ..helpers.classes import (
FakeCursor, Reference, SQLALL, ConnectionConfigurationMixin)
from ..helpers.methods import use_common_filters, xorify
from ..objects import Field, Row, Query, Expression
from .base import NoSQLAdapter
from . import adapters
try:
from bson import Binary
from bson.binary import USER_DEFINED_SUBTYPE
except:
class Binary(object):
pass
USER_DEFINED_SUBTYPE = 0
@adapters.register_for('mongodb')
class Mongo(ConnectionConfigurationMixin, NoSQLAdapter):
dbengine = 'mongodb'
drivers = ('pymongo',)
def find_driver(self):
super(Mongo, self).find_driver()
#: ensure pymongo version >= 3.0
if 'fake_version' in self.driver_args:
version = self.driver_args['fake_version']
else:
from pymongo import version
if int(version.split('.')[0]) < 3:
raise RuntimeError(
"pydal requires pymongo version >= 3.0, found '%s'" % version)
def _initialize_(self, do_connect):
super(Mongo, self)._initialize_(do_connect)
#: uri parse
from pymongo import uri_parser
m = uri_parser.parse_uri(self.uri)
if isinstance(m, tuple):
m = {"database": m[1]}
if m.get('database') is None:
raise SyntaxError("Database is required!")
self._driver_db = m['database']
#: mongodb imports and utils
from bson.objectid import ObjectId
from bson.son import SON
from pymongo.write_concern import WriteConcern
self.epoch = datetime.fromtimestamp(0)
self.SON = SON
self.ObjectId = ObjectId
self.WriteConcern = WriteConcern
#: options
self.db_codec = 'UTF-8'
# this is the minimum amount of replicates that it should wait
# for on insert/update
self.minimumreplication = self.adapter_args.get(
'minimumreplication', 0)
# by default all inserts and selects are performed asynchronous,
# but now the default is
# synchronous, except when overruled by either this default or
# function parameter
self.safe = 1 if self.adapter_args.get('safe', True) else 0
self._mock_reconnect()
def connector(self):
conn = self.driver.MongoClient(self.uri, w=self.safe)[self._driver_db]
conn.cursor = lambda: FakeCursor()
conn.close = lambda: None
conn.commit = lambda: None
return conn
def _configure_on_first_reconnect(self):
#: server version
self._server_version = self.connection.command(
"serverStatus")['version']
self.server_version = tuple(
[int(x) for x in self._server_version.split('.')])
self.server_version_major = (
self.server_version[0] + self.server_version[1] / 10.0)
def object_id(self, arg=None):
""" Convert input to a valid Mongodb ObjectId instance
self.object_id("<random>") -> ObjectId (not unique) instance """
if not arg:
arg = 0
if isinstance(arg, basestring):
# we assume an integer as default input
rawhex = len(arg.replace("0x", "").replace("L", "")) == 24
if arg.isdigit() and (not rawhex):
arg = int(arg)
elif arg == "<random>":
arg = int("0x%s" % "".join([
random.choice("0123456789abcdef")
for x in range(24)]), 0)
elif arg.isalnum():
if not arg.startswith("0x"):
arg = "0x%s" % arg
try:
arg = int(arg, 0)
except ValueError as e:
raise ValueError(
"invalid objectid argument string: %s" % e)
else:
raise ValueError("Invalid objectid argument string. " +
"Requires an integer or base 16 value")
elif isinstance(arg, self.ObjectId):
return arg
elif isinstance(arg, (Row, Reference)):
return self.object_id(long(arg['id']))
elif not isinstance(arg, (int, long)):
raise TypeError(
"object_id argument must be of type ObjectId or an objectid " +
"representable integer (type %s)" % type(arg))
hexvalue = hex(arg)[2:].rstrip('L').zfill(24)
return self.ObjectId(hexvalue)
def _get_collection(self, tablename, safe=None):
ctable = self.connection[tablename]
if safe is not None and safe != self.safe:
wc = self.WriteConcern(w=self._get_safe(safe))
ctable = ctable.with_options(write_concern=wc)
return ctable
def _get_safe(self, val=None):
if val is None:
return self.safe
return 1 if val else 0
def _regex_select_as_parser(self, colname):
return self.dialect.REGEX_SELECT_AS_PARSER.search(colname)
@staticmethod
def _parse_data(expression, attribute, value=None):
if isinstance(expression, (list, tuple)):
ret = False
for e in expression:
ret = Mongo._parse_data(e, attribute, value) or ret
return ret
if value is not None:
try:
expression._parse_data[attribute] = value
except AttributeError:
return None
try:
return expression._parse_data[attribute]
except (AttributeError, TypeError):
return None
def _expand(self, expression, field_type=None, query_env={}):
if isinstance(expression, Field):
if expression.type == 'id':
result = "_id"
else:
result = expression.name
if self._parse_data(expression, 'pipeline'):
# field names as part of expressions need to start with '$'
result = '$' + result
elif isinstance(expression, (Expression, Query)):
first = expression.first
second = expression.second
if isinstance(first, Field) and "reference" in first.type:
# cast to Mongo ObjectId
if isinstance(second, (tuple, list, set)):
second = [
self.object_id(item) for item in expression.second]
else:
second = self.object_id(expression.second)
op = expression.op
optional_args = expression.optional_args or {}
optional_args['query_env'] = query_env
if second is not None:
result = op(first, second, **optional_args)
elif first is not None:
result = op(first, **optional_args)
elif isinstance(op, str):
result = op
else:
result = op(**optional_args)
elif isinstance(expression, Expansion):
expression.query = (self.expand(expression.query, field_type,
query_env=query_env))
result = expression
elif isinstance(expression, (list, tuple)):
result = [self.represent(item, field_type) for item in expression]
elif field_type:
result = self.represent(expression, field_type)
else:
result = expression
return result
def represent(self, obj, field_type):
if isinstance(obj, self.ObjectId):
return obj
return super(Mongo, self).represent(obj, field_type)
def truncate(self, table, mode, safe=None):
ctable = self.connection[table._tablename]
ctable.delete_many({})
def count(self, query, distinct=None, snapshot=True):
if not isinstance(query, Query):
raise SyntaxError("Type '%s' not supported in count" % type(query))
distinct_fields = []
if distinct is True:
distinct_fields = [x for x in query.first.table if x.name != 'id']
elif distinct:
if isinstance(distinct, Field):
distinct_fields = [distinct]
else:
while (isinstance(distinct, Expression) and
isinstance(distinct.second, Field)):
distinct_fields += [distinct.second]
distinct = distinct.first
if isinstance(distinct, Field):
distinct_fields += [distinct]
distinct = True
expanded = Expansion(
self, 'count', query, fields=distinct_fields, distinct=distinct)
ctable = expanded.get_collection()
if not expanded.pipeline:
return ctable.count(filter=expanded.query_dict)
for record in ctable.aggregate(expanded.pipeline):
return record['count']
return 0
def select(self, query, fields, attributes, snapshot=False):
attributes['snapshot'] = snapshot
return self.__select(query, fields, **attributes)
def __select(self, query, fields, left=False, join=False, distinct=False,
orderby=False, groupby=False, having=False, limitby=False,
orderby_on_limitby=True, for_update=False, outer_scoped=[],
required=None, cache=None, cacheable=None, processor=None,
snapshot=False):
new_fields = []
for item in fields:
if isinstance(item, SQLALL):
new_fields += item._table
else:
new_fields.append(item)
fields = new_fields
tablename = self.get_table(query, *fields)._tablename
if for_update:
self.db.logger.warning(
"Attribute 'for_update' unsupported by MongoDB")
if join or left:
raise NotOnNOSQLError("Joins not supported on NoSQL databases")
if required or cache or cacheable:
self.db.logger.warning(
"Attributes 'required', 'cache' and 'cacheable' are" +
" unsupported by MongoDB")
if limitby and orderby_on_limitby and not orderby:
if groupby:
orderby = groupby
else:
table = self.db[tablename]
orderby = [
table[x] for x in (
hasattr(table, '_primarykey') and
table._primarykey or ['_id'])]
if not orderby:
mongosort_list = []
else:
if snapshot:
raise RuntimeError(
"snapshot and orderby are mutually exclusive")
if isinstance(orderby, (list, tuple)):
orderby = xorify(orderby)
if str(orderby) == '<random>':
# !!!! need to add 'random'
mongosort_list = self.dialect.random
else:
mongosort_list = []
for f in self.expand(orderby).split(','):
include = 1
if f.startswith('-'):
include = -1
f = f[1:]
if f.startswith('$'):
f = f[1:]
mongosort_list.append((f, include))
expanded = Expansion(
self, 'select', query, fields or self.db[tablename],
groupby=groupby, distinct=distinct, having=having)
ctable = self.connection[tablename]
modifiers = {'snapshot': snapshot}
if not expanded.pipeline:
if limitby:
limitby_skip, limitby_limit = limitby[0], int(limitby[1]) - 1
else:
limitby_skip = limitby_limit = 0
mongo_list_dicts = ctable.find(
expanded.query_dict, expanded.field_dicts, skip=limitby_skip,
limit=limitby_limit, sort=mongosort_list, modifiers=modifiers)
null_rows = []
else:
if mongosort_list:
sortby_dict = self.SON()
for f in mongosort_list:
sortby_dict[f[0]] = f[1]
expanded.pipeline.append({'$sort': sortby_dict})
if limitby and limitby[1]:
expanded.pipeline.append({'$limit': limitby[1]})
if limitby and limitby[0]:
expanded.pipeline.append({'$skip': limitby[0]})
mongo_list_dicts = ctable.aggregate(expanded.pipeline)
null_rows = [(None,)]
rows = []
# populate row in proper order
# Here we replace ._id with .id to follow the standard naming
colnames = []
newnames = []
for field in expanded.fields:
if hasattr(field, "tablename"):
if field.name in ('id', '_id'):
# Mongodb reserved uuid key
colname = (tablename + '.' + 'id', '_id')
else:
colname = (field.longname, field.name)
elif not isinstance(query, Expression):
colname = (field.name, field.name)
colnames.append(colname[1])
newnames.append(colname[0])
for record in mongo_list_dicts:
row = []
for colname in colnames:
try:
value = record[colname]
except:
value = None
if self.server_version_major < 2.6:
# '$size' not present in server versions < 2.6
if isinstance(value, list) and '$addToSet' in colname:
value = len(value)
row.append(value)
rows.append(row)
if not rows:
rows = null_rows
processor = processor or self.parse
result = processor(rows, fields, newnames, blob_decode=True)
return result
def check_notnull(self, table, values):
for fieldname in table._notnulls:
if fieldname not in values or values[fieldname] is None:
raise Exception("NOT NULL constraint failed: %s" % fieldname)
def check_unique(self, table, values):
if len(table._uniques) > 0:
db = table._db
unique_queries = []
for fieldname in table._uniques:
if fieldname in values:
value = values[fieldname]
else:
value = table[fieldname].default
unique_queries.append(
Query(db, self.dialect.eq, table[fieldname], value))
if len(unique_queries) > 0:
unique_query = unique_queries[0]
# if more than one field, build a query of ORs
for query in unique_queries[1:]:
unique_query = Query(
db, self.dialect._or, unique_query, query)
if self.count(unique_query, distinct=False) != 0:
for query in unique_queries:
if self.count(query, distinct=False) != 0:
# one of the 'OR' queries failed, see which one
raise Exception(
"NOT UNIQUE constraint failed: %s" %
query.first.name)
def insert(self, table, fields, safe=None):
"""Safe determines whether a asynchronous request is done or a
synchronous action is done
For safety, we use by default synchronous requests"""
values = {}
safe = self._get_safe(safe)
ctable = self._get_collection(table._tablename, safe)
for k, v in fields:
if k.name not in ["id", "safe"]:
fieldname = k.name
fieldtype = table[k.name].type
values[fieldname] = self.represent(v, fieldtype)
# validate notnulls
try:
self.check_notnull(table, values)
except Exception as e:
if hasattr(table, '_on_insert_error'):
return table._on_insert_error(table, fields, e)
raise e
# validate uniques
try:
self.check_unique(table, values)
except Exception as e:
if hasattr(table, '_on_insert_error'):
return table._on_insert_error(table, fields, e)
raise e
# perform the insert
result = ctable.insert_one(values)
if result.acknowledged:
Oid = result.inserted_id
rid = Reference(long(str(Oid), 16))
(rid._table, rid._record) = (table, None)
return rid
else:
return None
def update(self, table, query, fields, safe=None):
# return amount of adjusted rows or zero, but no exceptions
# @ related not finding the result
if not isinstance(query, Query):
raise RuntimeError("Not implemented")
safe = self._get_safe(safe)
if safe:
amount = 0
else:
amount = self.count(query, distinct=False)
if amount == 0:
return amount
expanded = Expansion(self, 'update', query, fields)
ctable = expanded.get_collection(safe)
if expanded.pipeline:
try:
for doc in ctable.aggregate(expanded.pipeline):
result = ctable.replace_one({'_id': doc['_id']}, doc)
if safe and result.acknowledged:
amount += result.matched_count
return amount
except Exception as e:
# TODO Reverse update query to verify that the query succeeded
raise RuntimeError(
"uncaught exception when updating rows: %s" % e)
try:
result = ctable.update_many(
filter=expanded.query_dict,
update={'$set': expanded.field_dicts})
if safe and result.acknowledged:
amount = result.matched_count
return amount
except Exception as e:
# TODO Reverse update query to verify that the query succeeded
raise RuntimeError(
"uncaught exception when updating rows: %s" % e)
def delete(self, table, query, safe=None):
if not isinstance(query, Query):
raise RuntimeError("query type %s is not supported" % type(query))
safe = self._get_safe(safe)
expanded = Expansion(self, 'delete', query)
ctable = expanded.get_collection(safe)
if expanded.pipeline:
deleted = [x['_id'] for x in ctable.aggregate(expanded.pipeline)]
else:
deleted = [x['_id'] for x in ctable.find(expanded.query_dict)]
# find references to deleted items
db = self.db
cascade = []
set_null = []
for field in table._referenced_by:
if field.type == 'reference ' + table._tablename:
if field.ondelete == 'CASCADE':
cascade.append(field)
if field.ondelete == 'SET NULL':
set_null.append(field)
cascade_list = []
set_null_list = []
for field in table._referenced_by_list:
if field.type == 'list:reference ' + table._tablename:
if field.ondelete == 'CASCADE':
cascade_list.append(field)
if field.ondelete == 'SET NULL':
set_null_list.append(field)
# perform delete
result = ctable.delete_many({"_id": {"$in": deleted}})
if result.acknowledged:
amount = result.deleted_count
else:
amount = len(deleted)
# clean up any references
if amount and deleted:
# ::TODO:: test if deleted references cascade
def remove_from_list(field, deleted, safe):
for delete in deleted:
modify = {field.name: delete}
dtable = self._get_collection(field.tablename, safe)
dtable.update_many(
filter=modify, update={'$pull': modify})
# for cascaded items, if the reference is the only item in the
# list, then remove the entire record, else delete reference
# from the list
for field in cascade_list:
for delete in deleted:
modify = {field.name: [delete]}
dtable = self._get_collection(field.tablename, safe)
dtable.delete_many(filter=modify)
remove_from_list(field, deleted, safe)
for field in set_null_list:
remove_from_list(field, deleted, safe)
for field in cascade:
db(field.belongs(deleted)).delete()
for field in set_null:
db(field.belongs(deleted)).update(**{field.name: None})
return amount
def bulk_insert(self, table, items):
return [self.insert(table, item) for item in items]
class Expansion(object):
"""
Class to encapsulate a pydal expression and track the parse
expansion and its results.
Two different MongoDB mechanisms are targeted here. If the query
is sufficiently simple, then simple queries are generated. The
bulk of the complexity here is however to support more complex
queries that are targeted to the MongoDB Aggregation Pipeline.
This class supports four operations: 'count', 'select', 'update'
and 'delete'.
Behavior varies somewhat for each operation type. However
building each pipeline stage is shared where the behavior is the
same (or similar) for the different operations.
In general an attempt is made to build the query without using the
pipeline, and if that fails then the query is rebuilt with the
pipeline.
QUERY constructed in _build_pipeline_query():
$project : used to calculate expressions if needed
$match: filters out records
FIELDS constructed in _expand_fields():
FIELDS:COUNT
$group : filter for distinct if needed
$group: count the records remaining
FIELDS:SELECT
$group : implement aggregations if needed
$project: implement expressions (etc) for select
FIELDS:UPDATE
$project: implement expressions (etc) for update
HAVING constructed in _add_having():
$project : used to calculate expressions
$match: filters out records
$project : used to filter out previous expression fields
"""
def __init__(self, adapter, crud, query, fields=(), tablename=None,
groupby=None, distinct=False, having=None):
self.adapter = adapter
self.NULL_QUERY = {'_id': {
'$gt': self.adapter.ObjectId('000000000000000000000000')}}
self._parse_data = {'pipeline': False, 'need_group':
bool(groupby or distinct or having)}
self.crud = crud
self.having = having
self.distinct = distinct
if not groupby and distinct:
if distinct is True:
# groupby gets all fields
self.groupby = fields
else:
self.groupby = distinct
else:
self.groupby = groupby
if crud == 'update':
self.values = [(f[0], self.annotate_expression(f[1]))
for f in (fields or [])]
self.fields = [f[0] for f in self.values]
else:
self.fields = [self.annotate_expression(f)
for f in (fields or [])]
self.tablename = (tablename or
adapter.get_table(query, *self.fields)._tablename)
if use_common_filters(query):
query = adapter.common_filter(query, [self.tablename])
self.query = self.annotate_expression(query)
# expand the query
self.pipeline = []
self.query_dict = adapter.expand(self.query)
self.field_dicts = adapter.SON()
self.field_groups = adapter.SON()
self.field_groups['_id'] = adapter.SON()
if self._parse_data['pipeline']:
# if the query needs the aggregation engine, set that up
self._build_pipeline_query()
# expand the fields for the aggregation engine
self._expand_fields(None)
else:
# expand the fields
try:
if not self._parse_data['need_group']:
self._expand_fields(self._fields_loop_abort)
else:
self._parse_data['pipeline'] = True
raise StopIteration
except StopIteration:
# if the fields needs the aggregation engine, set that up
self.field_dicts = adapter.SON()
if self.query_dict:
if self.query_dict != self.NULL_QUERY:
self.pipeline = [{'$match': self.query_dict}]
self.query_dict = {}
# expand the fields for the aggregation engine
self._expand_fields(None)
if not self._parse_data['pipeline']:
if crud == 'update':
# do not update id fields
for fieldname in ("_id", "id"):
if fieldname in self.field_dicts:
del self.field_dicts[fieldname]
else:
if crud == 'update':
self._add_all_fields_projection(self.field_dicts)
self.field_dicts = adapter.SON()
elif crud == 'select':
if self._parse_data['need_group']:
if not self.groupby:
# no groupby, aggregate all records
self.field_groups['_id'] = None
# id has no value after aggregations
self.field_dicts['_id'] = False
self.pipeline.append({'$group': self.field_groups})
if self.field_dicts:
self.pipeline.append({'$project': self.field_dicts})
self.field_dicts = adapter.SON()
self._add_having()
elif crud == 'count':
if self._parse_data['need_group']:
self.pipeline.append({'$group': self.field_groups})
self.pipeline.append(
{'$group': {"_id": None, 'count': {"$sum": 1}}})
#elif crud == 'delete':
# pass
@property
def dialect(self):
return self.adapter.dialect
def _build_pipeline_query(self):
# search for anything needing the $match stage.
# currently only '$regex' requires the match stage
def parse_need_match_stage(items, parent, parent_key):
need_match = False
non_matched_indices = []
if isinstance(items, list):
indices = range(len(items))
elif isinstance(items, dict):
indices = items.keys()
else:
return
for i in indices:
if parse_need_match_stage(items[i], items, i):
need_match = True
elif i not in [self.dialect.REGEXP_MARK1,
self.dialect.REGEXP_MARK2]:
non_matched_indices.append(i)
if i == self.dialect.REGEXP_MARK1:
need_match = True
self.query_dict['project'].update(items[i])
parent[parent_key] = items[self.dialect.REGEXP_MARK2]
if need_match:
for i in non_matched_indices:
name = str(items[i])
self.query_dict['project'][name] = items[i]
items[i] = {name: True}
if parent is None and self.query_dict['project']:
self.query_dict['match'] = items
return need_match
expanded = self.adapter.expand(self.query)
if self.dialect.REGEXP_MARK1 in expanded:
# the REGEXP_MARK is at the top of the tree, so can just split
# the regex over a '$project' and a '$match'
self.query_dict = None
match = expanded[self.dialect.REGEXP_MARK2]
project = expanded[self.dialect.REGEXP_MARK1]
else:
self.query_dict = {'project': {}, 'match': {}}
if parse_need_match_stage(expanded, None, None):
project = self.query_dict['project']
match = self.query_dict['match']
else:
project = {'__query__': expanded}
match = {'__query__': True}
if self.crud in ['select', 'update']:
self._add_all_fields_projection(project)
else:
self.pipeline.append({'$project': project})
self.pipeline.append({'$match': match})
self.query_dict = None
def _expand_fields(self, mid_loop):
if self.crud == 'update':
mid_loop = mid_loop or self._fields_loop_update_pipeline
for field, value in self.values:
self._expand_field(field, value, mid_loop)
elif self.crud in ['select', 'count']:
mid_loop = mid_loop or self._fields_loop_select_pipeline
for field in self.fields:
self._expand_field(field, field, mid_loop)
elif self.fields:
raise RuntimeError(self.crud + " not supported with fields")
def _expand_field(self, field, value, mid_loop):
expanded = {}
if isinstance(field, Field):
expanded = self.adapter.expand(value, field.type)
elif isinstance(field, (Expression, Query)):
expanded = self.adapter.expand(field)
field.name = str(expanded)
else:
raise RuntimeError("%s not supported with fields" % type(field))
if mid_loop:
expanded = mid_loop(expanded, field, value)
self.field_dicts[field.name] = expanded
def _fields_loop_abort(self, expanded, *args):
# if we need the aggregation engine, then start over
if self._parse_data['pipeline']:
raise StopIteration()
return expanded
def _fields_loop_update_pipeline(self, expanded, field, value):
if not isinstance(value, Expression):
if self.adapter.server_version_major >= 2.6:
expanded = {'$literal': expanded}
# '$literal' not present in server versions < 2.6
elif field.type in ['string', 'text', 'password']:
expanded = {'$concat': [expanded]}
elif field.type in ['integer', 'bigint', 'float', 'double']:
expanded = {'$add': [expanded]}
elif field.type == 'boolean':
expanded = {'$and': [expanded]}
elif field.type in ['date', 'time', 'datetime']:
expanded = {'$add': [expanded]}
else:
raise RuntimeError(
"updating with expressions not supported for field type " +
"'%s' in MongoDB version < 2.6" % field.type)
return expanded
def _fields_loop_select_pipeline(self, expanded, field, value):
# search for anything needing $group
def parse_groups(items, parent, parent_key):
for item in items:
if isinstance(items[item], list):
for list_item in items[item]:
if isinstance(list_item, dict):
parse_groups(list_item, items[item],
items[item].index(list_item))
elif isinstance(items[item], dict):
parse_groups(items[item], items, item)
if item == self.dialect.GROUP_MARK:
name = str(items)
self.field_groups[name] = items[item]
parent[parent_key] = '$' + name
return items
if self.dialect.AS_MARK in field.name:
# The AS_MARK in the field name is used by base to alias the
# result, we don't actually need the AS_MARK in the parse tree
# so we remove it here.
if isinstance(expanded, list):
# AS mark is first element in list, drop it
expanded = expanded[1]
elif self.dialect.AS_MARK in expanded:
# AS mark is element in dict, drop it
del expanded[self.dialect.AS_MARK]
else:
# ::TODO:: should be possible to do this...
raise SyntaxError("AS() not at top of parse tree")
if self.dialect.GROUP_MARK in expanded:
# the GROUP_MARK is at the top of the tree, so can just pass
# the group result straight through the '$project' stage
self.field_groups[field.name] = expanded[self.dialect.GROUP_MARK]
expanded = 1
elif self.dialect.GROUP_MARK in field.name:
# the GROUP_MARK is not at the top of the tree, so we need to
# pass the group results through to a '$project' stage.
expanded = parse_groups(expanded, None, None)
elif self._parse_data['need_group']:
if field in self.groupby:
# this is a 'groupby' field
self.field_groups['_id'][field.name] = expanded
expanded = '$_id.' + field.name
else:
raise SyntaxError("field '%s' not in groupby" % field)
return expanded
def _add_all_fields_projection(self, fields):
for fieldname in self.adapter.db[self.tablename].fields:
# add all fields to projection to pass them through
if fieldname not in fields and fieldname not in ("_id", "id"):
fields[fieldname] = 1
self.pipeline.append({'$project': fields})
def _add_having(self):
if not self.having:
return
self._expand_field(
self.having, None, self._fields_loop_select_pipeline)
fields = {'__having__': self.field_dicts[self.having.name]}
for fieldname in self.pipeline[-1]['$project']:
# add all fields to projection to pass them through
if fieldname not in fields and fieldname not in ("_id", "id"):
fields[fieldname] = 1
self.pipeline.append({'$project': copy.copy(fields)})
self.pipeline.append({'$match': {'__having__': True}})
del fields['__having__']
self.pipeline.append({'$project': fields})
def annotate_expression(self, expression):
def mark_has_field(expression):
if not isinstance(expression, (Expression, Query)):
return False
first_has_field = mark_has_field(expression.first)
second_has_field = mark_has_field(expression.second)
expression.has_field = (isinstance(expression, Field) or
first_has_field or second_has_field)
return expression.has_field
def add_parse_data(child, parent):
if isinstance(child, (Expression, Query)):
child.parse_root = parent.parse_root
child.parse_parent = parent
child.parse_depth = parent.parse_depth + 1
child._parse_data = parent._parse_data
add_parse_data(child.first, child)
add_parse_data(child.second, child)
elif isinstance(child, (list, tuple)):
for c in child:
add_parse_data(c, parent)
if isinstance(expression, (Expression, Query)):
expression.parse_root = expression
expression.parse_depth = -1
expression._parse_data = self._parse_data
add_parse_data(expression, expression)
mark_has_field(expression)
return expression
def get_collection(self, safe=None):
return self.adapter._get_collection(self.tablename, safe)
class MongoBlob(Binary):
MONGO_BLOB_BYTES = USER_DEFINED_SUBTYPE
MONGO_BLOB_NON_UTF8_STR = USER_DEFINED_SUBTYPE + 1
def __new__(cls, value):
# return None and Binary() unmolested
if value is None or isinstance(value, Binary):
return value
# bytearray is marked as MONGO_BLOB_BYTES
if isinstance(value, bytearray):
return Binary.__new__(
cls, bytes(value), MongoBlob.MONGO_BLOB_BYTES)
# return non-strings as Binary(), eg: PY3 bytes()
if not isinstance(value, basestring):
return Binary(value)
# if string is encodable as UTF-8, then return as string
try:
value.encode('utf-8')
return value
except UnicodeDecodeError:
# string which can not be UTF-8 encoded, eg: pickle strings
return Binary.__new__(
cls, value, MongoBlob.MONGO_BLOB_NON_UTF8_STR)
def __repr__(self):
return repr(MongoBlob.decode(self))
@staticmethod
def decode(value):
if isinstance(value, Binary):
if value.subtype == MongoBlob.MONGO_BLOB_BYTES:
return bytearray(value)
if value.subtype == MongoBlob.MONGO_BLOB_NON_UTF8_STR:
return str(value)
return value