
package org.kth.dks.dks_comm;
import java.io.IOException;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.net.ServerSocket;
import java.net.SocketTimeoutException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Timer;
import java.util.TimerTask;
import java.util.Vector;
import java.util.Map.Entry;

import org.apache.log4j.Logger;
import org.kth.dks.dks_marshal.DKSMarshal;
import org.kth.dks.dks_marshal.MsgSrcDestWrapper;
import org.kth.dks.util.AsyncOperation;
import org.kth.dks.util.AtomicBoolean;
import org.kth.dks.util.Pair;
import org.kth.dks.util.Triple;

/**
 * <p>Title: DKS</p>
 * <p>Description: DKS Middleware</p>
 * <p>Copyright: Copyright (c) 2004</p>
 * <p>Company: KTH-IMIT</p>
 * @author Ali Ghodsi (aligh@kth.se)
 * @version 1.0
 */

class ListenerNB extends Thread implements Listener {
    
    private static Logger log = Logger.getLogger(ListenerNB.class);  
    
    private static final int BACKLOG = 20;
    private int port = 0;
    private boolean finish = false;
    private Vector connList = new Vector();
    private ConnectionManager a_conMan;
    private ConnectionHandler connHandler;
    private final ServerSocket serverSocket;
    
    private ServerSocketChannel serverChannel; 
    private Selector selector;
    private SelectionKey myKey;
    private Map socketMap = Collections.synchronizedMap( new HashMap() );
    private Map anonNodes = Collections.synchronizedMap( new HashMap() );
    private List connectList = Collections.synchronizedList( new LinkedList() );
    private AtomicBoolean checkTimeouts = new AtomicBoolean();
    ThreadPool threadPool = null;
	private final Timer rttTimer;
    
    private int statisticsBytesReceived            = 0;
    private int statisticsMsgsReceived             = 0;
    
    //new
    public DKSMarshal dksMarshal = null;
    
    public ListenerNB( int _port, ConnectionManager cm, DKSMarshal dksm, ConnectionHandler ch ) throws IOException
    {
        this(_port, cm, dksm, ch, null);
    }
    
    public ListenerNB( int _port, ConnectionManager cm, DKSMarshal dksm, ConnectionHandler ch, InetAddress bindAddr ) throws IOException
    {
        this.port = _port;
        dksMarshal = dksm;
        a_conMan = cm;
        connHandler = ch;
        serverChannel =  ServerSocketChannel.open();
        serverSocket = serverChannel.socket();
        if (bindAddr!=null) 
            serverSocket.bind(new InetSocketAddress(bindAddr, _port), BACKLOG);
        else
            serverSocket.bind(new InetSocketAddress(_port), BACKLOG);
        serverChannel.configureBlocking(false);
        
        selector = Selector.open();
        myKey = serverChannel.register(selector, SelectionKey.OP_ACCEPT);
        
        serverSocket.setSoTimeout(1000);
        this.setName(ListenerNB.class.getName());
        
        threadPool = ThreadPool.getInstance();
        
		TimerTask timeoutDetector = new TimerTask() {
			public void run() {
				checkTimeouts.set(true);
				selector.wakeup();
			}
		};
		
        rttTimer = new Timer(ListenerNB.class.getName() + ".RTTTimer");
		rttTimer.schedule(timeoutDetector, 0, 8000);
		
        this.start();
    }
    
    public void end() {
        finish = true;
        threadPool.end();
        rttTimer.cancel();
    }
    
    public Pair createHandlers(SocketChannel newSock) {
        ConnHandlerOutNB cOut = new ConnHandlerOutNB(this, newSock, a_conMan.getDKSMarshal(), connHandler);
        ConnHandlerInNB cIn   = new ConnHandlerInNB(this, newSock, cOut, a_conMan.getDKSMarshal(), connHandler);
        Pair connPair = new Pair(cIn, cOut);
        connHandler.statAddOpenConnection(1);
        connHandler.statAddTotalConnection(1);
        return connPair;
    }
    
	public void prepareTimeoutEvents() {
		for (Iterator it = socketMap.entrySet().iterator(); it.hasNext(); ) {
			Entry entry = (Entry)it.next();
			Pair pair = (Pair)entry.getValue();
			final ConnHandlerOutNB cout = (ConnHandlerOutNB) pair.second();
			Runnable task = new Runnable() {
				public void run() {
					cout.checkTimeouts();
				}
			};
			threadPool.addJob(task);
		}
	}

	public synchronized void identifiedNode(ConnHandlerInNB connId, DKSNetAddress na) {
        
        Pair pair = (Pair) anonNodes.remove(connId);;
        if (pair==null) {
            System.err.println("Identification of node failed: "+na);
            return;
        }
        ((ConnHandlerInNB)pair.first()).setRemoteAddress(na);
        ((ConnHandlerOutNB)pair.second()).setRemoteAddress(na);
        socketMap.put(na, pair);
    }
    
    public synchronized void connectionClosed(SelectionKey key, int status, DKSNetAddress remAddr) {
        key.cancel();
        if (remAddr!=null) {
        	log.error( "Removing dead connection (status:"+status+"), DKSNetAddress established");
        	socketMap.remove(remAddr);
        } else {
        	log.error( "Removing dead connection (status:"+status+"), , DKSNetAddress not established");
            Pair pair = (Pair)key.attachment();
            ConnHandlerInNB  cin  = (ConnHandlerInNB)pair.first();
        	if (cin!=null) 
        		anonNodes.remove(cin);
        }
        connHandler.statAddOpenConnection(-1);
    }
    
    public void checkTimeouts() {
    	
    }
    
    public void run() {
        
        while ( !finish ) {
            try {
                
                registerNewConnections();
                
//                log.debug("Selecting");
                int num = selector.select();
                
                log.debug("Selector returned:"+num);

                if (checkTimeouts.compareAndSet(true, false)) {
                	prepareTimeoutEvents();
                }
                
                if (num==0)
                    continue;
                
                Set selKeys = selector.selectedKeys();
                
                for (Iterator it = selKeys.iterator(); it.hasNext(); ) {
                    SelectionKey currKey = (SelectionKey)it.next();
                    if (currKey.isValid() && currKey.isAcceptable()) {
                        SocketChannel newSock = ((ServerSocketChannel)currKey.channel()).accept();
                        
                        log.debug("Accepted from:"+newSock.socket().getRemoteSocketAddress());
                        
                        newSock.configureBlocking(false);
                        Pair pair = createHandlers(newSock);
                        anonNodes.put(pair.first(), pair);
                        
                        SelectionKey key = newSock.register(selector, 
                                SelectionKey.OP_READ | SelectionKey.OP_WRITE, pair);
                        ((ConnHandlerInNB)pair.first()).setKey(key);
                        ((ConnHandlerOutNB)pair.second()).setKey(key);
                        
                    } 
                    if (currKey.isValid() && currKey.isConnectable()) {
//                      System.out.println("Connected to "+currKey.channel());
                        SocketChannel sock = (SocketChannel) currKey.channel();
                        boolean doneConnect=false;
                        try {
                        	 doneConnect = sock.finishConnect();
                        } catch(IOException ex) {
                        	log.warn("IOException thrown when connecting to "+sock.socket()+"\n"+ex);
                        }
                        
                        Triple triple = (Triple) currKey.attachment();
                        
                        AsyncOperation syn = (AsyncOperation) triple.third();
                                                
                        if (doneConnect) { 
                            Pair pair = new Pair(triple.first(), triple.second());
                            currKey.attach(pair);
                            currKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE); 
                            
                            log.debug("Properly connected to endpoint");
                            
                            ((ConnHandlerInNB)triple.first()).setKey(currKey);
                            ((ConnHandlerOutNB)triple.second()).setKey(currKey);
                            
                            ((ConnHandlerOutNB)triple.second()).connected();
                            
                            syn.complete(Boolean.TRUE);
                        } else {
                            currKey.cancel();
                            syn.complete(Boolean.FALSE);
                        }
                    }
                    if (currKey.isValid() && currKey.isWritable()) {
//                    	log.debug("Writing selected");
                    	Pair connPair = (Pair)currKey.attachment();
                        ConnHandlerOutNB cOut = (ConnHandlerOutNB)connPair.second();
//                      
                        currKey.interestOps(currKey.interestOps() & (~SelectionKey.OP_WRITE));
//                      log.debug("Enqueueing outgoing message");
                        threadPool.addJob(cOut);
                    }
                    if (currKey.isValid() && currKey.isReadable()) {
//                    	log.debug("Reading selected");
                    	Pair connPair = (Pair)currKey.attachment();
                    	ConnHandlerInNB cIn = (ConnHandlerInNB)connPair.first();
                    	currKey.interestOps(currKey.interestOps() & (~SelectionKey.OP_READ));
                    	threadPool.addJob(cIn);
                    } 
                    
                    it.remove();
                }
                
            }
            catch (SocketTimeoutException ste ) {
            }
            catch (Exception ex1) {
                ex1.printStackTrace();
            }
            
        }
        
        
        while( ! connList.isEmpty() ) {
            ConnHandlerIn h = (ConnHandlerIn)connList.elementAt(0);
            h.end();
            connList.removeElementAt(0);
        }
        
        end();
    } //run
    
    private void registerNewConnections() {
        while (!connectList.isEmpty()) {
            
            Triple destSockSyn = (Triple) connectList.remove(0);
            DKSNetAddress dest = (DKSNetAddress) destSockSyn.first();
            SocketChannel sock = (SocketChannel) destSockSyn.second();
            AsyncOperation syn = (AsyncOperation) destSockSyn.third();
            
            Pair pair = createHandlers(sock);
            
            ((ConnHandlerInNB)pair.first()).setRemoteAddress(dest);
            ((ConnHandlerOutNB)pair.second()).setRemoteAddress(dest);
            
            Triple triple = new Triple(pair.first(), pair.second(), syn);
            
            socketMap.put(dest, pair);
            
            ConnHandlerOutNB cOut = (ConnHandlerOutNB) pair.second();
            
            try {
                SelectionKey key = sock.register(selector, 
                        SelectionKey.OP_CONNECT,   
                        triple);
                
                log.debug("everything ok here");
                
            } catch (ClosedChannelException ex) {
                ex.printStackTrace();
                syn.complete(Boolean.FALSE);
            }
        }
    }
    
    public boolean createConnection(DKSNetAddress dest) {
        SocketChannel sock = ConnHandlerOutNB.createConnection(dest);
        if (sock==null)
            return false;
        
        AsyncOperation sync = AsyncOperation.start();
        
        Triple destSockSync = new Triple(dest, sock, sync);
        
        connectList.add(destSockSync);
        
        selector.wakeup();
        
        /* TODO join future */
        try {
            Boolean b = (Boolean)sync.waitOn();
            return b.booleanValue();
            
        } catch (InterruptedException ex) {
            ex.printStackTrace();
            return false;
        } catch (Exception ex) {
            ex.printStackTrace();
            return false;
        }
        
    }
    
    public boolean send(MsgSrcDestWrapper triple)
    {
        
        DKSNetAddress src = triple.getSrc().getDKSNetAddress();
        DKSNetAddress dest = triple.getDest().getDKSNetAddress();
        
        if (!socketMap.containsKey(dest)){
            
            log.debug("we have to create the connection first");
            
            if (!createConnection(dest)) 
                return false;
            
            log.debug("connection created");
        }
        
        Pair p = (Pair) socketMap.get(dest);
        ConnHandlerOutNB o = (ConnHandlerOutNB) p.second();
        
        o.sendMessage(src, triple);
        
        log.debug("message enqueued sent");
        
        return true;
    }
    
    
    public int getLocalPort() {
        return serverSocket.getLocalPort();
    }
    
    public String getHostAddress() {
        byte[] ip = serverSocket.getInetAddress().getAddress();
        if (ip[0]==0 && ip[1]==0 && ip[2]==0 && ip[3]==0) {
            String currIp = null;
            try
            {
                Enumeration cards = NetworkInterface.getNetworkInterfaces();
                while (cards.hasMoreElements()) {
                    NetworkInterface currNet = (NetworkInterface)cards.nextElement(); 
                    Enumeration ads = currNet.getInetAddresses();
                    
                    while (ads.hasMoreElements()) {
                        InetAddress adr = (InetAddress) ads.nextElement();
                        
                        String c = adr.getHostAddress();
                        if (adr instanceof Inet4Address && !c.equals("127.0.0.1") && 
                        		!c.startsWith("169.") &&
                        		!c.startsWith("10.") &&
                        		!c.startsWith("192.")) 
                            currIp=c;
                    }
                }
                if (currIp==null) {
                    InetAddress ia =  InetAddress.getLocalHost();
                    currIp = ia.getHostAddress();
                }
                return currIp;
            } catch (Exception ex) { ex.printStackTrace(); }
            
        } else {
            return serverSocket.getInetAddress().getHostAddress();
        }
        
        return null;
    }
    
} //Listener class
