Coverage for glotter/auto_gen_test.py: 99%

220 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-03-01 21:54 +0000

1from functools import partial 

2from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple 

3 

4from pydantic import ( 

5 BaseModel, 

6 Field, 

7 ValidationError, 

8 ValidationInfo, 

9 field_validator, 

10 model_validator, 

11) 

12 

13from glotter.errors import ( 

14 get_error_details, 

15 raise_simple_validation_error, 

16 raise_validation_errors, 

17 validate_str_dict, 

18 validate_str_list, 

19) 

20from glotter.utils import indent, quote 

21 

22TransformationScalarFuncT = Callable[[str, str], Tuple[str, str]] 

23TransformationDictFuncT = Callable[[List[str], str, str], Tuple[str, str]] 

24 

25 

26class AutoGenParam(BaseModel): 

27 """Object used to auto-generated a test parameter""" 

28 

29 name: str = "" 

30 input: Optional[str] = None 

31 expected: Any 

32 

33 @field_validator("expected") 

34 def validate_expected(cls, value): 

35 """ 

36 Validate expected value 

37 

38 :param value: Expected value 

39 :return: Original expected value 

40 :raises: :exc:`ValidationError` if invalid expected value 

41 """ 

42 

43 if isinstance(value, dict): 

44 if not value: 

45 raise_simple_validation_error(cls, "Too few items", value) 

46 

47 if len(value) > 1: 

48 raise_simple_validation_error(cls, "Too many items", value) 

49 

50 key, item = tuple(*value.items()) 

51 if key in ("exec", "string"): 

52 if not isinstance(item, str): 

53 raise_simple_validation_error( 

54 cls, "Input should be a valid string", item, (key,) 

55 ) 

56 if not item: 

57 raise_simple_validation_error(cls, "Value must not be empty", item, (key,)) 

58 elif key != "self": 

59 raise_simple_validation_error(cls, 'Invalid "expected" type', item) 

60 elif isinstance(value, list): 

61 validate_str_list(cls, value) 

62 elif not isinstance(value, str): 

63 raise_simple_validation_error( 

64 cls, "Input should be a valid string, list, or dictionary", value 

65 ) 

66 

67 return value 

68 

69 def get_pytest_param(self) -> str: 

70 """ 

71 Get pytest parameter string 

72 

73 :return: pytest parameter string if name is not empty, empty string otherwise 

74 """ 

75 

76 if not self.name: 

77 return "" 

78 

79 input_param = self.input 

80 if isinstance(input_param, str): 

81 input_param = quote(input_param) 

82 

83 expected_output = self.expected 

84 if isinstance(expected_output, str): 

85 expected_output = quote(expected_output) 

86 elif isinstance(expected_output, dict) and "string" in expected_output: 

87 expected_output = self.get_constant_variable_name() 

88 

89 return f"pytest.param({input_param}, {expected_output}, id={quote(self.name)}),\n" 

90 

91 def get_constant_variable_name(self) -> str: 

92 """ 

93 Get constant variable name 

94 

95 :return: constant variable name 

96 """ 

97 

98 variable_name = "" 

99 if isinstance(self.expected, dict) and "string" in self.expected: 

100 variable_name = self.expected["string"].upper() 

101 

102 return variable_name 

103 

104 

105def _append_method_to_actual(method: str, actual_var: str, expected_var) -> Tuple[str, str]: 

106 return f"{actual_var}.{method}()", expected_var 

107 

108 

109def _append_method_to_expected(method: str, actual_var: str, expected_var: str) -> Tuple[str, str]: 

110 return actual_var, f"{expected_var}.{method}()" 

111 

112 

113def _remove_chars(values: List[str], actual_var: str, expected_var: str) -> Tuple[str, str]: 

114 for value in values: 

115 actual_var += f'.replace({quote(value)}, "")' 

116 

117 return actual_var, expected_var 

118 

119 

120def _strip_chars(values: List[str], actual_var: str, expected_var: str) -> Tuple[str, str]: 

121 for value in values: 

122 actual_var += f".strip({quote(value)})" 

123 

124 return actual_var, expected_var 

125 

126 

127def _unique_sort(actual_var, expected_var): 

128 return f"sorted(set({actual_var}))", f"sorted(set({expected_var}))" 

129 

130 

131class AutoGenTest(BaseModel): 

132 """Object used to auto-generated a test""" 

133 

134 name: Annotated[str, Field(strict=True, min_length=1, pattern="^[a-zA-Z][0-9a-zA-Z_]*$")] 

135 requires_parameters: bool = False 

136 inputs: Annotated[List[str], Field(strict=True, min_length=1)] = Field( 

137 ["Input"], validate_default=True 

138 ) 

139 params: Annotated[List[AutoGenParam], Field(strict=True, min_length=1)] = Field( 

140 None, validate_default=True 

141 ) 

142 strings: Dict[str, str] = {} 

143 transformations: List[Any] = [] 

144 

145 SCALAR_TRANSFORMATION_FUNCS: ClassVar[Dict[str, TransformationScalarFuncT]] = { 

146 "strip": partial(_append_method_to_actual, "strip"), 

147 "splitlines": partial(_append_method_to_actual, "splitlines"), 

148 "lower": partial(_append_method_to_actual, "lower"), 

149 "any_order": _unique_sort, 

150 "strip_expected": partial(_append_method_to_expected, "strip"), 

151 "splitlines_expected": partial(_append_method_to_expected, "splitlines"), 

152 } 

153 DICT_TRANSFORMATION_FUNCS: ClassVar[Dict[str, TransformationDictFuncT]] = { 

154 "remove": _remove_chars, 

155 "strip": _strip_chars, 

156 } 

157 

158 @field_validator("inputs", mode="before") 

159 @classmethod 

160 def validate_inputs(cls, values): 

161 """ 

162 Validate each input 

163 

164 :param values: Inputs to validate 

165 :return: Original inputs 

166 :raises: :exc:`ValidationError` if input invalid 

167 """ 

168 

169 validate_str_list(cls, values) 

170 return values 

171 

172 @field_validator("params", mode="before") 

173 @classmethod 

174 def validate_params(cls, values, info: ValidationInfo): 

175 """ 

176 Validate each parameter 

177 

178 :param values: Parameters to validate 

179 :param info: Test item 

180 :return: Original parameters 

181 :raises: :exc:`ValidationError` if project requires parameters but no input, no name, 

182 or empty name. Also, raised if no expected output 

183 """ 

184 

185 errors = [] 

186 field_is_required = "Field is required when parameters required" 

187 

188 for index, value in enumerate(values): 

189 if info.data.get("requires_parameters"): 

190 if not isinstance(value, dict): 

191 errors.append( 

192 get_error_details("Input should be a valid dictionary", (index,), value) 

193 ) 

194 continue 

195 

196 if "name" not in value: 

197 errors.append(get_error_details(field_is_required, (index, "name"), value)) 

198 elif isinstance(value["name"], str) and not value["name"]: 

199 errors.append( 

200 get_error_details( 

201 "Value must not be empty when parameters required", 

202 (index, "name"), 

203 value, 

204 ) 

205 ) 

206 

207 if "input" not in value: 

208 errors.append(get_error_details(field_is_required, (index, "input"), value)) 

209 

210 if "expected" not in value: 

211 errors.append(get_error_details(field_is_required, (index, "expected"), value)) 

212 

213 if errors: 

214 # Collect inner errors 

215 for index, value in enumerate(values): 

216 if not isinstance(value, dict): 

217 continue 

218 

219 try: 

220 AutoGenParam.model_validate(value) 

221 except ValidationError as exc: 

222 for err in exc.errors(): 

223 loc = (index,) + tuple(err.get("loc", ())) 

224 msg = err.get("msg") or str(err.get("type", "value_error")) 

225 input_val = err.get("input", value) 

226 errors.append(get_error_details(msg, loc, input_val)) 

227 

228 raise_validation_errors(cls, errors) 

229 

230 return values 

231 

232 @field_validator("strings", mode="before") 

233 @classmethod 

234 def validate_strings(cls, values): 

235 """ 

236 Validate each string 

237 

238 :param values: Strings to validate 

239 :return: Original strings 

240 :raises: :exc:`ValidationError` if strings invalid 

241 """ 

242 

243 validate_str_dict(cls, values) 

244 return values 

245 

246 @field_validator("transformations", mode="before") 

247 @classmethod 

248 def validate_transformation(cls, values): 

249 """ 

250 Validate each transformation 

251 

252 :param values: Transformations to validate 

253 :return: Original values 

254 :raises: :exc:`ValidationError` if Invalid transformation 

255 """ 

256 

257 if not isinstance(values, list): 

258 raise_simple_validation_error(cls, "Input should be a valid list", values) 

259 

260 errors = [] 

261 for index, value in enumerate(values): 

262 if isinstance(value, str): 

263 if value not in cls.SCALAR_TRANSFORMATION_FUNCS: 

264 errors.append( 

265 get_error_details(f'Invalid transformation "{value}"', (index,), value) 

266 ) 

267 elif isinstance(value, dict): 

268 key = str(*value) 

269 if key not in cls.DICT_TRANSFORMATION_FUNCS: 

270 errors.append( 

271 get_error_details(f'Invalid transformation "{key}"', (index,), value) 

272 ) 

273 else: 

274 errors += validate_str_list(cls, value[key], (index, key), raise_exc=False) 

275 else: 

276 errors.append( 

277 get_error_details( 

278 "Input should be a valid string or dictionary", (index,), value 

279 ) 

280 ) 

281 

282 if errors: 

283 raise_validation_errors(cls, errors) 

284 

285 return values 

286 

287 @model_validator(mode="after") 

288 def validate_test_strings(self): 

289 """ 

290 Validate each test string 

291 """ 

292 

293 errors = [] 

294 if self.requires_parameters: 

295 for index, param in enumerate(self.params): 

296 loc = ("params", index, "expected", "string") 

297 expected = param.expected 

298 if ( 

299 isinstance(param.expected, dict) 

300 and "string" in expected 

301 and expected["string"] not in self.strings 

302 ): 

303 expected_string = expected["string"] 

304 errors.append( 

305 get_error_details( 

306 f"Refers to a non-existent string {expected_string}", 

307 loc=loc, 

308 input=expected_string, 

309 ) 

310 ) 

311 

312 if errors: 

313 raise_validation_errors(self.__class__, errors) 

314 

315 return self 

316 

317 def transform_vars(self) -> Tuple[str, str]: 

318 """ 

319 Transform variables using the specified transformations 

320 

321 :return: Transformed actual and expected variables 

322 """ 

323 

324 actual_var = "actual" 

325 expected_var = "expected" 

326 for transfomation in self.transformations: 

327 if isinstance(transfomation, str): 

328 actual_var, expected_var = self.SCALAR_TRANSFORMATION_FUNCS[transfomation]( 

329 actual_var, expected_var 

330 ) 

331 else: 

332 key, item = tuple(*transfomation.items()) 

333 actual_var, expected_var = self.DICT_TRANSFORMATION_FUNCS[key]( 

334 item, actual_var, expected_var 

335 ) 

336 

337 return actual_var, expected_var 

338 

339 def get_constant_variables(self) -> str: 

340 """ 

341 Get contant variables 

342 

343 :return: constant variables 

344 """ 

345 

346 constant_variables = {} 

347 if self.requires_parameters: 

348 for param in self.params: 

349 variable_name = param.get_constant_variable_name() 

350 if variable_name and variable_name not in constant_variables: 

351 constant_variables[variable_name] = self.strings[param.expected["string"]] 

352 

353 return "".join(f"{name} = {quote(value)}\n" for name, value in constant_variables.items()) 

354 

355 def get_pytest_params(self) -> str: 

356 """ 

357 Get pytest parameters 

358 

359 :return: pytest parameters 

360 """ 

361 

362 if not self.requires_parameters: 

363 return "" 

364 

365 pytest_params = "".join( 

366 indent(param.get_pytest_param(), 8) for param in self.params 

367 ).strip() 

368 return f"""\ 

369@pytest.mark.parametrize( 

370 ("in_params", "expected"), 

371 [ 

372 {pytest_params} 

373 ] 

374) 

375""" 

376 

377 def get_test_function_and_run(self, project_name_underscores: str) -> str: 

378 """ 

379 Get test function and run command 

380 

381 :param project_name_underscores: Project name with underscores between each word 

382 :return: Test function and run command 

383 """ 

384 

385 func_params = "" 

386 run_param = "" 

387 if self.requires_parameters: 

388 func_params = "in_params, expected, " 

389 run_param = "params=in_params" 

390 

391 return f"""\ 

392def test_{self.name}({func_params}{project_name_underscores}): 

393 actual = {project_name_underscores}.run({run_param}) 

394""" 

395 

396 def get_expected_output(self, project_name_underscores: str) -> str: 

397 """ 

398 Get test code that gets the expected output 

399 

400 :param project_name_underscores: Project name with underscores between each word 

401 :return: Test code that gets the expected output 

402 

403 """ 

404 

405 if self.requires_parameters: 

406 return "" 

407 

408 expected_output = self.params[0].expected 

409 if isinstance(expected_output, str): 

410 expected_output = quote(expected_output) 

411 elif isinstance(expected_output, dict): 

412 return _get_expected_file(project_name_underscores, expected_output) 

413 

414 return f"expected = {expected_output}\n" 

415 

416 def generate_test(self, project_name_underscores: str) -> str: 

417 """ 

418 Generate test code 

419 

420 :param project_name_underscores: Project name with underscores between each word 

421 :return: Test code 

422 """ 

423 

424 test_code = "@project_test(PROJECT_NAME)\n" 

425 test_code += self.get_pytest_params() 

426 test_code += self.get_test_function_and_run(project_name_underscores) 

427 test_code += indent(self.get_expected_output(project_name_underscores), 4) 

428 actual_var, expected_var = self.transform_vars() 

429 test_code += indent(_get_assert(actual_var, expected_var, self.params[0].expected), 4) 

430 return test_code 

431 

432 

433def _get_expected_file(project_name_underscores: str, expected_output: Dict[str, str]) -> str: 

434 if "exec" in expected_output: 

435 script = quote(expected_output["exec"]) 

436 return f"expected = {project_name_underscores}.exec({script})\n" 

437 

438 test_code = f"""\ 

439with open({project_name_underscores}.full_path, "r", encoding="utf-8") as file: 

440 expected = file.read() 

441""" 

442 

443 if "self" in expected_output: 443 ↛ 452line 443 didn't jump to line 452 because the condition on line 443 was always true

444 test_code += """\ 

445diff_len = len(actual) - len(expected) 

446if diff_len > 0: 

447 expected += "\\n" 

448elif diff_len < 0: 

449 actual += "\\n" 

450""" 

451 

452 return test_code 

453 

454 

455def _get_assert(actual_var: str, expected_var: str, expected_output) -> str: 

456 if isinstance(expected_output, list): 

457 return f"""\ 

458actual_list = {actual_var} 

459expected_list = {expected_var} 

460assert len(actual_list) == len(expected_list), "Length not equal" 

461for index in range(len(expected_list)): 

462 assert actual_list[index] == expected_list[index], f"Item {{index + 1}} is not equal" 

463""" 

464 

465 test_code = "" 

466 if actual_var != "actual": 

467 test_code += f"actual = {actual_var}\n" 

468 

469 if expected_var != "expected": 

470 test_code += f"expected = {expected_var}\n" 

471 

472 return f"{test_code}assert actual == expected\n" 

473 

474 

475class AutoGenUseTests(BaseModel): 

476 """Object used to specify what tests to use""" 

477 

478 name: str 

479 search: Annotated[str, Field(strict=True, pattern="^[0-9a-zA-Z_]*$")] = "" 

480 replace: Annotated[str, Field(strict=True, pattern="^[0-9a-zA-Z_]*$")] = "" 

481 

482 @model_validator(mode="before") 

483 @classmethod 

484 def validate_search_with_replace(cls, values): 

485 """ 

486 Validate that if either search or replace is specified, both must be specified 

487 

488 :param values: Values to validate 

489 :return: Original values 

490 :raise: `exc`:ValidationError if either search or replace is specified, both are specified 

491 """ 

492 

493 if "search" in values and "replace" not in values: 

494 raise_simple_validation_error( 

495 cls, '"search" item specified without "replace" item', values 

496 ) 

497 

498 if "search" not in values and "replace" in values: 

499 raise_simple_validation_error( 

500 cls, '"replace" item specified without "search" item', values 

501 ) 

502 

503 return values