Browse Source

chore: rename cast_autoincremented_name to cast_name

version-14
phot0n 3 years ago
parent
commit
4ea87fd9cc
2 changed files with 6 additions and 6 deletions
  1. +5
    -5
      frappe/model/db_query.py
  2. +1
    -1
      frappe/tests/test_db_query.py

+ 5
- 5
frappe/model/db_query.py View File

@@ -164,7 +164,7 @@ class DatabaseQuery(object):


# left join parent, child tables # left join parent, child tables
for child in self.tables[1:]: for child in self.tables[1:]:
parent_name = self.cast_autoincremented_name(f"{self.tables[0]}.name")
parent_name = self.cast_name(f"{self.tables[0]}.name")
args.tables += f" {self.join} {child} on ({child}.parent = {parent_name})" args.tables += f" {self.join} {child} on ({child}.parent = {parent_name})"


if self.grouped_or_conditions: if self.grouped_or_conditions:
@@ -327,7 +327,7 @@ class DatabaseQuery(object):
func_found = False func_found = False
for func in sql_functions: for func in sql_functions:
if func in field.lower(): if func in field.lower():
self.fields[i] = self.cast_autoincremented_name(field, func)
self.fields[i] = self.cast_name(field, func)
func_found = True func_found = True
break break


@@ -343,7 +343,7 @@ class DatabaseQuery(object):
if table_name not in self.tables: if table_name not in self.tables:
self.append_table(table_name) self.append_table(table_name)


def cast_autoincremented_name(self, column: str, sql_function: str = "",) -> str:
def cast_name(self, column: str, sql_function: str = "",) -> str:
if frappe.db.db_type == "postgres": if frappe.db.db_type == "postgres":
if "name" in column.lower(): if "name" in column.lower():
if "cast(" not in column.lower() or "::" not in column: if "cast(" not in column.lower() or "::" not in column:
@@ -477,9 +477,9 @@ class DatabaseQuery(object):
self.append_table(tname) self.append_table(tname)


if 'ifnull(' in f.fieldname: if 'ifnull(' in f.fieldname:
column_name = self.cast_autoincremented_name(f.fieldname, "ifnull(")
column_name = self.cast_name(f.fieldname, "ifnull(")
else: else:
column_name = self.cast_autoincremented_name(f"{tname}.{f.fieldname}")
column_name = self.cast_name(f"{tname}.{f.fieldname}")


if f.operator.lower() in additional_filters_config: if f.operator.lower() in additional_filters_config:
f.update(get_additional_filter_field(additional_filters_config, f, f.value)) f.update(get_additional_filter_field(additional_filters_config, f, f.value))


+ 1
- 1
frappe/tests/test_db_query.py View File

@@ -494,7 +494,7 @@ class TestReportview(unittest.TestCase):
response = execute_cmd("frappe.desk.reportview.get") response = execute_cmd("frappe.desk.reportview.get")
self.assertListEqual(response["keys"], ["field_label", "field_name", "_aggregate_column", 'columns']) self.assertListEqual(response["keys"], ["field_label", "field_name", "_aggregate_column", 'columns'])


def test_cast_autoincremented_name(self):
def test_cast_name(self):
from frappe.core.doctype.doctype.test_doctype import new_doctype from frappe.core.doctype.doctype.test_doctype import new_doctype


dt = new_doctype("autoinc_dt_test", autoincremented=True).insert(ignore_permissions=True) dt = new_doctype("autoinc_dt_test", autoincremented=True).insert(ignore_permissions=True)


Loading…
Cancel
Save