fix(handoff): correct 8-bit modulo csum calculation

Fix the handoff 8-bit modulo checksum calculation to ensure we never get
a checksum larger than 8 bits. The previous calculation failed to
truncate the sum at the final step in update_checksum

Change-Id: Ice0b72eb139af90f416adeff157d337646d6201a
Signed-off-by: Harrison Mutai <harrison.mutai@arm.com>
This commit is contained in:
Harrison Mutai 2024-11-29 16:33:02 +00:00 committed by J-Alves
parent 157c619786
commit 5ca0241c7a
3 changed files with 22 additions and 10 deletions

View file

@ -88,12 +88,19 @@ def test_calculate_te_sum_of_bytes(tag_id, data):
assert te.sum_of_bytes == csum assert te.sum_of_bytes == csum
@pytest.mark.parametrize(("tag_id", "data"), test_entries) def test_calc_tl_checksum(tmpdir, random_entries):
def test_calculate_tl_checksum(tag_id, data): tl_file = tmpdir.join("tl.bin")
tl = TransferList(0x1000) tl = TransferList(0x1000)
tl.add_transfer_entry(tag_id, data) for id, data in random_entries(10):
assert tl.sum_of_bytes() == 0 tl.add_transfer_entry(id, data)
assert sum(tl.to_bytes()) % 256 == 0
# Write the transfer list to a file and check that the sum of bytes is 0
tl.write_to_file(tl_file)
assert sum(tl_file.read_binary()) % 256 == 0
def test_empty_transfer_list_blob(tmpdir): def test_empty_transfer_list_blob(tmpdir):
@ -129,7 +136,7 @@ def test_write_multiple_tes_to_file(tmpdir, random_entries):
"""Check that we can create a TL with multiple TE's.""" """Check that we can create a TL with multiple TE's."""
test_file = tmpdir.join("test_tl_blob.bin") test_file = tmpdir.join("test_tl_blob.bin")
tl = TransferList(0x4000) tl = TransferList(0x4000)
_test_entries = random_entries() _test_entries = list(random_entries())
for tag_id, data in _test_entries: for tag_id, data in _test_entries:
tl.add_transfer_entry(tag_id, data) tl.add_transfer_entry(tag_id, data)

View file

@ -48,7 +48,10 @@ class TransferEntry:
@property @property
def sum_of_bytes(self) -> int: def sum_of_bytes(self) -> int:
return (sum(self.header_to_bytes()) + sum(self.data)) % 256 return sum(self.to_bytes()) % 256
def to_bytes(self) -> bytes:
return self.header_to_bytes() + self.data
def header_to_bytes(self) -> bytes: def header_to_bytes(self) -> bytes:
return self.id.to_bytes(3, "little") + struct.pack( return self.id.to_bytes(3, "little") + struct.pack(

View file

@ -189,13 +189,15 @@ class TransferList:
def update_checksum(self) -> None: def update_checksum(self) -> None:
"""Calculates the checksum based on the sum of bytes.""" """Calculates the checksum based on the sum of bytes."""
self.checksum = 256 - ((self.sum_of_bytes() - self.checksum) % 256) self.checksum = (256 - (self.sum_of_bytes() - self.checksum)) % 256
assert self.checksum <= 0xFF
def to_bytes(self) -> bytes:
return self.header_to_bytes() + b"".join([te.to_bytes() for te in self.entries])
def sum_of_bytes(self) -> int: def sum_of_bytes(self) -> int:
"""Sum of all bytes between the base address and the end of that last TE (modulo 0xff).""" """Sum of all bytes between the base address and the end of that last TE (modulo 0xff)."""
return ( return (sum(self.to_bytes())) % 256
sum(self.header_to_bytes()) + sum(te.sum_of_bytes for te in self.entries)
) % 256
def get_entry(self, tag_id: int) -> Optional[TransferEntry]: def get_entry(self, tag_id: int) -> Optional[TransferEntry]:
for te in self.entries: for te in self.entries: