Added source code to zxing.org
[zxing.git] / zxingorg / src / com / google / zxing / web / DoSFilter.java
1 /*\r
2  * Copyright 2008 Google Inc.\r
3  *\r
4  * Licensed under the Apache License, Version 2.0 (the "License");\r
5  * you may not use this file except in compliance with the License.\r
6  * You may obtain a copy of the License at\r
7  *\r
8  *      http://www.apache.org/licenses/LICENSE-2.0\r
9  *\r
10  * Unless required by applicable law or agreed to in writing, software\r
11  * distributed under the License is distributed on an "AS IS" BASIS,\r
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r
13  * See the License for the specific language governing permissions and\r
14  * limitations under the License.\r
15  */\r
16 \r
17 package com.google.zxing.web;\r
18 \r
19 import javax.servlet.Filter;\r
20 import javax.servlet.FilterChain;\r
21 import javax.servlet.FilterConfig;\r
22 import javax.servlet.ServletException;\r
23 import javax.servlet.ServletRequest;\r
24 import javax.servlet.ServletResponse;\r
25 import javax.servlet.ServletContext;\r
26 import javax.servlet.http.HttpServletResponse;\r
27 import java.io.IOException;\r
28 import java.net.InetAddress;\r
29 import java.net.UnknownHostException;\r
30 import java.util.Collections;\r
31 import java.util.HashSet;\r
32 import java.util.Set;\r
33 import java.util.Timer;\r
34 import java.util.TimerTask;\r
35 import java.util.Collection;\r
36 \r
37 /**\r
38  * @author Sean Owen\r
39  */\r
40 public final class DoSFilter implements Filter {\r
41 \r
42         private static final int MAX_ACCESSES_PER_IP_PER_TIME = 10;\r
43         private static final long MAX_ACCESS_INTERVAL_MSEC = 10L * 1000L;\r
44         private static final long UNBAN_INTERVAL_MSEC = 60L * 60L * 1000L;\r
45 \r
46         private final IPTrie numRecentAccesses;\r
47         private final Timer timer;\r
48         private final Set<String> bannedIPAddresses;\r
49         private final Collection<String> manuallyBannedIPAddresses;\r
50     private ServletContext context;\r
51 \r
52     public DoSFilter() {\r
53                 numRecentAccesses = new IPTrie();\r
54                 timer = new Timer("DosFilter reset timer");\r
55                 bannedIPAddresses = Collections.synchronizedSet(new HashSet<String>());\r
56                 manuallyBannedIPAddresses = new HashSet<String>();\r
57         }\r
58 \r
59     public void init(FilterConfig filterConfig) {\r
60         context = filterConfig.getServletContext();\r
61         String bannedIPs = filterConfig.getInitParameter("bannedIPs");\r
62             if (bannedIPs != null) {\r
63                     for (String ip : bannedIPs.split(",")) {\r
64                             manuallyBannedIPAddresses.add(ip.trim());\r
65                     }\r
66             }\r
67                 timer.scheduleAtFixedRate(new ResetTask(), 0L, MAX_ACCESS_INTERVAL_MSEC);\r
68             timer.scheduleAtFixedRate(new UnbanTask(), 0L, UNBAN_INTERVAL_MSEC);\r
69     }\r
70 \r
71     public void doFilter(ServletRequest request,\r
72                          ServletResponse response,\r
73                          FilterChain chain) throws IOException, ServletException {\r
74             if (isBanned(request)) {\r
75                     HttpServletResponse servletResponse = (HttpServletResponse) response;\r
76                     servletResponse.sendError(HttpServletResponse.SC_FORBIDDEN);\r
77             } else {\r
78                     chain.doFilter(request, response);\r
79             }\r
80     }\r
81 \r
82         private boolean isBanned(ServletRequest request) {\r
83                 String remoteIPAddressString = request.getRemoteAddr();\r
84                 if (bannedIPAddresses.contains(remoteIPAddressString) ||\r
85                     manuallyBannedIPAddresses.contains(remoteIPAddressString)) {\r
86                         return true;\r
87                 }\r
88                 InetAddress remoteIPAddress;\r
89                 try {\r
90                         remoteIPAddress = InetAddress.getByName(remoteIPAddressString);\r
91                 } catch (UnknownHostException uhe) {\r
92                         context.log("Can't determine host from: " + remoteIPAddressString + "; assuming banned");\r
93                         return true;\r
94                 }\r
95                 if (numRecentAccesses.incrementAndGet(remoteIPAddress) > MAX_ACCESSES_PER_IP_PER_TIME) {\r
96                         context.log("Possible DoS attack from " + remoteIPAddressString);\r
97                         bannedIPAddresses.add(remoteIPAddressString);\r
98                         return true;\r
99                 }\r
100                 return false;\r
101         }\r
102 \r
103         public void destroy() {\r
104             timer.cancel();\r
105         numRecentAccesses.clear();\r
106             bannedIPAddresses.clear();\r
107     }\r
108 \r
109         private final class ResetTask extends TimerTask {\r
110                 @Override\r
111                 public void run() {\r
112                         numRecentAccesses.clear();\r
113                 }\r
114         }\r
115 \r
116         private final class UnbanTask extends TimerTask {\r
117                 @Override\r
118                 public void run() {\r
119                         bannedIPAddresses.clear();\r
120                 }\r
121         }\r
122 \r
123 }