/*
 * Decompiled with CFR 0.152.
 */
package com.flipkart.gjex.grpc.interceptor;

import com.flipkart.gjex.core.context.GJEXContext;
import com.flipkart.gjex.core.filter.Filter;
import com.flipkart.gjex.core.filter.MethodFilters;
import com.flipkart.gjex.core.filter.ServerRequestParams;
import com.flipkart.gjex.core.logging.Logging;
import com.flipkart.gjex.core.util.Pair;
import com.flipkart.gjex.grpc.utils.AnnotationUtils;
import com.google.protobuf.GeneratedMessageV3;
import io.grpc.BindableService;
import io.grpc.Context;
import io.grpc.ForwardingServerCall;
import io.grpc.ForwardingServerCallListener;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.lang.reflect.Method;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.inject.Named;
import javax.inject.Singleton;
import javax.validation.ConstraintViolationException;

@Singleton
@Named(value="FilterInterceptor")
public class FilterInterceptor
implements ServerInterceptor,
Logging {
    private Map<String, List<Filter>> filtersMap = new HashMap<String, List<Filter>>();

    public void registerFilters(List<Filter> filters, List<BindableService> services) {
        Map classToInstanceMap = filters.stream().collect(Collectors.toMap(Object::getClass, Function.identity()));
        services.forEach(service -> {
            List<Pair<?, Method>> annotatedMethods = AnnotationUtils.getAnnotatedMethods(service.getClass(), MethodFilters.class);
            if (annotatedMethods != null) {
                annotatedMethods.forEach(pair -> {
                    LinkedList filtersForMethod = new LinkedList();
                    Arrays.asList(((Method)pair.getValue()).getAnnotation(MethodFilters.class).value()).forEach(filterClass -> {
                        if (!classToInstanceMap.containsKey(filterClass)) {
                            throw new RuntimeException("Filter instance not bound for Filter class :" + filterClass.getName());
                        }
                        filtersForMethod.add(classToInstanceMap.get(filterClass));
                    });
                    this.filtersMap.put((service.bindService().getServiceDescriptor().getName() + "/" + ((Method)pair.getValue()).getName()).toLowerCase(), filtersForMethod);
                });
            }
        });
    }

    public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
        final List<Filter> filters = this.filtersMap.get(call.getMethodDescriptor().getFullMethodName().toLowerCase());
        Metadata forwardHeaders = new Metadata();
        if (filters == null) {
            return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(next.startCall((ServerCall)new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(call){}, headers)){};
        }
        for (Filter filter : filters) {
            try {
                ServerRequestParams serverRequestParams = new ServerRequestParams(((SocketAddress)Objects.requireNonNull(call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR))).toString(), call.getMethodDescriptor().getFullMethodName().toLowerCase());
                filter.doFilterRequest(serverRequestParams, headers);
                for (Metadata.Key key : filter.getForwardHeaderKeys()) {
                    Object value = headers.get(key);
                    if (value == null) continue;
                    forwardHeaders.put(key, value);
                }
            }
            catch (StatusRuntimeException se) {
                call.close(se.getStatus(), se.getTrailers());
                return new ServerCall.Listener<ReqT>(){};
            }
        }
        final Context contextWithHeaders = forwardHeaders.keys().isEmpty() ? null : Context.current().withValue(GJEXContext.getHeadersKey(), (Object)forwardHeaders);
        ServerCall.Listener listener = null;
        listener = next.startCall((ServerCall)new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(call){

            public void sendMessage(RespT response) {
                Context previous = FilterInterceptor.this.attachContext(contextWithHeaders);
                try {
                    filters.forEach(filter -> filter.doProcessResponse((GeneratedMessageV3)response));
                    super.sendMessage(response);
                }
                finally {
                    FilterInterceptor.this.detachContext(contextWithHeaders, previous);
                }
            }

            public void sendHeaders(Metadata responseHeaders) {
                Context previous = FilterInterceptor.this.attachContext(contextWithHeaders);
                try {
                    filters.forEach(filter -> filter.doProcessResponseHeaders(responseHeaders));
                    super.sendHeaders(responseHeaders);
                }
                finally {
                    FilterInterceptor.this.detachContext(contextWithHeaders, previous);
                }
            }
        }, headers);
        return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(listener){

            public void onHalfClose() {
                Context previous = FilterInterceptor.this.attachContext(contextWithHeaders);
                try {
                    super.onHalfClose();
                }
                catch (RuntimeException ex) {
                    FilterInterceptor.this.handleException(call, ex);
                }
                finally {
                    FilterInterceptor.this.detachContext(contextWithHeaders, previous);
                }
            }

            public void onMessage(ReqT request) {
                Context previous = FilterInterceptor.this.attachContext(contextWithHeaders);
                try {
                    filters.forEach(filter -> filter.doProcessRequest((GeneratedMessageV3)request));
                    super.onMessage(request);
                }
                finally {
                    FilterInterceptor.this.detachContext(contextWithHeaders, previous);
                }
            }
        };
    }

    private <ReqT, RespT> void handleException(ServerCall<ReqT, RespT> call, Exception e) {
        this.error("Closing gRPC call due to RuntimeException.", e);
        Status returnStatus = Status.INTERNAL;
        if (ConstraintViolationException.class.isAssignableFrom(e.getClass())) {
            returnStatus = Status.INVALID_ARGUMENT;
        }
        try {
            call.close(returnStatus.withDescription(e.getMessage()), new Metadata());
        }
        catch (IllegalStateException ie) {
            this.warn("Exception while attempting to close ServerCall stream: " + ie.getMessage());
        }
    }

    private Context attachContext(Context context) {
        return context == null ? null : context.attach();
    }

    private void detachContext(Context currentContext, Context previousContext) {
        if (currentContext != null) {
            currentContext.detach(previousContext);
        }
    }
}

