/*
 * Decompiled with CFR 0.152.
 */
package com.ghostchu.peerbanhelper.util.traversal.forwarder;

import com.ghostchu.peerbanhelper.BanList;
import com.ghostchu.peerbanhelper.ExternalSwitch;
import com.ghostchu.peerbanhelper.Main;
import com.ghostchu.peerbanhelper.event.banwave.PeerBanEvent;
import com.ghostchu.peerbanhelper.text.Lang;
import com.ghostchu.peerbanhelper.text.TextManager;
import com.ghostchu.peerbanhelper.util.IPAddressUtil;
import com.ghostchu.peerbanhelper.util.ipdb.IPDBManager;
import com.ghostchu.peerbanhelper.util.traversal.NatAddressProvider;
import com.ghostchu.peerbanhelper.util.traversal.forwarder.Forwarder;
import com.ghostchu.peerbanhelper.util.traversal.forwarder.ForwarderIOHandlerType;
import com.ghostchu.peerbanhelper.util.traversal.forwarder.iohandler.EpollHandler;
import com.ghostchu.peerbanhelper.util.traversal.forwarder.iohandler.ForwarderIOHandler;
import com.ghostchu.peerbanhelper.util.traversal.forwarder.iohandler.IOUringHandler;
import com.ghostchu.peerbanhelper.util.traversal.forwarder.iohandler.KQueueHandler;
import com.ghostchu.peerbanhelper.util.traversal.forwarder.iohandler.NioHandler;
import com.ghostchu.peerbanhelper.util.traversal.forwarder.table.ConnectionStatistics;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.Maps;
import com.google.common.eventbus.Subscribe;
import inet.ipaddr.IPAddress;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.MultiThreadIoEventLoopGroup;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.kqueue.KQueue;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.uring.IoUring;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.GenericFutureListener;
import java.io.IOException;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.LongAdder;
import lombok.Generated;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TCPForwarderImpl
implements AutoCloseable,
Forwarder,
NatAddressProvider {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(TCPForwarderImpl.class);
    private final int proxyPort;
    private final String proxyHost;
    private final String upstreamHost;
    private final int upstreamPort;
    private final BanList banList;
    private final EventLoopGroup bossGroup;
    private final EventLoopGroup workerGroup;
    private final ScheduledExecutorService sched = Executors.newScheduledThreadPool(1, Thread.ofVirtual().factory());
    private final BiMap<InetSocketAddress, InetSocketAddress> connectionMap = Maps.synchronizedBiMap((BiMap)HashBiMap.create());
    private final Map<InetSocketAddress, ConnectionStatistics> connectionStats = new ConcurrentHashMap<InetSocketAddress, ConnectionStatistics>();
    private final Map<InetSocketAddress, Channel> downstreamChannelMap = new ConcurrentHashMap<InetSocketAddress, Channel>();
    private final LongAdder connectionHandled = new LongAdder();
    private final LongAdder connectionFailed = new LongAdder();
    private final LongAdder connectionBlocked = new LongAdder();
    private final LongAdder connectionRejected = new LongAdder();
    private final LongAdder totalToUpstream = new LongAdder();
    private final LongAdder totalToDownstream = new LongAdder();
    private final IPDBManager ipdb;
    private Channel serverChannel;
    private final ForwarderIOHandler ioHandler;

    public TCPForwarderImpl(BanList banList, String proxyHost, int proxyPort, String upstreamHost, int upstreamPort, IPDBManager ipdb) {
        this.banList = banList;
        this.proxyHost = proxyHost;
        this.proxyPort = proxyPort;
        this.upstreamHost = upstreamHost;
        this.upstreamPort = upstreamPort;
        this.ipdb = ipdb;
        this.ioHandler = IoUring.isAvailable() ? new IOUringHandler() : (Epoll.isAvailable() ? new EpollHandler() : (KQueue.isAvailable() ? new KQueueHandler() : new NioHandler()));
        log.debug("IOHandler selected: {}", (Object)this.ioHandler.getClass().getSimpleName());
        this.bossGroup = new MultiThreadIoEventLoopGroup(this.ioHandler.ioHandlerFactory());
        this.workerGroup = new MultiThreadIoEventLoopGroup(this.ioHandler.ioHandlerFactory());
        log.debug("Netty TCPForwarder created: proxy {}:{}, upstream {}:{}", new Object[]{proxyHost, proxyPort, upstreamHost, upstreamPort});
        this.sched.scheduleAtFixedRate(this::cleanupBannedConnections, 0L, 30L, TimeUnit.SECONDS);
        Main.getEventBus().register((Object)this);
    }

    @Override
    public void start() {
        ServerBootstrap b = new ServerBootstrap();
        ((ServerBootstrap)((ServerBootstrap)this.ioHandler.apply(((ServerBootstrap)b.group(this.bossGroup, this.workerGroup).channel(this.ioHandler.serverSocketChannelClass())).childHandler((ChannelHandler)new ChannelInitializer<SocketChannel>(this){
            final /* synthetic */ TCPForwarderImpl this$0;
            {
                TCPForwarderImpl tCPForwarderImpl = this$0;
                Objects.requireNonNull(tCPForwarderImpl);
                this.this$0 = tCPForwarderImpl;
            }

            protected void initChannel(SocketChannel ch) {
                ch.pipeline().addLast(new ChannelHandler[]{new ProxyFrontendHandler(this.this$0)});
            }
        })).option(ChannelOption.SO_BACKLOG, (Object)128)).option(ChannelOption.SO_REUSEADDR, (Object)true)).childOption(ChannelOption.SO_KEEPALIVE, (Object)true).childOption(ChannelOption.AUTO_READ, (Object)false);
        try {
            this.serverChannel = b.bind(this.proxyHost, this.proxyPort).sync().channel();
            log.debug("Netty TCPForwarder started and listening on {}:{}", (Object)this.proxyHost, (Object)this.proxyPort);
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            log.error(TextManager.tlUI(Lang.AUTOSTUN_TCP_FORWARDER_UNABLE_START, new Object[0]), (Throwable)e);
            this.close();
        }
    }

    @Subscribe
    public void onPeerBanned(PeerBanEvent peerBanEvent) {
        IPAddress bannedPeerAddr = peerBanEvent.getPeer().getAddress();
        log.debug("Received PeerBanEvent for {}", (Object)bannedPeerAddr);
        this.downstreamChannelMap.forEach((address, channel) -> {
            InetAddress inetAddress = address.getAddress();
            IPAddress downstreamAddress = IPAddressUtil.getIPAddress(inetAddress.getHostAddress());
            if (bannedPeerAddr.contains(downstreamAddress)) {
                log.debug("Closing connection from banned address from banEvent: {}", address);
                this.closeConnectionByDownstreamAddress((InetSocketAddress)address);
                this.connectionBlocked.increment();
            }
        });
    }

    private void cleanupBannedConnections() {
        this.downstreamChannelMap.forEach((clientAddress, channel) -> {
            IPAddress clientIp = IPAddressUtil.getIPAddress(clientAddress.getAddress().getHostAddress());
            if (this.banList.contains(clientIp)) {
                log.debug("Closing connection from banned address during cleanup: {}", clientAddress);
                this.closeConnectionByDownstreamAddress((InetSocketAddress)clientAddress);
                this.connectionBlocked.increment();
            }
        });
    }

    private void closeConnectionByDownstreamAddress(InetSocketAddress downstreamAddress) {
        Channel upstreamChannel;
        Channel downstreamChannel = this.downstreamChannelMap.get(downstreamAddress);
        if (downstreamChannel != null && downstreamChannel.hasAttr(RelayHandler.RELAY_CHANNEL_KEY) && (upstreamChannel = (Channel)downstreamChannel.attr(RelayHandler.RELAY_CHANNEL_KEY).get()) != null) {
            upstreamChannel.close();
        }
        if (downstreamChannel != null) {
            downstreamChannel.close();
        }
    }

    private boolean isDuplicateConnection(InetSocketAddress downstreamSocketAddress) {
        for (InetSocketAddress inetSocketAddress : this.connectionMap.keySet()) {
            if (!Arrays.equals(downstreamSocketAddress.getAddress().getAddress(), inetSocketAddress.getAddress().getAddress())) continue;
            return true;
        }
        return false;
    }

    @Override
    public void close() {
        this.sched.shutdown();
        if (this.serverChannel != null) {
            this.serverChannel.close().syncUninterruptibly();
        }
        this.bossGroup.shutdownGracefully().syncUninterruptibly();
        this.workerGroup.shutdownGracefully().syncUninterruptibly();
    }

    @Override
    public long getEstablishedConnections() {
        return this.downstreamChannelMap.size();
    }

    @Override
    public BiMap<InetSocketAddress, InetSocketAddress> getDownstreamAddressAsKeyConnectionMap() {
        return ImmutableBiMap.copyOf(this.connectionMap);
    }

    @Override
    public BiMap<InetSocketAddress, InetSocketAddress> getProxyLAddressAsKeyConnectionMap() {
        return ImmutableBiMap.copyOf((Map)this.connectionMap.inverse());
    }

    @Override
    public Map<InetSocketAddress, ConnectionStatistics> getDownstreamAddressAsKeyConnectionStats() {
        return Map.copyOf(this.connectionStats);
    }

    @Override
    public long getTotalToUpstream() {
        return this.totalToUpstream.sum();
    }

    @Override
    public long getTotalToDownstream() {
        return this.totalToDownstream.sum();
    }

    @Override
    public long getConnectionFailed() {
        return this.connectionFailed.sum();
    }

    @Override
    public long getConnectionHandled() {
        return this.connectionHandled.sum();
    }

    @Override
    public long getConnectionBlocked() {
        return this.connectionBlocked.sum();
    }

    @Override
    public long getConnectionRejected() {
        return this.connectionRejected.sum();
    }

    @Override
    public int getProxyPort() {
        return this.proxyPort;
    }

    @Override
    public int getUpstreamPort() {
        return this.upstreamPort;
    }

    @Override
    public String getProxyHost() {
        return this.proxyHost;
    }

    @Override
    public String getUpstremHost() {
        return this.upstreamHost;
    }

    @Override
    @Nullable
    public InetSocketAddress translate(@Nullable InetSocketAddress nattedAddress) {
        return (InetSocketAddress)this.connectionMap.inverse().get((Object)nattedAddress);
    }

    @Override
    public ForwarderIOHandlerType getForwarderIOHandlerType() {
        return this.ioHandler.ioHandlerType();
    }

    public class RelayHandler
    extends ChannelInboundHandlerAdapter {
        public static final AttributeKey<Channel> RELAY_CHANNEL_KEY = AttributeKey.valueOf((String)"relayChannel");
        private final Channel relayChannel;
        final /* synthetic */ TCPForwarderImpl this$0;

        public RelayHandler(TCPForwarderImpl this$0, Channel relayChannel) {
            TCPForwarderImpl tCPForwarderImpl = this$0;
            Objects.requireNonNull(tCPForwarderImpl);
            this.this$0 = tCPForwarderImpl;
            this.relayChannel = relayChannel;
        }

        public void channelActive(ChannelHandlerContext ctx) {
            ctx.read();
        }

        public void channelRead(ChannelHandlerContext ctx, Object msg) {
            if (this.relayChannel.isActive()) {
                this.relayChannel.writeAndFlush(msg).addListener(future -> {
                    if (future.isSuccess()) {
                        ctx.channel().read();
                    } else {
                        future.channel().close();
                    }
                });
            } else {
                ReferenceCountUtil.release((Object)msg);
                ctx.close();
            }
        }

        public void channelInactive(ChannelHandlerContext ctx) {
            if (this.relayChannel.isActive()) {
                this.relayChannel.writeAndFlush((Object)Unpooled.EMPTY_BUFFER).addListener((GenericFutureListener)ChannelFutureListener.CLOSE);
            }
            this.cleanupConnectionState(ctx.channel());
        }

        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
            if (!(cause instanceof IOException)) {
                log.debug("Exception in RelayHandler from {}", (Object)ctx.channel().remoteAddress(), (Object)cause);
            }
            ctx.close();
            if (this.relayChannel.isActive()) {
                this.relayChannel.close();
            }
        }

        private void cleanupConnectionState(Channel channel) {
            InetSocketAddress downstreamAddress;
            Channel downstreamChannel;
            Channel channel2 = downstreamChannel = channel.hasAttr(RELAY_CHANNEL_KEY) ? (Channel)((Channel)channel.attr(RELAY_CHANNEL_KEY).get()).attr(RELAY_CHANNEL_KEY).get() : channel;
            if (downstreamChannel != null && (downstreamAddress = (InetSocketAddress)downstreamChannel.remoteAddress()) != null) {
                this.this$0.connectionMap.remove((Object)downstreamAddress);
                this.this$0.connectionStats.remove(downstreamAddress);
                this.this$0.downstreamChannelMap.remove(downstreamAddress);
            }
        }
    }

    public static class TrafficCounterHandler
    extends ChannelDuplexHandler {
        private final ConnectionStatistics stats;
        private final boolean downstreamChannel;
        private final LongAdder totalToUpstream;
        private final LongAdder totalToDownstream;

        public TrafficCounterHandler(ConnectionStatistics stats, boolean downstreamChannel, LongAdder totalToUpstream, LongAdder totalToDownstream) {
            this.stats = stats;
            this.downstreamChannel = downstreamChannel;
            this.totalToUpstream = totalToUpstream;
            this.totalToDownstream = totalToDownstream;
        }

        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
            if (msg instanceof ByteBuf) {
                ByteBuf buf = (ByteBuf)msg;
                int readableBytes = buf.readableBytes();
                if (this.downstreamChannel) {
                    this.stats.getToUpstreamBytes().add(readableBytes);
                    this.totalToUpstream.add(readableBytes);
                } else {
                    this.stats.getToDownstreamBytes().add(readableBytes);
                    this.totalToDownstream.add(readableBytes);
                }
                this.stats.setLastActivityAt();
            }
            super.channelRead(ctx, msg);
        }
    }

    @ChannelHandler.Sharable
    public class ProxyFrontendHandler
    extends ChannelInboundHandlerAdapter {
        final /* synthetic */ TCPForwarderImpl this$0;

        public ProxyFrontendHandler(TCPForwarderImpl this$0) {
            TCPForwarderImpl tCPForwarderImpl = this$0;
            Objects.requireNonNull(tCPForwarderImpl);
            this.this$0 = tCPForwarderImpl;
        }

        public void channelActive(ChannelHandlerContext ctx) {
            final Channel downstreamChannel = ctx.channel();
            InetSocketAddress downstreamSocketAddress = (InetSocketAddress)downstreamChannel.remoteAddress();
            if (this.this$0.isDuplicateConnection(downstreamSocketAddress)) {
                log.debug("Multiple connections from the same IP address detected: {}, disconnecting...", (Object)downstreamSocketAddress.getAddress().getHostAddress());
                downstreamChannel.close();
                this.this$0.connectionRejected.increment();
                return;
            }
            IPAddress downstreamIpAddress = IPAddressUtil.getIPAddress(downstreamSocketAddress.getAddress().getHostAddress());
            if (this.ifBannedAddress(downstreamIpAddress)) {
                log.debug("Decline banned connection from {}:{}", (Object)downstreamIpAddress, (Object)downstreamSocketAddress.getPort());
                downstreamChannel.close();
                this.this$0.connectionBlocked.increment();
                return;
            }
            this.this$0.connectionHandled.increment();
            log.debug("Accepted new downstream connection: {}", (Object)downstreamSocketAddress);
            Bootstrap b = new Bootstrap();
            ((Bootstrap)((Bootstrap)((Bootstrap)b.group((EventLoopGroup)downstreamChannel.eventLoop())).channel(this.this$0.ioHandler.clientSocketChannelClass())).handler((ChannelHandler)new ChannelInitializer<SocketChannel>(this){
                final /* synthetic */ ProxyFrontendHandler this$1;
                {
                    ProxyFrontendHandler proxyFrontendHandler = this$1;
                    Objects.requireNonNull(proxyFrontendHandler);
                    this.this$1 = proxyFrontendHandler;
                }

                protected void initChannel(SocketChannel ch) {
                    ch.pipeline().addLast(new ChannelHandler[]{new RelayHandler(this.this$1.this$0, downstreamChannel)});
                }
            })).option(ChannelOption.SO_KEEPALIVE, (Object)true);
            Thread.ofVirtual().name("Connection establisher").start(() -> {
                ChannelFuture connectFuture = this.connectToUpstreamFriendly(b, downstreamSocketAddress);
                connectFuture.addListener(future -> {
                    if (future.isSuccess()) {
                        Channel upstreamChannel = future.channel();
                        log.debug("Connected to upstream: {}, starting dual directory forward channel", (Object)upstreamChannel.remoteAddress());
                        ctx.pipeline().addLast(new ChannelHandler[]{new RelayHandler(this.this$0, upstreamChannel)});
                        downstreamChannel.attr(RelayHandler.RELAY_CHANNEL_KEY).set((Object)upstreamChannel);
                        upstreamChannel.attr(RelayHandler.RELAY_CHANNEL_KEY).set((Object)downstreamChannel);
                        ConnectionStatistics stats = new ConnectionStatistics();
                        stats.setEstablishedAt();
                        InetSocketAddress upstreamLocalSocketAddress = (InetSocketAddress)upstreamChannel.localAddress();
                        this.this$0.connectionMap.put((Object)downstreamSocketAddress, (Object)upstreamLocalSocketAddress);
                        stats.setIpGeoData(this.this$0.ipdb.queryIPDB(downstreamSocketAddress.getAddress()).geoData().get());
                        this.this$0.connectionStats.put(downstreamSocketAddress, stats);
                        this.this$0.downstreamChannelMap.put(downstreamSocketAddress, downstreamChannel);
                        downstreamChannel.pipeline().addFirst(new ChannelHandler[]{new TrafficCounterHandler(stats, true, this.this$0.totalToUpstream, this.this$0.totalToDownstream)});
                        upstreamChannel.pipeline().addFirst(new ChannelHandler[]{new TrafficCounterHandler(stats, false, this.this$0.totalToUpstream, this.this$0.totalToDownstream)});
                        ctx.pipeline().remove((ChannelHandler)this);
                        downstreamChannel.config().setAutoRead(true);
                    } else {
                        log.debug("Error connecting to upstream server", future.cause());
                        downstreamChannel.close();
                        this.this$0.connectionFailed.increment();
                        if (this.this$0.connectionFailed.sum() % 50L == 0L) {
                            log.error(TextManager.tlUI(Lang.AUTOSTUN_TCP_FORWARDER_UNABLE_CONNECT_UPSTREAM, this.this$0.upstreamHost + ":" + this.this$0.upstreamPort, this.this$0.connectionFailed.sum()), future.cause());
                        }
                    }
                });
            });
        }

        private boolean ifBannedAddress(IPAddress downstreamIpAddress) {
            return this.this$0.banList.contains(downstreamIpAddress);
        }

        private ChannelFuture connectToUpstreamFriendly(Bootstrap b, InetSocketAddress downstreamSocket) {
            InetAddress incomingAddress = downstreamSocket.getAddress();
            InetSocketAddress upstreamAddress = new InetSocketAddress(this.this$0.upstreamHost, this.this$0.upstreamPort);
            if (upstreamAddress.getAddress().isLoopbackAddress() && incomingAddress instanceof Inet4Address && ExternalSwitch.parseBoolean("pbh.TCPForwarder.useFriendlyAddressForLoopback", true)) {
                try {
                    byte[] bytes = new byte[]{127, incomingAddress.getAddress()[1], incomingAddress.getAddress()[2], incomingAddress.getAddress()[3]};
                    InetAddress outgoingAddress = InetAddress.getByAddress(bytes);
                    try {
                        ChannelFuture firstAttempt = ((Bootstrap)b.localAddress(outgoingAddress, downstreamSocket.getPort())).connect((SocketAddress)upstreamAddress).syncUninterruptibly();
                        if (firstAttempt.isSuccess()) {
                            return firstAttempt;
                        }
                        throw new IllegalStateException("Failed to bind to friendly address " + outgoingAddress.getHostAddress() + ":" + downstreamSocket.getPort());
                    }
                    catch (Exception e) {
                        log.debug("Failed to bind to friendly address {}:{}, trying random port.", (Object)outgoingAddress.getHostAddress(), (Object)downstreamSocket.getPort());
                        try {
                            Bootstrap b2 = b.clone();
                            ChannelFuture secondAttempt = ((Bootstrap)b2.localAddress(outgoingAddress, 0)).connect((SocketAddress)upstreamAddress).syncUninterruptibly();
                            if (secondAttempt.isSuccess()) {
                                return secondAttempt;
                            }
                        }
                        catch (Exception e2) {
                            log.debug("Failed to bind to friendly address {} with random port as well.", (Object)outgoingAddress.getHostAddress());
                        }
                    }
                }
                catch (UnknownHostException bytes) {
                    // empty catch block
                }
            }
            Bootstrap b3 = b.clone();
            log.debug("Failed to bind to friendly address, falling back to default.");
            return ((Bootstrap)b3.localAddress(null)).connect((SocketAddress)upstreamAddress);
        }

        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
            if (!(cause instanceof IOException)) {
                log.debug("Exception in ProxyFrontendHandler", cause);
            }
            ctx.close();
            this.this$0.connectionFailed.increment();
        }
    }
}

