
525 lines
20 KiB
Raw Permalink Normal View History

2018-10-25 15:33:07 +00:00
import re
from .._compat import PY2, basestring
from ..adapters.mongo import Mongo
from ..exceptions import NotOnNOSQLError
from ..objects import Field
from .base import NoSQLDialect
from . import dialects
_aggregate_map = {
'SUM': '$sum',
'MAX': '$max',
'MIN': '$min',
'AVG': '$avg',
_extract_map = {
'dayofyear': '$dayOfYear',
'day': '$dayOfMonth',
'dayofweek': '$dayOfWeek',
'year': '$year',
'month': '$month',
'week': '$week',
'hour': '$hour',
'minute': '$minute',
'second': '$second',
'millisecond': '$millisecond',
'string': '$dateToString',
def needs_aggregation_pipeline(f):
def wrap(self, first, *args, **kwargs):
self.adapter._parse_data(first, 'pipeline', True)
if len(args) > 0:
self.adapter._parse_data(args[0], 'pipeline', True)
return f(self, first, *args, **kwargs)
return wrap
def validate_second(f):
def wrap(*args, **kwargs):
if len(args) < 3 or args[2] is None:
raise RuntimeError("Cannot compare %s with None" % args[1])
return f(*args, **kwargs)
return wrap
def check_fields_for_cmp(f):
def wrap(self, first, second=None, *args, **kwargs):
if (self.adapter._parse_data((first, second), 'pipeline')):
pipeline = True
elif not isinstance(first, Field) or self._has_field(second):
pipeline = True
self.adapter._parse_data((first, second), 'pipeline', True)
pipeline = False
return f(self, first, second, *args, pipeline=pipeline, **kwargs)
return wrap
class MongoDialect(NoSQLDialect):
AS_MARK = "__#AS#__"
REGEX_SELECT_AS_PARSER = re.compile("\\'" + AS_MARK + "\\': \\'(\\S+)\\'")
def _has_field(expression):
return expression.has_field
except AttributeError:
return False
def invert(self, first, query_env={}):
return '-%s' % self.expand(first, query_env=query_env)
def _not(self, val, query_env={}):
op = self.expand(val, query_env=query_env)
op_k = list(op)[0]
op_body = op[op_k]
rv = None
if type(op_body) is list:
# apply De Morgan law for and/or
# not(A and B) -> not(A) or not(B)
# not(A or B) -> not(A) and not(B)
not_op = '$and' if op_k == '$or' else '$or'
rv = {not_op: [self._not(val.first, query_env),
self._not(val.second, query_env)]}
sub_ops = list(op_body.keys())
if len(sub_ops) == 1 and sub_ops[0] == '$ne':
rv = {op_k: op_body['$ne']}
except AttributeError:
rv = {op_k: {'$ne': op_body}}
if rv is None:
rv = {op_k: {'$not': op_body}}
return rv
def _and(self, first, second, query_env={}):
# pymongo expects: .find({'$and': [{'x':'1'}, {'y':'2'}]})
if isinstance(second, bool):
if second:
return self.expand(first, query_env=query_env)
return self.ne(first, first)
return {'$and': [self.expand(first, query_env=query_env),
self.expand(second, query_env=query_env)]}
def _or(self, first, second, query_env={}):
# pymongo expects: .find({'$or': [{'name':'1'}, {'name':'2'}]})
if isinstance(second, bool):
if not second:
return self.expand(first, query_env=query_env)
return True
return {'$or': [self.expand(first, query_env=query_env),
self.expand(second, query_env=query_env)]}
def belongs(self, first, second, query_env={}):
if isinstance(second, str):
# this is broken, the only way second is a string is if it has
# been converted to SQL. This no worky. This might be made to
# work if adapter._select did not return SQL.
raise RuntimeError("nested queries not supported")
items = [self.expand(item, first.type, query_env=query_env)
for item in second]
return {self.expand(first, query_env=query_env): {"$in": items}}
def _cmp_ops_aggregation_pipeline(self, op, first, second, query_env={}):
type = first.type
type = None
return {op: [self.expand(first, query_env=query_env),
self.expand(second, type, query_env=query_env)]}
def eq(self, first, second=None, pipeline=False, query_env={}):
if pipeline:
return self._cmp_ops_aggregation_pipeline('$eq', first, second,
return {self.expand(first, query_env=query_env):
self.expand(second, first.type, query_env=query_env)}
def ne(self, first, second=None, pipeline=False, query_env={}):
if pipeline:
return self._cmp_ops_aggregation_pipeline('$ne', first, second,
return {self.expand(first, query_env=query_env):
{'$ne': self.expand(second, first.type, query_env=query_env)}}
def lt(self, first, second=None, pipeline=False, query_env={}):
if pipeline:
return self._cmp_ops_aggregation_pipeline('$lt', first, second,
return {self.expand(first, query_env=query_env):
{'$lt': self.expand(second, first.type, query_env=query_env)}}
def lte(self, first, second=None, pipeline=False, query_env={}):
if pipeline:
return self._cmp_ops_aggregation_pipeline('$lte', first, second,
return {self.expand(first, query_env=query_env):
{'$lte': self.expand(second, first.type, query_env=query_env)}}
def gt(self, first, second=None, pipeline=False, query_env={}):
if pipeline:
return self._cmp_ops_aggregation_pipeline('$gt', first, second,
return {self.expand(first, query_env=query_env):
{'$gt': self.expand(second, first.type, query_env=query_env)}}
def gte(self, first, second=None, pipeline=False, query_env={}):
if pipeline:
return self._cmp_ops_aggregation_pipeline('$gte', first, second,
return {self.expand(first, query_env=query_env):
{'$gte': self.expand(second, first.type, query_env=query_env)}}
def add(self, first, second, query_env={}):
op_code = '$add'
for field in [first, second]:
if field.type in ['string', 'text', 'password']:
op_code = '$concat'
return {op_code: [self.expand(first, query_env=query_env),
self.expand(second, first.type, query_env=query_env)]}
def sub(self, first, second, query_env={}):
return {'$subtract': [
self.expand(first, query_env=query_env),
self.expand(second, first.type, query_env=query_env)]}
def mul(self, first, second, query_env={}):
return {'$multiply': [
self.expand(first, query_env=query_env),
self.expand(second, first.type, query_env=query_env)]}
def div(self, first, second, query_env={}):
return {'$divide': [
self.expand(first, query_env=query_env),
self.expand(second, first.type, query_env=query_env)]}
def mod(self, first, second, query_env={}):
return {'$mod': [
self.expand(first, query_env=query_env),
self.expand(second, first.type, query_env=query_env)]}
def aggregate(self, first, what, query_env={}):
if what == 'ABS':
return {
"$cond": [
{"$lt": [self.expand(first, query_env=query_env), 0]},
{"$subtract": [0, self.expand(first, query_env=query_env)]},
self.expand(first, query_env=query_env)]}
expanded = {_aggregate_map[what]:
self.expand(first, query_env=query_env)}
except KeyError:
raise NotImplementedError("'%s' not implemented" % what)
self.adapter._parse_data(first, 'need_group', True)
return {self.GROUP_MARK: expanded}
def count(self, first, distinct=None, query_env={}):
self.adapter._parse_data(first, 'need_group', True)
if distinct:
ret = {self.GROUP_MARK: {"$addToSet":
self.expand(first, query_env=query_env)}}
if self.adapter.server_version_major >= 2.6:
# '$size' not present in server versions < 2.6
ret = {'$size': ret}
return ret
return {self.GROUP_MARK: {"$sum": 1}}
def extract(self, first, what, query_env={}):
return {_extract_map[what]: self.expand(first, query_env=query_env)}
except KeyError:
raise NotImplementedError("EXTRACT(%s) not implemented" % what)
def epoch(self, first, query_env={}):
return {"$divide": [
{"$subtract": [self.expand(first, query_env=query_env),
self.adapter.epoch]}, 1000]}
def case(self, query, true_false, query_env={}):
return {"$cond": [
self.expand(query, query_env=query_env),
self.expand(true_false[0], query_env=query_env),
self.expand(true_false[1], query_env=query_env)]}
def _as(self, first, second, query_env={}):
# put the AS_MARK into the structure. The 'AS' name will be parsed
# later from the string of the field name.
if isinstance(first, Field):
return [{self.AS_MARK: second},
self.expand(first, query_env=query_env)]
result = self.expand(first, query_env=query_env)
result[self.AS_MARK] = second
return result
# We could implement an option that simulates a full featured SQL
# database. But I think the option should be set explicit or
# implemented as another library.
def on(self, first, second, query_env={}):
raise NotOnNOSQLError()
def comma(self, first, second, query_env={}):
# returns field name lists, to be separated via split(',')
return '%s,%s' % (self.expand(first, query_env=query_env),
self.expand(second, query_env=query_env))
# TODO verify full compatibilty with official SQL Like operator
def _build_like_regex(self, first, second, case_sensitive=True,
escape=None, ends_with=False, starts_with=False,
whole_string=True, like_wildcards=False,
base = self.expand(second, 'string', query_env=query_env)
need_regex = (whole_string or not case_sensitive or starts_with or
ends_with or like_wildcards and
('_' in base or '%' in base))
if not need_regex:
return base
expr = re.escape(base)
if like_wildcards:
if escape:
# protect % and _ which are escaped
expr = expr.replace(escape+'\\%', '%')
if PY2:
expr = expr.replace(escape+'\\_', '_')
elif escape+'_' in expr:
set_aside = str(self.adapter.object_id('<random>'))
while set_aside in expr:
set_aside = str(self.adapter.object_id('<random>'))
expr = expr.replace(escape+'_', set_aside)
set_aside = None
expr = expr.replace('\\%', '.*')
if PY2:
expr = expr.replace('\\_', '.')
expr = expr.replace('_', '.')
if escape:
# convert to protected % and _
expr = expr.replace('%', '\\%')
if PY2:
expr = expr.replace('_', '\\_')
elif set_aside:
expr = expr.replace(set_aside, '_')
if starts_with:
pattern = '^%s'
elif ends_with:
pattern = '%s$'
elif whole_string:
pattern = '^%s$'
pattern = '%s'
return self.regexp(first, pattern % expr, case_sensitive, query_env)
def like(self, first, second, case_sensitive=True, escape=None,
return self._build_like_regex(
first, second, case_sensitive=case_sensitive, escape=escape,
like_wildcards=True, query_env=query_env)
def ilike(self, first, second, escape=None, query_env={}):
return self.like(first, second, case_sensitive=False, escape=escape,
def startswith(self, first, second, query_env={}):
return self._build_like_regex(first, second, starts_with=True,
def endswith(self, first, second, query_env={}):
return self._build_like_regex(first, second, ends_with=True,
# TODO verify full compatibilty with official oracle contains operator
def contains(self, first, second, case_sensitive=True, query_env={}):
if isinstance(second, self.adapter.ObjectId):
ret = {self.expand(first, query_env=query_env): second}
elif isinstance(second, Field):
if second.type in ['string', 'text']:
if isinstance(first, Field):
if first.type in ['list:string', 'string', 'text']:
ret = {
'$where': "this.%s.indexOf(this.%s) > -1" % (
first.name, second.name)}
raise NotImplementedError(
"field.CONTAINS() not implemented for field " +
"type of '%s'" % first.type)
raise NotImplementedError(
"x.CONTAINS() not implemented for x type of '%s'" %
elif second.type in ['integer', 'bigint']:
ret = {
'$where': "this.%s.indexOf(this.%s + '') > -1" % (
first.name, second.name)}
raise NotImplementedError(
"CONTAINS(field) not implemented for field type '%s'" %
elif isinstance(second, (basestring, int)):
whole_string = isinstance(first, Field) and \
first.type == 'list:string'
ret = self._build_like_regex(
first, second, case_sensitive=case_sensitive,
whole_string=whole_string, query_env=query_env)
# first.type in ('string', 'text', 'json', 'upload')
# or first.type.startswith('list:'):
raise NotImplementedError(
"CONTAINS() not implemented for type '%s'" % type(second))
return ret
def substring(self, field, parameters, query_env={}):
def parse_parameters(pos0, length):
The expression object can return these as string based expressions.
We can't use that so we have to tease it apart.
These are the possibilities:
pos0 = '(%s - %d)' % (self.len(), abs(start) - 1)
pos0 = start + 1
length = self.len()
length = '(%s - %d - %s)' % (self.len(), abs(stop) - 1, pos0)
length = '(%s - %s)' % (stop + 1, pos0)
Two of these five require the length of the string which is not
supported by Mongo, so for now these cause an Exception and
won't reach here.
If this were to ever be supported it may require a change to
Expression.__getitem__ so that it either returned the base
expression to be expanded here, or converted length to a string
to be parsed back to a call to STRLEN()
if isinstance(length, basestring):
return (pos0 - 1, eval(length))
# take the rest of the string
return (pos0 - 1, -1)
parameters = parse_parameters(*parameters)
return {'$substr': [self.expand(field, query_env=query_env),
parameters[0], parameters[1]]}
def lower(self, first, query_env={}):
return {'$toLower': self.expand(first, query_env=query_env)}
def upper(self, first, query_env={}):
return {'$toUpper': self.expand(first, query_env=query_env)}
def regexp(self, first, second, case_sensitive=True, query_env={}):
""" MongoDB provides regular expression capabilities for pattern
matching strings in queries. MongoDB uses Perl compatible
regular expressions (i.e. 'PCRE') version 8.36 with UTF-8 support.
if (isinstance(first, Field) and
first.type in ['integer', 'bigint', 'float', 'double']):
return {
'$where': "RegExp('%s').test(this.%s + '')" % (
self.expand(second, 'string', query_env=query_env),
expanded_first = self.expand(first, query_env=query_env)
regex_second = {'$regex': self.expand(second, 'string',
if not case_sensitive:
regex_second['$options'] = 'i'
if (self.adapter._parse_data((first, second), 'pipeline')):
name = str(expanded_first)
return {
self.REGEXP_MARK1: {name: expanded_first},
self.REGEXP_MARK2: {name: regex_second}}
return {expanded_first: regex_second}
except TypeError:
# if first is not hashable, then will need the pipeline
self.adapter._parse_data((first, second), 'pipeline', True)
return {}
def length(self, first, query_env={}):
Mongo has committed $strLenBytes, $strLenCP, and $substrCP to $project
aggregation stage in dev branch V3.3.4
$project: {
byteLength: {$strLenBytes: "$string"},
cpLength: {$strLenCP: "$string"}
byteSubstr: {$substrBytes: ["$string", 0, 4]},
cpSubstr: {$substrCP: ["$string", 0, 4]}
if self.adapter.server_version_major <= 3.2:
# $strLenBytes not supported by mongo before version 3.4
raise NotImplementedError()
# implement here :-)
raise NotImplementedError()
def coalesce(self, first, second, query_env={}):
if len(second) > 1:
second = [self.coalesce(second[0], second[1:])]
return {"$ifNull": [self.expand(first, query_env=query_env),
self.expand(second[0], query_env=query_env)]}
def random(self):
Mongo has released the '$sample' pipeline stage in V3.2
if self.adapter.server_version_major <= 3.0:
# '$sample' not present until server version 3.2
raise NotImplementedError()
# implement here :-)
raise NotImplementedError()