You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

503 lines
15 KiB

  1. # Copyright (c) 2013, Web Notes Technologies Pvt. Ltd.
  2. # MIT License. See license.txt
  3. # Database Module
  4. # --------------------
  5. from __future__ import unicode_literals
  6. import MySQLdb
  7. import warnings
  8. import webnotes
  9. from webnotes import conf
  10. import datetime
  11. class Database:
  12. """
  13. Open a database connection with the given parmeters, if use_default is True, use the
  14. login details from `conf.py`. This is called by the request handler and is accessible using
  15. the `conn` global variable. the `sql` method is also global to run queries
  16. """
  17. def __init__(self, host=None, user=None, password=None, ac_name=None, use_default = 0):
  18. self.host = host or conf.db_host or 'localhost'
  19. self.user = user or conf.db_name
  20. if ac_name:
  21. self.user = self.get_db_login(ac_name) or conf.db_name
  22. if use_default:
  23. self.user = conf.db_name
  24. self.transaction_writes = 0
  25. self.auto_commit_on_many_writes = 0
  26. self.password = password or webnotes.get_db_password(self.user)
  27. self.connect()
  28. if self.user != 'root':
  29. self.use(self.user)
  30. def get_db_login(self, ac_name):
  31. return ac_name
  32. def connect(self):
  33. """
  34. Connect to a database
  35. """
  36. warnings.filterwarnings('ignore', category=MySQLdb.Warning)
  37. self._conn = MySQLdb.connect(user=self.user, host=self.host, passwd=self.password,
  38. use_unicode=True, charset='utf8')
  39. self._conn.converter[246]=float
  40. self._cursor = self._conn.cursor()
  41. webnotes.local.rollback_observers = []
  42. def use(self, db_name):
  43. """
  44. `USE` db_name
  45. """
  46. self._conn.select_db(db_name)
  47. self.cur_db_name = db_name
  48. def validate_query(self, q):
  49. cmd = q.strip().lower().split()[0]
  50. if cmd in ['alter', 'drop', 'truncate'] and webnotes.user.name != 'Administrator':
  51. webnotes.msgprint('Not allowed to execute query')
  52. raise Exception
  53. def sql(self, query, values=(), as_dict = 0, as_list = 0, formatted = 0,
  54. debug=0, ignore_ddl=0, as_utf8=0, auto_commit=0, update=None):
  55. """
  56. * Execute a `query`, with given `values`
  57. * returns as a dictionary if as_dict = 1
  58. * returns as a list of lists (with cleaned up dates) if as_list = 1
  59. """
  60. # in transaction validations
  61. self.check_transaction_status(query)
  62. # autocommit
  63. if auto_commit: self.commit()
  64. # execute
  65. try:
  66. if values!=():
  67. if isinstance(values, dict):
  68. values = dict(values)
  69. if debug:
  70. try:
  71. self.explain_query(query, values)
  72. webnotes.errprint(query % values)
  73. except TypeError:
  74. webnotes.errprint([query, values])
  75. if (conf.get("logging") or False)==2:
  76. webnotes.log("<<<< query")
  77. webnotes.log(query)
  78. webnotes.log("with values:")
  79. webnotes.log(values)
  80. webnotes.log(">>>>")
  81. self._cursor.execute(query, values)
  82. else:
  83. if debug:
  84. self.explain_query(query)
  85. webnotes.errprint(query)
  86. if (conf.get("logging") or False)==2:
  87. webnotes.log("<<<< query")
  88. webnotes.log(query)
  89. webnotes.log(">>>>")
  90. self._cursor.execute(query)
  91. except Exception, e:
  92. # ignore data definition errors
  93. if ignore_ddl and e.args[0] in (1146,1054,1091):
  94. pass
  95. else:
  96. raise
  97. if auto_commit: self.commit()
  98. # scrub output if required
  99. if as_dict:
  100. ret = self.fetch_as_dict(formatted, as_utf8)
  101. if update:
  102. for r in ret:
  103. r.update(update)
  104. return ret
  105. elif as_list:
  106. return self.convert_to_lists(self._cursor.fetchall(), formatted, as_utf8)
  107. elif as_utf8:
  108. return self.convert_to_lists(self._cursor.fetchall(), formatted, as_utf8)
  109. else:
  110. return self._cursor.fetchall()
  111. def explain_query(self, query, values=None):
  112. try:
  113. webnotes.errprint("--- query explain ---")
  114. if values is None:
  115. self._cursor.execute("explain " + query)
  116. else:
  117. self._cursor.execute("explain " + query, values)
  118. import json
  119. webnotes.errprint(json.dumps(self.fetch_as_dict(), indent=1))
  120. webnotes.errprint("--- query explain end ---")
  121. except:
  122. webnotes.errprint("error in query explain")
  123. def sql_list(self, query, values=(), debug=False):
  124. return [r[0] for r in self.sql(query, values, debug=debug)]
  125. def sql_ddl(self, query, values=()):
  126. self.commit()
  127. self.sql(query)
  128. def check_transaction_status(self, query):
  129. if self.transaction_writes and query and query.strip().split()[0].lower() in ['start', 'alter', 'drop', 'create', "begin"]:
  130. raise Exception, 'This statement can cause implicit commit'
  131. if query and query.strip().lower() in ('commit', 'rollback'):
  132. self.transaction_writes = 0
  133. if query[:6].lower() in ['update', 'insert']:
  134. self.transaction_writes += 1
  135. if not webnotes.flags.in_test and self.transaction_writes > 10000:
  136. if self.auto_commit_on_many_writes:
  137. webnotes.conn.commit()
  138. webnotes.conn.begin()
  139. else:
  140. webnotes.msgprint('A very long query was encountered. If you are trying to import data, please do so using smaller files')
  141. raise Exception, 'Bad Query!!! Too many writes'
  142. def fetch_as_dict(self, formatted=0, as_utf8=0):
  143. result = self._cursor.fetchall()
  144. ret = []
  145. needs_formatting = self.needs_formatting(result, formatted)
  146. for r in result:
  147. row_dict = webnotes._dict({})
  148. for i in range(len(r)):
  149. if needs_formatting:
  150. val = self.convert_to_simple_type(r[i], formatted)
  151. else:
  152. val = r[i]
  153. if as_utf8 and type(val) is unicode:
  154. val = val.encode('utf-8')
  155. row_dict[self._cursor.description[i][0]] = val
  156. ret.append(row_dict)
  157. return ret
  158. def needs_formatting(self, result, formatted):
  159. if result and result[0]:
  160. for v in result[0]:
  161. if isinstance(v, (datetime.date, datetime.timedelta, datetime.datetime, long)):
  162. return True
  163. if formatted and isinstance(v, (int, float)):
  164. return True
  165. return False
  166. def get_description(self):
  167. return self._cursor.description
  168. def convert_to_simple_type(self, v, formatted=0):
  169. from webnotes.utils import formatdate, fmt_money
  170. if isinstance(v, (datetime.date, datetime.timedelta, datetime.datetime, long)):
  171. if isinstance(v, datetime.date):
  172. v = unicode(v)
  173. if formatted:
  174. v = formatdate(v)
  175. # time
  176. elif isinstance(v, (datetime.timedelta, datetime.datetime)):
  177. v = unicode(v)
  178. # long
  179. elif isinstance(v, long):
  180. v=int(v)
  181. # convert to strings... (if formatted)
  182. if formatted:
  183. if isinstance(v, float):
  184. v=fmt_money(v)
  185. elif isinstance(v, int):
  186. v = unicode(v)
  187. return v
  188. def convert_to_lists(self, res, formatted=0, as_utf8=0):
  189. nres = []
  190. needs_formatting = self.needs_formatting(res, formatted)
  191. for r in res:
  192. nr = []
  193. for c in r:
  194. if needs_formatting:
  195. val = self.convert_to_simple_type(c, formatted)
  196. else:
  197. val = c
  198. if as_utf8 and type(val) is unicode:
  199. val = val.encode('utf-8')
  200. nr.append(val)
  201. nres.append(nr)
  202. return nres
  203. def convert_to_utf8(self, res, formatted=0):
  204. nres = []
  205. for r in res:
  206. nr = []
  207. for c in r:
  208. if type(c) is unicode:
  209. c = c.encode('utf-8')
  210. nr.append(self.convert_to_simple_type(c, formatted))
  211. nres.append(nr)
  212. return nres
  213. def build_conditions(self, filters):
  214. def _build_condition(key):
  215. """
  216. filter's key is passed by map function
  217. build conditions like:
  218. * ifnull(`fieldname`, default_value) = %(fieldname)s
  219. * `fieldname` [=, !=, >, >=, <, <=] %(fieldname)s
  220. """
  221. _operator = "="
  222. value = filters.get(key)
  223. if isinstance(value, (list, tuple)):
  224. _operator = value[0]
  225. filters[key] = value[1]
  226. if _operator not in ["=", "!=", ">", ">=", "<", "<=", "like"]:
  227. _operator = "="
  228. if "[" in key:
  229. split_key = key.split("[")
  230. return "ifnull(`" + split_key[0] + "`, " + split_key[1][:-1] + ") " \
  231. + _operator + " %(" + key + ")s"
  232. else:
  233. return "`" + key + "` " + _operator + " %(" + key + ")s"
  234. if isinstance(filters, basestring):
  235. filters = { "name": filters }
  236. conditions = map(_build_condition, filters)
  237. return " and ".join(conditions), filters
  238. def get(self, doctype, filters=None, as_dict=True):
  239. return self.get_value(doctype, filters, "*", as_dict=as_dict)
  240. def get_value(self, doctype, filters=None, fieldname="name", ignore=None, as_dict=False, debug=False):
  241. """Get a single / multiple value from a record.
  242. For Single DocType, let filters be = None"""
  243. ret = self.get_values(doctype, filters, fieldname, ignore, as_dict, debug)
  244. return ret and ((len(ret[0]) > 1 or as_dict) and ret[0] or ret[0][0]) or None
  245. def get_values(self, doctype, filters=None, fieldname="name", ignore=None, as_dict=False, debug=False):
  246. if isinstance(filters, list):
  247. return self.get_value_for_many_names(doctype, filters, fieldname, debug=debug)
  248. fields = fieldname
  249. if fieldname!="*":
  250. if isinstance(fieldname, basestring):
  251. fields = [fieldname]
  252. else:
  253. fields = fieldname
  254. if (filters is not None) and (filters!=doctype or doctype=="DocType"):
  255. try:
  256. return self.get_values_from_table(fields, filters, doctype, as_dict, debug)
  257. except Exception, e:
  258. if ignore and e.args[0] in (1146, 1054):
  259. # table or column not found, return None
  260. return None
  261. elif (not ignore) and e.args[0]==1146:
  262. # table not found, look in singles
  263. pass
  264. else:
  265. raise
  266. return self.get_values_from_single(fields, filters, doctype, as_dict, debug)
  267. def get_values_from_single(self, fields, filters, doctype, as_dict=False, debug=False):
  268. if fields=="*" or isinstance(filters, dict):
  269. r = self.sql("""select field, value from tabSingles where doctype=%s""", doctype)
  270. # check if single doc matches with filters
  271. values = webnotes._dict(r)
  272. if isinstance(filters, dict):
  273. for key, value in filters.items():
  274. if values.get(key) != value:
  275. return []
  276. if as_dict:
  277. return values and [values] or []
  278. if isinstance(fields, list):
  279. return [map(lambda d: values.get(d), fields)]
  280. else:
  281. r = self.sql("""select field, value
  282. from tabSingles where field in (%s) and doctype=%s""" \
  283. % (', '.join(['%s'] * len(fields)), '%s'),
  284. tuple(fields) + (doctype,), as_dict=False, debug=debug)
  285. if as_dict:
  286. return r and [webnotes._dict(r)] or []
  287. else:
  288. return r and [[i[1] for i in r]] or []
  289. def get_values_from_table(self, fields, filters, doctype, as_dict, debug):
  290. fl = []
  291. if isinstance(fields, (list, tuple)):
  292. for f in fields:
  293. if "(" in f: # function
  294. fl.append(f)
  295. else:
  296. fl.append("`" + f + "`")
  297. fl = ", ".join(fields)
  298. else:
  299. fl = fields
  300. if fields=="*":
  301. as_dict = True
  302. conditions, filters = self.build_conditions(filters)
  303. r = self.sql("select %s from `tab%s` where %s" % (fl, doctype,
  304. conditions), filters, as_dict=as_dict, debug=debug)
  305. return r
  306. def get_value_for_many_names(self, doctype, names, field, debug=False):
  307. names = filter(None, names)
  308. if names:
  309. return dict(self.sql("select name, `%s` from `tab%s` where name in (%s)" \
  310. % (field, doctype, ", ".join(["%s"]*len(names))), names, debug=debug))
  311. else:
  312. return {}
  313. def set_value(self, dt, dn, field, val, modified=None, modified_by=None):
  314. from webnotes.utils import now
  315. if dn and dt!=dn:
  316. self.sql("""update `tab%s` set `%s`=%s, modified=%s, modified_by=%s
  317. where name=%s""" % (dt, field, "%s", "%s", "%s", "%s"),
  318. (val, modified or now(), modified_by or webnotes.session["user"], dn))
  319. else:
  320. if self.sql("select value from tabSingles where field=%s and doctype=%s", (field, dt)):
  321. self.sql("""update tabSingles set value=%s where field=%s and doctype=%s""",
  322. (val, field, dt))
  323. else:
  324. self.sql("""insert into tabSingles(doctype, field, value)
  325. values (%s, %s, %s)""", (dt, field, val, ))
  326. if field!="modified":
  327. self.set_value(dt, dn, "modified", modified or now())
  328. def set_in_doc(self, doc, field, val):
  329. self.set(doc, field, val)
  330. def set(self, doc, field, val):
  331. from webnotes.utils import now
  332. doc.modified = now()
  333. doc.modified_by = webnotes.session["user"]
  334. self.set_value(doc.doctype, doc.name, field, val, doc.modified, doc.modified_by)
  335. doc.fields[field] = val
  336. def touch(self, doctype, docname):
  337. from webnotes.utils import now
  338. webnotes.conn.sql("""update `tab{doctype}` set `modified`=%s
  339. where name=%s""".format(doctype=doctype), (now(), docname))
  340. def set_global(self, key, val, user='__global'):
  341. self.set_default(key, val, user)
  342. def get_global(self, key, user='__global'):
  343. return self.get_default(key, user)
  344. def set_default(self, key, val, parent="Control Panel"):
  345. """set control panel default (tabDefaultVal)"""
  346. import webnotes.defaults
  347. webnotes.defaults.set_default(key, val, parent)
  348. def add_default(self, key, val, parent="Control Panel"):
  349. import webnotes.defaults
  350. webnotes.defaults.add_default(key, val, parent)
  351. def get_default(self, key, parent="Control Panel"):
  352. """get default value"""
  353. import webnotes.defaults
  354. d = webnotes.defaults.get_defaults(parent).get(key)
  355. return isinstance(d, list) and d[0] or d
  356. def get_defaults_as_list(self, key, parent="Control Panel"):
  357. import webnotes.defaults
  358. d = webnotes.defaults.get_default(key, parent)
  359. return isinstance(d, basestring) and [d] or d
  360. def get_defaults(self, key=None, parent="Control Panel"):
  361. """get all defaults"""
  362. import webnotes.defaults
  363. if key:
  364. return webnotes.defaults.get_defaults(parent).get(key)
  365. else:
  366. return webnotes.defaults.get_defaults(parent)
  367. def begin(self):
  368. return # not required
  369. def commit(self):
  370. self.sql("commit")
  371. webnotes.local.rollback_observers = []
  372. def rollback(self):
  373. self.sql("rollback")
  374. for obj in webnotes.local.rollback_observers:
  375. if hasattr(obj, "on_rollback"):
  376. obj.on_rollback()
  377. webnotes.local.rollback_observers = []
  378. def field_exists(self, dt, fn):
  379. return self.sql("select name from tabDocField where fieldname=%s and parent=%s", (dt, fn))
  380. def table_exists(self, tablename):
  381. return tablename in [d[0] for d in self.sql("show tables")]
  382. def exists(self, dt, dn=None):
  383. if isinstance(dt, basestring):
  384. if dt==dn:
  385. return True # single always exists (!)
  386. try:
  387. return self.sql('select name from `tab%s` where name=%s' % (dt, '%s'), dn)
  388. except:
  389. return None
  390. elif isinstance(dt, dict) and dt.get('doctype'):
  391. try:
  392. conditions = []
  393. for d in dt:
  394. if d == 'doctype': continue
  395. conditions.append('`%s` = "%s"' % (d, dt[d].replace('"', '\"')))
  396. return self.sql('select name from `tab%s` where %s' % \
  397. (dt['doctype'], " and ".join(conditions)))
  398. except:
  399. return None
  400. def count(self, dt, filters=None):
  401. if filters:
  402. conditions, filters = self.build_conditions(filters)
  403. return webnotes.conn.sql("""select count(*)
  404. from `tab%s` where %s""" % (dt, conditions), filters)[0][0]
  405. else:
  406. return webnotes.conn.sql("""select count(*)
  407. from `tab%s`""" % (dt,))[0][0]
  408. def get_table_columns(self, doctype):
  409. return [r[0] for r in self.sql("DESC `tab%s`" % doctype)]
  410. def close(self):
  411. if self._conn:
  412. self._cursor.close()
  413. self._conn.close()
  414. self._cursor = None
  415. self._conn = None