QUIC Protocol Implementation 1.0
A Python implementation of the QUIC (Quick UDP Internet Connections) protocol.
Loading...
Searching...
No Matches
test_packet.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2"""
3@file test_packet.py
4@brief Unit tests for the packet module.
5"""
6
7import unittest
8import sys
9import os
10from unittest.mock import Mock, patch
11
12# Add the parent directory to sys.path to import the modules
13sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
14
15from packet import PacketHeader, Packet
16from frame import FrameStream
17from constants import Constants
18
19
20class TestPacketHeader(unittest.TestCase):
21 """
22 @brief Test cases for the PacketHeader class.
23 """
24
25 def setUp(self):
26 self.header = PacketHeader(packet_number_length=1)
28 packet_number_length=3,
29 header_form=True,
30 fixed_bit=True,
31 spin_bit=True,
32 key_phase=True,
33 reserved_bits=3
34 )
35
37 """Test initialization with default values"""
38 self.assertEqual(self.header.packet_number_length, 1)
39 self.assertFalse(self.header.header_form)
40 self.assertFalse(self.header.fixed_bit)
41 self.assertFalse(self.header.spin_bit)
42 self.assertFalse(self.header.key_phase)
43 self.assertEqual(self.header.reserved_bits, Constants.ZERO)
44
46 """Test initialization with custom values"""
47 self.assertEqual(self.header_all_fields.packet_number_length, 3)
48 self.assertTrue(self.header_all_fields.header_form)
49 self.assertTrue(self.header_all_fields.fixed_bit)
50 self.assertTrue(self.header_all_fields.spin_bit)
51 self.assertTrue(self.header_all_fields.key_phase)
52 self.assertEqual(self.header_all_fields.reserved_bits, 3)
53
54 def test_pack(self):
55 """Test packing header to bytes"""
56 packed_header = self.header.pack()
57 self.assertEqual(len(packed_header), Constants.HEADER_LENGTH)
58 self.assertEqual(packed_header, bytes([0x01])) # Just packet_number_length=1
59
61 """Test packing header with all fields set"""
62 packed_header = self.header_all_fields.pack()
63 self.assertEqual(len(packed_header), Constants.HEADER_LENGTH)
64
65 expected_byte = (
66 (int(self.header_all_fields.header_form) << Constants.FORM_SHIFT) |
67 (int(self.header_all_fields.fixed_bit) << Constants.FIXED_SHIFT) |
68 (int(self.header_all_fields.spin_bit) << Constants.SPIN_SHIFT) |
69 (self.header_all_fields.reserved_bits << Constants.RES_SHIFT) |
70 (int(self.header_all_fields.key_phase) << Constants.KEY_SHIFT) |
71 self.header_all_fields.packet_number_length
72 )
73 self.assertEqual(packed_header, bytes([expected_byte]))
74
75 def test_unpack(self):
76 """Test unpacking header from bytes"""
77 packed_header = self.header.pack()
78 unpacked_header = PacketHeader.unpack(packed_header)
79
80 self.assertEqual(unpacked_header.packet_number_length, self.header.packet_number_length)
81 self.assertEqual(unpacked_header.header_form, self.header.header_form)
82 self.assertEqual(unpacked_header.fixed_bit, self.header.fixed_bit)
83 self.assertEqual(unpacked_header.spin_bit, self.header.spin_bit)
84 self.assertEqual(unpacked_header.key_phase, self.header.key_phase)
85 self.assertEqual(unpacked_header.reserved_bits, self.header.reserved_bits)
86
88 """Test unpacking header with all fields set"""
89 packed_header = self.header_all_fields.pack()
90 unpacked_header = PacketHeader.unpack(packed_header)
91
92 self.assertEqual(unpacked_header.packet_number_length, self.header_all_fields.packet_number_length)
93 self.assertEqual(unpacked_header.header_form, self.header_all_fields.header_form)
94 self.assertEqual(unpacked_header.fixed_bit, self.header_all_fields.fixed_bit)
95 self.assertEqual(unpacked_header.spin_bit, self.header_all_fields.spin_bit)
96 self.assertEqual(unpacked_header.key_phase, self.header_all_fields.key_phase)
97
98 if self.header_all_fields.reserved_bits:
99 self.assertNotEqual(unpacked_header.reserved_bits, 0)
100
101
102class TestPacket(unittest.TestCase):
103 """
104 @brief Test cases for the Packet class.
105 """
106
107 def setUp(self):
111
112 # Create test frames
113 self.test_frames = [
115 stream_id=10,
116 offset=0,
117 length=len(b'Frame 1'),
118 fin=False,
119 data=b'Frame 1'
120 ),
122 stream_id=20,
123 offset=0,
124 length=len(b'Frame 2'),
125 fin=True,
126 data=b'Frame 2'
127 )
128 ]
129
130 def test_init(self):
131 """Test initialization"""
132 self.assertEqual(self.packet.destination_connection_id, self.destination_connection_id)
133 self.assertEqual(self.packet.packet_number, self.packet_number)
134 self.assertEqual(self.packet.payload, [])
135
136 def test_add_frame(self):
137 """Test adding a frame to the packet"""
138 self.packet.add_frame(self.test_frames[0])
139 self.assertEqual(len(self.packet.payload), 1)
140 self.assertEqual(self.packet.payload[0], self.test_frames[0])
141
142 self.packet.add_frame(self.test_frames[1])
143 self.assertEqual(len(self.packet.payload), 2)
144 self.assertEqual(self.packet.payload[1], self.test_frames[1])
145
147 """Test packing an empty packet"""
148 packed_packet = self.packet.pack()
149
150 # Header (1) + Dest Connection ID (8) + Packet Number (variable, but at least 1)
151 min_expected_length = Constants.HEADER_LENGTH + Constants.DEST_CONNECTION_ID_LENGTH + 1
152 self.assertGreaterEqual(len(packed_packet), min_expected_length)
153
155 """Test packing a packet with frames"""
156 for frame in self.test_frames:
157 self.packet.add_frame(frame)
158
159 packed_packet = self.packet.pack()
160
161 # Header + Dest Connection ID + Packet Number + Encoded frames
162 min_expected_length = (Constants.HEADER_LENGTH +
163 Constants.DEST_CONNECTION_ID_LENGTH +
164 1 + # Minimum packet number length
165 sum(len(frame.encode()) for frame in self.test_frames))
166
167 self.assertGreaterEqual(len(packed_packet), min_expected_length)
168
169 def test_unpack(self):
170 """Test unpacking a packet"""
171 for frame in self.test_frames:
172 self.packet.add_frame(frame)
173
174 packed_packet = self.packet.pack()
175 unpacked_packet = Packet.unpack(packed_packet)
176
177 self.assertEqual(unpacked_packet.destination_connection_id, self.packet.destination_connection_id)
178 self.assertEqual(len(unpacked_packet.payload), len(self.packet.payload))
179
180 for i, frame in enumerate(unpacked_packet.payload):
181 self.assertEqual(frame.stream_id, self.test_frames[i].stream_id)
182 self.assertEqual(frame.offset, self.test_frames[i].offset)
183 self.assertEqual(frame.length, self.test_frames[i].length)
184 self.assertEqual(frame.fin, self.test_frames[i].fin)
185 self.assertEqual(frame.data, self.test_frames[i].data)
186
188 """Test extracting frames from payload bytes"""
189 encoded_frames = b''
190 for frame in self.test_frames:
191 encoded_frames += frame.encode()
192
193 decoded_frames = Packet.get_frames_from_payload_bytes(encoded_frames)
194
195 self.assertEqual(len(decoded_frames), len(self.test_frames))
196 for i, frame in enumerate(decoded_frames):
197 self.assertEqual(frame.stream_id, self.test_frames[i].stream_id)
198 self.assertEqual(frame.offset, self.test_frames[i].offset)
199 self.assertEqual(frame.length, self.test_frames[i].length)
200 self.assertEqual(frame.fin, self.test_frames[i].fin)
201 self.assertEqual(frame.data, self.test_frames[i].data)
202
204 """Test with large values for destination_connection_id and packet_number"""
205 large_dest_id = 2 ** 64 - 1 # Max value for 8 bytes
206 large_packet_number = 2 ** 32 - 1 # Max value for 4 bytes
207
208 packet = Packet(large_dest_id, large_packet_number)
209 packed_packet = packet.pack()
210 unpacked_packet = Packet.unpack(packed_packet)
211
212 self.assertEqual(unpacked_packet.destination_connection_id, large_dest_id)
213 self.assertEqual(unpacked_packet.packet_number, large_packet_number)
214
215
216if __name__ == '__main__':
217 unittest.main()
test_get_frames_from_payload_bytes(self)