Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.
 
 
 
 
 
 

383 рядки
12 KiB

  1. import re
  2. from typing import List, Tuple, Union
  3. import psycopg2
  4. import psycopg2.extensions
  5. from psycopg2.extensions import ISOLATION_LEVEL_REPEATABLE_READ
  6. from psycopg2.errorcodes import STRING_DATA_RIGHT_TRUNCATION
  7. import frappe
  8. from frappe.database.database import Database
  9. from frappe.database.postgres.schema import PostgresTable
  10. from frappe.utils import cstr, get_table_name
  11. # cast decimals as floats
  12. DEC2FLOAT = psycopg2.extensions.new_type(
  13. psycopg2.extensions.DECIMAL.values,
  14. 'DEC2FLOAT',
  15. lambda value, curs: float(value) if value is not None else None)
  16. psycopg2.extensions.register_type(DEC2FLOAT)
  17. class PostgresDatabase(Database):
  18. ProgrammingError = psycopg2.ProgrammingError
  19. TableMissingError = psycopg2.ProgrammingError
  20. OperationalError = psycopg2.OperationalError
  21. InternalError = psycopg2.InternalError
  22. SQLError = psycopg2.ProgrammingError
  23. DataError = psycopg2.DataError
  24. InterfaceError = psycopg2.InterfaceError
  25. REGEX_CHARACTER = '~'
  26. def setup_type_map(self):
  27. self.db_type = 'postgres'
  28. self.type_map = {
  29. 'Currency': ('decimal', '21,9'),
  30. 'Int': ('bigint', None),
  31. 'Long Int': ('bigint', None),
  32. 'Float': ('decimal', '21,9'),
  33. 'Percent': ('decimal', '21,9'),
  34. 'Check': ('smallint', None),
  35. 'Small Text': ('text', ''),
  36. 'Long Text': ('text', ''),
  37. 'Code': ('text', ''),
  38. 'Text Editor': ('text', ''),
  39. 'Markdown Editor': ('text', ''),
  40. 'HTML Editor': ('text', ''),
  41. 'Date': ('date', ''),
  42. 'Datetime': ('timestamp', None),
  43. 'Time': ('time', '6'),
  44. 'Text': ('text', ''),
  45. 'Data': ('varchar', self.VARCHAR_LEN),
  46. 'Link': ('varchar', self.VARCHAR_LEN),
  47. 'Dynamic Link': ('varchar', self.VARCHAR_LEN),
  48. 'Password': ('text', ''),
  49. 'Select': ('varchar', self.VARCHAR_LEN),
  50. 'Rating': ('decimal', '3,2'),
  51. 'Read Only': ('varchar', self.VARCHAR_LEN),
  52. 'Attach': ('text', ''),
  53. 'Attach Image': ('text', ''),
  54. 'Signature': ('text', ''),
  55. 'Color': ('varchar', self.VARCHAR_LEN),
  56. 'Barcode': ('text', ''),
  57. 'Geolocation': ('text', ''),
  58. 'Duration': ('decimal', '21,9'),
  59. 'Icon': ('varchar', self.VARCHAR_LEN),
  60. 'Autocomplete': ('varchar', self.VARCHAR_LEN),
  61. }
  62. def get_connection(self):
  63. conn = psycopg2.connect("host='{}' dbname='{}' user='{}' password='{}' port={}".format(
  64. self.host, self.user, self.user, self.password, self.port
  65. ))
  66. conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ)
  67. return conn
  68. def escape(self, s, percent=True):
  69. """Escape quotes and percent in given string."""
  70. if isinstance(s, bytes):
  71. s = s.decode('utf-8')
  72. # MariaDB's driver treats None as an empty string
  73. # So Postgres should do the same
  74. if s is None:
  75. s = ''
  76. if percent:
  77. s = s.replace("%", "%%")
  78. s = s.encode('utf-8')
  79. return str(psycopg2.extensions.QuotedString(s))
  80. def get_database_size(self):
  81. ''''Returns database size in MB'''
  82. db_size = self.sql("SELECT (pg_database_size(%s) / 1024 / 1024) as database_size",
  83. self.db_name, as_dict=True)
  84. return db_size[0].get('database_size')
  85. # pylint: disable=W0221
  86. def sql(self, query, values=(), *args, **kwargs):
  87. return super(PostgresDatabase, self).sql(
  88. modify_query(query),
  89. modify_values(values),
  90. *args,
  91. **kwargs
  92. )
  93. def get_tables(self, cached=True):
  94. return [d[0] for d in self.sql("""select table_name
  95. from information_schema.tables
  96. where table_catalog='{0}'
  97. and table_type = 'BASE TABLE'
  98. and table_schema='{1}'""".format(frappe.conf.db_name, frappe.conf.get("db_schema", "public")))]
  99. def format_date(self, date):
  100. if not date:
  101. return '0001-01-01'
  102. if not isinstance(date, str):
  103. date = date.strftime('%Y-%m-%d')
  104. return date
  105. # column type
  106. @staticmethod
  107. def is_type_number(code):
  108. return code == psycopg2.NUMBER
  109. @staticmethod
  110. def is_type_datetime(code):
  111. return code == psycopg2.DATETIME
  112. # exception type
  113. @staticmethod
  114. def is_deadlocked(e):
  115. return e.pgcode == '40P01'
  116. @staticmethod
  117. def is_timedout(e):
  118. # http://initd.org/psycopg/docs/extensions.html?highlight=datatype#psycopg2.extensions.QueryCanceledError
  119. return isinstance(e, psycopg2.extensions.QueryCanceledError)
  120. @staticmethod
  121. def is_syntax_error(e):
  122. return isinstance(e, psycopg2.errors.SyntaxError)
  123. @staticmethod
  124. def is_table_missing(e):
  125. return getattr(e, 'pgcode', None) == '42P01'
  126. @staticmethod
  127. def is_missing_table(e):
  128. return PostgresDatabase.is_table_missing(e)
  129. @staticmethod
  130. def is_missing_column(e):
  131. return getattr(e, 'pgcode', None) == '42703'
  132. @staticmethod
  133. def is_access_denied(e):
  134. return e.pgcode == '42501'
  135. @staticmethod
  136. def cant_drop_field_or_key(e):
  137. return e.pgcode.startswith('23')
  138. @staticmethod
  139. def is_duplicate_entry(e):
  140. return e.pgcode == '23505'
  141. @staticmethod
  142. def is_primary_key_violation(e):
  143. return getattr(e, "pgcode", None) == '23505' and '_pkey' in cstr(e.args[0])
  144. @staticmethod
  145. def is_unique_key_violation(e):
  146. return getattr(e, "pgcode", None) == '23505' and '_key' in cstr(e.args[0])
  147. @staticmethod
  148. def is_duplicate_fieldname(e):
  149. return e.pgcode == '42701'
  150. @staticmethod
  151. def is_data_too_long(e):
  152. return e.pgcode == STRING_DATA_RIGHT_TRUNCATION
  153. def rename_table(self, old_name: str, new_name: str) -> Union[List, Tuple]:
  154. old_name = get_table_name(old_name)
  155. new_name = get_table_name(new_name)
  156. return self.sql(f"ALTER TABLE `{old_name}` RENAME TO `{new_name}`")
  157. def describe(self, doctype: str)-> Union[List, Tuple]:
  158. table_name = get_table_name(doctype)
  159. return self.sql(f"SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_NAME = '{table_name}'")
  160. def change_column_type(self, doctype: str, column: str, type: str, nullable: bool = False) -> Union[List, Tuple]:
  161. table_name = get_table_name(doctype)
  162. null_constraint = "SET NOT NULL" if not nullable else "DROP NOT NULL"
  163. return self.sql(f"""ALTER TABLE "{table_name}"
  164. ALTER COLUMN "{column}" TYPE {type},
  165. ALTER COLUMN "{column}" {null_constraint}""")
  166. def create_auth_table(self):
  167. self.sql_ddl("""create table if not exists "__Auth" (
  168. "doctype" VARCHAR(140) NOT NULL,
  169. "name" VARCHAR(255) NOT NULL,
  170. "fieldname" VARCHAR(140) NOT NULL,
  171. "password" TEXT NOT NULL,
  172. "encrypted" INT NOT NULL DEFAULT 0,
  173. PRIMARY KEY ("doctype", "name", "fieldname")
  174. )""")
  175. def create_global_search_table(self):
  176. if not '__global_search' in self.get_tables():
  177. self.sql('''create table "__global_search"(
  178. doctype varchar(100),
  179. name varchar({0}),
  180. title varchar({0}),
  181. content text,
  182. route varchar({0}),
  183. published int not null default 0,
  184. unique (doctype, name))'''.format(self.VARCHAR_LEN))
  185. def create_user_settings_table(self):
  186. self.sql_ddl("""create table if not exists "__UserSettings" (
  187. "user" VARCHAR(180) NOT NULL,
  188. "doctype" VARCHAR(180) NOT NULL,
  189. "data" TEXT,
  190. UNIQUE ("user", "doctype")
  191. )""")
  192. def create_help_table(self):
  193. self.sql('''CREATE TABLE "help"(
  194. "path" varchar(255),
  195. "content" text,
  196. "title" text,
  197. "intro" text,
  198. "full_path" text)''')
  199. self.sql('''CREATE INDEX IF NOT EXISTS "help_index" ON "help" ("path")''')
  200. def updatedb(self, doctype, meta=None):
  201. """
  202. Syncs a `DocType` to the table
  203. * creates if required
  204. * updates columns
  205. * updates indices
  206. """
  207. res = self.sql("select issingle from `tabDocType` where name='{}'".format(doctype))
  208. if not res:
  209. raise Exception('Wrong doctype {0} in updatedb'.format(doctype))
  210. if not res[0][0]:
  211. db_table = PostgresTable(doctype, meta)
  212. db_table.validate()
  213. self.commit()
  214. db_table.sync()
  215. self.begin()
  216. @staticmethod
  217. def get_on_duplicate_update(key='name'):
  218. if isinstance(key, list):
  219. key = '", "'.join(key)
  220. return 'ON CONFLICT ("{key}") DO UPDATE SET '.format(
  221. key=key
  222. )
  223. def check_implicit_commit(self, query):
  224. pass # postgres can run DDL in transactions without implicit commits
  225. def has_index(self, table_name, index_name):
  226. return self.sql("""SELECT 1 FROM pg_indexes WHERE tablename='{table_name}'
  227. and indexname='{index_name}' limit 1""".format(table_name=table_name, index_name=index_name))
  228. def add_index(self, doctype: str, fields: List, index_name: str = None):
  229. """Creates an index with given fields if not already created.
  230. Index name will be `fieldname1_fieldname2_index`"""
  231. table_name = get_table_name(doctype)
  232. index_name = index_name or self.get_index_name(fields)
  233. fields_str = '", "'.join(re.sub(r"\(.*\)", "", field) for field in fields)
  234. self.sql_ddl(f'CREATE INDEX IF NOT EXISTS "{index_name}" ON `{table_name}` ("{fields_str}")')
  235. def add_unique(self, doctype, fields, constraint_name=None):
  236. if isinstance(fields, str):
  237. fields = [fields]
  238. if not constraint_name:
  239. constraint_name = "unique_" + "_".join(fields)
  240. if not self.sql("""
  241. SELECT CONSTRAINT_NAME
  242. FROM information_schema.TABLE_CONSTRAINTS
  243. WHERE table_name=%s
  244. AND constraint_type='UNIQUE'
  245. AND CONSTRAINT_NAME=%s""",
  246. ('tab' + doctype, constraint_name)):
  247. self.commit()
  248. self.sql("""ALTER TABLE `tab%s`
  249. ADD CONSTRAINT %s UNIQUE (%s)""" % (doctype, constraint_name, ", ".join(fields)))
  250. def get_table_columns_description(self, table_name):
  251. """Returns list of column and its description"""
  252. # pylint: disable=W1401
  253. return self.sql('''
  254. SELECT a.column_name AS name,
  255. CASE LOWER(a.data_type)
  256. WHEN 'character varying' THEN CONCAT('varchar(', a.character_maximum_length ,')')
  257. WHEN 'timestamp without time zone' THEN 'timestamp'
  258. ELSE a.data_type
  259. END AS type,
  260. BOOL_OR(b.index) AS index,
  261. SPLIT_PART(COALESCE(a.column_default, NULL), '::', 1) AS default,
  262. BOOL_OR(b.unique) AS unique
  263. FROM information_schema.columns a
  264. LEFT JOIN
  265. (SELECT indexdef, tablename,
  266. indexdef LIKE '%UNIQUE INDEX%' AS unique,
  267. indexdef NOT LIKE '%UNIQUE INDEX%' AS index
  268. FROM pg_indexes
  269. WHERE tablename='{table_name}') b
  270. ON SUBSTRING(b.indexdef, '(.*)') LIKE CONCAT('%', a.column_name, '%')
  271. WHERE a.table_name = '{table_name}'
  272. GROUP BY a.column_name, a.data_type, a.column_default, a.character_maximum_length;
  273. '''.format(table_name=table_name), as_dict=1)
  274. def get_database_list(self, target):
  275. return [d[0] for d in self.sql("SELECT datname FROM pg_database;")]
  276. def modify_query(query):
  277. """"Modifies query according to the requirements of postgres"""
  278. # replace ` with " for definitions
  279. query = str(query)
  280. query = query.replace('`', '"')
  281. query = replace_locate_with_strpos(query)
  282. # select from requires ""
  283. if re.search('from tab', query, flags=re.IGNORECASE):
  284. query = re.sub(r'from tab([\w-]*)', r'from "tab\1"', query, flags=re.IGNORECASE)
  285. # only find int (with/without signs), ignore decimals (with/without signs), ignore hashes (which start with numbers),
  286. # drop .0 from decimals and add quotes around them
  287. #
  288. # >>> query = "c='abcd' , a >= 45, b = -45.0, c = 40, d=4500.0, e=3500.53, f=40psdfsd, g=9092094312, h=12.00023"
  289. # >>> re.sub(r"([=><]+)\s*(?!\d+[a-zA-Z])(?![+-]?\d+\.\d\d+)([+-]?\d+)(\.0)?", r"\1 '\2'", query)
  290. # "c='abcd' , a >= '45', b = '-45', c = '40', d= '4500', e=3500.53, f=40psdfsd, g= '9092094312', h=12.00023
  291. query = re.sub(r"([=><]+)\s*(?!\d+[a-zA-Z])(?![+-]?\d+\.\d\d+)([+-]?\d+)(\.0)?", r"\1 '\2'", query)
  292. return query
  293. def modify_values(values):
  294. def stringify_value(value):
  295. if isinstance(value, int):
  296. value = str(value)
  297. elif isinstance(value, float):
  298. truncated_float = int(value)
  299. if value == truncated_float:
  300. value = str(truncated_float)
  301. return value
  302. if not values:
  303. return values
  304. if isinstance(values, dict):
  305. for k, v in values.items():
  306. values[k] = stringify_value(v)
  307. elif isinstance(values, (tuple, list)):
  308. new_values = []
  309. for val in values:
  310. new_values.append(stringify_value(val))
  311. values = new_values
  312. else:
  313. values = stringify_value(values)
  314. return values
  315. def replace_locate_with_strpos(query):
  316. # strpos is the locate equivalent in postgres
  317. if re.search(r'locate\(', query, flags=re.IGNORECASE):
  318. query = re.sub(r'locate\(([^,]+),([^)]+)(\)?)\)', r'strpos(\2\3, \1)', query, flags=re.IGNORECASE)
  319. return query