package explorviz.hpc_monitoring.filter.reduction;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import com.lmax.disruptor.EventHandler;
import com.lmax.disruptor.RingBuffer;
import com.lmax.disruptor.dsl.Disruptor;

import explorviz.hpc_monitoring.disruptor.RecordEvent;
import explorviz.hpc_monitoring.filter.counting.CountingThroughputFilter;
import explorviz.hpc_monitoring.reader.IPeriodicTimeSignalReceiver;
import explorviz.hpc_monitoring.reader.TimeReader;
import explorviz.hpc_monitoring.record.IRecord;
import explorviz.hpc_monitoring.record.Trace;
import explorviz.hpc_monitoring.record.events.AbstractOperationEvent;

public class TracePatternSummarizationFilter implements
		EventHandler<RecordEvent>, IPeriodicTimeSignalReceiver {
	private final long maxCollectionDuration;

	private final Map<Trace, TraceAggregationBuffer> trace2buffer = new TreeMap<Trace, TraceAggregationBuffer>(
			new TraceComperator());

	private static final CountingThroughputFilter counter = new CountingThroughputFilter(
			"Reduced trace results per second");

	private final RingBuffer<RecordEvent> ringBuffer;

	public TracePatternSummarizationFilter(final long maxCollectionDuration,
			final EventHandler<RecordEvent> endReceiver) {
		this.maxCollectionDuration = maxCollectionDuration;

		final ExecutorService exec = Executors.newCachedThreadPool();
		final Disruptor<RecordEvent> disruptor = new Disruptor<RecordEvent>(
				RecordEvent.EVENT_FACTORY, 128, exec);

		@SuppressWarnings("unchecked")
		final EventHandler<RecordEvent>[] eventHandlers = new EventHandler[1];
		eventHandlers[0] = endReceiver;
		if (endReceiver != null) {
			disruptor.handleEventsWith(eventHandlers);
		}
		ringBuffer = disruptor.start();
		new TimeReader(1 * 1000, this).start();
	}

	@Override
	public void periodicTimeSignal(final long timestamp) {
		processTimeoutQueue(timestamp);
	}

	private void processTimeoutQueue(final long timestamp) {
		final long bufferTimeout = timestamp - maxCollectionDuration;
		final List<Trace> toRemove = new ArrayList<Trace>();
		for (final TraceAggregationBuffer traceBuffer : trace2buffer.values()) {
			if (traceBuffer.getBufferCreatedTimestamp() <= bufferTimeout) {
				final Trace aggregatedTrace = traceBuffer.getAggregatedTrace();
				sendOutTrace(aggregatedTrace);
				toRemove.add(aggregatedTrace);
			}
		}
		for (final Trace traceEventRecords : toRemove) {
			trace2buffer.remove(traceEventRecords);
		}
	}

	private void sendOutTrace(final Trace aggregatedTrace) {
		counter.inputObjects(aggregatedTrace);
		putInRingBuffer(aggregatedTrace);
	}

	private void putInRingBuffer(final IRecord record) {
		final long hiseq = ringBuffer.next();
		final RecordEvent valueEvent = ringBuffer.get(hiseq);
		valueEvent.setValue(record);
		ringBuffer.publish(hiseq);
	}

	@Override
	public void onEvent(final RecordEvent event, final long sequence,
			final boolean endOfBatch) throws Exception {
		final IRecord value = event.getValue();
		if (value instanceof Trace) {
			insertIntoBuffer((Trace) value);
		}
	}

	private void insertIntoBuffer(final Trace trace) {
		TraceAggregationBuffer traceAggregationBuffer = trace2buffer.get(trace);
		if (traceAggregationBuffer == null) {
			traceAggregationBuffer = new TraceAggregationBuffer(
					System.nanoTime());
			trace2buffer.put(trace, traceAggregationBuffer);
		}
		traceAggregationBuffer.insertTrace(trace);
	}

	public void terminate(final boolean error) {
		for (final TraceAggregationBuffer traceBuffer : trace2buffer.values()) {
			sendOutTrace(traceBuffer.getAggregatedTrace());
		}
		trace2buffer.clear();
	}

	private static final class TraceComperator implements Comparator<Trace> {

		@Override
		public int compare(final Trace t1, final Trace t2) {
			final AbstractOperationEvent[] recordsT1 = t1.getTraceEvents();
			final AbstractOperationEvent[] recordsT2 = t2.getTraceEvents();

			if ((recordsT1.length - recordsT2.length) != 0) {
				return recordsT1.length - recordsT2.length;
			}

			final int cmpHostnames = t1.getTraceMetadata().getHostname()
					.compareTo(t2.getTraceMetadata().getHostname());
			if (cmpHostnames != 0) {
				return cmpHostnames;
			}

			// TODO deep check records
			return 0;
		}
	}
}
