1 /*
2  *  Licensed to the Apache Software Foundation (ASF) under one or more
3  *  contributor license agreements.  See the NOTICE file distributed with
4  *  this work for additional information regarding copyright ownership.
5  *  The ASF licenses this file to You under the Apache License, Version 2.0
6  *  (the "License"); you may not use this file except in compliance with
7  *  the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  */

17 package org.apache.tomcat.websocket.server;
18
19 import java.util.Comparator;
20 import java.util.Set;
21 import java.util.concurrent.ConcurrentSkipListSet;
22 import java.util.concurrent.atomic.AtomicInteger;
23
24 import org.apache.tomcat.websocket.BackgroundProcess;
25 import org.apache.tomcat.websocket.BackgroundProcessManager;
26
27 /**
28  * Provides timeouts for asynchronous web socket writes. On the server side we
29  * only have access to {@link javax.servlet.ServletOutputStream} and
30  * {@link javax.servlet.ServletInputStream} so there is no way to set a timeout
31  * for writes to the client.
32  */

33 public class WsWriteTimeout implements BackgroundProcess {
34
35     private final Set<WsRemoteEndpointImplServer> endpoints =
36             new ConcurrentSkipListSet<>(new EndpointComparator());
37     private final AtomicInteger count = new AtomicInteger(0);
38     private int backgroundProcessCount = 0;
39     private volatile int processPeriod = 1;
40
41     @Override
42     public void backgroundProcess() {
43         // This method gets called once a second.
44         backgroundProcessCount ++;
45
46         if (backgroundProcessCount >= processPeriod) {
47             backgroundProcessCount = 0;
48
49             long now = System.currentTimeMillis();
50             for (WsRemoteEndpointImplServer endpoint : endpoints) {
51                 if (endpoint.getTimeoutExpiry() < now) {
52                     // Background thread, not the thread that triggered the
53                     // write so no need to use a dispatch
54                     endpoint.onTimeout(false);
55                 } else {
56                     // Endpoints are ordered by timeout expiry so if this point
57                     // is reached there is no need to check the remaining
58                     // endpoints
59                     break;
60                 }
61             }
62         }
63     }
64
65
66     @Override
67     public void setProcessPeriod(int period) {
68         this.processPeriod = period;
69     }
70
71
72     /**
73      * {@inheritDoc}
74      *
75      * The default value is 1 which means asynchronous write timeouts are
76      * processed every 1 second.
77      */

78     @Override
79     public int getProcessPeriod() {
80         return processPeriod;
81     }
82
83
84     public void register(WsRemoteEndpointImplServer endpoint) {
85         boolean result = endpoints.add(endpoint);
86         if (result) {
87             int newCount = count.incrementAndGet();
88             if (newCount == 1) {
89                 BackgroundProcessManager.getInstance().register(this);
90             }
91         }
92     }
93
94
95     public void unregister(WsRemoteEndpointImplServer endpoint) {
96         boolean result = endpoints.remove(endpoint);
97         if (result) {
98             int newCount = count.decrementAndGet();
99             if (newCount == 0) {
100                 BackgroundProcessManager.getInstance().unregister(this);
101             }
102         }
103     }
104
105
106     /**
107      * Note: this comparator imposes orderings that are inconsistent with equals
108      */

109     private static class EndpointComparator implements
110             Comparator<WsRemoteEndpointImplServer> {
111
112         @Override
113         public int compare(WsRemoteEndpointImplServer o1,
114                 WsRemoteEndpointImplServer o2) {
115
116             long t1 = o1.getTimeoutExpiry();
117             long t2 = o2.getTimeoutExpiry();
118
119             if (t1 < t2) {
120                 return -1;
121             } else if (t1 == t2) {
122                 return 0;
123             } else {
124                 return 1;
125             }
126         }
127     }
128 }
129