fix(handoff): correct 8-bit modulo csum calculation

Fix the handoff 8-bit modulo checksum calculation to ensure we never get
a checksum larger than 8 bits. The previous calculation failed to
truncate the sum at the final step in update_checksum

Change-Id: Ice0b72eb139af90f416adeff157d337646d6201a
Signed-off-by: Harrison Mutai <harrison.mutai@arm.com>
This commit is contained in:
Harrison Mutai 2024-11-29 16:33:02 +00:00 committed by J-Alves
parent 157c619786
commit 5ca0241c7a
3 changed files with 22 additions and 10 deletions

View file

@ -88,12 +88,19 @@ def test_calculate_te_sum_of_bytes(tag_id, data):
assert te.sum_of_bytes == csum
@pytest.mark.parametrize(("tag_id", "data"), test_entries)
def test_calculate_tl_checksum(tag_id, data):
def test_calc_tl_checksum(tmpdir, random_entries):
tl_file = tmpdir.join("tl.bin")
tl = TransferList(0x1000)
tl.add_transfer_entry(tag_id, data)
assert tl.sum_of_bytes() == 0
for id, data in random_entries(10):
tl.add_transfer_entry(id, data)
assert sum(tl.to_bytes()) % 256 == 0
# Write the transfer list to a file and check that the sum of bytes is 0
tl.write_to_file(tl_file)
assert sum(tl_file.read_binary()) % 256 == 0
def test_empty_transfer_list_blob(tmpdir):
@ -129,7 +136,7 @@ def test_write_multiple_tes_to_file(tmpdir, random_entries):
"""Check that we can create a TL with multiple TE's."""
test_file = tmpdir.join("test_tl_blob.bin")
tl = TransferList(0x4000)
_test_entries = random_entries()
_test_entries = list(random_entries())
for tag_id, data in _test_entries:
tl.add_transfer_entry(tag_id, data)

View file

@ -48,7 +48,10 @@ class TransferEntry:
@property
def sum_of_bytes(self) -> int:
return (sum(self.header_to_bytes()) + sum(self.data)) % 256
return sum(self.to_bytes()) % 256
def to_bytes(self) -> bytes:
return self.header_to_bytes() + self.data
def header_to_bytes(self) -> bytes:
return self.id.to_bytes(3, "little") + struct.pack(

View file

@ -189,13 +189,15 @@ class TransferList:
def update_checksum(self) -> None:
"""Calculates the checksum based on the sum of bytes."""
self.checksum = 256 - ((self.sum_of_bytes() - self.checksum) % 256)
self.checksum = (256 - (self.sum_of_bytes() - self.checksum)) % 256
assert self.checksum <= 0xFF
def to_bytes(self) -> bytes:
return self.header_to_bytes() + b"".join([te.to_bytes() for te in self.entries])
def sum_of_bytes(self) -> int:
"""Sum of all bytes between the base address and the end of that last TE (modulo 0xff)."""
return (
sum(self.header_to_bytes()) + sum(te.sum_of_bytes for te in self.entries)
) % 256
return (sum(self.to_bytes())) % 256
def get_entry(self, tag_id: int) -> Optional[TransferEntry]:
for te in self.entries: