001/**
002 * Copyright (c) 2015-2022, Michael Yang 杨福海 (fuhai999@gmail.com).
003 * <p>
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 * <p>
008 * http://www.apache.org/licenses/LICENSE-2.0
009 * <p>
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package io.jboot.test.web;
017
018import com.google.common.collect.LinkedHashMultimap;
019import com.google.common.collect.Multimap;
020import com.jfinal.kit.LogKit;
021import io.jboot.utils.StrUtil;
022
023import javax.servlet.*;
024import javax.servlet.http.*;
025import java.io.BufferedReader;
026import java.io.IOException;
027import java.io.InputStreamReader;
028import java.security.Principal;
029import java.util.*;
030
031public class MockHttpServletRequest implements HttpServletRequest {
032
033    protected String contextPath;
034    protected String method = "GET";
035    protected String pathInfo;
036    protected String pathTranslated;
037    protected String queryString = "";
038    protected String requestURI;
039    protected String servletPath;
040    protected String characterEncoding = "UTF-8";
041    protected String protocol = "HTTP/1.1";
042
043    private String remoteAddr = "127.0.0.1";
044    private String remoteHost = "localhost";
045    private int remotePort = 80;
046    private String localName = "localhost";
047    private String localAddr = "127.0.0.1";
048    private int localPort = 80;
049
050    protected String remoteUser;
051    protected String authType;
052    protected Principal userPrincipal;
053
054    protected StringBuffer requestURL;
055    protected HttpSession session;
056    protected ServletInputStream inputStream;
057
058    protected byte[] content;
059
060
061    protected ServletContext servletContext = MockServletContext.DEFAULT;
062    protected HttpServletResponse response;
063
064    protected Map<String, String> headers = new HashMap<>();
065    protected Map<String, Object> attributeMap = new HashMap<>();
066    protected Map<String, String[]> parameters = new HashMap<>();
067    protected Set<Cookie> cookies = new HashSet<>();
068    protected LinkedList<Locale> locales = new LinkedList<>();
069
070    protected boolean requestedSessionIdValid = true;
071
072    protected boolean requestedSessionIdFromCookie = true;
073
074    protected boolean requestedSessionIdFromURL = false;
075
076    protected final Set<String> userRoles = new HashSet<>();
077    protected final Multimap<String, Part> parts = LinkedHashMultimap.create();
078
079
080    public MockHttpServletRequest() {
081//        super(MockProxy.create(HttpServletRequest.class));
082    }
083
084    @Override
085    public String getContextPath() {
086        if (contextPath == null) {
087            contextPath = servletContext.getContextPath();
088        }
089        return contextPath;
090    }
091
092    public void setContextPath(String contextPath) {
093        this.contextPath = contextPath;
094    }
095
096
097    @Override
098    public String getHeader(String name) {
099        return headers.get(name.toLowerCase());
100    }
101
102    public void setHeaders(Map<String, String> headers) {
103        if (headers != null) {
104            headers.forEach((s, s2) -> MockHttpServletRequest.this.headers.put(s.toLowerCase(), s2));
105        }
106    }
107
108    public void addHeader(String name, Object value) {
109        headers.put(name.toLowerCase(), value.toString());
110    }
111
112    @Override
113    public String getMethod() {
114        return method;
115    }
116
117    public void setMethod(String method) {
118        this.method = method;
119    }
120
121
122    @Override
123    public String getPathInfo() {
124        return pathInfo;
125    }
126
127    public void setPathInfo(String pathInfo) {
128        this.pathInfo = pathInfo;
129    }
130
131
132    @Override
133    public String getPathTranslated() {
134        return (this.pathInfo != null ? getRealPath(this.pathInfo) : null);
135    }
136
137
138    @Override
139    public String getQueryString() {
140        return queryString;
141    }
142
143    public void setQueryString(String queryString) {
144        this.queryString = queryString;
145    }
146
147
148    @Override
149    public String getRemoteUser() {
150        return remoteUser;
151    }
152
153
154    public void addUserRole(String role) {
155        this.userRoles.add(role);
156    }
157
158    @Override
159    public boolean isUserInRole(String role) {
160        return userRoles.contains(role);
161    }
162
163    public void setRemoteUser(String remoteUser) {
164        this.remoteUser = remoteUser;
165    }
166
167
168    @Override
169    public String getAuthType() {
170        return authType;
171    }
172
173    public void setAuthType(String authType) {
174        this.authType = authType;
175    }
176
177    @Override
178    public String getRequestURI() {
179        return requestURI;
180    }
181
182    public void setRequestURI(String requestURI) {
183        this.requestURI = requestURI;
184    }
185
186
187    @Override
188    public StringBuffer getRequestURL() {
189        return requestURL;
190    }
191
192    public void setRequestURL(StringBuffer requestURL) {
193        this.requestURL = requestURL;
194    }
195
196    @Override
197    public String getRequestedSessionId() {
198        if (session != null) {
199            return session.getId();
200        }
201        return null;
202    }
203
204
205    @Override
206    public String getServletPath() {
207        return servletPath;
208    }
209
210    public void setServletPath(String servletPath) {
211        this.servletPath = servletPath;
212        if (requestURI == null) {
213            this.requestURI = getContextPath() + servletPath;
214        }
215    }
216
217
218    @Override
219    public HttpSession getSession() {
220        return getSession(true);
221    }
222
223
224    @Override
225    public HttpSession getSession(boolean create) {
226        if (session != null) {
227            return session;
228        }
229
230        String sessionId = getCookieValue("jsessionId");
231        if (sessionId != null) {
232            session = new MockHttpSession(sessionId, getServletContext());
233            session.setMaxInactiveInterval(60 * 60);
234        } else if (create) {
235            sessionId = UUID.randomUUID().toString().replace("-", "");
236            session = new MockHttpSession(sessionId, getServletContext());
237            session.setMaxInactiveInterval(60 * 60);
238            setCookie("jsessionId", sessionId, -1);
239        }
240        return session;
241    }
242
243
244    @Override
245    public String changeSessionId() {
246        String sessionId = UUID.randomUUID().toString().replace("-", "");
247        session = new MockHttpSession(sessionId, getServletContext());
248        session.setMaxInactiveInterval(60 * 60);
249        setCookie("jsessionId", sessionId, -1);
250        return sessionId;
251    }
252
253    /**
254     * Get cookie value by cookie name.
255     */
256    private String getCookieValue(String name) {
257        Cookie cookie = getCookieObject(name);
258        return cookie != null ? cookie.getValue() : null;
259    }
260
261    /**
262     * Get cookie object by cookie name.
263     */
264    private Cookie getCookieObject(String name) {
265        for (Cookie cookie : cookies) {
266            if (cookie.getName().equals(name)) {
267                return cookie;
268            }
269        }
270        return null;
271    }
272
273    /**
274     * @param name
275     * @param value
276     * @param maxAgeInSeconds
277     */
278    private void setCookie(String name, String value, int maxAgeInSeconds) {
279        Cookie cookie = new Cookie(name, value);
280        cookie.setMaxAge(maxAgeInSeconds);
281        response.addCookie(cookie);
282    }
283
284
285    @Override
286    public Principal getUserPrincipal() {
287        return userPrincipal;
288    }
289
290    public void setUserPrincipal(Principal userPrincipal) {
291        this.userPrincipal = userPrincipal;
292    }
293
294
295    @Override
296    public Object getAttribute(String key) {
297        return attributeMap.get(key);
298    }
299
300    @Override
301    public Enumeration<String> getAttributeNames() {
302        return Collections.enumeration(attributeMap.keySet());
303    }
304
305
306    @Override
307    public String getCharacterEncoding() {
308        return characterEncoding;
309    }
310
311    @Override
312    public int getContentLength() {
313        String cl = this.getHeader("content-length");
314        if (cl != null) {
315            try {
316                return Integer.parseInt(cl);
317            } catch (NumberFormatException e) {
318                return 0;
319            }
320        }
321
322        if (inputStream != null) {
323            try {
324                return inputStream.available();
325            } catch (IOException e) {
326                return 0;
327            }
328        }
329
330        return 0;
331    }
332
333    @Override
334    public String getContentType() {
335        return this.getHeader("content-type");
336    }
337
338
339    @Override
340    public ServletInputStream getInputStream() throws IOException {
341        if (inputStream == null) {
342            inputStream = new MockServletInputStream("");
343        }
344        return inputStream;
345    }
346
347    public void setInputStream(ServletInputStream ins) {
348        this.inputStream = ins;
349    }
350
351
352    public void setContent(byte[] content) {
353        this.content = content;
354    }
355
356
357    @Override
358    public long getContentLengthLong() {
359        return (this.content != null ? this.content.length : -1);
360    }
361
362
363    @Override
364    public String getParameter(String key) {
365        if (parameters.containsKey(key)) {
366            return parameters.get(key)[0];
367        }
368        return null;
369    }
370
371    public void addParameter(String key, Number num) {
372        addParameter(key, num.toString());
373    }
374
375    public void addParameter(String key, String value) {
376        addParameter(key, new String[]{value});
377    }
378
379    public void addParameter(String key, Object value) {
380        addParameter(key, new String[]{String.valueOf(value)});
381    }
382
383
384    public void addParameter(String key, String[] values) {
385        parameters.put(key, values);
386
387        if ("GET".equalsIgnoreCase(getMethod())) {
388            updateQueryString();
389        }
390    }
391
392    public void addQueryParameter(String key, Object value) {
393
394        if ("GET".equalsIgnoreCase(getMethod())) {
395            parameters.put(key, new String[]{String.valueOf(value)});
396        }
397
398        Map queryStringMap = StrUtil.isNotBlank(queryString) ? StrUtil.queryStringToMap(this.queryString) : new HashMap();
399        queryStringMap.put(key, value);
400        setQueryString(StrUtil.mapToQueryString(queryStringMap));
401    }
402
403    private void updateQueryString() {
404        StringBuilder sb = new StringBuilder();
405        for (String key : parameters.keySet()) {
406            if (key == null || key.length() == 0) {
407                continue;
408            }
409            if (sb.length() > 0) {
410                sb.append("&");
411            }
412
413            sb.append(key.trim()).append("=");
414            String[] values = parameters.get(key);
415            if (values == null || values.length == 0) {
416                continue;
417            }
418
419            if (values.length == 1) {
420                sb.append(StrUtil.urlEncode(values[0]));
421            } else {
422                for (int i = 0; i < values.length; i++) {
423                    if (i == 0) {
424                        sb.append(StrUtil.urlEncode(values[i]));
425                    } else {
426                        if (sb.length() > 0) {
427                            sb.append("&");
428                        }
429                        sb.append(key.trim()).append("=").append(StrUtil.urlEncode(values[i]));
430                    }
431                }
432            }
433        }
434
435        setQueryString(sb.toString());
436    }
437
438    @Override
439    public Map<String, String[]> getParameterMap() {
440        return parameters;
441    }
442
443    @Override
444    public Enumeration<String> getParameterNames() {
445        return Collections.enumeration(parameters.keySet());
446    }
447
448    @Override
449    public String[] getParameterValues(String name) {
450        return parameters.get(name);
451    }
452
453
454    @Override
455    public String getProtocol() {
456        return protocol;
457    }
458
459    public void setProtocol(String protocol) {
460        this.protocol = protocol;
461    }
462
463    @Override
464    public void removeAttribute(String key) {
465        attributeMap.remove(key);
466    }
467
468    @Override
469    public void setAttribute(String key, Object value) {
470        attributeMap.put(key, value);
471    }
472
473    @Override
474    public Cookie[] getCookies() {
475        return cookies.toArray(new Cookie[cookies.size()]);
476    }
477
478    public void setCookies(Set<Cookie> cookies) {
479        this.cookies = cookies;
480    }
481
482    @Override
483    public Enumeration<Locale> getLocales() {
484        return Collections.enumeration(locales);
485    }
486
487    public void setLocales(LinkedList<Locale> locales) {
488        this.locales = locales;
489    }
490
491    @Override
492    public void setCharacterEncoding(String characterEncoding) {
493        this.characterEncoding = characterEncoding;
494    }
495
496    @Override
497    public ServletContext getServletContext() {
498        return servletContext;
499    }
500
501    @Override
502    public AsyncContext startAsync() throws IllegalStateException {
503        return null;
504    }
505
506    @Override
507    public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException {
508        return null;
509    }
510
511    @Override
512    public boolean isAsyncStarted() {
513        return false;
514    }
515
516    @Override
517    public String getServerName() {
518        return "localhost";
519    }
520
521    @Override
522    public int getServerPort() {
523        return 80;
524    }
525
526    @Override
527    public BufferedReader getReader() throws IOException {
528        return new BufferedReader(new InputStreamReader(getInputStream()));
529    }
530
531
532    public void setRemotePort(int remotePort) {
533        this.remotePort = remotePort;
534    }
535
536    @Override
537    public int getRemotePort() {
538        return remotePort;
539    }
540
541    public void setRemoteAddr(String remoteAddr) {
542        this.remoteAddr = remoteAddr;
543    }
544
545    @Override
546    public String getRemoteAddr() {
547        return remoteAddr;
548    }
549
550    public void setRemoteHost(String remoteHost) {
551        this.remoteHost = remoteHost;
552    }
553
554    @Override
555    public String getRemoteHost() {
556        return remoteHost;
557    }
558
559    @Override
560    public boolean isRequestedSessionIdFromURL() {
561        return requestedSessionIdFromURL;
562    }
563
564    @Override
565    public boolean isRequestedSessionIdFromUrl() {
566        return requestedSessionIdFromURL;
567    }
568
569    public void setRequestedSessionIdFromURL(boolean requestedSessionIdFromURL) {
570        this.requestedSessionIdFromURL = requestedSessionIdFromURL;
571    }
572
573    @Override
574    public boolean authenticate(HttpServletResponse response) throws IOException, ServletException {
575        return false;
576    }
577
578    @Override
579    public void login(String username, String password) throws ServletException {
580        LogKit.error("Unsupport login method!");
581    }
582
583    public void setRequestedSessionIdFromCookie(boolean requestedSessionIdFromCookie) {
584        this.requestedSessionIdFromCookie = requestedSessionIdFromCookie;
585    }
586
587    @Override
588    public boolean isRequestedSessionIdFromCookie() {
589        return requestedSessionIdFromCookie;
590    }
591
592    public void setRequestedSessionIdValid(boolean requestedSessionIdValid) {
593        this.requestedSessionIdValid = requestedSessionIdValid;
594    }
595
596    @Override
597    public boolean isRequestedSessionIdValid() {
598        return requestedSessionIdValid;
599    }
600
601    @Override
602    public String getScheme() {
603        return "http";
604    }
605
606    @Override
607    public Locale getLocale() {
608        return this.locales.getFirst();
609    }
610
611    public void setLocalPort(int localPort) {
612        this.localPort = localPort;
613    }
614
615    @Override
616    public int getLocalPort() {
617        return localPort;
618    }
619
620    public void setLocalAddr(String localAddr) {
621        this.localAddr = localAddr;
622    }
623
624    @Override
625    public String getLocalAddr() {
626        return localAddr;
627    }
628
629    public void setLocalName(String localName) {
630        this.localName = localName;
631    }
632
633    @Override
634    public String getLocalName() {
635        return localName;
636    }
637
638    @Override
639    public boolean isAsyncSupported() {
640        return false;
641    }
642
643    @Override
644    public AsyncContext getAsyncContext() {
645        return null;
646    }
647
648    @Override
649    public boolean isSecure() {
650        return false;
651    }
652
653    @Override
654    public RequestDispatcher getRequestDispatcher(String path) {
655        return null;
656    }
657
658    @Override
659    public String getRealPath(String path) {
660        return this.servletContext.getRealPath(path);
661    }
662
663    @Override
664    public long getDateHeader(String name) {
665        return Long.valueOf(getHeader(name));
666    }
667
668    @Override
669    public DispatcherType getDispatcherType() {
670        return DispatcherType.REQUEST;
671    }
672
673    @Override
674    public Enumeration<String> getHeaders(String name) {
675        String header = getHeader(name);
676        String[] headers = header.split(";");
677        return Collections.enumeration(Arrays.asList(headers));
678    }
679
680    @Override
681    public Enumeration<String> getHeaderNames() {
682        return Collections.enumeration(headers.keySet());
683    }
684
685    @Override
686    public int getIntHeader(String name) {
687        return Integer.valueOf(getHeader(name));
688    }
689
690    @Override
691    public void logout() throws ServletException {
692        this.userPrincipal = null;
693        this.remoteUser = null;
694        this.authType = null;
695    }
696
697    public void addPart(Part part) {
698        this.parts.put(part.getName(), part);
699    }
700
701    @Override
702    public Part getPart(String name) throws IOException, ServletException {
703        final Collection<Part> parts = this.parts.get(name);
704        for (Part part : parts) {
705            return part;
706        }
707        return null;
708    }
709
710    @Override
711    public Collection<Part> getParts() throws IOException, ServletException {
712        List<Part> result = new LinkedList<>(this.parts.values());
713        return result;
714    }
715
716    @Override
717    public <T extends HttpUpgradeHandler> T upgrade(Class<T> handlerClass) throws IOException, ServletException {
718        LogKit.error("Unsupport upgrade method!");
719        return null;
720    }
721
722    public void setServletContext(ServletContext servletContext) {
723        this.servletContext = servletContext;
724    }
725
726    public HttpServletResponse getResponse() {
727        return response;
728    }
729
730    public void setResponse(HttpServletResponse response) {
731        this.response = response;
732    }
733}