노트패드로 적다가 구글문서로 붙여넣기 했는데 탭, 한글 띄워쓰기 등등이 다 깨져버렸다.
정리하다가 귀찮아서 방치..
테스트 코드는 python으로 했고, twisted 등을 사용하려니 왠지 내가 구현하고자 하는 초점이 그게 아닌 것 같아서 걍 소켓 블록 모드에 타임아웃 걸어서 처리했다.
websocket(줄여서 앞으로 ws라고 함) 구현이 정의된 RFC가 2가지 버전이 있다.
문제는 아직도 draft 상태이고 언제 또 확 뒤집힐지 모른다는 것 이다.
나는 websocket-bridge에서 master branch로 구현하고있는 hixie-76 문서를 보고 서버측을 구현을 했으며, 관련된 샘플 코드는 아래서 볼 수 있다.
# -*- coding: utf-8 -*-
import sys, os, traceback
import socket
import struct
from md5 import *
class WebSocket(object):
# state
STAND_BY = 0
ESTABLISHED = 1
CLOSING = 2
CLOSED = 3
sk = None
_state = STAND_BY
req = None
class Exception(Exception):
msg = ""
addr = None
def __init__(self, addr="", msg=""):
self.addr = addr
self.msg = unicode(msg)
def __unicode__(self):
return u"[" + self.addr + u"] " + self.msg
def __str_(self):
return str(self.__unicode__().encode('utf-8'))
def __init__(self, sk):
self.sk = sk
self.sk.settimeout(0.5)
self._openning_handshake()
def _openning_handshake(self):
try:
addr = self.sk.getsockname()
print "client's openning handshake"
try:
t = self.sk.recv(1024)
if t[:4] != "GET ":
print "unknown request"
self.abort()
t_len = len(t)
while t_len < 4096 and t.find("\r\n\r\n") == -1:
t += self.sk.recv(1024)
t_len = len(t)
except socket.timeout:
self.abort()
req = {}
# 여기서 부터는 완전한 헤더가 완성됨
req['key_3'] = t[t.find('\r\n\r\n') + 4:]
header = t[:t.find('\r\n\r\n') + 2].split('\r\n')
request_line = header[0]
fields = header[1:]
print fields
fields = tuple(
map(
lambda x: (x[0].lower(), x[1]),
[map(lambda x: x.strip(), line.split(':', 1)) for line in fields if len(line) > 0]
)
)
# request line
req['resource_name'] = tuple(request_line.split(' '))[1]
if req['resource_name'][0] != '/':
print "resource var start with a soliders character"
self.abort()
#make server's openning handshake data
req['fields'] = dict(fields)
req['subprotocol'] = req['fields']['sec-websocket-protocol'] \
if req['fields'].has_key('sec-websocket-protocol') \
else None
req['subprotocol_ex'] = 'Sec-WebSocket-Protocol: %s\r\n' % req['subprotocol'] \
if req['subprotocol'] \
else ""
req['key_1'] = req['fields']['sec-websocket-key1']
req['key_2'] = req['fields']['sec-websocket-key2']
req['origin'] = req['fields']['origin']
req['host'] = req['fields']['host']
req['secure_flag'] = False
req['location'] = 'ws://' + req['host'] + req['resource_name']
key_n1 = int(filter(lambda x: x.isdigit(), req['key_1']))
key_n2 = int(filter(lambda x: x.isdigit(), req['key_2']))
sp_n1 = req['key_1'].count(' ')
sp_n2 = req['key_2'].count(' ')
if sp_n1 == 0 or sp_n2 == 0:
print "WARNING: cross-protocol intrusion detection"
self.abort()
if key_n1 % sp_n1 != 0 or key_n2 % sp_n2 != 0:
print "no WebSocket client"
self.abort()
part_n1 = key_n1 / sp_n1
part_n2 = key_n2 / sp_n2
req['challenge'] = struct.pack('>LLQ', part_n1, part_n2, int(req['key_3'].encode('hex'), 16))
if len(req['challenge']) != 16:
print "challenge failed"
self.abort()
req['response'] = md5(req['challenge']).digest()
if not req['fields'].has_key('origin'):
print "not found origin field in header"
self.abort()
base_hands = '\
HTTP/1.1 101 WebSocket Protocol Handshake\r\n\
Upgrade: WebSocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Location: %(location)s\r\n\
Sec-WebSocket-Origin: %(origin)s\r\n\
%(subprotocol_ex)s\r\n\
' % req
print "request from", self.sk.getpeername(), "with \n", t
print "response to", self.sk.getpeername(), "with \n", base_hands, req['response']
self.sk.send(base_hands + req['response'])
self.sk.settimeout(60)
#self.req = req
self._state = WebSocket.ESTABLISHED
except WebSocket.Exception:
print '[', addr, ']', "WebSocket exception"
#except:
# print '[', addr, ']', "unknown exception"
def abort(self):
if self.sk:
self.sk.close()
self.sk = None
self._state = WebSocket.CLOSED
raise WebSocket.Exception()
def _closing_handshake(self):
if self.state == WebSocket.ESTABLISHED:
self.sk.send("\xff\x00")
self.sk.shutdown(socket.SHUT_WR)
self.sk.settimeout(0.5)
try:
while self.sk.recv(1024): pass # dummy op
except:
pass
self.sk.close()
self._state = WebSocket.CLOSING
def close(self):
if self.state == WebSocket.ESTABLISHED:
self._closing_handshake()
elif self.sk:
self.sk.close()
self.sk = None
self._state = WebSocket.CLOSED
@property
def state(self):
return self._state
def send(self, data):
if self.state == WebSocket.ESTABLISHED:
return self.sk.send('\x00' + data + '\xff')
return -1
def recv(self):
if self.state != WebSocket.ESTABLISHED:
return -1
ret = []
buf = self.sk.recv(1024)
if not buf: # disconnect from client
self.abort()
while 1:
type = ord(buf[0])
if type & 0x80 == 0x00: # type of text
if type != 0:
self.abort()
buf = buf[1:]
eos = buf.find('\xff')
while len(buf) < 4096 and eos == -1:
tmp = self.sk.recv(1024)
if not tmp: self.abort()
buf += tmp
eos = buf.find('\xff')
if eos == -1:
print 'overflow stream'
self.abort() # overflow stream
else:
ret.append(buf[:eos])
buf = buf[eos+1:]
if len(buf) == 0: # end of stream
return ret
elif type & 0x80 != 0x00: # type of binary
if type != 0xFF:
self.abort()
buf = buf[1:]
if ord(buf[0]) == 0x00: # closing handshake
self.close()
return -1
# get length
length = 0
recv_len = len(buf)
while recv_len < 32: # maximum
eol = False
for i, b in enumerate(buf):
b_v = ord(b) & 0x7F
length = length * 128 + b_v
if not ord(b) & 0x80:
eol = True
buf = buf[i+1:]
break
if not eol:
buf = self.sk.recv(1024)
if not buf: self.abort()
recv_len += len(buf)
if not eol: self.abort()
# get data
while len(buf) < length:
tmp = self.sk.recv(1024)
if not tmp: self.abort()
buf += tmp
ret.append(buf[:length])
buf = buf[length+1:]
if len(buf) == 0:
return ret
else:
# unreachable area
self.abort()
def main():
s=socket.socket()
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("", 9998))
print 'listen to', s.getsockname()
s.listen(5)
while True:
try:
print 'wait for connection..'
c=s.accept()
print c[1], "accepted"
ws = WebSocket(c[0])
while True:
print "wait for recv.."
frames = ws.recv()
if frames == -1:
print 'disconnect from', c[1]
ws.close()
break
print "echo: ", frames
for frame in frames: ws.send(frame)
except KeyboardInterrupt:
s.close()
sys.exit(0)
except WebSocket.Exception as e:
print 'ws: abort', e;
except:
c[0].close()
print traceback.print_exc()
if __name__ == "__main__":
main()
코드가 지저분한 이유는 RFC명세에 나온 변수명을 그대로(비슷하게) 갖다 썼기 때문이다.
간단한 채팅정도야 만들 수 있지만,, 이건 말 그대로 스펙을 구현해보고자 하는 샘플이기 때문에 실제로 사용하려면 twisted 모듈을 사용해서 구현해야겠다. 실제로 구글에서 찾아보면 많이 나오기에 딱히 구현하고 싶지 않다.. -_-
댓글을 달아 주세요