//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//

package org.eclipse.jetty.websocket.javax.tests.server;

import java.nio.file.Path;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.CloseReason;
import javax.websocket.ContainerProvider;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerEndpoint;

import org.eclipse.jetty.toolchain.test.MavenTestingUtils;
import org.eclipse.jetty.webapp.WebAppContext;
import org.eclipse.jetty.websocket.javax.common.JavaxWebSocketSession;
import org.eclipse.jetty.websocket.javax.tests.EventSocket;
import org.eclipse.jetty.websocket.javax.tests.WSServer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class WebAppClassLoaderTest
{
    @ServerEndpoint("/echo")
    public static class MySocket
    {
        public static final CountDownLatch closeLatch = new CountDownLatch(1);
        public static final Map<String, ClassLoader> classLoaders = new ConcurrentHashMap<>();

        public MySocket()
        {
            classLoaders.put("constructor", Thread.currentThread().getContextClassLoader());
        }

        @OnOpen
        public void onOpen(Session session)
        {
            classLoaders.put("onOpen", Thread.currentThread().getContextClassLoader());
        }

        @OnMessage
        public void onMessage(Session session, String msg)
        {
            classLoaders.put("onMessage", Thread.currentThread().getContextClassLoader());
        }

        @OnError
        public void onError(Throwable error)
        {
            classLoaders.put("onError", Thread.currentThread().getContextClassLoader());
        }

        @OnClose
        public void onClose(CloseReason closeReason)
        {
            classLoaders.put("onClose", Thread.currentThread().getContextClassLoader());
            closeLatch.countDown();
        }
    }

    private WSServer server;
    private WebAppContext webapp;

    @BeforeEach
    public void startServer() throws Exception
    {
        Path testdir = MavenTestingUtils.getTargetTestingPath(WebAppClassLoaderTest.class.getName());
        server = new WSServer(testdir, "app");
        server.createWebInf();
        server.copyEndpoint(MySocket.class);
        server.start();
        webapp = server.createWebAppContext();
        server.deployWebapp(webapp);
    }

    @AfterEach
    public void stopServer() throws Exception
    {
        server.stop();
    }

    private void awaitServerClose() throws Exception
    {
        ClassLoader webAppClassLoader = webapp.getClassLoader();
        Class<?> mySocketClass = webAppClassLoader.loadClass(MySocket.class.getName());
        CountDownLatch closeLatch = (CountDownLatch)mySocketClass.getDeclaredField("closeLatch").get(null);
        assertTrue(closeLatch.await(5, TimeUnit.SECONDS));
    }

    private ClassLoader getClassLoader(String event) throws Exception
    {
        ClassLoader webAppClassLoader = webapp.getClassLoader();
        Class<?> mySocketClass = webAppClassLoader.loadClass(MySocket.class.getName());
        Map<String, ClassLoader> classLoaderMap = (Map)mySocketClass.getDeclaredField("classLoaders").get(null);
        return classLoaderMap.get(event);
    }

    @ParameterizedTest
    @ValueSource(strings = {"constructor", "onOpen", "onMessage", "onError", "onClose"})
    public void testForWebAppClassLoader(String event) throws Exception
    {
        WebSocketContainer client = ContainerProvider.getWebSocketContainer();
        EventSocket clientSocket = new EventSocket();
        Session session = client.connectToServer(clientSocket, server.getWsUri().resolve("/app/echo"));
        session.getBasicRemote().sendText("trigger onMessage -> onError -> onClose");
        ((JavaxWebSocketSession)session).abort();
        assertTrue(clientSocket.closeLatch.await(5, TimeUnit.SECONDS));
        awaitServerClose();

        ClassLoader webAppClassLoader = webapp.getClassLoader();
        assertThat(event, getClassLoader(event), is(webAppClassLoader));
    }
}
