fix(back-end): fix csrm and cors
This commit is contained in:
@@ -0,0 +1,88 @@
|
|||||||
|
package com.printcalculator.config;
|
||||||
|
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
import java.util.LinkedHashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
public class AllowedOriginService {
|
||||||
|
|
||||||
|
private final List<String> allowedOrigins;
|
||||||
|
|
||||||
|
public AllowedOriginService(
|
||||||
|
@Value("${app.frontend.base-url:http://localhost:4200}") String frontendBaseUrl,
|
||||||
|
@Value("${app.cors.additional-allowed-origins:}") String additionalAllowedOrigins
|
||||||
|
) {
|
||||||
|
LinkedHashSet<String> configuredOrigins = new LinkedHashSet<>();
|
||||||
|
addConfiguredOrigin(configuredOrigins, frontendBaseUrl, "app.frontend.base-url");
|
||||||
|
|
||||||
|
for (String rawOrigin : additionalAllowedOrigins.split(",")) {
|
||||||
|
addConfiguredOrigin(configuredOrigins, rawOrigin, "app.cors.additional-allowed-origins");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (configuredOrigins.isEmpty()) {
|
||||||
|
throw new IllegalStateException("At least one allowed origin must be configured.");
|
||||||
|
}
|
||||||
|
this.allowedOrigins = List.copyOf(configuredOrigins);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getAllowedOrigins() {
|
||||||
|
return allowedOrigins;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isAllowed(String rawOriginOrUrl) {
|
||||||
|
String normalizedOrigin = normalizeRequestOrigin(rawOriginOrUrl);
|
||||||
|
return normalizedOrigin != null && allowedOrigins.contains(normalizedOrigin);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addConfiguredOrigin(Set<String> configuredOrigins, String rawOriginOrUrl, String propertyName) {
|
||||||
|
if (rawOriginOrUrl == null || rawOriginOrUrl.isBlank()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
String normalizedOrigin = normalizeRequestOrigin(rawOriginOrUrl);
|
||||||
|
if (normalizedOrigin == null) {
|
||||||
|
throw new IllegalStateException(propertyName + " must contain absolute http(s) URLs.");
|
||||||
|
}
|
||||||
|
configuredOrigins.add(normalizedOrigin);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String normalizeRequestOrigin(String rawOriginOrUrl) {
|
||||||
|
if (rawOriginOrUrl == null || rawOriginOrUrl.isBlank()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
URI uri = URI.create(rawOriginOrUrl.trim());
|
||||||
|
String scheme = uri.getScheme();
|
||||||
|
String host = uri.getHost();
|
||||||
|
if (scheme == null || host == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
String normalizedScheme = scheme.toLowerCase(Locale.ROOT);
|
||||||
|
if (!"http".equals(normalizedScheme) && !"https".equals(normalizedScheme)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
String normalizedHost = host.toLowerCase(Locale.ROOT);
|
||||||
|
int port = uri.getPort();
|
||||||
|
if (isDefaultPort(normalizedScheme, port) || port < 0) {
|
||||||
|
return normalizedScheme + "://" + normalizedHost;
|
||||||
|
}
|
||||||
|
return normalizedScheme + "://" + normalizedHost + ":" + port;
|
||||||
|
} catch (IllegalArgumentException ignored) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isDefaultPort(String scheme, int port) {
|
||||||
|
return ("http".equals(scheme) && port == 80)
|
||||||
|
|| ("https".equals(scheme) && port == 443);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
package com.printcalculator.security;
|
||||||
|
|
||||||
|
import com.printcalculator.config.AllowedOriginService;
|
||||||
|
import jakarta.servlet.FilterChain;
|
||||||
|
import jakarta.servlet.ServletException;
|
||||||
|
import jakarta.servlet.http.HttpServletRequest;
|
||||||
|
import jakarta.servlet.http.HttpServletResponse;
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
import org.springframework.web.filter.OncePerRequestFilter;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Locale;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
@Component
|
||||||
|
public class AdminCsrfProtectionFilter extends OncePerRequestFilter {
|
||||||
|
|
||||||
|
private static final Set<String> SAFE_METHODS = Set.of("GET", "HEAD", "OPTIONS", "TRACE");
|
||||||
|
|
||||||
|
private final AllowedOriginService allowedOriginService;
|
||||||
|
|
||||||
|
public AdminCsrfProtectionFilter(AllowedOriginService allowedOriginService) {
|
||||||
|
this.allowedOriginService = allowedOriginService;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean shouldNotFilter(HttpServletRequest request) {
|
||||||
|
String path = resolvePath(request);
|
||||||
|
String method = request.getMethod() == null ? "" : request.getMethod().toUpperCase(Locale.ROOT);
|
||||||
|
return !path.startsWith("/api/admin/") || SAFE_METHODS.contains(method);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void doFilterInternal(HttpServletRequest request,
|
||||||
|
HttpServletResponse response,
|
||||||
|
FilterChain filterChain) throws ServletException, IOException {
|
||||||
|
String origin = request.getHeader(HttpHeaders.ORIGIN);
|
||||||
|
String referer = request.getHeader(HttpHeaders.REFERER);
|
||||||
|
|
||||||
|
if (allowedOriginService.isAllowed(origin) || allowedOriginService.isAllowed(referer)) {
|
||||||
|
filterChain.doFilter(request, response);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
response.setStatus(HttpServletResponse.SC_FORBIDDEN);
|
||||||
|
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
|
||||||
|
response.getWriter().write("{\"error\":\"CSRF_INVALID\"}");
|
||||||
|
}
|
||||||
|
|
||||||
|
private String resolvePath(HttpServletRequest request) {
|
||||||
|
String path = request.getRequestURI();
|
||||||
|
String contextPath = request.getContextPath();
|
||||||
|
if (contextPath != null && !contextPath.isEmpty() && path.startsWith(contextPath)) {
|
||||||
|
return path.substring(contextPath.length());
|
||||||
|
}
|
||||||
|
return path;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user