Coverage for glotter/settings.py: 100%
159 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-12 02:25 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-12 02:25 +0000
1# pylint hates pydantic
2# pylint: disable=E0213,E0611
3from typing import Optional, Dict
4import os
5from warnings import warn
7import yaml
8from pydantic import BaseModel, validator, root_validator, ValidationError
9from pydantic.error_wrappers import ErrorWrapper
11from glotter.project import Project, AcronymScheme
12from glotter.singleton import Singleton
13from glotter.utils import error_and_exit, indent
16class Settings(metaclass=Singleton):
17 def __init__(self):
18 self._project_root = os.getcwd()
19 try:
20 self._parser = SettingsParser(self._project_root)
21 except ValidationError as e:
22 error_and_exit(_format_validate_error(e))
24 self._projects = self._parser.projects
25 self._source_root = self._parser.source_root or self._project_root
26 self._test_mappings = {}
28 @property
29 def projects(self):
30 return self._projects
32 @property
33 def project_root(self):
34 return self._project_root
36 @property
37 def source_root(self):
38 return self._source_root
40 @source_root.setter
41 def source_root(self, value):
42 self._source_root = value or self._project_root
44 @property
45 def test_mappings(self):
46 return self._test_mappings
48 def get_test_mapping_name(self, project_type):
49 mappings = self._test_mappings.get(project_type)
50 if mappings:
51 return [func.__name__ for func in mappings]
52 return []
54 def add_test_mapping(self, project_type, func):
55 if project_type not in self._projects:
56 raise KeyError(f"Project type {project_type} was not found in glotter.yml")
58 if project_type not in self._test_mappings:
59 self._test_mappings[project_type] = []
60 self._test_mappings[project_type].append(func)
62 def verify_project_type(self, name):
63 return name.lower() in self.projects
66def _format_validate_error(validation_error: ValidationError) -> str:
67 error_msgs = []
68 for error in validation_error.errors():
69 error_msgs.append(
70 "- "
71 + " -> ".join(
72 _format_location_item(location)
73 for location in error["loc"]
74 if location != "__root__"
75 )
76 + ":"
77 )
78 error_msgs.append(indent(error["msg"], 4))
80 return "Errors found in the following items:\n" + "\n".join(error_msgs)
83def _format_location_item(location) -> str:
84 if isinstance(location, int):
85 return f"item {location + 1}"
87 return str(location)
90class SettingsConfigSettings(BaseModel):
91 acronym_scheme: AcronymScheme = AcronymScheme.two_letter_limit
92 yml_path: str
93 source_root: Optional[str] = None
95 @validator("acronym_scheme", pre=True)
96 def get_acronym_scheme(cls, value):
97 if isinstance(value, str):
98 return value.lower()
100 return value
102 @validator("source_root")
103 def get_source_root(cls, value, values):
104 if os.path.isabs(value):
105 return value
107 yml_dir = os.path.dirname(values["yml_path"])
108 return os.path.abspath(os.path.join(yml_dir, value))
111class SettingsConfig(BaseModel):
112 yml_path: str
113 settings: Optional[SettingsConfigSettings] = None
114 projects: Dict[str, Project] = {}
116 @validator("settings", pre=True, always=True)
117 def get_settings(cls, value, values):
118 if value is None:
119 return {"yml_path": values["yml_path"]}
121 if isinstance(value, dict):
122 return {**value, "yml_path": values["yml_path"]}
124 return value
126 @validator("projects", pre=True)
127 def get_projects(cls, value, values):
128 if not isinstance(value, dict):
129 raise ValueError("value is not a valid dict")
131 acronym_scheme = values["settings"].acronym_scheme
132 for project_name, item in value.items():
133 if not isinstance(item, dict):
134 break
136 value[project_name] = {**item, "acronym_scheme": acronym_scheme}
138 return value
140 @root_validator()
141 def validate_projects(cls, values):
142 projects = values.get("projects")
143 if not isinstance(projects, dict):
144 return values
146 projects_with_use_tests = {
147 project_name: project
148 for project_name, project in projects.items()
149 if project.use_tests
150 }
152 errors = []
153 for project_name, project in projects_with_use_tests.items():
154 use_tests_name = project.use_tests.name
155 loc = ("projects", project_name, "use_tests")
157 # Make sure "use_tests" item refers to an actual project
158 if use_tests_name not in projects:
159 errors.append(
160 ErrorWrapper(
161 ValueError(
162 f"refers to a non-existent project {project.use_tests.name}"
163 ),
164 loc=loc,
165 )
166 )
167 # Make sure one "use_tests" item does not refer to another "use_tests" item
168 elif use_tests_name in projects_with_use_tests:
169 errors.append(
170 ErrorWrapper(
171 ValueError(
172 f'refers to another "use_tests" project {use_tests_name}'
173 ),
174 loc=loc,
175 )
176 )
177 # Make sure "use_tests" item refers to a project with tests
178 elif not projects[use_tests_name].tests:
179 errors.append(
180 ErrorWrapper(
181 ValueError(
182 f'refers to project {use_tests_name}, which has no "tests" item'
183 ),
184 loc=loc,
185 )
186 )
187 # Otherwise, set the tests that the "use_tests" item refers to with the tests renamed
188 else:
189 project.set_tests(projects[use_tests_name])
191 if errors:
192 raise ValidationError(errors, model=cls)
194 return values
197class SettingsParser:
198 def __init__(self, project_root):
199 self._project_root = project_root
200 self._yml_path = None
201 self._acronym_scheme = None
202 self._projects = None
203 self._source_root = None
204 self._yml_path = self._locate_yml()
206 yml = None
207 if self._yml_path is not None:
208 yml = self._parse_yml()
209 else:
210 self._yml_path = project_root
211 warn(f'.glotter.yml not found in directory "{project_root}"')
213 if yml is None:
214 yml = {}
216 if not isinstance(yml, dict):
217 error_and_exit(".glotter.yml does not contain a dict")
219 config = SettingsConfig(**yml, yml_path=self._yml_path)
220 self._acronym_scheme = config.settings.acronym_scheme
221 self._source_root = config.settings.source_root
222 self._projects = config.projects
224 @property
225 def project_root(self):
226 return self._project_root
228 @property
229 def yml_path(self):
230 return self._yml_path
232 @property
233 def source_root(self):
234 return self._source_root
236 @property
237 def acronym_scheme(self):
238 return self._acronym_scheme
240 @property
241 def projects(self):
242 return self._projects
244 def _parse_yml(self):
245 with open(self._yml_path, "r", encoding="utf-8") as f:
246 contents = f.read()
248 return yaml.safe_load(contents)
250 def _locate_yml(self):
251 for root, _, files in os.walk(self._project_root):
252 if ".glotter.yml" in files:
253 path = os.path.abspath(root)
254 return os.path.join(path, ".glotter.yml")
256 return None