refactor: frappe.db.set_valueversion-14
@@ -143,6 +143,8 @@ lang = local("lang") | |||||
# This if block is never executed when running the code. It is only used for | # This if block is never executed when running the code. It is only used for | ||||
# telling static code analyzer where to find dynamically defined attributes. | # telling static code analyzer where to find dynamically defined attributes. | ||||
if typing.TYPE_CHECKING: | if typing.TYPE_CHECKING: | ||||
from frappe.utils.redis_wrapper import RedisWrapper | |||||
from frappe.database.mariadb.database import MariaDBDatabase | from frappe.database.mariadb.database import MariaDBDatabase | ||||
from frappe.database.postgres.database import PostgresDatabase | from frappe.database.postgres.database import PostgresDatabase | ||||
from frappe.query_builder.builder import MariaDB, Postgres | from frappe.query_builder.builder import MariaDB, Postgres | ||||
@@ -150,6 +152,7 @@ if typing.TYPE_CHECKING: | |||||
db: typing.Union[MariaDBDatabase, PostgresDatabase] | db: typing.Union[MariaDBDatabase, PostgresDatabase] | ||||
qb: typing.Union[MariaDB, Postgres] | qb: typing.Union[MariaDB, Postgres] | ||||
# end: static analysis hack | # end: static analysis hack | ||||
def init(site, sites_path=None, new_site=False): | def init(site, sites_path=None, new_site=False): | ||||
@@ -311,9 +314,8 @@ def destroy(): | |||||
release_local(local) | release_local(local) | ||||
# memcache | |||||
redis_server = None | redis_server = None | ||||
def cache(): | |||||
def cache() -> "RedisWrapper": | |||||
"""Returns redis connection.""" | """Returns redis connection.""" | ||||
global redis_server | global redis_server | ||||
if not redis_server: | if not redis_server: | ||||
@@ -99,7 +99,6 @@ def get_value(doctype, fieldname, filters=None, as_dict=True, debug=False, paren | |||||
if not filters: | if not filters: | ||||
filters = None | filters = None | ||||
if frappe.get_meta(doctype).issingle: | if frappe.get_meta(doctype).issingle: | ||||
value = frappe.db.get_values_from_single(fields, filters, doctype, as_dict=as_dict, debug=debug) | value = frappe.db.get_values_from_single(fields, filters, doctype, as_dict=as_dict, debug=debug) | ||||
else: | else: | ||||
@@ -623,6 +623,7 @@ def transform_database(context, table, engine, row_format, failfast): | |||||
@click.command('run-tests') | @click.command('run-tests') | ||||
@click.option('--app', help="For App") | @click.option('--app', help="For App") | ||||
@click.option('--doctype', help="For DocType") | @click.option('--doctype', help="For DocType") | ||||
@click.option('--case', help="Select particular TestCase") | |||||
@click.option('--doctype-list-path', help="Path to .txt file for list of doctypes. Example erpnext/tests/server/agriculture.txt") | @click.option('--doctype-list-path', help="Path to .txt file for list of doctypes. Example erpnext/tests/server/agriculture.txt") | ||||
@click.option('--test', multiple=True, help="Specific test") | @click.option('--test', multiple=True, help="Specific test") | ||||
@click.option('--ui-tests', is_flag=True, default=False, help="Run UI Tests") | @click.option('--ui-tests', is_flag=True, default=False, help="Run UI Tests") | ||||
@@ -636,7 +637,7 @@ def transform_database(context, table, engine, row_format, failfast): | |||||
@pass_context | @pass_context | ||||
def run_tests(context, app=None, module=None, doctype=None, test=(), profile=False, | def run_tests(context, app=None, module=None, doctype=None, test=(), profile=False, | ||||
coverage=False, junit_xml_output=False, ui_tests = False, doctype_list_path=None, | coverage=False, junit_xml_output=False, ui_tests = False, doctype_list_path=None, | ||||
skip_test_records=False, skip_before_tests=False, failfast=False): | |||||
skip_test_records=False, skip_before_tests=False, failfast=False, case=None): | |||||
with CodeCoverage(coverage, app): | with CodeCoverage(coverage, app): | ||||
import frappe.test_runner | import frappe.test_runner | ||||
@@ -658,7 +659,7 @@ def run_tests(context, app=None, module=None, doctype=None, test=(), profile=Fal | |||||
ret = frappe.test_runner.main(app, module, doctype, context.verbose, tests=tests, | ret = frappe.test_runner.main(app, module, doctype, context.verbose, tests=tests, | ||||
force=context.force, profile=profile, junit_xml_output=junit_xml_output, | force=context.force, profile=profile, junit_xml_output=junit_xml_output, | ||||
ui_tests=ui_tests, doctype_list_path=doctype_list_path, failfast=failfast) | |||||
ui_tests=ui_tests, doctype_list_path=doctype_list_path, failfast=failfast, case=case) | |||||
if len(ret.failures) == 0 and len(ret.errors) == 0: | if len(ret.failures) == 0 and len(ret.errors) == 0: | ||||
ret = 0 | ret = 0 | ||||
@@ -7,9 +7,8 @@ import frappe | |||||
import os | import os | ||||
import unittest | import unittest | ||||
from frappe import _ | from frappe import _ | ||||
from frappe.core.doctype.file.file import get_attached_images, move_file, get_files_in_folder, unzip_file | |||||
from frappe.core.doctype.file.file import File, get_attached_images, move_file, get_files_in_folder, unzip_file | |||||
from frappe.utils import get_files_path | from frappe.utils import get_files_path | ||||
# test_records = frappe.get_test_records('File') | |||||
test_content1 = 'Hello' | test_content1 = 'Hello' | ||||
test_content2 = 'Hello World' | test_content2 = 'Hello World' | ||||
@@ -24,8 +23,6 @@ def make_test_doc(): | |||||
class TestSimpleFile(unittest.TestCase): | class TestSimpleFile(unittest.TestCase): | ||||
def setUp(self): | def setUp(self): | ||||
self.attached_to_doctype, self.attached_to_docname = make_test_doc() | self.attached_to_doctype, self.attached_to_docname = make_test_doc() | ||||
self.test_content = test_content1 | self.test_content = test_content1 | ||||
@@ -38,21 +35,13 @@ class TestSimpleFile(unittest.TestCase): | |||||
_file.save() | _file.save() | ||||
self.saved_file_url = _file.file_url | self.saved_file_url = _file.file_url | ||||
def test_save(self): | def test_save(self): | ||||
_file = frappe.get_doc("File", {"file_url": self.saved_file_url}) | _file = frappe.get_doc("File", {"file_url": self.saved_file_url}) | ||||
content = _file.get_content() | content = _file.get_content() | ||||
self.assertEqual(content, self.test_content) | self.assertEqual(content, self.test_content) | ||||
def tearDown(self): | |||||
# File gets deleted on rollback, so blank | |||||
pass | |||||
class TestBase64File(unittest.TestCase): | class TestBase64File(unittest.TestCase): | ||||
def setUp(self): | def setUp(self): | ||||
self.attached_to_doctype, self.attached_to_docname = make_test_doc() | self.attached_to_doctype, self.attached_to_docname = make_test_doc() | ||||
self.test_content = base64.b64encode(test_content1.encode('utf-8')) | self.test_content = base64.b64encode(test_content1.encode('utf-8')) | ||||
@@ -66,18 +55,12 @@ class TestBase64File(unittest.TestCase): | |||||
_file.save() | _file.save() | ||||
self.saved_file_url = _file.file_url | self.saved_file_url = _file.file_url | ||||
def test_saved_content(self): | def test_saved_content(self): | ||||
_file = frappe.get_doc("File", {"file_url": self.saved_file_url}) | _file = frappe.get_doc("File", {"file_url": self.saved_file_url}) | ||||
content = _file.get_content() | content = _file.get_content() | ||||
self.assertEqual(content, test_content1) | self.assertEqual(content, test_content1) | ||||
def tearDown(self): | |||||
# File gets deleted on rollback, so blank | |||||
pass | |||||
class TestSameFileName(unittest.TestCase): | class TestSameFileName(unittest.TestCase): | ||||
def test_saved_content(self): | def test_saved_content(self): | ||||
self.attached_to_doctype, self.attached_to_docname = make_test_doc() | self.attached_to_doctype, self.attached_to_docname = make_test_doc() | ||||
@@ -130,8 +113,6 @@ class TestSameFileName(unittest.TestCase): | |||||
class TestSameContent(unittest.TestCase): | class TestSameContent(unittest.TestCase): | ||||
def setUp(self): | def setUp(self): | ||||
self.attached_to_doctype1, self.attached_to_docname1 = make_test_doc() | self.attached_to_doctype1, self.attached_to_docname1 = make_test_doc() | ||||
self.attached_to_doctype2, self.attached_to_docname2 = make_test_doc() | self.attached_to_doctype2, self.attached_to_docname2 = make_test_doc() | ||||
@@ -186,10 +167,6 @@ class TestSameContent(unittest.TestCase): | |||||
limit_property.delete() | limit_property.delete() | ||||
frappe.clear_cache(doctype='ToDo') | frappe.clear_cache(doctype='ToDo') | ||||
def tearDown(self): | |||||
# File gets deleted on rollback, so blank | |||||
pass | |||||
class TestFile(unittest.TestCase): | class TestFile(unittest.TestCase): | ||||
def setUp(self): | def setUp(self): | ||||
@@ -398,7 +375,7 @@ class TestFile(unittest.TestCase): | |||||
def test_make_thumbnail(self): | def test_make_thumbnail(self): | ||||
# test web image | # test web image | ||||
test_file = frappe.get_doc({ | |||||
test_file: File = frappe.get_doc({ | |||||
"doctype": "File", | "doctype": "File", | ||||
"file_name": 'logo', | "file_name": 'logo', | ||||
"file_url": frappe.utils.get_url('/_test/assets/image.jpg'), | "file_url": frappe.utils.get_url('/_test/assets/image.jpg'), | ||||
@@ -10,19 +10,20 @@ import re | |||||
import string | import string | ||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from time import time | from time import time | ||||
from typing import Dict, List, Union, Tuple | |||||
from typing import Dict, List, Tuple, Union | |||||
from pypika.terms import Criterion, NullValue, PseudoColumn | |||||
import frappe | import frappe | ||||
import frappe.defaults | import frappe.defaults | ||||
import frappe.model.meta | import frappe.model.meta | ||||
from frappe import _ | from frappe import _ | ||||
from frappe.utils import now, getdate, cast, get_datetime | |||||
from frappe.model.utils.link_count import flush_local_link_count | from frappe.model.utils.link_count import flush_local_link_count | ||||
from frappe.query_builder.functions import Count | from frappe.query_builder.functions import Count | ||||
from frappe.query_builder.functions import Min, Max, Avg, Sum | |||||
from frappe.query_builder.utils import Column | |||||
from frappe.query_builder.utils import DocType | |||||
from frappe.utils import cast, get_datetime, getdate, now, sbool | |||||
from .query import Query | from .query import Query | ||||
from pypika.terms import Criterion, PseudoColumn | |||||
class Database(object): | class Database(object): | ||||
@@ -557,7 +558,21 @@ class Database(object): | |||||
def get_list(*args, **kwargs): | def get_list(*args, **kwargs): | ||||
return frappe.get_list(*args, **kwargs) | return frappe.get_list(*args, **kwargs) | ||||
def get_single_value(self, doctype, fieldname, cache=False): | |||||
def set_single_value(self, doctype, fieldname, value, *args, **kwargs): | |||||
"""Set field value of Single DocType. | |||||
:param doctype: DocType of the single object | |||||
:param fieldname: `fieldname` of the property | |||||
:param value: `value` of the property | |||||
Example: | |||||
# Update the `deny_multiple_sessions` field in System Settings DocType. | |||||
company = frappe.db.set_single_value("System Settings", "deny_multiple_sessions", True) | |||||
""" | |||||
return self.set_value(doctype, doctype, fieldname, value, *args, **kwargs) | |||||
def get_single_value(self, doctype, fieldname, cache=True): | |||||
"""Get property of Single DocType. Cache locally by default | """Get property of Single DocType. Cache locally by default | ||||
:param doctype: DocType of the single object whose value is requested | :param doctype: DocType of the single object whose value is requested | ||||
@@ -572,7 +587,7 @@ class Database(object): | |||||
if not doctype in self.value_cache: | if not doctype in self.value_cache: | ||||
self.value_cache[doctype] = {} | self.value_cache[doctype] = {} | ||||
if fieldname in self.value_cache[doctype]: | |||||
if cache and fieldname in self.value_cache[doctype]: | |||||
return self.value_cache[doctype][fieldname] | return self.value_cache[doctype][fieldname] | ||||
val = self.query.get_sql( | val = self.query.get_sql( | ||||
@@ -679,53 +694,55 @@ class Database(object): | |||||
:param debug: Print the query in the developer / js console. | :param debug: Print the query in the developer / js console. | ||||
:param for_update: Will add a row-level lock to the value that is being set so that it can be released on commit. | :param for_update: Will add a row-level lock to the value that is being set so that it can be released on commit. | ||||
""" | """ | ||||
if not modified: | |||||
modified = now() | |||||
if not modified_by: | |||||
modified_by = frappe.session.user | |||||
is_single_doctype = not (dn and dt != dn) | |||||
to_update = field if isinstance(field, dict) else {field: val} | |||||
to_update = {} | |||||
if update_modified: | if update_modified: | ||||
to_update = {"modified": modified, "modified_by": modified_by} | |||||
modified = modified or now() | |||||
modified_by = modified_by or frappe.session.user | |||||
to_update.update({"modified": modified, "modified_by": modified_by}) | |||||
if is_single_doctype: | |||||
frappe.db.delete( | |||||
"Singles", | |||||
filters={"field": ("in", tuple(to_update)), "doctype": dt}, debug=debug | |||||
) | |||||
singles_data = ((dt, key, sbool(value)) for key, value in to_update.items()) | |||||
query = ( | |||||
frappe.qb.into("Singles") | |||||
.columns("doctype", "field", "value") | |||||
.insert(*singles_data) | |||||
).run(debug=debug) | |||||
frappe.clear_document_cache(dt, dt) | |||||
if isinstance(field, dict): | |||||
to_update.update(field) | |||||
else: | else: | ||||
to_update.update({field: val}) | |||||
table = DocType(dt) | |||||
if dn and dt!=dn: | |||||
# with table | |||||
set_values = [] | |||||
for key in to_update: | |||||
set_values.append('`{0}`=%({0})s'.format(key)) | |||||
if for_update: | |||||
docnames = tuple( | |||||
self.get_values(dt, dn, "name", debug=debug, for_update=for_update, pluck=True) | |||||
) or (NullValue(),) | |||||
query = frappe.qb.update(table).where(table.name.isin(docnames)) | |||||
for name in self.get_values(dt, dn, 'name', for_update=for_update, debug=debug): | |||||
values = dict(name=name[0]) | |||||
values.update(to_update) | |||||
for docname in docnames: | |||||
frappe.clear_document_cache(dt, docname) | |||||
self.sql("""update `tab{0}` | |||||
set {1} where name=%(name)s""".format(dt, ', '.join(set_values)), | |||||
values, debug=debug) | |||||
else: | |||||
query = self.query.build_conditions(table=dt, filters=dn, update=True) | |||||
# TODO: Fix this; doesn't work rn - gavin@frappe.io | |||||
# frappe.cache().hdel_keys(dt, "document_cache") | |||||
# Workaround: clear all document caches | |||||
frappe.cache().delete_value('document_cache') | |||||
frappe.clear_document_cache(dt, values['name']) | |||||
else: | |||||
# for singles | |||||
keys = list(to_update) | |||||
self.sql(''' | |||||
delete from `tabSingles` | |||||
where field in ({0}) and | |||||
doctype=%s'''.format(', '.join(['%s']*len(keys))), | |||||
list(keys) + [dt], debug=debug) | |||||
for key, value in to_update.items(): | |||||
self.sql('''insert into `tabSingles` (doctype, field, value) values (%s, %s, %s)''', | |||||
(dt, key, value), debug=debug) | |||||
frappe.clear_document_cache(dt, dn) | |||||
for column, value in to_update.items(): | |||||
query = query.set(column, value) | |||||
query.run(debug=debug) | |||||
if dt in self.value_cache: | if dt in self.value_cache: | ||||
del self.value_cache[dt] | del self.value_cache[dt] | ||||
@staticmethod | @staticmethod | ||||
def set(doc, field, val): | def set(doc, field, val): | ||||
"""Set value in document. **Avoid**""" | """Set value in document. **Avoid**""" | ||||
@@ -1,8 +1,21 @@ | |||||
from frappe.query_builder.terms import ParameterizedValueWrapper, ParameterizedFunction | |||||
import pypika | |||||
import pypika.terms | |||||
from pypika import * | |||||
from pypika import Field | |||||
from pypika.utils import ignore_copy | |||||
from frappe.query_builder.terms import ParameterizedFunction, ParameterizedValueWrapper | |||||
from frappe.query_builder.utils import ( | |||||
Column, | |||||
DocType, | |||||
get_query_builder, | |||||
patch_query_aggregation, | |||||
patch_query_execute, | |||||
) | |||||
pypika.terms.ValueWrapper = ParameterizedValueWrapper | pypika.terms.ValueWrapper = ParameterizedValueWrapper | ||||
pypika.terms.Function = ParameterizedFunction | pypika.terms.Function = ParameterizedFunction | ||||
from pypika import * | |||||
from frappe.query_builder.utils import Column, DocType, get_query_builder, patch_query_execute, patch_query_aggregation | |||||
# * Overrides the field() method and replaces it with the a `PseudoColumn` 'field' for consistency | |||||
pypika.queries.Selectable.__getattr__ = ignore_copy(lambda table, x: Field(x, table=table)) | |||||
pypika.queries.Selectable.__getitem__ = ignore_copy(lambda table, x: Field(x, table=table)) | |||||
pypika.queries.Selectable.field = pypika.terms.PseudoColumn("field") |
@@ -1,8 +1,12 @@ | |||||
from pypika import MySQLQuery, Order, PostgreSQLQuery, terms | from pypika import MySQLQuery, Order, PostgreSQLQuery, terms | ||||
from pypika.queries import Schema, Table | |||||
from frappe.utils import get_table_name | |||||
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder | |||||
from pypika.queries import QueryBuilder, Schema, Table | |||||
from pypika.terms import Function | from pypika.terms import Function | ||||
from frappe.query_builder.terms import ParameterizedValueWrapper | |||||
from frappe.utils import get_table_name | |||||
class Base: | class Base: | ||||
terms = terms | terms = terms | ||||
desc = Order.desc | desc = Order.desc | ||||
@@ -19,13 +23,13 @@ class Base: | |||||
return Table(table_name, *args, **kwargs) | return Table(table_name, *args, **kwargs) | ||||
@classmethod | @classmethod | ||||
def into(cls, table, *args, **kwargs): | |||||
def into(cls, table, *args, **kwargs) -> QueryBuilder: | |||||
if isinstance(table, str): | if isinstance(table, str): | ||||
table = cls.DocType(table) | table = cls.DocType(table) | ||||
return super().into(table, *args, **kwargs) | return super().into(table, *args, **kwargs) | ||||
@classmethod | @classmethod | ||||
def update(cls, table, *args, **kwargs): | |||||
def update(cls, table, *args, **kwargs) -> QueryBuilder: | |||||
if isinstance(table, str): | if isinstance(table, str): | ||||
table = cls.DocType(table) | table = cls.DocType(table) | ||||
return super().update(table, *args, **kwargs) | return super().update(table, *args, **kwargs) | ||||
@@ -34,6 +38,10 @@ class Base: | |||||
class MariaDB(Base, MySQLQuery): | class MariaDB(Base, MySQLQuery): | ||||
Field = terms.Field | Field = terms.Field | ||||
@classmethod | |||||
def _builder(cls, *args, **kwargs) -> "MySQLQueryBuilder": | |||||
return super()._builder(*args, wrapper_cls=ParameterizedValueWrapper, **kwargs) | |||||
@classmethod | @classmethod | ||||
def from_(cls, table, *args, **kwargs): | def from_(cls, table, *args, **kwargs): | ||||
if isinstance(table, str): | if isinstance(table, str): | ||||
@@ -53,6 +61,10 @@ class Postgres(Base, PostgreSQLQuery): | |||||
# they are two different objects. The quick fix used here is to replace the | # they are two different objects. The quick fix used here is to replace the | ||||
# Field names in the "Field" function. | # Field names in the "Field" function. | ||||
@classmethod | |||||
def _builder(cls, *args, **kwargs) -> "PostgreSQLQueryBuilder": | |||||
return super()._builder(*args, wrapper_cls=ParameterizedValueWrapper, **kwargs) | |||||
@classmethod | @classmethod | ||||
def Field(cls, field_name, *args, **kwargs): | def Field(cls, field_name, *args, **kwargs): | ||||
if field_name in cls.field_translation: | if field_name in cls.field_translation: | ||||
@@ -1,33 +1,76 @@ | |||||
from datetime import timedelta | |||||
from typing import Any, Dict, Optional | from typing import Any, Dict, Optional | ||||
from frappe.utils.data import format_timedelta | |||||
from pypika.terms import Function, ValueWrapper | from pypika.terms import Function, ValueWrapper | ||||
from pypika.utils import format_alias_sql | from pypika.utils import format_alias_sql | ||||
class NamedParameterWrapper(): | |||||
def __init__(self, parameters: Dict[str, Any]): | |||||
self.parameters = parameters | |||||
class NamedParameterWrapper: | |||||
"""Utility class to hold parameter values and keys""" | |||||
def update_parameters(self, param_key: Any, param_value: Any, **kwargs): | |||||
def __init__(self) -> None: | |||||
self.parameters = {} | |||||
def get_sql(self, param_value: Any, **kwargs) -> str: | |||||
"""returns SQL for a parameter, while adding the real value in a dict | |||||
Args: | |||||
param_value (Any): Value of the parameter | |||||
Returns: | |||||
str: parameter used in the SQL query | |||||
""" | |||||
param_key = f"%(param{len(self.parameters) + 1})s" | |||||
self.parameters[param_key[2:-2]] = param_value | self.parameters[param_key[2:-2]] = param_value | ||||
return param_key | |||||
def get_sql(self, **kwargs): | |||||
return f'%(param{len(self.parameters) + 1})s' | |||||
def get_parameters(self) -> Dict[str, Any]: | |||||
"""get dict with parameters and values | |||||
Returns: | |||||
Dict[str, Any]: parameter dict | |||||
""" | |||||
return self.parameters | |||||
class ParameterizedValueWrapper(ValueWrapper): | class ParameterizedValueWrapper(ValueWrapper): | ||||
def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", param_wrapper= None, **kwargs: Any) -> str: | |||||
if param_wrapper is None: | |||||
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) | |||||
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) | |||||
""" | |||||
Class to monkey patch ValueWrapper | |||||
Adds functionality to parameterize queries when a `param wrapper` is passed in get_sql() | |||||
""" | |||||
def get_sql( | |||||
self, | |||||
quote_char: Optional[str] = None, | |||||
secondary_quote_char: str = "'", | |||||
param_wrapper: Optional[NamedParameterWrapper] = None, | |||||
**kwargs: Any, | |||||
) -> str: | |||||
if param_wrapper and isinstance(self.value, str): | |||||
# add quotes if it's a string value | |||||
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs) | |||||
sql = param_wrapper.get_sql(param_value=value_sql, **kwargs) | |||||
else: | else: | ||||
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs) if not isinstance(self.value,int) else self.value | |||||
param_sql = param_wrapper.get_sql(**kwargs) | |||||
param_wrapper.update_parameters(param_key=param_sql, param_value=value_sql, **kwargs) | |||||
return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs) | |||||
# * BUG: pypika doesen't parse timedeltas | |||||
if isinstance(self.value, timedelta): | |||||
self.value = format_timedelta(self.value) | |||||
sql = self.get_value_sql( | |||||
quote_char=quote_char, | |||||
secondary_quote_char=secondary_quote_char, | |||||
**kwargs, | |||||
) | |||||
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) | |||||
class ParameterizedFunction(Function): | class ParameterizedFunction(Function): | ||||
""" | |||||
Class to monkey patch pypika.terms.Functions | |||||
Only to pass `param_wrapper` in `get_function_sql`. | |||||
""" | |||||
def get_sql(self, **kwargs: Any) -> str: | def get_sql(self, **kwargs: Any) -> str: | ||||
with_alias = kwargs.pop("with_alias", False) | with_alias = kwargs.pop("with_alias", False) | ||||
with_namespace = kwargs.pop("with_namespace", False) | with_namespace = kwargs.pop("with_namespace", False) | ||||
@@ -35,15 +78,24 @@ class ParameterizedFunction(Function): | |||||
dialect = kwargs.pop("dialect", None) | dialect = kwargs.pop("dialect", None) | ||||
param_wrapper = kwargs.pop("param_wrapper", None) | param_wrapper = kwargs.pop("param_wrapper", None) | ||||
function_sql = self.get_function_sql(with_namespace=with_namespace, quote_char=quote_char, param_wrapper=param_wrapper, dialect=dialect) | |||||
function_sql = self.get_function_sql( | |||||
with_namespace=with_namespace, | |||||
quote_char=quote_char, | |||||
param_wrapper=param_wrapper, | |||||
dialect=dialect, | |||||
) | |||||
if self.schema is not None: | if self.schema is not None: | ||||
function_sql = "{schema}.{function}".format( | function_sql = "{schema}.{function}".format( | ||||
schema=self.schema.get_sql(quote_char=quote_char, dialect=dialect, **kwargs), | |||||
schema=self.schema.get_sql( | |||||
quote_char=quote_char, dialect=dialect, **kwargs | |||||
), | |||||
function=function_sql, | function=function_sql, | ||||
) | ) | ||||
if with_alias: | if with_alias: | ||||
return format_alias_sql(function_sql, self.alias, quote_char=quote_char, **kwargs) | |||||
return format_alias_sql( | |||||
function_sql, self.alias, quote_char=quote_char, **kwargs | |||||
) | |||||
return function_sql | return function_sql |
@@ -1,16 +1,16 @@ | |||||
from enum import Enum | from enum import Enum | ||||
from typing import Any, Callable, Dict, Union, get_type_hints | |||||
from importlib import import_module | from importlib import import_module | ||||
from typing import Any, Callable, Dict, Union, get_type_hints | |||||
from pypika import Query | from pypika import Query | ||||
from pypika.queries import Column | from pypika.queries import Column | ||||
from pypika.terms import PseudoColumn | |||||
import frappe | import frappe | ||||
from frappe.query_builder.terms import NamedParameterWrapper | |||||
from .builder import MariaDB, Postgres | from .builder import MariaDB, Postgres | ||||
from pypika.terms import PseudoColumn | |||||
from frappe.query_builder.terms import NamedParameterWrapper | |||||
class db_type_is(Enum): | class db_type_is(Enum): | ||||
MARIADB = "mariadb" | MARIADB = "mariadb" | ||||
@@ -59,11 +59,11 @@ def patch_query_execute(): | |||||
return frappe.db.sql(query, params, *args, **kwargs) # nosemgrep | return frappe.db.sql(query, params, *args, **kwargs) # nosemgrep | ||||
def prepare_query(query): | def prepare_query(query): | ||||
params = {} | |||||
query = query.get_sql(param_wrapper = NamedParameterWrapper(params)) | |||||
param_collector = NamedParameterWrapper() | |||||
query = query.get_sql(param_wrapper=param_collector) | |||||
if frappe.flags.in_safe_exec and not query.lower().strip().startswith("select"): | if frappe.flags.in_safe_exec and not query.lower().strip().startswith("select"): | ||||
raise frappe.PermissionError('Only SELECT SQL allowed in scripting') | raise frappe.PermissionError('Only SELECT SQL allowed in scripting') | ||||
return query, params | |||||
return query, param_collector.get_parameters() | |||||
query_class = get_attr(str(frappe.qb).split("'")[1]) | query_class = get_attr(str(frappe.qb).split("'")[1]) | ||||
builder_class = get_type_hints(query_class._builder).get('return') | builder_class = get_type_hints(query_class._builder).get('return') | ||||
@@ -78,7 +78,7 @@ def patch_query_execute(): | |||||
def patch_query_aggregation(): | def patch_query_aggregation(): | ||||
"""Patch aggregation functions to frappe.qb | """Patch aggregation functions to frappe.qb | ||||
""" | """ | ||||
from frappe.query_builder.functions import _max, _min, _avg, _sum | |||||
from frappe.query_builder.functions import _avg, _max, _min, _sum | |||||
frappe.qb.max = _max | frappe.qb.max = _max | ||||
frappe.qb.min = _min | frappe.qb.min = _min | ||||
@@ -30,7 +30,7 @@ def xmlrunner_wrapper(output): | |||||
def main(app=None, module=None, doctype=None, verbose=False, tests=(), | def main(app=None, module=None, doctype=None, verbose=False, tests=(), | ||||
force=False, profile=False, junit_xml_output=None, ui_tests=False, | force=False, profile=False, junit_xml_output=None, ui_tests=False, | ||||
doctype_list_path=None, skip_test_records=False, failfast=False): | |||||
doctype_list_path=None, skip_test_records=False, failfast=False, case=None): | |||||
global unittest_runner | global unittest_runner | ||||
if doctype_list_path: | if doctype_list_path: | ||||
@@ -76,7 +76,7 @@ def main(app=None, module=None, doctype=None, verbose=False, tests=(), | |||||
if doctype: | if doctype: | ||||
ret = run_tests_for_doctype(doctype, verbose, tests, force, profile, failfast=failfast, junit_xml_output=junit_xml_output) | ret = run_tests_for_doctype(doctype, verbose, tests, force, profile, failfast=failfast, junit_xml_output=junit_xml_output) | ||||
elif module: | elif module: | ||||
ret = run_tests_for_module(module, verbose, tests, profile, failfast=failfast, junit_xml_output=junit_xml_output) | |||||
ret = run_tests_for_module(module, verbose, tests, profile, failfast=failfast, junit_xml_output=junit_xml_output, case=case) | |||||
else: | else: | ||||
ret = run_all_tests(app, verbose, profile, ui_tests, failfast=failfast, junit_xml_output=junit_xml_output) | ret = run_all_tests(app, verbose, profile, ui_tests, failfast=failfast, junit_xml_output=junit_xml_output) | ||||
@@ -182,16 +182,16 @@ def run_tests_for_doctype(doctypes, verbose=False, tests=(), force=False, profil | |||||
return _run_unittest(modules, verbose=verbose, tests=tests, profile=profile, failfast=failfast, junit_xml_output=junit_xml_output) | return _run_unittest(modules, verbose=verbose, tests=tests, profile=profile, failfast=failfast, junit_xml_output=junit_xml_output) | ||||
def run_tests_for_module(module, verbose=False, tests=(), profile=False, failfast=False, junit_xml_output=False): | |||||
def run_tests_for_module(module, verbose=False, tests=(), profile=False, failfast=False, junit_xml_output=False, case=None): | |||||
module = importlib.import_module(module) | module = importlib.import_module(module) | ||||
if hasattr(module, "test_dependencies"): | if hasattr(module, "test_dependencies"): | ||||
for doctype in module.test_dependencies: | for doctype in module.test_dependencies: | ||||
make_test_records(doctype, verbose=verbose) | make_test_records(doctype, verbose=verbose) | ||||
frappe.db.commit() | frappe.db.commit() | ||||
return _run_unittest(module, verbose=verbose, tests=tests, profile=profile, failfast=failfast, junit_xml_output=junit_xml_output) | |||||
return _run_unittest(module, verbose=verbose, tests=tests, profile=profile, failfast=failfast, junit_xml_output=junit_xml_output, case=case) | |||||
def _run_unittest(modules, verbose=False, tests=(), profile=False, failfast=False, junit_xml_output=False): | |||||
def _run_unittest(modules, verbose=False, tests=(), profile=False, failfast=False, junit_xml_output=False, case=None): | |||||
frappe.db.begin() | frappe.db.begin() | ||||
test_suite = unittest.TestSuite() | test_suite = unittest.TestSuite() | ||||
@@ -200,7 +200,10 @@ def _run_unittest(modules, verbose=False, tests=(), profile=False, failfast=Fals | |||||
modules = [modules] | modules = [modules] | ||||
for module in modules: | for module in modules: | ||||
module_test_cases = unittest.TestLoader().loadTestsFromModule(module) | |||||
if case: | |||||
module_test_cases = unittest.TestLoader().loadTestsFromTestCase(getattr(module, case)) | |||||
else: | |||||
module_test_cases = unittest.TestLoader().loadTestsFromModule(module) | |||||
if tests: | if tests: | ||||
for each in module_test_cases: | for each in module_test_cases: | ||||
for test_case in each.__dict__["_tests"]: | for test_case in each.__dict__["_tests"]: | ||||
@@ -337,7 +340,7 @@ def make_test_records_for_doctype(doctype, verbose=0, force=False): | |||||
elif hasattr(test_module, "test_records"): | elif hasattr(test_module, "test_records"): | ||||
if doctype in frappe.local.test_objects: | if doctype in frappe.local.test_objects: | ||||
frappe.local.test_objects[doctype] += make_test_objects(doctype, test_module.test_records, verbose, force) | frappe.local.test_objects[doctype] += make_test_objects(doctype, test_module.test_records, verbose, force) | ||||
else: | |||||
else: | |||||
frappe.local.test_objects[doctype] = make_test_objects(doctype, test_module.test_records, verbose, force) | frappe.local.test_objects[doctype] = make_test_objects(doctype, test_module.test_records, verbose, force) | ||||
else: | else: | ||||
@@ -1,21 +1,21 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# Copyright (c) 2015, Frappe Technologies Pvt. Ltd. and Contributors | |||||
# Copyright (c) 2022, Frappe Technologies Pvt. Ltd. and Contributors | |||||
# License: MIT. See LICENSE | # License: MIT. See LICENSE | ||||
import datetime | |||||
import inspect | |||||
import unittest | import unittest | ||||
from random import choice | from random import choice | ||||
import datetime | |||||
from unittest.mock import patch | |||||
import frappe | import frappe | ||||
from frappe.custom.doctype.custom_field.custom_field import create_custom_field | from frappe.custom.doctype.custom_field.custom_field import create_custom_field | ||||
from frappe.utils import random_string | |||||
from frappe.utils.testutils import clear_custom_fields | |||||
from frappe.query_builder import Field | |||||
from frappe.database import savepoint | from frappe.database import savepoint | ||||
from .test_query_builder import run_only_if, db_type_is | |||||
from frappe.database.database import Database | |||||
from frappe.query_builder import Field | |||||
from frappe.query_builder.functions import Concat_ws | from frappe.query_builder.functions import Concat_ws | ||||
from frappe.tests.test_query_builder import db_type_is, run_only_if | |||||
from frappe.utils import add_days, now, random_string | |||||
from frappe.utils.testutils import clear_custom_fields | |||||
class TestDB(unittest.TestCase): | class TestDB(unittest.TestCase): | ||||
@@ -84,20 +84,6 @@ class TestDB(unittest.TestCase): | |||||
), | ), | ||||
) | ) | ||||
def test_set_value(self): | |||||
todo1 = frappe.get_doc(dict(doctype='ToDo', description = 'test_set_value 1')).insert() | |||||
todo2 = frappe.get_doc(dict(doctype='ToDo', description = 'test_set_value 2')).insert() | |||||
frappe.db.set_value('ToDo', todo1.name, 'description', 'test_set_value change 1') | |||||
self.assertEqual(frappe.db.get_value('ToDo', todo1.name, 'description'), 'test_set_value change 1') | |||||
# multiple set-value | |||||
frappe.db.set_value('ToDo', dict(description=('like', '%test_set_value%')), | |||||
'description', 'change 2') | |||||
self.assertEqual(frappe.db.get_value('ToDo', todo1.name, 'description'), 'change 2') | |||||
self.assertEqual(frappe.db.get_value('ToDo', todo2.name, 'description'), 'change 2') | |||||
def test_escape(self): | def test_escape(self): | ||||
frappe.db.escape("香港濟生堂製藥有限公司 - IT".encode("utf-8")) | frappe.db.escape("香港濟生堂製藥有限公司 - IT".encode("utf-8")) | ||||
@@ -246,7 +232,6 @@ class TestDB(unittest.TestCase): | |||||
frappe.delete_doc(test_doctype, doc) | frappe.delete_doc(test_doctype, doc) | ||||
clear_custom_fields(test_doctype) | clear_custom_fields(test_doctype) | ||||
def test_savepoints(self): | def test_savepoints(self): | ||||
frappe.db.rollback() | frappe.db.rollback() | ||||
save_point = "todonope" | save_point = "todonope" | ||||
@@ -365,6 +350,143 @@ class TestDDLCommandsMaria(unittest.TestCase): | |||||
self.assertEquals(len(indexs_in_table), 2) | self.assertEquals(len(indexs_in_table), 2) | ||||
class TestDBSetValue(unittest.TestCase): | |||||
@classmethod | |||||
def setUpClass(cls): | |||||
cls.todo1 = frappe.get_doc(doctype="ToDo", description="test_set_value 1").insert() | |||||
cls.todo2 = frappe.get_doc(doctype="ToDo", description="test_set_value 2").insert() | |||||
def test_update_single_doctype_field(self): | |||||
value = frappe.db.get_single_value("System Settings", "deny_multiple_sessions") | |||||
changed_value = not value | |||||
frappe.db.set_value("System Settings", "System Settings", "deny_multiple_sessions", changed_value) | |||||
current_value = frappe.db.get_single_value("System Settings", "deny_multiple_sessions") | |||||
self.assertEqual(current_value, changed_value) | |||||
changed_value = not current_value | |||||
frappe.db.set_value("System Settings", None, "deny_multiple_sessions", changed_value) | |||||
current_value = frappe.db.get_single_value("System Settings", "deny_multiple_sessions") | |||||
self.assertEqual(current_value, changed_value) | |||||
changed_value = not current_value | |||||
frappe.db.set_single_value("System Settings", "deny_multiple_sessions", changed_value) | |||||
current_value = frappe.db.get_single_value("System Settings", "deny_multiple_sessions") | |||||
self.assertEqual(current_value, changed_value) | |||||
def test_update_single_row_single_column(self): | |||||
frappe.db.set_value("ToDo", self.todo1.name, "description", "test_set_value change 1") | |||||
updated_value = frappe.db.get_value("ToDo", self.todo1.name, "description") | |||||
self.assertEqual(updated_value, "test_set_value change 1") | |||||
def test_update_single_row_multiple_columns(self): | |||||
description, status = "Upated by test_update_single_row_multiple_columns", "Closed" | |||||
frappe.db.set_value("ToDo", self.todo1.name, { | |||||
"description": description, | |||||
"status": status, | |||||
}, update_modified=False) | |||||
updated_desciption, updated_status = frappe.db.get_value("ToDo", | |||||
filters={"name": self.todo1.name}, | |||||
fieldname=["description", "status"] | |||||
) | |||||
self.assertEqual(description, updated_desciption) | |||||
self.assertEqual(status, updated_status) | |||||
def test_update_multiple_rows_single_column(self): | |||||
frappe.db.set_value("ToDo", {"description": ("like", "%test_set_value%")}, "description", "change 2") | |||||
self.assertEqual(frappe.db.get_value("ToDo", self.todo1.name, "description"), "change 2") | |||||
self.assertEqual(frappe.db.get_value("ToDo", self.todo2.name, "description"), "change 2") | |||||
def test_update_multiple_rows_multiple_columns(self): | |||||
todos_to_update = frappe.get_all("ToDo", filters={ | |||||
"description": ("like", "%test_set_value%"), | |||||
"status": ("!=", "Closed") | |||||
}, pluck="name") | |||||
frappe.db.set_value("ToDo", { | |||||
"description": ("like", "%test_set_value%"), | |||||
"status": ("!=", "Closed") | |||||
}, { | |||||
"status": "Closed", | |||||
"priority": "High" | |||||
}) | |||||
test_result = frappe.get_all("ToDo", filters={"name": ("in", todos_to_update)}, fields=["status", "priority"]) | |||||
self.assertTrue(all(x for x in test_result if x["status"] == "Closed")) | |||||
self.assertTrue(all(x for x in test_result if x["priority"] == "High")) | |||||
def test_update_modified_options(self): | |||||
self.todo2.reload() | |||||
todo = self.todo2 | |||||
updated_description = f"{todo.description} - by `test_update_modified_options`" | |||||
custom_modified = datetime.datetime.fromisoformat(add_days(now(), 10)) | |||||
custom_modified_by = "user_that_doesnt_exist@example.com" | |||||
frappe.db.set_value("ToDo", todo.name, "description", updated_description, update_modified=False) | |||||
self.assertEqual(updated_description, frappe.db.get_value("ToDo", todo.name, "description")) | |||||
self.assertEqual(todo.modified, frappe.db.get_value("ToDo", todo.name, "modified")) | |||||
frappe.db.set_value("ToDo", todo.name, "description", "test_set_value change 1", modified=custom_modified, modified_by=custom_modified_by) | |||||
self.assertTupleEqual( | |||||
(custom_modified, custom_modified_by), | |||||
frappe.db.get_value("ToDo", todo.name, ["modified", "modified_by"]) | |||||
) | |||||
def test_for_update(self): | |||||
self.todo1.reload() | |||||
with patch.object(Database, "sql") as sql_called: | |||||
frappe.db.set_value( | |||||
self.todo1.doctype, | |||||
self.todo1.name, | |||||
"description", | |||||
f"{self.todo1.description}-edit by `test_for_update`" | |||||
) | |||||
first_query = sql_called.call_args_list[0].args[0] | |||||
second_query = sql_called.call_args_list[1].args[0] | |||||
self.assertTrue(sql_called.call_count == 2) | |||||
self.assertTrue("FOR UPDATE" in first_query) | |||||
if frappe.conf.db_type == "postgres": | |||||
from frappe.database.postgres.database import modify_query | |||||
self.assertTrue(modify_query("UPDATE `tabToDo` SET") in second_query) | |||||
if frappe.conf.db_type == "mariadb": | |||||
self.assertTrue("UPDATE `tabToDo` SET" in second_query) | |||||
def test_cleared_cache(self): | |||||
self.todo2.reload() | |||||
with patch.object(frappe, "clear_document_cache") as clear_cache: | |||||
frappe.db.set_value( | |||||
self.todo2.doctype, | |||||
self.todo2.name, | |||||
"description", | |||||
f"{self.todo2.description}-edit by `test_cleared_cache`" | |||||
) | |||||
clear_cache.assert_called() | |||||
def test_update_alias(self): | |||||
args = (self.todo1.doctype, self.todo1.name, "description", "Updated by `test_update_alias`") | |||||
kwargs = {"for_update": False, "modified": None, "modified_by": None, "update_modified": True, "debug": False} | |||||
self.assertTrue("return self.set_value(" in inspect.getsource(frappe.db.update)) | |||||
with patch.object(Database, "set_value") as set_value: | |||||
frappe.db.update(*args, **kwargs) | |||||
set_value.assert_called_once() | |||||
set_value.assert_called_with(*args, **kwargs) | |||||
@classmethod | |||||
def tearDownClass(cls): | |||||
frappe.db.rollback() | |||||
@run_only_if(db_type_is.POSTGRES) | @run_only_if(db_type_is.POSTGRES) | ||||
class TestDDLCommandsPost(unittest.TestCase): | class TestDDLCommandsPost(unittest.TestCase): | ||||
test_table_name = "TestNotes" | test_table_name = "TestNotes" | ||||
@@ -5,7 +5,7 @@ import frappe | |||||
from frappe.query_builder.custom import ConstantColumn | from frappe.query_builder.custom import ConstantColumn | ||||
from frappe.query_builder.functions import Coalesce, GroupConcat, Match | from frappe.query_builder.functions import Coalesce, GroupConcat, Match | ||||
from frappe.query_builder.utils import db_type_is | from frappe.query_builder.utils import db_type_is | ||||
from frappe.query_builder import Case | |||||
def run_only_if(dbtype: db_type_is) -> Callable: | def run_only_if(dbtype: db_type_is) -> Callable: | ||||
return unittest.skipIf( | return unittest.skipIf( | ||||
@@ -25,8 +25,14 @@ class TestCustomFunctionsMariaDB(unittest.TestCase): | |||||
) | ) | ||||
def test_constant_column(self): | def test_constant_column(self): | ||||
query = frappe.qb.from_("DocType").select("name", ConstantColumn("John").as_("User")) | |||||
self.assertEqual(query.get_sql(), "SELECT `name`,'John' `User` FROM `tabDocType`") | |||||
query = frappe.qb.from_("DocType").select( | |||||
"name", ConstantColumn("John").as_("User") | |||||
) | |||||
self.assertEqual( | |||||
query.get_sql(), "SELECT `name`,'John' `User` FROM `tabDocType`" | |||||
) | |||||
@run_only_if(db_type_is.POSTGRES) | @run_only_if(db_type_is.POSTGRES) | ||||
class TestCustomFunctionsPostgres(unittest.TestCase): | class TestCustomFunctionsPostgres(unittest.TestCase): | ||||
def test_concat(self): | def test_concat(self): | ||||
@@ -39,8 +45,13 @@ class TestCustomFunctionsPostgres(unittest.TestCase): | |||||
) | ) | ||||
def test_constant_column(self): | def test_constant_column(self): | ||||
query = frappe.qb.from_("DocType").select("name", ConstantColumn("John").as_("User")) | |||||
self.assertEqual(query.get_sql(), 'SELECT "name",\'John\' "User" FROM "tabDocType"') | |||||
query = frappe.qb.from_("DocType").select( | |||||
"name", ConstantColumn("John").as_("User") | |||||
) | |||||
self.assertEqual( | |||||
query.get_sql(), 'SELECT "name",\'John\' "User" FROM "tabDocType"' | |||||
) | |||||
class TestBuilderBase(object): | class TestBuilderBase(object): | ||||
def test_adding_tabs(self): | def test_adding_tabs(self): | ||||
@@ -55,23 +66,68 @@ class TestBuilderBase(object): | |||||
self.assertIsInstance(query.run, Callable) | self.assertIsInstance(query.run, Callable) | ||||
self.assertIsInstance(data, list) | self.assertIsInstance(data, list) | ||||
def test_walk(self): | |||||
DocType = frappe.qb.DocType('DocType') | |||||
class TestParameterization(unittest.TestCase): | |||||
def test_where_conditions(self): | |||||
DocType = frappe.qb.DocType("DocType") | |||||
query = ( | query = ( | ||||
frappe.qb.from_(DocType) | frappe.qb.from_(DocType) | ||||
.select(DocType.name) | .select(DocType.name) | ||||
.where((DocType.owner == "Administrator' --") | |||||
& (Coalesce(DocType.search_fields == "subject")) | |||||
.where((DocType.owner == "Administrator' --")) | |||||
) | |||||
self.assertTrue("walk" in dir(query)) | |||||
query, params = query.walk() | |||||
self.assertIn("%(param1)s", query) | |||||
self.assertIn("param1", params) | |||||
self.assertEqual(params["param1"], "Administrator' --") | |||||
def test_set_cnoditions(self): | |||||
DocType = frappe.qb.DocType("DocType") | |||||
query = frappe.qb.update(DocType).set(DocType.value, "some_value") | |||||
self.assertTrue("walk" in dir(query)) | |||||
query, params = query.walk() | |||||
self.assertIn("%(param1)s", query) | |||||
self.assertIn("param1", params) | |||||
self.assertEqual(params["param1"], "some_value") | |||||
def test_where_conditions_functions(self): | |||||
DocType = frappe.qb.DocType("DocType") | |||||
query = ( | |||||
frappe.qb.from_(DocType) | |||||
.select(DocType.name) | |||||
.where(Coalesce(DocType.search_fields == "subject")) | |||||
) | |||||
self.assertTrue("walk" in dir(query)) | |||||
query, params = query.walk() | |||||
self.assertIn("%(param1)s", query) | |||||
self.assertIn("param1", params) | |||||
self.assertEqual(params["param1"], "subject") | |||||
def test_case(self): | |||||
DocType = frappe.qb.DocType("DocType") | |||||
query = ( | |||||
frappe.qb.from_(DocType) | |||||
.select( | |||||
Case() | |||||
.when(DocType.search_fields == "value", "other_value") | |||||
.when(Coalesce(DocType.search_fields == "subject_in_function"), "true_value") | |||||
) | ) | ||||
) | ) | ||||
self.assertTrue("walk" in dir(query)) | self.assertTrue("walk" in dir(query)) | ||||
query, params = query.walk() | query, params = query.walk() | ||||
self.assertIn("%(param1)s", query) | self.assertIn("%(param1)s", query) | ||||
self.assertIn("%(param2)s", query) | |||||
self.assertIn("param1",params) | |||||
self.assertEqual(params["param1"],"Administrator' --") | |||||
self.assertEqual(params["param2"],"subject") | |||||
self.assertIn("param1", params) | |||||
self.assertEqual(params["param1"], "value") | |||||
self.assertEqual(params["param2"], "other_value") | |||||
self.assertEqual(params["param3"], "subject_in_function") | |||||
self.assertEqual(params["param4"], "true_value") | |||||
@run_only_if(db_type_is.MARIADB) | @run_only_if(db_type_is.MARIADB) | ||||
@@ -84,6 +140,7 @@ class TestBuilderMaria(unittest.TestCase, TestBuilderBase): | |||||
"SELECT * FROM `__Auth`", frappe.qb.from_("__Auth").select("*").get_sql() | "SELECT * FROM `__Auth`", frappe.qb.from_("__Auth").select("*").get_sql() | ||||
) | ) | ||||
@run_only_if(db_type_is.POSTGRES) | @run_only_if(db_type_is.POSTGRES) | ||||
class TestBuilderPostgres(unittest.TestCase, TestBuilderBase): | class TestBuilderPostgres(unittest.TestCase, TestBuilderBase): | ||||
def test_adding_tabs_in_from(self): | def test_adding_tabs_in_from(self): | ||||
@@ -1,22 +1,28 @@ | |||||
# Copyright (c) 2015, Frappe Technologies Pvt. Ltd. and Contributors | |||||
# Copyright (c) 2022, Frappe Technologies Pvt. Ltd. and Contributors | |||||
# License: MIT. See LICENSE | # License: MIT. See LICENSE | ||||
import unittest | |||||
import frappe | |||||
from frappe.utils import evaluate_filters, money_in_words, scrub_urls, get_url | |||||
from frappe.utils import validate_url, validate_email_address | |||||
from frappe.utils import ceil, floor | |||||
from frappe.utils.data import cast, validate_python_code | |||||
from frappe.utils.diff import get_version_diff, version_query, _get_value_from_version | |||||
from PIL import Image | |||||
from frappe.utils.image import strip_exif_data, optimize_image | |||||
import io | import io | ||||
import json | |||||
import unittest | |||||
from datetime import date, datetime, time, timedelta | |||||
from decimal import Decimal | |||||
from enum import Enum | |||||
from mimetypes import guess_type | from mimetypes import guess_type | ||||
from datetime import datetime, timedelta, date | |||||
from unittest.mock import patch | from unittest.mock import patch | ||||
import pytz | |||||
from PIL import Image | |||||
import frappe | |||||
from frappe.utils import ceil, evaluate_filters, floor, format_timedelta | |||||
from frappe.utils import get_url, money_in_words, parse_timedelta, scrub_urls | |||||
from frappe.utils import validate_email_address, validate_url | |||||
from frappe.utils.data import cast, validate_python_code | |||||
from frappe.utils.diff import _get_value_from_version, get_version_diff, version_query | |||||
from frappe.utils.image import optimize_image, strip_exif_data | |||||
from frappe.utils.response import json_handler | |||||
class TestFilters(unittest.TestCase): | class TestFilters(unittest.TestCase): | ||||
def test_simple_dict(self): | def test_simple_dict(self): | ||||
self.assertTrue(evaluate_filters({'doctype': 'User', 'status': 'Open'}, {'status': 'Open'})) | self.assertTrue(evaluate_filters({'doctype': 'User', 'status': 'Open'}, {'status': 'Open'})) | ||||
@@ -273,9 +279,7 @@ class TestPythonExpressions(unittest.TestCase): | |||||
for expr in invalid_expressions: | for expr in invalid_expressions: | ||||
self.assertRaises(frappe.ValidationError, validate_python_code, expr) | self.assertRaises(frappe.ValidationError, validate_python_code, expr) | ||||
class TestDiffUtils(unittest.TestCase): | class TestDiffUtils(unittest.TestCase): | ||||
@classmethod | @classmethod | ||||
def setUpClass(cls): | def setUpClass(cls): | ||||
cls.doc = frappe.get_doc(doctype="Client Script", dt="Client Script") | cls.doc = frappe.get_doc(doctype="Client Script", dt="Client Script") | ||||
@@ -330,8 +334,59 @@ class TestDateUtils(unittest.TestCase): | |||||
self.assertEqual(frappe.utils.get_last_day_of_week("2020-12-28"), | self.assertEqual(frappe.utils.get_last_day_of_week("2020-12-28"), | ||||
frappe.utils.getdate("2021-01-02")) | frappe.utils.getdate("2021-01-02")) | ||||
class TestXlsxUtils(unittest.TestCase): | |||||
class TestResponse(unittest.TestCase): | |||||
def test_json_handler(self): | |||||
class TEST(Enum): | |||||
ABC = "!@)@)!" | |||||
BCE = "ENJD" | |||||
GOOD_OBJECT = { | |||||
"time_types": [ | |||||
date(year=2020, month=12, day=2), | |||||
datetime(year=2020, month=12, day=2, hour=23, minute=23, second=23, microsecond=23, tzinfo=pytz.utc), | |||||
time(hour=23, minute=23, second=23, microsecond=23, tzinfo=pytz.utc), | |||||
timedelta(days=10, hours=12, minutes=120, seconds=10), | |||||
], | |||||
"float": [ | |||||
Decimal(29.21), | |||||
], | |||||
"doc": [ | |||||
frappe.get_doc("System Settings"), | |||||
], | |||||
"iter": [ | |||||
{1, 2, 3}, | |||||
(1, 2, 3), | |||||
"abcdef", | |||||
], | |||||
"string": "abcdef" | |||||
} | |||||
BAD_OBJECT = {"Enum": TEST} | |||||
processed_object = json.loads(json.dumps(GOOD_OBJECT, default=json_handler)) | |||||
self.assertTrue(all([isinstance(x, str) for x in processed_object["time_types"]])) | |||||
self.assertTrue(all([isinstance(x, float) for x in processed_object["float"]])) | |||||
self.assertTrue(all([isinstance(x, (list, str)) for x in processed_object["iter"]])) | |||||
self.assertIsInstance(processed_object["string"], str) | |||||
with self.assertRaises(TypeError): | |||||
json.dumps(BAD_OBJECT, default=json_handler) | |||||
class TestTimeDeltaUtils(unittest.TestCase): | |||||
def test_format_timedelta(self): | |||||
self.assertEqual(format_timedelta(timedelta(seconds=0)), "0:00:00") | |||||
self.assertEqual(format_timedelta(timedelta(hours=10)), "10:00:00") | |||||
self.assertEqual(format_timedelta(timedelta(hours=100)), "100:00:00") | |||||
self.assertEqual(format_timedelta(timedelta(seconds=100, microseconds=129)), "0:01:40.000129") | |||||
self.assertEqual(format_timedelta(timedelta(seconds=100, microseconds=12212199129)), "3:25:12.199129") | |||||
def test_parse_timedelta(self): | |||||
self.assertEqual(parse_timedelta("0:0:0"), timedelta(seconds=0)) | |||||
self.assertEqual(parse_timedelta("10:0:0"), timedelta(hours=10)) | |||||
self.assertEqual(parse_timedelta("7 days, 0:32:18.192221"), timedelta(days=7, seconds=1938, microseconds=192221)) | |||||
self.assertEqual(parse_timedelta("7 days, 0:32:18"), timedelta(days=7, seconds=1938)) | |||||
class TestXlsxUtils(unittest.TestCase): | |||||
def test_unescape(self): | def test_unescape(self): | ||||
from frappe.utils.xlsxutils import handle_html | from frappe.utils.xlsxutils import handle_html | ||||
@@ -20,6 +20,7 @@ class TestWebsite(unittest.TestCase): | |||||
doctype='User', | doctype='User', | ||||
email='test-user-for-home-page@example.com', | email='test-user-for-home-page@example.com', | ||||
first_name='test')).insert(ignore_if_duplicate=True) | first_name='test')).insert(ignore_if_duplicate=True) | ||||
user.reload() | |||||
role = frappe.get_doc(dict( | role = frappe.get_doc(dict( | ||||
doctype = 'Role', | doctype = 'Role', | ||||
@@ -1,4 +1,4 @@ | |||||
# Copyright (c) 2015, Frappe Technologies Pvt. Ltd. and Contributors | |||||
# Copyright (c) 2022, Frappe Technologies Pvt. Ltd. and Contributors | |||||
# License: MIT. See LICENSE | # License: MIT. See LICENSE | ||||
import functools | import functools | ||||
@@ -56,8 +56,8 @@ def get_email_address(user=None): | |||||
def get_formatted_email(user, mail=None): | def get_formatted_email(user, mail=None): | ||||
"""get Email Address of user formatted as: `John Doe <johndoe@example.com>`""" | """get Email Address of user formatted as: `John Doe <johndoe@example.com>`""" | ||||
fullname = get_fullname(user) | fullname = get_fullname(user) | ||||
method = get_hook_method('get_sender_details') | method = get_hook_method('get_sender_details') | ||||
if method: | if method: | ||||
sender_name, mail = method() | sender_name, mail = method() | ||||
# if method exists but sender_name is "" | # if method exists but sender_name is "" | ||||
@@ -1,17 +1,22 @@ | |||||
# Copyright (c) 2015, Frappe Technologies Pvt. Ltd. and Contributors | |||||
# Copyright (c) 2022, Frappe Technologies Pvt. Ltd. and Contributors | |||||
# License: MIT. See LICENSE | # License: MIT. See LICENSE | ||||
from typing import Optional | |||||
import frappe | |||||
import operator | |||||
import json | |||||
import base64 | import base64 | ||||
import re, datetime, math, time | |||||
import datetime | |||||
import json | |||||
import math | |||||
import operator | |||||
import re | |||||
import time | |||||
from code import compile_command | from code import compile_command | ||||
from enum import Enum | |||||
from typing import Any, Dict, List, Optional, Tuple, Union | |||||
from urllib.parse import quote, urljoin | from urllib.parse import quote, urljoin | ||||
from frappe.desk.utils import slug | |||||
from click import secho | from click import secho | ||||
from enum import Enum | |||||
import frappe | |||||
from frappe.desk.utils import slug | |||||
DATE_FORMAT = "%Y-%m-%d" | DATE_FORMAT = "%Y-%m-%d" | ||||
TIME_FORMAT = "%H:%M:%S.%f" | TIME_FORMAT = "%H:%M:%S.%f" | ||||
@@ -99,11 +104,17 @@ def get_timedelta(time: Optional[str] = None) -> Optional[datetime.timedelta]: | |||||
datetime.timedelta: Timedelta object equivalent of the passed `time` string | datetime.timedelta: Timedelta object equivalent of the passed `time` string | ||||
""" | """ | ||||
from dateutil import parser | from dateutil import parser | ||||
from dateutil.parser import ParserError | |||||
time = time or "0:0:0" | time = time or "0:0:0" | ||||
try: | try: | ||||
t = parser.parse(time) | |||||
try: | |||||
t = parser.parse(time) | |||||
except ParserError as e: | |||||
if "day" in e.args[1]: | |||||
from frappe.utils import parse_timedelta | |||||
return parse_timedelta(time) | |||||
return datetime.timedelta( | return datetime.timedelta( | ||||
hours=t.hour, minutes=t.minute, seconds=t.second, microseconds=t.microsecond | hours=t.hour, minutes=t.minute, seconds=t.second, microseconds=t.microsecond | ||||
) | ) | ||||
@@ -201,7 +212,7 @@ def get_time_zone(): | |||||
return frappe.cache().get_value("time_zone", _get_time_zone) | return frappe.cache().get_value("time_zone", _get_time_zone) | ||||
def convert_utc_to_timezone(utc_timestamp, time_zone): | def convert_utc_to_timezone(utc_timestamp, time_zone): | ||||
from pytz import timezone, UnknownTimeZoneError | |||||
from pytz import UnknownTimeZoneError, timezone | |||||
utcnow = timezone('UTC').localize(utc_timestamp) | utcnow = timezone('UTC').localize(utc_timestamp) | ||||
try: | try: | ||||
return utcnow.astimezone(timezone(time_zone)) | return utcnow.astimezone(timezone(time_zone)) | ||||
@@ -327,7 +338,7 @@ def get_time(time_str): | |||||
return time_str | return time_str | ||||
else: | else: | ||||
if isinstance(time_str, datetime.timedelta): | if isinstance(time_str, datetime.timedelta): | ||||
time_str = str(time_str) | |||||
return format_timedelta(time_str) | |||||
return parser.parse(time_str).time() | return parser.parse(time_str).time() | ||||
def get_datetime_str(datetime_obj): | def get_datetime_str(datetime_obj): | ||||
@@ -610,7 +621,7 @@ def cast(fieldtype, value=None): | |||||
value = flt(value) | value = flt(value) | ||||
elif fieldtype in ("Int", "Check"): | elif fieldtype in ("Int", "Check"): | ||||
value = cint(value) | |||||
value = cint(sbool(value)) | |||||
elif fieldtype in ("Data", "Text", "Small Text", "Long Text", | elif fieldtype in ("Data", "Text", "Small Text", "Long Text", | ||||
"Text Editor", "Select", "Link", "Dynamic Link"): | "Text Editor", "Select", "Link", "Dynamic Link"): | ||||
@@ -726,7 +737,7 @@ def ceil(s): | |||||
def cstr(s, encoding='utf-8'): | def cstr(s, encoding='utf-8'): | ||||
return frappe.as_unicode(s, encoding) | return frappe.as_unicode(s, encoding) | ||||
def sbool(x): | |||||
def sbool(x: str) -> Union[bool, Any]: | |||||
"""Converts str object to Boolean if possible. | """Converts str object to Boolean if possible. | ||||
Example: | Example: | ||||
"true" becomes True | "true" becomes True | ||||
@@ -737,12 +748,15 @@ def sbool(x): | |||||
x (str): String to be converted to Bool | x (str): String to be converted to Bool | ||||
Returns: | Returns: | ||||
object: Returns Boolean or type(x) | |||||
object: Returns Boolean or x | |||||
""" | """ | ||||
from distutils.util import strtobool | |||||
try: | try: | ||||
return bool(strtobool(x)) | |||||
val = x.lower() | |||||
if val in ('true', '1'): | |||||
return True | |||||
elif val in ('false', '0'): | |||||
return False | |||||
return x | |||||
except Exception: | except Exception: | ||||
return x | return x | ||||
@@ -917,13 +931,13 @@ number_format_info = { | |||||
"#.########": (".", "", 8) | "#.########": (".", "", 8) | ||||
} | } | ||||
def get_number_format_info(format): | |||||
def get_number_format_info(format: str) -> Tuple[str, str, int]: | |||||
return number_format_info.get(format) or (".", ",", 2) | return number_format_info.get(format) or (".", ",", 2) | ||||
# | # | ||||
# convert currency to words | # convert currency to words | ||||
# | # | ||||
def money_in_words(number, main_currency = None, fraction_currency=None): | |||||
def money_in_words(number: str, main_currency: Optional[str] = None, fraction_currency: Optional[str] = None): | |||||
""" | """ | ||||
Returns string in words with currency and fraction currency. | Returns string in words with currency and fraction currency. | ||||
""" | """ | ||||
@@ -1009,9 +1023,11 @@ def is_image(filepath): | |||||
def get_thumbnail_base64_for_image(src): | def get_thumbnail_base64_for_image(src): | ||||
from os.path import exists as file_exists | from os.path import exists as file_exists | ||||
from PIL import Image | from PIL import Image | ||||
from frappe import cache, safe_decode | |||||
from frappe.core.doctype.file.file import get_local_image | from frappe.core.doctype.file.file import get_local_image | ||||
from frappe import safe_decode, cache | |||||
if not src: | if not src: | ||||
frappe.throw('Invalid source for image: {0}'.format(src)) | frappe.throw('Invalid source for image: {0}'.format(src)) | ||||
@@ -1302,7 +1318,7 @@ operator_map = { | |||||
"None": lambda a, b: (not a) and True or False | "None": lambda a, b: (not a) and True or False | ||||
} | } | ||||
def evaluate_filters(doc, filters): | |||||
def evaluate_filters(doc, filters: Union[Dict, List, Tuple]): | |||||
'''Returns true if doc matches filters''' | '''Returns true if doc matches filters''' | ||||
if isinstance(filters, dict): | if isinstance(filters, dict): | ||||
for key, value in filters.items(): | for key, value in filters.items(): | ||||
@@ -1319,7 +1335,7 @@ def evaluate_filters(doc, filters): | |||||
return True | return True | ||||
def compare(val1, condition, val2, fieldtype=None): | |||||
def compare(val1: Any, condition: str, val2: Any, fieldtype: Optional[str] = None): | |||||
ret = False | ret = False | ||||
if fieldtype: | if fieldtype: | ||||
val2 = cast(fieldtype, val2) | val2 = cast(fieldtype, val2) | ||||
@@ -1328,7 +1344,7 @@ def compare(val1, condition, val2, fieldtype=None): | |||||
return ret | return ret | ||||
def get_filter(doctype, f, filters_config=None): | |||||
def get_filter(doctype: str, f: Union[Dict, List, Tuple], filters_config=None) -> "frappe._dict": | |||||
"""Returns a _dict like | """Returns a _dict like | ||||
{ | { | ||||
@@ -1415,8 +1431,10 @@ def make_filter_dict(filters): | |||||
return _filter | return _filter | ||||
def sanitize_column(column_name): | def sanitize_column(column_name): | ||||
from frappe import _ | |||||
import sqlparse | import sqlparse | ||||
from frappe import _ | |||||
regex = re.compile("^.*[,'();].*") | regex = re.compile("^.*[,'();].*") | ||||
column_name = sqlparse.format(column_name, strip_comments=True, keyword_case="lower") | column_name = sqlparse.format(column_name, strip_comments=True, keyword_case="lower") | ||||
blacklisted_keywords = ['select', 'create', 'insert', 'delete', 'drop', 'update', 'case', 'and', 'or'] | blacklisted_keywords = ['select', 'create', 'insert', 'delete', 'drop', 'update', 'case', 'and', 'or'] | ||||
@@ -1492,9 +1510,10 @@ def strip(val, chars=None): | |||||
return (val or "").replace("\ufeff", "").replace("\u200b", "").strip(chars) | return (val or "").replace("\ufeff", "").replace("\u200b", "").strip(chars) | ||||
def to_markdown(html): | def to_markdown(html): | ||||
from html2text import html2text | |||||
from html.parser import HTMLParser | from html.parser import HTMLParser | ||||
from html2text import html2text | |||||
text = None | text = None | ||||
try: | try: | ||||
text = html2text(html or '') | text = html2text(html or '') | ||||
@@ -1504,7 +1523,8 @@ def to_markdown(html): | |||||
return text | return text | ||||
def md_to_html(markdown_text): | def md_to_html(markdown_text): | ||||
from markdown2 import markdown as _markdown, MarkdownError | |||||
from markdown2 import MarkdownError | |||||
from markdown2 import markdown as _markdown | |||||
extras = { | extras = { | ||||
'fenced-code-blocks': None, | 'fenced-code-blocks': None, | ||||
@@ -1529,14 +1549,14 @@ def md_to_html(markdown_text): | |||||
def markdown(markdown_text): | def markdown(markdown_text): | ||||
return md_to_html(markdown_text) | return md_to_html(markdown_text) | ||||
def is_subset(list_a, list_b): | |||||
def is_subset(list_a: List, list_b: List) -> bool: | |||||
'''Returns whether list_a is a subset of list_b''' | '''Returns whether list_a is a subset of list_b''' | ||||
return len(list(set(list_a) & set(list_b))) == len(list_a) | return len(list(set(list_a) & set(list_b))) == len(list_a) | ||||
def generate_hash(*args, **kwargs): | |||||
def generate_hash(*args, **kwargs) -> str: | |||||
return frappe.generate_hash(*args, **kwargs) | return frappe.generate_hash(*args, **kwargs) | ||||
def guess_date_format(date_string): | |||||
def guess_date_format(date_string: str) -> str: | |||||
DATE_FORMATS = [ | DATE_FORMATS = [ | ||||
r"%d/%b/%y", | r"%d/%b/%y", | ||||
r"%d-%m-%Y", | r"%d-%m-%Y", | ||||
@@ -1611,13 +1631,13 @@ def guess_date_format(date_string): | |||||
if date_format and time_format: | if date_format and time_format: | ||||
return (date_format + ' ' + time_format).strip() | return (date_format + ' ' + time_format).strip() | ||||
def validate_json_string(string): | |||||
def validate_json_string(string: str) -> None: | |||||
try: | try: | ||||
json.loads(string) | json.loads(string) | ||||
except (TypeError, ValueError): | except (TypeError, ValueError): | ||||
raise frappe.ValidationError | raise frappe.ValidationError | ||||
def get_user_info_for_avatar(user_id): | |||||
def get_user_info_for_avatar(user_id: str) -> Dict: | |||||
user_info = { | user_info = { | ||||
"email": user_id, | "email": user_id, | ||||
"image": "", | "image": "", | ||||
@@ -1664,3 +1684,30 @@ class UnicodeWithAttrs(str): | |||||
def __init__(self, text): | def __init__(self, text): | ||||
self.toc_html = text.toc_html | self.toc_html = text.toc_html | ||||
self.metadata = text.metadata | self.metadata = text.metadata | ||||
def format_timedelta(o: datetime.timedelta) -> str: | |||||
# mariadb allows a wide diff range - https://mariadb.com/kb/en/time/ | |||||
# but frappe doesnt - i think via babel : only allows 0..23 range for hour | |||||
total_seconds = o.total_seconds() | |||||
hours, remainder = divmod(total_seconds, 3600) | |||||
minutes, seconds = divmod(remainder, 60) | |||||
rounded_seconds = round(seconds, 6) | |||||
int_seconds = int(seconds) | |||||
if rounded_seconds == int_seconds: | |||||
seconds = int_seconds | |||||
else: | |||||
seconds = rounded_seconds | |||||
return "{:01}:{:02}:{:02}".format(int(hours), int(minutes), seconds) | |||||
def parse_timedelta(s: str) -> datetime.timedelta: | |||||
# ref: https://stackoverflow.com/a/21074460/10309266 | |||||
if 'day' in s: | |||||
m = re.match(r"(?P<days>[-\d]+) day[s]*, (?P<hours>\d+):(?P<minutes>\d+):(?P<seconds>\d[\.\d+]*)", s) | |||||
else: | |||||
m = re.match(r"(?P<hours>\d+):(?P<minutes>\d+):(?P<seconds>\d[\.\d+]*)", s) | |||||
return datetime.timedelta(**{key: float(val) for key, val in m.groupdict().items()}) |
@@ -3,9 +3,11 @@ | |||||
import frappe | import frappe | ||||
import datetime | import datetime | ||||
from frappe.utils import formatdate, fmt_money, flt, cstr, cint, format_datetime, format_time, format_duration | |||||
from frappe.utils import formatdate, fmt_money, flt, cstr, cint, format_datetime, format_time, format_duration, format_timedelta | |||||
from frappe.model.meta import get_field_currency, get_field_precision | from frappe.model.meta import get_field_currency, get_field_precision | ||||
import re | import re | ||||
from dateutil.parser import ParserError | |||||
def format_value(value, df=None, doc=None, currency=None, translated=False, format=None): | def format_value(value, df=None, doc=None, currency=None, translated=False, format=None): | ||||
'''Format value based on given fieldtype, document reference, currency reference. | '''Format value based on given fieldtype, document reference, currency reference. | ||||
@@ -47,7 +49,10 @@ def format_value(value, df=None, doc=None, currency=None, translated=False, form | |||||
return format_datetime(value) | return format_datetime(value) | ||||
elif df.get("fieldtype")=="Time": | elif df.get("fieldtype")=="Time": | ||||
return format_time(value) | |||||
try: | |||||
return format_time(value) | |||||
except ParserError: | |||||
return format_timedelta(value) | |||||
elif value==0 and df.get("fieldtype") in ("Int", "Float", "Currency", "Percent") and df.get("print_hide_if_no_value"): | elif value==0 and df.get("fieldtype") in ("Int", "Float", "Currency", "Percent") and df.get("print_hide_if_no_value"): | ||||
# this is required to show 0 as blank in table columns | # this is required to show 0 as blank in table columns | ||||
@@ -1,4 +1,4 @@ | |||||
# Copyright (c) 2015, Frappe Technologies Pvt. Ltd. and Contributors | |||||
# Copyright (c) 2022, Frappe Technologies Pvt. Ltd. and Contributors | |||||
# License: MIT. See LICENSE | # License: MIT. See LICENSE | ||||
import json | import json | ||||
@@ -16,7 +16,7 @@ from werkzeug.local import LocalProxy | |||||
from werkzeug.wsgi import wrap_file | from werkzeug.wsgi import wrap_file | ||||
from werkzeug.wrappers import Response | from werkzeug.wrappers import Response | ||||
from werkzeug.exceptions import NotFound, Forbidden | from werkzeug.exceptions import NotFound, Forbidden | ||||
from frappe.utils import cint | |||||
from frappe.utils import cint, format_timedelta | |||||
from urllib.parse import quote | from urllib.parse import quote | ||||
from frappe.core.doctype.access_log.access_log import make_access_log | from frappe.core.doctype.access_log.access_log import make_access_log | ||||
@@ -122,12 +122,14 @@ def make_logs(response = None): | |||||
def json_handler(obj): | def json_handler(obj): | ||||
"""serialize non-serializable data for json""" | """serialize non-serializable data for json""" | ||||
# serialize date | |||||
import collections.abc | |||||
from collections.abc import Iterable | |||||
if isinstance(obj, (datetime.date, datetime.timedelta, datetime.datetime, datetime.time)): | |||||
if isinstance(obj, (datetime.date, datetime.datetime, datetime.time)): | |||||
return str(obj) | return str(obj) | ||||
elif isinstance(obj, datetime.timedelta): | |||||
return format_timedelta(obj) | |||||
elif isinstance(obj, decimal.Decimal): | elif isinstance(obj, decimal.Decimal): | ||||
return float(obj) | return float(obj) | ||||
@@ -138,7 +140,7 @@ def json_handler(obj): | |||||
doc = obj.as_dict(no_nulls=True) | doc = obj.as_dict(no_nulls=True) | ||||
return doc | return doc | ||||
elif isinstance(obj, collections.abc.Iterable): | |||||
elif isinstance(obj, Iterable): | |||||
return list(obj) | return list(obj) | ||||
elif type(obj)==type or isinstance(obj, Exception): | elif type(obj)==type or isinstance(obj, Exception): | ||||
@@ -88,7 +88,7 @@ def get_home_page(): | |||||
# portal default | # portal default | ||||
if not home_page: | if not home_page: | ||||
home_page = frappe.db.get_value("Portal Settings", None, "default_portal_home") | |||||
home_page = frappe.db.get_single_value("Portal Settings", "default_portal_home") | |||||
# by hooks | # by hooks | ||||
if not home_page: | if not home_page: | ||||
@@ -96,7 +96,7 @@ def get_home_page(): | |||||
# global | # global | ||||
if not home_page: | if not home_page: | ||||
home_page = frappe.db.get_value("Website Settings", None, "home_page") | |||||
home_page = frappe.db.get_single_value("Website Settings", "home_page") | |||||
if not home_page: | if not home_page: | ||||
home_page = "login" if frappe.session.user == 'Guest' else "me" | home_page = "login" if frappe.session.user == 'Guest' else "me" | ||||