package dev.langchain4j.store.embedding.filter.parser.sql;

import dev.langchain4j.Experimental;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.FilterParser;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan;
import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsIn;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThan;
import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo;
import dev.langchain4j.store.embedding.filter.comparison.IsNotIn;
import dev.langchain4j.store.embedding.filter.logical.And;
import dev.langchain4j.store.embedding.filter.logical.Not;
import dev.langchain4j.store.embedding.filter.logical.Or;
import java.net.URLEncoder;
import java.time.Clock;
import java.time.LocalDateTime;
import java.time.temporal.IsoFields;
import java.util.ArrayList;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.DoubleValue;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.ExtractExpression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.NotExpression;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.SignedExpression;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.TimeKeyExpression;
import net.sf.jsqlparser.expression.operators.arithmetic.Addition;
import net.sf.jsqlparser.expression.operators.arithmetic.Division;
import net.sf.jsqlparser.expression.operators.arithmetic.Multiplication;
import net.sf.jsqlparser.expression.operators.arithmetic.Subtraction;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.Between;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;

@Experimental
/* loaded from: input_file:dev/langchain4j/store/embedding/filter/parser/sql/SqlFilterParser.class */
public class SqlFilterParser implements FilterParser {
    private final LocalDateTime localDateTime;

    public SqlFilterParser() {
        this(Clock.systemDefaultZone());
    }

    public SqlFilterParser(Clock clock) {
        this.localDateTime = LocalDateTime.now((Clock) ValidationUtils.ensureNotNull(clock, "clock"));
    }

    public Filter parse(String str) {
        if (!str.toUpperCase().startsWith("SELECT")) {
            str = "SELECT * FROM fake_table WHERE " + str;
        }
        try {
            return mapParenthesis(CCJSqlParserUtil.parse(str).getWhere());
        } catch (JSQLParserException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    private Filter mapParenthesis(Expression expression) {
        if (expression instanceof BinaryExpression) {
            return mapBinaryExpression((BinaryExpression) expression);
        }
        if (expression instanceof NotExpression) {
            return new Not(mapParenthesis(((NotExpression) expression).getExpression()));
        }
        if (expression instanceof Parenthesis) {
            return mapParenthesis(((Parenthesis) expression).getExpression());
        }
        if (expression instanceof InExpression) {
            return mapInExpression((InExpression) expression);
        }
        if (expression instanceof Between) {
            return mapBetween((Between) expression);
        }
        throw Exceptions.illegalArgument("Unsupported expression: '%s'%s", new Object[]{expression, createGithubIssueLink(expression)});
    }

    private static String createGithubIssueLink(Expression expression) {
        try {
            return ". Please click the following link to open an issue on our GitHub: https://github.com/langchain4j/langchain4j/issues/new?labels=SqlFilterParser&title=SqlFilterParser:%20Support%20new%20expression%20type&body=" + URLEncoder.encode(expression.toString(), "UTF-8");
        } catch (Exception e) {
            return "";
        }
    }

    private Filter mapBinaryExpression(BinaryExpression binaryExpression) {
        if (binaryExpression instanceof AndExpression) {
            return new And(mapParenthesis(binaryExpression.getLeftExpression()), mapParenthesis(binaryExpression.getRightExpression()));
        }
        if (binaryExpression instanceof OrExpression) {
            return new Or(mapParenthesis(binaryExpression.getLeftExpression()), mapParenthesis(binaryExpression.getRightExpression()));
        }
        if (binaryExpression instanceof EqualsTo) {
            return new IsEqualTo(getKey(binaryExpression), getValue(binaryExpression));
        }
        if (binaryExpression instanceof NotEqualsTo) {
            return new IsNotEqualTo(getKey(binaryExpression), getValue(binaryExpression));
        }
        if (binaryExpression instanceof GreaterThan) {
            return new IsGreaterThan(getKey(binaryExpression), getValue(binaryExpression));
        }
        if (binaryExpression instanceof GreaterThanEquals) {
            return new IsGreaterThanOrEqualTo(getKey(binaryExpression), getValue(binaryExpression));
        }
        if (binaryExpression instanceof MinorThan) {
            return new IsLessThan(getKey(binaryExpression), getValue(binaryExpression));
        }
        if (binaryExpression instanceof MinorThanEquals) {
            return new IsLessThanOrEqualTo(getKey(binaryExpression), getValue(binaryExpression));
        }
        throw Exceptions.illegalArgument("Unsupported expression: '%s'%s", new Object[]{binaryExpression, createGithubIssueLink(binaryExpression)});
    }

    private Filter mapInExpression(InExpression inExpression) {
        String columnName = inExpression.getLeftExpression().getColumnName();
        final ArrayList arrayList = new ArrayList();
        inExpression.getRightExpression().accept(new ExpressionVisitorAdapter() { // from class: dev.langchain4j.store.embedding.filter.parser.sql.SqlFilterParser.1
            public void visit(StringValue stringValue) {
                arrayList.add(stringValue.getValue());
            }

            public void visit(LongValue longValue) {
                arrayList.add(Long.valueOf(longValue.getValue()));
            }

            public void visit(DoubleValue doubleValue) {
                arrayList.add(Double.valueOf(doubleValue.getValue()));
            }
        });
        return inExpression.isNot() ? new IsNotIn(columnName, arrayList) : new IsIn(columnName, arrayList);
    }

    private Filter mapBetween(Between between) {
        String columnName = between.getLeftExpression().getColumnName();
        return new IsGreaterThanOrEqualTo(columnName, getValue(between.getBetweenExpressionStart())).and(new IsLessThanOrEqualTo(columnName, getValue(between.getBetweenExpressionEnd())));
    }

    private String getKey(BinaryExpression binaryExpression) {
        return binaryExpression.getLeftExpression().getColumnName();
    }

    private Comparable<?> getValue(BinaryExpression binaryExpression) {
        return getValue(binaryExpression.getRightExpression());
    }

    private Comparable<?> getValue(Expression expression) {
        if (expression instanceof StringValue) {
            return ((StringValue) expression).getValue();
        }
        if (expression instanceof LongValue) {
            return Long.valueOf(((LongValue) expression).getValue());
        }
        if (expression instanceof DoubleValue) {
            return Double.valueOf(((DoubleValue) expression).getValue());
        }
        if (expression instanceof SignedExpression) {
            SignedExpression signedExpression = (SignedExpression) expression;
            if (signedExpression.getSign() == '-') {
                if (signedExpression.getExpression() instanceof LongValue) {
                    return Long.valueOf(Long.parseLong("-" + signedExpression.getExpression().toString()));
                }
                if (signedExpression.getExpression() instanceof DoubleValue) {
                    return Double.valueOf(Double.parseDouble("-" + signedExpression.getExpression().toString()));
                }
            }
        } else if (expression instanceof Function) {
            Function function = (Function) expression;
            if (function.getName().equalsIgnoreCase("YEAR")) {
                ExpressionList parameters = function.getParameters();
                if (parameters.size() == 1 && (parameters.get(0) instanceof Function) && ((Function) parameters.get(0)).getName().equalsIgnoreCase("CURDATE")) {
                    return Long.valueOf(currentYear());
                }
            } else if (function.getName().equalsIgnoreCase("MONTH")) {
                ExpressionList parameters2 = function.getParameters();
                if (parameters2.size() == 1 && (parameters2.get(0) instanceof Function) && ((Function) parameters2.get(0)).getName().equalsIgnoreCase("CURDATE")) {
                    return Long.valueOf(currentMonth());
                }
            }
        } else if (expression instanceof ExtractExpression) {
            ExtractExpression extractExpression = (ExtractExpression) expression;
            if (extractExpression.getExpression() instanceof TimeKeyExpression) {
                TimeKeyExpression expression2 = extractExpression.getExpression();
                if (expression2.getStringValue().equalsIgnoreCase("CURRENT_DATE") || expression2.getStringValue().equalsIgnoreCase("CURRENT_TIME") || expression2.getStringValue().equalsIgnoreCase("CURRENT_TIMESTAMP")) {
                    String upperCase = extractExpression.getName().toUpperCase();
                    boolean z = -1;
                    switch (upperCase.hashCode()) {
                        case -2020697580:
                            if (upperCase.equals("MINUTE")) {
                                z = 7;
                                break;
                            }
                            break;
                        case 67452:
                            if (upperCase.equals("DAY")) {
                                z = 3;
                                break;
                            }
                            break;
                        case 67884:
                            if (upperCase.equals("DOW")) {
                                z = 4;
                                break;
                            }
                            break;
                        case 67886:
                            if (upperCase.equals("DOY")) {
                                z = 5;
                                break;
                            }
                            break;
                        case 2223588:
                            if (upperCase.equals("HOUR")) {
                                z = 6;
                                break;
                            }
                            break;
                        case 2660340:
                            if (upperCase.equals("WEEK")) {
                                z = 2;
                                break;
                            }
                            break;
                        case 2719805:
                            if (upperCase.equals("YEAR")) {
                                z = false;
                                break;
                            }
                            break;
                        case 73542240:
                            if (upperCase.equals("MONTH")) {
                                z = true;
                                break;
                            }
                            break;
                    }
                    switch (z) {
                        case false:
                            return Long.valueOf(currentYear());
                        case true:
                            return Long.valueOf(currentMonth());
                        case true:
                            return Long.valueOf(currentWeekOfYear());
                        case true:
                            return Long.valueOf(currentDayOfMonth());
                        case true:
                            return Long.valueOf(currentDayOfWeek());
                        case true:
                            return Long.valueOf(currentDayOfYear());
                        case true:
                            return Long.valueOf(currentHour());
                        case true:
                            return Long.valueOf(currentMinute());
                    }
                }
            }
        } else if (expression instanceof Addition) {
            Comparable<?> value = getValue(((Addition) expression).getLeftExpression());
            Comparable<?> value2 = getValue(((Addition) expression).getRightExpression());
            if ((value instanceof Long) && (value2 instanceof Long)) {
                return Long.valueOf(((Long) value).longValue() + ((Long) value2).longValue());
            }
            if ((value instanceof Double) && (value2 instanceof Double)) {
                return Double.valueOf(((Double) value).doubleValue() + ((Double) value2).doubleValue());
            }
        } else if (expression instanceof Subtraction) {
            Comparable<?> value3 = getValue(((Subtraction) expression).getLeftExpression());
            Comparable<?> value4 = getValue(((Subtraction) expression).getRightExpression());
            if ((value3 instanceof Long) && (value4 instanceof Long)) {
                return Long.valueOf(((Long) value3).longValue() - ((Long) value4).longValue());
            }
            if ((value3 instanceof Double) && (value4 instanceof Double)) {
                return Double.valueOf(((Double) value3).doubleValue() - ((Double) value4).doubleValue());
            }
        } else if (expression instanceof Multiplication) {
            Comparable<?> value5 = getValue(((Multiplication) expression).getLeftExpression());
            Comparable<?> value6 = getValue(((Multiplication) expression).getRightExpression());
            if ((value5 instanceof Long) && (value6 instanceof Long)) {
                return Long.valueOf(((Long) value5).longValue() * ((Long) value6).longValue());
            }
            if ((value5 instanceof Double) && (value6 instanceof Double)) {
                return Double.valueOf(((Double) value5).doubleValue() * ((Double) value6).doubleValue());
            }
        } else if (expression instanceof Division) {
            Comparable<?> value7 = getValue(((Division) expression).getLeftExpression());
            Comparable<?> value8 = getValue(((Division) expression).getRightExpression());
            if ((value7 instanceof Long) && (value8 instanceof Long)) {
                return Long.valueOf(((Long) value7).longValue() / ((Long) value8).longValue());
            }
            if ((value7 instanceof Double) && (value8 instanceof Double)) {
                return Double.valueOf(((Double) value7).doubleValue() / ((Double) value8).doubleValue());
            }
        }
        throw Exceptions.illegalArgument("Unsupported expression: '%s'%s", new Object[]{expression, createGithubIssueLink(expression)});
    }

    private long currentYear() {
        return this.localDateTime.getYear();
    }

    private long currentMonth() {
        return this.localDateTime.getMonthValue();
    }

    private long currentWeekOfYear() {
        return this.localDateTime.get(IsoFields.WEEK_OF_WEEK_BASED_YEAR);
    }

    private long currentDayOfMonth() {
        return this.localDateTime.getDayOfMonth();
    }

    private long currentDayOfWeek() {
        return this.localDateTime.getDayOfWeek().getValue();
    }

    private long currentDayOfYear() {
        return this.localDateTime.getDayOfYear();
    }

    private long currentHour() {
        return this.localDateTime.getHour();
    }

    private long currentMinute() {
        return this.localDateTime.getMinute();
    }
}
