You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

235 line
8.2 KiB

  1. import unittest
  2. from typing import Callable
  3. import frappe
  4. from frappe.query_builder.custom import ConstantColumn
  5. from frappe.query_builder.functions import Coalesce, GroupConcat, Match, CombineDatetime
  6. from frappe.query_builder.utils import db_type_is
  7. from frappe.query_builder import Case
  8. def run_only_if(dbtype: db_type_is) -> Callable:
  9. return unittest.skipIf(
  10. db_type_is(frappe.conf.db_type) != dbtype, f"Only runs for {dbtype.value}"
  11. )
  12. @run_only_if(db_type_is.MARIADB)
  13. class TestCustomFunctionsMariaDB(unittest.TestCase):
  14. def test_concat(self):
  15. self.assertEqual("GROUP_CONCAT('Notes')", GroupConcat("Notes").get_sql())
  16. def test_match(self):
  17. query = Match("Notes").Against("text")
  18. self.assertEqual(
  19. " MATCH('Notes') AGAINST ('+text*' IN BOOLEAN MODE)", query.get_sql()
  20. )
  21. def test_constant_column(self):
  22. query = frappe.qb.from_("DocType").select(
  23. "name", ConstantColumn("John").as_("User")
  24. )
  25. self.assertEqual(
  26. query.get_sql(), "SELECT `name`,'John' `User` FROM `tabDocType`"
  27. )
  28. def test_timestamp(self):
  29. note = frappe.qb.DocType("Note")
  30. self.assertEqual("TIMESTAMP(posting_date,posting_time)", CombineDatetime(note.posting_date, note.posting_time).get_sql())
  31. self.assertEqual("TIMESTAMP('2021-01-01','00:00:21')", CombineDatetime("2021-01-01", "00:00:21").get_sql())
  32. todo = frappe.qb.DocType("ToDo")
  33. select_query = (frappe.qb
  34. .from_(note)
  35. .join(todo).on(todo.refernce_name == note.name)
  36. .select(CombineDatetime(note.posting_date, note.posting_time)))
  37. self.assertIn("select timestamp(`tabnote`.`posting_date`,`tabnote`.`posting_time`)", str(select_query).lower())
  38. select_query = select_query.orderby(CombineDatetime(note.posting_date, note.posting_time))
  39. self.assertIn("order by timestamp(`tabnote`.`posting_date`,`tabnote`.`posting_time`)", str(select_query).lower())
  40. select_query = select_query.where(CombineDatetime(note.posting_date, note.posting_time) >= CombineDatetime("2021-01-01", "00:00:01"))
  41. self.assertIn("timestamp(`tabnote`.`posting_date`,`tabnote`.`posting_time`)>=timestamp('2021-01-01','00:00:01')", str(select_query).lower())
  42. select_query = select_query.select(CombineDatetime(note.posting_date, note.posting_time, alias="timestamp"))
  43. self.assertIn("timestamp(`tabnote`.`posting_date`,`tabnote`.`posting_time`) `timestamp`", str(select_query).lower())
  44. @run_only_if(db_type_is.POSTGRES)
  45. class TestCustomFunctionsPostgres(unittest.TestCase):
  46. def test_concat(self):
  47. self.assertEqual("STRING_AGG('Notes',',')", GroupConcat("Notes").get_sql())
  48. def test_match(self):
  49. query = Match("Notes").Against("text")
  50. self.assertEqual(
  51. "TO_TSVECTOR('Notes') @@ PLAINTO_TSQUERY('text')", query.get_sql()
  52. )
  53. def test_constant_column(self):
  54. query = frappe.qb.from_("DocType").select(
  55. "name", ConstantColumn("John").as_("User")
  56. )
  57. self.assertEqual(
  58. query.get_sql(), 'SELECT "name",\'John\' "User" FROM "tabDocType"'
  59. )
  60. def test_timestamp(self):
  61. note = frappe.qb.DocType("Note")
  62. self.assertEqual("posting_date+posting_time", CombineDatetime(note.posting_date, note.posting_time).get_sql())
  63. self.assertEqual("CAST('2021-01-01' AS DATE)+CAST('00:00:21' AS TIME)", CombineDatetime("2021-01-01", "00:00:21").get_sql())
  64. todo = frappe.qb.DocType("ToDo")
  65. select_query = (frappe.qb
  66. .from_(note)
  67. .join(todo).on(todo.refernce_name == note.name)
  68. .select(CombineDatetime(note.posting_date, note.posting_time)))
  69. self.assertIn('select "tabnote"."posting_date"+"tabnote"."posting_time"', str(select_query).lower())
  70. select_query = select_query.orderby(CombineDatetime(note.posting_date, note.posting_time))
  71. self.assertIn('order by "tabnote"."posting_date"+"tabnote"."posting_time"', str(select_query).lower())
  72. select_query = select_query.where(
  73. CombineDatetime(note.posting_date, note.posting_time) >= CombineDatetime('2021-01-01', '00:00:01')
  74. )
  75. self.assertIn("""where "tabnote"."posting_date"+"tabnote"."posting_time">=cast('2021-01-01' as date)+cast('00:00:01' as time)""",
  76. str(select_query).lower())
  77. select_query = select_query.select(CombineDatetime(note.posting_date, note.posting_time, alias="timestamp"))
  78. self.assertIn('"tabnote"."posting_date"+"tabnote"."posting_time" "timestamp"', str(select_query).lower())
  79. class TestBuilderBase(object):
  80. def test_adding_tabs(self):
  81. self.assertEqual("tabNotes", frappe.qb.DocType("Notes").get_sql())
  82. self.assertEqual("__Auth", frappe.qb.DocType("__Auth").get_sql())
  83. self.assertEqual("Notes", frappe.qb.Table("Notes").get_sql())
  84. def test_run_patcher(self):
  85. query = frappe.qb.from_("ToDo").select("*").limit(1)
  86. data = query.run(as_dict=True)
  87. self.assertTrue("run" in dir(query))
  88. self.assertIsInstance(query.run, Callable)
  89. self.assertIsInstance(data, list)
  90. class TestParameterization(unittest.TestCase):
  91. def test_where_conditions(self):
  92. DocType = frappe.qb.DocType("DocType")
  93. query = (
  94. frappe.qb.from_(DocType)
  95. .select(DocType.name)
  96. .where((DocType.owner == "Administrator' --"))
  97. )
  98. self.assertTrue("walk" in dir(query))
  99. query, params = query.walk()
  100. self.assertIn("%(param1)s", query)
  101. self.assertIn("param1", params)
  102. self.assertEqual(params["param1"], "Administrator' --")
  103. def test_set_cnoditions(self):
  104. DocType = frappe.qb.DocType("DocType")
  105. query = frappe.qb.update(DocType).set(DocType.value, "some_value")
  106. self.assertTrue("walk" in dir(query))
  107. query, params = query.walk()
  108. self.assertIn("%(param1)s", query)
  109. self.assertIn("param1", params)
  110. self.assertEqual(params["param1"], "some_value")
  111. def test_where_conditions_functions(self):
  112. DocType = frappe.qb.DocType("DocType")
  113. query = (
  114. frappe.qb.from_(DocType)
  115. .select(DocType.name)
  116. .where(Coalesce(DocType.search_fields == "subject"))
  117. )
  118. self.assertTrue("walk" in dir(query))
  119. query, params = query.walk()
  120. self.assertIn("%(param1)s", query)
  121. self.assertIn("param1", params)
  122. self.assertEqual(params["param1"], "subject")
  123. def test_case(self):
  124. DocType = frappe.qb.DocType("DocType")
  125. query = (
  126. frappe.qb.from_(DocType)
  127. .select(
  128. Case()
  129. .when(DocType.search_fields == "value", "other_value")
  130. .when(Coalesce(DocType.search_fields == "subject_in_function"), "true_value")
  131. .else_("Overdue")
  132. )
  133. )
  134. self.assertTrue("walk" in dir(query))
  135. query, params = query.walk()
  136. self.assertIn("%(param1)s", query)
  137. self.assertIn("param1", params)
  138. self.assertEqual(params["param1"], "value")
  139. self.assertEqual(params["param2"], "other_value")
  140. self.assertEqual(params["param3"], "subject_in_function")
  141. self.assertEqual(params["param4"], "true_value")
  142. self.assertEqual(params["param5"], "Overdue")
  143. def test_case_in_update(self):
  144. DocType = frappe.qb.DocType("DocType")
  145. query = (
  146. frappe.qb.update(DocType)
  147. .set(
  148. "parent",
  149. Case()
  150. .when(DocType.search_fields == "value", "other_value")
  151. .when(Coalesce(DocType.search_fields == "subject_in_function"), "true_value")
  152. .else_("Overdue")
  153. )
  154. )
  155. self.assertTrue("walk" in dir(query))
  156. query, params = query.walk()
  157. self.assertIn("%(param1)s", query)
  158. self.assertIn("param1", params)
  159. self.assertEqual(params["param1"], "value")
  160. self.assertEqual(params["param2"], "other_value")
  161. self.assertEqual(params["param3"], "subject_in_function")
  162. self.assertEqual(params["param4"], "true_value")
  163. self.assertEqual(params["param5"], "Overdue")
  164. @run_only_if(db_type_is.MARIADB)
  165. class TestBuilderMaria(unittest.TestCase, TestBuilderBase):
  166. def test_adding_tabs_in_from(self):
  167. self.assertEqual(
  168. "SELECT * FROM `tabNotes`", frappe.qb.from_("Notes").select("*").get_sql()
  169. )
  170. self.assertEqual(
  171. "SELECT * FROM `__Auth`", frappe.qb.from_("__Auth").select("*").get_sql()
  172. )
  173. @run_only_if(db_type_is.POSTGRES)
  174. class TestBuilderPostgres(unittest.TestCase, TestBuilderBase):
  175. def test_adding_tabs_in_from(self):
  176. self.assertEqual(
  177. 'SELECT * FROM "tabNotes"', frappe.qb.from_("Notes").select("*").get_sql()
  178. )
  179. self.assertEqual(
  180. 'SELECT * FROM "__Auth"', frappe.qb.from_("__Auth").select("*").get_sql()
  181. )
  182. def test_replace_tables(self):
  183. info_schema = frappe.qb.Schema("information_schema")
  184. self.assertEqual(
  185. 'SELECT * FROM "pg_stat_all_tables"',
  186. frappe.qb.from_(info_schema.tables).select("*").get_sql(),
  187. )
  188. def test_replace_fields_post(self):
  189. self.assertEqual("relname", frappe.qb.Field("table_name").get_sql())