package org.planx.xmlstore.routing.messaging;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.Timer;
import java.util.TimerTask;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.mina.core.RuntimeIoException;
import org.apache.mina.core.buffer.AbstractIoBuffer;
import org.apache.mina.core.future.ConnectFuture;
import org.apache.mina.core.future.WriteFuture;
import org.apache.mina.core.service.IoAcceptor;
import org.apache.mina.core.service.IoHandler;
import org.apache.mina.core.service.IoHandlerAdapter;
import org.apache.mina.core.session.IoSession;
import org.apache.mina.filter.codec.ProtocolCodecFilter;
import org.apache.mina.filter.codec.serialization.ObjectSerializationCodecFactory;
import org.apache.mina.filter.logging.LoggingFilter;
import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
import org.apache.mina.transport.socket.nio.NioSocketConnector;

/**
 * Listens for incoming UDP messages and provides a framework for sending messages
 * and responding to received messages.
 * Two threads are started: One that listens for incoming messages and one that
 * handles timeout events.
 **/
public class MessageServer {
    
    private static final Log log = LogFactory.getLog(MessageServer.class);
    
    /** a single instance to define a codec for serialize and unserialize */
    protected ProtocolCodecFilter codec;

    // TODO 20100622 bleny should be configurable 
    /** maximum size of a serializable object of a message */
    protected static final int MAX_MESSAGE_SIZE = 2 * 10485760; // 10 MiB

    protected IoAcceptor acceptor;
    
    private static Random random = new Random();
    private MessageFactory factory;
    private long timeout;
    private Timer timer;
    private boolean isRunning = true;
    private Map receivers; // keeps track of registered receivers
    private Map tasks;     // keeps track of timeout events

    /**
     * Constructs a MessageServer listening on the specified UDP port using the
     * specified MessageFactory for interpreting incoming messages.
     *
     * @param udpPort The UDP port on which to listen for incoming messages
     * @param factory Factory for creating Message and Receiver objects
     * @param timeout The timeout period in milliseconds
     *
     * @throws SocketException if the socket could not be opened, or the socket
     *                         could not bind to the specified local port
     **/
    public MessageServer(int udpPort, MessageFactory factory, long timeout)
                                                    throws SocketException {
        this.factory = factory;
        this.timeout = timeout;
        timer = new Timer(true);
        receivers = new HashMap();
        tasks = new HashMap();
        
        ObjectSerializationCodecFactory serialization = 
            new ObjectSerializationCodecFactory();
        serialization.setDecoderMaxObjectSize(MAX_MESSAGE_SIZE);
        codec = new ProtocolCodecFilter(serialization);
        
        // acceptor is like a ServerSocket, it will receive messages
        acceptor = new NioSocketAcceptor();
        //acceptor.getSessionConfig().setReadBufferSize( 2048 );

        // acceptor will serialize messages
        acceptor.getFilterChain().addLast("codec", codec);
        acceptor.getFilterChain().addLast("logger", new LoggingFilter());
        
        // received messages will be manager by a MyHandler instance  
        IoHandler handler = new MyHandler(this);
        acceptor.setHandler(handler);
        
        // bind the acceptor on the specified port
        try {
            acceptor.bind(new InetSocketAddress(udpPort));
        } catch (IOException e) {
            log.error("unable to bind acceptor on port " + udpPort, e);
            throw new SocketException("unable to bind acceptor on port "
                                                                + udpPort + e);
        }
    }
    
    /** this classes handles all the messages received by the DHT */
    protected class MyHandler extends IoHandlerAdapter {

        protected MessageServer messageServer;
        
        public MyHandler(MessageServer messageServer) {
            this.messageServer = messageServer;
        }
        
        /** deal with messages received
         * those messages are byte-arrays sent with
         * {@link #sendMessage(int, Message, InetAddress, int)}
         */
        @Override
        public void messageReceived(IoSession session, Object msg)
                throws Exception {
            
            ByteArrayInputStream bin = new ByteArrayInputStream((byte[]) msg);

            log.debug("message received (" + bin.available() + " bytes)");
            DataInputStream din = new DataInputStream(bin);
            int comm = din.readInt();
            byte messCode = din.readByte();
            Message message = factory.createMessage(messCode, din);
            din.close();

            // Create Receiver if one is supported
            Receiver recv = null;
            recv = factory.createReceiver(messCode, messageServer);

            // If no receiver, get registered Receiver, if any
            if (recv == null) {
                synchronized (this) {
                    Integer key = new Integer(comm);
                    recv = (Receiver) receivers.remove(key);
                    // Cancel timer if there was a registered Receiver
                    if (recv != null) {
                        TimerTask task = (TimerTask) tasks.remove(key);
                        task.cancel();
                    }
                }
            }

            // Invoke Receiver if one was found
            if (recv != null) {
                recv.receive(message, comm);
            }
        }
        
        @Override
        public void exceptionCaught(IoSession session, Throwable cause)
                throws Exception {
            log.error("exception caught network " + cause);
        }
    }

    /**
     * Sends the specified Message and calls the specified Receiver when a reply
     * for the message is received. If <code>recv</code> is <code>null</code>
     * any reply is ignored. Returns a unique communication id which can be used
     * to identify a reply.
     **/
    public synchronized int send(Message message, InetAddress ip,
                                 int port, Receiver recv) throws IOException {
        if (!isRunning) throw new IllegalStateException("MessageServer not running");
        int comm = random.nextInt();
        if (recv != null) {
            Integer key = new Integer(comm);
            receivers.put(key, recv);
            TimerTask task = new TimeoutTask(comm, recv);
            timer.schedule(task, timeout);
            tasks.put(key, task);
        }
        sendMessage(comm, message, ip, port);
        return comm;
    }

    /**
     * Sends a reply to the message with the specified communication id.
     **/
    public synchronized void reply(int comm, Message message, InetAddress ip,
                                                int port) throws IOException {
        if (!isRunning) throw new IllegalStateException("MessageServer not running");
        sendMessage(comm, message, ip, port);
    }
    
    private void sendMessage(int comm, Message message, InetAddress ip, int port)
                                                              throws IOException {
        ByteArrayOutputStream bout = new ByteArrayOutputStream();
        DataOutputStream dout = new DataOutputStream(bout);
        dout.writeInt(comm);
        dout.writeByte(message.code());
        message.toStream(dout);
        dout.close();

        byte[] data = bout.toByteArray();


        NioSocketConnector connector = new NioSocketConnector();
        // Configure the service.
        connector.setConnectTimeoutMillis(10*1000);
        connector.getFilterChain().addLast("codec", codec);
        connector.getFilterChain().addLast("logger", new LoggingFilter());        
        connector.setHandler(new IoHandlerAdapter());
        
        IoSession session = null;
        SocketAddress dest = new InetSocketAddress(ip, port);
        try {
            ConnectFuture future = connector.connect(dest);
            future.awaitUninterruptibly();
            session = future.getSession();
            log.debug("send " + data.length + " bytes");
            WriteFuture future2 = session.write(data);
            // need to wait, or large values won't be well sent 
            future2.awaitUninterruptibly();
        } catch (RuntimeIoException e) {
            log.error("unable to connect to " + dest, e);
            throw new IOException(e);
        } finally {
            connector.dispose();
        }

    }

    private synchronized void unregister(int comm) {
        Integer key = new Integer(comm);
        receivers.remove(key);
        tasks.remove(key);
    }

    /**
     * Signals to the MessageServer thread that it should stop running.
     **/
    public synchronized void close() {
        if (!isRunning) throw new IllegalStateException("MessageServer not running");
        isRunning = false;
        acceptor.dispose();
        timer.cancel();
        receivers.clear();
        tasks.clear();
    }

    /**
     * Task that gets called by a separate thread if a timeout for a receiver occurs.
     * When a reply arrives this task must be cancelled using the <code>cancel()</code>
     * method inherited from <code>TimerTask</code>. In this case the caller is
     * responsible for removing the task from the <code>tasks</code> map.
     **/
    class TimeoutTask extends TimerTask {
        private int comm;
        private Receiver recv;

        public TimeoutTask(int comm, Receiver recv) {
            this.comm = comm;
            this.recv = recv;
        }

        public void run() {
            try {
                unregister(comm);
                recv.timeout(comm);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
}
