diff --git a/conda/core/portability.py b/conda/core/portability.py index 846721f9555..23fa160d61b 100644 --- a/conda/core/portability.py +++ b/conda/core/portability.py @@ -1,5 +1,6 @@ # Copyright (C) 2012 Anaconda, Inc # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations from logging import getLogger from os.path import realpath @@ -28,6 +29,17 @@ MAX_SHEBANG_LENGTH = 127 if on_linux else 512 # Not used on Windows +# These are the most common file encodings that we run across when having to replace our +# PREFIX_PLACEHOLDER string. They apply to binary and text formats. +# More information/discussion: https://github.com/conda/conda/pull/9946 +POPULAR_ENCODINGS = ( + "utf-8", + "utf-16-le", + "utf-16-be", + "utf-32-le", + "utf-32-be", +) + class _PaddingError(Exception): pass @@ -74,58 +86,85 @@ def _update_prefix(original_data): subprocess.run(['/usr/bin/codesign', '-s', '-', '-f', realpath(path)], capture_output=True) -def replace_prefix(mode, data, placeholder, new_prefix): - if mode == FileMode.text: - if not on_win: - # if new_prefix contains spaces, it might break the shebang! - # handle this by escaping the spaces early, which will trigger a - # /usr/bin/env replacement later on - newline_pos = data.find(b"\n") - if newline_pos > -1: - shebang_line, rest_of_data = data[:newline_pos], data[newline_pos:] - shebang_placeholder = f"#!{placeholder}".encode() - if shebang_placeholder in shebang_line: - escaped_shebang = f"#!{new_prefix}".replace(" ", "\\ ").encode('utf-8') - shebang_line = shebang_line.replace(shebang_placeholder, escaped_shebang) - data = shebang_line + rest_of_data - # the rest of the file can be replaced normally - data = data.replace(placeholder.encode('utf-8'), new_prefix.encode('utf-8')) - elif mode == FileMode.binary: - data = binary_replace(data, placeholder.encode('utf-8'), new_prefix.encode('utf-8')) - else: - raise CondaIOError("Invalid mode: %r" % mode) +def replace_prefix(mode: FileMode, data: bytes, placeholder: str, new_prefix: str) -> bytes: + """ + Replaces `placeholder` text with the `new_prefix` provided. The `mode` provided can + either be text or binary. + + We use the `POPULAR_ENCODINGS` module level constant defined above to make several + passes at replacing the placeholder. We do this to account for as many encodings as + possible. If this causes any performance problems in the future, it could potentially + be removed (i.e. just using the most popular "utf-8" encoding"). + + More information/discussion available here: https://github.com/conda/conda/pull/9946 + """ + for encoding in POPULAR_ENCODINGS: + if mode == FileMode.text: + if not on_win: + # if new_prefix contains spaces, it might break the shebang! + # handle this by escaping the spaces early, which will trigger a + # /usr/bin/env replacement later on + newline_pos = data.find(b"\n") + if newline_pos > -1: + shebang_line, rest_of_data = data[:newline_pos], data[newline_pos:] + shebang_placeholder = f"#!{placeholder}".encode(encoding) + if shebang_placeholder in shebang_line: + escaped_shebang = f"#!{new_prefix}".replace(" ", "\\ ").encode(encoding) + shebang_line = shebang_line.replace(shebang_placeholder, escaped_shebang) + data = shebang_line + rest_of_data + # the rest of the file can be replaced normally + data = data.replace(placeholder.encode(encoding), new_prefix.encode(encoding)) + elif mode == FileMode.binary: + data = binary_replace( + data, placeholder.encode(encoding), new_prefix.encode(encoding), encoding=encoding + ) + else: + raise CondaIOError("Invalid mode: %r" % mode) return data -def binary_replace(data, a, b): +def binary_replace( + data: bytes, search: bytes, replacement: bytes, encoding: str = "utf-8" +) -> bytes: """ - Perform a binary replacement of `data`, where the placeholder `a` is - replaced with `b` and the remaining string is padded with null characters. + Perform a binary replacement of `data`, where the placeholder `search` is + replaced with `replacement` and the remaining string is padded with null characters. All input arguments are expected to be bytes objects. + + Parameters + ---------- + data: + The bytes object that will be searched and replaced + search: + The bytes object to find + replacement: + The bytes object that will replace `search` + encoding: str + The encoding of the expected string in the binary. """ + zeros = "\0".encode(encoding) if on_win: # on Windows for binary files, we currently only replace a pyzzer-type entry point # we skip all other prefix replacement if has_pyzzer_entry_point(data): - return replace_pyzzer_entry_point_shebang(data, a, b) + return replace_pyzzer_entry_point_shebang(data, search, replacement) else: return data def replace(match): - occurrences = match.group().count(a) - padding = (len(a) - len(b)) * occurrences + occurrences = match.group().count(search) + padding = (len(search) - len(replacement)) * occurrences if padding < 0: raise _PaddingError - return match.group().replace(a, b) + b'\0' * padding + return match.group().replace(search, replacement) + b"\0" * padding original_data_len = len(data) - pat = re.compile(re.escape(a) + b'([^\0]*?)\0') + pat = re.compile(re.escape(search) + b"(?:(?!(?:" + zeros + b")).)*" + zeros) data = pat.sub(replace, data) assert len(data) == original_data_len return data - def has_pyzzer_entry_point(data): pos = data.rfind(b'PK\x05\x06') return pos >= 0 diff --git a/tests/test_install.py b/tests/test_install.py index c7325aed7bc..693ea74ab5c 100644 --- a/tests/test_install.py +++ b/tests/test_install.py @@ -31,49 +31,45 @@ class TestBinaryReplace(unittest.TestCase): @pytest.mark.xfail(on_win, reason="binary replacement on windows skipped", strict=True) def test_simple(self): - self.assertEqual( - binary_replace(b'xxxaaaaaxyz\x00zz', b'aaaaa', b'bbbbb'), - b'xxxbbbbbxyz\x00zz') + for encoding in ["utf-8", "utf-16-le", "utf-16-be", "utf-32-le", "utf-32-be"]: + a = "aaaaa".encode(encoding) + b = "bbbb".encode(encoding) + data = "xxxaaaaaxyz\0zz".encode(encoding) + result = "xxxbbbbxyz\0\0zz".encode(encoding) + self.assertEqual(binary_replace(data, a, b, encoding=encoding), result) @pytest.mark.xfail(on_win, reason="binary replacement on windows skipped", strict=True) def test_shorter(self): self.assertEqual( - binary_replace(b'xxxaaaaaxyz\x00zz', b'aaaaa', b'bbbb'), - b'xxxbbbbxyz\x00\x00zz') + binary_replace(b"xxxaaaaaxyz\x00zz", b"aaaaa", b"bbbb"), b"xxxbbbbxyz\x00\x00zz" + ) @pytest.mark.xfail(on_win, reason="binary replacement on windows skipped", strict=True) def test_too_long(self): - self.assertRaises(_PaddingError, binary_replace, - b'xxxaaaaaxyz\x00zz', b'aaaaa', b'bbbbbbbb') + self.assertRaises( + _PaddingError, binary_replace, b"xxxaaaaaxyz\x00zz", b"aaaaa", b"bbbbbbbb" + ) @pytest.mark.xfail(on_win, reason="binary replacement on windows skipped", strict=True) def test_no_extra(self): - self.assertEqual(binary_replace(b'aaaaa\x00', b'aaaaa', b'bbbbb'), - b'bbbbb\x00') + self.assertEqual(binary_replace(b"aaaaa\x00", b"aaaaa", b"bbbbb"), b"bbbbb\x00") @pytest.mark.xfail(on_win, reason="binary replacement on windows skipped", strict=True) def test_two(self): self.assertEqual( - binary_replace(b'aaaaa\x001234aaaaacc\x00\x00', b'aaaaa', - b'bbbbb'), - b'bbbbb\x001234bbbbbcc\x00\x00') + binary_replace(b"aaaaa\x001234aaaaacc\x00\x00", b"aaaaa", b"bbbbb"), + b"bbbbb\x001234bbbbbcc\x00\x00", + ) @pytest.mark.xfail(on_win, reason="binary replacement on windows skipped", strict=True) def test_spaces(self): - self.assertEqual( - binary_replace(b' aaaa \x00', b'aaaa', b'bbbb'), - b' bbbb \x00') + self.assertEqual(binary_replace(b" aaaa \x00", b"aaaa", b"bbbb"), b" bbbb \x00") @pytest.mark.xfail(on_win, reason="binary replacement on windows skipped", strict=True) def test_multiple(self): - self.assertEqual( - binary_replace(b'aaaacaaaa\x00', b'aaaa', b'bbbb'), - b'bbbbcbbbb\x00') - self.assertEqual( - binary_replace(b'aaaacaaaa\x00', b'aaaa', b'bbb'), - b'bbbcbbb\x00\x00\x00') - self.assertRaises(_PaddingError, binary_replace, - b'aaaacaaaa\x00', b'aaaa', b'bbbbb') + self.assertEqual(binary_replace(b"aaaacaaaa\x00", b"aaaa", b"bbbb"), b"bbbbcbbbb\x00") + self.assertEqual(binary_replace(b"aaaacaaaa\x00", b"aaaa", b"bbb"), b"bbbcbbb\x00\x00\x00") + self.assertRaises(_PaddingError, binary_replace, b"aaaacaaaa\x00", b"aaaa", b"bbbbb") @pytest.mark.integration @pytest.mark.skipif(not on_win, reason="exe entry points only necessary on win")