package org.kth.dks.dks_comm;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import org.apache.log4j.Logger;
import org.kth.dks.dks_comm.ConnHandlerOut.UnackedMsg;
import org.kth.dks.dks_marshal.BootstrapMsg;
import org.kth.dks.dks_marshal.DKSMarshal;
import org.kth.dks.dks_marshal.MsgSrcDestWrapper;
import org.kth.dks.util.DKSPrintTypes;

/**
 * <p>Title: DKS</p>
 * <p>Description: DKS Middleware</p>
 * <p>Copyright: Copyright (c) 2004</p>
 * <p>Company: KTH-IMIT</p>
 * @author Ali Ghodsi (aligh@kth.se)
 * @version 1.0
 */

public class ConnHandlerOutNB implements Runnable {
	
    private static Logger log = Logger.getLogger(ConnHandlerOutNB.class);
    
	Status state = Status.INIT;
	
	private SocketChannel sock;
	private static final int BUFFSIZE       = 4096;
	private final int MAXSEND	  		    = 100;
	private List msgQueue                   = Collections.synchronizedList(new LinkedList());
	private Map unackedMsgs                 = Collections.synchronizedMap(new HashMap());
	private double rtt                      = DEFAULTRTT;
	private final static double DEFAULTRTT  = 1000.0d;
	private final static double RTOFACTOR   = 20.0;
	private boolean finish                  = false;
	private ByteBuffer buff                 = ByteBuffer.allocateDirect(BUFFSIZE);
	private DKSMarshal marshal			    = null;
	private ListenerNB listener	            = null;
	private SelectionKey key				= null;
	private DKSNetAddress remoteNetAddr     = null;
	private ConnectionHandler connHandler   = null;
    
	/* STATE */
	private OutMsgElement currMsg           = null; 
	private byte[] varBuffArr				= null;
	private int varBuffArrPos				= 0;
	private int msgNo                       = 0;
	/* STATE */
	
	public ConnHandlerOutNB(ListenerNB l, SocketChannel s, DKSMarshal dm, ConnectionHandler ch) {
		this.sock = s;
		this.marshal = dm;
		this.listener = l;
        this.connHandler = ch;
		buff.order(ByteOrder.BIG_ENDIAN);
		buff.limit(0);
//		System.out.println("Buff rem:"+buff.remaining()+" limit:"+buff.limit());
	}

	public static synchronized SocketChannel createConnection(DKSNetAddress na) {
		try {
            log.debug("Entering createConnection");
            
			SocketChannel sock = SocketChannel.open();
            sock.configureBlocking(false);
            
//            log.debug("About TO CONNECT TO " + na.getIP() + ":" + na.getPort());
            
			boolean finish = sock.connect(new InetSocketAddress(na.getIP(), na.getPort()));

            log.debug("CONNECTED TO " + na.getIP() + ":" + na.getPort());
            
			return sock;
		} catch (IOException ex) {
			ex.printStackTrace();
		}
        log.warn("Returning null");
		return null;
	}
	
	public synchronized void connected() {
		log.debug("Presenting myself to endpoint");
		DKSNetAddress myNA = new DKSNetAddress(listener.getHostAddress(), listener.getLocalPort());
		addElement(new OutMsgElement(myNA, ConnMessageTypes.PRESENT_MSG ,null));
	}
	
	public synchronized void setRemoteAddress(DKSNetAddress netAdr) {
        remoteNetAddr = netAdr;
	}
	
	private void interestedWrite() {
		if (key!=null) {
			key.interestOps( key.interestOps() | SelectionKey.OP_WRITE );
			key.selector().wakeup();
		}
	}
	
	private void interestedWriteNot() {
		if (key!=null)
			key.interestOps( key.interestOps() & (~SelectionKey.OP_WRITE) );
	}
	
	public synchronized void setKey(SelectionKey key) {
		this.key = key;
	}

	private boolean initH() {
		currMsg = getNextElement();
		if (currMsg==null)
			return false;
		
		state = Status.SENDTYPE;
		return true;
	}
	
	private boolean sendtypeH() {
		if (buff.remaining()<1)
			return false;
	
		buff.put(currMsg.type.toByte());

		if (currMsg.type==ConnMessageTypes.CONTENTS_MSG) {
			bufferUnacked();

			state = Status.CONT_TRANS;
			
		} else if (currMsg.type==ConnMessageTypes.PRESENT_MSG) {

			BootstrapMsg m = new BootstrapMsg(currMsg.src);
			varBuffArr = m.flatten(); 
			state = Status.WRITE_ARRLEN;

		} else if (currMsg.type==ConnMessageTypes.ACK_MSG) {
			state = Status.ACK;
		}
		return true;
	}
	
	private boolean ackH() {
		if (buff.remaining()<4)
			return false;
		buff.putInt(((Integer) currMsg.load).intValue());
		state = Status.INIT;
		return true;
	}
	
	private boolean cont_transH() {
		if (buff.remaining()<1)
			return false;
		buff.put(marshal.TRANSDEFAULT);
		state = Status.CONT_SEQNR;
		return true;
	}
	
	private boolean cont_seqnrH() {
		if (buff.remaining()<4)
			return false;
		buff.putInt(msgNo);
		MsgSrcDestWrapper mesg = (MsgSrcDestWrapper) currMsg.load;
		
		varBuffArr = marshal.marshalMsgSrcDestWrapper(mesg);
		
		state = Status.WRITE_ARRLEN;
		return true;
	}
	
	private boolean write_arrlenH() {
		if (buff.remaining()<4)
			return false;
		buff.putInt(varBuffArr.length);
		state = Status.WRITE_ARR;
		return true;
	}
	
	private boolean write_arrH() {
		
		int act = Math.min(buff.remaining(), varBuffArr.length - varBuffArrPos);
		
		buff.put(varBuffArr, varBuffArrPos, act);
		varBuffArrPos+=act;
		
		if (varBuffArrPos==varBuffArr.length) {
			state = Status.INIT;
			varBuffArrPos = 0;
			varBuffArr = null;
			return true;
		} 
		return false;
	}
	
	private void stateMachine() {
//		log.debug("compacted");
		buff.compact();
		boolean cont = true;
		while (cont) {
//			log.debug("state:"+state);
			if (state==Status.INIT)
				cont = initH();
			else if (state==Status.SENDTYPE) 
				cont = sendtypeH();
			else if (state==Status.ACK) 
				cont = ackH();
			else if (state==Status.CONT_TRANS) 
				cont = cont_transH();
			else if (state==Status.CONT_SEQNR) 
				cont = cont_seqnrH();
			else if (state==Status.WRITE_ARRLEN) 
				cont = write_arrlenH();
			else if (state==Status.WRITE_ARR) 
				cont = write_arrH();
		}
		
		buff.flip();
	}

	public void run() {
		write();
	}
	
	private synchronized void write() {
		
		int wrSize = 1;

		try {

//			log.debug("running state machine");
			stateMachine();			
			
			while (wrSize>0 && buff.hasRemaining()) {
				wrSize = sock.write(buff);
                connHandler.statAddBytesSent(wrSize);

				if (!buff.hasRemaining())
					stateMachine();
			}
		
		} catch (IOException ex) {
			ex.printStackTrace();
		}

		if (wrSize==0 && buff.hasRemaining())
			interestedWrite();
		else
			interestedWriteNot();
		
		if (wrSize<0)
			listener.connectionClosed(key, wrSize, remoteNetAddr);
	}
	
	synchronized void end() {
		finish = true;
		notifyAll();
	}
	
	private synchronized OutMsgElement getNextElement(){
		if (msgQueue.isEmpty())
			return null;
		
		return finish ? null : (OutMsgElement)msgQueue.remove(0);
	}
	
	private synchronized void addElement(OutMsgElement ele){
		while (msgQueue.size()>=MAXSEND) {
			try {
				log.error("Waiting because outgoing buffer full");
				wait();
			} catch (Exception ex) {
				log.error(ex+"");
			}
		}
		msgQueue.add(ele);
		interestedWrite();
		if (key!=null) {
			key.selector().wakeup();
		}
//		if (msgQueue.size()==1)
//			notify();
	}

	private void bufferUnacked() {
		msgNo++;
		MsgSrcDestWrapper mesg = (MsgSrcDestWrapper) currMsg.load;
		unackedMsgs.put(new Integer(msgNo), new UnackedMsg(mesg)); // clock is read here
        connHandler.statAddMsgsUnacked(1);
	}
	
	public void sendAck(int msgId){
//		DKSPrint.println(DKSPrintTypes.SENDER, "Sending ack " + msgId + " to "  + destNode);
		addElement(new OutMsgElement(null, ConnMessageTypes.ACK_MSG, new Integer(msgId)));
	}
	
	public void ackReceived(int msgId){
		// This method removes a message in the out queue
		// beacuse the message has been properly delivered.
		// Moreover, the rtt is calsualted.
		//
		// Note that the rtt is not used yet!
//		DKSPrint.println(DKSPrintTypes.SENDER, "Received ack " + msgId + " from "  + destNode);
		
		UnackedMsg msg = (UnackedMsg) unackedMsgs.remove(new Integer(msgId));
		if (msg != null) {
			
			long newRTT = msg.calculateRTT();
			if( newRTT > 0 ) {
				double alpha = 0.8;
				rtt = rtt * alpha + ((double)newRTT) * (1.0 - alpha);
				log.debug("AVG=" + rtt + " NEW=" + (double)newRTT);
			}
		}
		else{
			log.warn("Received an ack for a message that does not exists :"+msgId);
		}
        connHandler.statAddMsgsUnacked(-1);
	}
	
	public void sendMessage(DKSNetAddress src, Object msg ){
//		DKSPrint.println(DKSPrintTypes.SENDER, "Sending msg  to "  + destNode);
		addElement(new OutMsgElement(src, ConnMessageTypes.CONTENTS_MSG, msg));
		listener.threadPool.addJob(this);
	}
	
	class OutMsgElement{
		public ConnMessageTypes type;
		public Object load;
		public DKSNetAddress src;
		public OutMsgElement(DKSNetAddress _src, ConnMessageTypes type, Object load){
			src = _src;
			this.type = type;
			this.load = load;
		}
	}
	
	class UnackedMsg{
		private MsgSrcDestWrapper msg;
		private long rttStart = 0;
		
		UnackedMsg(MsgSrcDestWrapper msg){
			this.msg = msg;
			rttStart = (new Date()).getTime();
		}
		
		long calculateRTT() {
			long rttEnd = (new Date()).getTime();
			long newRTT = rttEnd - rttStart;
			rttStart = 0;
			return newRTT;
		}
	}
	
	public synchronized void checkTimeouts() {
		synchronized (unackedMsgs) {
			for (Iterator it = unackedMsgs.values().iterator(); it.hasNext(); ) {
				UnackedMsg msg = (UnackedMsg) it.next();
				double elapsed = (new Date()).getTime() - msg.rttStart;
				if (elapsed > ( (double) RTOFACTOR) * rtt) {
					it.remove();
					log.warn(
							"Timeout after " + elapsed + "msec with avg rtt=" + rtt);
                    
                    connHandler.statAddMsgsFailed(1);
					marshal.failureHandler(msg.msg);
				}
			}
		}
	}
	
	private static class Status {
		public static final Status INIT = new Status("INIT");
		public static final Status SENDTYPE = new Status("SENDTYPE");
		public static final Status ACK = new Status("ACK");
		public static final Status CONT_TRANS = new Status("CONT_TRANS");
		public static final Status CONT_SEQNR = new Status("CONT_SEQNR");
		public static final Status WRITE_ARRLEN = new Status("WRITE_ARRLEN");
		public static final Status WRITE_ARR = new Status("WRITE_ARR");
		private final String name;
		private Status(String s) { name = s; }
		public String toString() { return name; }
	}
	
}
