Coverage for glotter/settings.py: 99%

163 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-08 17:57 +0000

1import os 

2from typing import Dict, Optional 

3from warnings import warn 

4 

5import yaml 

6from pydantic import ( 

7 BaseModel, 

8 Field, 

9 ValidationError, 

10 ValidationInfo, 

11 field_validator, 

12 model_validator, 

13) 

14 

15from glotter.errors import get_error_details, raise_simple_validation_error, raise_validation_errors 

16from glotter.project import AcronymScheme, Project 

17from glotter.singleton import Singleton 

18from glotter.utils import error_and_exit, indent 

19 

20 

21class Settings(metaclass=Singleton): 

22 def __init__(self): 

23 self._project_root = os.getcwd() 

24 try: 

25 self._parser = SettingsParser(self._project_root) 

26 except ValidationError as e: 

27 error_and_exit(_format_validate_error(e)) 

28 

29 self._projects = self._parser.projects 

30 self._source_root = self._parser.source_root or self._project_root 

31 self._test_mappings = {} 

32 

33 @property 

34 def projects(self): 

35 return self._projects 

36 

37 @property 

38 def project_root(self): 

39 return self._project_root 

40 

41 @property 

42 def source_root(self): 

43 return self._source_root 

44 

45 @source_root.setter 

46 def source_root(self, value): 

47 self._source_root = value or self._project_root 

48 

49 @property 

50 def test_mappings(self): 

51 return self._test_mappings 

52 

53 def get_test_mapping_name(self, project_type): 

54 mappings = self._test_mappings.get(project_type) 

55 if mappings: 

56 return [func.__name__ for func in mappings] 

57 return [] 

58 

59 def add_test_mapping(self, project_type, func): 

60 if project_type not in self._projects: 

61 raise KeyError(f"Project type {project_type} was not found in glotter.yml") 

62 

63 if project_type not in self._test_mappings: 

64 self._test_mappings[project_type] = [] 

65 self._test_mappings[project_type].append(func) 

66 

67 def verify_project_type(self, name): 

68 return name.lower() in self.projects 

69 

70 

71def _format_validate_error(validation_error: ValidationError) -> str: 

72 error_msgs = [] 

73 for error in validation_error.errors(): 

74 error_msgs.append( 

75 "- " 

76 + ".".join( 

77 _format_location_item(location) 

78 for location in error["loc"] 

79 if location != "__root__" 

80 ) 

81 + ":" 

82 ) 

83 error_msgs.append(indent(error["msg"], 4)) 

84 

85 return "Errors found in the following items:\n" + "\n".join(error_msgs) 

86 

87 

88def _format_location_item(location) -> str: 

89 if isinstance(location, int): 

90 return f"item {location + 1}" 

91 

92 return str(location) 

93 

94 

95class SettingsConfigSettings(BaseModel): 

96 acronym_scheme: AcronymScheme = Field(AcronymScheme.two_letter_limit, validate_default=True) 

97 yml_path: str 

98 source_root: Optional[str] = None 

99 

100 @field_validator("acronym_scheme", mode="before") 

101 @classmethod 

102 def get_acronym_scheme(cls, value): 

103 if isinstance(value, str): 

104 return value.lower() 

105 

106 return value 

107 

108 @field_validator("source_root", mode="after") 

109 @classmethod 

110 def get_source_root(cls, value, info: ValidationInfo): 

111 if os.path.isabs(value): 

112 return value 

113 

114 yml_dir = os.path.dirname(info.data["yml_path"]) 

115 return os.path.abspath(os.path.join(yml_dir, value)) 

116 

117 

118class SettingsConfig(BaseModel): 

119 yml_path: str 

120 settings: Optional[SettingsConfigSettings] = Field(None, validate_default=True) 

121 projects: Dict[str, Project] = {} 

122 

123 @field_validator("settings", mode="before") 

124 @classmethod 

125 def get_settings(cls, value, info: ValidationInfo): 

126 if value is None: 

127 return {"yml_path": info.data["yml_path"]} 

128 

129 if isinstance(value, dict): 

130 return {**value, "yml_path": info.data["yml_path"]} 

131 

132 return value 

133 

134 @field_validator("projects", mode="before") 

135 @classmethod 

136 def get_projects(cls, value, info: ValidationInfo): 

137 if not isinstance(value, dict): 

138 raise_simple_validation_error(cls, "Input should be a valid dictionary", value) 

139 

140 acronym_scheme = info.data["settings"].acronym_scheme 

141 for project_name, item in value.items(): 

142 if not isinstance(item, dict): 

143 break 

144 

145 value[project_name] = {**item, "acronym_scheme": acronym_scheme} 

146 

147 return value 

148 

149 @model_validator(mode="after") 

150 def validate_projects(self): 

151 projects = self.projects 

152 if not isinstance(projects, dict): 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true

153 return self 

154 

155 projects_with_use_tests = { 

156 project_name: project for project_name, project in projects.items() if project.use_tests 

157 } 

158 

159 errors = [] 

160 for project_name, project in projects_with_use_tests.items(): 

161 use_tests_name = project.use_tests.name 

162 loc = ("projects", project_name, "use_tests") 

163 

164 # Make sure "use_tests" item refers to an actual project 

165 if use_tests_name not in projects: 

166 errors.append( 

167 get_error_details( 

168 f"Refers to a non-existent project {use_tests_name}", 

169 loc=loc, 

170 input=use_tests_name, 

171 ) 

172 ) 

173 # Make sure one "use_tests" item does not refer to another "use_tests" item 

174 elif use_tests_name in projects_with_use_tests: 

175 errors.append( 

176 get_error_details( 

177 f'Refers to another "use_tests" project {use_tests_name}', 

178 loc=loc, 

179 input=use_tests_name, 

180 ) 

181 ) 

182 # Make sure "use_tests" item refers to a project with tests 

183 elif not projects[use_tests_name].tests: 

184 errors.append( 

185 get_error_details( 

186 f'Refers to project {use_tests_name}, which has no "tests" item', 

187 loc=loc, 

188 input=use_tests_name, 

189 ) 

190 ) 

191 # Otherwise, set the tests that the "use_tests" item refers to with the tests renamed 

192 else: 

193 project.set_tests(projects[use_tests_name]) 

194 

195 if errors: 

196 raise_validation_errors(self.__class__, errors) 

197 

198 return self 

199 

200 

201class SettingsParser: 

202 def __init__(self, project_root): 

203 self._project_root = project_root 

204 self._yml_path = None 

205 self._acronym_scheme = None 

206 self._projects = None 

207 self._source_root = None 

208 self._yml_path = self._locate_yml() 

209 

210 yml = None 

211 if self._yml_path is not None: 

212 yml = self._parse_yml() 

213 else: 

214 self._yml_path = project_root 

215 warn(f'.glotter.yml not found in directory "{project_root}"') 

216 

217 if yml is None: 

218 yml = {} 

219 

220 if not isinstance(yml, dict): 

221 error_and_exit(".glotter.yml does not contain a dict") 

222 

223 config = SettingsConfig(**yml, yml_path=self._yml_path) 

224 self._acronym_scheme = config.settings.acronym_scheme 

225 self._source_root = config.settings.source_root 

226 self._projects = config.projects 

227 

228 @property 

229 def project_root(self): 

230 return self._project_root 

231 

232 @property 

233 def yml_path(self): 

234 return self._yml_path 

235 

236 @property 

237 def source_root(self): 

238 return self._source_root 

239 

240 @property 

241 def acronym_scheme(self): 

242 return self._acronym_scheme 

243 

244 @property 

245 def projects(self): 

246 return self._projects 

247 

248 def _parse_yml(self): 

249 with open(self._yml_path, "r", encoding="utf-8") as f: 

250 contents = f.read() 

251 

252 return yaml.safe_load(contents) 

253 

254 def _locate_yml(self): 

255 for root, _, files in os.walk(self._project_root): 

256 if ".glotter.yml" in files: 

257 path = os.path.abspath(root) 

258 return os.path.join(path, ".glotter.yml") 

259 

260 return None