Coverage for glotter/auto_gen_test.py: 99%

186 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-08 17:57 +0000

1from functools import partial 

2from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple 

3 

4from pydantic import ( 

5 BaseModel, 

6 Field, 

7 ValidationError, 

8 ValidationInfo, 

9 field_validator, 

10 model_validator, 

11) 

12 

13from glotter.errors import ( 

14 get_error_details, 

15 raise_simple_validation_error, 

16 raise_validation_errors, 

17 validate_str_list, 

18) 

19from glotter.utils import indent, quote 

20 

21TransformationScalarFuncT = Callable[[str, str], Tuple[str, str]] 

22TransformationDictFuncT = Callable[[List[str], str, str], Tuple[str, str]] 

23 

24 

25class AutoGenParam(BaseModel): 

26 """Object used to auto-generated a test parameter""" 

27 

28 name: str = "" 

29 input: Optional[str] = None 

30 expected: Any 

31 

32 @field_validator("expected") 

33 def validate_expected(cls, value): 

34 """ 

35 Validate expected value 

36 

37 :param value: Expected value 

38 :return: Original expected value 

39 :raises: :exc:`ValidationError` if invalid expected value 

40 """ 

41 

42 if isinstance(value, dict): 

43 if not value: 

44 raise_simple_validation_error(cls, "Too few items", value) 

45 

46 if len(value) > 1: 

47 raise_simple_validation_error(cls, "Too many items", value) 

48 

49 key, item = tuple(*value.items()) 

50 if key == "exec": 

51 if not isinstance(item, str): 

52 raise_simple_validation_error( 

53 cls, "Input should be a valid string", item, (key,) 

54 ) 

55 if not item: 

56 raise_simple_validation_error(cls, "Value must not be empty", item, (key,)) 

57 elif key != "self": 

58 raise_simple_validation_error(cls, 'Invalid "expected" type', item) 

59 elif isinstance(value, list): 

60 validate_str_list(cls, value) 

61 elif not isinstance(value, str): 

62 raise_simple_validation_error( 

63 cls, "Input should be a valid string, list, or dictionary", value 

64 ) 

65 

66 return value 

67 

68 def get_pytest_param(self) -> str: 

69 """ 

70 Get pytest parameter string 

71 

72 :return: pytest parameter string if name is not empty, empty string otherwise 

73 """ 

74 

75 if not self.name: 

76 return "" 

77 

78 input_param = self.input 

79 if isinstance(input_param, str): 

80 input_param = quote(input_param) 

81 

82 expected_output = self.expected 

83 if isinstance(expected_output, str): 

84 expected_output = quote(expected_output) 

85 

86 return f"pytest.param({input_param}, {expected_output}, id={quote(self.name)}),\n" 

87 

88 

89def _append_method_to_actual(method: str, actual_var: str, expected_var) -> Tuple[str, str]: 

90 return f"{actual_var}.{method}()", expected_var 

91 

92 

93def _append_method_to_expected(method: str, actual_var: str, expected_var: str) -> Tuple[str, str]: 

94 return actual_var, f"{expected_var}.{method}()" 

95 

96 

97def _remove_chars(values: List[str], actual_var: str, expected_var: str) -> Tuple[str, str]: 

98 for value in values: 

99 actual_var += f'.replace({quote(value)}, "")' 

100 

101 return actual_var, expected_var 

102 

103 

104def _strip_chars(values: List[str], actual_var: str, expected_var: str) -> Tuple[str, str]: 

105 for value in values: 

106 actual_var += f".strip({quote(value)})" 

107 

108 return actual_var, expected_var 

109 

110 

111def _unique_sort(actual_var, expected_var): 

112 return f"sorted(set({actual_var}))", f"sorted(set({expected_var}))" 

113 

114 

115class AutoGenTest(BaseModel): 

116 """Object used to auto-generated a test""" 

117 

118 name: Annotated[str, Field(strict=True, min_length=1, pattern="^[a-zA-Z][0-9a-zA-Z_]*$")] 

119 requires_parameters: bool = False 

120 inputs: Annotated[List[str], Field(strict=True, min_length=1)] = Field( 

121 ["Input"], validate_default=True 

122 ) 

123 params: Annotated[List[AutoGenParam], Field(strict=True, min_length=1)] = Field( 

124 None, validate_default=True 

125 ) 

126 transformations: List[Any] = [] 

127 

128 SCALAR_TRANSFORMATION_FUNCS: ClassVar[Dict[str, TransformationScalarFuncT]] = { 

129 "strip": partial(_append_method_to_actual, "strip"), 

130 "splitlines": partial(_append_method_to_actual, "splitlines"), 

131 "lower": partial(_append_method_to_actual, "lower"), 

132 "any_order": _unique_sort, 

133 "strip_expected": partial(_append_method_to_expected, "strip"), 

134 "splitlines_expected": partial(_append_method_to_expected, "splitlines"), 

135 } 

136 DICT_TRANSFORMATION_FUNCS: ClassVar[Dict[str, TransformationDictFuncT]] = { 

137 "remove": _remove_chars, 

138 "strip": _strip_chars, 

139 } 

140 

141 @field_validator("inputs", mode="before") 

142 @classmethod 

143 def validate_inputs(cls, values): 

144 """ 

145 Validate each input 

146 

147 :param values: Inputs to validate 

148 :return: Original inputs 

149 :raises: :exc:`ValidationError` if input invalid 

150 """ 

151 

152 validate_str_list(cls, values) 

153 return values 

154 

155 @field_validator("params", mode="before") 

156 @classmethod 

157 def validate_params(cls, values, info: ValidationInfo): 

158 """ 

159 Validate each parameter 

160 

161 :param values: Parameters to validate 

162 :param info: Test item 

163 :return: Original parameters 

164 :raises: :exc:`ValidationError` if project requires parameters but no input, no name, 

165 or empty name. Also, raised if no expected output 

166 """ 

167 

168 errors = [] 

169 field_is_required = "Field is required when parameters required" 

170 

171 for index, value in enumerate(values): 

172 if info.data.get("requires_parameters"): 

173 if not isinstance(value, dict): 

174 errors.append( 

175 get_error_details("Input should be a valid dictionary", (index,), value) 

176 ) 

177 continue 

178 

179 if "name" not in value: 

180 errors.append(get_error_details(field_is_required, (index, "name"), value)) 

181 elif isinstance(value["name"], str) and not value["name"]: 

182 errors.append( 

183 get_error_details( 

184 "Value must not be empty when parameters required", 

185 (index, "name"), 

186 value, 

187 ) 

188 ) 

189 

190 if "input" not in value: 

191 errors.append(get_error_details(field_is_required, (index, "input"), value)) 

192 

193 if "expected" not in value: 

194 errors.append(get_error_details(field_is_required, (index, "expected"), value)) 

195 

196 if errors: 

197 # Collect inner errors 

198 for index, value in enumerate(values): 

199 if not isinstance(value, dict): 

200 continue 

201 

202 try: 

203 AutoGenParam.model_validate(value) 

204 except ValidationError as exc: 

205 for err in exc.errors(): 

206 loc = (index,) + tuple(err.get("loc", ())) 

207 msg = err.get("msg") or str(err.get("type", "value_error")) 

208 input_val = err.get("input", value) 

209 errors.append(get_error_details(msg, loc, input_val)) 

210 

211 raise_validation_errors(cls, errors) 

212 

213 return values 

214 

215 @field_validator("transformations", mode="before") 

216 @classmethod 

217 def validate_transformation(cls, values): 

218 """ 

219 Validate each transformation 

220 

221 :param values: Transformations to validate 

222 :return: Original values 

223 :raises: :exc:`ValidationError` if Invalid transformation 

224 """ 

225 

226 if not isinstance(values, list): 

227 raise_simple_validation_error(cls, "Input should be a valid list", values) 

228 

229 errors = [] 

230 for index, value in enumerate(values): 

231 if isinstance(value, str): 

232 if value not in cls.SCALAR_TRANSFORMATION_FUNCS: 

233 errors.append( 

234 get_error_details(f'Invalid transformation "{value}"', (index,), value) 

235 ) 

236 elif isinstance(value, dict): 

237 key = str(*value) 

238 if key not in cls.DICT_TRANSFORMATION_FUNCS: 

239 errors.append( 

240 get_error_details(f'Invalid transformation "{key}"', (index,), value) 

241 ) 

242 else: 

243 errors += validate_str_list(cls, value[key], (index, key), raise_exc=False) 

244 else: 

245 errors.append( 

246 get_error_details( 

247 "Input should be a valid string or dictionary", (index,), value 

248 ) 

249 ) 

250 

251 if errors: 

252 raise_validation_errors(cls, errors) 

253 

254 return values 

255 

256 def transform_vars(self) -> Tuple[str, str]: 

257 """ 

258 Transform variables using the specified transformations 

259 

260 :return: Transformed actual and expected variables 

261 """ 

262 

263 actual_var = "actual" 

264 expected_var = "expected" 

265 for transfomation in self.transformations: 

266 if isinstance(transfomation, str): 

267 actual_var, expected_var = self.SCALAR_TRANSFORMATION_FUNCS[transfomation]( 

268 actual_var, expected_var 

269 ) 

270 else: 

271 key, item = tuple(*transfomation.items()) 

272 actual_var, expected_var = self.DICT_TRANSFORMATION_FUNCS[key]( 

273 item, actual_var, expected_var 

274 ) 

275 

276 return actual_var, expected_var 

277 

278 def get_pytest_params(self) -> str: 

279 """ 

280 Get pytest parameters 

281 

282 :return: pytest parameters 

283 """ 

284 

285 if not self.requires_parameters: 

286 return "" 

287 

288 pytest_params = "".join( 

289 indent(param.get_pytest_param(), 8) for param in self.params 

290 ).strip() 

291 return f"""\ 

292@pytest.mark.parametrize( 

293 ("in_params", "expected"), 

294 [ 

295 {pytest_params} 

296 ] 

297) 

298""" 

299 

300 def get_test_function_and_run(self, project_name_underscores: str) -> str: 

301 """ 

302 Get test function and run command 

303 

304 :param project_name_underscores: Project name with underscores between each word 

305 :return: Test function and run command 

306 """ 

307 

308 func_params = "" 

309 run_param = "" 

310 if self.requires_parameters: 

311 func_params = "in_params, expected, " 

312 run_param = "params=in_params" 

313 

314 return f"""\ 

315def test_{self.name}({func_params}{project_name_underscores}): 

316 actual = {project_name_underscores}.run({run_param}) 

317""" 

318 

319 def get_expected_output(self, project_name_underscores: str) -> str: 

320 """ 

321 Get test code that gets the expected output 

322 

323 :param project_name_underscores: Project name with underscores between each word 

324 :return: Test code that gets the expected output 

325 

326 """ 

327 

328 if self.requires_parameters: 

329 return "" 

330 

331 expected_output = self.params[0].expected 

332 if isinstance(expected_output, str): 

333 expected_output = quote(expected_output) 

334 elif isinstance(expected_output, dict): 

335 return _get_expected_file(project_name_underscores, expected_output) 

336 

337 return f"expected = {expected_output}\n" 

338 

339 def generate_test(self, project_name_underscores: str) -> str: 

340 """ 

341 Generate test code 

342 

343 :param project_name_underscores: Project name with underscores between each word 

344 :return: Test code 

345 """ 

346 

347 test_code = "@project_test(PROJECT_NAME)\n" 

348 test_code += self.get_pytest_params() 

349 test_code += self.get_test_function_and_run(project_name_underscores) 

350 test_code += indent(self.get_expected_output(project_name_underscores), 4) 

351 actual_var, expected_var = self.transform_vars() 

352 test_code += indent(_get_assert(actual_var, expected_var, self.params[0].expected), 4) 

353 return test_code 

354 

355 

356def _get_expected_file(project_name_underscores: str, expected_output: Dict[str, str]) -> str: 

357 if "exec" in expected_output: 

358 script = quote(expected_output["exec"]) 

359 return f"expected = {project_name_underscores}.exec({script})\n" 

360 

361 test_code = f"""\ 

362with open({project_name_underscores}.full_path, "r", encoding="utf-8") as file: 

363 expected = file.read() 

364""" 

365 

366 if "self" in expected_output: 366 ↛ 375line 366 didn't jump to line 375 because the condition on line 366 was always true

367 test_code += """\ 

368diff_len = len(actual) - len(expected) 

369if diff_len > 0: 

370 expected += "\\n" 

371elif diff_len < 0: 

372 actual += "\\n" 

373""" 

374 

375 return test_code 

376 

377 

378def _get_assert(actual_var: str, expected_var: str, expected_output) -> str: 

379 if isinstance(expected_output, list): 

380 return f"""\ 

381actual_list = {actual_var} 

382expected_list = {expected_var} 

383assert len(actual_list) == len(expected_list), "Length not equal" 

384for index in range(len(expected_list)): 

385 assert actual_list[index] == expected_list[index], f"Item {{index + 1}} is not equal" 

386""" 

387 

388 test_code = "" 

389 if actual_var != "actual": 

390 test_code += f"actual = {actual_var}\n" 

391 

392 if expected_var != "expected": 

393 test_code += f"expected = {expected_var}\n" 

394 

395 return f"{test_code}assert actual == expected\n" 

396 

397 

398class AutoGenUseTests(BaseModel): 

399 """Object used to specify what tests to use""" 

400 

401 name: str 

402 search: Annotated[str, Field(strict=True, pattern="^[0-9a-zA-Z_]*$")] = "" 

403 replace: Annotated[str, Field(strict=True, pattern="^[0-9a-zA-Z_]*$")] = "" 

404 

405 @model_validator(mode="before") 

406 @classmethod 

407 def validate_search_with_replace(cls, values): 

408 """ 

409 Validate that if either search or replace is specified, both must be specified 

410 

411 :param values: Values to validate 

412 :return: Original values 

413 :raise: `exc`:ValidationError if either search or replace is specified, both are specified 

414 """ 

415 

416 if "search" in values and "replace" not in values: 

417 raise_simple_validation_error( 

418 cls, '"search" item specified without "replace" item', values 

419 ) 

420 

421 if "search" not in values and "replace" in values: 

422 raise_simple_validation_error( 

423 cls, '"replace" item specified without "search" item', values 

424 ) 

425 

426 return values