QUIC Protocol Implementation 1.0
A Python implementation of the QUIC (Quick UDP Internet Connections) protocol.
Loading...
Searching...
No Matches
test_quic.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2"""
3@file test_quic.py
4@brief Unit tests for the quic module.
5"""
6
7import unittest
8import sys
9import os
10import socket
11import time
12from unittest.mock import Mock, patch, MagicMock
13
14# Add the parent directory to sys.path to import the modules
15sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
16
17from quic import QuicConnection
18from stream import Stream
19from frame import FrameStream
20from packet import Packet
21from constants import Constants
22
23
24class TestQuicConnection(unittest.TestCase):
25 """
26 @brief Test cases for the QuicConnection class.
27 """
28
29 def setUp(self):
30 with patch('socket.socket'):
31 self.connection_id = Constants.CONNECTION_ID_SENDER
32 self.local_addr = Constants.ADDR_SENDER
33 self.remote_addr = Constants.ADDR_RECEIVER
35
37 """Test initialization of QuicConnection"""
38 self.assertEqual(self.quic_connection._connection_id, self.connection_id)
39 self.assertEqual(self.quic_connection._local_addr, self.local_addr)
40 self.assertEqual(self.quic_connection._remote_addr, self.remote_addr)
41 self.assertEqual(self.quic_connection._streams, {})
42 self.assertEqual(self.quic_connection._active_streams_ids, [])
43 self.assertEqual(self.quic_connection._streams_counter, Constants.ZERO)
44 self.assertEqual(self.quic_connection._sent_packets_counter, Constants.ZERO)
45 self.assertEqual(self.quic_connection._received_packets_counter, Constants.ZERO)
46 self.assertEqual(self.quic_connection._packet_size, Constants.ZERO)
47 self.assertTrue(self.quic_connection._idle)
48
50 """Test getting a new stream"""
51 initiated_by = Constants.CONNECTION_ID_SENDER
52 direction = Constants.BIDI
53 stream_id = 38
54 mock_stream_instance = Mock()
55
56 # First, make sure _stream_id_generator returns our expected ID
57 with patch.object(self.quic_connection, '_stream_id_generator', return_value=stream_id):
58 # Then, patch the internal get_stream_by_id method to return our mock stream
59 with patch.object(self.quic_connection, '_get_stream_by_id', return_value=mock_stream_instance):
60 # Now test get_stream
61 result = self.quic_connection.get_stream(initiated_by, direction)
62
63 # Verify the result is our mock stream
64 self.assertEqual(result, mock_stream_instance)
65
66 @patch('quic.Stream')
67 def test_add_stream(self, mock_stream):
68 """Test adding a stream"""
69 mock_stream_instance = Mock()
70 mock_stream.return_value = mock_stream_instance
71
72 stream_id = 38
73 initiated_by = True
74 direction = False
75
76 with patch('quic.QuicConnection._add_stream_to_stats_dict') as mock_add_stats:
77 self.quic_connection._add_stream(stream_id, initiated_by, direction)
78
79 # Check stream was added to streams dict
80 self.assertIn(stream_id, self.quic_connection._streams)
81
82 # Check stats were updated
83 mock_add_stats.assert_called_once_with(stream_id)
84
85 # Check streams counter was incremented
86 self.assertEqual(self.quic_connection._streams_counter, Constants.ONE)
87
89 """Test stream ID generation"""
90 # Test client-initiated bidirectional stream
91 self.quic_connection._streams_counter = 1
92 stream_id = self.quic_connection._stream_id_generator(Constants.CONNECTION_ID_SENDER, Constants.BIDI)
93 self.assertEqual(stream_id, 4) # 1 in binary (001) + '00' = 00100 = 4
94
95 # Test client-initiated unidirectional stream
96 self.quic_connection._streams_counter = 2
97 stream_id = self.quic_connection._stream_id_generator(Constants.CONNECTION_ID_SENDER, Constants.UNIDI)
98 self.assertEqual(stream_id, 10) # 2 in binary (010) + '10' = 01010 = 10
99
100 # Test server-initiated bidirectional stream
101 self.quic_connection._streams_counter = 3
102 stream_id = self.quic_connection._stream_id_generator(Constants.CONNECTION_ID_RECEIVER, Constants.BIDI)
103 self.assertEqual(stream_id, 13) # 3 in binary (011) + '01' = 01101 = 13
104
106 """Test adding a stream to the stats dictionary"""
107 stream_id = 38
108
109 with patch('time.time', return_value=12345):
110 self.quic_connection._add_stream_to_stats_dict(stream_id)
111
112 self.assertIn(stream_id, self.quic_connection._stats_dict)
113 self.assertEqual(self.quic_connection._stats_dict[stream_id]['total_bytes'], Constants.ZERO)
114 self.assertEqual(self.quic_connection._stats_dict[stream_id]['total_time'], 12345)
115 self.assertEqual(self.quic_connection._stats_dict[stream_id]['total_packets'], set())
116
117 @patch('quic.Stream')
118 def test_get_stream_by_id_existing(self, mock_stream):
119 """Test getting an existing stream by ID"""
120 stream_id = 38
121 mock_stream_instance = Mock()
122 self.quic_connection._streams[stream_id] = mock_stream_instance
123
124 result = self.quic_connection._get_stream_by_id(stream_id)
125 self.assertEqual(result, mock_stream_instance)
126
128 """Test getting a new stream by ID"""
129 stream_id = 38
130 mock_stream_instance = Mock()
131
132 # First mock is_stream_id_in_dict to return False
133 with patch('quic.QuicConnection._is_stream_id_in_dict', return_value=False):
134 # Then mock _add_stream to add the mock_stream to the streams dict
135 def mock_add_stream(sid, init, dir):
136 self.quic_connection._streams[sid] = mock_stream_instance
137
138 with patch('quic.Stream.is_s_init_by_sid', return_value=True):
139 with patch('quic.Stream.is_uni_by_sid', return_value=False):
140 with patch('quic.QuicConnection._add_stream', side_effect=mock_add_stream) as mock_add:
141 result = self.quic_connection._get_stream_by_id(stream_id)
142
143 self.assertEqual(result, mock_stream_instance)
144 mock_add.assert_called_once_with(stream_id, True, False)
145
147 """Test removing a stream"""
148 stream_id = 38
149 mock_stream = Mock()
150 self.quic_connection._streams[stream_id] = mock_stream
151 self.quic_connection._active_streams_ids.append(stream_id)
152
153 result = self.quic_connection._remove_stream(stream_id)
154
155 self.assertEqual(result, mock_stream)
156 self.assertNotIn(stream_id, self.quic_connection._active_streams_ids)
157 self.assertNotIn(stream_id, self.quic_connection._streams)
158
159 @patch('builtins.open', new_callable=unittest.mock.mock_open, read_data=b'Test file data')
160 def test_add_file_to_stream(self, mock_open):
161 """Test adding a file to a stream"""
162 stream_id = 38
163 path = "test_file.txt"
164
165 with patch('quic.QuicConnection._add_data_to_stream') as mock_add_data:
166 self.quic_connection.add_file_to_stream(stream_id, path)
167
168 mock_open.assert_called_once_with(path, 'rb')
169 mock_add_data.assert_called_once_with(stream_id, b'Test file data')
170
172 """Test adding data to a stream"""
173 stream_id = 38
174 data = b'Test data'
175 mock_stream = Mock()
176
177 with patch('quic.QuicConnection._get_stream_by_id', return_value=mock_stream) as mock_get:
178 with patch('quic.QuicConnection._add_active_stream_id') as mock_add_active:
179 self.quic_connection._add_data_to_stream(stream_id, data)
180
181 mock_get.assert_called_once_with(stream_id)
182 mock_stream.add_data_to_stream.assert_called_once_with(data=data)
183 mock_add_active.assert_called_once_with(stream_id)
184
186 """Test adding an active stream ID"""
187 stream_id = 38
188
189 # Test adding a new active stream ID
190 self.quic_connection._add_active_stream_id(stream_id)
191 self.assertIn(stream_id, self.quic_connection._active_streams_ids)
192
193 # Test adding an already active stream ID (should not duplicate)
194 initial_length = len(self.quic_connection._active_streams_ids)
195 self.quic_connection._add_active_stream_id(stream_id)
196 self.assertEqual(len(self.quic_connection._active_streams_ids), initial_length)
197
199 """Test checking if a stream ID is in the dictionary"""
200 stream_id = 38
201
202 # Test with stream not in dictionary
203 self.assertFalse(self.quic_connection._is_stream_id_in_dict(stream_id))
204
205 # Test with stream in dictionary
206 self.quic_connection._streams[stream_id] = Mock()
207 self.assertTrue(self.quic_connection._is_stream_id_in_dict(stream_id))
208
209 @patch('time.time', return_value=12345)
210 def test_set_start_time(self, mock_time):
211 """Test setting the start time for all streams"""
212 self.quic_connection._stats_dict = {
213 1: {'total_time': 0},
214 2: {'total_time': 0}
215 }
216
217 self.quic_connection._set_start_time()
218
219 for stream in self.quic_connection._stats_dict.values():
220 self.assertEqual(stream['total_time'], 12345)
221
222 @patch('quic.QuicConnection._send_packet_size')
223 @patch('quic.QuicConnection._create_packet')
224 @patch('quic.QuicConnection._send_packet')
225 @patch('quic.QuicConnection._close_connection')
226 def test_send_packets(self, mock_close, mock_send, mock_create, mock_send_size):
227 """Test sending packets"""
228 mock_packet = Mock()
229 mock_create.return_value = mock_packet
230 mock_packet.pack.return_value = b'packed packet'
231
232 # Setup to run the loop once then exit
233 self.quic_connection._active_streams_ids = [38]
234
235 def side_effect(*args, **kwargs):
236 self.quic_connection._active_streams_ids = []
237 return True
238
239 mock_send.side_effect = side_effect
240
241 with patch('time.time', return_value=12345):
242 self.quic_connection.send_packets()
243
244 mock_send_size.assert_called_once()
245 mock_create.assert_called_once()
246 mock_send.assert_called_once_with(b'packed packet')
247 mock_close.assert_called_once()
248
250 """Test sending the packet size"""
251 with patch('quic.PACKET_SIZE', 1500):
252 with patch('quic.QuicConnection._send_packet', return_value=True) as mock_send:
253 result = self.quic_connection._send_packet_size()
254
255 self.assertEqual(self.quic_connection._packet_size, 1500)
256 mock_send.assert_called_once_with((1500).to_bytes(Constants.PACKET_SIZE_BYTES, 'big'))
257 self.assertTrue(result)
258
259 @patch('sys.getsizeof', side_effect=lambda x: 10 if isinstance(x, Packet) else 5)
260 def test_create_packet(self, mock_getsizeof):
261 """Test creating a packet"""
262 self.quic_connection._packet_size = 30
263
264 with patch('quic.Packet') as mock_packet_class:
265 mock_packet = Mock()
266 mock_packet_class.return_value = mock_packet
267
268 with patch('quic.QuicConnection._generate_streams_frames'):
269 with patch('quic.QuicConnection._get_stream_from_active_streams', return_value=None):
270 result = self.quic_connection._create_packet()
271
272 expected_dest_conn_id = 1 # Based on connection_id=0 in setup
273 mock_packet_class.assert_called_once_with(expected_dest_conn_id, Constants.ZERO)
274 self.assertEqual(self.quic_connection._sent_packets_counter, 1)
275
277 """Test generating frames for all active streams"""
278 stream_id1, stream_id2 = 38, 43
279 self.quic_connection._active_streams_ids = [stream_id1, stream_id2]
280
281 mock_stream1, mock_stream2 = Mock(), Mock()
282
283 with patch('quic.PACKET_SIZE', 1500):
284 with patch('quic.QuicConnection._get_stream_by_id', side_effect=[mock_stream1, mock_stream2]) as mock_get:
285 self.quic_connection._generate_streams_frames()
286
287 self.assertEqual(mock_get.call_count, 2)
288 mock_stream1.generate_stream_frames.assert_called_once_with(1500 // Constants.FRAMES_IN_PACKET)
289 mock_stream2.generate_stream_frames.assert_called_once_with(1500 // Constants.FRAMES_IN_PACKET)
290
292 """Test getting a stream from active streams when empty"""
293 self.quic_connection._active_streams_ids = []
294
295 result = self.quic_connection._get_stream_from_active_streams()
296
297 self.assertIsNone(result)
298 self.assertFalse(self.quic_connection._idle)
299
300 @patch('random.choice', return_value=38)
302 """Test getting a stream from active streams"""
303 stream_id = 38
304 mock_stream = Mock()
305 self.quic_connection._streams[stream_id] = mock_stream
306 self.quic_connection._active_streams_ids = [stream_id]
307
308 result = self.quic_connection._get_stream_from_active_streams()
309
310 self.assertEqual(result, mock_stream)
311
313 """Test sending a packet"""
314 packet = b'test packet'
315
316 self.quic_connection._socket.sendto.return_value = len(packet)
317
318 result = self.quic_connection._send_packet(packet)
319
320 self.quic_connection._socket.sendto.assert_called_once_with(packet, self.remote_addr)
321 self.assertTrue(result)
322
323 @patch('quic.QuicConnection._receive_packet')
324 def test_receive_packets(self, mock_receive):
325 """Test receiving packets"""
326 # Setup to run the loop once then exit
327 self.quic_connection._idle = True
328
329 def side_effect(*args, **kwargs):
330 self.quic_connection._idle = False
331
332 mock_receive.side_effect = side_effect
333
334 self.quic_connection.receive_packets()
335
336 self.quic_connection._socket.settimeout.assert_called_once_with(Constants.TIMEOUT)
337 mock_receive.assert_called_once()
338
339 @patch('quic.QuicConnection._handle_received_packet_size')
340 def test_receive_packet_size(self, mock_handle):
341 """Test receiving a packet size"""
342 packet = b'size packet'
343 addr = ('127.0.0.1', 1234)
344
345 self.quic_connection._socket.recvfrom.return_value = (packet, addr)
346
347 with patch('time.time', return_value=12345):
348 with patch('quic.QuicConnection._increment_received_packets_counter') as mock_increment:
349 self.quic_connection._receive_packet()
350
351 mock_handle.assert_called_once_with(packet)
352 mock_increment.assert_called_once()
353
354 @patch('quic.QuicConnection._handle_received_packet')
355 def test_receive_packet_data(self, mock_handle):
356 """Test receiving a data packet"""
357 self.quic_connection._packet_size = 1500
358 packet = b'data packet'
359 addr = ('127.0.0.1', 1234)
360
361 self.quic_connection._socket.recvfrom.return_value = (packet, addr)
362
363 with patch('quic.QuicConnection._increment_received_packets_counter') as mock_increment:
364 self.quic_connection._receive_packet()
365
366 mock_handle.assert_called_once_with(packet)
367 self.assertTrue(mock_increment.called)
368
370 """Test incrementing the received packets counter"""
371 initial = self.quic_connection._received_packets_counter
372
373 self.quic_connection._increment_received_packets_counter()
374
375 self.assertEqual(self.quic_connection._received_packets_counter, initial + 1)
376
377 def test_handle_received_packet_size(self):
378 """Test handling a received packet size"""
379 packet_size = (1500).to_bytes(Constants.PACKET_SIZE_BYTES, 'big')
380
381 with patch('builtins.print') as mock_print:
382 self.quic_connection._handle_received_packet_size(packet_size)
383
384 mock_print.assert_called_once()
385 self.assertEqual(self.quic_connection._packet_size, 1500)
386
387
388if __name__ == '__main__':
389 unittest.main()
test_receive_packets(self, mock_receive)
Definition test_quic.py:324
test_receive_packet_size(self, mock_handle)
Definition test_quic.py:340
test_get_stream_from_active_streams_empty(self)
Definition test_quic.py:291
test_set_start_time(self, mock_time)
Definition test_quic.py:210
test_receive_packet_data(self, mock_handle)
Definition test_quic.py:355
test_get_stream_by_id_existing(self, mock_stream)
Definition test_quic.py:118
test_get_stream_from_active_streams(self, mock_choice)
Definition test_quic.py:301
test_send_packets(self, mock_close, mock_send, mock_create, mock_send_size)
Definition test_quic.py:226
test_create_packet(self, mock_getsizeof)
Definition test_quic.py:260
test_increment_received_packets_counter(self)
Definition test_quic.py:369
test_add_file_to_stream(self, mock_open)
Definition test_quic.py:160
test_add_stream(self, mock_stream)
Definition test_quic.py:67