Coverage for glotter/auto_gen_test.py: 99%
220 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-03-01 21:54 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-03-01 21:54 +0000
1from functools import partial
2from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple
4from pydantic import (
5 BaseModel,
6 Field,
7 ValidationError,
8 ValidationInfo,
9 field_validator,
10 model_validator,
11)
13from glotter.errors import (
14 get_error_details,
15 raise_simple_validation_error,
16 raise_validation_errors,
17 validate_str_dict,
18 validate_str_list,
19)
20from glotter.utils import indent, quote
22TransformationScalarFuncT = Callable[[str, str], Tuple[str, str]]
23TransformationDictFuncT = Callable[[List[str], str, str], Tuple[str, str]]
26class AutoGenParam(BaseModel):
27 """Object used to auto-generated a test parameter"""
29 name: str = ""
30 input: Optional[str] = None
31 expected: Any
33 @field_validator("expected")
34 def validate_expected(cls, value):
35 """
36 Validate expected value
38 :param value: Expected value
39 :return: Original expected value
40 :raises: :exc:`ValidationError` if invalid expected value
41 """
43 if isinstance(value, dict):
44 if not value:
45 raise_simple_validation_error(cls, "Too few items", value)
47 if len(value) > 1:
48 raise_simple_validation_error(cls, "Too many items", value)
50 key, item = tuple(*value.items())
51 if key in ("exec", "string"):
52 if not isinstance(item, str):
53 raise_simple_validation_error(
54 cls, "Input should be a valid string", item, (key,)
55 )
56 if not item:
57 raise_simple_validation_error(cls, "Value must not be empty", item, (key,))
58 elif key != "self":
59 raise_simple_validation_error(cls, 'Invalid "expected" type', item)
60 elif isinstance(value, list):
61 validate_str_list(cls, value)
62 elif not isinstance(value, str):
63 raise_simple_validation_error(
64 cls, "Input should be a valid string, list, or dictionary", value
65 )
67 return value
69 def get_pytest_param(self) -> str:
70 """
71 Get pytest parameter string
73 :return: pytest parameter string if name is not empty, empty string otherwise
74 """
76 if not self.name:
77 return ""
79 input_param = self.input
80 if isinstance(input_param, str):
81 input_param = quote(input_param)
83 expected_output = self.expected
84 if isinstance(expected_output, str):
85 expected_output = quote(expected_output)
86 elif isinstance(expected_output, dict) and "string" in expected_output:
87 expected_output = self.get_constant_variable_name()
89 return f"pytest.param({input_param}, {expected_output}, id={quote(self.name)}),\n"
91 def get_constant_variable_name(self) -> str:
92 """
93 Get constant variable name
95 :return: constant variable name
96 """
98 variable_name = ""
99 if isinstance(self.expected, dict) and "string" in self.expected:
100 variable_name = self.expected["string"].upper()
102 return variable_name
105def _append_method_to_actual(method: str, actual_var: str, expected_var) -> Tuple[str, str]:
106 return f"{actual_var}.{method}()", expected_var
109def _append_method_to_expected(method: str, actual_var: str, expected_var: str) -> Tuple[str, str]:
110 return actual_var, f"{expected_var}.{method}()"
113def _remove_chars(values: List[str], actual_var: str, expected_var: str) -> Tuple[str, str]:
114 for value in values:
115 actual_var += f'.replace({quote(value)}, "")'
117 return actual_var, expected_var
120def _strip_chars(values: List[str], actual_var: str, expected_var: str) -> Tuple[str, str]:
121 for value in values:
122 actual_var += f".strip({quote(value)})"
124 return actual_var, expected_var
127def _unique_sort(actual_var, expected_var):
128 return f"sorted(set({actual_var}))", f"sorted(set({expected_var}))"
131class AutoGenTest(BaseModel):
132 """Object used to auto-generated a test"""
134 name: Annotated[str, Field(strict=True, min_length=1, pattern="^[a-zA-Z][0-9a-zA-Z_]*$")]
135 requires_parameters: bool = False
136 inputs: Annotated[List[str], Field(strict=True, min_length=1)] = Field(
137 ["Input"], validate_default=True
138 )
139 params: Annotated[List[AutoGenParam], Field(strict=True, min_length=1)] = Field(
140 None, validate_default=True
141 )
142 strings: Dict[str, str] = {}
143 transformations: List[Any] = []
145 SCALAR_TRANSFORMATION_FUNCS: ClassVar[Dict[str, TransformationScalarFuncT]] = {
146 "strip": partial(_append_method_to_actual, "strip"),
147 "splitlines": partial(_append_method_to_actual, "splitlines"),
148 "lower": partial(_append_method_to_actual, "lower"),
149 "any_order": _unique_sort,
150 "strip_expected": partial(_append_method_to_expected, "strip"),
151 "splitlines_expected": partial(_append_method_to_expected, "splitlines"),
152 }
153 DICT_TRANSFORMATION_FUNCS: ClassVar[Dict[str, TransformationDictFuncT]] = {
154 "remove": _remove_chars,
155 "strip": _strip_chars,
156 }
158 @field_validator("inputs", mode="before")
159 @classmethod
160 def validate_inputs(cls, values):
161 """
162 Validate each input
164 :param values: Inputs to validate
165 :return: Original inputs
166 :raises: :exc:`ValidationError` if input invalid
167 """
169 validate_str_list(cls, values)
170 return values
172 @field_validator("params", mode="before")
173 @classmethod
174 def validate_params(cls, values, info: ValidationInfo):
175 """
176 Validate each parameter
178 :param values: Parameters to validate
179 :param info: Test item
180 :return: Original parameters
181 :raises: :exc:`ValidationError` if project requires parameters but no input, no name,
182 or empty name. Also, raised if no expected output
183 """
185 errors = []
186 field_is_required = "Field is required when parameters required"
188 for index, value in enumerate(values):
189 if info.data.get("requires_parameters"):
190 if not isinstance(value, dict):
191 errors.append(
192 get_error_details("Input should be a valid dictionary", (index,), value)
193 )
194 continue
196 if "name" not in value:
197 errors.append(get_error_details(field_is_required, (index, "name"), value))
198 elif isinstance(value["name"], str) and not value["name"]:
199 errors.append(
200 get_error_details(
201 "Value must not be empty when parameters required",
202 (index, "name"),
203 value,
204 )
205 )
207 if "input" not in value:
208 errors.append(get_error_details(field_is_required, (index, "input"), value))
210 if "expected" not in value:
211 errors.append(get_error_details(field_is_required, (index, "expected"), value))
213 if errors:
214 # Collect inner errors
215 for index, value in enumerate(values):
216 if not isinstance(value, dict):
217 continue
219 try:
220 AutoGenParam.model_validate(value)
221 except ValidationError as exc:
222 for err in exc.errors():
223 loc = (index,) + tuple(err.get("loc", ()))
224 msg = err.get("msg") or str(err.get("type", "value_error"))
225 input_val = err.get("input", value)
226 errors.append(get_error_details(msg, loc, input_val))
228 raise_validation_errors(cls, errors)
230 return values
232 @field_validator("strings", mode="before")
233 @classmethod
234 def validate_strings(cls, values):
235 """
236 Validate each string
238 :param values: Strings to validate
239 :return: Original strings
240 :raises: :exc:`ValidationError` if strings invalid
241 """
243 validate_str_dict(cls, values)
244 return values
246 @field_validator("transformations", mode="before")
247 @classmethod
248 def validate_transformation(cls, values):
249 """
250 Validate each transformation
252 :param values: Transformations to validate
253 :return: Original values
254 :raises: :exc:`ValidationError` if Invalid transformation
255 """
257 if not isinstance(values, list):
258 raise_simple_validation_error(cls, "Input should be a valid list", values)
260 errors = []
261 for index, value in enumerate(values):
262 if isinstance(value, str):
263 if value not in cls.SCALAR_TRANSFORMATION_FUNCS:
264 errors.append(
265 get_error_details(f'Invalid transformation "{value}"', (index,), value)
266 )
267 elif isinstance(value, dict):
268 key = str(*value)
269 if key not in cls.DICT_TRANSFORMATION_FUNCS:
270 errors.append(
271 get_error_details(f'Invalid transformation "{key}"', (index,), value)
272 )
273 else:
274 errors += validate_str_list(cls, value[key], (index, key), raise_exc=False)
275 else:
276 errors.append(
277 get_error_details(
278 "Input should be a valid string or dictionary", (index,), value
279 )
280 )
282 if errors:
283 raise_validation_errors(cls, errors)
285 return values
287 @model_validator(mode="after")
288 def validate_test_strings(self):
289 """
290 Validate each test string
291 """
293 errors = []
294 if self.requires_parameters:
295 for index, param in enumerate(self.params):
296 loc = ("params", index, "expected", "string")
297 expected = param.expected
298 if (
299 isinstance(param.expected, dict)
300 and "string" in expected
301 and expected["string"] not in self.strings
302 ):
303 expected_string = expected["string"]
304 errors.append(
305 get_error_details(
306 f"Refers to a non-existent string {expected_string}",
307 loc=loc,
308 input=expected_string,
309 )
310 )
312 if errors:
313 raise_validation_errors(self.__class__, errors)
315 return self
317 def transform_vars(self) -> Tuple[str, str]:
318 """
319 Transform variables using the specified transformations
321 :return: Transformed actual and expected variables
322 """
324 actual_var = "actual"
325 expected_var = "expected"
326 for transfomation in self.transformations:
327 if isinstance(transfomation, str):
328 actual_var, expected_var = self.SCALAR_TRANSFORMATION_FUNCS[transfomation](
329 actual_var, expected_var
330 )
331 else:
332 key, item = tuple(*transfomation.items())
333 actual_var, expected_var = self.DICT_TRANSFORMATION_FUNCS[key](
334 item, actual_var, expected_var
335 )
337 return actual_var, expected_var
339 def get_constant_variables(self) -> str:
340 """
341 Get contant variables
343 :return: constant variables
344 """
346 constant_variables = {}
347 if self.requires_parameters:
348 for param in self.params:
349 variable_name = param.get_constant_variable_name()
350 if variable_name and variable_name not in constant_variables:
351 constant_variables[variable_name] = self.strings[param.expected["string"]]
353 return "".join(f"{name} = {quote(value)}\n" for name, value in constant_variables.items())
355 def get_pytest_params(self) -> str:
356 """
357 Get pytest parameters
359 :return: pytest parameters
360 """
362 if not self.requires_parameters:
363 return ""
365 pytest_params = "".join(
366 indent(param.get_pytest_param(), 8) for param in self.params
367 ).strip()
368 return f"""\
369@pytest.mark.parametrize(
370 ("in_params", "expected"),
371 [
372 {pytest_params}
373 ]
374)
375"""
377 def get_test_function_and_run(self, project_name_underscores: str) -> str:
378 """
379 Get test function and run command
381 :param project_name_underscores: Project name with underscores between each word
382 :return: Test function and run command
383 """
385 func_params = ""
386 run_param = ""
387 if self.requires_parameters:
388 func_params = "in_params, expected, "
389 run_param = "params=in_params"
391 return f"""\
392def test_{self.name}({func_params}{project_name_underscores}):
393 actual = {project_name_underscores}.run({run_param})
394"""
396 def get_expected_output(self, project_name_underscores: str) -> str:
397 """
398 Get test code that gets the expected output
400 :param project_name_underscores: Project name with underscores between each word
401 :return: Test code that gets the expected output
403 """
405 if self.requires_parameters:
406 return ""
408 expected_output = self.params[0].expected
409 if isinstance(expected_output, str):
410 expected_output = quote(expected_output)
411 elif isinstance(expected_output, dict):
412 return _get_expected_file(project_name_underscores, expected_output)
414 return f"expected = {expected_output}\n"
416 def generate_test(self, project_name_underscores: str) -> str:
417 """
418 Generate test code
420 :param project_name_underscores: Project name with underscores between each word
421 :return: Test code
422 """
424 test_code = "@project_test(PROJECT_NAME)\n"
425 test_code += self.get_pytest_params()
426 test_code += self.get_test_function_and_run(project_name_underscores)
427 test_code += indent(self.get_expected_output(project_name_underscores), 4)
428 actual_var, expected_var = self.transform_vars()
429 test_code += indent(_get_assert(actual_var, expected_var, self.params[0].expected), 4)
430 return test_code
433def _get_expected_file(project_name_underscores: str, expected_output: Dict[str, str]) -> str:
434 if "exec" in expected_output:
435 script = quote(expected_output["exec"])
436 return f"expected = {project_name_underscores}.exec({script})\n"
438 test_code = f"""\
439with open({project_name_underscores}.full_path, "r", encoding="utf-8") as file:
440 expected = file.read()
441"""
443 if "self" in expected_output: 443 ↛ 452line 443 didn't jump to line 452 because the condition on line 443 was always true
444 test_code += """\
445diff_len = len(actual) - len(expected)
446if diff_len > 0:
447 expected += "\\n"
448elif diff_len < 0:
449 actual += "\\n"
450"""
452 return test_code
455def _get_assert(actual_var: str, expected_var: str, expected_output) -> str:
456 if isinstance(expected_output, list):
457 return f"""\
458actual_list = {actual_var}
459expected_list = {expected_var}
460assert len(actual_list) == len(expected_list), "Length not equal"
461for index in range(len(expected_list)):
462 assert actual_list[index] == expected_list[index], f"Item {{index + 1}} is not equal"
463"""
465 test_code = ""
466 if actual_var != "actual":
467 test_code += f"actual = {actual_var}\n"
469 if expected_var != "expected":
470 test_code += f"expected = {expected_var}\n"
472 return f"{test_code}assert actual == expected\n"
475class AutoGenUseTests(BaseModel):
476 """Object used to specify what tests to use"""
478 name: str
479 search: Annotated[str, Field(strict=True, pattern="^[0-9a-zA-Z_]*$")] = ""
480 replace: Annotated[str, Field(strict=True, pattern="^[0-9a-zA-Z_]*$")] = ""
482 @model_validator(mode="before")
483 @classmethod
484 def validate_search_with_replace(cls, values):
485 """
486 Validate that if either search or replace is specified, both must be specified
488 :param values: Values to validate
489 :return: Original values
490 :raise: `exc`:ValidationError if either search or replace is specified, both are specified
491 """
493 if "search" in values and "replace" not in values:
494 raise_simple_validation_error(
495 cls, '"search" item specified without "replace" item', values
496 )
498 if "search" not in values and "replace" in values:
499 raise_simple_validation_error(
500 cls, '"replace" item specified without "search" item', values
501 )
503 return values