/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.rag.content.retriever.neo4j;

import dev.langchain4j.community.rag.content.retriever.neo4j.Neo4jGraph;
import dev.langchain4j.community.rag.content.retriever.neo4j.Neo4jUtils;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.neo4j.cypherdsl.core.Statement;
import org.neo4j.cypherdsl.core.renderer.Configuration;
import org.neo4j.cypherdsl.core.renderer.Dialect;
import org.neo4j.cypherdsl.core.renderer.Renderer;
import org.neo4j.cypherdsl.parser.CypherParser;
import org.neo4j.driver.Record;
import org.neo4j.driver.types.Type;
import org.neo4j.driver.types.TypeSystem;

public class Neo4jText2CypherRetriever
implements ContentRetriever {
    private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from((String)"Task:Generate Cypher statement to query a graph database.\nInstructions\nUse only the provided relationship types and properties in the schema.\nDo not use any other relationship types or properties that are not provided.\n\nSchema:\n{{schema}}\n\n{{examples}}\nNote: Do not include any explanations or apologies in your responses.\nDo not respond to any questions that might ask anything else than for you to construct a Cypher statement.\nDo not include any text except the generated Cypher statement.\nThe question is: {{question}}\n");
    private static final Type NODE = TypeSystem.getDefault().NODE();
    private static final Type RELATIONSHIP = TypeSystem.getDefault().RELATIONSHIP();
    private static final Type PATH = TypeSystem.getDefault().PATH();
    private final Neo4jGraph graph;
    private final ChatModel chatModel;
    private final PromptTemplate promptTemplate;
    private final int maxRetries;
    private final List<String> examples;
    private final List<String> relationships;
    private final String dialect;

    public Neo4jText2CypherRetriever(Neo4jGraph graph, ChatModel chatModel, PromptTemplate promptTemplate, List<String> examples, int maxRetries, List<String> relationships, String dialect) {
        this.graph = (Neo4jGraph)ValidationUtils.ensureNotNull((Object)graph, (String)"graph");
        this.chatModel = (ChatModel)ValidationUtils.ensureNotNull((Object)chatModel, (String)"chatModel");
        this.promptTemplate = (PromptTemplate)Utils.getOrDefault((Object)promptTemplate, (Object)DEFAULT_PROMPT_TEMPLATE);
        this.examples = Utils.getOrDefault(examples, List.of());
        this.maxRetries = maxRetries;
        this.relationships = Utils.getOrDefault(relationships, List.of());
        this.dialect = (String)Utils.getOrDefault((Object)dialect, (Object)Dialect.NEO4J_5_26.name());
    }

    public static Builder builder() {
        return new Builder();
    }

    public Neo4jGraph getGraph() {
        return this.graph;
    }

    public ChatModel getChatModel() {
        return this.chatModel;
    }

    public PromptTemplate getPromptTemplate() {
        return this.promptTemplate;
    }

    public List<Content> retrieve(Query query) {
        String question = query.text();
        String schema = this.graph.getSchema();
        String examplesString = "";
        if (!this.examples.isEmpty()) {
            String exampleJoin = String.join((CharSequence)"\n", this.examples);
            examplesString = String.format("Cypher examples: \n%s\n", exampleJoin);
        }
        Map<String, String> templateVariables = Map.of("schema", schema, "question", question, "examples", examplesString);
        String cypherPrompt = this.promptTemplate.apply(templateVariables).text();
        ArrayList<UserMessage> messages = new ArrayList<UserMessage>();
        messages.add(UserMessage.from((String)cypherPrompt));
        String emptyResultMsg = "The query result is empty. If `maxRetries` number is not reached, the query will be re-generated";
        try {
            return (List)RetryUtils.withRetry(() -> {
                List<String> response;
                String cypherQuery = this.generateCypherQuery(messages);
                try {
                    response = this.executeQuery(cypherQuery);
                }
                catch (Exception e) {
                    String errorUserMsg = String.format("The previous Cypher Statement throws the following error, consider it to return the correct statement: `%s`.\nPlease, try to return a valid query.\n\nCypher query:\n", e.getMessage());
                    messages.add(UserMessage.from((String)errorUserMsg));
                    throw e;
                }
                List<Content> list = response.stream().map(Content::from).toList();
                if (list.isEmpty()) {
                    String errorUserMsg = "The previous Cypher Statement returns no result, consider it to return the correct statement.\nPlease, try to return a valid query.\n\nCypher query:\n";
                    messages.add(UserMessage.from((String)errorUserMsg));
                    throw new RuntimeException(emptyResultMsg);
                }
                return list;
            }, (int)this.maxRetries);
        }
        catch (Exception e) {
            if (e.getMessage().contains(emptyResultMsg)) {
                return List.of();
            }
            throw e;
        }
    }

    private String getFixedCypherWithDSL(String cypher) {
        if (this.relationships.isEmpty()) {
            return cypher;
        }
        Statement statement = CypherParser.parse((String)cypher);
        Configuration.Builder configuration = Configuration.newConfig().withPrettyPrint(false).alwaysEscapeNames(false).withEnforceSchema(true).withDialect(Dialect.valueOf((String)this.dialect));
        this.relationships.stream().map(Configuration::relationshipDefinition).forEach(arg_0 -> ((Configuration.Builder)configuration).withRelationshipDefinition(arg_0));
        return Renderer.getRenderer((Configuration)configuration.build()).render(statement);
    }

    private String generateCypherQuery(List<ChatMessage> messages) {
        String cypherQuery = this.chatModel.chat(messages).aiMessage().text();
        cypherQuery = this.getFixedCypherWithDSL(cypherQuery);
        return Neo4jUtils.getBacktickText(cypherQuery);
    }

    private List<String> executeQuery(String cypherQuery) {
        List<Record> records = this.graph.executeRead(cypherQuery);
        return records.stream().flatMap(r -> r.values().stream()).map(value -> {
            boolean isEntity;
            boolean bl = isEntity = NODE.isTypeOf(value) || RELATIONSHIP.isTypeOf(value) || PATH.isTypeOf(value);
            if (isEntity) {
                return value.asMap().toString();
            }
            return value.toString();
        }).toList();
    }

    public static class Builder {
        protected Neo4jGraph graph;
        protected ChatModel chatModel;
        protected PromptTemplate promptTemplate;
        protected List<String> relationships;
        protected String dialect;
        protected int maxRetries = 3;
        protected List<String> examples;

        public Builder graph(Neo4jGraph graph) {
            this.graph = graph;
            return this;
        }

        public Builder chatModel(ChatModel chatModel) {
            this.chatModel = chatModel;
            return this;
        }

        public Builder promptTemplate(PromptTemplate promptTemplate) {
            this.promptTemplate = promptTemplate;
            return this;
        }

        public Builder relationships(List<String> relationships) {
            this.relationships = relationships;
            return this;
        }

        public Builder dialect(String dialect) {
            this.dialect = dialect;
            return this;
        }

        public Builder maxRetries(int maxRetries) {
            this.maxRetries = maxRetries;
            return this;
        }

        public Builder examples(List<String> examples) {
            this.examples = examples;
            return this;
        }

        public Neo4jText2CypherRetriever build() {
            return new Neo4jText2CypherRetriever(this.graph, this.chatModel, this.promptTemplate, this.examples, this.maxRetries, this.relationships, this.dialect);
        }
    }
}

