1
17 package org.apache.tomcat.websocket.server;
18
19 import java.io.IOException;
20 import java.util.Arrays;
21 import java.util.Collections;
22 import java.util.EnumSet;
23 import java.util.Map;
24 import java.util.Set;
25 import java.util.concurrent.ConcurrentHashMap;
26 import java.util.concurrent.ConcurrentSkipListMap;
27
28 import javax.servlet.DispatcherType;
29 import javax.servlet.FilterRegistration;
30 import javax.servlet.ServletContext;
31 import javax.servlet.ServletException;
32 import javax.servlet.http.HttpServletRequest;
33 import javax.servlet.http.HttpServletResponse;
34 import javax.websocket.CloseReason;
35 import javax.websocket.CloseReason.CloseCodes;
36 import javax.websocket.DeploymentException;
37 import javax.websocket.Encoder;
38 import javax.websocket.server.ServerContainer;
39 import javax.websocket.server.ServerEndpoint;
40 import javax.websocket.server.ServerEndpointConfig;
41 import javax.websocket.server.ServerEndpointConfig.Configurator;
42
43 import org.apache.tomcat.InstanceManager;
44 import org.apache.tomcat.util.res.StringManager;
45 import org.apache.tomcat.websocket.WsSession;
46 import org.apache.tomcat.websocket.WsWebSocketContainer;
47 import org.apache.tomcat.websocket.pojo.PojoMethodMapping;
48
49
59 public class WsServerContainer extends WsWebSocketContainer
60 implements ServerContainer {
61
62 private static final StringManager sm = StringManager.getManager(WsServerContainer.class);
63
64 private static final CloseReason AUTHENTICATED_HTTP_SESSION_CLOSED =
65 new CloseReason(CloseCodes.VIOLATED_POLICY,
66 "This connection was established under an authenticated " +
67 "HTTP session that has ended.");
68
69 private final WsWriteTimeout wsWriteTimeout = new WsWriteTimeout();
70
71 private final ServletContext servletContext;
72 private final Map<String,ExactPathMatch> configExactMatchMap = new ConcurrentHashMap<>();
73 private final Map<Integer,ConcurrentSkipListMap<String,TemplatePathMatch>> configTemplateMatchMap =
74 new ConcurrentHashMap<>();
75 private volatile boolean enforceNoAddAfterHandshake =
76 org.apache.tomcat.websocket.Constants.STRICT_SPEC_COMPLIANCE;
77 private volatile boolean addAllowed = true;
78 private final Map<String,Set<WsSession>> authenticatedSessions = new ConcurrentHashMap<>();
79 private volatile boolean endpointsRegistered = false;
80 private volatile boolean deploymentFailed = false;
81
82 WsServerContainer(ServletContext servletContext) {
83
84 this.servletContext = servletContext;
85 setInstanceManager((InstanceManager) servletContext.getAttribute(InstanceManager.class.getName()));
86
87
88 String value = servletContext.getInitParameter(
89 Constants.BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM);
90 if (value != null) {
91 setDefaultMaxBinaryMessageBufferSize(Integer.parseInt(value));
92 }
93
94 value = servletContext.getInitParameter(
95 Constants.TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM);
96 if (value != null) {
97 setDefaultMaxTextMessageBufferSize(Integer.parseInt(value));
98 }
99
100 value = servletContext.getInitParameter(
101 Constants.ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM);
102 if (value != null) {
103 setEnforceNoAddAfterHandshake(Boolean.parseBoolean(value));
104 }
105
106 FilterRegistration.Dynamic fr = servletContext.addFilter(
107 "Tomcat WebSocket (JSR356) Filter", new WsFilter());
108 fr.setAsyncSupported(true);
109
110 EnumSet<DispatcherType> types = EnumSet.of(DispatcherType.REQUEST,
111 DispatcherType.FORWARD);
112
113 fr.addMappingForUrlPatterns(types, true, ");
114 }
115
116
117 /**
118 * Published the provided endpoint implementation at the specified path with
119 * the specified configuration. {@link #WsServerContainer(ServletContext)}
120 * must be called before calling this method.
121 *
122 * @param sec The configuration to use when creating endpoint instances
123 * @throws DeploymentException if the endpoint cannot be published as
124 * requested
125 */
126 @Override
127 public void addEndpoint(ServerEndpointConfig sec) throws DeploymentException {
128 addEndpoint(sec, false);
129 }
130
131
132 void addEndpoint(ServerEndpointConfig sec, boolean fromAnnotatedPojo) throws DeploymentException {
133
134 if (enforceNoAddAfterHandshake && !addAllowed) {
135 throw new DeploymentException(
136 sm.getString("serverContainer.addNotAllowed"));
137 }
138
139 if (servletContext == null) {
140 throw new DeploymentException(
141 sm.getString("serverContainer.servletContextMissing"));
142 }
143
144 if (deploymentFailed) {
145 throw new DeploymentException(sm.getString("serverContainer.failedDeployment",
146 servletContext.getContextPath(), servletContext.getVirtualServerName()));
147 }
148
149 try {
150 String path = sec.getPath();
151
152
153 PojoMethodMapping methodMapping = new PojoMethodMapping(sec.getEndpointClass(),
154 sec.getDecoders(), path);
155 if (methodMapping.getOnClose() != null || methodMapping.getOnOpen() != null
156 || methodMapping.getOnError() != null || methodMapping.hasMessageHandlers()) {
157 sec.getUserProperties().put(org.apache.tomcat.websocket.pojo.Constants.POJO_METHOD_MAPPING_KEY,
158 methodMapping);
159 }
160
161 UriTemplate uriTemplate = new UriTemplate(path);
162 if (uriTemplate.hasParameters()) {
163 Integer key = Integer.valueOf(uriTemplate.getSegmentCount());
164 ConcurrentSkipListMap<String,TemplatePathMatch> templateMatches =
165 configTemplateMatchMap.get(key);
166 if (templateMatches == null) {
167
168
169 templateMatches = new ConcurrentSkipListMap<>();
170 configTemplateMatchMap.putIfAbsent(key, templateMatches);
171 templateMatches = configTemplateMatchMap.get(key);
172 }
173 TemplatePathMatch newMatch = new TemplatePathMatch(sec, uriTemplate, fromAnnotatedPojo);
174 TemplatePathMatch oldMatch = templateMatches.putIfAbsent(uriTemplate.getNormalizedPath(), newMatch);
175 if (oldMatch != null) {
176
177
178 if (oldMatch.isFromAnnotatedPojo() && !newMatch.isFromAnnotatedPojo() &&
179 oldMatch.getConfig().getEndpointClass() == newMatch.getConfig().getEndpointClass()) {
180
181 templateMatches.put(path, oldMatch);
182 } else {
183
184 throw new DeploymentException(
185 sm.getString("serverContainer.duplicatePaths", path,
186 sec.getEndpointClass(),
187 sec.getEndpointClass()));
188 }
189 }
190 } else {
191
192 ExactPathMatch newMatch = new ExactPathMatch(sec, fromAnnotatedPojo);
193 ExactPathMatch oldMatch = configExactMatchMap.put(path, newMatch);
194 if (oldMatch != null) {
195
196
197 if (oldMatch.isFromAnnotatedPojo() && !newMatch.isFromAnnotatedPojo() &&
198 oldMatch.getConfig().getEndpointClass() == newMatch.getConfig().getEndpointClass()) {
199
200 configExactMatchMap.put(path, oldMatch);
201 } else {
202
203 throw new DeploymentException(
204 sm.getString("serverContainer.duplicatePaths", path,
205 oldMatch.getConfig().getEndpointClass(),
206 sec.getEndpointClass()));
207 }
208 }
209 }
210
211 endpointsRegistered = true;
212 } catch (DeploymentException de) {
213 failDeployment();
214 throw de;
215 }
216 }
217
218
219
226 @Override
227 public void addEndpoint(Class<?> pojo) throws DeploymentException {
228 addEndpoint(pojo, false);
229 }
230
231
232 void addEndpoint(Class<?> pojo, boolean fromAnnotatedPojo) throws DeploymentException {
233
234 if (deploymentFailed) {
235 throw new DeploymentException(sm.getString("serverContainer.failedDeployment",
236 servletContext.getContextPath(), servletContext.getVirtualServerName()));
237 }
238
239 ServerEndpointConfig sec;
240
241 try {
242 ServerEndpoint annotation = pojo.getAnnotation(ServerEndpoint.class);
243 if (annotation == null) {
244 throw new DeploymentException(
245 sm.getString("serverContainer.missingAnnotation",
246 pojo.getName()));
247 }
248 String path = annotation.value();
249
250
251 validateEncoders(annotation.encoders());
252
253
254 Class<? extends Configurator> configuratorClazz =
255 annotation.configurator();
256 Configurator configurator = null;
257 if (!configuratorClazz.equals(Configurator.class)) {
258 try {
259 configurator = annotation.configurator().getConstructor().newInstance();
260 } catch (ReflectiveOperationException e) {
261 throw new DeploymentException(sm.getString(
262 "serverContainer.configuratorFail",
263 annotation.configurator().getName(),
264 pojo.getClass().getName()), e);
265 }
266 }
267 sec = ServerEndpointConfig.Builder.create(pojo, path).
268 decoders(Arrays.asList(annotation.decoders())).
269 encoders(Arrays.asList(annotation.encoders())).
270 subprotocols(Arrays.asList(annotation.subprotocols())).
271 configurator(configurator).
272 build();
273 } catch (DeploymentException de) {
274 failDeployment();
275 throw de;
276 }
277
278 addEndpoint(sec, fromAnnotatedPojo);
279 }
280
281
282 void failDeployment() {
283 deploymentFailed = true;
284
285
286 endpointsRegistered = false;
287 configExactMatchMap.clear();
288 configTemplateMatchMap.clear();
289 }
290
291
292 boolean areEndpointsRegistered() {
293 return endpointsRegistered;
294 }
295
296
297
315 public void doUpgrade(HttpServletRequest request,
316 HttpServletResponse response, ServerEndpointConfig sec,
317 Map<String,String> pathParams)
318 throws ServletException, IOException {
319 UpgradeUtil.doUpgrade(this, request, response, sec, pathParams);
320 }
321
322
323 public WsMappingResult findMapping(String path) {
324
325
326
327 if (addAllowed) {
328 addAllowed = false;
329 }
330
331
332 ExactPathMatch match = configExactMatchMap.get(path);
333 if (match != null) {
334 return new WsMappingResult(match.getConfig(), Collections.<String, String>emptyMap());
335 }
336
337
338 UriTemplate pathUriTemplate = null;
339 try {
340 pathUriTemplate = new UriTemplate(path);
341 } catch (DeploymentException e) {
342
343 return null;
344 }
345
346
347 Integer key = Integer.valueOf(pathUriTemplate.getSegmentCount());
348 ConcurrentSkipListMap<String,TemplatePathMatch> templateMatches = configTemplateMatchMap.get(key);
349
350 if (templateMatches == null) {
351
352
353 return null;
354 }
355
356
357
358 ServerEndpointConfig sec = null;
359 Map<String,String> pathParams = null;
360 for (TemplatePathMatch templateMatch : templateMatches.values()) {
361 pathParams = templateMatch.getUriTemplate().match(pathUriTemplate);
362 if (pathParams != null) {
363 sec = templateMatch.getConfig();
364 break;
365 }
366 }
367
368 if (sec == null) {
369
370 return null;
371 }
372
373 return new WsMappingResult(sec, pathParams);
374 }
375
376
377
378 public boolean isEnforceNoAddAfterHandshake() {
379 return enforceNoAddAfterHandshake;
380 }
381
382
383 public void setEnforceNoAddAfterHandshake(
384 boolean enforceNoAddAfterHandshake) {
385 this.enforceNoAddAfterHandshake = enforceNoAddAfterHandshake;
386 }
387
388
389 protected WsWriteTimeout getTimeout() {
390 return wsWriteTimeout;
391 }
392
393
394
399 @Override
400 protected void registerSession(Object key, WsSession wsSession) {
401 super.registerSession(key, wsSession);
402 if (wsSession.isOpen() &&
403 wsSession.getUserPrincipal() != null &&
404 wsSession.getHttpSessionId() != null) {
405 registerAuthenticatedSession(wsSession,
406 wsSession.getHttpSessionId());
407 }
408 }
409
410
411
416 @Override
417 protected void unregisterSession(Object key, WsSession wsSession) {
418 if (wsSession.getUserPrincipal() != null &&
419 wsSession.getHttpSessionId() != null) {
420 unregisterAuthenticatedSession(wsSession,
421 wsSession.getHttpSessionId());
422 }
423 super.unregisterSession(key, wsSession);
424 }
425
426
427 private void registerAuthenticatedSession(WsSession wsSession,
428 String httpSessionId) {
429 Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId);
430 if (wsSessions == null) {
431 wsSessions = Collections.newSetFromMap(
432 new ConcurrentHashMap<WsSession,Boolean>());
433 authenticatedSessions.putIfAbsent(httpSessionId, wsSessions);
434 wsSessions = authenticatedSessions.get(httpSessionId);
435 }
436 wsSessions.add(wsSession);
437 }
438
439
440 private void unregisterAuthenticatedSession(WsSession wsSession,
441 String httpSessionId) {
442 Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId);
443
444 if (wsSessions != null) {
445 wsSessions.remove(wsSession);
446 }
447 }
448
449
450 public void closeAuthenticatedSession(String httpSessionId) {
451 Set<WsSession> wsSessions = authenticatedSessions.remove(httpSessionId);
452
453 if (wsSessions != null && !wsSessions.isEmpty()) {
454 for (WsSession wsSession : wsSessions) {
455 try {
456 wsSession.close(AUTHENTICATED_HTTP_SESSION_CLOSED);
457 } catch (IOException e) {
458
459
460 }
461 }
462 }
463 }
464
465
466 private static void validateEncoders(Class<? extends Encoder>[] encoders)
467 throws DeploymentException {
468
469 for (Class<? extends Encoder> encoder : encoders) {
470
471
472 @SuppressWarnings("unused")
473 Encoder instance;
474 try {
475 encoder.getConstructor().newInstance();
476 } catch(ReflectiveOperationException e) {
477 throw new DeploymentException(sm.getString(
478 "serverContainer.encoderFail", encoder.getName()), e);
479 }
480 }
481 }
482
483
484 private static class TemplatePathMatch {
485 private final ServerEndpointConfig config;
486 private final UriTemplate uriTemplate;
487 private final boolean fromAnnotatedPojo;
488
489 public TemplatePathMatch(ServerEndpointConfig config, UriTemplate uriTemplate,
490 boolean fromAnnotatedPojo) {
491 this.config = config;
492 this.uriTemplate = uriTemplate;
493 this.fromAnnotatedPojo = fromAnnotatedPojo;
494 }
495
496
497 public ServerEndpointConfig getConfig() {
498 return config;
499 }
500
501
502 public UriTemplate getUriTemplate() {
503 return uriTemplate;
504 }
505
506
507 public boolean isFromAnnotatedPojo() {
508 return fromAnnotatedPojo;
509 }
510 }
511
512
513 private static class ExactPathMatch {
514 private final ServerEndpointConfig config;
515 private final boolean fromAnnotatedPojo;
516
517 public ExactPathMatch(ServerEndpointConfig config, boolean fromAnnotatedPojo) {
518 this.config = config;
519 this.fromAnnotatedPojo = fromAnnotatedPojo;
520 }
521
522
523 public ServerEndpointConfig getConfig() {
524 return config;
525 }
526
527
528 public boolean isFromAnnotatedPojo() {
529 return fromAnnotatedPojo;
530 }
531 }
532 }
533