Coverage for glotter/settings.py: 100%
159 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-09-13 19:09 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-09-13 19:09 +0000
1import os
2from typing import Dict, Optional
3from warnings import warn
5import yaml
6from pydantic import BaseModel, ValidationError, root_validator, validator
7from pydantic.error_wrappers import ErrorWrapper
9from glotter.project import AcronymScheme, Project
10from glotter.singleton import Singleton
11from glotter.utils import error_and_exit, indent
14class Settings(metaclass=Singleton):
15 def __init__(self):
16 self._project_root = os.getcwd()
17 try:
18 self._parser = SettingsParser(self._project_root)
19 except ValidationError as e:
20 error_and_exit(_format_validate_error(e))
22 self._projects = self._parser.projects
23 self._source_root = self._parser.source_root or self._project_root
24 self._test_mappings = {}
26 @property
27 def projects(self):
28 return self._projects
30 @property
31 def project_root(self):
32 return self._project_root
34 @property
35 def source_root(self):
36 return self._source_root
38 @source_root.setter
39 def source_root(self, value):
40 self._source_root = value or self._project_root
42 @property
43 def test_mappings(self):
44 return self._test_mappings
46 def get_test_mapping_name(self, project_type):
47 mappings = self._test_mappings.get(project_type)
48 if mappings:
49 return [func.__name__ for func in mappings]
50 return []
52 def add_test_mapping(self, project_type, func):
53 if project_type not in self._projects:
54 raise KeyError(f"Project type {project_type} was not found in glotter.yml")
56 if project_type not in self._test_mappings:
57 self._test_mappings[project_type] = []
58 self._test_mappings[project_type].append(func)
60 def verify_project_type(self, name):
61 return name.lower() in self.projects
64def _format_validate_error(validation_error: ValidationError) -> str:
65 error_msgs = []
66 for error in validation_error.errors():
67 error_msgs.append(
68 "- "
69 + " -> ".join(
70 _format_location_item(location)
71 for location in error["loc"]
72 if location != "__root__"
73 )
74 + ":"
75 )
76 error_msgs.append(indent(error["msg"], 4))
78 return "Errors found in the following items:\n" + "\n".join(error_msgs)
81def _format_location_item(location) -> str:
82 if isinstance(location, int):
83 return f"item {location + 1}"
85 return str(location)
88class SettingsConfigSettings(BaseModel):
89 acronym_scheme: AcronymScheme = AcronymScheme.two_letter_limit
90 yml_path: str
91 source_root: Optional[str] = None
93 @validator("acronym_scheme", pre=True)
94 def get_acronym_scheme(cls, value):
95 if isinstance(value, str):
96 return value.lower()
98 return value
100 @validator("source_root")
101 def get_source_root(cls, value, values):
102 if os.path.isabs(value):
103 return value
105 yml_dir = os.path.dirname(values["yml_path"])
106 return os.path.abspath(os.path.join(yml_dir, value))
109class SettingsConfig(BaseModel):
110 yml_path: str
111 settings: Optional[SettingsConfigSettings] = None
112 projects: Dict[str, Project] = {}
114 @validator("settings", pre=True, always=True)
115 def get_settings(cls, value, values):
116 if value is None:
117 return {"yml_path": values["yml_path"]}
119 if isinstance(value, dict):
120 return {**value, "yml_path": values["yml_path"]}
122 return value
124 @validator("projects", pre=True)
125 def get_projects(cls, value, values):
126 if not isinstance(value, dict):
127 raise ValueError("value is not a valid dict")
129 acronym_scheme = values["settings"].acronym_scheme
130 for project_name, item in value.items():
131 if not isinstance(item, dict):
132 break
134 value[project_name] = {**item, "acronym_scheme": acronym_scheme}
136 return value
138 @root_validator()
139 def validate_projects(cls, values):
140 projects = values.get("projects")
141 if not isinstance(projects, dict):
142 return values
144 projects_with_use_tests = {
145 project_name: project for project_name, project in projects.items() if project.use_tests
146 }
148 errors = []
149 for project_name, project in projects_with_use_tests.items():
150 use_tests_name = project.use_tests.name
151 loc = ("projects", project_name, "use_tests")
153 # Make sure "use_tests" item refers to an actual project
154 if use_tests_name not in projects:
155 errors.append(
156 ErrorWrapper(
157 ValueError(f"refers to a non-existent project {project.use_tests.name}"),
158 loc=loc,
159 )
160 )
161 # Make sure one "use_tests" item does not refer to another "use_tests" item
162 elif use_tests_name in projects_with_use_tests:
163 errors.append(
164 ErrorWrapper(
165 ValueError(f'refers to another "use_tests" project {use_tests_name}'),
166 loc=loc,
167 )
168 )
169 # Make sure "use_tests" item refers to a project with tests
170 elif not projects[use_tests_name].tests:
171 errors.append(
172 ErrorWrapper(
173 ValueError(
174 f'refers to project {use_tests_name}, which has no "tests" item'
175 ),
176 loc=loc,
177 )
178 )
179 # Otherwise, set the tests that the "use_tests" item refers to with the tests renamed
180 else:
181 project.set_tests(projects[use_tests_name])
183 if errors:
184 raise ValidationError(errors, model=cls)
186 return values
189class SettingsParser:
190 def __init__(self, project_root):
191 self._project_root = project_root
192 self._yml_path = None
193 self._acronym_scheme = None
194 self._projects = None
195 self._source_root = None
196 self._yml_path = self._locate_yml()
198 yml = None
199 if self._yml_path is not None:
200 yml = self._parse_yml()
201 else:
202 self._yml_path = project_root
203 warn(f'.glotter.yml not found in directory "{project_root}"')
205 if yml is None:
206 yml = {}
208 if not isinstance(yml, dict):
209 error_and_exit(".glotter.yml does not contain a dict")
211 config = SettingsConfig(**yml, yml_path=self._yml_path)
212 self._acronym_scheme = config.settings.acronym_scheme
213 self._source_root = config.settings.source_root
214 self._projects = config.projects
216 @property
217 def project_root(self):
218 return self._project_root
220 @property
221 def yml_path(self):
222 return self._yml_path
224 @property
225 def source_root(self):
226 return self._source_root
228 @property
229 def acronym_scheme(self):
230 return self._acronym_scheme
232 @property
233 def projects(self):
234 return self._projects
236 def _parse_yml(self):
237 with open(self._yml_path, "r", encoding="utf-8") as f:
238 contents = f.read()
240 return yaml.safe_load(contents)
242 def _locate_yml(self):
243 for root, _, files in os.walk(self._project_root):
244 if ".glotter.yml" in files:
245 path = os.path.abspath(root)
246 return os.path.join(path, ".glotter.yml")
248 return None