177fc2bc440801be44225b507ed35eab8bd024f8
[zxing.git] / zxingorg / src / com / google / zxing / web / DoSFilter.java
1 /*\r
2  * Copyright 2008 ZXing authors\r
3  *\r
4  * Licensed under the Apache License, Version 2.0 (the "License");\r
5  * you may not use this file except in compliance with the License.\r
6  * You may obtain a copy of the License at\r
7  *\r
8  *      http://www.apache.org/licenses/LICENSE-2.0\r
9  *\r
10  * Unless required by applicable law or agreed to in writing, software\r
11  * distributed under the License is distributed on an "AS IS" BASIS,\r
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r
13  * See the License for the specific language governing permissions and\r
14  * limitations under the License.\r
15  */\r
16 \r
17 package com.google.zxing.web;\r
18 \r
19 import javax.servlet.Filter;\r
20 import javax.servlet.FilterChain;\r
21 import javax.servlet.FilterConfig;\r
22 import javax.servlet.ServletContext;\r
23 import javax.servlet.ServletException;\r
24 import javax.servlet.ServletRequest;\r
25 import javax.servlet.ServletResponse;\r
26 import javax.servlet.http.HttpServletResponse;\r
27 import java.io.IOException;\r
28 import java.net.InetAddress;\r
29 import java.net.UnknownHostException;\r
30 import java.util.Collection;\r
31 import java.util.Collections;\r
32 import java.util.HashSet;\r
33 import java.util.Set;\r
34 import java.util.Timer;\r
35 import java.util.TimerTask;\r
36 import java.util.regex.Pattern;\r
37 \r
38 /**\r
39  * A {@link Filter} that rejects requests from hosts that are sending too many\r
40  * requests in too short a time.\r
41  * \r
42  * @author Sean Owen\r
43  */\r
44 public final class DoSFilter implements Filter {\r
45 \r
46   private static final int MAX_ACCESSES_PER_IP_PER_TIME = 10;\r
47   private static final long MAX_ACCESS_INTERVAL_MSEC = 10L * 1000L;\r
48   private static final long UNBAN_INTERVAL_MSEC = 60L * 60L * 1000L;\r
49   private static final Pattern COMMA_PATTERN = Pattern.compile(",");\r
50 \r
51   private final IPTrie numRecentAccesses;\r
52   private final Timer timer;\r
53   private final Set<String> bannedIPAddresses;\r
54   private final Collection<String> manuallyBannedIPAddresses;\r
55   private ServletContext context;\r
56 \r
57   public DoSFilter() {\r
58     numRecentAccesses = new IPTrie();\r
59     timer = new Timer("DosFilter reset timer");\r
60     bannedIPAddresses = Collections.synchronizedSet(new HashSet<String>());\r
61     manuallyBannedIPAddresses = new HashSet<String>();\r
62   }\r
63 \r
64   public void init(FilterConfig filterConfig) {\r
65     context = filterConfig.getServletContext();\r
66     String bannedIPs = filterConfig.getInitParameter("bannedIPs");\r
67     if (bannedIPs != null) {\r
68       for (String ip : COMMA_PATTERN.split(bannedIPs)) {\r
69         manuallyBannedIPAddresses.add(ip.trim());\r
70       }\r
71     }\r
72     timer.scheduleAtFixedRate(new ResetTask(), 0L, MAX_ACCESS_INTERVAL_MSEC);\r
73     timer.scheduleAtFixedRate(new UnbanTask(), 0L, UNBAN_INTERVAL_MSEC);\r
74   }\r
75 \r
76   public void doFilter(ServletRequest request,\r
77                        ServletResponse response,\r
78                        FilterChain chain) throws IOException, ServletException {\r
79     if (isBanned(request)) {\r
80       HttpServletResponse servletResponse = (HttpServletResponse) response;\r
81       servletResponse.sendError(HttpServletResponse.SC_FORBIDDEN);\r
82     } else {\r
83       chain.doFilter(request, response);\r
84     }\r
85   }\r
86 \r
87   private boolean isBanned(ServletRequest request) {\r
88     String remoteIPAddressString = request.getRemoteAddr();\r
89     if (bannedIPAddresses.contains(remoteIPAddressString) ||\r
90             manuallyBannedIPAddresses.contains(remoteIPAddressString)) {\r
91       return true;\r
92     }\r
93     InetAddress remoteIPAddress;\r
94     try {\r
95       remoteIPAddress = InetAddress.getByName(remoteIPAddressString);\r
96     } catch (UnknownHostException uhe) {\r
97       context.log("Can't determine host from: " + remoteIPAddressString + "; assuming banned");\r
98       return true;\r
99     }\r
100     if (numRecentAccesses.incrementAndGet(remoteIPAddress) > MAX_ACCESSES_PER_IP_PER_TIME) {\r
101       context.log("Possible DoS attack from " + remoteIPAddressString);\r
102       bannedIPAddresses.add(remoteIPAddressString);\r
103       return true;\r
104     }\r
105     return false;\r
106   }\r
107 \r
108   public void destroy() {\r
109     timer.cancel();\r
110     numRecentAccesses.clear();\r
111     bannedIPAddresses.clear();\r
112   }\r
113 \r
114   private final class ResetTask extends TimerTask {\r
115     @Override\r
116     public void run() {\r
117       numRecentAccesses.clear();\r
118     }\r
119   }\r
120 \r
121   private final class UnbanTask extends TimerTask {\r
122     @Override\r
123     public void run() {\r
124       bannedIPAddresses.clear();\r
125     }\r
126   }\r
127 \r
128 }