package com.ovopark.training.enhancer.subject.tracing;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.function.Consumer;

import org.slf4j.MDC;

import com.ovopark.training.enhancer.utils.EhContextUtil;

public class TracingHelper {

    public static TracingHelper getInstance() {
        return Holder.INSTANCE;
    }

    public static Executor executorService(ThreadPoolExecutor poolExecutor) {
        return executorService(poolExecutor, null);
    }

    /**
     * 允许传递线程异常捕获回调无返回值
     * @param poolExecutor
     * @return
     */
    public static Executor executorService(ThreadPoolExecutor poolExecutor, Consumer<Throwable> exceptionCallback) {
        class CurrentTraceContextExecutorService extends WrappingExecutorService {

            @Override
            protected ExecutorService delegate() {
                return poolExecutor;
            }

            @Override
            protected <C> Callable<C> wrap(Callable<C> task) {
                return TracingHelper.getInstance().wrap(task, exceptionCallback);
            }

            @Override
            protected Runnable wrap(Runnable task) {
                return TracingHelper.getInstance().wrap(task, exceptionCallback);
            }
        }
        return new CurrentTraceContextExecutorService();
    }

    private Runnable wrap(Runnable task, Consumer<Throwable> exceptionCallback) {
        Map<String, String> contextMap = MDC.getCopyOfContextMap();
        HashMap<String, Object> ehContextMap = EhContextUtil.getAll();
        class CurrentTraceContextRunnable implements Runnable {
            @Override
            public void run() {
                if (contextMap != null) {
                    MDC.setContextMap(new HashMap<>(contextMap));
                }
                if (ehContextMap != null) {
                    EhContextUtil.putAll(new HashMap<>(ehContextMap));
                }
                try {
                    task.run();
                } catch(Exception exception){
                    if (exceptionCallback != null) {
                        exceptionCallback.accept(exception);
                    }
                    throw exception;
                } finally {
                    EhContextUtil.singleClear();
                    MDC.clear();
                }
            }
        }
        return new CurrentTraceContextRunnable();
    }

    private <C> Callable<C> wrap(Callable<C> task, Consumer<Throwable> exceptionCallback) {
        Map<String, String> contextMap = MDC.getCopyOfContextMap();
        HashMap<String, Object> ehContextMap = EhContextUtil.getAll();
        class CurrentTraceContextCallable implements Callable<C> {
            @Override
            public C call() throws Exception {
                if (contextMap != null) {
                    MDC.setContextMap(new HashMap<>(contextMap));
                }
                if (ehContextMap != null) {
                    EhContextUtil.putAll(new HashMap<>(ehContextMap));
                }
                C call;
                try {
                    call = task.call();
                } catch(Exception exception){
                    if (exceptionCallback != null) {
                        exceptionCallback.accept(exception);
                    }
                    throw exception;
                } finally {
                    EhContextUtil.singleClear();
                    MDC.clear();
                }
                return call;
            }
        }
        return new CurrentTraceContextCallable();
    }

    static class Holder {
        private static final TracingHelper INSTANCE = new TracingHelper();
    }
}
