Coverage for src/lib/network.py: 13%

171 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2025-03-09 17:37 +0000

1 

2import datetime as dt 

3from logging import Logger 

4from socket import socket as Socket 

5from ssl import SSLObject, SSLWantReadError, SSLWantWriteError, SSLError 

6from select import select 

7from struct import error as StructError 

8from typing import Union 

9 

10 

11CLIENT_READ_SIZE = 2048 

12 

13# RawCommandType = list[list[int, int, list[int]]] 

14RawGroupType = int 

15RawCommandType = int 

16RawPayloadItemsType = list[bytes] 

17RawItemType = tuple[RawGroupType, RawCommandType, RawPayloadItemsType] 

18RawCommandsType = list[RawItemType] 

19 

20class SslHandshakeError(Exception): 

21 pass 

22 

23# status.commands.append([group, command, payload_items]) 

24class SocketReadStatus(): # pragma: no cover 

25 disconnect: bool 

26 commands: RawCommandsType 

27 msg: str 

28 

29 def __init__(self) -> None: 

30 self.disconnect = False 

31 self.commands = [] 

32 self.msg = 'Init' 

33 

34class Network(): 

35 _logger: Logger 

36 _ssl_handshake_timeout: dt.timedelta 

37 

38 def __init__(self) -> None: 

39 pass 

40 

41 def _client_read(self, sock: Socket) -> SocketReadStatus: 

42 self._logger.debug('_client_read(%s)', sock) 

43 

44 status = SocketReadStatus() 

45 

46 raw_total = b'' 

47 

48 reading = True 

49 while reading: 

50 raw_len = 0 

51 try: 

52 raw = sock.recv(CLIENT_READ_SIZE) 

53 raw_len = len(raw) 

54 self._logger.debug('recv raw A: %d %s', raw_len, raw) 

55 

56 raw_total += raw 

57 except TimeoutError as e: 

58 self._logger.debug('TimeoutError: %s', e) 

59 

60 status.disconnect = True 

61 status.msg = 'TimeoutError' 

62 reading = False 

63 except ConnectionResetError as e: 

64 self._logger.debug('ConnectionResetError: %s', e) 

65 

66 status.disconnect = True 

67 status.msg = 'ConnectionResetError' 

68 reading = False 

69 except SSLWantReadError as e: 

70 self._logger.debug('SSLWantReadError: %s', e) 

71 

72 status.disconnect = True 

73 status.msg = 'SSLWantReadError' 

74 reading = False 

75 else: 

76 if raw_len >= CLIENT_READ_SIZE: 

77 self._logger.debug('raw_len(%d) >= CLIENT_READ_SIZE(%d)', raw_len, CLIENT_READ_SIZE) 

78 

79 reading = True 

80 elif raw_len > 0: 

81 self._logger.debug('raw_len(%d) > 0', raw_len) 

82 

83 reading = False 

84 else: 

85 self._logger.debug('raw_len(%d) < CLIENT_READ_SIZE(%d)', raw_len, CLIENT_READ_SIZE) 

86 

87 reading = False 

88 status.disconnect = True 

89 status.msg = 'raw_len < CLIENT_READ_SIZE' 

90 

91 raw_total_len = len(raw_total) 

92 if raw_total_len > 0: 

93 self._logger.debug('recv raw B: %d %s', raw_total_len, raw_total) 

94 

95 raw_pos = 0 

96 while raw_pos < raw_total_len: 

97 try: 

98 flags_i = raw_total[raw_pos] 

99 raw_pos += 1 

100 

101 group = raw_total[raw_pos] 

102 raw_pos += 1 

103 

104 command = raw_total[raw_pos] 

105 raw_pos += 1 

106 except IndexError as e: 

107 self._logger.debug('IndexError: %s', e) 

108 

109 status.disconnect = True 

110 status.msg = 'array index out of range' 

111 return 

112 

113 lengths_are_4_bytes = flags_i & 1 != 0 

114 

115 try: 

116 length = int.from_bytes(raw_total[raw_pos:raw_pos + 4], 'little') 

117 raw_pos += 4 

118 except StructError as e: 

119 self._logger.debug('struct.error: %s', e) 

120 

121 status.disconnect = True 

122 status.msg = 'unpack error' 

123 return 

124 

125 payload_raw = raw_total[raw_pos:] 

126 payload_items = [] 

127 

128 self._logger.debug('group: %d', group) 

129 self._logger.debug('command: %d', command) 

130 self._logger.debug('length: %d %s', length, type(length)) 

131 

132 pos = 0 

133 while pos < length: 

134 if lengths_are_4_bytes: 

135 item_len = int.from_bytes(payload_raw[pos:pos + 4], 'little') 

136 pos += 3 

137 else: 

138 item_len = payload_raw[pos] 

139 pos += 1 

140 

141 self._logger.debug('item len: %d %s', item_len, type(item_len)) 

142 

143 item = payload_raw[pos:pos + item_len] 

144 self._logger.debug('item content: %s', item) 

145 

146 # TODO: Remove decode here. use Bytes everywhere and decode in server.py when needed 

147 payload_items.append(item) 

148 pos += item_len 

149 

150 status.commands.append([group, command, payload_items]) 

151 raw_pos += length + 1 

152 # self._logger.debug('raw_pos: %d', raw_pos) 

153 

154 return status 

155 

156 def _client_write(self, sock: Socket, group: int, command: int, data: list = []): 

157 self._logger.debug('_client_write(%d, %d, %s)', group, command, data) 

158 

159 flag_lengths_are_4_bytes = False 

160 

161 for item in data: 

162 self._logger.debug('data item: %s %s', type(item), item) 

163 

164 if isinstance(item, int): 

165 continue 

166 

167 item_content_len = len(item) 

168 if item_content_len > 255: 

169 flag_lengths_are_4_bytes = True 

170 

171 self._logger.debug('flag_lengths_are_4_bytes: %s', flag_lengths_are_4_bytes) 

172 

173 payload_len_i = 0 

174 payload_items = [] 

175 for item in data: 

176 enconded_item = None 

177 if isinstance(item, str): 

178 enconded_item = item.encode() 

179 elif isinstance(item, bytes): 

180 enconded_item = item 

181 elif isinstance(item, int): 

182 enconded_item = item.to_bytes(4, 'little') 

183 

184 item_content_len = len(enconded_item) 

185 payload_len_i += item_content_len 

186 self._logger.debug('item: l=%d t=%s i=%s', item_content_len, type(item), item) 

187 

188 if flag_lengths_are_4_bytes: 

189 payload_items.append(item_content_len.to_bytes(4, 'little')) 

190 payload_len_i += 3 

191 else: 

192 self._logger.debug('item_content_len: %s', item_content_len.to_bytes(1, 'little')) 

193 payload_items.append(item_content_len.to_bytes(1, 'little')) 

194 

195 payload_len_i += 1 

196 payload_items.append(enconded_item) 

197 

198 self._logger.debug('payload_len_i: %d', payload_len_i) 

199 self._logger.debug('payload_items: %s', payload_items) 

200 

201 payload = b''.join(payload_items) 

202 

203 flags_i = 0 

204 if flag_lengths_are_4_bytes: # LENs are 4 bytes 

205 flags_i |= 1 

206 

207 flags_b = flags_i.to_bytes(1, 'little') 

208 group_b = group.to_bytes(1, 'little') 

209 command_b = command.to_bytes(1, 'little') 

210 payload_len_b = payload_len_i.to_bytes(4, 'little') 

211 

212 raw = flags_b + group_b + command_b + payload_len_b + payload + b'\x00' 

213 

214 self._logger.debug('sock: %s', sock) 

215 self._logger.debug('send raw: %d %s', len(raw), raw) 

216 

217 sock.sendall(raw) 

218 

219 def _ssl_handshake(self, socket_ssl: SSLObject) -> None: 

220 self._logger.debug('_ssl_handshake(%s)', socket_ssl) 

221 

222 start = dt.datetime.now() 

223 tries = 0 

224 while True: 

225 try: 

226 self._logger.debug('ssl handshake: %d', tries) 

227 socket_ssl.do_handshake() 

228 break 

229 except SSLWantReadError as e: 

230 pass 

231 # self._logger.debug('ssl.SSLWantReadError: %s', e) 

232 select([socket_ssl], [], [], 0.3) 

233 except SSLWantWriteError as e: 

234 pass 

235 # self._logger.debug('ssl.SSLWantWriteError: %s', e) 

236 select([], [socket_ssl], [], 0.3) 

237 except SSLError as e: 

238 self._logger.error('ssl.SSLError: %s', e) 

239 raise SslHandshakeError(e) 

240 

241 now = dt.datetime.now() 

242 if now - start >= self._ssl_handshake_timeout: 

243 raise SslHandshakeError('ssl handshake timeout') 

244 

245 tries += 1 

246 

247 self._logger.debug('ssl handshake done: %d', tries)