Coverage for src/lib/network.py: 13%
171 statements
« prev ^ index » next coverage.py v7.2.7, created at 2025-03-09 17:37 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2025-03-09 17:37 +0000
2import datetime as dt
3from logging import Logger
4from socket import socket as Socket
5from ssl import SSLObject, SSLWantReadError, SSLWantWriteError, SSLError
6from select import select
7from struct import error as StructError
8from typing import Union
11CLIENT_READ_SIZE = 2048
13# RawCommandType = list[list[int, int, list[int]]]
14RawGroupType = int
15RawCommandType = int
16RawPayloadItemsType = list[bytes]
17RawItemType = tuple[RawGroupType, RawCommandType, RawPayloadItemsType]
18RawCommandsType = list[RawItemType]
20class SslHandshakeError(Exception):
21 pass
23# status.commands.append([group, command, payload_items])
24class SocketReadStatus(): # pragma: no cover
25 disconnect: bool
26 commands: RawCommandsType
27 msg: str
29 def __init__(self) -> None:
30 self.disconnect = False
31 self.commands = []
32 self.msg = 'Init'
34class Network():
35 _logger: Logger
36 _ssl_handshake_timeout: dt.timedelta
38 def __init__(self) -> None:
39 pass
41 def _client_read(self, sock: Socket) -> SocketReadStatus:
42 self._logger.debug('_client_read(%s)', sock)
44 status = SocketReadStatus()
46 raw_total = b''
48 reading = True
49 while reading:
50 raw_len = 0
51 try:
52 raw = sock.recv(CLIENT_READ_SIZE)
53 raw_len = len(raw)
54 self._logger.debug('recv raw A: %d %s', raw_len, raw)
56 raw_total += raw
57 except TimeoutError as e:
58 self._logger.debug('TimeoutError: %s', e)
60 status.disconnect = True
61 status.msg = 'TimeoutError'
62 reading = False
63 except ConnectionResetError as e:
64 self._logger.debug('ConnectionResetError: %s', e)
66 status.disconnect = True
67 status.msg = 'ConnectionResetError'
68 reading = False
69 except SSLWantReadError as e:
70 self._logger.debug('SSLWantReadError: %s', e)
72 status.disconnect = True
73 status.msg = 'SSLWantReadError'
74 reading = False
75 else:
76 if raw_len >= CLIENT_READ_SIZE:
77 self._logger.debug('raw_len(%d) >= CLIENT_READ_SIZE(%d)', raw_len, CLIENT_READ_SIZE)
79 reading = True
80 elif raw_len > 0:
81 self._logger.debug('raw_len(%d) > 0', raw_len)
83 reading = False
84 else:
85 self._logger.debug('raw_len(%d) < CLIENT_READ_SIZE(%d)', raw_len, CLIENT_READ_SIZE)
87 reading = False
88 status.disconnect = True
89 status.msg = 'raw_len < CLIENT_READ_SIZE'
91 raw_total_len = len(raw_total)
92 if raw_total_len > 0:
93 self._logger.debug('recv raw B: %d %s', raw_total_len, raw_total)
95 raw_pos = 0
96 while raw_pos < raw_total_len:
97 try:
98 flags_i = raw_total[raw_pos]
99 raw_pos += 1
101 group = raw_total[raw_pos]
102 raw_pos += 1
104 command = raw_total[raw_pos]
105 raw_pos += 1
106 except IndexError as e:
107 self._logger.debug('IndexError: %s', e)
109 status.disconnect = True
110 status.msg = 'array index out of range'
111 return
113 lengths_are_4_bytes = flags_i & 1 != 0
115 try:
116 length = int.from_bytes(raw_total[raw_pos:raw_pos + 4], 'little')
117 raw_pos += 4
118 except StructError as e:
119 self._logger.debug('struct.error: %s', e)
121 status.disconnect = True
122 status.msg = 'unpack error'
123 return
125 payload_raw = raw_total[raw_pos:]
126 payload_items = []
128 self._logger.debug('group: %d', group)
129 self._logger.debug('command: %d', command)
130 self._logger.debug('length: %d %s', length, type(length))
132 pos = 0
133 while pos < length:
134 if lengths_are_4_bytes:
135 item_len = int.from_bytes(payload_raw[pos:pos + 4], 'little')
136 pos += 3
137 else:
138 item_len = payload_raw[pos]
139 pos += 1
141 self._logger.debug('item len: %d %s', item_len, type(item_len))
143 item = payload_raw[pos:pos + item_len]
144 self._logger.debug('item content: %s', item)
146 # TODO: Remove decode here. use Bytes everywhere and decode in server.py when needed
147 payload_items.append(item)
148 pos += item_len
150 status.commands.append([group, command, payload_items])
151 raw_pos += length + 1
152 # self._logger.debug('raw_pos: %d', raw_pos)
154 return status
156 def _client_write(self, sock: Socket, group: int, command: int, data: list = []):
157 self._logger.debug('_client_write(%d, %d, %s)', group, command, data)
159 flag_lengths_are_4_bytes = False
161 for item in data:
162 self._logger.debug('data item: %s %s', type(item), item)
164 if isinstance(item, int):
165 continue
167 item_content_len = len(item)
168 if item_content_len > 255:
169 flag_lengths_are_4_bytes = True
171 self._logger.debug('flag_lengths_are_4_bytes: %s', flag_lengths_are_4_bytes)
173 payload_len_i = 0
174 payload_items = []
175 for item in data:
176 enconded_item = None
177 if isinstance(item, str):
178 enconded_item = item.encode()
179 elif isinstance(item, bytes):
180 enconded_item = item
181 elif isinstance(item, int):
182 enconded_item = item.to_bytes(4, 'little')
184 item_content_len = len(enconded_item)
185 payload_len_i += item_content_len
186 self._logger.debug('item: l=%d t=%s i=%s', item_content_len, type(item), item)
188 if flag_lengths_are_4_bytes:
189 payload_items.append(item_content_len.to_bytes(4, 'little'))
190 payload_len_i += 3
191 else:
192 self._logger.debug('item_content_len: %s', item_content_len.to_bytes(1, 'little'))
193 payload_items.append(item_content_len.to_bytes(1, 'little'))
195 payload_len_i += 1
196 payload_items.append(enconded_item)
198 self._logger.debug('payload_len_i: %d', payload_len_i)
199 self._logger.debug('payload_items: %s', payload_items)
201 payload = b''.join(payload_items)
203 flags_i = 0
204 if flag_lengths_are_4_bytes: # LENs are 4 bytes
205 flags_i |= 1
207 flags_b = flags_i.to_bytes(1, 'little')
208 group_b = group.to_bytes(1, 'little')
209 command_b = command.to_bytes(1, 'little')
210 payload_len_b = payload_len_i.to_bytes(4, 'little')
212 raw = flags_b + group_b + command_b + payload_len_b + payload + b'\x00'
214 self._logger.debug('sock: %s', sock)
215 self._logger.debug('send raw: %d %s', len(raw), raw)
217 sock.sendall(raw)
219 def _ssl_handshake(self, socket_ssl: SSLObject) -> None:
220 self._logger.debug('_ssl_handshake(%s)', socket_ssl)
222 start = dt.datetime.now()
223 tries = 0
224 while True:
225 try:
226 self._logger.debug('ssl handshake: %d', tries)
227 socket_ssl.do_handshake()
228 break
229 except SSLWantReadError as e:
230 pass
231 # self._logger.debug('ssl.SSLWantReadError: %s', e)
232 select([socket_ssl], [], [], 0.3)
233 except SSLWantWriteError as e:
234 pass
235 # self._logger.debug('ssl.SSLWantWriteError: %s', e)
236 select([], [socket_ssl], [], 0.3)
237 except SSLError as e:
238 self._logger.error('ssl.SSLError: %s', e)
239 raise SslHandshakeError(e)
241 now = dt.datetime.now()
242 if now - start >= self._ssl_handshake_timeout:
243 raise SslHandshakeError('ssl handshake timeout')
245 tries += 1
247 self._logger.debug('ssl handshake done: %d', tries)