Coverage for glotter/settings.py: 99%
132 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-03-01 21:54 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-03-01 21:54 +0000
1import os
2from dataclasses import dataclass
3from functools import cache
4from typing import Dict, Optional
6from glotter_core.project import AcronymScheme
7from glotter_core.settings import CoreSettingsParser
8from pydantic import (
9 BaseModel,
10 Field,
11 ValidationError,
12 ValidationInfo,
13 field_validator,
14 model_validator,
15)
17from glotter.errors import get_error_details, raise_simple_validation_error, raise_validation_errors
18from glotter.project import Project
19from glotter.utils import error_and_exit, indent
22@cache
23def get_settings():
24 """
25 Get Settings as a singleton
26 """
27 return Settings()
30class Settings:
31 def __init__(self):
32 self._project_root = os.getcwd()
33 try:
34 self._parser = SettingsParser(self._project_root)
35 except ValidationError as e:
36 error_and_exit(_format_validate_error(e))
38 self._projects = self._parser.projects
39 self._source_root = self._parser.source_root or self._project_root
40 self._test_mappings = {}
42 @property
43 def projects(self):
44 return self._projects
46 @property
47 def project_root(self):
48 return self._project_root
50 @property
51 def source_root(self):
52 return self._source_root
54 @source_root.setter
55 def source_root(self, value):
56 self._source_root = value or self._project_root
58 @property
59 def test_mappings(self):
60 return self._test_mappings
62 def get_test_mapping_name(self, project_type):
63 mappings = self._test_mappings.get(project_type)
64 if mappings:
65 return [func.__name__ for func in mappings]
66 return []
68 def add_test_mapping(self, project_type, func):
69 if project_type not in self._projects:
70 raise KeyError(f"Project type {project_type} was not found in glotter.yml")
72 if project_type not in self._test_mappings:
73 self._test_mappings[project_type] = []
74 self._test_mappings[project_type].append(func)
76 def verify_project_type(self, name):
77 return name.lower() in self.projects
80def _format_validate_error(validation_error: ValidationError) -> str:
81 error_msgs = []
82 for error in validation_error.errors():
83 error_msgs.append(
84 "- "
85 + ".".join(
86 _format_location_item(location)
87 for location in error["loc"]
88 if location != "__root__"
89 )
90 + ":"
91 )
92 error_msgs.append(indent(error["msg"], 4))
94 return "Errors found in the following items:\n" + "\n".join(error_msgs)
97def _format_location_item(location) -> str:
98 if isinstance(location, int):
99 return f"item {location + 1}"
101 return str(location)
104class SettingsConfigSettings(BaseModel):
105 acronym_scheme: AcronymScheme = Field(AcronymScheme.two_letter_limit, validate_default=True)
106 yml_path: str
107 source_root: Optional[str] = None
109 @field_validator("acronym_scheme", mode="before")
110 @classmethod
111 def get_acronym_scheme(cls, value):
112 if isinstance(value, str):
113 return value.lower()
115 return value
117 @field_validator("source_root", mode="after")
118 @classmethod
119 def get_source_root(cls, value, info: ValidationInfo):
120 if os.path.isabs(value):
121 return value
123 yml_dir = os.path.dirname(info.data["yml_path"])
124 return os.path.abspath(os.path.join(yml_dir, value))
127class SettingsConfig(BaseModel):
128 yml_path: str
129 settings: Optional[SettingsConfigSettings] = Field(None, validate_default=True)
130 projects: Dict[str, Project] = {}
132 @field_validator("settings", mode="before")
133 @classmethod
134 def get_settings(cls, value, info: ValidationInfo):
135 if value is None:
136 return {"yml_path": info.data["yml_path"]}
138 if isinstance(value, dict):
139 return {**value, "yml_path": info.data["yml_path"]}
141 return value
143 @field_validator("projects", mode="before")
144 @classmethod
145 def get_projects(cls, value, info: ValidationInfo):
146 if not isinstance(value, dict):
147 raise_simple_validation_error(cls, "Input should be a valid dictionary", value)
149 acronym_scheme = info.data["settings"].acronym_scheme
150 for project_name, item in value.items():
151 if not isinstance(item, dict):
152 break
154 value[project_name] = {**item, "acronym_scheme": acronym_scheme}
156 return value
158 @model_validator(mode="after")
159 def validate_projects(self):
160 projects = self.projects
161 if not isinstance(projects, dict): 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true
162 return self
164 projects_with_use_tests = {
165 project_name: project for project_name, project in projects.items() if project.use_tests
166 }
168 errors = []
169 for project_name, project in projects_with_use_tests.items():
170 use_tests_name = project.use_tests.name
171 loc = ("projects", project_name, "use_tests")
173 # Make sure "use_tests" item refers to an actual project
174 if use_tests_name not in projects:
175 errors.append(
176 get_error_details(
177 f"Refers to a non-existent project {use_tests_name}",
178 loc=loc,
179 input=use_tests_name,
180 )
181 )
182 # Make sure one "use_tests" item does not refer to another "use_tests" item
183 elif use_tests_name in projects_with_use_tests:
184 errors.append(
185 get_error_details(
186 f'Refers to another "use_tests" project {use_tests_name}',
187 loc=loc,
188 input=use_tests_name,
189 )
190 )
191 # Make sure "use_tests" item refers to a project with tests
192 elif not projects[use_tests_name].tests:
193 errors.append(
194 get_error_details(
195 f'Refers to project {use_tests_name}, which has no "tests" item',
196 loc=loc,
197 input=use_tests_name,
198 )
199 )
200 # Otherwise, set the tests that the "use_tests" item refers to with the tests renamed
201 else:
202 project.set_tests(projects[use_tests_name])
204 if errors:
205 raise_validation_errors(self.__class__, errors)
207 return self
210@dataclass(frozen=True)
211class SettingsParser(CoreSettingsParser):
212 def __init__(self, project_root):
213 try:
214 super().__init__(project_root)
215 except ValueError as exc:
216 error_and_exit(str(exc))
218 config = SettingsConfig(**self.yml, yml_path=self.yml_path)
219 object.__setattr__(self, "acronym_scheme", config.settings.acronym_scheme)
220 object.__setattr__(self, "source_root", config.settings.source_root)
221 object.__setattr__(self, "projects", config.projects)