
package org.kth.dks.dks_comm;

import java.io.IOException;
import java.nio.BufferOverflowException;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;

import org.apache.log4j.Logger;
import org.kth.dks.dks_marshal.BootstrapMsg;
import org.kth.dks.dks_marshal.DKSMarshal;
import org.kth.dks.dks_marshal.DKSMessage;

/**
 * <p>Title: DKS</p>
 * <p>Description: DKS Middleware</p>
 * <p>Copyright: Copyright (c) 2004</p>
 * <p>Company: KTH-IMIT</p>
 * @author Ali Ghodsi (aligh@kth.se)
 * @version 1.0
 */

public class ConnHandlerInNB implements Runnable {
	
    private Logger log = Logger.getLogger(ConnHandlerInNB.class);
	private static final int BUFFSIZE = 4096;
	
	Status state = Status.INIT;
	
	private final DKSMarshal marshal;
	private DKSNetAddress destRef              = null;
	private SocketChannel sock;
	private ConnHandlerOutNB connOut;
	private ByteBuffer buff = ByteBuffer.allocateDirect(BUFFSIZE);
	private ListenerNB listener;
    private ConnectionHandler connHandler;
	private SelectionKey key;
	private DKSNetAddress remoteNetAddr = null;
	
	/* STATE */
	ConnMessageTypes msgType;
	byte transType;	
	int msgID;
	int msgSize;
	private byte[] varBuff = null;
	private int varBuffPos = 0;
	/* STATE */
	
	public ConnHandlerInNB(ListenerNB l, SocketChannel s, ConnHandlerOutNB co, DKSMarshal dm, ConnectionHandler ch) {
		this.sock = s;		
		this.connOut = co;
		this.marshal = dm;
		this.listener = l;
        this.connHandler = ch;
		buff.order(ByteOrder.BIG_ENDIAN);
	}
	
	public synchronized void setRemoteAddress(DKSNetAddress netAdr) {
        remoteNetAddr = netAdr;
	}
	
	public synchronized void setKey(SelectionKey key) {
		this.key = key;
	}

	private void interestedRead() {
		if (key!=null) {
			key.interestOps( key.interestOps() | SelectionKey.OP_READ );
			key.selector().wakeup();
		}
	}
	
//	private void interestedReadNot() {
//		if (key!=null)
//			key.interestOps( key.interestOps() & (~SelectionKey.OP_READ) );
//	}
//	
	private boolean initH() {
		if (buff.remaining()<1)
			return false;
		byte msgTypeByte = buff.get();
		msgType = ConnMessageTypes.valueOf(msgTypeByte);
		if (msgType==ConnMessageTypes.CONTENTS_MSG) {
			state = Status.CONTENT;
		} else if (msgType==ConnMessageTypes.ACK_MSG) {
			state = Status.ACK;
		} else if (msgType==ConnMessageTypes.PRESENT_MSG) {
			state = Status.READBLOCK;
		} else {
			log.error("ERROR: could not move out of INIT state:"+msgTypeByte);
		}
		return true;
	}
	
	private boolean contentH() {
		if (buff.remaining()<1)
			return false;
		transType = buff.get();
		state = Status.CONTENT_TYPE;
		return true;
	}
	
	private boolean content_typeH() {
		if (buff.remaining()<4)
			return false;
		msgID = buff.getInt();
		state = Status.READBLOCK;
		return true;
	}
	
	private boolean ackH() {
		if (buff.remaining()<4)
			return false;
		msgID = buff.getInt();
		connOut.ackReceived(msgID);
		connHandler.statAddMsgsDelivered(1);
		state = Status.INIT;
		return true;
	}
	
	private boolean readblockH() {
		if (buff.remaining()<4)
			return false;
		msgSize = buff.getInt();
		state = Status.READBLOCK_DATA;
		return true;
	}
	
	private boolean readblock_dataH() {
		
		if (buff.remaining() == 0)
			return false;
		
		int rSz = Math.min(buff.remaining(), msgSize-varBuffPos);
		
		if (varBuffPos==0)
			varBuff = new byte[msgSize];
		
		buff.get(varBuff, varBuffPos, rSz);
		
		varBuffPos += rSz;	

		if (varBuffPos == msgSize) {
			
			varBuffPos = 0;
			
			if (msgType==ConnMessageTypes.PRESENT_MSG) {
				fin_presH();
			} else if (msgType==ConnMessageTypes.CONTENTS_MSG) {
				fin_contH();
			}
		}
		return true;	
	}
	
	private void fin_contH() {
		connOut.sendAck(msgID);
		if (false==marshal.unmarshalDispatch(transType, varBuff, destRef))
			log.warn(  "Msg not accepted, CommunicationBuffer overflow");
		varBuff = null;
		connHandler.statAddMsgsReceived(1);
		state = Status.INIT;
	}
	
	private void fin_presH() {
		BootstrapMsg m = (BootstrapMsg) DKSMessage.unmarshal(varBuff);
		varBuff = null;
		destRef = m.getNetAddress();		
		listener.identifiedNode(this, destRef);
		state = Status.INIT;
	}
	
	private void stateMachine() {
		/*
		 type=readByte;
		 if (CONTENT) 
		 transType = readbytE;
		 id = readInt;
		 ReadBlock;
		 else if (ACK)
		 id = readInt;
		 else if (PRESENT)
		 READBLOCK;
		 */
		boolean cont = true;
		while (cont) {
//			System.out.println("State Machine, state:"+state);
			if (state==Status.INIT) 
				cont=initH();
			else if (state==Status.CONTENT) 
				cont=contentH();
			else if (state==Status.CONTENT_TYPE) 
				cont=content_typeH();
			else if (state==Status.ACK) 
				cont=ackH();
			else if (state==Status.READBLOCK) 
				cont=readblockH();
			else if (state==Status.READBLOCK_DATA) 
				cont=readblock_dataH();
		}
	}

	public void run() {
		read();
	}
	
	private synchronized void read() {
		int bytesRead = 0;
		try {
			bytesRead = sock.read(buff);
            connHandler.statAddBytesReceived(bytesRead);
		
			if (bytesRead>0) {
				buff.flip();

				stateMachine();
				
				buff.compact();
			}
			
		} catch(IOException ex) {
			log.error("IOException inside ConnHandlerIn, should close this class:"+ex);
		}
		if (bytesRead<0)
			listener.connectionClosed(key, bytesRead, remoteNetAddr);
		else
			interestedRead();	
		log.debug("Got out "+bytesRead+" key:"+ ((SocketChannel)key.channel()).
                socket().getInetAddress() + ":" + ((SocketChannel)key.channel()).
                socket().getPort());
	}
	
	private static class Status {
		public static final Status INIT = new Status("INIT");
		public static final Status CONTENT = new Status("CONTENT");
		public static final Status CONTENT_TYPE = new Status("CONTENT_TYPE");
		public static final Status ACK = new Status("ACK");
		public static final Status READBLOCK = new Status("READBLOCK");
		public static final Status READBLOCK_DATA = new Status("READBLOCK_DATA");
		private final String name;
		private Status(String s) { name = s; }
		public String toString() { return name; }
	}
	
}
