package com.wss.common.results;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.google.gson.Gson;
import com.wss.common.logging.LogContext;
import com.wss.common.logging.LogUtils;
import com.wss.common.results.enums.*;
import com.wss.common.results.scaResultsDTO.ScaResObjDTO;
import com.wss.common.results.scaResultsDTO.ScaResultsDTO;
import com.wss.common.results.scaTypesDTO.ScaTypesDTO;
import lombok.Getter;
import org.apache.commons.lang3.EnumUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


import java.io.File;
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Instant;
import java.time.ZoneId;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.stream.Collectors;

/**
 * @author Arye.Hochman
 */
public class ScaRes {
    public static final List<String> supportedPMs = List.of("maven", "npm", "nuget", "swift", "python", "pip", "pipenv",
            "poetry", "ruby"); // TODO: Add sbt when the new resolver will be default

    @Getter
    private boolean scaResEnabled;
    private static String repoDir;
    @Getter
    private final String resFile;
    @Getter
    private ScaResultsDTO scaResults;
    @Getter
    private final List<String> failedPms;
    private ScaTypesDTO scaTypes;
    private final List<ScaResObjDTO> scmFailures;
    @Getter
    private final Map<String, String> scmPairs;
    private final int resMsgLimit = 400;
    private final Logger logger = LoggerFactory.getLogger(ScaRes.class);
    private final LogContext logCtx;

    private ScaRes(LogContext logContext, boolean enabled, String repoDirPath) {
        String newResEnv = System.getenv("DETAILED_SCA_RESULTS_INFO");
        scaResEnabled = newResEnv == null ? enabled : Boolean.parseBoolean(newResEnv);
        repoDir = repoDirPath;
        resFile = repoDir + File.separator + "mend-sca-results.json";
        System.setProperty("MEND_SCA_RESULTS_FILE", resFile);
        scaResults = new ScaResultsDTO();
        failedPms = new ArrayList<>();
        scaTypes = new ScaTypesDTO();
        scmFailures = new ArrayList<>();
        scmPairs = new HashMap<>();
        logCtx = logContext;
    }

    public static ScaRes initResults(LogContext logContext, boolean enabled, String repoDirPath) {
        return new ScaRes(logContext, enabled, repoDirPath);
    }

    public void addScmFailure(StageEnum stage, LevelEnum level, String resMsg, int exitCode) {
        scmFailures.add(new ScaResObjDTO(
                DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSSSSSSSS'Z'").withZone(ZoneId.from(ZoneOffset.UTC)).format(Instant.now()),
                TypeEnum.GENERAL.toString(),
                "",
                stage.toString(),
                false,
                level.toString(),
                ResTypeEnum.STAGE_FAILURE.toString(),
                limitResMsg(resMsg),
                ToolEnum.SCM.toString(),
                false,
                exitCode > 0 ? new HashMap<>(Map.of("exitCode", String.valueOf(exitCode))) : new HashMap<>()
        ));
    }

    public String limitResMsg(String resMsg) {
        if (resMsg.length() > resMsgLimit) {
            return resMsg.substring(0, resMsgLimit - 3) + "...";
        }
        return resMsg;
    }

    public void addEchoLog(String key, String value) {
        scmPairs.put(key, value);
    }

    public ScaResultsDTO loadScaResults() {
        loadResFile();

        if (!scmFailures.isEmpty()) {
            logger.debug(logWrapper("Appends the SCM failures to the result list"));
            if (scaResults.getResults().get(repoDir) == null) {
                scaResults.getResults().put(repoDir, scmFailures);
            } else {
                scaResults.getResults().get(repoDir).addAll(scmFailures);
            }
        }

        validateScaResults();
        mergeResults();

        ObjectWriter ow = new ObjectMapper().writer();
        try {
            logger.debug(logWrapper("SCA_RESULTS_JSON=" + ow.writeValueAsString(scaResults)));
        } catch (Exception e) {
            logger.error(logWrapper("error while printing the sca results json"));
        }

        return scaResults;
    }

    private void loadResFile() {
        Path resFilePath = Paths.get(resFile);
        if (Files.exists(resFilePath)) {

            logger.debug(logWrapper("Trying to load the results file: {}"), resFile);
            try {
                scaResults = new Gson().fromJson(Files.readString(resFilePath), ScaResultsDTO.class);
            } catch (Exception e) {
                logger.error(logWrapper("Failed to load the results file {}"), resFile);
                scaResults = new ScaResultsDTO();
            }

            try {
                Files.deleteIfExists(resFilePath);
            } catch (Exception e) {
                logger.error(logWrapper("Failed to delete the results file {} - {}"), resFile, e.getMessage());
            }
        }
    }

    private void validateScaResults() {
        ClassLoader classloader = Thread.currentThread().getContextClassLoader();
        URL mockUrl = classloader.getResource("detection-types.json");
        if (mockUrl == null) {
            logger.debug(logWrapper("Failed to find the detection-types.json file"));
        } else {
            try (Reader reader = new InputStreamReader(mockUrl.openStream())) {
                scaTypes = new Gson().fromJson(reader, ScaTypesDTO.class);
            } catch (Exception e) {
                logger.debug(logWrapper("Failed to load the detection-types.json file: {}"), e.getMessage());
            }
        }
        scaResults.getResults().entrySet().removeIf(path -> {
            path.getValue().removeIf(resObj -> !isValidResObj(resObj));
            return path.getValue().isEmpty();
        });

        scaResults.getTotalSuccess().entrySet().removeIf(stage -> !isValidStageSum(stage.getKey(), "totalSuccess"));
        scaResults.getTotalFail().entrySet().removeIf(stage -> !isValidStageSum(stage.getKey(), "totalFail"));
    }

    private boolean isValidResObj(ScaResObjDTO resObj) {
        if (!resObj.getPm().isEmpty() && !resObj.getLevel().equals(LevelEnum.INFO.name()) && !failedPms.contains(resObj.getPm())) {
            failedPms.add(resObj.getPm());
        }
        String errType = "";

        if (!EnumUtils.isValidEnum(TypeEnum.class, resObj.getType())) {
            errType = "type: \"" + resObj.getType();
        } else if (!EnumUtils.isValidEnum(StageEnum.class, resObj.getStage())
                || (!scaTypes.getStages().isEmpty() && !scaTypes.getStages().containsKey(resObj.getStage()))) {
            errType = "stage: \"" + resObj.getStage();
        } else if (!EnumUtils.isValidEnum(LevelEnum.class, resObj.getLevel())) {
            errType = "level: \"" + resObj.getLevel();
        } /*else if (!EnumUtils.isValidEnum(ResTypeEnum.class, resObj.getResType())
                || (!scaTypes.getResTypes().isEmpty() && !scaTypes.getResTypes().containsKey(resObj.getResType()))
                || (!scaTypes.getResTypes().isEmpty() && !scaTypes.getResTypes().get(resObj.getResType()).isEmpty()
                && !scaTypes.getResTypes().get(resObj.getResType()).containsKey(resObj.getLevel()))) {
            errType = "res-type: \"" + resObj.getResType();
        }*/ else if (!EnumUtils.isValidEnum(ToolEnum.class, resObj.getTool())) {
            errType = "tool: \"" + resObj.getTool();
        } else if (resObj.getResMsg().length() > resMsgLimit) {
            resObj.setResMsg(limitResMsg(resObj.getResMsg()));
        }

        if (!errType.isEmpty()) {
            logger.debug(logWrapper("Invalid {}\" for: {}"), errType, new Gson().toJson(resObj));
            return false;
        }
        return true;
    }

    private boolean isValidStageSum(String stage, String field) {
        if (!EnumUtils.isValidEnum(StageEnum.class, stage)
                || (!scaTypes.getStages().isEmpty() && !scaTypes.getStages().containsKey(stage))) {
            logger.debug(logWrapper("Invalid {} stage: \"{}\""), field, stage);
            return false;
        }
        return true;
    }

    // Merge several messages from the same flow with the property merge=true, into a single message with the level of the highest among them
    // Prevents a situation where essential problems in the flow are reported as info (which is not displayed by default) -
    // because there are fallbacks and the like after them, that can be successful,
    // and only the last problem in the flow, which is sometimes marginal, is reported as an error / warn
    private void mergeResults() {
        for (Map.Entry<String, List<ScaResObjDTO>> path : scaResults.getResults().entrySet()) {
            path.getValue().stream().collect(Collectors.groupingBy(r -> Arrays.asList(r.getStage(), r.getTool())))
                    .forEach((group, results) -> {
                        List<ScaResObjDTO> merge = results.stream().filter(ScaResObjDTO::isMerge).toList();
                        if (merge.size() < 2) return;

                        ScaResObjDTO lastRes = merge.get(merge.size() - 1);

                        int level = merge.stream().mapToInt(res -> LevelEnum.valueOf(res.getLevel()).ordinal()).min().orElse(0);
                        lastRes.setLevel(LevelEnum.values()[level].name());
                        String resMsg = "<ul><li>" + merge.stream().map(ScaResObjDTO::getResMsg).collect(Collectors.joining("</li><li>")) + "</li></ul>";
                        lastRes.setResMsg(resMsg);
                        boolean multipleTypes = merge.stream().map(ScaResObjDTO::getResType).distinct().count() > 1;
                        if (multipleTypes) lastRes.setResType(ResTypeEnum.MULTIPLE_TYPES.name());

                        for (int i = 0; i < merge.size() - 1; i++) {
                            scaResults.getResults().get(path.getKey()).remove(merge.get(i));
                        }
                    });
        }
    }

    public void prepareScaPairs(Map<String, String> baseLog) {
        scmPairs.putAll(baseLog);
        scaResults.getTags().forEach((pm, tags) -> {
            if (tags.isEmpty()) {
                tags.add("NO_TAGS");
            }
            scmPairs.put(pm.toUpperCase() + "_TAGS", String.join(",", tags));
        });
        addSumPairs();
    }

    private void addSumPairs() {
        try {
            scaResults.getResults().forEach((path, pathResults) -> {
                List<String> pmPreStepFailures = new ArrayList<>();
                for (ScaResObjDTO resObj : pathResults) {
                    String pmPrefix = resObj.getPm().isEmpty() ? "GENERAL_" : resObj.getPm().toUpperCase() + "_";
                    addStageSum(pmPrefix, resObj, pmPreStepFailures);
                    if (!resObj.getPm().isEmpty() && !scmPairs.containsKey(pmPrefix + "TAGS")) {
                        scmPairs.put(pmPrefix + "TAGS", "NO_TAGS");
                    }
                }
                // empty hostRules relation to failure
                pmPreStepFailures.forEach(pm -> {
                    boolean thereHR = scaResults.getTags().containsKey(pm) && scaResults.getTags().get(pm).contains("HOST_RULES");
                    scmPairs.put(pm.toUpperCase() + "_FAILURE_HOSTRULES", Boolean.toString(thereHR));
/*
                    boolean connSuccess = scmPairs.get(pm.toUpperCase() + "_CONNECTIVITY_SUCCESS").equals("0");
                    boolean connFailed = scmPairs.get(pm.toUpperCase() + "_CONNECTIVITY_ERROR").equals("0");
                    scmPairs.put(pm.toUpperCase() + "_FAILURE_HOSTRULES", Boolean.toString(connSuccess && connFailed)); // Apparently the Boolean should be flipped
*/
                });
            });
        } catch (Exception e) {
            logger.debug(logWrapper("sca res map creation error: {}"), e.getMessage());
        }
    }

    private void addStageSum(String pmPrefix, ScaResObjDTO resObj, List<String> pmPreStepFailures) {
        String resPrefix = pmPrefix + resObj.getStage() + "_";
        preparePmStageLevelSumMaps(resPrefix);

        if (resObj.isSuccess()) {
            resPrefix += "SUCCESS";
        } else {
            scmPairs.put(resPrefix + resObj.getLevel(), String.valueOf(Integer.parseInt(scmPairs.get(resPrefix + resObj.getLevel())) + 1));
            resPrefix += resObj.getResType();
            if (resObj.getLevel().equals(LevelEnum.ERROR.toString()) && resObj.getStage().equals(StageEnum.PRE_STEP.toString())
                    && !resObj.getPm().isEmpty() && !pmPreStepFailures.contains(resObj.getPm())) {
                pmPreStepFailures.add(resObj.getPm());
            }
        }
        int count = scmPairs.get(resPrefix) == null ? 0 : Integer.parseInt(scmPairs.get(resPrefix));
        scmPairs.put(resPrefix, String.valueOf(count + 1));
    }

    private void preparePmStageLevelSumMaps(String resPrefix) {
        for (LevelEnum level : LevelEnum.values()) {
            if (!scmPairs.containsKey(resPrefix + level.toString())) {
                scmPairs.put(resPrefix + level, String.valueOf(0));
            }
        }
        if (!scmPairs.containsKey(resPrefix + "SUCCESS")) {
            scmPairs.put(resPrefix + "SUCCESS", String.valueOf(0));
        }
    }

    private String logWrapper(String message) {
        return LogUtils.formatLogMessage(logCtx, message);
    }
}
