/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.expression.function;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.function.BiPredicate;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.expression.function.PPLTypeChecker;
import shaded.com.google.common.annotations.VisibleForTesting;

public final class CoercionUtils {
    private static final Set<ExprType> NUMBER_TYPES = ExprCoreType.numberTypes();
    private static final List<CoercionRule> COMMON_COERCION_RULES = List.of(CoercionRule.of((left, right) -> CoercionUtils.areDateAndTime(left, right), (left, right) -> ExprCoreType.TIMESTAMP), CoercionRule.of((left, right) -> CoercionUtils.hasString(left, right) && CoercionUtils.hasNumber(left, right), (left, right) -> ExprCoreType.DOUBLE));
    private static final int IMPOSSIBLE_WIDENING = Integer.MAX_VALUE;
    private static final int TYPE_EQUAL = 0;

    @Nullable
    public static List<RexNode> castArguments(RexBuilder builder, PPLTypeChecker typeChecker, List<RexNode> arguments) {
        List<List<ExprType>> paramTypeCombinations = typeChecker.getParameterTypes();
        List<ExprType> sourceTypes = arguments.stream().map(node -> OpenSearchTypeFactory.convertRelDataTypeToExprType(node.getType())).collect(Collectors.toList());
        PriorityQueue rankedSignatures = new PriorityQueue((left, right) -> Integer.compare((Integer)right.getValue(), (Integer)left.getValue()));
        for (List<ExprType> paramTypes2 : paramTypeCombinations) {
            int distance = CoercionUtils.distance(sourceTypes, paramTypes2);
            if (distance == 0) {
                return CoercionUtils.castArguments(builder, paramTypes2, arguments);
            }
            Optional.of(distance).filter(value -> value != Integer.MAX_VALUE).ifPresent(value -> rankedSignatures.add(Pair.of((Object)paramTypes2, (Object)value)));
        }
        return Optional.ofNullable((Pair)rankedSignatures.peek()).map(Pair::getKey).map(paramTypes -> CoercionUtils.castArguments(builder, paramTypes, arguments)).orElse(null);
    }

    @Nullable
    public static List<RexNode> widenArguments(RexBuilder builder, List<RexNode> arguments) {
        ExprType widestType = CoercionUtils.findWidestType(arguments);
        if (widestType == null) {
            return null;
        }
        return arguments.stream().map(arg -> CoercionUtils.cast(builder, widestType, arg)).toList();
    }

    @Nullable
    private static List<RexNode> castArguments(RexBuilder builder, List<ExprType> paramTypes, List<RexNode> arguments) {
        if (paramTypes.size() != arguments.size()) {
            return null;
        }
        ArrayList<RexNode> castedArguments = new ArrayList<RexNode>();
        for (int i = 0; i < paramTypes.size(); ++i) {
            RexNode arg;
            ExprType toType = paramTypes.get(i);
            RexNode castedArg = CoercionUtils.cast(builder, toType, arg = arguments.get(i));
            if (castedArg == null) {
                return null;
            }
            castedArguments.add(castedArg);
        }
        return castedArguments;
    }

    @Nullable
    private static RexNode cast(RexBuilder builder, ExprType targetType, RexNode arg) {
        ExprType argType = OpenSearchTypeFactory.convertRelDataTypeToExprType(arg.getType());
        if (!argType.shouldCast(targetType)) {
            return arg;
        }
        if (CoercionUtils.distance(argType, targetType) != Integer.MAX_VALUE) {
            return builder.makeCast(OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg, true, true);
        }
        return CoercionUtils.resolveCommonType(argType, targetType).map(exprType -> builder.makeCast(OpenSearchTypeFactory.convertExprTypeToRelDataType(exprType), arg, true, true)).orElse(null);
    }

    @Nullable
    private static ExprType findWidestType(List<RexNode> arguments) {
        if (arguments.isEmpty()) {
            return null;
        }
        ExprType widestType = OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.getFirst().getType());
        if (arguments.size() == 1) {
            return widestType;
        }
        for (int i = 1; i < arguments.size(); ++i) {
            ExprType type = OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(i).getType());
            try {
                ExprType tempType = widestType;
                widestType = CoercionUtils.resolveCommonType(widestType, type).orElseGet(() -> CoercionUtils.max(tempType, type));
                continue;
            }
            catch (ExpressionEvaluationException e) {
                return null;
            }
        }
        return widestType;
    }

    private static boolean areDateAndTime(ExprType type1, ExprType type2) {
        return type1 == ExprCoreType.DATE && type2 == ExprCoreType.TIME || type1 == ExprCoreType.TIME && type2 == ExprCoreType.DATE;
    }

    @VisibleForTesting
    public static Optional<ExprType> resolveCommonType(ExprType left, ExprType right) {
        return COMMON_COERCION_RULES.stream().map(rule -> rule.apply(left, right)).flatMap(Optional::stream).findFirst();
    }

    public static boolean hasString(List<RexNode> rexNodeList) {
        return rexNodeList.stream().map(RexNode::getType).map(OpenSearchTypeFactory::convertRelDataTypeToExprType).anyMatch(t -> t == ExprCoreType.STRING);
    }

    private static boolean hasString(ExprType left, ExprType right) {
        return left == ExprCoreType.STRING || right == ExprCoreType.STRING;
    }

    private static boolean hasNumber(ExprType left, ExprType right) {
        return NUMBER_TYPES.contains(left) || NUMBER_TYPES.contains(right);
    }

    private static boolean hasBoolean(ExprType left, ExprType right) {
        return left == ExprCoreType.BOOLEAN || right == ExprCoreType.BOOLEAN;
    }

    private static int distance(ExprType type1, ExprType type2) {
        return CoercionUtils.distance(type1, type2, 0);
    }

    private static int distance(ExprType type1, ExprType type2, int distance) {
        if (type1 == type2) {
            return distance;
        }
        if (type1 == ExprCoreType.UNKNOWN) {
            return Integer.MAX_VALUE;
        }
        if (type1 == ExprCoreType.STRING && type2 == ExprCoreType.DOUBLE) {
            return 1;
        }
        return type1.getParent().stream().map(parentOfType1 -> CoercionUtils.distance(parentOfType1, type2, distance + 1)).reduce(Math::min).get();
    }

    public static ExprType max(ExprType type1, ExprType type2) {
        int type1To2 = CoercionUtils.distance(type1, type2);
        int type2To1 = CoercionUtils.distance(type2, type1);
        if (type1To2 == Integer.MAX_VALUE && type2To1 == Integer.MAX_VALUE) {
            throw new ExpressionEvaluationException(String.format("no max type of %s and %s ", type1, type2));
        }
        return type1To2 == Integer.MAX_VALUE ? type1 : type2;
    }

    public static int distance(List<ExprType> sourceTypes, List<ExprType> targetTypes) {
        if (sourceTypes.size() != targetTypes.size()) {
            return Integer.MAX_VALUE;
        }
        int totalDistance = 0;
        for (int i = 0; i < sourceTypes.size(); ++i) {
            ExprType target;
            ExprType source = sourceTypes.get(i);
            int distance = CoercionUtils.distance(source, target = targetTypes.get(i));
            if (distance == Integer.MAX_VALUE) {
                return Integer.MAX_VALUE;
            }
            totalDistance += distance;
        }
        return totalDistance;
    }

    private record CoercionRule(BiPredicate<ExprType, ExprType> predicate, BinaryOperator<ExprType> resolver) {
        Optional<ExprType> apply(ExprType left, ExprType right) {
            return this.predicate.test(left, right) ? Optional.of((ExprType)this.resolver.apply(left, right)) : Optional.empty();
        }

        static CoercionRule of(BiPredicate<ExprType, ExprType> predicate, BinaryOperator<ExprType> resolver) {
            return new CoercionRule(predicate, resolver);
        }
    }
}

