Untitled
unknown
golang
2 years ago
3.9 kB
2
Indexable
Never
package udp import ( "errors" "fmt" "go.dedis.ch/cs438/transport" "net" "os" "strconv" "strings" "sync" "time" ) const bufSize = 65000 // NewUDP returns a new udp transport implementation. func NewUDP() transport.Transport { return &UDP{} } // UDP implements a transport layer using UDP // // - implements transport.Transport type UDP struct { transport.Transport } // CreateSocket implements transport.Transport func (n *UDP) CreateSocket(address string) (transport.ClosableSocket, error) { conn, err := net.ListenPacket("udp", address) if err != nil { return nil, fmt.Errorf("failed listening: %w", err) } return &Socket{ conn: conn, addr: conn.LocalAddr().String(), }, nil } type Packet struct { sync.Mutex data []transport.Packet } // Socket implements a network socket using UDP. // // - implements transport.Socket // - implements transport.ClosableSocket type Socket struct { transport.Socket transport.ClosableSocket addr string conn net.PacketConn ins Packet outs Packet } // Close implements transport.Socket. It returns an error if already closed. func (s *Socket) Close() error { err := s.conn.Close() if err != nil { return err } return nil } // Send implements transport.Socket func (s *Socket) Send(dest string, pkt transport.Packet, timeout time.Duration) error { data, err := pkt.Marshal() if err != nil { return fmt.Errorf("failed packet serialization: %w", err) } split := strings.Split(dest, ":") if len(split) != 2 { return &net.AddrError{Err: "invalid address", Addr: dest} } port, err := strconv.Atoi(split[1]) if err != nil { return fmt.Errorf("failed address parsing: %w", err) } addr := &net.UDPAddr{ Port: port, IP: net.ParseIP(split[0]), } if timeout > 0 { err := s.conn.SetWriteDeadline(time.Now().Add(timeout)) if err != nil { return err } } _, err = s.conn.WriteTo(data, addr) if err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { return transport.TimeoutError(timeout) } return fmt.Errorf("failed writing to connnection: %w", err) } pkts := &s.outs pkts.Lock() defer pkts.Unlock() pkts.data = append(pkts.data, pkt.Copy()) return nil } // Recv implements transport.Socket. It blocks until a packet is received, or // the timeout is reached. In the case the timeout is reached, return a // TimeoutErr. func (s *Socket) Recv(timeout time.Duration) (transport.Packet, error) { if timeout > 0 { err := s.conn.SetReadDeadline(time.Now().Add(timeout)) if err != nil { return transport.Packet{}, err } } data := make([]byte, bufSize) bufLen, _, err := s.conn.ReadFrom(data) if err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { return transport.Packet{}, transport.TimeoutError(timeout) } return transport.Packet{}, fmt.Errorf("failed reading from connnection: %w", err) } var p transport.Packet err = p.Unmarshal(data[:bufLen]) if err != nil { return transport.Packet{}, fmt.Errorf("failed unmarshalling: %w", err) } pkts := &s.ins pkts.Lock() defer pkts.Unlock() pkts.data = append(pkts.data, p.Copy()) return p, nil } // GetAddress implements transport.Socket. It returns the address assigned. Can // be useful in the case one provided a :0 address, which makes the system use a // random free port. func (s *Socket) GetAddress() string { return s.addr } // GetIns implements transport.Socket func (s *Socket) GetIns() []transport.Packet { pkts := &s.ins pkts.Lock() defer pkts.Unlock() ins := make([]transport.Packet, len(pkts.data)) for i, p := range pkts.data { ins[i] = p.Copy() } return ins } // GetOuts implements transport.Socket func (s *Socket) GetOuts() []transport.Packet { pkts := &s.outs pkts.Lock() defer pkts.Unlock() outs := make([]transport.Packet, len(pkts.data)) for i, p := range pkts.data { outs[i] = p.Copy() } return outs }