Coverage for glotter / settings.py: 100%

167 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-10 16:13 +0000

1import os 

2from dataclasses import dataclass 

3from functools import cache 

4from typing import Dict, Optional 

5 

6from glotter_core.project import AcronymScheme 

7from glotter_core.settings import CoreSettingsParser 

8from pydantic import ( 

9 BaseModel, 

10 Field, 

11 ValidationError, 

12 ValidationInfo, 

13 field_validator, 

14 model_validator, 

15) 

16 

17from glotter.errors import get_error_details, raise_simple_validation_error, raise_validation_errors 

18from glotter.project import Project 

19from glotter.utils import error_and_exit, indent 

20 

21 

22@cache 

23def get_settings(): 

24 """ 

25 Get Settings as a singleton 

26 """ 

27 return Settings() 

28 

29 

30class Settings: 

31 def __init__(self): 

32 self._project_root = os.getcwd() 

33 try: 

34 self._parser = SettingsParser(self._project_root) 

35 except ValidationError as e: 

36 error_and_exit(_format_validate_error(e)) 

37 

38 self._projects = self._parser.projects 

39 self._source_root = self._parser.source_root or self._project_root 

40 self._test_mappings = {} 

41 

42 @property 

43 def projects(self): 

44 return self._projects 

45 

46 @property 

47 def project_root(self): 

48 return self._project_root 

49 

50 @property 

51 def source_root(self): 

52 return self._source_root 

53 

54 @source_root.setter 

55 def source_root(self, value): 

56 self._source_root = value or self._project_root 

57 

58 @property 

59 def test_mappings(self): 

60 return self._test_mappings 

61 

62 def get_test_mapping_name(self, project_type): 

63 mappings = self._test_mappings.get(project_type) 

64 if mappings: 

65 return [func.__name__ for func in mappings] 

66 return [] 

67 

68 def add_test_mapping(self, project_type, func): 

69 if project_type not in self._projects: 

70 raise KeyError(f"Project type {project_type} was not found in glotter.yml") 

71 

72 if project_type not in self._test_mappings: 

73 self._test_mappings[project_type] = [] 

74 self._test_mappings[project_type].append(func) 

75 

76 def verify_project_type(self, name): 

77 return name.lower() in self.projects 

78 

79 

80def _format_validate_error(validation_error: ValidationError) -> str: 

81 error_msgs = [] 

82 for error in validation_error.errors(): 

83 error_msgs.append( 

84 "- " 

85 + ".".join( 

86 _format_location_item(location) 

87 for location in error["loc"] 

88 if location != "__root__" 

89 ) 

90 + ":" 

91 ) 

92 error_msgs.append(indent(error["msg"], 4)) 

93 

94 return "Errors found in the following items:\n" + "\n".join(error_msgs) 

95 

96 

97def _format_location_item(location) -> str: 

98 if isinstance(location, int): 

99 return f"item {location + 1}" 

100 

101 return str(location) 

102 

103 

104def _validate_use_tests_repeat(projects: object) -> list: 

105 errors = [] 

106 

107 if not isinstance(projects, dict): 

108 return errors 

109 

110 for project_name, project in projects.items(): 

111 if not isinstance(project, dict): 

112 continue 

113 

114 use_tests = project.get("use_tests") 

115 repeat = project.get("repeat") 

116 if not isinstance(use_tests, dict) or not isinstance(repeat, dict): 

117 continue 

118 

119 use_tests_name = use_tests.get("name") 

120 target_project = projects.get(use_tests_name) 

121 if not isinstance(use_tests_name, str) or not isinstance(target_project, dict): 

122 continue 

123 

124 tests = target_project.get("tests") 

125 search = use_tests.get("search") 

126 replace = use_tests.get("replace") 

127 if not isinstance(search, str) or not isinstance(replace, str): 

128 continue 

129 

130 valid_test_names = { 

131 test_name.replace(search, replace) 

132 for test_name in tests.keys() 

133 if isinstance(test_name, str) 

134 } 

135 

136 for repeat_name in repeat: 

137 if isinstance(repeat_name, str) and repeat_name not in valid_test_names: 

138 errors.append( 

139 get_error_details( 

140 f"Refers to a non-existent test name {repeat_name}", 

141 ("projects", project_name, "repeat", repeat_name), 

142 repeat_name, 

143 ) 

144 ) 

145 

146 return errors 

147 

148 

149def _convert_validation_error_to_error_details(validation_error: ValidationError) -> list: 

150 errors = [] 

151 for error in validation_error.errors(): 

152 errors.append( 

153 get_error_details( 

154 error["msg"], 

155 tuple(error.get("loc", ())), 

156 error.get("input"), 

157 ) 

158 ) 

159 return errors 

160 

161 

162class SettingsConfigSettings(BaseModel): 

163 acronym_scheme: AcronymScheme = Field(AcronymScheme.two_letter_limit, validate_default=True) 

164 yml_path: str 

165 source_root: Optional[str] = None 

166 

167 @field_validator("acronym_scheme", mode="before") 

168 @classmethod 

169 def get_acronym_scheme(cls, value): 

170 if isinstance(value, str): 

171 return value.lower() 

172 

173 return value 

174 

175 @field_validator("source_root", mode="after") 

176 @classmethod 

177 def get_source_root(cls, value, info: ValidationInfo): 

178 if os.path.isabs(value): 

179 return value 

180 

181 yml_dir = os.path.dirname(info.data["yml_path"]) 

182 return os.path.abspath(os.path.join(yml_dir, value)) 

183 

184 

185class SettingsConfig(BaseModel): 

186 yml_path: str 

187 settings: Optional[SettingsConfigSettings] = Field(None, validate_default=True) 

188 projects: Dict[str, Project] = {} 

189 

190 @field_validator("settings", mode="before") 

191 @classmethod 

192 def get_settings(cls, value, info: ValidationInfo): 

193 if value is None: 

194 return {"yml_path": info.data["yml_path"]} 

195 

196 if isinstance(value, dict): 

197 return {**value, "yml_path": info.data["yml_path"]} 

198 

199 return value 

200 

201 @field_validator("projects", mode="before") 

202 @classmethod 

203 def get_projects(cls, value, info: ValidationInfo): 

204 if not isinstance(value, dict): 

205 raise_simple_validation_error(cls, "Input should be a valid dictionary", value) 

206 

207 acronym_scheme = info.data["settings"].acronym_scheme 

208 for project_name, item in value.items(): 

209 if not isinstance(item, dict): 

210 continue 

211 

212 value[project_name] = {**item, "acronym_scheme": acronym_scheme} 

213 

214 return value 

215 

216 @model_validator(mode="after") 

217 def validate_projects(self): 

218 projects = self.projects 

219 projects_with_use_tests = { 

220 project_name: project for project_name, project in projects.items() if project.use_tests 

221 } 

222 

223 errors = [] 

224 for project_name, project in projects_with_use_tests.items(): 

225 use_tests_name = project.use_tests.name 

226 loc = ("projects", project_name, "use_tests") 

227 

228 # Make sure "use_tests" item refers to an actual project 

229 if use_tests_name not in projects: 

230 errors.append( 

231 get_error_details( 

232 f"Refers to a non-existent project {use_tests_name}", 

233 loc=loc, 

234 input=use_tests_name, 

235 ) 

236 ) 

237 # Make sure one "use_tests" item does not refer to another "use_tests" item 

238 elif use_tests_name in projects_with_use_tests: 

239 errors.append( 

240 get_error_details( 

241 f'Refers to another "use_tests" project {use_tests_name}', 

242 loc=loc, 

243 input=use_tests_name, 

244 ) 

245 ) 

246 # Make sure "use_tests" item refers to a project with tests 

247 elif not projects[use_tests_name].tests: 

248 errors.append( 

249 get_error_details( 

250 f'Refers to project {use_tests_name}, which has no "tests" item', 

251 loc=loc, 

252 input=use_tests_name, 

253 ) 

254 ) 

255 # Otherwise, set the tests that the "use_tests" item refers to with the tests renamed 

256 else: 

257 errors += project.set_tests( 

258 projects[use_tests_name], loc_prefix=("projects", project_name) 

259 ) 

260 

261 if errors: 

262 raise_validation_errors(self.__class__, errors) 

263 

264 return self 

265 

266 

267@dataclass(frozen=True) 

268class SettingsParser(CoreSettingsParser): 

269 def __init__(self, project_root): 

270 try: 

271 super().__init__(project_root) 

272 except ValueError as exc: 

273 error_and_exit(str(exc)) 

274 

275 try: 

276 config = SettingsConfig(**self.yml, yml_path=self.yml_path) 

277 except ValidationError as exc: 

278 extra_errors = _validate_use_tests_repeat(self.yml.get("projects", {})) 

279 if extra_errors: 

280 combined_errors = _convert_validation_error_to_error_details(exc) + extra_errors 

281 raise ValidationError.from_exception_data( 

282 title=SettingsConfig.__name__, 

283 line_errors=combined_errors, 

284 ) 

285 raise 

286 

287 object.__setattr__(self, "acronym_scheme", config.settings.acronym_scheme) 

288 object.__setattr__(self, "source_root", config.settings.source_root) 

289 object.__setattr__(self, "projects", config.projects)