You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

170 lines
4.7 KiB

  1. import copy
  2. import datetime
  3. import signal
  4. import unittest
  5. from contextlib import contextmanager
  6. import frappe
  7. from frappe.model.base_document import BaseDocument
  8. from frappe.utils import cint
  9. datetime_like_types = (datetime.datetime, datetime.date, datetime.time, datetime.timedelta)
  10. class FrappeTestCase(unittest.TestCase):
  11. """Base test class for Frappe tests.
  12. If you specify `setUpClass` then make sure to call `super().setUpClass`
  13. otherwise this class will become ineffective.
  14. """
  15. TEST_SITE = "test_site"
  16. SHOW_TRANSACTION_COMMIT_WARNINGS = False
  17. @classmethod
  18. def setUpClass(cls) -> None:
  19. cls.TEST_SITE = getattr(frappe.local, "site", None) or cls.TEST_SITE
  20. # flush changes done so far to avoid flake
  21. frappe.db.commit()
  22. frappe.db.begin()
  23. if cls.SHOW_TRANSACTION_COMMIT_WARNINGS:
  24. frappe.db.add_before_commit(_commit_watcher)
  25. # enqueue teardown actions (executed in LIFO order)
  26. cls.addClassCleanup(_restore_thread_locals, copy.deepcopy(frappe.local.flags))
  27. cls.addClassCleanup(_rollback_db)
  28. return super().setUpClass()
  29. # --- Frappe Framework specific assertions
  30. def assertDocumentEqual(self, expected, actual):
  31. """Compare a (partial) expected document with actual Document."""
  32. if isinstance(expected, BaseDocument):
  33. expected = expected.as_dict()
  34. for field, value in expected.items():
  35. if isinstance(value, list):
  36. actual_child_docs = actual.get(field)
  37. self.assertEqual(len(value), len(actual_child_docs), msg=f"{field} length should be same")
  38. for exp_child, actual_child in zip(value, actual_child_docs):
  39. self.assertDocumentEqual(exp_child, actual_child)
  40. else:
  41. self._compare_field(value, actual.get(field), actual, field)
  42. def _compare_field(self, expected, actual, doc, field):
  43. msg = f"{field} should be same."
  44. if isinstance(expected, float):
  45. precision = doc.precision(field)
  46. self.assertAlmostEqual(expected, actual, f"{field} should be same to {precision} digits")
  47. elif isinstance(expected, (bool, int)):
  48. self.assertEqual(expected, cint(actual), msg=msg)
  49. elif isinstance(expected, datetime_like_types):
  50. self.assertEqual(str(expected), str(actual), msg=msg)
  51. else:
  52. self.assertEqual(expected, actual, msg=msg)
  53. @contextmanager
  54. def assertQueryCount(self, count):
  55. def _sql_with_count(*args, **kwargs):
  56. frappe.db.sql_query_count += 1
  57. return orig_sql(*args, **kwargs)
  58. try:
  59. orig_sql = frappe.db.sql
  60. frappe.db.sql_query_count = 0
  61. frappe.db.sql = _sql_with_count
  62. yield
  63. self.assertLessEqual(frappe.db.sql_query_count, count)
  64. finally:
  65. frappe.db.sql = orig_sql
  66. def _commit_watcher():
  67. import traceback
  68. print("Warning:, transaction committed during tests.")
  69. traceback.print_stack(limit=5)
  70. def _rollback_db():
  71. frappe.local.before_commit = []
  72. frappe.local.rollback_observers = []
  73. frappe.db.value_cache = {}
  74. frappe.db.rollback()
  75. def _restore_thread_locals(flags):
  76. frappe.local.flags = flags
  77. frappe.local.error_log = []
  78. frappe.local.message_log = []
  79. frappe.local.debug_log = []
  80. frappe.local.realtime_log = []
  81. frappe.local.conf = frappe._dict(frappe.get_site_config())
  82. frappe.local.cache = {}
  83. frappe.local.lang = "en"
  84. frappe.local.lang_full_dict = None
  85. @contextmanager
  86. def change_settings(doctype, settings_dict):
  87. """A context manager to ensure that settings are changed before running
  88. function and restored after running it regardless of exceptions occured.
  89. This is useful in tests where you want to make changes in a function but
  90. don't retain those changes.
  91. import and use as decorator to cover full function or using `with` statement.
  92. example:
  93. @change_settings("Print Settings", {"send_print_as_pdf": 1})
  94. def test_case(self):
  95. ...
  96. """
  97. try:
  98. settings = frappe.get_doc(doctype)
  99. # remember setting
  100. previous_settings = copy.deepcopy(settings_dict)
  101. for key in previous_settings:
  102. previous_settings[key] = getattr(settings, key)
  103. # change setting
  104. for key, value in settings_dict.items():
  105. setattr(settings, key, value)
  106. settings.save()
  107. # singles are cached by default, clear to avoid flake
  108. frappe.db.value_cache[settings] = {}
  109. yield # yield control to calling function
  110. finally:
  111. # restore settings
  112. settings = frappe.get_doc(doctype)
  113. for key, value in previous_settings.items():
  114. setattr(settings, key, value)
  115. settings.save()
  116. def timeout(seconds=30, error_message="Test timed out."):
  117. """Timeout decorator to ensure a test doesn't run for too long.
  118. adapted from https://stackoverflow.com/a/2282656"""
  119. def decorator(func):
  120. def _handle_timeout(signum, frame):
  121. raise Exception(error_message)
  122. def wrapper(*args, **kwargs):
  123. signal.signal(signal.SIGALRM, _handle_timeout)
  124. signal.alarm(seconds)
  125. try:
  126. result = func(*args, **kwargs)
  127. finally:
  128. signal.alarm(0)
  129. return result
  130. return wrapper
  131. return decorator