package udp
import (
"errors"
"fmt"
"go.dedis.ch/cs438/transport"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
)
const bufSize = 65000
// NewUDP returns a new udp transport implementation.
func NewUDP() transport.Transport {
return &UDP{}
}
// UDP implements a transport layer using UDP
//
// - implements transport.Transport
type UDP struct {
transport.Transport
}
// CreateSocket implements transport.Transport
func (n *UDP) CreateSocket(address string) (transport.ClosableSocket, error) {
conn, err := net.ListenPacket("udp", address)
if err != nil {
return nil, fmt.Errorf("failed listening: %w", err)
}
return &Socket{
conn: conn,
addr: conn.LocalAddr().String(),
}, nil
}
type Packet struct {
sync.Mutex
data []transport.Packet
}
// Socket implements a network socket using UDP.
//
// - implements transport.Socket
// - implements transport.ClosableSocket
type Socket struct {
transport.Socket
transport.ClosableSocket
addr string
conn net.PacketConn
ins Packet
outs Packet
}
// Close implements transport.Socket. It returns an error if already closed.
func (s *Socket) Close() error {
err := s.conn.Close()
if err != nil {
return err
}
return nil
}
// Send implements transport.Socket
func (s *Socket) Send(dest string, pkt transport.Packet, timeout time.Duration) error {
data, err := pkt.Marshal()
if err != nil {
return fmt.Errorf("failed packet serialization: %w", err)
}
split := strings.Split(dest, ":")
if len(split) != 2 {
return &net.AddrError{Err: "invalid address", Addr: dest}
}
port, err := strconv.Atoi(split[1])
if err != nil {
return fmt.Errorf("failed address parsing: %w", err)
}
addr := &net.UDPAddr{
Port: port,
IP: net.ParseIP(split[0]),
}
if timeout > 0 {
err := s.conn.SetWriteDeadline(time.Now().Add(timeout))
if err != nil {
return err
}
}
_, err = s.conn.WriteTo(data, addr)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
return transport.TimeoutError(timeout)
}
return fmt.Errorf("failed writing to connnection: %w", err)
}
pkts := &s.outs
pkts.Lock()
defer pkts.Unlock()
pkts.data = append(pkts.data, pkt.Copy())
return nil
}
// Recv implements transport.Socket. It blocks until a packet is received, or
// the timeout is reached. In the case the timeout is reached, return a
// TimeoutErr.
func (s *Socket) Recv(timeout time.Duration) (transport.Packet, error) {
if timeout > 0 {
err := s.conn.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
return transport.Packet{}, err
}
}
data := make([]byte, bufSize)
bufLen, _, err := s.conn.ReadFrom(data)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
return transport.Packet{}, transport.TimeoutError(timeout)
}
return transport.Packet{}, fmt.Errorf("failed reading from connnection: %w", err)
}
var p transport.Packet
err = p.Unmarshal(data[:bufLen])
if err != nil {
return transport.Packet{}, fmt.Errorf("failed unmarshalling: %w", err)
}
pkts := &s.ins
pkts.Lock()
defer pkts.Unlock()
pkts.data = append(pkts.data, p.Copy())
return p, nil
}
// GetAddress implements transport.Socket. It returns the address assigned. Can
// be useful in the case one provided a :0 address, which makes the system use a
// random free port.
func (s *Socket) GetAddress() string {
return s.addr
}
// GetIns implements transport.Socket
func (s *Socket) GetIns() []transport.Packet {
pkts := &s.ins
pkts.Lock()
defer pkts.Unlock()
ins := make([]transport.Packet, len(pkts.data))
for i, p := range pkts.data {
ins[i] = p.Copy()
}
return ins
}
// GetOuts implements transport.Socket
func (s *Socket) GetOuts() []transport.Packet {
pkts := &s.outs
pkts.Lock()
defer pkts.Unlock()
outs := make([]transport.Packet, len(pkts.data))
for i, p := range pkts.data {
outs[i] = p.Copy()
}
return outs
}