diff --git a/chat.py b/chat.py index 9701021..ce0f332 100644 --- a/chat.py +++ b/chat.py @@ -1,60 +1,87 @@ import socket import threading import protocol +import sys LOCAL_PORT = 9000 -TIMEOUT = 2 +TIMEOUT = 2 # 秒 + class StopWaitChat: def __init__(self, local_port, peer_addr): - self.sock = socket.socket( - socket.AF_INET6, - socket.SOCK_DGRAM - ) + # IPv6 UDP socket + self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + + # 绑定本地 IPv6 任意地址 self.sock.bind(("::", local_port)) + + # 设置 socket 超时(用于停等协议) self.sock.settimeout(TIMEOUT) self.peer_addr = peer_addr + + # 停等协议序号 self.send_seq = 0 self.recv_seq = 0 def sender(self): + """ + 发送线程(停等协议) + """ while True: - msg = input(">> ").encode() + try: + msg = input(">> ").encode() + except EOFError: + break + pkt = protocol.make_packet( self.send_seq, protocol.TYPE_DATA, msg ) + # 停等:一直发,直到收到正确 ACK while True: self.sock.sendto(pkt, self.peer_addr) + try: data, _ = self.sock.recvfrom(1024) seq, ptype, _ = protocol.parse_packet(data) + if ptype == protocol.TYPE_ACK and seq == self.send_seq: + # 正确 ACK self.send_seq ^= 1 break + except socket.timeout: print("[timeout] retransmit") def receiver(self): + """ + 接收线程(必须捕获 timeout) + """ while True: - data, addr = self.sock.recvfrom(1024) + try: + data, addr = self.sock.recvfrom(1024) + except socket.timeout: + continue + seq, ptype, payload = protocol.parse_packet(data) if ptype == protocol.TYPE_DATA: if seq == self.recv_seq: print("\npeer:", payload.decode()) self.recv_seq ^= 1 + + # 无论是否重复,都返回 ACK ack = protocol.make_packet(seq, protocol.TYPE_ACK) self.sock.sendto(ack, addr) -if __name__ == "__main__": - import sys + +def main(): if len(sys.argv) != 3: - print("用法: python3 chat.py <对方IPv6地址> <对方端口>") - exit(1) + print("用法: python3 chat.py <对方IPv6地址或域名> <对方端口>") + sys.exit(1) peer_ip = sys.argv[1] peer_port = int(sys.argv[2]) @@ -64,5 +91,16 @@ if __name__ == "__main__": (peer_ip, peer_port) ) - threading.Thread(target=chat.receiver, daemon=True).start() + # 启动接收线程 + recv_thread = threading.Thread( + target=chat.receiver, + daemon=True + ) + recv_thread.start() + + # 主线程负责发送 chat.sender() + + +if __name__ == "__main__": + main()