Coverage for glotter / settings.py: 100%
167 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-10 16:13 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-10 16:13 +0000
1import os
2from dataclasses import dataclass
3from functools import cache
4from typing import Dict, Optional
6from glotter_core.project import AcronymScheme
7from glotter_core.settings import CoreSettingsParser
8from pydantic import (
9 BaseModel,
10 Field,
11 ValidationError,
12 ValidationInfo,
13 field_validator,
14 model_validator,
15)
17from glotter.errors import get_error_details, raise_simple_validation_error, raise_validation_errors
18from glotter.project import Project
19from glotter.utils import error_and_exit, indent
22@cache
23def get_settings():
24 """
25 Get Settings as a singleton
26 """
27 return Settings()
30class Settings:
31 def __init__(self):
32 self._project_root = os.getcwd()
33 try:
34 self._parser = SettingsParser(self._project_root)
35 except ValidationError as e:
36 error_and_exit(_format_validate_error(e))
38 self._projects = self._parser.projects
39 self._source_root = self._parser.source_root or self._project_root
40 self._test_mappings = {}
42 @property
43 def projects(self):
44 return self._projects
46 @property
47 def project_root(self):
48 return self._project_root
50 @property
51 def source_root(self):
52 return self._source_root
54 @source_root.setter
55 def source_root(self, value):
56 self._source_root = value or self._project_root
58 @property
59 def test_mappings(self):
60 return self._test_mappings
62 def get_test_mapping_name(self, project_type):
63 mappings = self._test_mappings.get(project_type)
64 if mappings:
65 return [func.__name__ for func in mappings]
66 return []
68 def add_test_mapping(self, project_type, func):
69 if project_type not in self._projects:
70 raise KeyError(f"Project type {project_type} was not found in glotter.yml")
72 if project_type not in self._test_mappings:
73 self._test_mappings[project_type] = []
74 self._test_mappings[project_type].append(func)
76 def verify_project_type(self, name):
77 return name.lower() in self.projects
80def _format_validate_error(validation_error: ValidationError) -> str:
81 error_msgs = []
82 for error in validation_error.errors():
83 error_msgs.append(
84 "- "
85 + ".".join(
86 _format_location_item(location)
87 for location in error["loc"]
88 if location != "__root__"
89 )
90 + ":"
91 )
92 error_msgs.append(indent(error["msg"], 4))
94 return "Errors found in the following items:\n" + "\n".join(error_msgs)
97def _format_location_item(location) -> str:
98 if isinstance(location, int):
99 return f"item {location + 1}"
101 return str(location)
104def _validate_use_tests_repeat(projects: object) -> list:
105 errors = []
107 if not isinstance(projects, dict):
108 return errors
110 for project_name, project in projects.items():
111 if not isinstance(project, dict):
112 continue
114 use_tests = project.get("use_tests")
115 repeat = project.get("repeat")
116 if not isinstance(use_tests, dict) or not isinstance(repeat, dict):
117 continue
119 use_tests_name = use_tests.get("name")
120 target_project = projects.get(use_tests_name)
121 if not isinstance(use_tests_name, str) or not isinstance(target_project, dict):
122 continue
124 tests = target_project.get("tests")
125 search = use_tests.get("search")
126 replace = use_tests.get("replace")
127 if not isinstance(search, str) or not isinstance(replace, str):
128 continue
130 valid_test_names = {
131 test_name.replace(search, replace)
132 for test_name in tests.keys()
133 if isinstance(test_name, str)
134 }
136 for repeat_name in repeat:
137 if isinstance(repeat_name, str) and repeat_name not in valid_test_names:
138 errors.append(
139 get_error_details(
140 f"Refers to a non-existent test name {repeat_name}",
141 ("projects", project_name, "repeat", repeat_name),
142 repeat_name,
143 )
144 )
146 return errors
149def _convert_validation_error_to_error_details(validation_error: ValidationError) -> list:
150 errors = []
151 for error in validation_error.errors():
152 errors.append(
153 get_error_details(
154 error["msg"],
155 tuple(error.get("loc", ())),
156 error.get("input"),
157 )
158 )
159 return errors
162class SettingsConfigSettings(BaseModel):
163 acronym_scheme: AcronymScheme = Field(AcronymScheme.two_letter_limit, validate_default=True)
164 yml_path: str
165 source_root: Optional[str] = None
167 @field_validator("acronym_scheme", mode="before")
168 @classmethod
169 def get_acronym_scheme(cls, value):
170 if isinstance(value, str):
171 return value.lower()
173 return value
175 @field_validator("source_root", mode="after")
176 @classmethod
177 def get_source_root(cls, value, info: ValidationInfo):
178 if os.path.isabs(value):
179 return value
181 yml_dir = os.path.dirname(info.data["yml_path"])
182 return os.path.abspath(os.path.join(yml_dir, value))
185class SettingsConfig(BaseModel):
186 yml_path: str
187 settings: Optional[SettingsConfigSettings] = Field(None, validate_default=True)
188 projects: Dict[str, Project] = {}
190 @field_validator("settings", mode="before")
191 @classmethod
192 def get_settings(cls, value, info: ValidationInfo):
193 if value is None:
194 return {"yml_path": info.data["yml_path"]}
196 if isinstance(value, dict):
197 return {**value, "yml_path": info.data["yml_path"]}
199 return value
201 @field_validator("projects", mode="before")
202 @classmethod
203 def get_projects(cls, value, info: ValidationInfo):
204 if not isinstance(value, dict):
205 raise_simple_validation_error(cls, "Input should be a valid dictionary", value)
207 acronym_scheme = info.data["settings"].acronym_scheme
208 for project_name, item in value.items():
209 if not isinstance(item, dict):
210 continue
212 value[project_name] = {**item, "acronym_scheme": acronym_scheme}
214 return value
216 @model_validator(mode="after")
217 def validate_projects(self):
218 projects = self.projects
219 projects_with_use_tests = {
220 project_name: project for project_name, project in projects.items() if project.use_tests
221 }
223 errors = []
224 for project_name, project in projects_with_use_tests.items():
225 use_tests_name = project.use_tests.name
226 loc = ("projects", project_name, "use_tests")
228 # Make sure "use_tests" item refers to an actual project
229 if use_tests_name not in projects:
230 errors.append(
231 get_error_details(
232 f"Refers to a non-existent project {use_tests_name}",
233 loc=loc,
234 input=use_tests_name,
235 )
236 )
237 # Make sure one "use_tests" item does not refer to another "use_tests" item
238 elif use_tests_name in projects_with_use_tests:
239 errors.append(
240 get_error_details(
241 f'Refers to another "use_tests" project {use_tests_name}',
242 loc=loc,
243 input=use_tests_name,
244 )
245 )
246 # Make sure "use_tests" item refers to a project with tests
247 elif not projects[use_tests_name].tests:
248 errors.append(
249 get_error_details(
250 f'Refers to project {use_tests_name}, which has no "tests" item',
251 loc=loc,
252 input=use_tests_name,
253 )
254 )
255 # Otherwise, set the tests that the "use_tests" item refers to with the tests renamed
256 else:
257 errors += project.set_tests(
258 projects[use_tests_name], loc_prefix=("projects", project_name)
259 )
261 if errors:
262 raise_validation_errors(self.__class__, errors)
264 return self
267@dataclass(frozen=True)
268class SettingsParser(CoreSettingsParser):
269 def __init__(self, project_root):
270 try:
271 super().__init__(project_root)
272 except ValueError as exc:
273 error_and_exit(str(exc))
275 try:
276 config = SettingsConfig(**self.yml, yml_path=self.yml_path)
277 except ValidationError as exc:
278 extra_errors = _validate_use_tests_repeat(self.yml.get("projects", {}))
279 if extra_errors:
280 combined_errors = _convert_validation_error_to_error_details(exc) + extra_errors
281 raise ValidationError.from_exception_data(
282 title=SettingsConfig.__name__,
283 line_errors=combined_errors,
284 )
285 raise
287 object.__setattr__(self, "acronym_scheme", config.settings.acronym_scheme)
288 object.__setattr__(self, "source_root", config.settings.source_root)
289 object.__setattr__(self, "projects", config.projects)