initial commit
This commit is contained in:
447
imshow/network.py
Normal file
447
imshow/network.py
Normal file
@@ -0,0 +1,447 @@
|
||||
import socket
|
||||
import zlib
|
||||
import json
|
||||
import struct
|
||||
import queue
|
||||
|
||||
from imshow import parallel
|
||||
|
||||
class SocketMessage():
|
||||
def __init__(self, msg={}):
|
||||
self.data = msg
|
||||
self.type = 'empty'
|
||||
|
||||
def encode(self):
|
||||
pass
|
||||
|
||||
def decode(self):
|
||||
pass
|
||||
|
||||
def tobytes(self):
|
||||
self.encode()
|
||||
js = {}
|
||||
js['type'] = self.type
|
||||
js['msg'] = self.data
|
||||
js = json.dumps(js)
|
||||
js = js.encode('utf-8')
|
||||
byte = zlib.compress(js)
|
||||
return byte
|
||||
|
||||
def frombytes(self, byte):
|
||||
js = zlib.decompress(byte)
|
||||
js = js.decode('utf-8')
|
||||
js = json.loads(js)
|
||||
msg = js['msg']
|
||||
self.data = msg
|
||||
self.decode()
|
||||
|
||||
class SocketServer():
|
||||
def __init__(self, ip='0.0.0.0', port=12345, message_handler=None):
|
||||
self.daemon = parallel.daemon
|
||||
self.ip = ip
|
||||
self.port = port
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket.bind((self.ip, self.port))
|
||||
self.socket.listen(5)
|
||||
self.terminate = False
|
||||
self.handler = message_handler
|
||||
|
||||
def handle_message(self, clientsocket, addr):
|
||||
# print('Connection Established from clinet', addr)
|
||||
if self.handler is not None:
|
||||
self.handler(clientsocket, addr)
|
||||
pass
|
||||
|
||||
def loop(self):
|
||||
while not self.terminate:
|
||||
clientsocket,addr = self.socket.accept()
|
||||
self.daemon.add_job(
|
||||
self.handle_message,
|
||||
args=[clientsocket, addr],
|
||||
name='Client[{0}]'.format(addr)
|
||||
)
|
||||
|
||||
def start(self, back=True):
|
||||
if back:
|
||||
self.daemon.add_job(self.loop, name='SocketMainLoop')
|
||||
else:
|
||||
self.loop()
|
||||
|
||||
# what should a packet header contains:
|
||||
# 1. message id
|
||||
# 2. packet id
|
||||
# 3. total packets
|
||||
# 4. total size
|
||||
|
||||
class PacketHeader():
|
||||
def __init__(self, mid=0, pid=0, pn=0, sz=0):
|
||||
self.msg_id = mid
|
||||
self.pkt_id = pid
|
||||
self.pkt_num = pn
|
||||
self.msg_sz = sz
|
||||
self.header_size = 16
|
||||
|
||||
def tobytes(self):
|
||||
b = struct.pack('LLLL', self.msg_id, self.pkt_id, self.pkt_num, self.msg_sz)
|
||||
return b
|
||||
|
||||
def frombytes(self, b):
|
||||
self.msg_id, self.pkt_id, self.pkt_num, self.msg_sz = struct.unpack('LLLL', b)
|
||||
return self
|
||||
|
||||
class Packet():
|
||||
def __init__(self, header=PacketHeader(), msg=b''):
|
||||
self.header = header
|
||||
self.msg = msg
|
||||
self.header_size = self.header.header_size
|
||||
|
||||
def frombytes(self, b):
|
||||
header = b[:self.header_size]
|
||||
self.msg = b[self.header_size:]
|
||||
self.header.frombytes(header)
|
||||
return self
|
||||
|
||||
def tobytes(self):
|
||||
msg = b''
|
||||
self.header.msg_sz = len(self.msg)
|
||||
msg += self.header.tobytes()
|
||||
msg += self.msg
|
||||
return msg
|
||||
|
||||
class PacketFactory():
|
||||
def __init__(self, max_size=8192, log=None):
|
||||
self.max_size = max_size
|
||||
self.id = 0
|
||||
self.header_size = PacketHeader().header_size
|
||||
self.log = print
|
||||
if log != None:
|
||||
self.log = log
|
||||
|
||||
def to_packets(self, msg):
|
||||
packets = []
|
||||
length = len(msg)
|
||||
capacity = self.max_size - self.header_size
|
||||
num_packets = int((length + capacity - 1) / capacity)
|
||||
for i in range(num_packets):
|
||||
header = PacketHeader(self.id, i, num_packets)
|
||||
packet = Packet(header, msg[i*capacity:(i+1)*capacity])
|
||||
packets.append(packet)
|
||||
self.id += 1
|
||||
return packets
|
||||
|
||||
def from_packets(self, packets):
|
||||
num_packet = len(packets)
|
||||
self.log('Packet Number:', num_packet)
|
||||
if num_packet == 0:
|
||||
return None
|
||||
msg_id = None
|
||||
packet_id = 0
|
||||
message = b''
|
||||
|
||||
for packet in packets:
|
||||
header = packet.header
|
||||
if msg_id is None:
|
||||
msg_id = header.msg_id
|
||||
if num_packet != header.pkt_num:
|
||||
self.log('Uncorrect Package Number')
|
||||
self.log('get {0} while it should be {1}'.format(header.pkt_num, num_packet))
|
||||
return None
|
||||
|
||||
if msg_id != header.msg_id:
|
||||
self.log('Uncorrect Message id')
|
||||
self.log('get {0} while it should be {1}'.format(header.pkt_num, msg_id))
|
||||
return None
|
||||
if packet_id != header.pkt_id:
|
||||
self.log('Uncorrect pkt id')
|
||||
self.log('get {0} while it should be {1}'.format(header.pkt_id, packet_id))
|
||||
return None
|
||||
message += packet.msg
|
||||
packet_id += 1
|
||||
return message
|
||||
|
||||
class AuthMessage(SocketMessage):
|
||||
def __init__(self, token='None'):
|
||||
super(AuthMessage, self).__init__()
|
||||
self.token = token
|
||||
self.type = 'auth'
|
||||
self.stat = 0
|
||||
# stat:
|
||||
# 0: auth client
|
||||
# 1: auth success
|
||||
# 2: auth failed.
|
||||
|
||||
def encode(self):
|
||||
self.data = {}
|
||||
self.data['token'] = self.token
|
||||
self.data['status'] = self.stat
|
||||
|
||||
def decode(self):
|
||||
self.token = self.data['token']
|
||||
self.stat = self.data['status']
|
||||
|
||||
class SocketConnection():
|
||||
def __init__(self, daemon, log_prefix=''):
|
||||
self.sock = None
|
||||
self.daemon = daemon
|
||||
self.messages = queue.Queue()
|
||||
self.terminated = False
|
||||
self.factory = PacketFactory(log=self.log)
|
||||
self.header_size = PacketHeader().header_size
|
||||
self.loglevel = 2
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
# log level:
|
||||
# 0 : debug
|
||||
# 1 : message or info
|
||||
# 2 : warning
|
||||
# 3 : error
|
||||
def log(self, *args, level=0, end='\n'):
|
||||
if level >= self.loglevel:
|
||||
print(self.log_prefix, end=' ')
|
||||
for msg in args:
|
||||
print(str(msg), end=' ')
|
||||
print(end, end='')
|
||||
|
||||
def start(self):
|
||||
self.jid = self.daemon.add_job(self.recv_bare, name='connection')
|
||||
|
||||
def SetSock(self, sock):
|
||||
self.sock = sock
|
||||
self.start()
|
||||
|
||||
def connect(self, host, port):
|
||||
if self.sock is None:
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.sock.connect((host, port))
|
||||
self.log('Sock Connected', level=1)
|
||||
else:
|
||||
self.log('Socket has already established, ingoring connect.', level=2)
|
||||
self.start()
|
||||
|
||||
def auth(self, token):
|
||||
if self.sock is None:
|
||||
self.log('Socket not established, unable to auth.', level=3)
|
||||
return None
|
||||
auth_msg = AuthMessage(token)
|
||||
msg = auth_msg.tobytes()
|
||||
self.log('Sending Authentication Message.')
|
||||
self.send(msg)
|
||||
self.log('Waiting Authentication Status.')
|
||||
msg = self.recv()
|
||||
auth_msg.frombytes(msg)
|
||||
self.log('Status Recived! auth_status =', auth_msg.__dict__)
|
||||
if auth_msg.stat == 1:
|
||||
self.log('Auth Success!', level=1)
|
||||
else:
|
||||
self.log('Auth Failed!', level=3)
|
||||
self.close()
|
||||
|
||||
def WaitAuth(self, token):
|
||||
msg = self.recv()
|
||||
auth_msg = AuthMessage()
|
||||
auth_msg.frombytes(msg)
|
||||
if auth_msg.token != token:
|
||||
self.log('Authentication Failed!', level=1)
|
||||
auth_msg = AuthMessage('InvalidAuth')
|
||||
auth_msg.stat = 2
|
||||
self.send(auth_msg.tobytes())
|
||||
self.close()
|
||||
else:
|
||||
self.log('Authentication Success!', level=1)
|
||||
auth_msg = AuthMessage('Welcome')
|
||||
auth_msg.stat = 1
|
||||
self.send(auth_msg.tobytes())
|
||||
|
||||
def send(self, msg):
|
||||
if self.terminated:
|
||||
return False
|
||||
packets = self.factory.to_packets(msg)
|
||||
self.log('spliting message to {0} packets'.format(len(packets)))
|
||||
for packet in packets:
|
||||
self.sock.send(packet.tobytes())
|
||||
return True
|
||||
|
||||
def commit_message(self, packets):
|
||||
self.log('Generating final packet...')
|
||||
full_msg = self.factory.from_packets(packets)
|
||||
if full_msg is not None:
|
||||
self.log('Valid package!')
|
||||
self.messages.put(full_msg)
|
||||
else:
|
||||
self.log('Invalid package')
|
||||
|
||||
def recv_bare(self):
|
||||
msg_id = None
|
||||
packets = []
|
||||
while not self.terminated:
|
||||
raw_msg = None
|
||||
try:
|
||||
raw_msg = self.sock.recv(8192)
|
||||
except (ConnectionAbortedError, ConnectionResetError):
|
||||
self.log('Connection Stopped')
|
||||
self.close()
|
||||
continue
|
||||
|
||||
if raw_msg is None or len(raw_msg) == 0:
|
||||
self.close()
|
||||
continue
|
||||
if len(raw_msg) < self.header_size:
|
||||
continue
|
||||
|
||||
pkt = Packet(PacketHeader())
|
||||
pkt.frombytes(raw_msg)
|
||||
packets.append(pkt)
|
||||
|
||||
if msg_id is None:
|
||||
msg_id = pkt.header.msg_id
|
||||
|
||||
if msg_id != pkt.header.msg_id:
|
||||
msg_id = pkt.header.msg_id
|
||||
packets = [pkt]
|
||||
continue
|
||||
|
||||
if pkt.header.pkt_id == pkt.header.pkt_num - 1:
|
||||
self.log('Finished')
|
||||
self.commit_message(packets)
|
||||
msg_id = None
|
||||
packets = []
|
||||
|
||||
|
||||
def recv(self):
|
||||
self.log('Getting messages')
|
||||
msg = None
|
||||
while msg is None and not self.terminated:
|
||||
try:
|
||||
msg = self.messages.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
msg = None
|
||||
continue
|
||||
return msg
|
||||
self.log('Message Get Finished')
|
||||
|
||||
def close(self):
|
||||
self.terminated = True
|
||||
self.sock.close()
|
||||
|
||||
|
||||
class HttpServer(SocketServer):
|
||||
def __init__(self):
|
||||
super(HttpServer, self).__init__(port=80)
|
||||
self.header = 'HTTP/1.1 200 OK\nServer: NaiveHttpServer\nConnection: close\nContent-Length: {0}\nContent-Type: text/html\n\n'
|
||||
|
||||
def handle_message(self, clientsocket, addr):
|
||||
# print('Connection Established from clinet', addr)
|
||||
msg = clientsocket.recv(8192)
|
||||
text = msg.decode('utf-8')
|
||||
while '\n' in text:
|
||||
text = text.replace('\n', '<br>')
|
||||
info = '<html><body>'
|
||||
info += '<h1> Hello, World </h1>'
|
||||
info += '<h2> Your IP Address & Port</h2>\n'
|
||||
info += '<p>{0}</p>'.format(str(addr))
|
||||
info += '<h2> Your Request</h2>\n'
|
||||
info += '<p>{0}</p>\n'.format(text)
|
||||
info += '</body></html>\n'
|
||||
length = len(info)
|
||||
# print('length =', length)
|
||||
header = self.header.format(length)
|
||||
msg = header + info
|
||||
# print('msg:', msg)
|
||||
clientsocket.send(msg.encode('utf-8'))
|
||||
# print('message sent!')
|
||||
clientsocket.close()
|
||||
|
||||
class BridgeConnection():
|
||||
def __init__(self, client, server, num_threads=8):
|
||||
self.client = client
|
||||
self.server = server
|
||||
self.daemon = parallel.ParallelHost(num_threads)
|
||||
self.exit = False
|
||||
self.client_terminated = False
|
||||
self.server_terminated = False
|
||||
|
||||
def is_terminated(self):
|
||||
if self.client_terminated:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def client_recv_handler(self, msg):
|
||||
# print('send message to server, size=', len(msg))
|
||||
self.server.send(msg)
|
||||
|
||||
def stop(self):
|
||||
self.daemon.stop('kill')
|
||||
self.server.close()
|
||||
print('connection terminated successfully.')
|
||||
|
||||
def client_recv(self):
|
||||
while True:
|
||||
msg = self.client.recv(8192)
|
||||
if len(msg) == 0:
|
||||
self.exit = True
|
||||
break
|
||||
# print('message received from clinet, size=', len(msg))
|
||||
self.daemon.add_job(self.client_recv_handler, args=[msg])
|
||||
self.client_terminated = True
|
||||
print('Client Recv terminated.')
|
||||
|
||||
def server_recv_handler(self, msg):
|
||||
# print('send message to client, size=', len(msg))
|
||||
self.client.send(msg)
|
||||
|
||||
def server_recv(self):
|
||||
while not self.exit:
|
||||
try:
|
||||
msg = self.server.recv(8192)
|
||||
except ConnectionAbortedError:
|
||||
break
|
||||
if len(msg) == 0:
|
||||
self.exit = True
|
||||
continue
|
||||
# print('message received from server, size=', len(msg))
|
||||
self.daemon.add_job(self.server_recv_handler, args=[msg])
|
||||
self.server_terminated = True
|
||||
print('Server Recv terminated.')
|
||||
|
||||
def run(self):
|
||||
self.daemon.add_job(self.server_recv)
|
||||
self.daemon.add_job(self.client_recv)
|
||||
while not self.is_terminated():
|
||||
time.sleep(1)
|
||||
self.stop()
|
||||
|
||||
def SendMessage(host, port, msg):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.connect((host, port))
|
||||
s.send(msg)
|
||||
print('Message Sent!')
|
||||
s.close()
|
||||
|
||||
|
||||
|
||||
def portforward(dst_ip, dst_port, listen_ip, listen_port):
|
||||
def port_forward_handler(clientsocket, addr):
|
||||
print('Handling message from clinet', addr)
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.connect((dst_ip, dst_port))
|
||||
print('Connect to dst server success!')
|
||||
connection = BridgeConnection(clientsocket, s)
|
||||
connection.run()
|
||||
print('handler exited')
|
||||
server = SocketServer(message_handler=port_forward_handler, ip=listen_ip, port=listen_port)
|
||||
server.start()
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# # let's try a port forwarding server using this socker server.
|
||||
# portforward('192.168.233.101', 22, '0.0.0.0', 30001)
|
||||
# portforward('192.168.233.102', 22, '0.0.0.0', 30002)
|
||||
|
||||
# con = console.console('PortForward')
|
||||
# con.interactive()
|
||||
|
||||
if __name__ == '__main__':
|
||||
server = HttpServer()
|
||||
server.start()
|
||||
import time
|
||||
time.sleep(10000)
|
||||
Reference in New Issue
Block a user