Pārlūkot izejas kodu

Merge pull request #15516 from ankush/savepoint_wrapper

feat: savepoint contextmanager
version-14
Ankush Menat pirms 3 gadiem
committed by GitHub
vecāks
revīzija
5492e1930f
Šim parakstam datu bāzē netika atrasta zināma atslēga GPG atslēgas ID: 4AEE18F83AFDEB23
3 mainītis faili ar 64 papildinājumiem un 5 dzēšanām
  1. +2
    -0
      frappe/database/__init__.py
  2. +35
    -5
      frappe/database/database.py
  3. +27
    -0
      frappe/tests/test_db.py

+ 2
- 0
frappe/database/__init__.py Parādīt failu

@@ -4,6 +4,8 @@
# Database Module # Database Module
# -------------------- # --------------------


from frappe.database.database import savepoint

def setup_database(force, source_sql=None, verbose=None, no_mariadb_socket=False): def setup_database(force, source_sql=None, verbose=None, no_mariadb_socket=False):
import frappe import frappe
if frappe.conf.db_type == 'postgres': if frappe.conf.db_type == 'postgres':


+ 35
- 5
frappe/database/database.py Parādīt failu

@@ -4,16 +4,18 @@
# Database Module # Database Module
# -------------------- # --------------------


import datetime
import random
import re import re
import time
from typing import Dict, List, Union
import string
from contextlib import contextmanager
from time import time
from typing import Dict, List, Union, Tuple

import frappe import frappe
import datetime
import frappe.defaults import frappe.defaults
import frappe.model.meta import frappe.model.meta

from frappe import _ from frappe import _
from time import time
from frappe.utils import now, getdate, cast, get_datetime from frappe.utils import now, getdate, cast, get_datetime
from frappe.model.utils.link_count import flush_local_link_count from frappe.model.utils.link_count import flush_local_link_count
from frappe.query_builder.functions import Count from frappe.query_builder.functions import Count
@@ -811,6 +813,9 @@ class Database(object):
Avoid using savepoints when writing to filesystem.""" Avoid using savepoints when writing to filesystem."""
self.sql(f"savepoint {save_point}") self.sql(f"savepoint {save_point}")


def release_savepoint(self, save_point):
self.sql(f"release savepoint {save_point}")

def rollback(self, *, save_point=None): def rollback(self, *, save_point=None):
"""`ROLLBACK` current transaction. Optionally rollback to a known save_point.""" """`ROLLBACK` current transaction. Optionally rollback to a known save_point."""
if save_point: if save_point:
@@ -1097,3 +1102,28 @@ def enqueue_jobs_after_commit():
q.enqueue_call(execute_job, timeout=job.get("timeout"), q.enqueue_call(execute_job, timeout=job.get("timeout"),
kwargs=job.get("queue_args")) kwargs=job.get("queue_args"))
frappe.flags.enqueue_after_commit = [] frappe.flags.enqueue_after_commit = []

@contextmanager
def savepoint(catch: Union[type, Tuple[type, ...]] = Exception):
""" Wrapper for wrapping blocks of DB operations in a savepoint.

as contextmanager:

for doc in docs:
with savepoint(catch=DuplicateError):
doc.insert()

as decorator (wraps FULL function call):

@savepoint(catch=DuplicateError)
def process_doc(doc):
doc.insert()
"""
try:
savepoint = ''.join(random.sample(string.ascii_lowercase, 10))
frappe.db.savepoint(savepoint)
yield # control back to calling function
except catch:
frappe.db.rollback(save_point=savepoint)
else:
frappe.db.release_savepoint(savepoint)

+ 27
- 0
frappe/tests/test_db.py Parādīt failu

@@ -12,6 +12,7 @@ from frappe.custom.doctype.custom_field.custom_field import create_custom_field
from frappe.utils import random_string from frappe.utils import random_string
from frappe.utils.testutils import clear_custom_fields from frappe.utils.testutils import clear_custom_fields
from frappe.query_builder import Field from frappe.query_builder import Field
from frappe.database import savepoint


from .test_query_builder import run_only_if, db_type_is from .test_query_builder import run_only_if, db_type_is
from frappe.query_builder.functions import Concat_ws from frappe.query_builder.functions import Concat_ws
@@ -267,6 +268,32 @@ class TestDB(unittest.TestCase):
for d in created_docs: for d in created_docs:
self.assertTrue(frappe.db.exists("ToDo", d)) self.assertTrue(frappe.db.exists("ToDo", d))


def test_savepoints_wrapper(self):
frappe.db.rollback()

class SpecificExc(Exception):
pass

created_docs = []
failed_docs = []

for _ in range(5):
with savepoint(catch=SpecificExc):
doc_kept = frappe.get_doc(doctype="ToDo", description="nope").save()
created_docs.append(doc_kept.name)

with savepoint(catch=SpecificExc):
doc_gone = frappe.get_doc(doctype="ToDo", description="nope").save()
failed_docs.append(doc_gone.name)
raise SpecificExc

frappe.db.commit()

for d in failed_docs:
self.assertFalse(frappe.db.exists("ToDo", d))
for d in created_docs:
self.assertTrue(frappe.db.exists("ToDo", d))



@run_only_if(db_type_is.MARIADB) @run_only_if(db_type_is.MARIADB)
class TestDDLCommandsMaria(unittest.TestCase): class TestDDLCommandsMaria(unittest.TestCase):


Notiek ielāde…
Atcelt
Saglabāt