Coverage for glotter/settings.py: 99%

132 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-03-01 21:54 +0000

1import os 

2from dataclasses import dataclass 

3from functools import cache 

4from typing import Dict, Optional 

5 

6from glotter_core.project import AcronymScheme 

7from glotter_core.settings import CoreSettingsParser 

8from pydantic import ( 

9 BaseModel, 

10 Field, 

11 ValidationError, 

12 ValidationInfo, 

13 field_validator, 

14 model_validator, 

15) 

16 

17from glotter.errors import get_error_details, raise_simple_validation_error, raise_validation_errors 

18from glotter.project import Project 

19from glotter.utils import error_and_exit, indent 

20 

21 

22@cache 

23def get_settings(): 

24 """ 

25 Get Settings as a singleton 

26 """ 

27 return Settings() 

28 

29 

30class Settings: 

31 def __init__(self): 

32 self._project_root = os.getcwd() 

33 try: 

34 self._parser = SettingsParser(self._project_root) 

35 except ValidationError as e: 

36 error_and_exit(_format_validate_error(e)) 

37 

38 self._projects = self._parser.projects 

39 self._source_root = self._parser.source_root or self._project_root 

40 self._test_mappings = {} 

41 

42 @property 

43 def projects(self): 

44 return self._projects 

45 

46 @property 

47 def project_root(self): 

48 return self._project_root 

49 

50 @property 

51 def source_root(self): 

52 return self._source_root 

53 

54 @source_root.setter 

55 def source_root(self, value): 

56 self._source_root = value or self._project_root 

57 

58 @property 

59 def test_mappings(self): 

60 return self._test_mappings 

61 

62 def get_test_mapping_name(self, project_type): 

63 mappings = self._test_mappings.get(project_type) 

64 if mappings: 

65 return [func.__name__ for func in mappings] 

66 return [] 

67 

68 def add_test_mapping(self, project_type, func): 

69 if project_type not in self._projects: 

70 raise KeyError(f"Project type {project_type} was not found in glotter.yml") 

71 

72 if project_type not in self._test_mappings: 

73 self._test_mappings[project_type] = [] 

74 self._test_mappings[project_type].append(func) 

75 

76 def verify_project_type(self, name): 

77 return name.lower() in self.projects 

78 

79 

80def _format_validate_error(validation_error: ValidationError) -> str: 

81 error_msgs = [] 

82 for error in validation_error.errors(): 

83 error_msgs.append( 

84 "- " 

85 + ".".join( 

86 _format_location_item(location) 

87 for location in error["loc"] 

88 if location != "__root__" 

89 ) 

90 + ":" 

91 ) 

92 error_msgs.append(indent(error["msg"], 4)) 

93 

94 return "Errors found in the following items:\n" + "\n".join(error_msgs) 

95 

96 

97def _format_location_item(location) -> str: 

98 if isinstance(location, int): 

99 return f"item {location + 1}" 

100 

101 return str(location) 

102 

103 

104class SettingsConfigSettings(BaseModel): 

105 acronym_scheme: AcronymScheme = Field(AcronymScheme.two_letter_limit, validate_default=True) 

106 yml_path: str 

107 source_root: Optional[str] = None 

108 

109 @field_validator("acronym_scheme", mode="before") 

110 @classmethod 

111 def get_acronym_scheme(cls, value): 

112 if isinstance(value, str): 

113 return value.lower() 

114 

115 return value 

116 

117 @field_validator("source_root", mode="after") 

118 @classmethod 

119 def get_source_root(cls, value, info: ValidationInfo): 

120 if os.path.isabs(value): 

121 return value 

122 

123 yml_dir = os.path.dirname(info.data["yml_path"]) 

124 return os.path.abspath(os.path.join(yml_dir, value)) 

125 

126 

127class SettingsConfig(BaseModel): 

128 yml_path: str 

129 settings: Optional[SettingsConfigSettings] = Field(None, validate_default=True) 

130 projects: Dict[str, Project] = {} 

131 

132 @field_validator("settings", mode="before") 

133 @classmethod 

134 def get_settings(cls, value, info: ValidationInfo): 

135 if value is None: 

136 return {"yml_path": info.data["yml_path"]} 

137 

138 if isinstance(value, dict): 

139 return {**value, "yml_path": info.data["yml_path"]} 

140 

141 return value 

142 

143 @field_validator("projects", mode="before") 

144 @classmethod 

145 def get_projects(cls, value, info: ValidationInfo): 

146 if not isinstance(value, dict): 

147 raise_simple_validation_error(cls, "Input should be a valid dictionary", value) 

148 

149 acronym_scheme = info.data["settings"].acronym_scheme 

150 for project_name, item in value.items(): 

151 if not isinstance(item, dict): 

152 break 

153 

154 value[project_name] = {**item, "acronym_scheme": acronym_scheme} 

155 

156 return value 

157 

158 @model_validator(mode="after") 

159 def validate_projects(self): 

160 projects = self.projects 

161 if not isinstance(projects, dict): 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true

162 return self 

163 

164 projects_with_use_tests = { 

165 project_name: project for project_name, project in projects.items() if project.use_tests 

166 } 

167 

168 errors = [] 

169 for project_name, project in projects_with_use_tests.items(): 

170 use_tests_name = project.use_tests.name 

171 loc = ("projects", project_name, "use_tests") 

172 

173 # Make sure "use_tests" item refers to an actual project 

174 if use_tests_name not in projects: 

175 errors.append( 

176 get_error_details( 

177 f"Refers to a non-existent project {use_tests_name}", 

178 loc=loc, 

179 input=use_tests_name, 

180 ) 

181 ) 

182 # Make sure one "use_tests" item does not refer to another "use_tests" item 

183 elif use_tests_name in projects_with_use_tests: 

184 errors.append( 

185 get_error_details( 

186 f'Refers to another "use_tests" project {use_tests_name}', 

187 loc=loc, 

188 input=use_tests_name, 

189 ) 

190 ) 

191 # Make sure "use_tests" item refers to a project with tests 

192 elif not projects[use_tests_name].tests: 

193 errors.append( 

194 get_error_details( 

195 f'Refers to project {use_tests_name}, which has no "tests" item', 

196 loc=loc, 

197 input=use_tests_name, 

198 ) 

199 ) 

200 # Otherwise, set the tests that the "use_tests" item refers to with the tests renamed 

201 else: 

202 project.set_tests(projects[use_tests_name]) 

203 

204 if errors: 

205 raise_validation_errors(self.__class__, errors) 

206 

207 return self 

208 

209 

210@dataclass(frozen=True) 

211class SettingsParser(CoreSettingsParser): 

212 def __init__(self, project_root): 

213 try: 

214 super().__init__(project_root) 

215 except ValueError as exc: 

216 error_and_exit(str(exc)) 

217 

218 config = SettingsConfig(**self.yml, yml_path=self.yml_path) 

219 object.__setattr__(self, "acronym_scheme", config.settings.acronym_scheme) 

220 object.__setattr__(self, "source_root", config.settings.source_root) 

221 object.__setattr__(self, "projects", config.projects)