From 5ca0241c7ac7fc07188281058e052044e8f9ec36 Mon Sep 17 00:00:00 2001 From: Harrison Mutai Date: Fri, 29 Nov 2024 16:33:02 +0000 Subject: [PATCH] fix(handoff): correct 8-bit modulo csum calculation Fix the handoff 8-bit modulo checksum calculation to ensure we never get a checksum larger than 8 bits. The previous calculation failed to truncate the sum at the final step in update_checksum Change-Id: Ice0b72eb139af90f416adeff157d337646d6201a Signed-off-by: Harrison Mutai --- tools/tlc/tests/test_transfer_list.py | 17 ++++++++++++----- tools/tlc/tlc/te.py | 5 ++++- tools/tlc/tlc/tl.py | 10 ++++++---- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tools/tlc/tests/test_transfer_list.py b/tools/tlc/tests/test_transfer_list.py index e5f90b07c..e00280b95 100644 --- a/tools/tlc/tests/test_transfer_list.py +++ b/tools/tlc/tests/test_transfer_list.py @@ -88,12 +88,19 @@ def test_calculate_te_sum_of_bytes(tag_id, data): assert te.sum_of_bytes == csum -@pytest.mark.parametrize(("tag_id", "data"), test_entries) -def test_calculate_tl_checksum(tag_id, data): +def test_calc_tl_checksum(tmpdir, random_entries): + tl_file = tmpdir.join("tl.bin") + tl = TransferList(0x1000) - tl.add_transfer_entry(tag_id, data) - assert tl.sum_of_bytes() == 0 + for id, data in random_entries(10): + tl.add_transfer_entry(id, data) + + assert sum(tl.to_bytes()) % 256 == 0 + + # Write the transfer list to a file and check that the sum of bytes is 0 + tl.write_to_file(tl_file) + assert sum(tl_file.read_binary()) % 256 == 0 def test_empty_transfer_list_blob(tmpdir): @@ -129,7 +136,7 @@ def test_write_multiple_tes_to_file(tmpdir, random_entries): """Check that we can create a TL with multiple TE's.""" test_file = tmpdir.join("test_tl_blob.bin") tl = TransferList(0x4000) - _test_entries = random_entries() + _test_entries = list(random_entries()) for tag_id, data in _test_entries: tl.add_transfer_entry(tag_id, data) diff --git a/tools/tlc/tlc/te.py b/tools/tlc/tlc/te.py index cf7aa67ca..0b6b53269 100644 --- a/tools/tlc/tlc/te.py +++ b/tools/tlc/tlc/te.py @@ -48,7 +48,10 @@ class TransferEntry: @property def sum_of_bytes(self) -> int: - return (sum(self.header_to_bytes()) + sum(self.data)) % 256 + return sum(self.to_bytes()) % 256 + + def to_bytes(self) -> bytes: + return self.header_to_bytes() + self.data def header_to_bytes(self) -> bytes: return self.id.to_bytes(3, "little") + struct.pack( diff --git a/tools/tlc/tlc/tl.py b/tools/tlc/tlc/tl.py index 98d2205bf..b7eb48615 100644 --- a/tools/tlc/tlc/tl.py +++ b/tools/tlc/tlc/tl.py @@ -189,13 +189,15 @@ class TransferList: def update_checksum(self) -> None: """Calculates the checksum based on the sum of bytes.""" - self.checksum = 256 - ((self.sum_of_bytes() - self.checksum) % 256) + self.checksum = (256 - (self.sum_of_bytes() - self.checksum)) % 256 + assert self.checksum <= 0xFF + + def to_bytes(self) -> bytes: + return self.header_to_bytes() + b"".join([te.to_bytes() for te in self.entries]) def sum_of_bytes(self) -> int: """Sum of all bytes between the base address and the end of that last TE (modulo 0xff).""" - return ( - sum(self.header_to_bytes()) + sum(te.sum_of_bytes for te in self.entries) - ) % 256 + return (sum(self.to_bytes())) % 256 def get_entry(self, tag_id: int) -> Optional[TransferEntry]: for te in self.entries: