Coverage for glotter/test_generator.py: 100%

61 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-03-01 21:54 +0000

1import os 

2import shutil 

3import subprocess 

4import tempfile 

5 

6from glotter.settings import get_settings 

7 

8AUTO_GEN_TEST_PATH = os.path.join("test", "generated") 

9 

10 

11def generate_tests(): 

12 """ 

13 Generate tests for all projects 

14 """ 

15 

16 shutil.rmtree(AUTO_GEN_TEST_PATH, ignore_errors=True) 

17 settings = get_settings() 

18 test_generators = { 

19 project_name: TestGenerator(project_name, project) 

20 for project_name, project in settings.projects.items() 

21 } 

22 test_codes = {} 

23 for project_name, test_generator in test_generators.items(): 

24 test_code = test_generator.generate_tests() 

25 if test_code: 

26 test_codes[project_name] = test_code 

27 

28 for project_name, test_code in test_codes.items(): 

29 test_generators[project_name].write_tests(test_code) 

30 

31 

32class TestGenerator: 

33 __test__ = False # Indicate this is not a test 

34 

35 def __init__(self, project_name, project): 

36 self.project_name = project_name 

37 self.project = project 

38 self.long_project_name = "_".join(self.project.words) 

39 

40 def generate_tests(self): 

41 if not self.project.tests: 

42 return "" 

43 

44 test_code = ( 

45 self._get_imports() + self._get_constant_variables() + self._get_project_fixture() 

46 ) 

47 for test_obj in self.project.tests.values(): 

48 test_code += test_obj.generate_test(self.long_project_name) 

49 

50 return format_str(test_code) 

51 

52 def _get_imports(self): 

53 test_code = "" 

54 if self.project.requires_parameters: 

55 test_code += "import pytest\n" 

56 

57 test_code += "from glotter import project_test, project_fixture\n" 

58 return test_code 

59 

60 def _get_constant_variables(self): 

61 test_code = "" 

62 if self.project.requires_parameters: 

63 constant_variables = set() 

64 for test_obj in self.project.tests.values(): 

65 constant_variable = test_obj.get_constant_variables() 

66 if constant_variable: 

67 constant_variables.add(constant_variable) 

68 

69 test_code += "".join( 

70 constant_variable for constant_variable in sorted(constant_variables) 

71 ) 

72 

73 return test_code 

74 

75 def _get_project_fixture(self): 

76 return f"""\ 

77PROJECT_NAME="{self.project_name}" 

78@project_fixture(PROJECT_NAME) 

79def {self.long_project_name}(request): 

80 try: 

81 request.param.build() 

82 yield request.param 

83 finally: 

84 request.param.cleanup() 

85""" 

86 

87 def write_tests(self, test_code): 

88 os.makedirs(AUTO_GEN_TEST_PATH, exist_ok=True) 

89 with open( 

90 os.path.join(AUTO_GEN_TEST_PATH, f"test_{self.long_project_name}.py"), 

91 "w", 

92 encoding="utf-8", 

93 ) as f: 

94 f.write(test_code) 

95 

96 

97def format_str(test_code): 

98 with tempfile.NamedTemporaryFile(mode="r+", encoding="utf-8") as tmp_file: 

99 tmp_file.write(test_code) 

100 tmp_file.flush() 

101 tmp_file.seek(0) 

102 subprocess.run( 

103 ["ruff", "format", "--line-length=100", tmp_file.name], 

104 check=True, 

105 stdout=subprocess.DEVNULL, 

106 ) 

107 

108 tmp_file.seek(0) 

109 contents = tmp_file.read() 

110 

111 return contents