Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.
 
 
 
 
 
 

283 rader
9.2 KiB

  1. import json
  2. import os
  3. import re
  4. import sys
  5. import time
  6. import unittest
  7. import click
  8. import frappe
  9. import requests
  10. from .test_runner import (SLOW_TEST_THRESHOLD, make_test_records, set_test_email_config)
  11. click_ctx = click.get_current_context(True)
  12. if click_ctx:
  13. click_ctx.color = True
  14. class ParallelTestRunner():
  15. def __init__(self, app, site, build_number=1, total_builds=1, with_coverage=False):
  16. self.app = app
  17. self.site = site
  18. self.with_coverage = with_coverage
  19. self.build_number = frappe.utils.cint(build_number) or 1
  20. self.total_builds = frappe.utils.cint(total_builds)
  21. self.setup_test_site()
  22. self.run_tests()
  23. def setup_test_site(self):
  24. frappe.init(site=self.site)
  25. if not frappe.db:
  26. frappe.connect()
  27. frappe.flags.in_test = True
  28. frappe.clear_cache()
  29. frappe.utils.scheduler.disable_scheduler()
  30. set_test_email_config()
  31. self.before_test_setup()
  32. def before_test_setup(self):
  33. start_time = time.time()
  34. for fn in frappe.get_hooks("before_tests", app_name=self.app):
  35. frappe.get_attr(fn)()
  36. test_module = frappe.get_module(f'{self.app}.tests')
  37. if hasattr(test_module, "global_test_dependencies"):
  38. for doctype in test_module.global_test_dependencies:
  39. make_test_records(doctype)
  40. elapsed = time.time() - start_time
  41. elapsed = click.style(f' ({elapsed:.03}s)', fg='red')
  42. click.echo(f'Before Test {elapsed}')
  43. def run_tests(self):
  44. self.test_result = ParallelTestResult(stream=sys.stderr, descriptions=True, verbosity=2)
  45. self.start_coverage()
  46. for test_file_info in self.get_test_file_list():
  47. self.run_tests_for_file(test_file_info)
  48. self.save_coverage()
  49. self.print_result()
  50. def run_tests_for_file(self, file_info):
  51. if not file_info: return
  52. frappe.set_user('Administrator')
  53. path, filename = file_info
  54. module = self.get_module(path, filename)
  55. self.create_test_dependency_records(module, path, filename)
  56. test_suite = unittest.TestSuite()
  57. module_test_cases = unittest.TestLoader().loadTestsFromModule(module)
  58. test_suite.addTest(module_test_cases)
  59. test_suite(self.test_result)
  60. def create_test_dependency_records(self, module, path, filename):
  61. if hasattr(module, "test_dependencies"):
  62. for doctype in module.test_dependencies:
  63. make_test_records(doctype)
  64. if os.path.basename(os.path.dirname(path)) == "doctype":
  65. # test_data_migration_connector.py > data_migration_connector.json
  66. test_record_filename = re.sub('^test_', '', filename).replace(".py", ".json")
  67. test_record_file_path = os.path.join(path, test_record_filename)
  68. if os.path.exists(test_record_file_path):
  69. with open(test_record_file_path, 'r') as f:
  70. doc = json.loads(f.read())
  71. doctype = doc["name"]
  72. make_test_records(doctype)
  73. def get_module(self, path, filename):
  74. app_path = frappe.get_pymodule_path(self.app)
  75. relative_path = os.path.relpath(path, app_path)
  76. if relative_path == '.':
  77. module_name = self.app
  78. else:
  79. relative_path = relative_path.replace('/', '.')
  80. module_name = os.path.splitext(filename)[0]
  81. module_name = f'{self.app}.{relative_path}.{module_name}'
  82. return frappe.get_module(module_name)
  83. def print_result(self):
  84. self.test_result.printErrors()
  85. click.echo(self.test_result)
  86. if self.test_result.failures or self.test_result.errors:
  87. if os.environ.get('CI'):
  88. sys.exit(1)
  89. def start_coverage(self):
  90. if self.with_coverage:
  91. from coverage import Coverage
  92. from frappe.utils import get_bench_path
  93. # Generate coverage report only for app that is being tested
  94. source_path = os.path.join(get_bench_path(), 'apps', self.app)
  95. omit=['*.html', '*.js', '*.xml', '*.css', '*.less', '*.scss',
  96. '*.vue', '*/doctype/*/*_dashboard.py', '*/patches/*']
  97. if self.app == 'frappe':
  98. omit.append('*/commands/*')
  99. self.coverage = Coverage(source=[source_path], omit=omit)
  100. self.coverage.start()
  101. def save_coverage(self):
  102. if not self.with_coverage:
  103. return
  104. self.coverage.stop()
  105. self.coverage.save()
  106. def get_test_file_list(self):
  107. test_list = get_all_tests(self.app)
  108. split_size = frappe.utils.ceil(len(test_list) / self.total_builds)
  109. # [1,2,3,4,5,6] to [[1,2], [3,4], [4,6]] if split_size is 2
  110. test_chunks = [test_list[x:x+split_size] for x in range(0, len(test_list), split_size)]
  111. return test_chunks[self.build_number - 1]
  112. class ParallelTestResult(unittest.TextTestResult):
  113. def startTest(self, test):
  114. self._started_at = time.time()
  115. super(unittest.TextTestResult, self).startTest(test)
  116. test_class = unittest.util.strclass(test.__class__)
  117. if not hasattr(self, 'current_test_class') or self.current_test_class != test_class:
  118. click.echo(f"\n{unittest.util.strclass(test.__class__)}")
  119. self.current_test_class = test_class
  120. def getTestMethodName(self, test):
  121. return test._testMethodName if hasattr(test, '_testMethodName') else str(test)
  122. def addSuccess(self, test):
  123. super(unittest.TextTestResult, self).addSuccess(test)
  124. elapsed = time.time() - self._started_at
  125. threshold_passed = elapsed >= SLOW_TEST_THRESHOLD
  126. elapsed = click.style(f' ({elapsed:.03}s)', fg='red') if threshold_passed else ''
  127. click.echo(f" {click.style(' ✔ ', fg='green')} {self.getTestMethodName(test)}{elapsed}")
  128. def addError(self, test, err):
  129. super(unittest.TextTestResult, self).addError(test, err)
  130. click.echo(f" {click.style(' ✖ ', fg='red')} {self.getTestMethodName(test)}")
  131. def addFailure(self, test, err):
  132. super(unittest.TextTestResult, self).addFailure(test, err)
  133. click.echo(f" {click.style(' ✖ ', fg='red')} {self.getTestMethodName(test)}")
  134. def addSkip(self, test, reason):
  135. super(unittest.TextTestResult, self).addSkip(test, reason)
  136. click.echo(f" {click.style(' = ', fg='white')} {self.getTestMethodName(test)}")
  137. def addExpectedFailure(self, test, err):
  138. super(unittest.TextTestResult, self).addExpectedFailure(test, err)
  139. click.echo(f" {click.style(' ✖ ', fg='red')} {self.getTestMethodName(test)}")
  140. def addUnexpectedSuccess(self, test):
  141. super(unittest.TextTestResult, self).addUnexpectedSuccess(test)
  142. click.echo(f" {click.style(' ✔ ', fg='green')} {self.getTestMethodName(test)}")
  143. def printErrors(self):
  144. click.echo('\n')
  145. self.printErrorList(' ERROR ', self.errors, 'red')
  146. self.printErrorList(' FAIL ', self.failures, 'red')
  147. def printErrorList(self, flavour, errors, color):
  148. for test, err in errors:
  149. click.echo(self.separator1)
  150. click.echo(f"{click.style(flavour, bg=color)} {self.getDescription(test)}")
  151. click.echo(self.separator2)
  152. click.echo(err)
  153. def __str__(self):
  154. return f"Tests: {self.testsRun}, Failing: {len(self.failures)}, Errors: {len(self.errors)}"
  155. def get_all_tests(app):
  156. test_file_list = []
  157. for path, folders, files in os.walk(frappe.get_pymodule_path(app)):
  158. for dontwalk in ('locals', '.git', 'public', '__pycache__'):
  159. if dontwalk in folders:
  160. folders.remove(dontwalk)
  161. # for predictability
  162. folders.sort()
  163. files.sort()
  164. if os.path.sep.join(["doctype", "doctype", "boilerplate"]) in path:
  165. # in /doctype/doctype/boilerplate/
  166. continue
  167. for filename in files:
  168. if filename.startswith("test_") and filename.endswith(".py") \
  169. and filename != 'test_runner.py':
  170. test_file_list.append([path, filename])
  171. return test_file_list
  172. class ParallelTestWithOrchestrator(ParallelTestRunner):
  173. '''
  174. This can be used to balance-out test time across multiple instances
  175. This is dependent on external orchestrator which returns next test to run
  176. orchestrator endpoints
  177. - register-instance (<build_id>, <instance_id>, test_spec_list)
  178. - get-next-test-spec (<build_id>, <instance_id>)
  179. - test-completed (<build_id>, <instance_id>)
  180. '''
  181. def __init__(self, app, site, with_coverage=False):
  182. self.orchestrator_url = os.environ.get('ORCHESTRATOR_URL')
  183. if not self.orchestrator_url:
  184. click.echo('ORCHESTRATOR_URL environment variable not found!')
  185. click.echo('Pass public URL after hosting https://github.com/frappe/test-orchestrator')
  186. sys.exit(1)
  187. self.ci_build_id = os.environ.get('CI_BUILD_ID')
  188. self.ci_instance_id = os.environ.get('CI_INSTANCE_ID') or frappe.generate_hash(length=10)
  189. if not self.ci_build_id:
  190. click.echo('CI_BUILD_ID environment variable not found!')
  191. sys.exit(1)
  192. ParallelTestRunner.__init__(self, app, site, with_coverage=with_coverage)
  193. def run_tests(self):
  194. self.test_status = 'ongoing'
  195. self.register_instance()
  196. super().run_tests()
  197. def get_test_file_list(self):
  198. while self.test_status == 'ongoing':
  199. yield self.get_next_test()
  200. def register_instance(self):
  201. test_spec_list = get_all_tests(self.app)
  202. response_data = self.call_orchestrator('register-instance', data={
  203. 'test_spec_list': test_spec_list
  204. })
  205. self.is_master = response_data.get('is_master')
  206. def get_next_test(self):
  207. response_data = self.call_orchestrator('get-next-test-spec')
  208. self.test_status = response_data.get('status')
  209. return response_data.get('next_test')
  210. def print_result(self):
  211. self.call_orchestrator('test-completed')
  212. return super().print_result()
  213. def call_orchestrator(self, endpoint, data={}):
  214. # add repo token header
  215. # build id in header
  216. headers = {
  217. 'CI-BUILD-ID': self.ci_build_id,
  218. 'CI-INSTANCE-ID': self.ci_instance_id,
  219. 'REPO-TOKEN': '2948288382838DE'
  220. }
  221. url = f'{self.orchestrator_url}/{endpoint}'
  222. res = requests.get(url, json=data, headers=headers)
  223. res.raise_for_status()
  224. response_data = {}
  225. if 'application/json' in res.headers.get('content-type'):
  226. response_data = res.json()
  227. return response_data