Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.
 
 
 
 
 
 

346 Zeilen
11 KiB

  1. import re
  2. from typing import List, Tuple, Union
  3. import psycopg2
  4. import psycopg2.extensions
  5. from psycopg2.extensions import ISOLATION_LEVEL_REPEATABLE_READ
  6. from psycopg2.errorcodes import STRING_DATA_RIGHT_TRUNCATION
  7. import frappe
  8. from frappe.database.database import Database
  9. from frappe.database.postgres.schema import PostgresTable
  10. from frappe.utils import cstr, get_table_name
  11. # cast decimals as floats
  12. DEC2FLOAT = psycopg2.extensions.new_type(
  13. psycopg2.extensions.DECIMAL.values,
  14. 'DEC2FLOAT',
  15. lambda value, curs: float(value) if value is not None else None)
  16. psycopg2.extensions.register_type(DEC2FLOAT)
  17. class PostgresDatabase(Database):
  18. ProgrammingError = psycopg2.ProgrammingError
  19. TableMissingError = psycopg2.ProgrammingError
  20. OperationalError = psycopg2.OperationalError
  21. InternalError = psycopg2.InternalError
  22. SQLError = psycopg2.ProgrammingError
  23. DataError = psycopg2.DataError
  24. InterfaceError = psycopg2.InterfaceError
  25. REGEX_CHARACTER = '~'
  26. def setup_type_map(self):
  27. self.db_type = 'postgres'
  28. self.type_map = {
  29. 'Currency': ('decimal', '21,9'),
  30. 'Int': ('bigint', None),
  31. 'Long Int': ('bigint', None),
  32. 'Float': ('decimal', '21,9'),
  33. 'Percent': ('decimal', '21,9'),
  34. 'Check': ('smallint', None),
  35. 'Small Text': ('text', ''),
  36. 'Long Text': ('text', ''),
  37. 'Code': ('text', ''),
  38. 'Text Editor': ('text', ''),
  39. 'Markdown Editor': ('text', ''),
  40. 'HTML Editor': ('text', ''),
  41. 'Date': ('date', ''),
  42. 'Datetime': ('timestamp', None),
  43. 'Time': ('time', '6'),
  44. 'Text': ('text', ''),
  45. 'Data': ('varchar', self.VARCHAR_LEN),
  46. 'Link': ('varchar', self.VARCHAR_LEN),
  47. 'Dynamic Link': ('varchar', self.VARCHAR_LEN),
  48. 'Password': ('text', ''),
  49. 'Select': ('varchar', self.VARCHAR_LEN),
  50. 'Rating': ('decimal', '3,2'),
  51. 'Read Only': ('varchar', self.VARCHAR_LEN),
  52. 'Attach': ('text', ''),
  53. 'Attach Image': ('text', ''),
  54. 'Signature': ('text', ''),
  55. 'Color': ('varchar', self.VARCHAR_LEN),
  56. 'Barcode': ('text', ''),
  57. 'Geolocation': ('text', ''),
  58. 'Duration': ('decimal', '21,9'),
  59. 'Icon': ('varchar', self.VARCHAR_LEN)
  60. }
  61. def get_connection(self):
  62. conn = psycopg2.connect("host='{}' dbname='{}' user='{}' password='{}' port={}".format(
  63. self.host, self.user, self.user, self.password, self.port
  64. ))
  65. conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ)
  66. return conn
  67. def escape(self, s, percent=True):
  68. """Escape quotes and percent in given string."""
  69. if isinstance(s, bytes):
  70. s = s.decode('utf-8')
  71. # MariaDB's driver treats None as an empty string
  72. # So Postgres should do the same
  73. if s is None:
  74. s = ''
  75. if percent:
  76. s = s.replace("%", "%%")
  77. s = s.encode('utf-8')
  78. return str(psycopg2.extensions.QuotedString(s))
  79. def get_database_size(self):
  80. ''''Returns database size in MB'''
  81. db_size = self.sql("SELECT (pg_database_size(%s) / 1024 / 1024) as database_size",
  82. self.db_name, as_dict=True)
  83. return db_size[0].get('database_size')
  84. # pylint: disable=W0221
  85. def sql(self, *args, **kwargs):
  86. if args:
  87. # since tuple is immutable
  88. args = list(args)
  89. args[0] = modify_query(args[0])
  90. args = tuple(args)
  91. elif kwargs.get('query'):
  92. kwargs['query'] = modify_query(kwargs.get('query'))
  93. return super(PostgresDatabase, self).sql(*args, **kwargs)
  94. def get_tables(self, cached=True):
  95. return [d[0] for d in self.sql("""select table_name
  96. from information_schema.tables
  97. where table_catalog='{0}'
  98. and table_type = 'BASE TABLE'
  99. and table_schema='{1}'""".format(frappe.conf.db_name, frappe.conf.get("db_schema", "public")))]
  100. def format_date(self, date):
  101. if not date:
  102. return '0001-01-01'
  103. if not isinstance(date, str):
  104. date = date.strftime('%Y-%m-%d')
  105. return date
  106. # column type
  107. @staticmethod
  108. def is_type_number(code):
  109. return code == psycopg2.NUMBER
  110. @staticmethod
  111. def is_type_datetime(code):
  112. return code == psycopg2.DATETIME
  113. # exception type
  114. @staticmethod
  115. def is_deadlocked(e):
  116. return e.pgcode == '40P01'
  117. @staticmethod
  118. def is_timedout(e):
  119. # http://initd.org/psycopg/docs/extensions.html?highlight=datatype#psycopg2.extensions.QueryCanceledError
  120. return isinstance(e, psycopg2.extensions.QueryCanceledError)
  121. @staticmethod
  122. def is_syntax_error(e):
  123. return isinstance(e, psycopg2.errors.SyntaxError)
  124. @staticmethod
  125. def is_table_missing(e):
  126. return getattr(e, 'pgcode', None) == '42P01'
  127. @staticmethod
  128. def is_missing_column(e):
  129. return getattr(e, 'pgcode', None) == '42703'
  130. @staticmethod
  131. def is_access_denied(e):
  132. return e.pgcode == '42501'
  133. @staticmethod
  134. def cant_drop_field_or_key(e):
  135. return e.pgcode.startswith('23')
  136. @staticmethod
  137. def is_duplicate_entry(e):
  138. return e.pgcode == '23505'
  139. @staticmethod
  140. def is_primary_key_violation(e):
  141. return e.pgcode == '23505' and '_pkey' in cstr(e.args[0])
  142. @staticmethod
  143. def is_unique_key_violation(e):
  144. return e.pgcode == '23505' and '_key' in cstr(e.args[0])
  145. @staticmethod
  146. def is_duplicate_fieldname(e):
  147. return e.pgcode == '42701'
  148. @staticmethod
  149. def is_data_too_long(e):
  150. return e.pgcode == STRING_DATA_RIGHT_TRUNCATION
  151. def rename_table(self, old_name: str, new_name: str) -> Union[List, Tuple]:
  152. old_name = get_table_name(old_name)
  153. new_name = get_table_name(new_name)
  154. return self.sql(f"ALTER TABLE `{old_name}` RENAME TO `{new_name}`")
  155. def describe(self, doctype: str)-> Union[List, Tuple]:
  156. table_name = get_table_name(doctype)
  157. return self.sql(f"SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_NAME = '{table_name}'")
  158. def change_column_type(self, doctype: str, column: str, type: str, nullable: bool = False) -> Union[List, Tuple]:
  159. table_name = get_table_name(doctype)
  160. null_constraint = "SET NOT NULL" if not nullable else "DROP NOT NULL"
  161. return self.sql(f"""ALTER TABLE "{table_name}"
  162. ALTER COLUMN "{column}" TYPE {type},
  163. ALTER COLUMN "{column}" {null_constraint}""")
  164. def create_auth_table(self):
  165. self.sql_ddl("""create table if not exists "__Auth" (
  166. "doctype" VARCHAR(140) NOT NULL,
  167. "name" VARCHAR(255) NOT NULL,
  168. "fieldname" VARCHAR(140) NOT NULL,
  169. "password" TEXT NOT NULL,
  170. "encrypted" INT NOT NULL DEFAULT 0,
  171. PRIMARY KEY ("doctype", "name", "fieldname")
  172. )""")
  173. def create_global_search_table(self):
  174. if not '__global_search' in self.get_tables():
  175. self.sql('''create table "__global_search"(
  176. doctype varchar(100),
  177. name varchar({0}),
  178. title varchar({0}),
  179. content text,
  180. route varchar({0}),
  181. published int not null default 0,
  182. unique (doctype, name))'''.format(self.VARCHAR_LEN))
  183. def create_user_settings_table(self):
  184. self.sql_ddl("""create table if not exists "__UserSettings" (
  185. "user" VARCHAR(180) NOT NULL,
  186. "doctype" VARCHAR(180) NOT NULL,
  187. "data" TEXT,
  188. UNIQUE ("user", "doctype")
  189. )""")
  190. def create_help_table(self):
  191. self.sql('''CREATE TABLE "help"(
  192. "path" varchar(255),
  193. "content" text,
  194. "title" text,
  195. "intro" text,
  196. "full_path" text)''')
  197. self.sql('''CREATE INDEX IF NOT EXISTS "help_index" ON "help" ("path")''')
  198. def updatedb(self, doctype, meta=None):
  199. """
  200. Syncs a `DocType` to the table
  201. * creates if required
  202. * updates columns
  203. * updates indices
  204. """
  205. res = self.sql("select issingle from `tabDocType` where name='{}'".format(doctype))
  206. if not res:
  207. raise Exception('Wrong doctype {0} in updatedb'.format(doctype))
  208. if not res[0][0]:
  209. db_table = PostgresTable(doctype, meta)
  210. db_table.validate()
  211. self.commit()
  212. db_table.sync()
  213. self.begin()
  214. @staticmethod
  215. def get_on_duplicate_update(key='name'):
  216. if isinstance(key, list):
  217. key = '", "'.join(key)
  218. return 'ON CONFLICT ("{key}") DO UPDATE SET '.format(
  219. key=key
  220. )
  221. def check_implicit_commit(self, query):
  222. pass # postgres can run DDL in transactions without implicit commits
  223. def has_index(self, table_name, index_name):
  224. return self.sql("""SELECT 1 FROM pg_indexes WHERE tablename='{table_name}'
  225. and indexname='{index_name}' limit 1""".format(table_name=table_name, index_name=index_name))
  226. def add_index(self, doctype: str, fields: List, index_name: str = None):
  227. """Creates an index with given fields if not already created.
  228. Index name will be `fieldname1_fieldname2_index`"""
  229. table_name = get_table_name(doctype)
  230. index_name = index_name or self.get_index_name(fields)
  231. fields_str = '", "'.join(re.sub(r"\(.*\)", "", field) for field in fields)
  232. self.sql_ddl(f'CREATE INDEX IF NOT EXISTS "{index_name}" ON `{table_name}` ("{fields_str}")')
  233. def add_unique(self, doctype, fields, constraint_name=None):
  234. if isinstance(fields, str):
  235. fields = [fields]
  236. if not constraint_name:
  237. constraint_name = "unique_" + "_".join(fields)
  238. if not self.sql("""
  239. SELECT CONSTRAINT_NAME
  240. FROM information_schema.TABLE_CONSTRAINTS
  241. WHERE table_name=%s
  242. AND constraint_type='UNIQUE'
  243. AND CONSTRAINT_NAME=%s""",
  244. ('tab' + doctype, constraint_name)):
  245. self.commit()
  246. self.sql("""ALTER TABLE `tab%s`
  247. ADD CONSTRAINT %s UNIQUE (%s)""" % (doctype, constraint_name, ", ".join(fields)))
  248. def get_table_columns_description(self, table_name):
  249. """Returns list of column and its description"""
  250. # pylint: disable=W1401
  251. return self.sql('''
  252. SELECT a.column_name AS name,
  253. CASE LOWER(a.data_type)
  254. WHEN 'character varying' THEN CONCAT('varchar(', a.character_maximum_length ,')')
  255. WHEN 'timestamp without time zone' THEN 'timestamp'
  256. ELSE a.data_type
  257. END AS type,
  258. BOOL_OR(b.index) AS index,
  259. SPLIT_PART(COALESCE(a.column_default, NULL), '::', 1) AS default,
  260. BOOL_OR(b.unique) AS unique
  261. FROM information_schema.columns a
  262. LEFT JOIN
  263. (SELECT indexdef, tablename,
  264. indexdef LIKE '%UNIQUE INDEX%' AS unique,
  265. indexdef NOT LIKE '%UNIQUE INDEX%' AS index
  266. FROM pg_indexes
  267. WHERE tablename='{table_name}') b
  268. ON SUBSTRING(b.indexdef, '(.*)') LIKE CONCAT('%', a.column_name, '%')
  269. WHERE a.table_name = '{table_name}'
  270. GROUP BY a.column_name, a.data_type, a.column_default, a.character_maximum_length;
  271. '''.format(table_name=table_name), as_dict=1)
  272. def get_database_list(self, target):
  273. return [d[0] for d in self.sql("SELECT datname FROM pg_database;")]
  274. def modify_query(query):
  275. """"Modifies query according to the requirements of postgres"""
  276. # replace ` with " for definitions
  277. query = str(query)
  278. query = query.replace('`', '"')
  279. query = replace_locate_with_strpos(query)
  280. # select from requires ""
  281. if re.search('from tab', query, flags=re.IGNORECASE):
  282. query = re.sub('from tab([a-zA-Z]*)', r'from "tab\1"', query, flags=re.IGNORECASE)
  283. return query
  284. def replace_locate_with_strpos(query):
  285. # strpos is the locate equivalent in postgres
  286. if re.search(r'locate\(', query, flags=re.IGNORECASE):
  287. query = re.sub(r'locate\(([^,]+),([^)]+)\)', r'strpos(\2, \1)', query, flags=re.IGNORECASE)
  288. return query