1 /*
2  *  Licensed to the Apache Software Foundation (ASF) under one or more
3  *  contributor license agreements.  See the NOTICE file distributed with
4  *  this work for additional information regarding copyright ownership.
5  *  The ASF licenses this file to You under the Apache License, Version 2.0
6  *  (the "License"); you may not use this file except in compliance with
7  *  the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  */

17 package org.apache.tomcat.websocket.server;
18
19 import java.lang.reflect.Modifier;
20 import java.util.HashSet;
21 import java.util.Set;
22
23 import javax.servlet.ServletContainerInitializer;
24 import javax.servlet.ServletContext;
25 import javax.servlet.ServletException;
26 import javax.servlet.annotation.HandlesTypes;
27 import javax.websocket.ContainerProvider;
28 import javax.websocket.DeploymentException;
29 import javax.websocket.Endpoint;
30 import javax.websocket.server.ServerApplicationConfig;
31 import javax.websocket.server.ServerEndpoint;
32 import javax.websocket.server.ServerEndpointConfig;
33
34 import org.apache.tomcat.util.compat.JreCompat;
35
36 /**
37  * Registers an interest in any class that is annotated with
38  * {@link ServerEndpoint} so that Endpoint can be published via the WebSocket
39  * server.
40  */

41 @HandlesTypes({ServerEndpoint.class, ServerApplicationConfig.class,
42         Endpoint.class})
43 public class WsSci implements ServletContainerInitializer {
44
45     @Override
46     public void onStartup(Set<Class<?>> clazzes, ServletContext ctx)
47             throws ServletException {
48
49         WsServerContainer sc = init(ctx, true);
50
51         if (clazzes == null || clazzes.size() == 0) {
52             return;
53         }
54
55         // Group the discovered classes by type
56         Set<ServerApplicationConfig> serverApplicationConfigs = new HashSet<>();
57         Set<Class<? extends Endpoint>> scannedEndpointClazzes = new HashSet<>();
58         Set<Class<?>> scannedPojoEndpoints = new HashSet<>();
59
60         try {
61             // wsPackage is "javax.websocket."
62             String wsPackage = ContainerProvider.class.getName();
63             wsPackage = wsPackage.substring(0, wsPackage.lastIndexOf('.') + 1);
64             for (Class<?> clazz : clazzes) {
65                 JreCompat jreCompat = JreCompat.getInstance();
66                 int modifiers = clazz.getModifiers();
67                 if (!Modifier.isPublic(modifiers) ||
68                         Modifier.isAbstract(modifiers) ||
69                         Modifier.isInterface(modifiers) ||
70                         !jreCompat.isExported(clazz)) {
71                     // Non-publicabstractinterface or not in an exported
72                     // package (Java 9+) - skip it.
73                     continue;
74                 }
75                 // Protect against scanning the WebSocket API JARs
76                 if (clazz.getName().startsWith(wsPackage)) {
77                     continue;
78                 }
79                 if (ServerApplicationConfig.class.isAssignableFrom(clazz)) {
80                     serverApplicationConfigs.add(
81                             (ServerApplicationConfig) clazz.getConstructor().newInstance());
82                 }
83                 if (Endpoint.class.isAssignableFrom(clazz)) {
84                     @SuppressWarnings("unchecked")
85                     Class<? extends Endpoint> endpoint =
86                             (Class<? extends Endpoint>) clazz;
87                     scannedEndpointClazzes.add(endpoint);
88                 }
89                 if (clazz.isAnnotationPresent(ServerEndpoint.class)) {
90                     scannedPojoEndpoints.add(clazz);
91                 }
92             }
93         } catch (ReflectiveOperationException e) {
94             throw new ServletException(e);
95         }
96
97         // Filter the results
98         Set<ServerEndpointConfig> filteredEndpointConfigs = new HashSet<>();
99         Set<Class<?>> filteredPojoEndpoints = new HashSet<>();
100
101         if (serverApplicationConfigs.isEmpty()) {
102             filteredPojoEndpoints.addAll(scannedPojoEndpoints);
103         } else {
104             for (ServerApplicationConfig config : serverApplicationConfigs) {
105                 Set<ServerEndpointConfig> configFilteredEndpoints =
106                         config.getEndpointConfigs(scannedEndpointClazzes);
107                 if (configFilteredEndpoints != null) {
108                     filteredEndpointConfigs.addAll(configFilteredEndpoints);
109                 }
110                 Set<Class<?>> configFilteredPojos =
111                         config.getAnnotatedEndpointClasses(
112                                 scannedPojoEndpoints);
113                 if (configFilteredPojos != null) {
114                     filteredPojoEndpoints.addAll(configFilteredPojos);
115                 }
116             }
117         }
118
119         try {
120             // Deploy endpoints
121             for (ServerEndpointConfig config : filteredEndpointConfigs) {
122                 sc.addEndpoint(config);
123             }
124             // Deploy POJOs
125             for (Class<?> clazz : filteredPojoEndpoints) {
126                 sc.addEndpoint(clazz, true);
127             }
128         } catch (DeploymentException e) {
129             throw new ServletException(e);
130         }
131     }
132
133
134     static WsServerContainer init(ServletContext servletContext,
135             boolean initBySciMechanism) {
136
137         WsServerContainer sc = new WsServerContainer(servletContext);
138
139         servletContext.setAttribute(
140                 Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE, sc);
141
142         servletContext.addListener(new WsSessionListener(sc));
143         // Can't register the ContextListener again if the ContextListener is
144         // calling this method
145         if (initBySciMechanism) {
146             servletContext.addListener(new WsContextListener());
147         }
148
149         return sc;
150     }
151 }
152