Browse Source

Merge pull request #15560 from gavindsouza/set_value-refactor

refactor: frappe.db.set_value
version-14
mergify[bot] 3 years ago
committed by GitHub
parent
commit
819202f5da
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 572 additions and 207 deletions
  1. +4
    -2
      frappe/__init__.py
  2. +0
    -1
      frappe/client.py
  3. +3
    -2
      frappe/commands/utils.py
  4. +2
    -25
      frappe/core/doctype/file/test_file.py
  5. +59
    -42
      frappe/database/database.py
  6. +17
    -4
      frappe/query_builder/__init__.py
  7. +16
    -4
      frappe/query_builder/builder.py
  8. +69
    -17
      frappe/query_builder/terms.py
  9. +7
    -7
      frappe/query_builder/utils.py
  10. +10
    -7
      frappe/test_runner.py
  11. +146
    -24
      frappe/tests/test_db.py
  12. +70
    -13
      frappe/tests/test_query_builder.py
  13. +71
    -16
      frappe/tests/test_utils.py
  14. +1
    -0
      frappe/tests/test_website.py
  15. +2
    -2
      frappe/utils/__init__.py
  16. +78
    -31
      frappe/utils/data.py
  17. +7
    -2
      frappe/utils/formatters.py
  18. +8
    -6
      frappe/utils/response.py
  19. +2
    -2
      frappe/website/utils.py

+ 4
- 2
frappe/__init__.py View File

@@ -143,6 +143,8 @@ lang = local("lang")
# This if block is never executed when running the code. It is only used for # This if block is never executed when running the code. It is only used for
# telling static code analyzer where to find dynamically defined attributes. # telling static code analyzer where to find dynamically defined attributes.
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from frappe.utils.redis_wrapper import RedisWrapper

from frappe.database.mariadb.database import MariaDBDatabase from frappe.database.mariadb.database import MariaDBDatabase
from frappe.database.postgres.database import PostgresDatabase from frappe.database.postgres.database import PostgresDatabase
from frappe.query_builder.builder import MariaDB, Postgres from frappe.query_builder.builder import MariaDB, Postgres
@@ -150,6 +152,7 @@ if typing.TYPE_CHECKING:
db: typing.Union[MariaDBDatabase, PostgresDatabase] db: typing.Union[MariaDBDatabase, PostgresDatabase]
qb: typing.Union[MariaDB, Postgres] qb: typing.Union[MariaDB, Postgres]



# end: static analysis hack # end: static analysis hack


def init(site, sites_path=None, new_site=False): def init(site, sites_path=None, new_site=False):
@@ -311,9 +314,8 @@ def destroy():


release_local(local) release_local(local)


# memcache
redis_server = None redis_server = None
def cache():
def cache() -> "RedisWrapper":
"""Returns redis connection.""" """Returns redis connection."""
global redis_server global redis_server
if not redis_server: if not redis_server:


+ 0
- 1
frappe/client.py View File

@@ -99,7 +99,6 @@ def get_value(doctype, fieldname, filters=None, as_dict=True, debug=False, paren
if not filters: if not filters:
filters = None filters = None



if frappe.get_meta(doctype).issingle: if frappe.get_meta(doctype).issingle:
value = frappe.db.get_values_from_single(fields, filters, doctype, as_dict=as_dict, debug=debug) value = frappe.db.get_values_from_single(fields, filters, doctype, as_dict=as_dict, debug=debug)
else: else:


+ 3
- 2
frappe/commands/utils.py View File

@@ -623,6 +623,7 @@ def transform_database(context, table, engine, row_format, failfast):
@click.command('run-tests') @click.command('run-tests')
@click.option('--app', help="For App") @click.option('--app', help="For App")
@click.option('--doctype', help="For DocType") @click.option('--doctype', help="For DocType")
@click.option('--case', help="Select particular TestCase")
@click.option('--doctype-list-path', help="Path to .txt file for list of doctypes. Example erpnext/tests/server/agriculture.txt") @click.option('--doctype-list-path', help="Path to .txt file for list of doctypes. Example erpnext/tests/server/agriculture.txt")
@click.option('--test', multiple=True, help="Specific test") @click.option('--test', multiple=True, help="Specific test")
@click.option('--ui-tests', is_flag=True, default=False, help="Run UI Tests") @click.option('--ui-tests', is_flag=True, default=False, help="Run UI Tests")
@@ -636,7 +637,7 @@ def transform_database(context, table, engine, row_format, failfast):
@pass_context @pass_context
def run_tests(context, app=None, module=None, doctype=None, test=(), profile=False, def run_tests(context, app=None, module=None, doctype=None, test=(), profile=False,
coverage=False, junit_xml_output=False, ui_tests = False, doctype_list_path=None, coverage=False, junit_xml_output=False, ui_tests = False, doctype_list_path=None,
skip_test_records=False, skip_before_tests=False, failfast=False):
skip_test_records=False, skip_before_tests=False, failfast=False, case=None):


with CodeCoverage(coverage, app): with CodeCoverage(coverage, app):
import frappe.test_runner import frappe.test_runner
@@ -658,7 +659,7 @@ def run_tests(context, app=None, module=None, doctype=None, test=(), profile=Fal


ret = frappe.test_runner.main(app, module, doctype, context.verbose, tests=tests, ret = frappe.test_runner.main(app, module, doctype, context.verbose, tests=tests,
force=context.force, profile=profile, junit_xml_output=junit_xml_output, force=context.force, profile=profile, junit_xml_output=junit_xml_output,
ui_tests=ui_tests, doctype_list_path=doctype_list_path, failfast=failfast)
ui_tests=ui_tests, doctype_list_path=doctype_list_path, failfast=failfast, case=case)


if len(ret.failures) == 0 and len(ret.errors) == 0: if len(ret.failures) == 0 and len(ret.errors) == 0:
ret = 0 ret = 0


+ 2
- 25
frappe/core/doctype/file/test_file.py View File

@@ -7,9 +7,8 @@ import frappe
import os import os
import unittest import unittest
from frappe import _ from frappe import _
from frappe.core.doctype.file.file import get_attached_images, move_file, get_files_in_folder, unzip_file
from frappe.core.doctype.file.file import File, get_attached_images, move_file, get_files_in_folder, unzip_file
from frappe.utils import get_files_path from frappe.utils import get_files_path
# test_records = frappe.get_test_records('File')


test_content1 = 'Hello' test_content1 = 'Hello'
test_content2 = 'Hello World' test_content2 = 'Hello World'
@@ -24,8 +23,6 @@ def make_test_doc():




class TestSimpleFile(unittest.TestCase): class TestSimpleFile(unittest.TestCase):


def setUp(self): def setUp(self):
self.attached_to_doctype, self.attached_to_docname = make_test_doc() self.attached_to_doctype, self.attached_to_docname = make_test_doc()
self.test_content = test_content1 self.test_content = test_content1
@@ -38,21 +35,13 @@ class TestSimpleFile(unittest.TestCase):
_file.save() _file.save()
self.saved_file_url = _file.file_url self.saved_file_url = _file.file_url



def test_save(self): def test_save(self):
_file = frappe.get_doc("File", {"file_url": self.saved_file_url}) _file = frappe.get_doc("File", {"file_url": self.saved_file_url})
content = _file.get_content() content = _file.get_content()
self.assertEqual(content, self.test_content) self.assertEqual(content, self.test_content)




def tearDown(self):
# File gets deleted on rollback, so blank
pass


class TestBase64File(unittest.TestCase): class TestBase64File(unittest.TestCase):


def setUp(self): def setUp(self):
self.attached_to_doctype, self.attached_to_docname = make_test_doc() self.attached_to_doctype, self.attached_to_docname = make_test_doc()
self.test_content = base64.b64encode(test_content1.encode('utf-8')) self.test_content = base64.b64encode(test_content1.encode('utf-8'))
@@ -66,18 +55,12 @@ class TestBase64File(unittest.TestCase):
_file.save() _file.save()
self.saved_file_url = _file.file_url self.saved_file_url = _file.file_url



def test_saved_content(self): def test_saved_content(self):
_file = frappe.get_doc("File", {"file_url": self.saved_file_url}) _file = frappe.get_doc("File", {"file_url": self.saved_file_url})
content = _file.get_content() content = _file.get_content()
self.assertEqual(content, test_content1) self.assertEqual(content, test_content1)




def tearDown(self):
# File gets deleted on rollback, so blank
pass


class TestSameFileName(unittest.TestCase): class TestSameFileName(unittest.TestCase):
def test_saved_content(self): def test_saved_content(self):
self.attached_to_doctype, self.attached_to_docname = make_test_doc() self.attached_to_doctype, self.attached_to_docname = make_test_doc()
@@ -130,8 +113,6 @@ class TestSameFileName(unittest.TestCase):




class TestSameContent(unittest.TestCase): class TestSameContent(unittest.TestCase):


def setUp(self): def setUp(self):
self.attached_to_doctype1, self.attached_to_docname1 = make_test_doc() self.attached_to_doctype1, self.attached_to_docname1 = make_test_doc()
self.attached_to_doctype2, self.attached_to_docname2 = make_test_doc() self.attached_to_doctype2, self.attached_to_docname2 = make_test_doc()
@@ -186,10 +167,6 @@ class TestSameContent(unittest.TestCase):
limit_property.delete() limit_property.delete()
frappe.clear_cache(doctype='ToDo') frappe.clear_cache(doctype='ToDo')


def tearDown(self):
# File gets deleted on rollback, so blank
pass



class TestFile(unittest.TestCase): class TestFile(unittest.TestCase):
def setUp(self): def setUp(self):
@@ -398,7 +375,7 @@ class TestFile(unittest.TestCase):


def test_make_thumbnail(self): def test_make_thumbnail(self):
# test web image # test web image
test_file = frappe.get_doc({
test_file: File = frappe.get_doc({
"doctype": "File", "doctype": "File",
"file_name": 'logo', "file_name": 'logo',
"file_url": frappe.utils.get_url('/_test/assets/image.jpg'), "file_url": frappe.utils.get_url('/_test/assets/image.jpg'),


+ 59
- 42
frappe/database/database.py View File

@@ -10,19 +10,20 @@ import re
import string import string
from contextlib import contextmanager from contextlib import contextmanager
from time import time from time import time
from typing import Dict, List, Union, Tuple
from typing import Dict, List, Tuple, Union

from pypika.terms import Criterion, NullValue, PseudoColumn


import frappe import frappe
import frappe.defaults import frappe.defaults
import frappe.model.meta import frappe.model.meta
from frappe import _ from frappe import _
from frappe.utils import now, getdate, cast, get_datetime
from frappe.model.utils.link_count import flush_local_link_count from frappe.model.utils.link_count import flush_local_link_count
from frappe.query_builder.functions import Count from frappe.query_builder.functions import Count
from frappe.query_builder.functions import Min, Max, Avg, Sum
from frappe.query_builder.utils import Column
from frappe.query_builder.utils import DocType
from frappe.utils import cast, get_datetime, getdate, now, sbool

from .query import Query from .query import Query
from pypika.terms import Criterion, PseudoColumn




class Database(object): class Database(object):
@@ -557,7 +558,21 @@ class Database(object):
def get_list(*args, **kwargs): def get_list(*args, **kwargs):
return frappe.get_list(*args, **kwargs) return frappe.get_list(*args, **kwargs)


def get_single_value(self, doctype, fieldname, cache=False):
def set_single_value(self, doctype, fieldname, value, *args, **kwargs):
"""Set field value of Single DocType.

:param doctype: DocType of the single object
:param fieldname: `fieldname` of the property
:param value: `value` of the property

Example:

# Update the `deny_multiple_sessions` field in System Settings DocType.
company = frappe.db.set_single_value("System Settings", "deny_multiple_sessions", True)
"""
return self.set_value(doctype, doctype, fieldname, value, *args, **kwargs)

def get_single_value(self, doctype, fieldname, cache=True):
"""Get property of Single DocType. Cache locally by default """Get property of Single DocType. Cache locally by default


:param doctype: DocType of the single object whose value is requested :param doctype: DocType of the single object whose value is requested
@@ -572,7 +587,7 @@ class Database(object):
if not doctype in self.value_cache: if not doctype in self.value_cache:
self.value_cache[doctype] = {} self.value_cache[doctype] = {}


if fieldname in self.value_cache[doctype]:
if cache and fieldname in self.value_cache[doctype]:
return self.value_cache[doctype][fieldname] return self.value_cache[doctype][fieldname]


val = self.query.get_sql( val = self.query.get_sql(
@@ -679,53 +694,55 @@ class Database(object):
:param debug: Print the query in the developer / js console. :param debug: Print the query in the developer / js console.
:param for_update: Will add a row-level lock to the value that is being set so that it can be released on commit. :param for_update: Will add a row-level lock to the value that is being set so that it can be released on commit.
""" """
if not modified:
modified = now()
if not modified_by:
modified_by = frappe.session.user
is_single_doctype = not (dn and dt != dn)
to_update = field if isinstance(field, dict) else {field: val}


to_update = {}
if update_modified: if update_modified:
to_update = {"modified": modified, "modified_by": modified_by}
modified = modified or now()
modified_by = modified_by or frappe.session.user
to_update.update({"modified": modified, "modified_by": modified_by})

if is_single_doctype:
frappe.db.delete(
"Singles",
filters={"field": ("in", tuple(to_update)), "doctype": dt}, debug=debug
)

singles_data = ((dt, key, sbool(value)) for key, value in to_update.items())
query = (
frappe.qb.into("Singles")
.columns("doctype", "field", "value")
.insert(*singles_data)
).run(debug=debug)
frappe.clear_document_cache(dt, dt)


if isinstance(field, dict):
to_update.update(field)
else: else:
to_update.update({field: val})
table = DocType(dt)


if dn and dt!=dn:
# with table
set_values = []
for key in to_update:
set_values.append('`{0}`=%({0})s'.format(key))
if for_update:
docnames = tuple(
self.get_values(dt, dn, "name", debug=debug, for_update=for_update, pluck=True)
) or (NullValue(),)
query = frappe.qb.update(table).where(table.name.isin(docnames))


for name in self.get_values(dt, dn, 'name', for_update=for_update, debug=debug):
values = dict(name=name[0])
values.update(to_update)
for docname in docnames:
frappe.clear_document_cache(dt, docname)


self.sql("""update `tab{0}`
set {1} where name=%(name)s""".format(dt, ', '.join(set_values)),
values, debug=debug)
else:
query = self.query.build_conditions(table=dt, filters=dn, update=True)
# TODO: Fix this; doesn't work rn - gavin@frappe.io
# frappe.cache().hdel_keys(dt, "document_cache")
# Workaround: clear all document caches
frappe.cache().delete_value('document_cache')


frappe.clear_document_cache(dt, values['name'])
else:
# for singles
keys = list(to_update)
self.sql('''
delete from `tabSingles`
where field in ({0}) and
doctype=%s'''.format(', '.join(['%s']*len(keys))),
list(keys) + [dt], debug=debug)
for key, value in to_update.items():
self.sql('''insert into `tabSingles` (doctype, field, value) values (%s, %s, %s)''',
(dt, key, value), debug=debug)

frappe.clear_document_cache(dt, dn)
for column, value in to_update.items():
query = query.set(column, value)

query.run(debug=debug)


if dt in self.value_cache: if dt in self.value_cache:
del self.value_cache[dt] del self.value_cache[dt]



@staticmethod @staticmethod
def set(doc, field, val): def set(doc, field, val):
"""Set value in document. **Avoid**""" """Set value in document. **Avoid**"""


+ 17
- 4
frappe/query_builder/__init__.py View File

@@ -1,8 +1,21 @@
from frappe.query_builder.terms import ParameterizedValueWrapper, ParameterizedFunction
import pypika
import pypika.terms
from pypika import *
from pypika import Field
from pypika.utils import ignore_copy

from frappe.query_builder.terms import ParameterizedFunction, ParameterizedValueWrapper
from frappe.query_builder.utils import (
Column,
DocType,
get_query_builder,
patch_query_aggregation,
patch_query_execute,
)


pypika.terms.ValueWrapper = ParameterizedValueWrapper pypika.terms.ValueWrapper = ParameterizedValueWrapper
pypika.terms.Function = ParameterizedFunction pypika.terms.Function = ParameterizedFunction


from pypika import *
from frappe.query_builder.utils import Column, DocType, get_query_builder, patch_query_execute, patch_query_aggregation
# * Overrides the field() method and replaces it with the a `PseudoColumn` 'field' for consistency
pypika.queries.Selectable.__getattr__ = ignore_copy(lambda table, x: Field(x, table=table))
pypika.queries.Selectable.__getitem__ = ignore_copy(lambda table, x: Field(x, table=table))
pypika.queries.Selectable.field = pypika.terms.PseudoColumn("field")

+ 16
- 4
frappe/query_builder/builder.py View File

@@ -1,8 +1,12 @@
from pypika import MySQLQuery, Order, PostgreSQLQuery, terms from pypika import MySQLQuery, Order, PostgreSQLQuery, terms
from pypika.queries import Schema, Table
from frappe.utils import get_table_name
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder
from pypika.queries import QueryBuilder, Schema, Table
from pypika.terms import Function from pypika.terms import Function


from frappe.query_builder.terms import ParameterizedValueWrapper
from frappe.utils import get_table_name


class Base: class Base:
terms = terms terms = terms
desc = Order.desc desc = Order.desc
@@ -19,13 +23,13 @@ class Base:
return Table(table_name, *args, **kwargs) return Table(table_name, *args, **kwargs)


@classmethod @classmethod
def into(cls, table, *args, **kwargs):
def into(cls, table, *args, **kwargs) -> QueryBuilder:
if isinstance(table, str): if isinstance(table, str):
table = cls.DocType(table) table = cls.DocType(table)
return super().into(table, *args, **kwargs) return super().into(table, *args, **kwargs)


@classmethod @classmethod
def update(cls, table, *args, **kwargs):
def update(cls, table, *args, **kwargs) -> QueryBuilder:
if isinstance(table, str): if isinstance(table, str):
table = cls.DocType(table) table = cls.DocType(table)
return super().update(table, *args, **kwargs) return super().update(table, *args, **kwargs)
@@ -34,6 +38,10 @@ class Base:
class MariaDB(Base, MySQLQuery): class MariaDB(Base, MySQLQuery):
Field = terms.Field Field = terms.Field


@classmethod
def _builder(cls, *args, **kwargs) -> "MySQLQueryBuilder":
return super()._builder(*args, wrapper_cls=ParameterizedValueWrapper, **kwargs)

@classmethod @classmethod
def from_(cls, table, *args, **kwargs): def from_(cls, table, *args, **kwargs):
if isinstance(table, str): if isinstance(table, str):
@@ -53,6 +61,10 @@ class Postgres(Base, PostgreSQLQuery):
# they are two different objects. The quick fix used here is to replace the # they are two different objects. The quick fix used here is to replace the
# Field names in the "Field" function. # Field names in the "Field" function.


@classmethod
def _builder(cls, *args, **kwargs) -> "PostgreSQLQueryBuilder":
return super()._builder(*args, wrapper_cls=ParameterizedValueWrapper, **kwargs)

@classmethod @classmethod
def Field(cls, field_name, *args, **kwargs): def Field(cls, field_name, *args, **kwargs):
if field_name in cls.field_translation: if field_name in cls.field_translation:


+ 69
- 17
frappe/query_builder/terms.py View File

@@ -1,33 +1,76 @@
from datetime import timedelta
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from frappe.utils.data import format_timedelta


from pypika.terms import Function, ValueWrapper from pypika.terms import Function, ValueWrapper
from pypika.utils import format_alias_sql from pypika.utils import format_alias_sql




class NamedParameterWrapper():
def __init__(self, parameters: Dict[str, Any]):
self.parameters = parameters
class NamedParameterWrapper:
"""Utility class to hold parameter values and keys"""


def update_parameters(self, param_key: Any, param_value: Any, **kwargs):
def __init__(self) -> None:
self.parameters = {}

def get_sql(self, param_value: Any, **kwargs) -> str:
"""returns SQL for a parameter, while adding the real value in a dict

Args:
param_value (Any): Value of the parameter

Returns:
str: parameter used in the SQL query
"""
param_key = f"%(param{len(self.parameters) + 1})s"
self.parameters[param_key[2:-2]] = param_value self.parameters[param_key[2:-2]] = param_value
return param_key


def get_sql(self, **kwargs):
return f'%(param{len(self.parameters) + 1})s'
def get_parameters(self) -> Dict[str, Any]:
"""get dict with parameters and values

Returns:
Dict[str, Any]: parameter dict
"""
return self.parameters




class ParameterizedValueWrapper(ValueWrapper): class ParameterizedValueWrapper(ValueWrapper):
def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", param_wrapper= None, **kwargs: Any) -> str:
if param_wrapper is None:
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
"""
Class to monkey patch ValueWrapper

Adds functionality to parameterize queries when a `param wrapper` is passed in get_sql()
"""

def get_sql(
self,
quote_char: Optional[str] = None,
secondary_quote_char: str = "'",
param_wrapper: Optional[NamedParameterWrapper] = None,
**kwargs: Any,
) -> str:
if param_wrapper and isinstance(self.value, str):
# add quotes if it's a string value
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs)
sql = param_wrapper.get_sql(param_value=value_sql, **kwargs)
else: else:
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs) if not isinstance(self.value,int) else self.value
param_sql = param_wrapper.get_sql(**kwargs)
param_wrapper.update_parameters(param_key=param_sql, param_value=value_sql, **kwargs)
return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs)
# * BUG: pypika doesen't parse timedeltas
if isinstance(self.value, timedelta):
self.value = format_timedelta(self.value)
sql = self.get_value_sql(
quote_char=quote_char,
secondary_quote_char=secondary_quote_char,
**kwargs,
)
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)




class ParameterizedFunction(Function): class ParameterizedFunction(Function):
"""
Class to monkey patch pypika.terms.Functions

Only to pass `param_wrapper` in `get_function_sql`.
"""

def get_sql(self, **kwargs: Any) -> str: def get_sql(self, **kwargs: Any) -> str:
with_alias = kwargs.pop("with_alias", False) with_alias = kwargs.pop("with_alias", False)
with_namespace = kwargs.pop("with_namespace", False) with_namespace = kwargs.pop("with_namespace", False)
@@ -35,15 +78,24 @@ class ParameterizedFunction(Function):
dialect = kwargs.pop("dialect", None) dialect = kwargs.pop("dialect", None)
param_wrapper = kwargs.pop("param_wrapper", None) param_wrapper = kwargs.pop("param_wrapper", None)


function_sql = self.get_function_sql(with_namespace=with_namespace, quote_char=quote_char, param_wrapper=param_wrapper, dialect=dialect)
function_sql = self.get_function_sql(
with_namespace=with_namespace,
quote_char=quote_char,
param_wrapper=param_wrapper,
dialect=dialect,
)


if self.schema is not None: if self.schema is not None:
function_sql = "{schema}.{function}".format( function_sql = "{schema}.{function}".format(
schema=self.schema.get_sql(quote_char=quote_char, dialect=dialect, **kwargs),
schema=self.schema.get_sql(
quote_char=quote_char, dialect=dialect, **kwargs
),
function=function_sql, function=function_sql,
) )


if with_alias: if with_alias:
return format_alias_sql(function_sql, self.alias, quote_char=quote_char, **kwargs)
return format_alias_sql(
function_sql, self.alias, quote_char=quote_char, **kwargs
)


return function_sql return function_sql

+ 7
- 7
frappe/query_builder/utils.py View File

@@ -1,16 +1,16 @@
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, Union, get_type_hints
from importlib import import_module from importlib import import_module
from typing import Any, Callable, Dict, Union, get_type_hints


from pypika import Query from pypika import Query
from pypika.queries import Column from pypika.queries import Column
from pypika.terms import PseudoColumn


import frappe import frappe
from frappe.query_builder.terms import NamedParameterWrapper


from .builder import MariaDB, Postgres from .builder import MariaDB, Postgres
from pypika.terms import PseudoColumn


from frappe.query_builder.terms import NamedParameterWrapper


class db_type_is(Enum): class db_type_is(Enum):
MARIADB = "mariadb" MARIADB = "mariadb"
@@ -59,11 +59,11 @@ def patch_query_execute():
return frappe.db.sql(query, params, *args, **kwargs) # nosemgrep return frappe.db.sql(query, params, *args, **kwargs) # nosemgrep


def prepare_query(query): def prepare_query(query):
params = {}
query = query.get_sql(param_wrapper = NamedParameterWrapper(params))
param_collector = NamedParameterWrapper()
query = query.get_sql(param_wrapper=param_collector)
if frappe.flags.in_safe_exec and not query.lower().strip().startswith("select"): if frappe.flags.in_safe_exec and not query.lower().strip().startswith("select"):
raise frappe.PermissionError('Only SELECT SQL allowed in scripting') raise frappe.PermissionError('Only SELECT SQL allowed in scripting')
return query, params
return query, param_collector.get_parameters()


query_class = get_attr(str(frappe.qb).split("'")[1]) query_class = get_attr(str(frappe.qb).split("'")[1])
builder_class = get_type_hints(query_class._builder).get('return') builder_class = get_type_hints(query_class._builder).get('return')
@@ -78,7 +78,7 @@ def patch_query_execute():
def patch_query_aggregation(): def patch_query_aggregation():
"""Patch aggregation functions to frappe.qb """Patch aggregation functions to frappe.qb
""" """
from frappe.query_builder.functions import _max, _min, _avg, _sum
from frappe.query_builder.functions import _avg, _max, _min, _sum


frappe.qb.max = _max frappe.qb.max = _max
frappe.qb.min = _min frappe.qb.min = _min


+ 10
- 7
frappe/test_runner.py View File

@@ -30,7 +30,7 @@ def xmlrunner_wrapper(output):


def main(app=None, module=None, doctype=None, verbose=False, tests=(), def main(app=None, module=None, doctype=None, verbose=False, tests=(),
force=False, profile=False, junit_xml_output=None, ui_tests=False, force=False, profile=False, junit_xml_output=None, ui_tests=False,
doctype_list_path=None, skip_test_records=False, failfast=False):
doctype_list_path=None, skip_test_records=False, failfast=False, case=None):
global unittest_runner global unittest_runner


if doctype_list_path: if doctype_list_path:
@@ -76,7 +76,7 @@ def main(app=None, module=None, doctype=None, verbose=False, tests=(),
if doctype: if doctype:
ret = run_tests_for_doctype(doctype, verbose, tests, force, profile, failfast=failfast, junit_xml_output=junit_xml_output) ret = run_tests_for_doctype(doctype, verbose, tests, force, profile, failfast=failfast, junit_xml_output=junit_xml_output)
elif module: elif module:
ret = run_tests_for_module(module, verbose, tests, profile, failfast=failfast, junit_xml_output=junit_xml_output)
ret = run_tests_for_module(module, verbose, tests, profile, failfast=failfast, junit_xml_output=junit_xml_output, case=case)
else: else:
ret = run_all_tests(app, verbose, profile, ui_tests, failfast=failfast, junit_xml_output=junit_xml_output) ret = run_all_tests(app, verbose, profile, ui_tests, failfast=failfast, junit_xml_output=junit_xml_output)


@@ -182,16 +182,16 @@ def run_tests_for_doctype(doctypes, verbose=False, tests=(), force=False, profil


return _run_unittest(modules, verbose=verbose, tests=tests, profile=profile, failfast=failfast, junit_xml_output=junit_xml_output) return _run_unittest(modules, verbose=verbose, tests=tests, profile=profile, failfast=failfast, junit_xml_output=junit_xml_output)


def run_tests_for_module(module, verbose=False, tests=(), profile=False, failfast=False, junit_xml_output=False):
def run_tests_for_module(module, verbose=False, tests=(), profile=False, failfast=False, junit_xml_output=False, case=None):
module = importlib.import_module(module) module = importlib.import_module(module)
if hasattr(module, "test_dependencies"): if hasattr(module, "test_dependencies"):
for doctype in module.test_dependencies: for doctype in module.test_dependencies:
make_test_records(doctype, verbose=verbose) make_test_records(doctype, verbose=verbose)


frappe.db.commit() frappe.db.commit()
return _run_unittest(module, verbose=verbose, tests=tests, profile=profile, failfast=failfast, junit_xml_output=junit_xml_output)
return _run_unittest(module, verbose=verbose, tests=tests, profile=profile, failfast=failfast, junit_xml_output=junit_xml_output, case=case)


def _run_unittest(modules, verbose=False, tests=(), profile=False, failfast=False, junit_xml_output=False):
def _run_unittest(modules, verbose=False, tests=(), profile=False, failfast=False, junit_xml_output=False, case=None):
frappe.db.begin() frappe.db.begin()


test_suite = unittest.TestSuite() test_suite = unittest.TestSuite()
@@ -200,7 +200,10 @@ def _run_unittest(modules, verbose=False, tests=(), profile=False, failfast=Fals
modules = [modules] modules = [modules]


for module in modules: for module in modules:
module_test_cases = unittest.TestLoader().loadTestsFromModule(module)
if case:
module_test_cases = unittest.TestLoader().loadTestsFromTestCase(getattr(module, case))
else:
module_test_cases = unittest.TestLoader().loadTestsFromModule(module)
if tests: if tests:
for each in module_test_cases: for each in module_test_cases:
for test_case in each.__dict__["_tests"]: for test_case in each.__dict__["_tests"]:
@@ -337,7 +340,7 @@ def make_test_records_for_doctype(doctype, verbose=0, force=False):
elif hasattr(test_module, "test_records"): elif hasattr(test_module, "test_records"):
if doctype in frappe.local.test_objects: if doctype in frappe.local.test_objects:
frappe.local.test_objects[doctype] += make_test_objects(doctype, test_module.test_records, verbose, force) frappe.local.test_objects[doctype] += make_test_objects(doctype, test_module.test_records, verbose, force)
else:
else:
frappe.local.test_objects[doctype] = make_test_objects(doctype, test_module.test_records, verbose, force) frappe.local.test_objects[doctype] = make_test_objects(doctype, test_module.test_records, verbose, force)


else: else:


+ 146
- 24
frappe/tests/test_db.py View File

@@ -1,21 +1,21 @@
# -*- coding: utf-8 -*-

# Copyright (c) 2015, Frappe Technologies Pvt. Ltd. and Contributors
# Copyright (c) 2022, Frappe Technologies Pvt. Ltd. and Contributors
# License: MIT. See LICENSE # License: MIT. See LICENSE


import datetime
import inspect
import unittest import unittest
from random import choice from random import choice
import datetime
from unittest.mock import patch


import frappe import frappe
from frappe.custom.doctype.custom_field.custom_field import create_custom_field from frappe.custom.doctype.custom_field.custom_field import create_custom_field
from frappe.utils import random_string
from frappe.utils.testutils import clear_custom_fields
from frappe.query_builder import Field
from frappe.database import savepoint from frappe.database import savepoint
from .test_query_builder import run_only_if, db_type_is
from frappe.database.database import Database
from frappe.query_builder import Field
from frappe.query_builder.functions import Concat_ws from frappe.query_builder.functions import Concat_ws
from frappe.tests.test_query_builder import db_type_is, run_only_if
from frappe.utils import add_days, now, random_string
from frappe.utils.testutils import clear_custom_fields




class TestDB(unittest.TestCase): class TestDB(unittest.TestCase):
@@ -84,20 +84,6 @@ class TestDB(unittest.TestCase):
), ),
) )


def test_set_value(self):
todo1 = frappe.get_doc(dict(doctype='ToDo', description = 'test_set_value 1')).insert()
todo2 = frappe.get_doc(dict(doctype='ToDo', description = 'test_set_value 2')).insert()

frappe.db.set_value('ToDo', todo1.name, 'description', 'test_set_value change 1')
self.assertEqual(frappe.db.get_value('ToDo', todo1.name, 'description'), 'test_set_value change 1')

# multiple set-value
frappe.db.set_value('ToDo', dict(description=('like', '%test_set_value%')),
'description', 'change 2')

self.assertEqual(frappe.db.get_value('ToDo', todo1.name, 'description'), 'change 2')
self.assertEqual(frappe.db.get_value('ToDo', todo2.name, 'description'), 'change 2')

def test_escape(self): def test_escape(self):
frappe.db.escape("香港濟生堂製藥有限公司 - IT".encode("utf-8")) frappe.db.escape("香港濟生堂製藥有限公司 - IT".encode("utf-8"))


@@ -246,7 +232,6 @@ class TestDB(unittest.TestCase):
frappe.delete_doc(test_doctype, doc) frappe.delete_doc(test_doctype, doc)
clear_custom_fields(test_doctype) clear_custom_fields(test_doctype)



def test_savepoints(self): def test_savepoints(self):
frappe.db.rollback() frappe.db.rollback()
save_point = "todonope" save_point = "todonope"
@@ -365,6 +350,143 @@ class TestDDLCommandsMaria(unittest.TestCase):
self.assertEquals(len(indexs_in_table), 2) self.assertEquals(len(indexs_in_table), 2)




class TestDBSetValue(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.todo1 = frappe.get_doc(doctype="ToDo", description="test_set_value 1").insert()
cls.todo2 = frappe.get_doc(doctype="ToDo", description="test_set_value 2").insert()

def test_update_single_doctype_field(self):
value = frappe.db.get_single_value("System Settings", "deny_multiple_sessions")
changed_value = not value

frappe.db.set_value("System Settings", "System Settings", "deny_multiple_sessions", changed_value)
current_value = frappe.db.get_single_value("System Settings", "deny_multiple_sessions")
self.assertEqual(current_value, changed_value)

changed_value = not current_value
frappe.db.set_value("System Settings", None, "deny_multiple_sessions", changed_value)
current_value = frappe.db.get_single_value("System Settings", "deny_multiple_sessions")
self.assertEqual(current_value, changed_value)

changed_value = not current_value
frappe.db.set_single_value("System Settings", "deny_multiple_sessions", changed_value)
current_value = frappe.db.get_single_value("System Settings", "deny_multiple_sessions")
self.assertEqual(current_value, changed_value)

def test_update_single_row_single_column(self):
frappe.db.set_value("ToDo", self.todo1.name, "description", "test_set_value change 1")
updated_value = frappe.db.get_value("ToDo", self.todo1.name, "description")
self.assertEqual(updated_value, "test_set_value change 1")

def test_update_single_row_multiple_columns(self):
description, status = "Upated by test_update_single_row_multiple_columns", "Closed"

frappe.db.set_value("ToDo", self.todo1.name, {
"description": description,
"status": status,
}, update_modified=False)

updated_desciption, updated_status = frappe.db.get_value("ToDo",
filters={"name": self.todo1.name},
fieldname=["description", "status"]
)

self.assertEqual(description, updated_desciption)
self.assertEqual(status, updated_status)

def test_update_multiple_rows_single_column(self):
frappe.db.set_value("ToDo", {"description": ("like", "%test_set_value%")}, "description", "change 2")

self.assertEqual(frappe.db.get_value("ToDo", self.todo1.name, "description"), "change 2")
self.assertEqual(frappe.db.get_value("ToDo", self.todo2.name, "description"), "change 2")

def test_update_multiple_rows_multiple_columns(self):
todos_to_update = frappe.get_all("ToDo", filters={
"description": ("like", "%test_set_value%"),
"status": ("!=", "Closed")
}, pluck="name")

frappe.db.set_value("ToDo", {
"description": ("like", "%test_set_value%"),
"status": ("!=", "Closed")
}, {
"status": "Closed",
"priority": "High"
})

test_result = frappe.get_all("ToDo", filters={"name": ("in", todos_to_update)}, fields=["status", "priority"])

self.assertTrue(all(x for x in test_result if x["status"] == "Closed"))
self.assertTrue(all(x for x in test_result if x["priority"] == "High"))

def test_update_modified_options(self):
self.todo2.reload()

todo = self.todo2
updated_description = f"{todo.description} - by `test_update_modified_options`"
custom_modified = datetime.datetime.fromisoformat(add_days(now(), 10))
custom_modified_by = "user_that_doesnt_exist@example.com"

frappe.db.set_value("ToDo", todo.name, "description", updated_description, update_modified=False)
self.assertEqual(updated_description, frappe.db.get_value("ToDo", todo.name, "description"))
self.assertEqual(todo.modified, frappe.db.get_value("ToDo", todo.name, "modified"))

frappe.db.set_value("ToDo", todo.name, "description", "test_set_value change 1", modified=custom_modified, modified_by=custom_modified_by)
self.assertTupleEqual(
(custom_modified, custom_modified_by),
frappe.db.get_value("ToDo", todo.name, ["modified", "modified_by"])
)

def test_for_update(self):
self.todo1.reload()

with patch.object(Database, "sql") as sql_called:
frappe.db.set_value(
self.todo1.doctype,
self.todo1.name,
"description",
f"{self.todo1.description}-edit by `test_for_update`"
)
first_query = sql_called.call_args_list[0].args[0]
second_query = sql_called.call_args_list[1].args[0]

self.assertTrue(sql_called.call_count == 2)
self.assertTrue("FOR UPDATE" in first_query)
if frappe.conf.db_type == "postgres":
from frappe.database.postgres.database import modify_query
self.assertTrue(modify_query("UPDATE `tabToDo` SET") in second_query)
if frappe.conf.db_type == "mariadb":
self.assertTrue("UPDATE `tabToDo` SET" in second_query)

def test_cleared_cache(self):
self.todo2.reload()

with patch.object(frappe, "clear_document_cache") as clear_cache:
frappe.db.set_value(
self.todo2.doctype,
self.todo2.name,
"description",
f"{self.todo2.description}-edit by `test_cleared_cache`"
)
clear_cache.assert_called()

def test_update_alias(self):
args = (self.todo1.doctype, self.todo1.name, "description", "Updated by `test_update_alias`")
kwargs = {"for_update": False, "modified": None, "modified_by": None, "update_modified": True, "debug": False}

self.assertTrue("return self.set_value(" in inspect.getsource(frappe.db.update))

with patch.object(Database, "set_value") as set_value:
frappe.db.update(*args, **kwargs)
set_value.assert_called_once()
set_value.assert_called_with(*args, **kwargs)

@classmethod
def tearDownClass(cls):
frappe.db.rollback()


@run_only_if(db_type_is.POSTGRES) @run_only_if(db_type_is.POSTGRES)
class TestDDLCommandsPost(unittest.TestCase): class TestDDLCommandsPost(unittest.TestCase):
test_table_name = "TestNotes" test_table_name = "TestNotes"


+ 70
- 13
frappe/tests/test_query_builder.py View File

@@ -5,7 +5,7 @@ import frappe
from frappe.query_builder.custom import ConstantColumn from frappe.query_builder.custom import ConstantColumn
from frappe.query_builder.functions import Coalesce, GroupConcat, Match from frappe.query_builder.functions import Coalesce, GroupConcat, Match
from frappe.query_builder.utils import db_type_is from frappe.query_builder.utils import db_type_is
from frappe.query_builder import Case


def run_only_if(dbtype: db_type_is) -> Callable: def run_only_if(dbtype: db_type_is) -> Callable:
return unittest.skipIf( return unittest.skipIf(
@@ -25,8 +25,14 @@ class TestCustomFunctionsMariaDB(unittest.TestCase):
) )


def test_constant_column(self): def test_constant_column(self):
query = frappe.qb.from_("DocType").select("name", ConstantColumn("John").as_("User"))
self.assertEqual(query.get_sql(), "SELECT `name`,'John' `User` FROM `tabDocType`")
query = frappe.qb.from_("DocType").select(
"name", ConstantColumn("John").as_("User")
)
self.assertEqual(
query.get_sql(), "SELECT `name`,'John' `User` FROM `tabDocType`"
)


@run_only_if(db_type_is.POSTGRES) @run_only_if(db_type_is.POSTGRES)
class TestCustomFunctionsPostgres(unittest.TestCase): class TestCustomFunctionsPostgres(unittest.TestCase):
def test_concat(self): def test_concat(self):
@@ -39,8 +45,13 @@ class TestCustomFunctionsPostgres(unittest.TestCase):
) )


def test_constant_column(self): def test_constant_column(self):
query = frappe.qb.from_("DocType").select("name", ConstantColumn("John").as_("User"))
self.assertEqual(query.get_sql(), 'SELECT "name",\'John\' "User" FROM "tabDocType"')
query = frappe.qb.from_("DocType").select(
"name", ConstantColumn("John").as_("User")
)
self.assertEqual(
query.get_sql(), 'SELECT "name",\'John\' "User" FROM "tabDocType"'
)



class TestBuilderBase(object): class TestBuilderBase(object):
def test_adding_tabs(self): def test_adding_tabs(self):
@@ -55,23 +66,68 @@ class TestBuilderBase(object):
self.assertIsInstance(query.run, Callable) self.assertIsInstance(query.run, Callable)
self.assertIsInstance(data, list) self.assertIsInstance(data, list)


def test_walk(self):
DocType = frappe.qb.DocType('DocType')

class TestParameterization(unittest.TestCase):
def test_where_conditions(self):
DocType = frappe.qb.DocType("DocType")
query = ( query = (
frappe.qb.from_(DocType) frappe.qb.from_(DocType)
.select(DocType.name) .select(DocType.name)
.where((DocType.owner == "Administrator' --")
& (Coalesce(DocType.search_fields == "subject"))
.where((DocType.owner == "Administrator' --"))
)
self.assertTrue("walk" in dir(query))
query, params = query.walk()

self.assertIn("%(param1)s", query)
self.assertIn("param1", params)
self.assertEqual(params["param1"], "Administrator' --")

def test_set_cnoditions(self):
DocType = frappe.qb.DocType("DocType")
query = frappe.qb.update(DocType).set(DocType.value, "some_value")

self.assertTrue("walk" in dir(query))
query, params = query.walk()

self.assertIn("%(param1)s", query)
self.assertIn("param1", params)
self.assertEqual(params["param1"], "some_value")

def test_where_conditions_functions(self):
DocType = frappe.qb.DocType("DocType")
query = (
frappe.qb.from_(DocType)
.select(DocType.name)
.where(Coalesce(DocType.search_fields == "subject"))
)

self.assertTrue("walk" in dir(query))
query, params = query.walk()

self.assertIn("%(param1)s", query)
self.assertIn("param1", params)
self.assertEqual(params["param1"], "subject")

def test_case(self):
DocType = frappe.qb.DocType("DocType")
query = (
frappe.qb.from_(DocType)
.select(
Case()
.when(DocType.search_fields == "value", "other_value")
.when(Coalesce(DocType.search_fields == "subject_in_function"), "true_value")
) )
) )

self.assertTrue("walk" in dir(query)) self.assertTrue("walk" in dir(query))
query, params = query.walk() query, params = query.walk()


self.assertIn("%(param1)s", query) self.assertIn("%(param1)s", query)
self.assertIn("%(param2)s", query)
self.assertIn("param1",params)
self.assertEqual(params["param1"],"Administrator' --")
self.assertEqual(params["param2"],"subject")
self.assertIn("param1", params)
self.assertEqual(params["param1"], "value")
self.assertEqual(params["param2"], "other_value")
self.assertEqual(params["param3"], "subject_in_function")
self.assertEqual(params["param4"], "true_value")




@run_only_if(db_type_is.MARIADB) @run_only_if(db_type_is.MARIADB)
@@ -84,6 +140,7 @@ class TestBuilderMaria(unittest.TestCase, TestBuilderBase):
"SELECT * FROM `__Auth`", frappe.qb.from_("__Auth").select("*").get_sql() "SELECT * FROM `__Auth`", frappe.qb.from_("__Auth").select("*").get_sql()
) )



@run_only_if(db_type_is.POSTGRES) @run_only_if(db_type_is.POSTGRES)
class TestBuilderPostgres(unittest.TestCase, TestBuilderBase): class TestBuilderPostgres(unittest.TestCase, TestBuilderBase):
def test_adding_tabs_in_from(self): def test_adding_tabs_in_from(self):


+ 71
- 16
frappe/tests/test_utils.py View File

@@ -1,22 +1,28 @@
# Copyright (c) 2015, Frappe Technologies Pvt. Ltd. and Contributors
# Copyright (c) 2022, Frappe Technologies Pvt. Ltd. and Contributors
# License: MIT. See LICENSE # License: MIT. See LICENSE
import unittest
import frappe


from frappe.utils import evaluate_filters, money_in_words, scrub_urls, get_url
from frappe.utils import validate_url, validate_email_address
from frappe.utils import ceil, floor
from frappe.utils.data import cast, validate_python_code
from frappe.utils.diff import get_version_diff, version_query, _get_value_from_version

from PIL import Image
from frappe.utils.image import strip_exif_data, optimize_image
import io import io
import json
import unittest
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from enum import Enum
from mimetypes import guess_type from mimetypes import guess_type
from datetime import datetime, timedelta, date

from unittest.mock import patch from unittest.mock import patch


import pytz
from PIL import Image

import frappe
from frappe.utils import ceil, evaluate_filters, floor, format_timedelta
from frappe.utils import get_url, money_in_words, parse_timedelta, scrub_urls
from frappe.utils import validate_email_address, validate_url
from frappe.utils.data import cast, validate_python_code
from frappe.utils.diff import _get_value_from_version, get_version_diff, version_query
from frappe.utils.image import optimize_image, strip_exif_data
from frappe.utils.response import json_handler


class TestFilters(unittest.TestCase): class TestFilters(unittest.TestCase):
def test_simple_dict(self): def test_simple_dict(self):
self.assertTrue(evaluate_filters({'doctype': 'User', 'status': 'Open'}, {'status': 'Open'})) self.assertTrue(evaluate_filters({'doctype': 'User', 'status': 'Open'}, {'status': 'Open'}))
@@ -273,9 +279,7 @@ class TestPythonExpressions(unittest.TestCase):
for expr in invalid_expressions: for expr in invalid_expressions:
self.assertRaises(frappe.ValidationError, validate_python_code, expr) self.assertRaises(frappe.ValidationError, validate_python_code, expr)



class TestDiffUtils(unittest.TestCase): class TestDiffUtils(unittest.TestCase):

@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.doc = frappe.get_doc(doctype="Client Script", dt="Client Script") cls.doc = frappe.get_doc(doctype="Client Script", dt="Client Script")
@@ -330,8 +334,59 @@ class TestDateUtils(unittest.TestCase):
self.assertEqual(frappe.utils.get_last_day_of_week("2020-12-28"), self.assertEqual(frappe.utils.get_last_day_of_week("2020-12-28"),
frappe.utils.getdate("2021-01-02")) frappe.utils.getdate("2021-01-02"))


class TestXlsxUtils(unittest.TestCase):
class TestResponse(unittest.TestCase):
def test_json_handler(self):
class TEST(Enum):
ABC = "!@)@)!"
BCE = "ENJD"

GOOD_OBJECT = {
"time_types": [
date(year=2020, month=12, day=2),
datetime(year=2020, month=12, day=2, hour=23, minute=23, second=23, microsecond=23, tzinfo=pytz.utc),
time(hour=23, minute=23, second=23, microsecond=23, tzinfo=pytz.utc),
timedelta(days=10, hours=12, minutes=120, seconds=10),
],
"float": [
Decimal(29.21),
],
"doc": [
frappe.get_doc("System Settings"),
],
"iter": [
{1, 2, 3},
(1, 2, 3),
"abcdef",
],
"string": "abcdef"
}

BAD_OBJECT = {"Enum": TEST}

processed_object = json.loads(json.dumps(GOOD_OBJECT, default=json_handler))

self.assertTrue(all([isinstance(x, str) for x in processed_object["time_types"]]))
self.assertTrue(all([isinstance(x, float) for x in processed_object["float"]]))
self.assertTrue(all([isinstance(x, (list, str)) for x in processed_object["iter"]]))
self.assertIsInstance(processed_object["string"], str)
with self.assertRaises(TypeError):
json.dumps(BAD_OBJECT, default=json_handler)

class TestTimeDeltaUtils(unittest.TestCase):
def test_format_timedelta(self):
self.assertEqual(format_timedelta(timedelta(seconds=0)), "0:00:00")
self.assertEqual(format_timedelta(timedelta(hours=10)), "10:00:00")
self.assertEqual(format_timedelta(timedelta(hours=100)), "100:00:00")
self.assertEqual(format_timedelta(timedelta(seconds=100, microseconds=129)), "0:01:40.000129")
self.assertEqual(format_timedelta(timedelta(seconds=100, microseconds=12212199129)), "3:25:12.199129")

def test_parse_timedelta(self):
self.assertEqual(parse_timedelta("0:0:0"), timedelta(seconds=0))
self.assertEqual(parse_timedelta("10:0:0"), timedelta(hours=10))
self.assertEqual(parse_timedelta("7 days, 0:32:18.192221"), timedelta(days=7, seconds=1938, microseconds=192221))
self.assertEqual(parse_timedelta("7 days, 0:32:18"), timedelta(days=7, seconds=1938))


class TestXlsxUtils(unittest.TestCase):
def test_unescape(self): def test_unescape(self):
from frappe.utils.xlsxutils import handle_html from frappe.utils.xlsxutils import handle_html




+ 1
- 0
frappe/tests/test_website.py View File

@@ -20,6 +20,7 @@ class TestWebsite(unittest.TestCase):
doctype='User', doctype='User',
email='test-user-for-home-page@example.com', email='test-user-for-home-page@example.com',
first_name='test')).insert(ignore_if_duplicate=True) first_name='test')).insert(ignore_if_duplicate=True)
user.reload()


role = frappe.get_doc(dict( role = frappe.get_doc(dict(
doctype = 'Role', doctype = 'Role',


+ 2
- 2
frappe/utils/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright (c) 2015, Frappe Technologies Pvt. Ltd. and Contributors
# Copyright (c) 2022, Frappe Technologies Pvt. Ltd. and Contributors
# License: MIT. See LICENSE # License: MIT. See LICENSE


import functools import functools
@@ -56,8 +56,8 @@ def get_email_address(user=None):
def get_formatted_email(user, mail=None): def get_formatted_email(user, mail=None):
"""get Email Address of user formatted as: `John Doe <johndoe@example.com>`""" """get Email Address of user formatted as: `John Doe <johndoe@example.com>`"""
fullname = get_fullname(user) fullname = get_fullname(user)

method = get_hook_method('get_sender_details') method = get_hook_method('get_sender_details')

if method: if method:
sender_name, mail = method() sender_name, mail = method()
# if method exists but sender_name is "" # if method exists but sender_name is ""


+ 78
- 31
frappe/utils/data.py View File

@@ -1,17 +1,22 @@
# Copyright (c) 2015, Frappe Technologies Pvt. Ltd. and Contributors
# Copyright (c) 2022, Frappe Technologies Pvt. Ltd. and Contributors
# License: MIT. See LICENSE # License: MIT. See LICENSE


from typing import Optional
import frappe
import operator
import json
import base64 import base64
import re, datetime, math, time
import datetime
import json
import math
import operator
import re
import time
from code import compile_command from code import compile_command
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.parse import quote, urljoin from urllib.parse import quote, urljoin
from frappe.desk.utils import slug
from click import secho from click import secho
from enum import Enum

import frappe
from frappe.desk.utils import slug


DATE_FORMAT = "%Y-%m-%d" DATE_FORMAT = "%Y-%m-%d"
TIME_FORMAT = "%H:%M:%S.%f" TIME_FORMAT = "%H:%M:%S.%f"
@@ -99,11 +104,17 @@ def get_timedelta(time: Optional[str] = None) -> Optional[datetime.timedelta]:
datetime.timedelta: Timedelta object equivalent of the passed `time` string datetime.timedelta: Timedelta object equivalent of the passed `time` string
""" """
from dateutil import parser from dateutil import parser
from dateutil.parser import ParserError


time = time or "0:0:0" time = time or "0:0:0"


try: try:
t = parser.parse(time)
try:
t = parser.parse(time)
except ParserError as e:
if "day" in e.args[1]:
from frappe.utils import parse_timedelta
return parse_timedelta(time)
return datetime.timedelta( return datetime.timedelta(
hours=t.hour, minutes=t.minute, seconds=t.second, microseconds=t.microsecond hours=t.hour, minutes=t.minute, seconds=t.second, microseconds=t.microsecond
) )
@@ -201,7 +212,7 @@ def get_time_zone():
return frappe.cache().get_value("time_zone", _get_time_zone) return frappe.cache().get_value("time_zone", _get_time_zone)


def convert_utc_to_timezone(utc_timestamp, time_zone): def convert_utc_to_timezone(utc_timestamp, time_zone):
from pytz import timezone, UnknownTimeZoneError
from pytz import UnknownTimeZoneError, timezone
utcnow = timezone('UTC').localize(utc_timestamp) utcnow = timezone('UTC').localize(utc_timestamp)
try: try:
return utcnow.astimezone(timezone(time_zone)) return utcnow.astimezone(timezone(time_zone))
@@ -327,7 +338,7 @@ def get_time(time_str):
return time_str return time_str
else: else:
if isinstance(time_str, datetime.timedelta): if isinstance(time_str, datetime.timedelta):
time_str = str(time_str)
return format_timedelta(time_str)
return parser.parse(time_str).time() return parser.parse(time_str).time()


def get_datetime_str(datetime_obj): def get_datetime_str(datetime_obj):
@@ -610,7 +621,7 @@ def cast(fieldtype, value=None):
value = flt(value) value = flt(value)


elif fieldtype in ("Int", "Check"): elif fieldtype in ("Int", "Check"):
value = cint(value)
value = cint(sbool(value))


elif fieldtype in ("Data", "Text", "Small Text", "Long Text", elif fieldtype in ("Data", "Text", "Small Text", "Long Text",
"Text Editor", "Select", "Link", "Dynamic Link"): "Text Editor", "Select", "Link", "Dynamic Link"):
@@ -726,7 +737,7 @@ def ceil(s):
def cstr(s, encoding='utf-8'): def cstr(s, encoding='utf-8'):
return frappe.as_unicode(s, encoding) return frappe.as_unicode(s, encoding)


def sbool(x):
def sbool(x: str) -> Union[bool, Any]:
"""Converts str object to Boolean if possible. """Converts str object to Boolean if possible.
Example: Example:
"true" becomes True "true" becomes True
@@ -737,12 +748,15 @@ def sbool(x):
x (str): String to be converted to Bool x (str): String to be converted to Bool


Returns: Returns:
object: Returns Boolean or type(x)
object: Returns Boolean or x
""" """
from distutils.util import strtobool

try: try:
return bool(strtobool(x))
val = x.lower()
if val in ('true', '1'):
return True
elif val in ('false', '0'):
return False
return x
except Exception: except Exception:
return x return x


@@ -917,13 +931,13 @@ number_format_info = {
"#.########": (".", "", 8) "#.########": (".", "", 8)
} }


def get_number_format_info(format):
def get_number_format_info(format: str) -> Tuple[str, str, int]:
return number_format_info.get(format) or (".", ",", 2) return number_format_info.get(format) or (".", ",", 2)


# #
# convert currency to words # convert currency to words
# #
def money_in_words(number, main_currency = None, fraction_currency=None):
def money_in_words(number: str, main_currency: Optional[str] = None, fraction_currency: Optional[str] = None):
""" """
Returns string in words with currency and fraction currency. Returns string in words with currency and fraction currency.
""" """
@@ -1009,9 +1023,11 @@ def is_image(filepath):


def get_thumbnail_base64_for_image(src): def get_thumbnail_base64_for_image(src):
from os.path import exists as file_exists from os.path import exists as file_exists

from PIL import Image from PIL import Image

from frappe import cache, safe_decode
from frappe.core.doctype.file.file import get_local_image from frappe.core.doctype.file.file import get_local_image
from frappe import safe_decode, cache


if not src: if not src:
frappe.throw('Invalid source for image: {0}'.format(src)) frappe.throw('Invalid source for image: {0}'.format(src))
@@ -1302,7 +1318,7 @@ operator_map = {
"None": lambda a, b: (not a) and True or False "None": lambda a, b: (not a) and True or False
} }


def evaluate_filters(doc, filters):
def evaluate_filters(doc, filters: Union[Dict, List, Tuple]):
'''Returns true if doc matches filters''' '''Returns true if doc matches filters'''
if isinstance(filters, dict): if isinstance(filters, dict):
for key, value in filters.items(): for key, value in filters.items():
@@ -1319,7 +1335,7 @@ def evaluate_filters(doc, filters):
return True return True




def compare(val1, condition, val2, fieldtype=None):
def compare(val1: Any, condition: str, val2: Any, fieldtype: Optional[str] = None):
ret = False ret = False
if fieldtype: if fieldtype:
val2 = cast(fieldtype, val2) val2 = cast(fieldtype, val2)
@@ -1328,7 +1344,7 @@ def compare(val1, condition, val2, fieldtype=None):


return ret return ret


def get_filter(doctype, f, filters_config=None):
def get_filter(doctype: str, f: Union[Dict, List, Tuple], filters_config=None) -> "frappe._dict":
"""Returns a _dict like """Returns a _dict like


{ {
@@ -1415,8 +1431,10 @@ def make_filter_dict(filters):
return _filter return _filter


def sanitize_column(column_name): def sanitize_column(column_name):
from frappe import _
import sqlparse import sqlparse

from frappe import _

regex = re.compile("^.*[,'();].*") regex = re.compile("^.*[,'();].*")
column_name = sqlparse.format(column_name, strip_comments=True, keyword_case="lower") column_name = sqlparse.format(column_name, strip_comments=True, keyword_case="lower")
blacklisted_keywords = ['select', 'create', 'insert', 'delete', 'drop', 'update', 'case', 'and', 'or'] blacklisted_keywords = ['select', 'create', 'insert', 'delete', 'drop', 'update', 'case', 'and', 'or']
@@ -1492,9 +1510,10 @@ def strip(val, chars=None):
return (val or "").replace("\ufeff", "").replace("\u200b", "").strip(chars) return (val or "").replace("\ufeff", "").replace("\u200b", "").strip(chars)


def to_markdown(html): def to_markdown(html):
from html2text import html2text
from html.parser import HTMLParser from html.parser import HTMLParser


from html2text import html2text

text = None text = None
try: try:
text = html2text(html or '') text = html2text(html or '')
@@ -1504,7 +1523,8 @@ def to_markdown(html):
return text return text


def md_to_html(markdown_text): def md_to_html(markdown_text):
from markdown2 import markdown as _markdown, MarkdownError
from markdown2 import MarkdownError
from markdown2 import markdown as _markdown


extras = { extras = {
'fenced-code-blocks': None, 'fenced-code-blocks': None,
@@ -1529,14 +1549,14 @@ def md_to_html(markdown_text):
def markdown(markdown_text): def markdown(markdown_text):
return md_to_html(markdown_text) return md_to_html(markdown_text)


def is_subset(list_a, list_b):
def is_subset(list_a: List, list_b: List) -> bool:
'''Returns whether list_a is a subset of list_b''' '''Returns whether list_a is a subset of list_b'''
return len(list(set(list_a) & set(list_b))) == len(list_a) return len(list(set(list_a) & set(list_b))) == len(list_a)


def generate_hash(*args, **kwargs):
def generate_hash(*args, **kwargs) -> str:
return frappe.generate_hash(*args, **kwargs) return frappe.generate_hash(*args, **kwargs)


def guess_date_format(date_string):
def guess_date_format(date_string: str) -> str:
DATE_FORMATS = [ DATE_FORMATS = [
r"%d/%b/%y", r"%d/%b/%y",
r"%d-%m-%Y", r"%d-%m-%Y",
@@ -1611,13 +1631,13 @@ def guess_date_format(date_string):
if date_format and time_format: if date_format and time_format:
return (date_format + ' ' + time_format).strip() return (date_format + ' ' + time_format).strip()


def validate_json_string(string):
def validate_json_string(string: str) -> None:
try: try:
json.loads(string) json.loads(string)
except (TypeError, ValueError): except (TypeError, ValueError):
raise frappe.ValidationError raise frappe.ValidationError


def get_user_info_for_avatar(user_id):
def get_user_info_for_avatar(user_id: str) -> Dict:
user_info = { user_info = {
"email": user_id, "email": user_id,
"image": "", "image": "",
@@ -1664,3 +1684,30 @@ class UnicodeWithAttrs(str):
def __init__(self, text): def __init__(self, text):
self.toc_html = text.toc_html self.toc_html = text.toc_html
self.metadata = text.metadata self.metadata = text.metadata


def format_timedelta(o: datetime.timedelta) -> str:
# mariadb allows a wide diff range - https://mariadb.com/kb/en/time/
# but frappe doesnt - i think via babel : only allows 0..23 range for hour
total_seconds = o.total_seconds()
hours, remainder = divmod(total_seconds, 3600)
minutes, seconds = divmod(remainder, 60)
rounded_seconds = round(seconds, 6)
int_seconds = int(seconds)

if rounded_seconds == int_seconds:
seconds = int_seconds
else:
seconds = rounded_seconds

return "{:01}:{:02}:{:02}".format(int(hours), int(minutes), seconds)


def parse_timedelta(s: str) -> datetime.timedelta:
# ref: https://stackoverflow.com/a/21074460/10309266
if 'day' in s:
m = re.match(r"(?P<days>[-\d]+) day[s]*, (?P<hours>\d+):(?P<minutes>\d+):(?P<seconds>\d[\.\d+]*)", s)
else:
m = re.match(r"(?P<hours>\d+):(?P<minutes>\d+):(?P<seconds>\d[\.\d+]*)", s)

return datetime.timedelta(**{key: float(val) for key, val in m.groupdict().items()})

+ 7
- 2
frappe/utils/formatters.py View File

@@ -3,9 +3,11 @@


import frappe import frappe
import datetime import datetime
from frappe.utils import formatdate, fmt_money, flt, cstr, cint, format_datetime, format_time, format_duration
from frappe.utils import formatdate, fmt_money, flt, cstr, cint, format_datetime, format_time, format_duration, format_timedelta
from frappe.model.meta import get_field_currency, get_field_precision from frappe.model.meta import get_field_currency, get_field_precision
import re import re
from dateutil.parser import ParserError



def format_value(value, df=None, doc=None, currency=None, translated=False, format=None): def format_value(value, df=None, doc=None, currency=None, translated=False, format=None):
'''Format value based on given fieldtype, document reference, currency reference. '''Format value based on given fieldtype, document reference, currency reference.
@@ -47,7 +49,10 @@ def format_value(value, df=None, doc=None, currency=None, translated=False, form
return format_datetime(value) return format_datetime(value)


elif df.get("fieldtype")=="Time": elif df.get("fieldtype")=="Time":
return format_time(value)
try:
return format_time(value)
except ParserError:
return format_timedelta(value)


elif value==0 and df.get("fieldtype") in ("Int", "Float", "Currency", "Percent") and df.get("print_hide_if_no_value"): elif value==0 and df.get("fieldtype") in ("Int", "Float", "Currency", "Percent") and df.get("print_hide_if_no_value"):
# this is required to show 0 as blank in table columns # this is required to show 0 as blank in table columns


+ 8
- 6
frappe/utils/response.py View File

@@ -1,4 +1,4 @@
# Copyright (c) 2015, Frappe Technologies Pvt. Ltd. and Contributors
# Copyright (c) 2022, Frappe Technologies Pvt. Ltd. and Contributors
# License: MIT. See LICENSE # License: MIT. See LICENSE


import json import json
@@ -16,7 +16,7 @@ from werkzeug.local import LocalProxy
from werkzeug.wsgi import wrap_file from werkzeug.wsgi import wrap_file
from werkzeug.wrappers import Response from werkzeug.wrappers import Response
from werkzeug.exceptions import NotFound, Forbidden from werkzeug.exceptions import NotFound, Forbidden
from frappe.utils import cint
from frappe.utils import cint, format_timedelta
from urllib.parse import quote from urllib.parse import quote
from frappe.core.doctype.access_log.access_log import make_access_log from frappe.core.doctype.access_log.access_log import make_access_log


@@ -122,12 +122,14 @@ def make_logs(response = None):


def json_handler(obj): def json_handler(obj):
"""serialize non-serializable data for json""" """serialize non-serializable data for json"""
# serialize date
import collections.abc
from collections.abc import Iterable


if isinstance(obj, (datetime.date, datetime.timedelta, datetime.datetime, datetime.time)):
if isinstance(obj, (datetime.date, datetime.datetime, datetime.time)):
return str(obj) return str(obj)


elif isinstance(obj, datetime.timedelta):
return format_timedelta(obj)

elif isinstance(obj, decimal.Decimal): elif isinstance(obj, decimal.Decimal):
return float(obj) return float(obj)


@@ -138,7 +140,7 @@ def json_handler(obj):
doc = obj.as_dict(no_nulls=True) doc = obj.as_dict(no_nulls=True)
return doc return doc


elif isinstance(obj, collections.abc.Iterable):
elif isinstance(obj, Iterable):
return list(obj) return list(obj)


elif type(obj)==type or isinstance(obj, Exception): elif type(obj)==type or isinstance(obj, Exception):


+ 2
- 2
frappe/website/utils.py View File

@@ -88,7 +88,7 @@ def get_home_page():


# portal default # portal default
if not home_page: if not home_page:
home_page = frappe.db.get_value("Portal Settings", None, "default_portal_home")
home_page = frappe.db.get_single_value("Portal Settings", "default_portal_home")


# by hooks # by hooks
if not home_page: if not home_page:
@@ -96,7 +96,7 @@ def get_home_page():


# global # global
if not home_page: if not home_page:
home_page = frappe.db.get_value("Website Settings", None, "home_page")
home_page = frappe.db.get_single_value("Website Settings", "home_page")


if not home_page: if not home_page:
home_page = "login" if frappe.session.user == 'Guest' else "me" home_page = "login" if frappe.session.user == 'Guest' else "me"


Loading…
Cancel
Save