Untitled

 avatar
unknown
golang
2 years ago
12 kB
3
Indexable
// Contains the implementation of a LSP client.

// heartbeat added, working on exponential backoff

package lsp

import (
	"container/list"
	"encoding/json"
	"fmt"
	"github.com/cmu440/lspnet"
	"sort" 
	"time"
)

type client struct {
	// TODO: implement this!
	udpConn     *lspnet.UDPConn
	connID      int
	seqNumCli   int
	seqNumSer   int
	msgList     *list.List
	readPayload chan []byte
	writeSig    chan []byte
	readAckMsg  chan *Message
	readDataMsg chan *Message
	readCalled  chan bool

	// for sliding window
	maxUnackedMessages int
	windowSize  int
	unackedMessageSeqList []int 
	writeQueue *list.List
	unwrittenList *list.List

	resendList *list.List

	// for epoch event
	epochLimit int
	epochMillis int
	maxBackOffInterval int
	epochCount int // how many epoch not responded

	tryWrite chan []byte 
	canWrite chan *Message
	updateWrite chan int

	epochTicker <-chan time.Time

	curEpochTime int // a global timer (we will add 1 to curEpochTime in each epoch)
}


type MessageInfo struct {
	payload []byte // probably change to string
	seqNum int
	curBackOff int
	lastSentEpoch int // type??
}


// NewClient creates, initiates, and returns a new client. This function
// should return after a connection with the server has been established
// (i.e., the client has received an Ack message from the server in response
// to its connection request), and should return a non-nil error if a
// connection could not be made (i.e., if after K epochs, the client still
// hasn't received an Ack message from the server in response to its K
// connection requests).
//
// initialSeqNum is an int representing the Initial Sequence Number (ISN) this
// client must use. You may assume that sequence numbers do not wrap around.
//
// hostport is a colon-separated string identifying the server's host address
// and port number (i.e., "localhost:9999").
func NewClient(hostport string, initialSeqNum int, params *Params) (Client, error) {
	saddr, err := lspnet.ResolveUDPAddr("udp", hostport)
	if err != nil {
		fmt.Println(err.Error())
		return nil, err
	}

	udpConn, err := lspnet.DialUDP("udp", nil, saddr)
	if err != nil {
		fmt.Println(err.Error())
		return nil, err
	}

	msgToSer := NewConnect(initialSeqNum) // need to resend connect message
	msgJson, err := json.Marshal(msgToSer)

	if err != nil {
		fmt.Println(err.Error())
		return nil, err
	}

	_, err = udpConn.Write(msgJson)
	if err != nil {
		fmt.Println(err.Error())
		return nil, err
	}

	buf := make([]byte, 2000)

	var msgFromSer Message

	for {
		n, err := udpConn.Read(buf)
		if err != nil {
			fmt.Println(err.Error())
			return nil, err
		}
		err = json.Unmarshal(buf[:n], &msgFromSer)
		if err != nil {
			fmt.Println(err.Error())
			return nil, err
		}

		if (msgFromSer.Type == MsgAck) && (msgFromSer.SeqNum == initialSeqNum) {
			cli := client{
				udpConn:     udpConn,
				connID:      msgFromSer.ConnID,
				seqNumCli:   initialSeqNum,
				seqNumSer:   msgFromSer.SeqNum + 1,
				msgList:     list.New(),
				readPayload: make(chan []byte, 1),
				writeSig:    make(chan []byte),
				readAckMsg:  make(chan *Message),
				readDataMsg: make(chan *Message),
				readCalled:  make(chan bool),
				maxUnackedMessages: params.MaxUnackedMessages,
				windowSize: params.WindowSize,
				epochLimit: params.EpochLimit,
				epochMillis: params.EpochMillis,
				maxBackOffInterval: params.MaxBackOffInterval,
				tryWrite:    make(chan []byte),
				canWrite:    make(chan *Message),
				updateWrite: make(chan int),
				unackedMessageSeqList: []int{},
				writeQueue:  list.New(),
				unwrittenList: list.New(),
				epochTicker: time.Tick(time.Duration(params.EpochMillis) * time.Millisecond), // ticker
				epochCount:  0,
				resendList:  list.New(),
				curEpochTime: 0,
			}

			go MainRoutine(&cli)
			go ReadRoutine(&cli)
			return &cli, nil
		}
	}
	return nil, err
}

func (c *client) ConnID() int {
	return c.connID
}

func (c *client) Read() ([]byte, error) {
	
	c.readCalled <- true
	payload := <-c.readPayload
	return payload, nil
}

func (c *client) Write(payload []byte) error {
	// send message to main_routine to check whether it can write
	c.tryWrite <- payload

	return nil
}

func (c *client) Close() error {
	c.udpConn.Close()

	// close all goroutines

	return nil
}

func MainRoutine(c *client) {

	readCalledSig := false
	var sendHeartBeat bool
	sendHeartBeat = true

	for {
		// fmt.Println("========================")

		if readCalledSig == true {
			head := c.msgList.Front()
			if head != nil && head.Value.(*Message).SeqNum == c.seqNumSer {
				c.readPayload <- head.Value.(*Message).Payload
				c.msgList.Remove(head)
				c.seqNumSer++
				readCalledSig = false
			}
		}

		if c.unwrittenList.Len() > 0{
			// fmt.Println("*** client 11111111")
			firstMsg := c.unwrittenList.Front().Value.([]byte)
			if len(c.unackedMessageSeqList) > 0{
				if (len(c.unackedMessageSeqList) < c.maxUnackedMessages) &&  (c.seqNumCli < (c.unackedMessageSeqList[0] + c.windowSize -1)) { // < or <=?
					writeMsg(c, firstMsg)

					c.unwrittenList.Remove(c.unwrittenList.Front())
					sendHeartBeat = false

					InsertToSeqList(c, c.seqNumCli)
					// create a new messageInfo instance
					msgInfo := MessageInfo{
						payload:       firstMsg,
						curBackOff:    0,
						lastSentEpoch: 0,
						seqNum:        c.seqNumCli,
					}
					// add this instance to the list
					c.resendList.PushBack(&msgInfo)

					// fmt.Println("*** client 22222222")
				}
			} else {
				writeMsg(c, firstMsg)

				c.unwrittenList.Remove(c.unwrittenList.Front())
				sendHeartBeat = false

				InsertToSeqList(c, c.seqNumCli)
				// create a new messageInfo instance
				msgInfo := MessageInfo{
					payload:       firstMsg,
					curBackOff:    0,
					lastSentEpoch: 0,
					seqNum:        c.seqNumCli,
				}
				// add this instance to the list
				c.resendList.PushBack(&msgInfo)

				// fmt.Println("*** client 333333333")
			}

			
			// fmt.Println("*** client 44444444")

			
			
			// fmt.Println("*** client 55555555")
		}


		// fmt.Println("--------------------")
		select {
		case <- c.epochTicker:
			c.curEpochTime += 1 // add at front or back of this case??

			// fire epoch
			if sendHeartBeat == true{
				// send heart beat message
				heartBeat(c)
			} 
			sendHeartBeat = true

			c.epochCount += 1

			if(c.epochCount >= c.epochLimit){ // assume server dead and close the connection?
				c.Close()
			}


			// check if we need to resend any message
			
			if c.resendList.Len() > 0{
				for e := c.resendList.Front(); e != nil; e = e.Next() {
					
					if e.Value.(*MessageInfo).lastSentEpoch + e.Value.(*MessageInfo).curBackOff <= c.curEpochTime{ // we should resend
						rewriteMsg(c, e.Value.(*MessageInfo).payload, e.Value.(*MessageInfo).seqNum)

						// update this messageInfo
						if e.Value.(*MessageInfo).curBackOff == 0{
							e.Value.(*MessageInfo).curBackOff += 1
						} else {
							e.Value.(*MessageInfo).curBackOff = e.Value.(*MessageInfo).curBackOff * 2
						}
						e.Value.(*MessageInfo).lastSentEpoch = c.curEpochTime
					}
				}
			}
			

		case msg := <-c.readDataMsg:
			InsertMsgToList(c, msg)
			// PrintList(c.msgList)
			WriteAck(c, msg)
			
			c.epochCount = 0 // means the server connection is still alive

		case ackMsg := <-c.readAckMsg:
			// update unackedMessageSeqList
			
			

			if ackMsg.Type == MsgAck{
				// fmt.Println("### client 1")
				if ackMsg.SeqNum == 0{ // heartbeat message
					// fmt.Println("### client 2")
					c.epochCount = 0
				} else {
					// fmt.Println("### client 3")
					RemoveFromSeqList(c, ackMsg.SeqNum)
					// fmt.Println("### client 4")
					RemoveFromResendList(c, ackMsg.SeqNum)
					// fmt.Println("### client 5")
				}
			}
			// fmt.Println("### client 6")

			if ackMsg.Type == MsgCAck {
				if len(c.unackedMessageSeqList) > 0{
					firstSeqInList := c.unackedMessageSeqList[0]

					for i := firstSeqInList; i <= ackMsg.SeqNum; i++ {
						RemoveFromSeqList(c, i)
						RemoveFromResendList(c, i)
					}
				}
				
			}

		case <-c.readCalled:
			readCalledSig = true


		case temp := <- c.tryWrite: 
			c.unwrittenList.PushBack(temp)

		}
	}
}

func ReadRoutine(c *client) {
	buf := make([]byte, 2000)

	for {
		var msgFromSer Message
		// fmt.Printf(" 11111111111\n")

		n, err := c.udpConn.Read(buf)
		if err != nil {
			fmt.Println(err.Error())
		}
		err = json.Unmarshal(buf[:n], &msgFromSer)
		if err != nil {
			fmt.Println(err.Error())
		}

		fmt.Printf("   client read = %s\n", &msgFromSer)

		if msgFromSer.Type == MsgData {
			// fmt.Printf("!!! client 1\n")
			c.readDataMsg <- &msgFromSer
			// fmt.Printf("!!! client 2\n")
		} else { // MsgAck or MsgCAck
			// fmt.Printf("!!! client 3\n")
			c.readAckMsg <- &msgFromSer
			// fmt.Printf("!!! client 4\n")
		}

	}
}


func PrintList(l *list.List) {
	i := 0
	fmt.Println("Client LIST:  START")
	for e := l.Front(); e != nil; e = e.Next() {
		fmt.Printf("Client:  %d: %s\n", i, e.Value.(*Message))
		i++
	}
	fmt.Println("Client LIST:  END")
}

func InsertMsgToList(c *client, msg *Message) {
	e := c.msgList.Back()
	for ; e != nil; e = e.Prev() {
		if e.Value.(*Message).SeqNum < msg.SeqNum {
			c.msgList.InsertAfter(msg, e)
			return
		}
	}
	c.msgList.PushFront(msg)
}

func WriteAck(c *client, msg *Message) error {

	msgToSer := NewAck(msg.ConnID, msg.SeqNum)
	msgJson, err := json.Marshal(msgToSer)

	if err != nil {
		fmt.Println(err.Error())
		return err
	}

	// fmt.Printf("Client:  write ack msg: %s\n", msgToSer)
	_, err = c.udpConn.Write(msgJson)
	if err != nil {
		fmt.Println(err.Error())
		return err
	}

	fmt.Printf("client ack = %s\n", msgToSer) // ?

	return nil
}

func heartBeat(c *client){
	msgToSer := NewAck(c.connID, 0)
	msgJson, err := json.Marshal(msgToSer)

	if err != nil {
		fmt.Println(err.Error())
		return 
	}

	fmt.Printf("Client:  write heartbeat: %s\n", msgToSer)

	_, err = c.udpConn.Write(msgJson)
	if err != nil {
		fmt.Println(err.Error())
		return 
	}
	return 
}


func RemoveFromSeqList(c *client, index int){
	// remove from the slice
	for i, v := range c.unackedMessageSeqList {
		if v == index{
			del_index := i
			c.unackedMessageSeqList[del_index] = c.unackedMessageSeqList[len(c.unackedMessageSeqList)-1]
			c.unackedMessageSeqList = c.unackedMessageSeqList[:len(c.unackedMessageSeqList)-1]
			break
		}
	} 

	// sort
	sort.Ints(c.unackedMessageSeqList)
}

func InsertToSeqList(c *client, num int){
	c.unackedMessageSeqList = append(c.unackedMessageSeqList, num)
	sort.Ints(c.unackedMessageSeqList) // small to large?
}


func writeMsg(c *client, payload []byte){
	c.seqNumCli++
	size := len(payload)
	checksum := CalculateChecksum(c.connID, c.seqNumCli, size, payload)
	msgToSer := NewData(c.connID, c.seqNumCli, size, payload, checksum)

	fmt.Printf("client write = %s\n", msgToSer)

	msgJson, err := json.Marshal(msgToSer)
	if err != nil {
		fmt.Println(err.Error())
		return 
	}

	_, err = c.udpConn.Write(msgJson)
	if err != nil {
		fmt.Println(err.Error())
		return 
	}

	return 
}


func rewriteMsg(c *client, payload []byte, seqNum int){
	size := len(payload)
	checksum := CalculateChecksum(c.connID, seqNum, size, payload)
	msgToSer := NewData(c.connID, seqNum, size, payload, checksum)

	msgJson, err := json.Marshal(msgToSer)
	if err != nil {
		fmt.Println(err.Error())
		return 
	}

	_, err = c.udpConn.Write(msgJson)
	if err != nil {
		fmt.Println(err.Error())
		return 
	}

	fmt.Printf("client write = %s\n", msgToSer)

	return 
}


func RemoveFromResendList(c *client, ackMsg_SeqNum int){
	for e := c.resendList.Front(); e != nil; e = e.Next() {
						
		if ackMsg_SeqNum == e.Value.(*MessageInfo).seqNum { // we should resend
			c.resendList.Remove(e) // remove the correct position???????
		}
	}

}






Editor is loading...