package org.kth.dks.planetlab;

import java.net.InetAddress;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Random;

import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import org.kth.dks.DKSObject;
import org.kth.dks.dks_comm.ConnectionManager;
import org.kth.dks.dks_comm.DKSOverlayAddress;
import org.kth.dks.dks_comm.DKSRef;
import org.kth.dks.dks_dht.DKSDHTCallback;
import org.kth.dks.dks_dht.DKSDHTImpl;
import org.kth.dks.dks_exceptions.DKSIdentifierAlreadyTaken;
import org.kth.dks.dks_exceptions.DKSRefNoResponse;
import org.kth.dks.dks_exceptions.DKSTooManyRestartJoins;
import org.kth.dks.dks_marshal.DKSMessage;
import org.kth.dks.dks_node.DKSNode;
import org.kth.dks.planetlab.messages.KeepAliveMsg;
import org.kth.dks.planetlab.messages.LiveNodesMsg;
import org.kth.dks.planetlab.messages.PingMsg;
import org.kth.dks.planetlab.messages.PongMsg;
import org.kth.dks.planetlab.messages.StrechMeasurementMsg;
import org.kth.dks.util.AsyncOperation;
import org.kth.dks.util.OperationType;

public class PlanetLabDKS implements DKSDHTCallback {

	private static Logger log = Logger.getLogger(PlanetLabDKS.class);

	private static long myID;

	private static int myPort;

	private static String myIP;

	private static String myHostname;

	private static long myLongIp;

	private static final String bootstrapIp = "193.10.64.35";

	private static final long bootstrapID = Long.parseLong("0b4c3243", 16);

	private static final int aliveSeconds = 900; /* 15 x 60 */

	private static final int keepAliveInterval = 300; /* 5 x 60 */
	
	private static final int pingTimeout = 60;

	private static final int minRandomSleep = 10;

	private static final int maxRandomSleep = 20;

	private ConnectionManager cm = null;

	private DKSDHTImpl dks = null;

	private DKSOverlayAddress myOA = null;

	private DKSRef myDKSRef = null;

	private DKSRef bootDKSRef = null;

	private boolean direct = false;

	private PreparedStatement psConnSelect = null;

	private PreparedStatement psConnReplace = null;

	private PreparedStatement psConnIncLoad = null;

	private PreparedStatement psStrech = null;

	private Object lock = null;

	private LiveNodesMsg message = null;

	private Random random = null; 
	
	private DKSRef currentResponsible = null;
	
	/**
	 * @param args
	 */
	public static void main(String[] args) {
		checkUsage(args);

		PropertyConfigurator.configure(System
				.getProperty("org.apache.log4j.config.file"));

		log.info("log4j properly configured");

		PlanetLabDKS object = new PlanetLabDKS();
		object.go();
	}

	private void go() {
		/* start the DKS node */
		try {
			cm = ConnectionManager.getInstanceMultiHome(myPort, InetAddress
					.getByName(myIP));

			cm.getWebServer().setHostname(myHostname);
			cm.getWebServer().setIp(myIP);
			cm.getWebServer().setPort(new Integer(myPort).toString());

			myDKSRef = DKSRef.valueOf("dksref://" + myIP + ":" + myPort + "/0/"
					+ myID + "/0/0");
			bootDKSRef = DKSRef.valueOf("dksref://" + bootstrapIp + ":"
					+ myPort + "/0/" + bootstrapID + "/0/0");

			myOA = new DKSOverlayAddress("DKSOverlay://0/" + myID + "/0");
			dks = new DKSDHTImpl(cm, myOA, this);
		} catch (Exception e) {
			log.error(e.getMessage());
			System.exit(1);
		}
		
		random = new Random();
		
		/* check direct MySQL connectivity */
		Connection connection = checkMySQLConnectivity();
		if (connection != null) {
			actAsDirectlyConnectedNode(connection);
		} else {
			actAsNonConnectedNode();
		}

	}

	private void actAsDirectlyConnectedNode(Connection connection) {
        
        log.debug("ACTING AS DIRECT-CONNECTED");
        
		direct = true;

		/* prepare MySQL statements */
		prepareMySQLStatements(connection);

		/* install handlers for messages */
		installMessageHandlers();

		/* get active nodes */
		ArrayList<DKSRef> activeNodes = getActiveNodes(aliveSeconds);
		if (activeNodes == null) {
			System.exit(1);
		}

		if (activeNodes.size() == 0) {
			/* start a new ring */
			dks.create();
		} else {
			/* join through a random node */
			boolean joined = false;
			do {
				int which = random.nextInt(activeNodes.size());
				DKSRef ref = activeNodes.get(which);
				try {
					dks.join(ref);
					joined = true;
				} catch (DKSTooManyRestartJoins e) {
					log.error(e.getMessage());
				} catch (DKSIdentifierAlreadyTaken e) {
					log.error(e.getMessage());
				} catch (DKSRefNoResponse e) {
					log.error(e.getMessage());
				}
				/* try each in turn */
				activeNodes.remove(which);
			} while (!joined && !activeNodes.isEmpty());

			if (!joined) {
				log.fatal("Could not contact any active nodes");
				System.exit(1);
			}
		}

		log.info("COMPLETED JOIN");

		/* report that I am active */
		keepAlive(myLongIp, myHostname, myIP, myPort, myID, true);

		work();
	}

	private void actAsNonConnectedNode() {
        
        log.debug("ACTING AS NON-CONNECTED");
        
		direct = false;
		lock = new Object();

		installMessageHandlers();
		LiveNodesMsg request = new LiveNodesMsg(LiveNodesMsg.Type.REQUEST, 0,
				null);

		ArrayList<DKSRef> refs;
		Random random = new Random();
		boolean joined = false;
		DKSRef myDCNDKSRef = null;

		do {
			/* repeat asking until receiving nonempty list of refs */
			do {
				dks.send(bootDKSRef, request);

				synchronized (lock) {
					while (message == null) {
						try {
							lock.wait();
						} catch (InterruptedException e) {
							continue;
						}
					}
				}
				refs = message.getRefs();
			} while (refs.size() == 0);

			/* repeatedly try to join through each one in turn */
			do {
				int which = random.nextInt(refs.size());
				DKSRef ref = refs.get(which);
				try {
					dks.join(ref);
					myDCNDKSRef = ref;
					joined = true;
				} catch (DKSTooManyRestartJoins e) {
					log.error(e.getMessage());
				} catch (DKSIdentifierAlreadyTaken e) {
					log.error(e.getMessage());
				} catch (DKSRefNoResponse e) {
					log.error(e.getMessage());
				}
				refs.remove(which);
			} while (!joined && !refs.isEmpty());

			/* if not joined retry */
		} while (!joined);

		/* we joined, now we tell this guy to increment its load */
		ArrayList<DKSRef> list = new ArrayList<DKSRef>();
		list.add(myDCNDKSRef);
		LiveNodesMsg choice = new LiveNodesMsg(LiveNodesMsg.Type.CHOICE, 1,
				list);
		dks.send(myDCNDKSRef, choice);

		log.info("COMPLETED JOIN");

		/* report that I am active */
		keepAlive(myLongIp, myHostname, myIP, myPort, myID, false);

		work();
	}

	private void work() {
		/* start keep alive thread */
		new KeepAliveThread(keepAliveInterval, this).start();

		/* take and store measurements */
		while (true) {
			sleepRandom();

			int dksRtt, directRtt;
			do {
				dksRtt = sendRandomLookup();
			} while (currentResponsible == null || currentResponsible.equals(myDKSRef));

			directRtt = sendDirectPing(currentResponsible);

			if (directRtt == -1)
				continue;

			storeMeasurement(myLongIp, HostUtils.getIpAsLong(currentResponsible.getIP()),
					dksRtt, directRtt, myDKSRef, myHostname);
		}
	}

	private int sendRandomLookup() {
		long key = (random.nextLong() % DKSNode.N);
		log.info("SENDING RANDOM LOOKUP (" + key + ")");
		long sentTime = System.currentTimeMillis();
		currentResponsible = dks.findResponsible(key);
		long receivedTime = System.currentTimeMillis();
		return (int) (receivedTime - sentTime);
	}

	public int sendDirectPing(DKSRef target) {
		long sentTime = System.currentTimeMillis();

		AsyncOperation pingOp = AsyncOperation.start(OperationType.FINDTYPE);

		PingMsg pingMsg = new PingMsg(sentTime, pingOp.getKey());

		dks.send(target, pingMsg);
		
		log.info("Sent PING to " + target.getIP());

		Integer rtt = null;
		try {
			rtt = (Integer) pingOp.waitOn(pingTimeout * 1000);
		} catch (Exception e) {
			pingOp.cancel();
			log.error("***PING TIMEOUT***");
			return -1;
		}

		if (rtt == null)
			return -1;

		return rtt.intValue();
	}

	private void prepareMySQLStatements(Connection connection) {
		try {
			psConnSelect = connection
					.prepareStatement("SELECT c.IP, c.PORT, c.ID FROM connectivity c"
							+ " WHERE TIME_TO_SEC(TIMEDIFF(LAST_SEEN, now())) <= ? "
							+ " AND c.DC=true ORDER BY c.LOAD ASC");

			psConnReplace = connection
					.prepareStatement("INSERT INTO connectivity(LONGIP, HOSTNAME, "
							+ "IP, PORT, ID, LAST_SEEN, DC) VALUES(?, ?, ?, ?, ?, now(), ?)"
							+ " ON DUPLICATE KEY UPDATE LAST_SEEN=now()");

			psConnIncLoad = connection.prepareStatement("UPDATE connectivity "
					+ "SET LOAD=LOAD+1 WHERE LONGIP = ?");

			psStrech = connection
					.prepareStatement("INSERT INTO address_strech "
							+ "VALUES(?, ?, ?, ?, ?)");
		} catch (SQLException e) {
			log.error("Exception preparing statements: " + e.getMessage());
			System.exit(1);
		}
	}

	private void installMessageHandlers() {
		DKSMessage.addMessageTypePrefixed(PingMsg.NAME,
				"planetlab.messages.PingMsg");
		dks.myDKSImpl.getDKSMarshal().addMsgHandlerPrefixed(myOA,
				"planetlab.messages.PingMsg", "planetlab.PlanetLabDKS",
				"pingMsgHandler", this);

		DKSMessage.addMessageTypePrefixed(PongMsg.NAME,
				"planetlab.messages.PongMsg");
		dks.myDKSImpl.getDKSMarshal().addMsgHandlerPrefixed(myOA,
				"planetlab.messages.PongMsg", "planetlab.PlanetLabDKS",
				"pongMsgHandler", this);

		DKSMessage.addMessageTypePrefixed(StrechMeasurementMsg.NAME,
				"planetlab.messages.StrechMeasurementMsg");
		dks.myDKSImpl.getDKSMarshal().addMsgHandlerPrefixed(myOA,
				"planetlab.messages.StrechMeasurementMsg",
				"planetlab.PlanetLabDKS", "strechMeasurementMsgHandler", this);

		DKSMessage.addMessageTypePrefixed(LiveNodesMsg.NAME,
				"planetlab.messages.LiveNodesMsg");
		dks.myDKSImpl.getDKSMarshal().addMsgHandlerPrefixed(myOA,
				"planetlab.messages.LiveNodesMsg", "planetlab.PlanetLabDKS",
				"liveNodesMsgHandler", this);

		DKSMessage.addMessageTypePrefixed(KeepAliveMsg.NAME,
				"planetlab.messages.KeepAliveMsg");
		dks.myDKSImpl.getDKSMarshal().addMsgHandlerPrefixed(myOA,
				"planetlab.messages.KeepAliveMsg", "planetlab.PlanetLabDKS",
				"keepAliveMsgHandler", this);
	}

	private synchronized void storeMeasurement(long fromIp, long toIp,
			int dksRttMs, int ipRttMs, DKSRef from, String hostname) {

		if (direct) {
			double strech = ((double) dksRttMs) / ((double) ipRttMs);
			try {
				psStrech.setLong(1, fromIp);
				psStrech.setLong(2, toIp);
				psStrech.setInt(3, dksRttMs);
				psStrech.setInt(4, ipRttMs);
				psStrech.setDouble(5, strech);

				psStrech.executeUpdate();
			} catch (SQLException e) {
				e.printStackTrace();
			}
		} else {
			/* send the measurement to directly connected node */
			StrechMeasurementMsg msg = new StrechMeasurementMsg(hostname, from,
					fromIp, toIp, dksRttMs, ipRttMs);
			dks.send(myDKSRef, msg);
		}
	}

	private synchronized void keepAlive(long longIp, String hostname,
			String ip, int port, long id, boolean dc) {
		if (direct) {
			/* directly update MySQL */
			try {
				psConnReplace.setLong(1, longIp);
				psConnReplace.setString(2, hostname);
				psConnReplace.setString(3, ip);
				psConnReplace.setInt(4, port);
				psConnReplace.setLong(5, id);
				psConnReplace.setBoolean(6, dc);

				psConnReplace.executeUpdate();
			} catch (SQLException e) {
				log.error(e.getMessage());
			}
		} else {
			/* send a message to directly connected node */
			KeepAliveMsg msg = new KeepAliveMsg(longIp, hostname, ip, port, id,
					dc);
			dks.send(myDKSRef, msg);
		}
	}
	
	public void sendKeepAlive() {
		keepAlive(myLongIp, myHostname, myIP, myPort, myID, direct);
	}

	private synchronized ArrayList<DKSRef> getActiveNodes(int seconds) {
		try {
			psConnSelect.setInt(1, seconds);

			ArrayList<DKSRef> activeNodes = new ArrayList<DKSRef>();
			ResultSet rs = psConnSelect.executeQuery();

			/* rs is never null according to javadoc */
			while (rs.next()) {
				String ip = rs.getString(1);
				int port = rs.getInt(2);
				long id = rs.getLong(3);

				activeNodes.add(DKSRef.valueOf("dksref://" + ip + ":" + port
						+ "/0/" + id + "/0/0"));
			}
			return activeNodes;
		} catch (Exception e) {
			log.error(e.getMessage());
			return null;
		}
	}

	private synchronized void incrementLoad(DKSRef ref) {
		long longIp = HostUtils.getIpAsLong(ref.getIP());
		try {
			psConnIncLoad.setLong(1, longIp);
			psConnIncLoad.executeUpdate();
		} catch (SQLException e) {
			log.error(e.getMessage());
		}
	}

	public void pingMsgHandler(DKSRef source, PingMsg msg) {
		dks.myDKSImpl.send(source, new PongMsg(msg.getTimestamp(), msg
				.getMsgId()));
		log.info("Replied to PING from " + source.getIP());
	}

	public void pongMsgHandler(DKSRef source, PongMsg msg) {
		long rtt = System.currentTimeMillis() - msg.getTimestamp();
		AsyncOperation.complete(msg.getMsgId(), new Integer((int) rtt));
		log.info("Received PONG from " + source.getIP());
	}

	public void strechMeasurementMsgHandler(DKSRef source,
			StrechMeasurementMsg msg) {
		storeMeasurement(msg.getFromIp(), msg.getToIp(), msg.getDksRttMs(), msg
				.getIpRttMs(), source, msg.getFromHostname());
	}

	public void liveNodesMsgHandler(DKSRef source, LiveNodesMsg msg) {
        log.debug("***LiveNodesMsg received from " + source);
		if (direct) {
			if (msg.getType() == LiveNodesMsg.Type.REQUEST) {
				ArrayList<DKSRef> refs = getActiveNodes(aliveSeconds);

				LiveNodesMsg reply = new LiveNodesMsg(LiveNodesMsg.Type.REPLY,
						refs.size(), refs);

				dks.send(source, reply);
			} else if (msg.getType() == LiveNodesMsg.Type.CHOICE) {
				if (msg.getSize() > 0) {
					DKSRef ref = msg.getRefs().get(0);
					incrementLoad(ref);
				}
			} else {
				log.fatal("DCN received REPLY msg.");
			}
		} else {
			if (msg.getType() == LiveNodesMsg.Type.REPLY) {
				synchronized (lock) {
					message = msg;
					lock.notify();
				}
			} else {
				log.fatal("IDC received REQUEST or CHOICE msg.");
			}
		}
	}

	public void keepAliveMsgHandler(DKSRef source, KeepAliveMsg msg) {
		if (direct) {
			keepAlive(msg.getLongIp(), msg.getHostname(), msg.getIp(), msg
					.getPort(), msg.getId(), msg.isDc());
		} else {
			log.fatal("IDC received KEEP_ALIVE message!");
		}
	}

	public void dhtBroadcastCallback(DKSObject value) {
	}

	public DKSMessage dhtRouteCallback(long identifier, DKSMessage msg) {
		if (msg instanceof StrechMeasurementMsg) {
			StrechMeasurementMsg strechMsg = (StrechMeasurementMsg) msg;
			storeMeasurement(strechMsg.getFromIp(), strechMsg.getToIp(),
					strechMsg.getDksRttMs(), strechMsg.getIpRttMs(), strechMsg
							.getFromDKSRef(), strechMsg.getFromHostname());
		} else {
			log.fatal("*********Received unknown message********");
		}
		return null;
	}

	public void dhtRouteCallbackAsync(long identifier, DKSMessage msg) {
		if (msg instanceof StrechMeasurementMsg) {
			StrechMeasurementMsg strechMsg = (StrechMeasurementMsg) msg;
			storeMeasurement(strechMsg.getFromIp(), strechMsg.getToIp(),
					strechMsg.getDksRttMs(), strechMsg.getIpRttMs(), strechMsg
							.getFromDKSRef(), strechMsg.getFromHostname());
		} else {
			log.fatal("*********Received unknown message********");
		}
	}

	private Connection checkMySQLConnectivity() {
		try {
			Class.forName("com.mysql.jdbc.Driver").newInstance();
			Connection connection = DriverManager
					.getConnection("jdbc:mysql://193.10.67.72/dks?user=dks&password=DKS");
			return connection;
		} catch (SQLException ex) {
			log.error(ex.getMessage());
		} catch (IllegalAccessException ex) {
			log.error("Can't load com.mysql.jdbc.Driver class!");
		} catch (ClassNotFoundException ex) {
			log.error("Can't load com.mysql.jdbc.Driver class!");
		} catch (InstantiationException ex) {
			log.error("Can't load com.mysql.jdbc.Driver class!");
		}
		return null;
	}

	private void sleepRandom() {
		int rand = random.nextInt(maxRandomSleep - minRandomSleep) + minRandomSleep;
		try {
			log.info("SLEEPING(" + rand + ") secs");
			Thread.sleep(rand * 1000);
		} catch (InterruptedException e) {
		}
	}

	private static void checkUsage(String[] args) {
		if (args.length == 4) {
			myHostname = args[0];
			myIP = args[1];
			myPort = Integer.parseInt(args[2]);
			myID = Long.parseLong(args[3], 16);
			myLongIp = HostUtils.getIpAsLong(myIP);
			return;
		}
		System.err
				.println("Usage: PlanetLabDKS <hostname> <bindIP> <bindPort> <dksID>");
		System.exit(1);
	}
}
