SP/web2py/gluon/contrib/memdb.py
Saturneic 064f602b1a Add.
2018-10-25 23:33:13 +08:00

906 lines
28 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This file is part of web2py Web Framework (Copyrighted, 2007-2009).
Developed by Massimo Di Pierro <mdipierro@cs.depaul.edu> and
Robin B <robi123@gmail.com>.
License: LGPLv3
"""
__all__ = ['MEMDB', 'Field']
import re
import sys
import os
import types
import datetime
import thread
import cStringIO
import csv
import copy
import gluon.validators as validators
from gluon.utils import web2py_uuid
from gluon.storage import Storage
from gluon import SQLTABLE
import random
SQL_DIALECTS = {'memcache': {
'boolean': bool,
'string': unicode,
'text': unicode,
'password': unicode,
'blob': unicode,
'upload': unicode,
'integer': long,
'double': float,
'date': datetime.date,
'time': datetime.time,
'datetime': datetime.datetime,
'id': int,
'reference': int,
'lower': None,
'upper': None,
'is null': 'IS NULL',
'is not null': 'IS NOT NULL',
'extract': None,
'left join': None,
}}
def cleanup(text):
if re.compile('[^0-9a-zA-Z_]').findall(text):
raise SyntaxError('Can\'t cleanup \'%s\': only [0-9a-zA-Z_] allowed in table and field names' % text)
return text
def assert_filter_fields(*fields):
for field in fields:
if isinstance(field, (Field, Expression)) and field.type\
in ['text', 'blob']:
raise SyntaxError('AppEngine does not index by: %s'
% field.type)
def dateobj_to_datetime(object):
# convert dates,times to datetimes for AppEngine
if isinstance(object, datetime.date):
object = datetime.datetime(object.year, object.month,
object.day)
if isinstance(object, datetime.time):
object = datetime.datetime(
1970,
1,
1,
object.hour,
object.minute,
object.second,
object.microsecond,
)
return object
def sqlhtml_validators(field_type, length):
v = {
'boolean': [],
'string': validators.IS_LENGTH(length),
'text': [],
'password': validators.IS_LENGTH(length),
'blob': [],
'upload': [],
'double': validators.IS_FLOAT_IN_RANGE(-1e100, 1e100),
'integer': validators.IS_INT_IN_RANGE(-1e100, 1e100),
'date': validators.IS_DATE(),
'time': validators.IS_TIME(),
'datetime': validators.IS_DATETIME(),
'reference': validators.IS_INT_IN_RANGE(0, 1e100),
}
try:
return v[field_type[:9]]
except KeyError:
return []
class DALStorage(dict):
"""
a dictionary that let you do d['a'] as well as d.a
"""
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self:
raise SyntaxError(
'Object \'%s\'exists and cannot be redefined' % key)
self[key] = value
def __repr__(self):
return '<DALStorage ' + dict.__repr__(self) + '>'
class SQLCallableList(list):
def __call__(self):
return copy.copy(self)
class MEMDB(DALStorage):
"""
an instance of this class represents a database connection
Example::
db=MEMDB(Client())
db.define_table('tablename',Field('fieldname1'),
Field('fieldname2'))
"""
def __init__(self, client):
self._dbname = 'memdb'
self['_lastsql'] = ''
self.tables = SQLCallableList()
self._translator = SQL_DIALECTS['memcache']
self.client = client
def define_table(
self,
tablename,
*fields,
**args
):
tablename = cleanup(tablename)
if tablename in dir(self) or tablename[0] == '_':
raise SyntaxError('invalid table name: %s' % tablename)
if not tablename in self.tables:
self.tables.append(tablename)
else:
raise SyntaxError('table already defined: %s' % tablename)
t = self[tablename] = Table(self, tablename, *fields)
t._create()
return t
def __call__(self, where=''):
return Set(self, where)
class SQLALL(object):
def __init__(self, table):
self.table = table
class Table(DALStorage):
"""
an instance of this class represents a database table
Example::
db=MEMDB(Client())
db.define_table('users',Field('name'))
db.users.insert(name='me')
"""
def __init__(
self,
db,
tablename,
*fields
):
self._db = db
self._tablename = tablename
self.fields = SQLCallableList()
self._referenced_by = []
fields = list(fields)
fields.insert(0, Field('id', 'id'))
for field in fields:
self.fields.append(field.name)
self[field.name] = field
field._tablename = self._tablename
field._table = self
field._db = self._db
self.ALL = SQLALL(self)
def _create(self):
fields = []
myfields = {}
for k in self.fields:
field = self[k]
attr = {}
if not field.type[:9] in ['id', 'reference']:
if field.notnull:
attr = dict(required=True)
if field.type[:2] == 'id':
continue
if field.type[:9] == 'reference':
referenced = field.type[10:].strip()
if not referenced:
raise SyntaxError('Table %s: reference \'%s\' to nothing!' % (
self._tablename, k))
if not referenced in self._db:
raise SyntaxError(
'Table: table %s does not exist' % referenced)
referee = self._db[referenced]
ftype = \
self._db._translator[field.type[:9]](
self._db[referenced]._tableobj)
if self._tablename in referee.fields: # ## THIS IS OK
raise SyntaxError('Field: table \'%s\' has same name as a field '
'in referenced table \'%s\'' % (
self._tablename, referenced))
self._db[referenced]._referenced_by.append((self._tablename,
field.name))
elif not field.type in self._db._translator\
or not self._db._translator[field.type]:
raise SyntaxError('Field: unknown field type %s' % field.type)
self._tableobj = self._db.client
return None
def create(self):
# nothing to do, here for backward compatility
pass
def drop(self):
# nothing to do, here for backward compatibility
self._db(self.id > 0).delete()
def insert(self, **fields):
# Checks 3 times that the id is new. 3 times is enough!
for i in range(3):
id = self._create_id()
if self.get(id) is None and self.update(id, **fields):
return long(id)
else:
raise RuntimeError("Too many ID conflicts")
def get(self, id):
val = self._tableobj.get(self._id_to_key(id))
if val:
return Storage(val)
else:
return None
def update(self, id, **fields):
for field in fields:
if not field in fields and self[field].default\
is not None:
fields[field] = self[field].default
if field in fields:
fields[field] = obj_represent(fields[field],
self[field].type, self._db)
return self._tableobj.set(self._id_to_key(id), fields)
def delete(self, id):
return self._tableobj.delete(self._id_to_key(id))
def _id_to_key(self, id):
return '__memdb__/t/%s/k/%s' % (self._tablename, str(id))
def _create_id(self):
return long(web2py_uuid().replace('-',''),16)
def __str__(self):
return self._tablename
def __call__(self, id, **kwargs):
record = self.get(id)
if record is None:
return None
if kwargs and any(record[key]!=kwargs[key] for key in kwargs):
return None
return record
class Expression(object):
def __init__(
self,
name,
type='string',
db=None,
):
(self.name, self.type, self._db) = (name, type, db)
def __str__(self):
return self.name
def __or__(self, other): # for use in sortby
assert_filter_fields(self, other)
return Expression(self.name + '|' + other.name, None, None)
def __invert__(self):
assert_filter_fields(self)
return Expression('-' + self.name, self.type, None)
# for use in Query
def __eq__(self, value):
return Query(self, '=', value)
def __ne__(self, value):
return Query(self, '!=', value)
def __lt__(self, value):
return Query(self, '<', value)
def __le__(self, value):
return Query(self, '<=', value)
def __gt__(self, value):
return Query(self, '>', value)
def __ge__(self, value):
return Query(self, '>=', value)
# def like(self,value): return Query(self,' LIKE ',value)
# def belongs(self,value): return Query(self,' IN ',value)
# for use in both Query and sortby
def __add__(self, other):
return Expression('%s+%s' % (self, other), 'float', None)
def __sub__(self, other):
return Expression('%s-%s' % (self, other), 'float', None)
def __mul__(self, other):
return Expression('%s*%s' % (self, other), 'float', None)
def __div__(self, other):
return Expression('%s/%s' % (self, other), 'float', None)
class Field(Expression):
"""
an instance of this class represents a database field
example::
a = Field(name, 'string', length=32, required=False,
default=None, requires=IS_NOT_EMPTY(), notnull=False,
unique=False, uploadfield=True)
to be used as argument of GQLDB.define_table
allowed field types:
string, boolean, integer, double, text, blob,
date, time, datetime, upload, password
strings must have a length or 512 by default.
fields should have a default or they will be required in SQLFORMs
the requires argument are used to validate the field input in SQLFORMs
"""
def __init__(
self,
fieldname,
type='string',
length=None,
default=None,
required=False,
requires=sqlhtml_validators,
ondelete='CASCADE',
notnull=False,
unique=False,
uploadfield=True,
):
self.name = cleanup(fieldname)
if fieldname in dir(Table) or fieldname[0] == '_':
raise SyntaxError('Field: invalid field name: %s' % fieldname)
if isinstance(type, Table):
type = 'reference ' + type._tablename
if not length:
length = 512
self.type = type # 'string', 'integer'
self.length = length # the length of the string
self.default = default # default value for field
self.required = required # is this field required
self.ondelete = ondelete.upper() # this is for reference fields only
self.notnull = notnull
self.unique = unique
self.uploadfield = uploadfield
if requires == sqlhtml_validators:
requires = sqlhtml_validators(type, length)
elif requires is None:
requires = []
self.requires = requires # list of validators
def formatter(self, value):
if value is None or not self.requires:
return value
if not isinstance(self.requires, (list, tuple)):
requires = [self.requires]
else:
requires = copy.copy(self.requires)
requires.reverse()
for item in requires:
if hasattr(item, 'formatter'):
value = item.formatter(value)
return value
def __str__(self):
return '%s.%s' % (self._tablename, self.name)
MEMDB.Field = Field # ## required by gluon/globals.py session.connect
def obj_represent(object, fieldtype, db):
if object is not None:
if fieldtype == 'date' and not isinstance(object,
datetime.date):
(y, m, d) = [int(x) for x in str(object).strip().split('-')]
object = datetime.date(y, m, d)
elif fieldtype == 'time' and not isinstance(object, datetime.time):
time_items = [int(x) for x in str(object).strip().split(':')[:3]]
if len(time_items) == 3:
(h, mi, s) = time_items
else:
(h, mi, s) = time_items + [0]
object = datetime.time(h, mi, s)
elif fieldtype == 'datetime' and not isinstance(object,
datetime.datetime):
(y, m, d) = [int(x) for x in
str(object)[:10].strip().split('-')]
time_items = [int(x) for x in
str(object)[11:].strip().split(':')[:3]]
if len(time_items) == 3:
(h, mi, s) = time_items
else:
(h, mi, s) = time_items + [0]
object = datetime.datetime(
y,
m,
d,
h,
mi,
s,
)
elif fieldtype == 'integer' and not isinstance(object, long):
object = long(object)
return object
class QueryException:
def __init__(self, **a):
self.__dict__ = a
class Query(object):
"""
A query object necessary to define a set.
It can be stored or can be passed to GQLDB.__call__() to obtain a Set
Example:
query=db.users.name=='Max'
set=db(query)
records=set.select()
"""
def __init__(
self,
left,
op=None,
right=None,
):
if isinstance(right, (Field, Expression)):
raise SyntaxError(
'Query: right side of filter must be a value or entity')
if isinstance(left, Field) and left.name == 'id':
if op == '=':
self.get_one = QueryException(
tablename=left._tablename, id=long(right or 0))
return
else:
raise SyntaxError('only equality by id is supported')
raise SyntaxError('not supported')
def __str__(self):
return str(self.left)
class Set(object):
"""
As Set represents a set of records in the database,
the records are identified by the where=Query(...) object.
normally the Set is generated by GQLDB.__call__(Query(...))
given a set, for example
set=db(db.users.name=='Max')
you can:
set.update(db.users.name='Massimo')
set.delete() # all elements in the set
set.select(orderby=db.users.id,groupby=db.users.name,limitby=(0,10))
and take subsets:
subset=set(db.users.id<5)
"""
def __init__(self, db, where=None):
self._db = db
self._tables = []
self.filters = []
if hasattr(where, 'get_all'):
self.where = where
self._tables.insert(0, where.get_all)
elif hasattr(where, 'get_one') and isinstance(where.get_one,
QueryException):
self.where = where.get_one
else:
# find out which tables are involved
if isinstance(where, Query):
self.filters = where.left
self.where = where
self._tables = [field._tablename for (field, op, val) in
self.filters]
def __call__(self, where):
if isinstance(self.where, QueryException) or isinstance(where,
QueryException):
raise SyntaxError('neither self.where nor where can be a QueryException instance')
if self.where:
return Set(self._db, self.where & where)
else:
return Set(self._db, where)
def _get_table_or_raise(self):
tablenames = list(set(self._tables)) # unique
if len(tablenames) < 1:
raise SyntaxError('Set: no tables selected')
if len(tablenames) > 1:
raise SyntaxError('Set: no join in appengine')
return self._db[tablenames[0]]._tableobj
def _getitem_exception(self):
(tablename, id) = (self.where.tablename, self.where.id)
fields = self._db[tablename].fields
self.colnames = ['%s.%s' % (tablename, t) for t in fields]
item = self._db[tablename].get(id)
return (item, fields, tablename, id)
def _select_except(self):
(item, fields, tablename, id) = self._getitem_exception()
if not item:
return []
new_item = []
for t in fields:
if t == 'id':
new_item.append(long(id))
else:
new_item.append(getattr(item, t))
r = [new_item]
return Rows(self._db, r, *self.colnames)
def select(self, *fields, **attributes):
"""
Always returns a Rows object, even if it may be empty
"""
if isinstance(self.where, QueryException):
return self._select_except()
else:
raise SyntaxError('select arguments not supported')
def count(self):
return len(self.select())
def delete(self):
if isinstance(self.where, QueryException):
(item, fields, tablename, id) = self._getitem_exception()
if not item:
return
self._db[tablename].delete(id)
else:
raise Exception('deletion not implemented')
def update(self, **update_fields):
if isinstance(self.where, QueryException):
(item, fields, tablename, id) = self._getitem_exception()
if not item:
return
for (key, value) in update_fields.items():
setattr(item, key, value)
self._db[tablename].update(id, **item)
else:
raise Exception('update not implemented')
def update_record(
t,
s,
id,
a,
):
item = s.get(id)
for (key, value) in a.items():
t[key] = value
setattr(item, key, value)
s.update(id, **item)
class Rows(object):
"""
A wrapper for the return value of a select. It basically represents a table.
It has an iterator and each row is represented as a dictionary.
"""
# ## this class still needs some work to care for ID/OID
def __init__(
self,
db,
response,
*colnames
):
self._db = db
self.colnames = colnames
self.response = response
def __len__(self):
return len(self.response)
def __getitem__(self, i):
if i >= len(self.response) or i < 0:
raise SyntaxError('Rows: no such row: %i' % i)
if len(self.response[0]) != len(self.colnames):
raise SyntaxError('Rows: internal error')
row = DALStorage()
for j in xrange(len(self.colnames)):
value = self.response[i][j]
if isinstance(value, unicode):
value = value.encode('utf-8')
packed = self.colnames[j].split('.')
try:
(tablename, fieldname) = packed
except:
if not '_extra' in row:
row['_extra'] = DALStorage()
row['_extra'][self.colnames[j]] = value
continue
table = self._db[tablename]
field = table[fieldname]
if not tablename in row:
row[tablename] = DALStorage()
if field.type[:9] == 'reference':
referee = field.type[10:].strip()
rid = value
row[tablename][fieldname] = rid
elif field.type == 'boolean' and value is not None:
# row[tablename][fieldname]=Set(self._db[referee].id==rid)
if value == True or value == 'T':
row[tablename][fieldname] = True
else:
row[tablename][fieldname] = False
elif field.type == 'date' and value is not None\
and not isinstance(value, datetime.date):
(y, m, d) = [int(x) for x in
str(value).strip().split('-')]
row[tablename][fieldname] = datetime.date(y, m, d)
elif field.type == 'time' and value is not None\
and not isinstance(value, datetime.time):
time_items = [int(x) for x in
str(value).strip().split(':')[:3]]
if len(time_items) == 3:
(h, mi, s) = time_items
else:
(h, mi, s) = time_items + [0]
row[tablename][fieldname] = datetime.time(h, mi, s)
elif field.type == 'datetime' and value is not None\
and not isinstance(value, datetime.datetime):
(y, m, d) = [int(x) for x in
str(value)[:10].strip().split('-')]
time_items = [int(x) for x in
str(value)[11:].strip().split(':')[:3]]
if len(time_items) == 3:
(h, mi, s) = time_items
else:
(h, mi, s) = time_items + [0]
row[tablename][fieldname] = datetime.datetime(
y,
m,
d,
h,
mi,
s,
)
else:
row[tablename][fieldname] = value
if fieldname == 'id':
id = row[tablename].id
row[tablename].update_record = lambda t = row[tablename], \
s = self._db[tablename], id = id, **a: update_record(t,
s, id, a)
for (referee_table, referee_name) in \
table._referenced_by:
s = self._db[referee_table][referee_name]
row[tablename][referee_table] = Set(self._db, s
== id)
if len(row.keys()) == 1:
return row[row.keys()[0]]
return row
def __iter__(self):
"""
iterator over records
"""
for i in xrange(len(self)):
yield self[i]
def __str__(self):
"""
serializes the table into a csv file
"""
s = cStringIO.StringIO()
writer = csv.writer(s)
writer.writerow(self.colnames)
c = len(self.colnames)
for i in xrange(len(self)):
row = [self.response[i][j] for j in xrange(c)]
for k in xrange(c):
if isinstance(row[k], unicode):
row[k] = row[k].encode('utf-8')
writer.writerow(row)
return s.getvalue()
def xml(self):
"""
serializes the table using SQLTABLE (if present)
"""
return SQLTABLE(self).xml()
def test_all():
"""
How to run from web2py dir:
export PYTHONPATH=.:YOUR_PLATFORMS_APPENGINE_PATH
python gluon/contrib/memdb.py
Setup the UTC timezone and database stubs
>>> import os
>>> os.environ['TZ'] = 'UTC'
>>> import time
>>> if hasattr(time, 'tzset'):
... time.tzset()
>>>
>>> from google.appengine.api import apiproxy_stub_map
>>> from google.appengine.api.memcache import memcache_stub
>>> apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
>>> apiproxy_stub_map.apiproxy.RegisterStub('memcache', memcache_stub.MemcacheServiceStub())
Create a table with all possible field types
>>> from google.appengine.api.memcache import Client
>>> db=MEMDB(Client())
>>> tmp=db.define_table('users', Field('stringf','string',length=32,required=True), Field('booleanf','boolean',default=False), Field('passwordf','password',notnull=True), Field('blobf','blob'), Field('uploadf','upload'), Field('integerf','integer',unique=True), Field('doublef','double',unique=True,notnull=True), Field('datef','date',default=datetime.date.today()), Field('timef','time'), Field('datetimef','datetime'), migrate='test_user.table')
Insert a field
>>> user_id = db.users.insert(stringf='a',booleanf=True,passwordf='p',blobf='0A', uploadf=None, integerf=5,doublef=3.14, datef=datetime.date(2001,1,1), timef=datetime.time(12,30,15), datetimef=datetime.datetime(2002,2,2,12,30,15))
>>> user_id != None
True
Select all
# >>> all = db().select(db.users.ALL)
Drop the table
# >>> db.users.drop()
Select many entities
>>> tmp = db.define_table(\"posts\", Field('body','text'), Field('total','integer'), Field('created_at','datetime'))
>>> many = 20 #2010 # more than 1000 single fetch limit (it can be slow)
>>> few = 5
>>> most = many - few
>>> 0 < few < most < many
True
>>> for i in range(many):
... f=db.posts.insert(body='', total=i,created_at=datetime.datetime(2008, 7, 6, 14, 15, 42, i))
>>>
# test timezones
>>> class TZOffset(datetime.tzinfo):
... def __init__(self,offset=0):
... self.offset = offset
... def utcoffset(self, dt): return datetime.timedelta(hours=self.offset)
... def dst(self, dt): return datetime.timedelta(0)
... def tzname(self, dt): return 'UTC' + str(self.offset)
...
>>> SERVER_OFFSET = -8
>>>
>>> stamp = datetime.datetime(2008, 7, 6, 14, 15, 42, 828201)
>>> post_id = db.posts.insert(created_at=stamp,body='body1')
>>> naive_stamp = db(db.posts.id==post_id).select()[0].created_at
>>> utc_stamp=naive_stamp.replace(tzinfo=TZOffset())
>>> server_stamp = utc_stamp.astimezone(TZOffset(SERVER_OFFSET))
>>> stamp == naive_stamp
True
>>> utc_stamp == server_stamp
True
>>> rows = db(db.posts.id==post_id).select()
>>> len(rows) == 1
True
>>> rows[0].body == 'body1'
True
>>> db(db.posts.id==post_id).delete()
>>> rows = db(db.posts.id==post_id).select()
>>> len(rows) == 0
True
>>> id = db.posts.insert(total='0') # coerce str to integer
>>> rows = db(db.posts.id==id).select()
>>> len(rows) == 1
True
>>> rows[0].total == 0
True
Examples of insert, select, update, delete
>>> tmp=db.define_table('person', Field('name'), Field('birth','date'), migrate='test_person.table')
>>> marco_id=db.person.insert(name=\"Marco\",birth='2005-06-22')
>>> person_id=db.person.insert(name=\"Massimo\",birth='1971-12-21')
>>> me=db(db.person.id==person_id).select()[0] # test select
>>> me.name
'Massimo'
>>> db(db.person.id==person_id).update(name='massimo') # test update
>>> me = db(db.person.id==person_id).select()[0]
>>> me.name
'massimo'
>>> str(me.birth)
'1971-12-21'
# resave date to ensure it comes back the same
>>> me=db(db.person.id==person_id).update(birth=me.birth) # test update
>>> me = db(db.person.id==person_id).select()[0]
>>> me.birth
datetime.date(1971, 12, 21)
>>> db(db.person.id==marco_id).delete() # test delete
>>> len(db(db.person.id==marco_id).select())
0
Update a single record
>>> me.update_record(name=\"Max\")
>>> me.name
'Max'
>>> me = db(db.person.id == person_id).select()[0]
>>> me.name
'Max'
"""
SQLField = Field
SQLTable = Table
SQLXorable = Expression
SQLQuery = Query
SQLSet = Set
SQLRows = Rows
SQLStorage = DALStorage
if __name__ == '__main__':
import doctest
doctest.testmod()