Coverage for glotter / auto_gen_test.py: 99%
227 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-10 16:13 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-10 16:13 +0000
1from functools import partial
2from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple
4from pydantic import (
5 BaseModel,
6 Field,
7 ValidationError,
8 ValidationInfo,
9 field_validator,
10 model_validator,
11)
13from glotter.errors import (
14 get_error_details,
15 raise_simple_validation_error,
16 raise_validation_errors,
17 validate_str_dict,
18 validate_str_list,
19)
20from glotter.utils import indent, quote
22TransformationScalarFuncT = Callable[[str, str], Tuple[str, str]]
23TransformationDictFuncT = Callable[[List[str], str, str], Tuple[str, str]]
26class AutoGenParam(BaseModel):
27 """Object used to auto-generated a test parameter"""
29 name: str = ""
30 input: Optional[str] = None
31 expected: Any
33 @field_validator("expected")
34 def validate_expected(cls, value):
35 """
36 Validate expected value
38 :param value: Expected value
39 :return: Original expected value
40 :raises: :exc:`ValidationError` if invalid expected value
41 """
43 if isinstance(value, dict):
44 if not value:
45 raise_simple_validation_error(cls, "Too few items", value)
47 if len(value) > 1:
48 raise_simple_validation_error(cls, "Too many items", value)
50 key, item = tuple(*value.items())
51 if key in ("exec", "string"):
52 if not isinstance(item, str):
53 raise_simple_validation_error(
54 cls, "Input should be a valid string", item, (key,)
55 )
56 if not item:
57 raise_simple_validation_error(cls, "Value must not be empty", item, (key,))
58 elif key != "self":
59 raise_simple_validation_error(cls, 'Invalid "expected" type', item)
60 elif isinstance(value, list):
61 validate_str_list(cls, value)
62 elif not isinstance(value, str):
63 raise_simple_validation_error(
64 cls, "Input should be a valid string, list, or dictionary", value
65 )
67 return value
69 def get_pytest_param(self) -> str:
70 """
71 Get pytest parameter string
73 :return: pytest parameter string if name is not empty, empty string otherwise
74 """
76 if not self.name:
77 return ""
79 input_param = self.input
80 if isinstance(input_param, str):
81 input_param = quote(input_param)
83 expected_output = self.expected
84 if isinstance(expected_output, str):
85 expected_output = quote(expected_output)
86 elif isinstance(expected_output, dict) and "string" in expected_output:
87 expected_output = self.get_constant_variable_name()
89 return f"pytest.param({input_param}, {expected_output}, id={quote(self.name)}),\n"
91 def get_constant_variable_name(self) -> str:
92 """
93 Get constant variable name
95 :return: constant variable name
96 """
98 variable_name = ""
99 if isinstance(self.expected, dict) and "string" in self.expected:
100 variable_name = self.expected["string"].upper()
102 return variable_name
105def _append_method_to_actual(method: str, actual_var: str, expected_var) -> Tuple[str, str]:
106 return f"{actual_var}.{method}()", expected_var
109def _append_method_to_expected(method: str, actual_var: str, expected_var: str) -> Tuple[str, str]:
110 return actual_var, f"{expected_var}.{method}()"
113def _remove_chars(values: List[str], actual_var: str, expected_var: str) -> Tuple[str, str]:
114 for value in values:
115 actual_var += f'.replace({quote(value)}, "")'
117 return actual_var, expected_var
120def _strip_chars(values: List[str], actual_var: str, expected_var: str) -> Tuple[str, str]:
121 for value in values:
122 actual_var += f".strip({quote(value)})"
124 return actual_var, expected_var
127def _unique_sort(actual_var, expected_var):
128 return f"sorted(set({actual_var}))", f"sorted(set({expected_var}))"
131class AutoGenTest(BaseModel):
132 """Object used to auto-generated a test"""
134 VALID_NAME_REGEX: ClassVar[str] = "^[a-zA-Z][0-9a-zA-Z_]*$"
136 name: Annotated[str, Field(strict=True, min_length=1, pattern=VALID_NAME_REGEX)]
137 requires_parameters: bool = False
138 inputs: Annotated[List[str], Field(strict=True, min_length=1)] = Field(
139 ["Input"], validate_default=True
140 )
141 params: Annotated[List[AutoGenParam], Field(strict=True, min_length=1)] = Field(
142 None, validate_default=True
143 )
144 strings: Dict[str, str] = {}
145 transformations: List[Any] = []
146 repeat: int = 1
148 SCALAR_TRANSFORMATION_FUNCS: ClassVar[Dict[str, TransformationScalarFuncT]] = {
149 "strip": partial(_append_method_to_actual, "strip"),
150 "splitlines": partial(_append_method_to_actual, "splitlines"),
151 "lower": partial(_append_method_to_actual, "lower"),
152 "any_order": _unique_sort,
153 "strip_expected": partial(_append_method_to_expected, "strip"),
154 "splitlines_expected": partial(_append_method_to_expected, "splitlines"),
155 }
156 DICT_TRANSFORMATION_FUNCS: ClassVar[Dict[str, TransformationDictFuncT]] = {
157 "remove": _remove_chars,
158 "strip": _strip_chars,
159 }
161 @field_validator("inputs", mode="before")
162 @classmethod
163 def validate_inputs(cls, values):
164 """
165 Validate each input
167 :param values: Inputs to validate
168 :return: Original inputs
169 :raises: :exc:`ValidationError` if input invalid
170 """
172 validate_str_list(cls, values)
173 return values
175 @field_validator("params", mode="before")
176 @classmethod
177 def validate_params(cls, values, info: ValidationInfo):
178 """
179 Validate each parameter
181 :param values: Parameters to validate
182 :param info: Test item
183 :return: Original parameters
184 :raises: :exc:`ValidationError` if project requires parameters but no input, no name,
185 or empty name. Also, raised if no expected output
186 """
188 errors = []
189 field_is_required = "Field is required when parameters required"
191 for index, value in enumerate(values):
192 if info.data.get("requires_parameters"):
193 if not isinstance(value, dict):
194 errors.append(
195 get_error_details("Input should be a valid dictionary", (index,), value)
196 )
197 continue
199 if "name" not in value:
200 errors.append(get_error_details(field_is_required, (index, "name"), value))
201 elif isinstance(value["name"], str) and not value["name"]:
202 errors.append(
203 get_error_details(
204 "Value must not be empty when parameters required",
205 (index, "name"),
206 value,
207 )
208 )
210 if "input" not in value:
211 errors.append(get_error_details(field_is_required, (index, "input"), value))
213 if "expected" not in value:
214 errors.append(get_error_details(field_is_required, (index, "expected"), value))
216 if errors:
217 # Collect inner errors
218 for index, value in enumerate(values):
219 if not isinstance(value, dict):
220 continue
222 try:
223 AutoGenParam.model_validate(value)
224 except ValidationError as exc:
225 for err in exc.errors():
226 loc = (index,) + tuple(err.get("loc", ()))
227 msg = err.get("msg") or str(err.get("type", "value_error"))
228 input_val = err.get("input", value)
229 errors.append(get_error_details(msg, loc, input_val))
231 raise_validation_errors(cls, errors)
233 return values
235 @field_validator("strings", mode="before")
236 @classmethod
237 def validate_strings(cls, values):
238 """
239 Validate each string
241 :param values: Strings to validate
242 :return: Original strings
243 :raises: :exc:`ValidationError` if strings invalid
244 """
246 validate_str_dict(cls, values)
247 return values
249 @field_validator("transformations", mode="before")
250 @classmethod
251 def validate_transformation(cls, values):
252 """
253 Validate each transformation
255 :param values: Transformations to validate
256 :return: Original values
257 :raises: :exc:`ValidationError` if Invalid transformation
258 """
260 if not isinstance(values, list):
261 raise_simple_validation_error(cls, "Input should be a valid list", values)
263 errors = []
264 for index, value in enumerate(values):
265 if isinstance(value, str):
266 if value not in cls.SCALAR_TRANSFORMATION_FUNCS:
267 errors.append(
268 get_error_details(f'Invalid transformation "{value}"', (index,), value)
269 )
270 elif isinstance(value, dict):
271 key = str(*value)
272 if key not in cls.DICT_TRANSFORMATION_FUNCS:
273 errors.append(
274 get_error_details(f'Invalid transformation "{key}"', (index,), value)
275 )
276 else:
277 errors += validate_str_list(cls, value[key], (index, key), raise_exc=False)
278 else:
279 errors.append(
280 get_error_details(
281 "Input should be a valid string or dictionary", (index,), value
282 )
283 )
285 if errors:
286 raise_validation_errors(cls, errors)
288 return values
290 @model_validator(mode="after")
291 def validate_test_strings(self):
292 """
293 Validate each test string
294 """
296 errors = []
297 if self.requires_parameters:
298 for index, param in enumerate(self.params):
299 loc = ("params", index, "expected", "string")
300 expected = param.expected
301 if (
302 isinstance(param.expected, dict)
303 and "string" in expected
304 and expected["string"] not in self.strings
305 ):
306 expected_string = expected["string"]
307 errors.append(
308 get_error_details(
309 f"Refers to a non-existent string {expected_string}",
310 loc=loc,
311 input=expected_string,
312 )
313 )
315 if errors:
316 raise_validation_errors(self.__class__, errors)
318 return self
320 def transform_vars(self) -> Tuple[str, str]:
321 """
322 Transform variables using the specified transformations
324 :return: Transformed actual and expected variables
325 """
327 actual_var = "actual"
328 expected_var = "expected"
329 for transfomation in self.transformations:
330 if isinstance(transfomation, str):
331 actual_var, expected_var = self.SCALAR_TRANSFORMATION_FUNCS[transfomation](
332 actual_var, expected_var
333 )
334 else:
335 key, item = tuple(*transfomation.items())
336 actual_var, expected_var = self.DICT_TRANSFORMATION_FUNCS[key](
337 item, actual_var, expected_var
338 )
340 return actual_var, expected_var
342 def get_constant_variables(self) -> str:
343 """
344 Get contant variables
346 :return: constant variables
347 """
349 constant_variables = {}
350 if self.requires_parameters:
351 for param in self.params:
352 variable_name = param.get_constant_variable_name()
353 if variable_name and variable_name not in constant_variables:
354 constant_variables[variable_name] = self.strings[param.expected["string"]]
356 return "".join(f"{name} = {quote(value)}\n" for name, value in constant_variables.items())
358 def get_pytest_params(self) -> str:
359 """
360 Get pytest parameters
362 :return: pytest parameters
363 """
365 if not self.requires_parameters:
366 return ""
368 pytest_params = "".join(
369 indent(param.get_pytest_param(), 8) for param in self.params
370 ).strip()
371 test_code = f"""\
372@pytest.mark.parametrize(
373 ("in_params", "expected"),
374 [
375 {pytest_params}
376 ]
377)
378"""
379 if self.repeat > 1:
380 test_code += f"""\
381@pytest.mark.parametrize("repeat", range(1, {self.repeat + 1}), ids=lambda x: f"repeat{{x}}")
382"""
384 return test_code
386 def get_test_function_and_run(self, project_name_underscores: str) -> str:
387 """
388 Get test function and run command
390 :param project_name_underscores: Project name with underscores between each word
391 :return: Test function and run command
392 """
394 func_params = ""
395 run_param = ""
396 if self.repeat > 1:
397 func_params += "repeat, "
399 if self.requires_parameters:
400 func_params += "in_params, expected, "
401 run_param += "params=in_params"
403 return f"""\
404def test_{self.name}({func_params}{project_name_underscores}):
405 actual = {project_name_underscores}.run({run_param})
406"""
408 def get_expected_output(self, project_name_underscores: str) -> str:
409 """
410 Get test code that gets the expected output
412 :param project_name_underscores: Project name with underscores between each word
413 :return: Test code that gets the expected output
415 """
417 if self.requires_parameters:
418 return ""
420 expected_output = self.params[0].expected
421 if isinstance(expected_output, str):
422 expected_output = quote(expected_output)
423 elif isinstance(expected_output, dict):
424 return _get_expected_file(project_name_underscores, expected_output)
426 return f"expected = {expected_output}\n"
428 def generate_test(self, project_name_underscores: str) -> str:
429 """
430 Generate test code
432 :param project_name_underscores: Project name with underscores between each word
433 :return: Test code
434 """
436 test_code = "@project_test(PROJECT_NAME)\n"
437 test_code += self.get_pytest_params()
438 test_code += self.get_test_function_and_run(project_name_underscores)
439 test_code += indent(self.get_expected_output(project_name_underscores), 4)
440 actual_var, expected_var = self.transform_vars()
441 test_code += indent(_get_assert(actual_var, expected_var, self.params[0].expected), 4)
442 return test_code
445def _get_expected_file(project_name_underscores: str, expected_output: Dict[str, str]) -> str:
446 if "exec" in expected_output:
447 script = quote(expected_output["exec"])
448 return f"expected = {project_name_underscores}.exec({script})\n"
450 test_code = f"""\
451with open({project_name_underscores}.full_path, "r", encoding="utf-8") as file:
452 expected = file.read()
453"""
455 if "self" in expected_output: 455 ↛ 464line 455 didn't jump to line 464 because the condition on line 455 was always true
456 test_code += """\
457diff_len = len(actual) - len(expected)
458if diff_len > 0:
459 expected += "\\n"
460elif diff_len < 0:
461 actual += "\\n"
462"""
464 return test_code
467def _get_assert(actual_var: str, expected_var: str, expected_output) -> str:
468 if isinstance(expected_output, list):
469 return f"""\
470actual_list = {actual_var}
471expected_list = {expected_var}
472assert len(actual_list) == len(expected_list), "Length not equal"
473for index in range(len(expected_list)):
474 assert actual_list[index] == expected_list[index], f"Item {{index + 1}} is not equal"
475"""
477 test_code = ""
478 if actual_var != "actual":
479 test_code += f"actual = {actual_var}\n"
481 if expected_var != "expected":
482 test_code += f"expected = {expected_var}\n"
484 return f"{test_code}assert actual == expected\n"
487class AutoGenUseTests(BaseModel):
488 """Object used to specify what tests to use"""
490 name: str
491 search: Annotated[str, Field(strict=True, pattern="^[0-9a-zA-Z_]*$")] = ""
492 replace: Annotated[str, Field(strict=True, pattern="^[0-9a-zA-Z_]*$")] = ""
494 @model_validator(mode="before")
495 @classmethod
496 def validate_search_with_replace(cls, values):
497 """
498 Validate that if either search or replace is specified, both must be specified
500 :param values: Values to validate
501 :return: Original values
502 :raise: `exc`:ValidationError if either search or replace is specified, both are specified
503 """
505 if "search" in values and "replace" not in values:
506 raise_simple_validation_error(
507 cls, '"search" item specified without "replace" item', values
508 )
510 if "search" not in values and "replace" in values:
511 raise_simple_validation_error(
512 cls, '"replace" item specified without "search" item', values
513 )
515 return values