Coverage for glotter / auto_gen_test.py: 99%

227 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-10 16:13 +0000

1from functools import partial 

2from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple 

3 

4from pydantic import ( 

5 BaseModel, 

6 Field, 

7 ValidationError, 

8 ValidationInfo, 

9 field_validator, 

10 model_validator, 

11) 

12 

13from glotter.errors import ( 

14 get_error_details, 

15 raise_simple_validation_error, 

16 raise_validation_errors, 

17 validate_str_dict, 

18 validate_str_list, 

19) 

20from glotter.utils import indent, quote 

21 

22TransformationScalarFuncT = Callable[[str, str], Tuple[str, str]] 

23TransformationDictFuncT = Callable[[List[str], str, str], Tuple[str, str]] 

24 

25 

26class AutoGenParam(BaseModel): 

27 """Object used to auto-generated a test parameter""" 

28 

29 name: str = "" 

30 input: Optional[str] = None 

31 expected: Any 

32 

33 @field_validator("expected") 

34 def validate_expected(cls, value): 

35 """ 

36 Validate expected value 

37 

38 :param value: Expected value 

39 :return: Original expected value 

40 :raises: :exc:`ValidationError` if invalid expected value 

41 """ 

42 

43 if isinstance(value, dict): 

44 if not value: 

45 raise_simple_validation_error(cls, "Too few items", value) 

46 

47 if len(value) > 1: 

48 raise_simple_validation_error(cls, "Too many items", value) 

49 

50 key, item = tuple(*value.items()) 

51 if key in ("exec", "string"): 

52 if not isinstance(item, str): 

53 raise_simple_validation_error( 

54 cls, "Input should be a valid string", item, (key,) 

55 ) 

56 if not item: 

57 raise_simple_validation_error(cls, "Value must not be empty", item, (key,)) 

58 elif key != "self": 

59 raise_simple_validation_error(cls, 'Invalid "expected" type', item) 

60 elif isinstance(value, list): 

61 validate_str_list(cls, value) 

62 elif not isinstance(value, str): 

63 raise_simple_validation_error( 

64 cls, "Input should be a valid string, list, or dictionary", value 

65 ) 

66 

67 return value 

68 

69 def get_pytest_param(self) -> str: 

70 """ 

71 Get pytest parameter string 

72 

73 :return: pytest parameter string if name is not empty, empty string otherwise 

74 """ 

75 

76 if not self.name: 

77 return "" 

78 

79 input_param = self.input 

80 if isinstance(input_param, str): 

81 input_param = quote(input_param) 

82 

83 expected_output = self.expected 

84 if isinstance(expected_output, str): 

85 expected_output = quote(expected_output) 

86 elif isinstance(expected_output, dict) and "string" in expected_output: 

87 expected_output = self.get_constant_variable_name() 

88 

89 return f"pytest.param({input_param}, {expected_output}, id={quote(self.name)}),\n" 

90 

91 def get_constant_variable_name(self) -> str: 

92 """ 

93 Get constant variable name 

94 

95 :return: constant variable name 

96 """ 

97 

98 variable_name = "" 

99 if isinstance(self.expected, dict) and "string" in self.expected: 

100 variable_name = self.expected["string"].upper() 

101 

102 return variable_name 

103 

104 

105def _append_method_to_actual(method: str, actual_var: str, expected_var) -> Tuple[str, str]: 

106 return f"{actual_var}.{method}()", expected_var 

107 

108 

109def _append_method_to_expected(method: str, actual_var: str, expected_var: str) -> Tuple[str, str]: 

110 return actual_var, f"{expected_var}.{method}()" 

111 

112 

113def _remove_chars(values: List[str], actual_var: str, expected_var: str) -> Tuple[str, str]: 

114 for value in values: 

115 actual_var += f'.replace({quote(value)}, "")' 

116 

117 return actual_var, expected_var 

118 

119 

120def _strip_chars(values: List[str], actual_var: str, expected_var: str) -> Tuple[str, str]: 

121 for value in values: 

122 actual_var += f".strip({quote(value)})" 

123 

124 return actual_var, expected_var 

125 

126 

127def _unique_sort(actual_var, expected_var): 

128 return f"sorted(set({actual_var}))", f"sorted(set({expected_var}))" 

129 

130 

131class AutoGenTest(BaseModel): 

132 """Object used to auto-generated a test""" 

133 

134 VALID_NAME_REGEX: ClassVar[str] = "^[a-zA-Z][0-9a-zA-Z_]*$" 

135 

136 name: Annotated[str, Field(strict=True, min_length=1, pattern=VALID_NAME_REGEX)] 

137 requires_parameters: bool = False 

138 inputs: Annotated[List[str], Field(strict=True, min_length=1)] = Field( 

139 ["Input"], validate_default=True 

140 ) 

141 params: Annotated[List[AutoGenParam], Field(strict=True, min_length=1)] = Field( 

142 None, validate_default=True 

143 ) 

144 strings: Dict[str, str] = {} 

145 transformations: List[Any] = [] 

146 repeat: int = 1 

147 

148 SCALAR_TRANSFORMATION_FUNCS: ClassVar[Dict[str, TransformationScalarFuncT]] = { 

149 "strip": partial(_append_method_to_actual, "strip"), 

150 "splitlines": partial(_append_method_to_actual, "splitlines"), 

151 "lower": partial(_append_method_to_actual, "lower"), 

152 "any_order": _unique_sort, 

153 "strip_expected": partial(_append_method_to_expected, "strip"), 

154 "splitlines_expected": partial(_append_method_to_expected, "splitlines"), 

155 } 

156 DICT_TRANSFORMATION_FUNCS: ClassVar[Dict[str, TransformationDictFuncT]] = { 

157 "remove": _remove_chars, 

158 "strip": _strip_chars, 

159 } 

160 

161 @field_validator("inputs", mode="before") 

162 @classmethod 

163 def validate_inputs(cls, values): 

164 """ 

165 Validate each input 

166 

167 :param values: Inputs to validate 

168 :return: Original inputs 

169 :raises: :exc:`ValidationError` if input invalid 

170 """ 

171 

172 validate_str_list(cls, values) 

173 return values 

174 

175 @field_validator("params", mode="before") 

176 @classmethod 

177 def validate_params(cls, values, info: ValidationInfo): 

178 """ 

179 Validate each parameter 

180 

181 :param values: Parameters to validate 

182 :param info: Test item 

183 :return: Original parameters 

184 :raises: :exc:`ValidationError` if project requires parameters but no input, no name, 

185 or empty name. Also, raised if no expected output 

186 """ 

187 

188 errors = [] 

189 field_is_required = "Field is required when parameters required" 

190 

191 for index, value in enumerate(values): 

192 if info.data.get("requires_parameters"): 

193 if not isinstance(value, dict): 

194 errors.append( 

195 get_error_details("Input should be a valid dictionary", (index,), value) 

196 ) 

197 continue 

198 

199 if "name" not in value: 

200 errors.append(get_error_details(field_is_required, (index, "name"), value)) 

201 elif isinstance(value["name"], str) and not value["name"]: 

202 errors.append( 

203 get_error_details( 

204 "Value must not be empty when parameters required", 

205 (index, "name"), 

206 value, 

207 ) 

208 ) 

209 

210 if "input" not in value: 

211 errors.append(get_error_details(field_is_required, (index, "input"), value)) 

212 

213 if "expected" not in value: 

214 errors.append(get_error_details(field_is_required, (index, "expected"), value)) 

215 

216 if errors: 

217 # Collect inner errors 

218 for index, value in enumerate(values): 

219 if not isinstance(value, dict): 

220 continue 

221 

222 try: 

223 AutoGenParam.model_validate(value) 

224 except ValidationError as exc: 

225 for err in exc.errors(): 

226 loc = (index,) + tuple(err.get("loc", ())) 

227 msg = err.get("msg") or str(err.get("type", "value_error")) 

228 input_val = err.get("input", value) 

229 errors.append(get_error_details(msg, loc, input_val)) 

230 

231 raise_validation_errors(cls, errors) 

232 

233 return values 

234 

235 @field_validator("strings", mode="before") 

236 @classmethod 

237 def validate_strings(cls, values): 

238 """ 

239 Validate each string 

240 

241 :param values: Strings to validate 

242 :return: Original strings 

243 :raises: :exc:`ValidationError` if strings invalid 

244 """ 

245 

246 validate_str_dict(cls, values) 

247 return values 

248 

249 @field_validator("transformations", mode="before") 

250 @classmethod 

251 def validate_transformation(cls, values): 

252 """ 

253 Validate each transformation 

254 

255 :param values: Transformations to validate 

256 :return: Original values 

257 :raises: :exc:`ValidationError` if Invalid transformation 

258 """ 

259 

260 if not isinstance(values, list): 

261 raise_simple_validation_error(cls, "Input should be a valid list", values) 

262 

263 errors = [] 

264 for index, value in enumerate(values): 

265 if isinstance(value, str): 

266 if value not in cls.SCALAR_TRANSFORMATION_FUNCS: 

267 errors.append( 

268 get_error_details(f'Invalid transformation "{value}"', (index,), value) 

269 ) 

270 elif isinstance(value, dict): 

271 key = str(*value) 

272 if key not in cls.DICT_TRANSFORMATION_FUNCS: 

273 errors.append( 

274 get_error_details(f'Invalid transformation "{key}"', (index,), value) 

275 ) 

276 else: 

277 errors += validate_str_list(cls, value[key], (index, key), raise_exc=False) 

278 else: 

279 errors.append( 

280 get_error_details( 

281 "Input should be a valid string or dictionary", (index,), value 

282 ) 

283 ) 

284 

285 if errors: 

286 raise_validation_errors(cls, errors) 

287 

288 return values 

289 

290 @model_validator(mode="after") 

291 def validate_test_strings(self): 

292 """ 

293 Validate each test string 

294 """ 

295 

296 errors = [] 

297 if self.requires_parameters: 

298 for index, param in enumerate(self.params): 

299 loc = ("params", index, "expected", "string") 

300 expected = param.expected 

301 if ( 

302 isinstance(param.expected, dict) 

303 and "string" in expected 

304 and expected["string"] not in self.strings 

305 ): 

306 expected_string = expected["string"] 

307 errors.append( 

308 get_error_details( 

309 f"Refers to a non-existent string {expected_string}", 

310 loc=loc, 

311 input=expected_string, 

312 ) 

313 ) 

314 

315 if errors: 

316 raise_validation_errors(self.__class__, errors) 

317 

318 return self 

319 

320 def transform_vars(self) -> Tuple[str, str]: 

321 """ 

322 Transform variables using the specified transformations 

323 

324 :return: Transformed actual and expected variables 

325 """ 

326 

327 actual_var = "actual" 

328 expected_var = "expected" 

329 for transfomation in self.transformations: 

330 if isinstance(transfomation, str): 

331 actual_var, expected_var = self.SCALAR_TRANSFORMATION_FUNCS[transfomation]( 

332 actual_var, expected_var 

333 ) 

334 else: 

335 key, item = tuple(*transfomation.items()) 

336 actual_var, expected_var = self.DICT_TRANSFORMATION_FUNCS[key]( 

337 item, actual_var, expected_var 

338 ) 

339 

340 return actual_var, expected_var 

341 

342 def get_constant_variables(self) -> str: 

343 """ 

344 Get contant variables 

345 

346 :return: constant variables 

347 """ 

348 

349 constant_variables = {} 

350 if self.requires_parameters: 

351 for param in self.params: 

352 variable_name = param.get_constant_variable_name() 

353 if variable_name and variable_name not in constant_variables: 

354 constant_variables[variable_name] = self.strings[param.expected["string"]] 

355 

356 return "".join(f"{name} = {quote(value)}\n" for name, value in constant_variables.items()) 

357 

358 def get_pytest_params(self) -> str: 

359 """ 

360 Get pytest parameters 

361 

362 :return: pytest parameters 

363 """ 

364 

365 if not self.requires_parameters: 

366 return "" 

367 

368 pytest_params = "".join( 

369 indent(param.get_pytest_param(), 8) for param in self.params 

370 ).strip() 

371 test_code = f"""\ 

372@pytest.mark.parametrize( 

373 ("in_params", "expected"), 

374 [ 

375 {pytest_params} 

376 ] 

377) 

378""" 

379 if self.repeat > 1: 

380 test_code += f"""\ 

381@pytest.mark.parametrize("repeat", range(1, {self.repeat + 1}), ids=lambda x: f"repeat{{x}}") 

382""" 

383 

384 return test_code 

385 

386 def get_test_function_and_run(self, project_name_underscores: str) -> str: 

387 """ 

388 Get test function and run command 

389 

390 :param project_name_underscores: Project name with underscores between each word 

391 :return: Test function and run command 

392 """ 

393 

394 func_params = "" 

395 run_param = "" 

396 if self.repeat > 1: 

397 func_params += "repeat, " 

398 

399 if self.requires_parameters: 

400 func_params += "in_params, expected, " 

401 run_param += "params=in_params" 

402 

403 return f"""\ 

404def test_{self.name}({func_params}{project_name_underscores}): 

405 actual = {project_name_underscores}.run({run_param}) 

406""" 

407 

408 def get_expected_output(self, project_name_underscores: str) -> str: 

409 """ 

410 Get test code that gets the expected output 

411 

412 :param project_name_underscores: Project name with underscores between each word 

413 :return: Test code that gets the expected output 

414 

415 """ 

416 

417 if self.requires_parameters: 

418 return "" 

419 

420 expected_output = self.params[0].expected 

421 if isinstance(expected_output, str): 

422 expected_output = quote(expected_output) 

423 elif isinstance(expected_output, dict): 

424 return _get_expected_file(project_name_underscores, expected_output) 

425 

426 return f"expected = {expected_output}\n" 

427 

428 def generate_test(self, project_name_underscores: str) -> str: 

429 """ 

430 Generate test code 

431 

432 :param project_name_underscores: Project name with underscores between each word 

433 :return: Test code 

434 """ 

435 

436 test_code = "@project_test(PROJECT_NAME)\n" 

437 test_code += self.get_pytest_params() 

438 test_code += self.get_test_function_and_run(project_name_underscores) 

439 test_code += indent(self.get_expected_output(project_name_underscores), 4) 

440 actual_var, expected_var = self.transform_vars() 

441 test_code += indent(_get_assert(actual_var, expected_var, self.params[0].expected), 4) 

442 return test_code 

443 

444 

445def _get_expected_file(project_name_underscores: str, expected_output: Dict[str, str]) -> str: 

446 if "exec" in expected_output: 

447 script = quote(expected_output["exec"]) 

448 return f"expected = {project_name_underscores}.exec({script})\n" 

449 

450 test_code = f"""\ 

451with open({project_name_underscores}.full_path, "r", encoding="utf-8") as file: 

452 expected = file.read() 

453""" 

454 

455 if "self" in expected_output: 455 ↛ 464line 455 didn't jump to line 464 because the condition on line 455 was always true

456 test_code += """\ 

457diff_len = len(actual) - len(expected) 

458if diff_len > 0: 

459 expected += "\\n" 

460elif diff_len < 0: 

461 actual += "\\n" 

462""" 

463 

464 return test_code 

465 

466 

467def _get_assert(actual_var: str, expected_var: str, expected_output) -> str: 

468 if isinstance(expected_output, list): 

469 return f"""\ 

470actual_list = {actual_var} 

471expected_list = {expected_var} 

472assert len(actual_list) == len(expected_list), "Length not equal" 

473for index in range(len(expected_list)): 

474 assert actual_list[index] == expected_list[index], f"Item {{index + 1}} is not equal" 

475""" 

476 

477 test_code = "" 

478 if actual_var != "actual": 

479 test_code += f"actual = {actual_var}\n" 

480 

481 if expected_var != "expected": 

482 test_code += f"expected = {expected_var}\n" 

483 

484 return f"{test_code}assert actual == expected\n" 

485 

486 

487class AutoGenUseTests(BaseModel): 

488 """Object used to specify what tests to use""" 

489 

490 name: str 

491 search: Annotated[str, Field(strict=True, pattern="^[0-9a-zA-Z_]*$")] = "" 

492 replace: Annotated[str, Field(strict=True, pattern="^[0-9a-zA-Z_]*$")] = "" 

493 

494 @model_validator(mode="before") 

495 @classmethod 

496 def validate_search_with_replace(cls, values): 

497 """ 

498 Validate that if either search or replace is specified, both must be specified 

499 

500 :param values: Values to validate 

501 :return: Original values 

502 :raise: `exc`:ValidationError if either search or replace is specified, both are specified 

503 """ 

504 

505 if "search" in values and "replace" not in values: 

506 raise_simple_validation_error( 

507 cls, '"search" item specified without "replace" item', values 

508 ) 

509 

510 if "search" not in values and "replace" in values: 

511 raise_simple_validation_error( 

512 cls, '"replace" item specified without "search" item', values 

513 ) 

514 

515 return values