Coverage for glotter/test_generator.py: 100%
61 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-03-01 21:54 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-03-01 21:54 +0000
1import os
2import shutil
3import subprocess
4import tempfile
6from glotter.settings import get_settings
8AUTO_GEN_TEST_PATH = os.path.join("test", "generated")
11def generate_tests():
12 """
13 Generate tests for all projects
14 """
16 shutil.rmtree(AUTO_GEN_TEST_PATH, ignore_errors=True)
17 settings = get_settings()
18 test_generators = {
19 project_name: TestGenerator(project_name, project)
20 for project_name, project in settings.projects.items()
21 }
22 test_codes = {}
23 for project_name, test_generator in test_generators.items():
24 test_code = test_generator.generate_tests()
25 if test_code:
26 test_codes[project_name] = test_code
28 for project_name, test_code in test_codes.items():
29 test_generators[project_name].write_tests(test_code)
32class TestGenerator:
33 __test__ = False # Indicate this is not a test
35 def __init__(self, project_name, project):
36 self.project_name = project_name
37 self.project = project
38 self.long_project_name = "_".join(self.project.words)
40 def generate_tests(self):
41 if not self.project.tests:
42 return ""
44 test_code = (
45 self._get_imports() + self._get_constant_variables() + self._get_project_fixture()
46 )
47 for test_obj in self.project.tests.values():
48 test_code += test_obj.generate_test(self.long_project_name)
50 return format_str(test_code)
52 def _get_imports(self):
53 test_code = ""
54 if self.project.requires_parameters:
55 test_code += "import pytest\n"
57 test_code += "from glotter import project_test, project_fixture\n"
58 return test_code
60 def _get_constant_variables(self):
61 test_code = ""
62 if self.project.requires_parameters:
63 constant_variables = set()
64 for test_obj in self.project.tests.values():
65 constant_variable = test_obj.get_constant_variables()
66 if constant_variable:
67 constant_variables.add(constant_variable)
69 test_code += "".join(
70 constant_variable for constant_variable in sorted(constant_variables)
71 )
73 return test_code
75 def _get_project_fixture(self):
76 return f"""\
77PROJECT_NAME="{self.project_name}"
78@project_fixture(PROJECT_NAME)
79def {self.long_project_name}(request):
80 try:
81 request.param.build()
82 yield request.param
83 finally:
84 request.param.cleanup()
85"""
87 def write_tests(self, test_code):
88 os.makedirs(AUTO_GEN_TEST_PATH, exist_ok=True)
89 with open(
90 os.path.join(AUTO_GEN_TEST_PATH, f"test_{self.long_project_name}.py"),
91 "w",
92 encoding="utf-8",
93 ) as f:
94 f.write(test_code)
97def format_str(test_code):
98 with tempfile.NamedTemporaryFile(mode="r+", encoding="utf-8") as tmp_file:
99 tmp_file.write(test_code)
100 tmp_file.flush()
101 tmp_file.seek(0)
102 subprocess.run(
103 ["ruff", "format", "--line-length=100", tmp_file.name],
104 check=True,
105 stdout=subprocess.DEVNULL,
106 )
108 tmp_file.seek(0)
109 contents = tmp_file.read()
111 return contents