/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tinkerpop.gremlin.server.handler;

import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.util.Attribute;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.time.Duration;
import java.time.LocalDateTime;
import java.time.temporal.Temporal;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.function.Function;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.tinkerpop.gremlin.server.Settings;
import org.apache.tinkerpop.gremlin.server.auth.AuthenticatedUser;
import org.apache.tinkerpop.gremlin.server.auth.AuthenticationException;
import org.apache.tinkerpop.gremlin.server.auth.Authenticator;
import org.apache.tinkerpop.gremlin.server.authz.Authorizer;
import org.apache.tinkerpop.gremlin.server.handler.AbstractAuthenticationHandler;
import org.apache.tinkerpop.gremlin.server.handler.StateKey;
import org.apache.tinkerpop.gremlin.util.message.RequestMessage;
import org.apache.tinkerpop.gremlin.util.message.ResponseMessage;
import org.apache.tinkerpop.gremlin.util.message.ResponseStatusCode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ChannelHandler.Sharable
public class SaslAuthenticationHandler
extends AbstractAuthenticationHandler {
    private static final Logger logger = LoggerFactory.getLogger(SaslAuthenticationHandler.class);
    private static final Base64.Decoder BASE64_DECODER = Base64.getDecoder();
    private static final Base64.Encoder BASE64_ENCODER = Base64.getEncoder();
    public static final Duration MAX_REQUEST_DEFERRABLE_DURATION = Duration.ofSeconds(5L);
    private static final Logger auditLogger = LoggerFactory.getLogger((String)"audit.org.apache.tinkerpop.gremlin.server");
    protected final Settings settings;

    @Deprecated
    public SaslAuthenticationHandler(Authenticator authenticator, Settings settings) {
        this(authenticator, null, settings);
    }

    public SaslAuthenticationHandler(Authenticator authenticator, Authorizer authorizer, Settings settings) {
        super(authenticator, authorizer);
        this.settings = settings;
    }

    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (!(msg instanceof RequestMessage)) {
            logger.warn("{} only processes RequestMessage instances - received {} - channel closing", (Object)((Object)((Object)this)).getClass().getSimpleName(), msg.getClass());
            ctx.close();
            return;
        }
        RequestMessage requestMessage = (RequestMessage)msg;
        Attribute negotiator = ctx.channel().attr(StateKey.NEGOTIATOR);
        Attribute request = ctx.channel().attr(StateKey.REQUEST_MESSAGE);
        Attribute deferredRequests = ctx.channel().attr(StateKey.DEFERRED_REQUEST_MESSAGES);
        if (negotiator.get() == null) {
            try {
                negotiator.set((Object)this.authenticator.newSaslNegotiator(this.getRemoteInetAddress(ctx)));
                request.set((Object)requestMessage);
                ResponseMessage authenticate = ResponseMessage.build((RequestMessage)requestMessage).code(ResponseStatusCode.AUTHENTICATE).create();
                ctx.writeAndFlush((Object)authenticate);
            }
            catch (Exception ex) {
                logger.error(String.format("%s is not ready to handle requests - check its configuration or related services", this.authenticator.getClass().getSimpleName()), (Throwable)ex);
                this.respondWithError(requestMessage, builder -> builder.statusMessage("Authenticator is not ready to handle requests").code(ResponseStatusCode.SERVER_ERROR), ctx);
            }
            return;
        }
        if (!requestMessage.getOp().equals("authentication")) {
            deferredRequests.setIfAbsent((Object)new ImmutablePair((Object)LocalDateTime.now(), new ArrayList()));
            ((List)((Pair)deferredRequests.get()).getValue()).add(requestMessage);
            Duration deferredDuration = Duration.between((Temporal)((Pair)deferredRequests.get()).getKey(), LocalDateTime.now());
            if (deferredDuration.compareTo(MAX_REQUEST_DEFERRABLE_DURATION) > 0) {
                this.respondWithError(requestMessage, builder -> builder.statusMessage("Authentication did not finish in the allowed duration (" + MAX_REQUEST_DEFERRABLE_DURATION + "s).").code(ResponseStatusCode.UNAUTHORIZED), ctx);
                return;
            }
            return;
        }
        if (!requestMessage.getArgs().containsKey("sasl")) {
            this.respondWithError(requestMessage, builder -> builder.statusMessage("Failed to authenticate").code(ResponseStatusCode.UNAUTHORIZED), ctx);
            return;
        }
        Object saslObject = requestMessage.getArgs().get("sasl");
        if (!(saslObject instanceof String)) {
            this.respondWithError(requestMessage, builder -> builder.statusMessage("Incorrect type for : sasl - base64 encoded String is expected").code(ResponseStatusCode.REQUEST_ERROR_MALFORMED_REQUEST), ctx);
            return;
        }
        try {
            byte[] saslResponse = BASE64_DECODER.decode((String)saslObject);
            byte[] saslMessage = ((Authenticator.SaslNegotiator)negotiator.get()).evaluateResponse(saslResponse);
            if (!((Authenticator.SaslNegotiator)negotiator.get()).isComplete()) {
                HashMap<String, String> metadata = new HashMap<String, String>();
                metadata.put("sasl", BASE64_ENCODER.encodeToString(saslMessage));
                ResponseMessage authenticate = ResponseMessage.build((RequestMessage)requestMessage).statusAttributes(metadata).code(ResponseStatusCode.AUTHENTICATE).create();
                ctx.writeAndFlush((Object)authenticate);
                return;
            }
            AuthenticatedUser user = ((Authenticator.SaslNegotiator)negotiator.get()).getAuthenticatedUser();
            ctx.channel().attr(StateKey.AUTHENTICATED_USER).set((Object)user);
            if (this.settings.enableAuditLog.booleanValue()) {
                String address = ctx.channel().remoteAddress().toString();
                if (address.startsWith("/") && address.length() > 1) {
                    address = address.substring(1);
                }
                String[] authClassParts = this.authenticator.getClass().toString().split("[.]");
                auditLogger.info("User {} with address {} authenticated by {}", new Object[]{user.getName(), address, authClassParts[authClassParts.length - 1]});
            }
            ctx.pipeline().remove((ChannelHandler)this);
            RequestMessage original = (RequestMessage)request.get();
            ctx.fireChannelRead((Object)original);
            if (deferredRequests.get() != null) {
                ((List)((Pair)deferredRequests.getAndSet(null)).getValue()).forEach(arg_0 -> ((ChannelHandlerContext)ctx).fireChannelRead(arg_0));
            }
        }
        catch (AuthenticationException ae) {
            this.respondWithError(requestMessage, builder -> builder.statusMessage(ae.getMessage()).code(ResponseStatusCode.UNAUTHORIZED), ctx);
        }
    }

    private void respondWithError(RequestMessage requestMessage, Function<ResponseMessage.Builder, ResponseMessage.Builder> buildResponse, ChannelHandlerContext ctx) {
        Attribute originalRequest = ctx.channel().attr(StateKey.REQUEST_MESSAGE);
        Attribute deferredRequests = ctx.channel().attr(StateKey.DEFERRED_REQUEST_MESSAGES);
        if (!requestMessage.getOp().equals("authentication")) {
            ctx.write((Object)buildResponse.apply(ResponseMessage.build((RequestMessage)requestMessage)).create());
        }
        if (originalRequest.get() != null) {
            ctx.write((Object)buildResponse.apply(ResponseMessage.build((RequestMessage)((RequestMessage)originalRequest.get()))).create());
        }
        if (deferredRequests.get() != null) {
            ((List)((Pair)deferredRequests.getAndSet(null)).getValue()).stream().map(ResponseMessage::build).map(buildResponse).map(ResponseMessage.Builder::create).forEach(arg_0 -> ((ChannelHandlerContext)ctx).write(arg_0));
        }
        ctx.flush();
    }

    private InetAddress getRemoteInetAddress(ChannelHandlerContext ctx) {
        Channel channel = ctx.channel();
        if (null == channel) {
            return null;
        }
        SocketAddress genericSocketAddr = channel.remoteAddress();
        if (!(genericSocketAddr instanceof InetSocketAddress)) {
            return null;
        }
        return ((InetSocketAddress)genericSocketAddr).getAddress();
    }
}

