package explorviz.hpc_monitoring.filter.reconstruction;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import com.lmax.disruptor.EventHandler;
import com.lmax.disruptor.RingBuffer;
import com.lmax.disruptor.dsl.Disruptor;

import explorviz.hpc_monitoring.disruptor.RecordEvent;
import explorviz.hpc_monitoring.filter.counting.CountingThroughputFilter;
import explorviz.hpc_monitoring.filter.reduction.TracePatternSummarizationFilter;
import explorviz.hpc_monitoring.reader.IPeriodicTimeSignalReceiver;
import explorviz.hpc_monitoring.reader.TimeReader;
import explorviz.hpc_monitoring.record.IRecord;
import explorviz.hpc_monitoring.record.Trace;
import explorviz.hpc_monitoring.record.TraceMetadata;
import explorviz.hpc_monitoring.record.events.AbstractOperationEvent;
import gnu.trove.iterator.TLongObjectIterator;
import gnu.trove.map.hash.TLongObjectHashMap;

public final class TraceReconstructionFilter implements
		EventHandler<RecordEvent>, IPeriodicTimeSignalReceiver {
	private static final CountingThroughputFilter counter = new CountingThroughputFilter(
			"Reconstructed traces per second");

	private final long maxTraceTimeout;

	private final TLongObjectHashMap<TraceBuffer> traceId2trace = new TLongObjectHashMap<TraceBuffer>(
			1024);
	private final RingBuffer<RecordEvent> ringBuffer;

	@SuppressWarnings("unchecked")
	public TraceReconstructionFilter(final long maxTraceTimeout,
			final EventHandler<RecordEvent> endReceiver) {
		this.maxTraceTimeout = maxTraceTimeout;

		final ExecutorService exec = Executors.newCachedThreadPool();
		final Disruptor<RecordEvent> disruptor = new Disruptor<RecordEvent>(
				RecordEvent.EVENT_FACTORY, 16384, exec);

		final EventHandler<RecordEvent>[] eventHandlers = new EventHandler[1];
		eventHandlers[0] = new TracePatternSummarizationFilter(
				2 * 1000 * 1000 * 1000, endReceiver);
		disruptor.handleEventsWith(eventHandlers);
		ringBuffer = disruptor.start();

		new TimeReader(2 * 1000, this).start();
	}

	@Override
	public void periodicTimeSignal(final long timestamp) {
		checkForTimeouts(timestamp);
	}

	private void checkForTimeouts(final long timestamp) {
		final long traceTimeout = timestamp - maxTraceTimeout;
		for (final TLongObjectIterator<TraceBuffer> iterator = traceId2trace
				.iterator(); iterator.hasNext(); iterator.advance()) {
			final TraceBuffer traceBuffer = iterator.value();
			if (traceBuffer.getMaxLoggingTimestamp() <= traceTimeout) {
				sendOutInvalidTrace(traceBuffer.toTrace());
				iterator.remove();
			}
		}
	}

	private void sendOutValidTrace(final Trace trace) {
		counter.inputObjects(trace);
		putInRingBuffer(trace);
	}

	private void sendOutInvalidTrace(final Trace trace) {
		// counter.inputObjects(trace);
		// putInRingBuffer(trace); // TODO
		System.out.println("Invalid trace: "
				+ trace.getTraceMetadata().getTraceId());
	}

	private void putInRingBuffer(final IRecord record) {
		final long hiseq = ringBuffer.next();
		final RecordEvent valueEvent = ringBuffer.get(hiseq);
		valueEvent.setValue(record);
		ringBuffer.publish(hiseq);
	}

	@Override
	public void onEvent(final RecordEvent event, final long sequence,
			final boolean endOfBatch) throws Exception {
		final IRecord record = event.getValue();
		if (record instanceof TraceMetadata) {
			final TraceMetadata traceMetadata = ((TraceMetadata) record);

			final long traceId = traceMetadata.getTraceId();
			final TraceBuffer traceBuffer = getBufferForTraceId(traceId);
			traceBuffer.setTrace(traceMetadata);
		} else if (record instanceof AbstractOperationEvent) {
			final AbstractOperationEvent abstractOperationEvent = ((AbstractOperationEvent) record);

			final long traceId = abstractOperationEvent.getTraceId();
			final TraceBuffer traceBuffer = getBufferForTraceId(traceId);
			traceBuffer.insertEvent(abstractOperationEvent);

			if (traceBuffer.isFinished()) {
				traceId2trace.remove(traceId);
				sendOutValidTrace(traceBuffer.toTrace());
			}
		}
	}

	private TraceBuffer getBufferForTraceId(final long traceId) {
		TraceBuffer traceBuffer = traceId2trace.get(traceId);
		if (traceBuffer == null) {
			traceBuffer = new TraceBuffer();
			traceId2trace.put(traceId, traceBuffer);
		}
		return traceBuffer;
	}

	public void terminate() {
		for (final Object entry : traceId2trace.values()) {
			if (entry instanceof TraceBuffer) {
				sendOutInvalidTrace(((TraceBuffer) entry).toTrace());
			}
		}
		traceId2trace.clear();
	}
}