diff --git a/.gitignore b/.gitignore index 1bbc728..33807df 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ .pytest_cache/ __pycache__/ dist/ +/.venv/ +/.hypothesis/ diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..43d7a3a --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[flake8] +max-line-length = 90 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_smoke.py b/tests/test_smoke.py new file mode 100644 index 0000000..218447a --- /dev/null +++ b/tests/test_smoke.py @@ -0,0 +1,29 @@ + + +from io import StringIO +from pathlib import Path +import subprocess +import sys +from typeforce._cli import main + + +def test_smoke(tmp_path: Path): + # create venv + venv_path = tmp_path / '.venv' + cmd = [sys.executable, '-m', 'venv', str(venv_path)] + subprocess.run(cmd, check=True) + + # install deps + exe_path = venv_path / 'bin' / 'python' + cmd = [str(exe_path), '-m', 'pip', 'install', 'astroid'] + subprocess.run(cmd, check=True) + + # run typeforce + stdout = StringIO() + code = main(['--exe', str(exe_path)], stdout) + assert code == 0 + + stdout.seek(0) + output = stdout.read() + assert 'astroid' in output + assert 'patched' in output diff --git a/typeforce/__init__.py b/typeforce/__init__.py index bab1daf..c78bb87 100644 --- a/typeforce/__init__.py +++ b/typeforce/__init__.py @@ -1,4 +1,4 @@ """CLI tool to propagate py.typed into third-party libraries. """ -__version__ = '0.2.0' +__version__ = '0.2.1' diff --git a/typeforce/_cli.py b/typeforce/_cli.py index b72eda3..4af794d 100644 --- a/typeforce/_cli.py +++ b/typeforce/_cli.py @@ -6,8 +6,14 @@ def main(argv: typing.List[str], stream: typing.TextIO) -> int: parser = ArgumentParser() - parser.add_argument('--exe', default=sys.executable) - parser.add_argument('--dry', action='store_true') + parser.add_argument( + '--exe', default=sys.executable, + help='path to python executable to patch', + ) + parser.add_argument( + '--dry', action='store_true', + help='don\'t patch, only print changes', + ) args = parser.parse_args(argv) explorer = Explorer(exe=args.exe, dry=args.dry) for line in explorer.run(): diff --git a/typeforce/_core.py b/typeforce/_core.py index 947798e..7c3afe4 100644 --- a/typeforce/_core.py +++ b/typeforce/_core.py @@ -12,17 +12,7 @@ MARKER = 'py.typed' TEMPLATE = '{name:30} {status}' -SCRIPT = """ -import sys -import {name} as module - -for name in dir(module): - obj = getattr(module, name) - ann = getattr(obj, '__annotations__', None) - if ann: - sys.exit(21) -sys.exit(22) -""" +ROOT = Path(__file__).parent class Module(typing.NamedTuple): @@ -36,12 +26,16 @@ def patched(self) -> bool: @property def stubbed(self) -> bool: + if self.path.name.endswith('-stubs'): + return True new_name = self.path.name + '-stubs' return (self.path.parent / new_name).exists() @property def annotated(self) -> bool: - script = SCRIPT.format(name=self.path.name) + script_path = ROOT / '_script_inspect.py' + script = script_path.read_text() + script = script.replace('MODULE_NAME', self.path.name) result = subprocess.run([self.exe, '-c', script]) assert result.returncode in (21, 22) return result.returncode == 21 @@ -74,10 +68,17 @@ class Explorer(typing.NamedTuple): @property def root(self) -> Path: - cmd = [self.exe, '-c', 'print(__import__("site").USER_SITE)'] + script_path = ROOT / '_script_sites.py' + cmd = [self.exe, str(script_path)] result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE) - path = result.stdout.decode().strip() - return Path(path) + sites = result.stdout.decode().strip().splitlines() + print(sites) + for site in sites: + path = Path(site) + if path.exists(): + print(path) + return path + raise LookupError('cannot find site packages directory') @property def modules(self) -> typing.Iterable[Module]: diff --git a/typeforce/_script_inspect.py b/typeforce/_script_inspect.py new file mode 100644 index 0000000..20cd2af --- /dev/null +++ b/typeforce/_script_inspect.py @@ -0,0 +1,9 @@ +import sys +import MODULE_NAME as module # type: ignore[import] + +for name in dir(module): + obj = getattr(module, name) + ann = getattr(obj, '__annotations__', None) + if ann: + sys.exit(21) +sys.exit(22) diff --git a/typeforce/_script_sites.py b/typeforce/_script_sites.py new file mode 100644 index 0000000..6f18721 --- /dev/null +++ b/typeforce/_script_sites.py @@ -0,0 +1,4 @@ +import site + +sites = site.getsitepackages() + [site.getusersitepackages()] +print('\n'.join(sites))