package org.springframework.web.socket.messaging;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import jodd.util.StringPool;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.SmartLifecycle;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import org.springframework.web.socket.handler.SessionLimitExceededException;
import org.springframework.web.socket.sockjs.transport.session.PollingSockJsSession;
import org.springframework.web.socket.sockjs.transport.session.StreamingSockJsSession;

/* loaded from: input_file:BOOT-INF/lib/spring-websocket-5.2.8.RELEASE.jar:org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.class */
public class SubProtocolWebSocketHandler implements WebSocketHandler, SubProtocolCapable, MessageHandler, SmartLifecycle {
    private static final int DEFAULT_TIME_TO_FIRST_MESSAGE = 60000;
    private final MessageChannel clientInboundChannel;
    private final SubscribableChannel clientOutboundChannel;

    @Nullable
    private SubProtocolHandler defaultProtocolHandler;
    private final Log logger = LogFactory.getLog((Class<?>) SubProtocolWebSocketHandler.class);
    private final Map<String, SubProtocolHandler> protocolHandlerLookup = new TreeMap(String.CASE_INSENSITIVE_ORDER);
    private final Set<SubProtocolHandler> protocolHandlers = new LinkedHashSet();
    private final Map<String, WebSocketSessionHolder> sessions = new ConcurrentHashMap();
    private int sendTimeLimit = 10000;
    private int sendBufferSizeLimit = 524288;
    private int timeToFirstMessage = 60000;
    private volatile long lastSessionCheckTime = System.currentTimeMillis();
    private final ReentrantLock sessionCheckLock = new ReentrantLock();
    private final DefaultStats stats = new DefaultStats();
    private volatile boolean running = false;
    private final Object lifecycleMonitor = new Object();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:BOOT-INF/lib/spring-websocket-5.2.8.RELEASE.jar:org/springframework/web/socket/messaging/SubProtocolWebSocketHandler$DefaultStats.class */
    public class DefaultStats implements Stats {
        private final AtomicInteger total;
        private final AtomicInteger webSocket;
        private final AtomicInteger httpStreaming;
        private final AtomicInteger httpPolling;
        private final AtomicInteger limitExceeded;
        private final AtomicInteger noMessagesReceived;
        private final AtomicInteger transportError;

        private DefaultStats() {
            this.total = new AtomicInteger();
            this.webSocket = new AtomicInteger();
            this.httpStreaming = new AtomicInteger();
            this.httpPolling = new AtomicInteger();
            this.limitExceeded = new AtomicInteger();
            this.noMessagesReceived = new AtomicInteger();
            this.transportError = new AtomicInteger();
        }

        @Override // org.springframework.web.socket.messaging.SubProtocolWebSocketHandler.Stats
        public int getTotalSessions() {
            return this.total.get();
        }

        @Override // org.springframework.web.socket.messaging.SubProtocolWebSocketHandler.Stats
        public int getWebSocketSessions() {
            return this.webSocket.get();
        }

        @Override // org.springframework.web.socket.messaging.SubProtocolWebSocketHandler.Stats
        public int getHttpStreamingSessions() {
            return this.httpStreaming.get();
        }

        @Override // org.springframework.web.socket.messaging.SubProtocolWebSocketHandler.Stats
        public int getHttpPollingSessions() {
            return this.httpPolling.get();
        }

        @Override // org.springframework.web.socket.messaging.SubProtocolWebSocketHandler.Stats
        public int getLimitExceededSessions() {
            return this.limitExceeded.get();
        }

        @Override // org.springframework.web.socket.messaging.SubProtocolWebSocketHandler.Stats
        public int getNoMessagesReceivedSessions() {
            return this.noMessagesReceived.get();
        }

        @Override // org.springframework.web.socket.messaging.SubProtocolWebSocketHandler.Stats
        public int getTransportErrorSessions() {
            return this.transportError.get();
        }

        void incrementSessionCount(WebSocketSession webSocketSession) {
            getCountFor(webSocketSession).incrementAndGet();
            this.total.incrementAndGet();
        }

        void decrementSessionCount(WebSocketSession webSocketSession) {
            getCountFor(webSocketSession).decrementAndGet();
        }

        void incrementLimitExceededCount() {
            this.limitExceeded.incrementAndGet();
        }

        void incrementNoMessagesReceivedCount() {
            this.noMessagesReceived.incrementAndGet();
        }

        void incrementTransportError() {
            this.transportError.incrementAndGet();
        }

        AtomicInteger getCountFor(WebSocketSession webSocketSession) {
            return webSocketSession instanceof PollingSockJsSession ? this.httpPolling : webSocketSession instanceof StreamingSockJsSession ? this.httpStreaming : this.webSocket;
        }

        public String toString() {
            return SubProtocolWebSocketHandler.this.sessions.size() + " current WS(" + this.webSocket.get() + ")-HttpStream(" + this.httpStreaming.get() + ")-HttpPoll(" + this.httpPolling.get() + "), " + this.total.get() + " total, " + (this.limitExceeded.get() + this.noMessagesReceived.get()) + " closed abnormally (" + this.noMessagesReceived.get() + " connect failure, " + this.limitExceeded.get() + " send limit, " + this.transportError.get() + " transport error)";
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/spring-websocket-5.2.8.RELEASE.jar:org/springframework/web/socket/messaging/SubProtocolWebSocketHandler$Stats.class */
    public interface Stats {
        int getTotalSessions();

        int getWebSocketSessions();

        int getHttpStreamingSessions();

        int getHttpPollingSessions();

        int getLimitExceededSessions();

        int getNoMessagesReceivedSessions();

        int getTransportErrorSessions();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:BOOT-INF/lib/spring-websocket-5.2.8.RELEASE.jar:org/springframework/web/socket/messaging/SubProtocolWebSocketHandler$WebSocketSessionHolder.class */
    public static class WebSocketSessionHolder {
        private final WebSocketSession session;
        private final long createTime = System.currentTimeMillis();
        private volatile boolean hasHandledMessages;

        public WebSocketSessionHolder(WebSocketSession webSocketSession) {
            this.session = webSocketSession;
        }

        public WebSocketSession getSession() {
            return this.session;
        }

        public long getCreateTime() {
            return this.createTime;
        }

        public void setHasHandledMessages() {
            this.hasHandledMessages = true;
        }

        public boolean hasHandledMessages() {
            return this.hasHandledMessages;
        }

        public String toString() {
            return "WebSocketSessionHolder[session=" + this.session + ", createTime=" + this.createTime + ", hasHandledMessages=" + this.hasHandledMessages + "]";
        }
    }

    public SubProtocolWebSocketHandler(MessageChannel messageChannel, SubscribableChannel subscribableChannel) {
        Assert.notNull(messageChannel, "Inbound MessageChannel must not be null");
        Assert.notNull(subscribableChannel, "Outbound MessageChannel must not be null");
        this.clientInboundChannel = messageChannel;
        this.clientOutboundChannel = subscribableChannel;
    }

    public void setProtocolHandlers(List<SubProtocolHandler> list) {
        this.protocolHandlerLookup.clear();
        this.protocolHandlers.clear();
        Iterator<SubProtocolHandler> it = list.iterator();
        while (it.hasNext()) {
            addProtocolHandler(it.next());
        }
    }

    public List<SubProtocolHandler> getProtocolHandlers() {
        return new ArrayList(this.protocolHandlers);
    }

    public void addProtocolHandler(SubProtocolHandler subProtocolHandler) {
        List<String> supportedProtocols = subProtocolHandler.getSupportedProtocols();
        if (CollectionUtils.isEmpty(supportedProtocols)) {
            if (this.logger.isErrorEnabled()) {
                this.logger.error("No sub-protocols for " + subProtocolHandler);
                return;
            }
            return;
        }
        for (String str : supportedProtocols) {
            SubProtocolHandler put = this.protocolHandlerLookup.put(str, subProtocolHandler);
            if (put != null && put != subProtocolHandler) {
                throw new IllegalStateException("Cannot map " + subProtocolHandler + " to protocol '" + str + "': already mapped to " + put + ".");
            }
        }
        this.protocolHandlers.add(subProtocolHandler);
    }

    public Map<String, SubProtocolHandler> getProtocolHandlerMap() {
        return this.protocolHandlerLookup;
    }

    public void setDefaultProtocolHandler(@Nullable SubProtocolHandler subProtocolHandler) {
        this.defaultProtocolHandler = subProtocolHandler;
        if (this.protocolHandlerLookup.isEmpty()) {
            setProtocolHandlers(Collections.singletonList(subProtocolHandler));
        }
    }

    @Nullable
    public SubProtocolHandler getDefaultProtocolHandler() {
        return this.defaultProtocolHandler;
    }

    @Override // org.springframework.web.socket.SubProtocolCapable
    public List<String> getSubProtocols() {
        return new ArrayList(this.protocolHandlerLookup.keySet());
    }

    public void setSendTimeLimit(int i) {
        this.sendTimeLimit = i;
    }

    public int getSendTimeLimit() {
        return this.sendTimeLimit;
    }

    public void setSendBufferSizeLimit(int i) {
        this.sendBufferSizeLimit = i;
    }

    public int getSendBufferSizeLimit() {
        return this.sendBufferSizeLimit;
    }

    public void setTimeToFirstMessage(int i) {
        this.timeToFirstMessage = i;
    }

    public int getTimeToFirstMessage() {
        return this.timeToFirstMessage;
    }

    public String getStatsInfo() {
        return this.stats.toString();
    }

    public Stats getStats() {
        return this.stats;
    }

    @Override // org.springframework.context.Lifecycle
    public final void start() {
        Assert.isTrue((this.defaultProtocolHandler == null && this.protocolHandlers.isEmpty()) ? false : true, "No handlers");
        synchronized (this.lifecycleMonitor) {
            this.clientOutboundChannel.subscribe(this);
            this.running = true;
        }
    }

    @Override // org.springframework.context.Lifecycle
    public final void stop() {
        synchronized (this.lifecycleMonitor) {
            this.running = false;
            this.clientOutboundChannel.unsubscribe(this);
        }
        for (WebSocketSessionHolder webSocketSessionHolder : this.sessions.values()) {
            try {
                webSocketSessionHolder.getSession().close(CloseStatus.GOING_AWAY);
            } catch (Throwable th) {
                if (this.logger.isWarnEnabled()) {
                    this.logger.warn("Failed to close '" + webSocketSessionHolder.getSession() + "': " + th);
                }
            }
        }
    }

    @Override // org.springframework.context.SmartLifecycle
    public final void stop(Runnable runnable) {
        synchronized (this.lifecycleMonitor) {
            stop();
            runnable.run();
        }
    }

    @Override // org.springframework.context.Lifecycle
    public final boolean isRunning() {
        return this.running;
    }

    @Override // org.springframework.web.socket.WebSocketHandler
    public void afterConnectionEstablished(WebSocketSession webSocketSession) throws Exception {
        if (webSocketSession.isOpen()) {
            this.stats.incrementSessionCount(webSocketSession);
            WebSocketSession decorateSession = decorateSession(webSocketSession);
            this.sessions.put(decorateSession.getId(), new WebSocketSessionHolder(decorateSession));
            findProtocolHandler(decorateSession).afterSessionStarted(decorateSession, this.clientInboundChannel);
        }
    }

    @Override // org.springframework.web.socket.WebSocketHandler
    public void handleMessage(WebSocketSession webSocketSession, WebSocketMessage<?> webSocketMessage) throws Exception {
        WebSocketSessionHolder webSocketSessionHolder = this.sessions.get(webSocketSession.getId());
        if (webSocketSessionHolder != null) {
            webSocketSession = webSocketSessionHolder.getSession();
        }
        findProtocolHandler(webSocketSession).handleMessageFromClient(webSocketSession, webSocketMessage, this.clientInboundChannel);
        if (webSocketSessionHolder != null) {
            webSocketSessionHolder.setHasHandledMessages();
        }
        checkSessions();
    }

    @Override // org.springframework.messaging.MessageHandler
    public void handleMessage(Message<?> message) throws MessagingException {
        String resolveSessionId = resolveSessionId(message);
        if (resolveSessionId == null) {
            if (this.logger.isErrorEnabled()) {
                this.logger.error("Could not find session id in " + message);
                return;
            }
            return;
        }
        WebSocketSessionHolder webSocketSessionHolder = this.sessions.get(resolveSessionId);
        if (webSocketSessionHolder == null) {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("No session for " + message);
                return;
            }
            return;
        }
        WebSocketSession session = webSocketSessionHolder.getSession();
        try {
            findProtocolHandler(session).handleMessageToClient(session, message);
        } catch (SessionLimitExceededException e) {
            try {
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("Terminating '" + session + StringPool.SINGLE_QUOTE, e);
                } else if (this.logger.isWarnEnabled()) {
                    this.logger.warn("Terminating '" + session + "': " + e.getMessage());
                }
                this.stats.incrementLimitExceededCount();
                clearSession(session, e.getStatus());
                session.close(e.getStatus());
            } catch (Exception e2) {
                this.logger.debug("Failure while closing session " + resolveSessionId + ".", e2);
            }
        } catch (Exception e3) {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("Failed to send message to client in " + session + ": " + message, e3);
            }
        }
    }

    @Override // org.springframework.web.socket.WebSocketHandler
    public void handleTransportError(WebSocketSession webSocketSession, Throwable th) throws Exception {
        this.stats.incrementTransportError();
    }

    @Override // org.springframework.web.socket.WebSocketHandler
    public void afterConnectionClosed(WebSocketSession webSocketSession, CloseStatus closeStatus) throws Exception {
        clearSession(webSocketSession, closeStatus);
    }

    @Override // org.springframework.web.socket.WebSocketHandler
    public boolean supportsPartialMessages() {
        return false;
    }

    protected WebSocketSession decorateSession(WebSocketSession webSocketSession) {
        return new ConcurrentWebSocketSessionDecorator(webSocketSession, getSendTimeLimit(), getSendBufferSizeLimit());
    }

    protected final SubProtocolHandler findProtocolHandler(WebSocketSession webSocketSession) {
        SubProtocolHandler next;
        String str = null;
        try {
            str = webSocketSession.getAcceptedProtocol();
        } catch (Exception e) {
            this.logger.error("Failed to obtain session.getAcceptedProtocol(): will use the default protocol handler (if configured).", e);
        }
        if (StringUtils.hasLength(str)) {
            next = this.protocolHandlerLookup.get(str);
            if (next == null) {
                throw new IllegalStateException("No handler for '" + str + "' among " + this.protocolHandlerLookup);
            }
        } else if (this.defaultProtocolHandler != null) {
            next = this.defaultProtocolHandler;
        } else {
            if (this.protocolHandlers.size() != 1) {
                throw new IllegalStateException("Multiple protocol handlers configured and no protocol was negotiated. Consider configuring a default SubProtocolHandler.");
            }
            next = this.protocolHandlers.iterator().next();
        }
        return next;
    }

    @Nullable
    private String resolveSessionId(Message<?> message) {
        String resolveSessionId;
        Iterator<SubProtocolHandler> it = this.protocolHandlerLookup.values().iterator();
        while (it.hasNext()) {
            String resolveSessionId2 = it.next().resolveSessionId(message);
            if (resolveSessionId2 != null) {
                return resolveSessionId2;
            }
        }
        if (this.defaultProtocolHandler == null || (resolveSessionId = this.defaultProtocolHandler.resolveSessionId(message)) == null) {
            return null;
        }
        return resolveSessionId;
    }

    private void checkSessions() {
        long currentTimeMillis = System.currentTimeMillis();
        if (isRunning() && currentTimeMillis - this.lastSessionCheckTime >= getTimeToFirstMessage() && this.sessionCheckLock.tryLock()) {
            try {
                for (WebSocketSessionHolder webSocketSessionHolder : this.sessions.values()) {
                    if (!webSocketSessionHolder.hasHandledMessages()) {
                        long createTime = currentTimeMillis - webSocketSessionHolder.getCreateTime();
                        if (createTime >= getTimeToFirstMessage()) {
                            WebSocketSession session = webSocketSessionHolder.getSession();
                            if (this.logger.isInfoEnabled()) {
                                this.logger.info("No messages received after " + createTime + " ms. Closing " + webSocketSessionHolder.getSession() + ".");
                            }
                            try {
                                this.stats.incrementNoMessagesReceivedCount();
                                session.close(CloseStatus.SESSION_NOT_RELIABLE);
                            } catch (Throwable th) {
                                if (this.logger.isWarnEnabled()) {
                                    this.logger.warn("Failed to close unreliable " + session, th);
                                }
                            }
                        }
                    }
                }
            } finally {
                this.lastSessionCheckTime = currentTimeMillis;
                this.sessionCheckLock.unlock();
            }
        }
    }

    private void clearSession(WebSocketSession webSocketSession, CloseStatus closeStatus) throws Exception {
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("Clearing session " + webSocketSession.getId());
        }
        if (this.sessions.remove(webSocketSession.getId()) != null) {
            this.stats.decrementSessionCount(webSocketSession);
        }
        findProtocolHandler(webSocketSession).afterSessionEnded(webSocketSession, closeStatus, this.clientInboundChannel);
    }

    public String toString() {
        return "SubProtocolWebSocketHandler" + this.protocolHandlers;
    }
}
