package com.ovopark.training.enhancer.subject.tracing;

import com.ovopark.training.enhancer.utils.EhContextUtil;
import org.slf4j.MDC;

import java.util.concurrent.Callable;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;

public class TracingHelper {

    public static TracingHelper getInstance() {
        return Holder.INSTANCE;
    }

    public static Executor executorService(ThreadPoolExecutor poolExecutor) {
        class CurrentTraceContextExecutorService extends WrappingExecutorService {

            @Override
            protected ExecutorService delegate() {
                return poolExecutor;
            }

            @Override
            protected <C> Callable<C> wrap(Callable<C> task) {
                return TracingHelper.getInstance().wrap(task);
            }

            @Override
            protected Runnable wrap(Runnable task) {
                return TracingHelper.getInstance().wrap(task);
            }
        }
        return new CurrentTraceContextExecutorService();
    }

    private Runnable wrap(Runnable task) {
        final EhTraceContext invocationContext = get();
        class CurrentTraceContextRunnable implements Runnable {
            @Override
            public void run() {
                processTrace(invocationContext);
                task.run();
            }
        }
        return new CurrentTraceContextRunnable();
    }

    private EhTraceContext get() {
        EhTraceContext ehTraceContext = new EhTraceContext();
        ehTraceContext.setTraceId(EhContextUtil.getTraceId());
        return ehTraceContext;
    }

    private <C> Callable<C> wrap(Callable<C> task) {
        final EhTraceContext invocationContext = get();
        class CurrentTraceContextCallable implements Callable<C> {
            @Override
            public C call() throws Exception {
                processTrace(invocationContext);
                return task.call();
            }
        }
        return new CurrentTraceContextCallable();
    }

    private void processTrace(EhTraceContext invocationContext) {
        MDC.put("traceId", invocationContext.getTraceId());
        MDC.put("requestId", invocationContext.getTraceId());
    }

    static class Holder {
        private static final TracingHelper INSTANCE = new TracingHelper();
    }
}
