package explorviz.hpc_monitoring.plugin;

import java.util.*;
import java.util.concurrent.TimeUnit;
import kieker.analysis.IProjectContext;
import kieker.analysis.plugin.annotation.*;
import kieker.analysis.plugin.filter.AbstractFilterPlugin;
import kieker.common.configuration.Configuration;
import explorviz.hpc_monitoring.record.TraceEventRecords;
import explorviz.hpc_monitoring.record.events.AbstractOperationEvent;

/**
 * This filter collects incoming traces for a specified amount of time.
 * Any traces representing the same series of events will be used to calculate
 * statistic informations like the average runtime of this kind of trace.
 * Only one specimen of these traces containing this information will be
 * forwarded
 * from this filter.
 * 
 * Statistic outliers regarding the runtime of the trace will be treated special
 * and therefore send out as they are and will not be mixed with others.
 * 
 * 
 * @author Florian Biss
 * 
 * @since 1.8
 */

@Plugin(description = "This filter tries to aggregate similar TraceEventRecordss into a single trace.", outputPorts = { @OutputPort(name = TraceEventRecordAggregationFilter.OUTPUT_PORT_NAME_TRACES, description = "Output port for the processed traces", eventTypes = { TraceEventRecords.class }) }, configuration = {
        @Property(name = TraceEventRecordAggregationFilter.CONFIG_PROPERTY_NAME_TIMEUNIT, defaultValue = TraceEventRecordAggregationFilter.CONFIG_PROPERTY_VALUE_TIMEUNIT),
        @Property(name = TraceEventRecordAggregationFilter.CONFIG_PROPERTY_NAME_MAX_COLLECTION_DURATION, defaultValue = TraceEventRecordAggregationFilter.CONFIG_PROPERTY_VALUE_MAX_COLLECTION_DURATION),
        @Property(name = TraceEventRecordAggregationFilter.CONFIG_PROPERTY_NAME_MAX_DEVIATION, defaultValue = TraceEventRecordAggregationFilter.CONFIG_PROPERTY_VALUE_MAX_DEVIATION) })
public class TraceEventRecordAggregationFilter extends AbstractFilterPlugin {
    /**
     * The name of the output port delivering the valid traces.
     */
    public static final String                                   OUTPUT_PORT_NAME_TRACES                       = "tracesOut";

    /**
     * The name of the input port receiving the trace records.
     */
    public static final String                                   INPUT_PORT_NAME_TRACES                        = "tracesIn";

    /**
     * The name of the property determining the time unit.
     */
    public static final String                                   CONFIG_PROPERTY_NAME_TIMEUNIT                 = "timeunit";

    /**
     * Clock input for timeout handling.
     */
    public static final String                                   INPUT_PORT_NAME_TIME_EVENT                    = "timestamp";

    /**
     * The default value of the time unit property (nanoseconds).
     */
    public static final String                                   CONFIG_PROPERTY_VALUE_TIMEUNIT                = "NANOSECONDS";          // TimeUnit.NANOSECONDS.name()

    /**
     * The name of the property determining the maximal trace timeout.
     */
    public static final String                                   CONFIG_PROPERTY_NAME_MAX_COLLECTION_DURATION  = "maxCollectionDuration";

    /**
     * The default value of the property determining the maximal trace timeout.
     */
    public static final String                                   CONFIG_PROPERTY_VALUE_MAX_COLLECTION_DURATION = "4000000000";

    /**
     * The name of the property determining the maximal runtime deviation
     * factor.
     * 
     * Outliers are indicated by
     * <code>|runtime - averageRuntime| > deviationFactor * standardDeviation</code>
     * .
     * Use negative number to aggregate all traces.
     */
    public static final String                                   CONFIG_PROPERTY_NAME_MAX_DEVIATION            = "maxDeviation";

    /**
     * The default value of the property determining the maximal runtime
     * deviation factor.
     * Default is two standard deviations.
     */
    public static final String                                   CONFIG_PROPERTY_VALUE_MAX_DEVIATION           = "-1";

    private final TimeUnit                                       timeunit;
    private final long                                           maxCollectionDuration;
    private final long                                           maxDeviation;

    private final Map<TraceEventRecords, TraceAggregationBuffer> trace2buffer;

    public TraceEventRecordAggregationFilter(final Configuration configuration,
            final IProjectContext projectContext) {
        super(configuration, projectContext);

        final String recordTimeunitProperty = projectContext
                .getProperty(IProjectContext.CONFIG_PROPERTY_NAME_RECORDS_TIME_UNIT);
        TimeUnit recordTimeunit;
        try {
            recordTimeunit = TimeUnit.valueOf(recordTimeunitProperty);
        }
        catch (final IllegalArgumentException ex) {
            recordTimeunit = TimeUnit.NANOSECONDS;
        }
        timeunit = recordTimeunit;

        maxDeviation = configuration
                .getLongProperty(CONFIG_PROPERTY_NAME_MAX_DEVIATION);

        maxCollectionDuration = timeunit.convert(configuration
                .getLongProperty(CONFIG_PROPERTY_NAME_MAX_COLLECTION_DURATION),
                timeunit);
        trace2buffer = new TreeMap<TraceEventRecords, TraceAggregationBuffer>(
                new TraceComperator());
    }

    @InputPort(name = INPUT_PORT_NAME_TRACES, description = "Collect identical traces and aggregate them.", eventTypes = { TraceEventRecords.class })
    public void newEvent(final TraceEventRecords event) {
        synchronized (this) {
            insertIntoBuffer(event);
        }
    }

    private void insertIntoBuffer(final TraceEventRecords trace) {
        TraceAggregationBuffer traceBuffer;
        traceBuffer = trace2buffer.get(trace);

        if (traceBuffer == null) { // first record for this id!
            synchronized (this) {
                traceBuffer = trace2buffer.get(trace);

                if (traceBuffer == null) { // NOCS (DCL)
                    traceBuffer = new TraceAggregationBuffer(System.nanoTime());
                    trace2buffer.put(trace, traceBuffer);
                }

            }
        }
        synchronized (this) {
            traceBuffer.insertTrace(trace);
        }
    }

    @InputPort(name = INPUT_PORT_NAME_TIME_EVENT, description = "Time signal for timeouts", eventTypes = { Long.class })
    public void newEvent(final Long timestamp) {
        synchronized (this) {
            processTimeoutQueue(timestamp);
        }
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void terminate(final boolean error) {
        synchronized (this) {
            for (final TraceAggregationBuffer traceBuffer : trace2buffer
                    .values()) {
                super.deliver(OUTPUT_PORT_NAME_TRACES,
                        traceBuffer.getAggregatedTrace());
            }
            trace2buffer.clear();
        }
    }

    private void processTimeoutQueue(final long timestamp) {
        final long bufferTimeout = timestamp - maxCollectionDuration;
        List<TraceEventRecords> toRemove = new ArrayList<TraceEventRecords>();
        for (final TraceAggregationBuffer traceBuffer : trace2buffer.values()) {
            if (traceBuffer.getBufferCreatedTimestamp() <= bufferTimeout) {
                TraceEventRecords aggregatedTrace = traceBuffer
                        .getAggregatedTrace();
                super.deliver(OUTPUT_PORT_NAME_TRACES, aggregatedTrace);
                toRemove.add(aggregatedTrace);
            }
        }
        for (TraceEventRecords traceEventRecords : toRemove) {
            trace2buffer.remove(traceEventRecords);
        }
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public Configuration getCurrentConfiguration() {
        final Configuration configuration = new Configuration();
        configuration.setProperty(CONFIG_PROPERTY_NAME_TIMEUNIT,
                timeunit.name());
        configuration.setProperty(CONFIG_PROPERTY_NAME_MAX_COLLECTION_DURATION,
                String.valueOf(maxCollectionDuration));
        configuration.setProperty(CONFIG_PROPERTY_NAME_MAX_DEVIATION,
                String.valueOf(maxDeviation));
        return configuration;
    }

    private static final class TraceAggregationBuffer {
        private TraceEventRecords accumulator;

        private final long        bufferCreatedTimestamp;

        public TraceAggregationBuffer(final long bufferCreatedTimestamp) {
            this.bufferCreatedTimestamp = bufferCreatedTimestamp;
        }

        public long getBufferCreatedTimestamp() {
            return bufferCreatedTimestamp;
        }

        public TraceEventRecords getAggregatedTrace() {
            return accumulator;
        }

        public void insertTrace(final TraceEventRecords trace) {
            aggregate(trace);
        }

        private void aggregate(final TraceEventRecords trace) {
            if (accumulator == null) {
                accumulator = trace;
            }
            else {
                final AbstractOperationEvent[] aggregatedRecords = accumulator
                        .getTraceOperations();
                final AbstractOperationEvent[] records = trace
                        .getTraceOperations();
                for (int i = 0; i < aggregatedRecords.length; i++) {
                    aggregatedRecords[i].getRuntime().merge(
                            records[i].getRuntime());
                }

                accumulator.getRuntime().merge(trace.getRuntime());
            }
        }
    }

    private static final class TraceComperator implements
            Comparator<TraceEventRecords> {

        public TraceComperator() {}

        public int compare(final TraceEventRecords t1,
                final TraceEventRecords t2) {
            final AbstractOperationEvent[] recordsT1 = t1.getTraceOperations();
            final AbstractOperationEvent[] recordsT2 = t2.getTraceOperations();

            if ((recordsT1.length - recordsT2.length) != 0) {
                return recordsT1.length - recordsT2.length;
            }

            final int cmpHostnames = t1.getTrace().getHostname()
                    .compareTo(t2.getTrace().getHostname());
            if (cmpHostnames != 0) {
                return cmpHostnames;
            }

            // TODO deep check records
            return 0;
        }
    }

}
