diff --git a/backend/src/main/java/com/printcalculator/config/AllowedOriginService.java b/backend/src/main/java/com/printcalculator/config/AllowedOriginService.java new file mode 100644 index 0000000..679b309 --- /dev/null +++ b/backend/src/main/java/com/printcalculator/config/AllowedOriginService.java @@ -0,0 +1,88 @@ +package com.printcalculator.config; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; + +import java.net.URI; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Set; + +@Service +public class AllowedOriginService { + + private final List allowedOrigins; + + public AllowedOriginService( + @Value("${app.frontend.base-url:http://localhost:4200}") String frontendBaseUrl, + @Value("${app.cors.additional-allowed-origins:}") String additionalAllowedOrigins + ) { + LinkedHashSet configuredOrigins = new LinkedHashSet<>(); + addConfiguredOrigin(configuredOrigins, frontendBaseUrl, "app.frontend.base-url"); + + for (String rawOrigin : additionalAllowedOrigins.split(",")) { + addConfiguredOrigin(configuredOrigins, rawOrigin, "app.cors.additional-allowed-origins"); + } + + if (configuredOrigins.isEmpty()) { + throw new IllegalStateException("At least one allowed origin must be configured."); + } + this.allowedOrigins = List.copyOf(configuredOrigins); + } + + public List getAllowedOrigins() { + return allowedOrigins; + } + + public boolean isAllowed(String rawOriginOrUrl) { + String normalizedOrigin = normalizeRequestOrigin(rawOriginOrUrl); + return normalizedOrigin != null && allowedOrigins.contains(normalizedOrigin); + } + + private void addConfiguredOrigin(Set configuredOrigins, String rawOriginOrUrl, String propertyName) { + if (rawOriginOrUrl == null || rawOriginOrUrl.isBlank()) { + return; + } + + String normalizedOrigin = normalizeRequestOrigin(rawOriginOrUrl); + if (normalizedOrigin == null) { + throw new IllegalStateException(propertyName + " must contain absolute http(s) URLs."); + } + configuredOrigins.add(normalizedOrigin); + } + + private String normalizeRequestOrigin(String rawOriginOrUrl) { + if (rawOriginOrUrl == null || rawOriginOrUrl.isBlank()) { + return null; + } + + try { + URI uri = URI.create(rawOriginOrUrl.trim()); + String scheme = uri.getScheme(); + String host = uri.getHost(); + if (scheme == null || host == null) { + return null; + } + + String normalizedScheme = scheme.toLowerCase(Locale.ROOT); + if (!"http".equals(normalizedScheme) && !"https".equals(normalizedScheme)) { + return null; + } + + String normalizedHost = host.toLowerCase(Locale.ROOT); + int port = uri.getPort(); + if (isDefaultPort(normalizedScheme, port) || port < 0) { + return normalizedScheme + "://" + normalizedHost; + } + return normalizedScheme + "://" + normalizedHost + ":" + port; + } catch (IllegalArgumentException ignored) { + return null; + } + } + + private boolean isDefaultPort(String scheme, int port) { + return ("http".equals(scheme) && port == 80) + || ("https".equals(scheme) && port == 443); + } +} diff --git a/backend/src/main/java/com/printcalculator/security/AdminCsrfProtectionFilter.java b/backend/src/main/java/com/printcalculator/security/AdminCsrfProtectionFilter.java new file mode 100644 index 0000000..47321d4 --- /dev/null +++ b/backend/src/main/java/com/printcalculator/security/AdminCsrfProtectionFilter.java @@ -0,0 +1,60 @@ +package com.printcalculator.security; + +import com.printcalculator.config.AllowedOriginService; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.stereotype.Component; +import org.springframework.web.filter.OncePerRequestFilter; + +import java.io.IOException; +import java.util.Locale; +import java.util.Set; + +@Component +public class AdminCsrfProtectionFilter extends OncePerRequestFilter { + + private static final Set SAFE_METHODS = Set.of("GET", "HEAD", "OPTIONS", "TRACE"); + + private final AllowedOriginService allowedOriginService; + + public AdminCsrfProtectionFilter(AllowedOriginService allowedOriginService) { + this.allowedOriginService = allowedOriginService; + } + + @Override + protected boolean shouldNotFilter(HttpServletRequest request) { + String path = resolvePath(request); + String method = request.getMethod() == null ? "" : request.getMethod().toUpperCase(Locale.ROOT); + return !path.startsWith("/api/admin/") || SAFE_METHODS.contains(method); + } + + @Override + protected void doFilterInternal(HttpServletRequest request, + HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + String origin = request.getHeader(HttpHeaders.ORIGIN); + String referer = request.getHeader(HttpHeaders.REFERER); + + if (allowedOriginService.isAllowed(origin) || allowedOriginService.isAllowed(referer)) { + filterChain.doFilter(request, response); + return; + } + + response.setStatus(HttpServletResponse.SC_FORBIDDEN); + response.setContentType(MediaType.APPLICATION_JSON_VALUE); + response.getWriter().write("{\"error\":\"CSRF_INVALID\"}"); + } + + private String resolvePath(HttpServletRequest request) { + String path = request.getRequestURI(); + String contextPath = request.getContextPath(); + if (contextPath != null && !contextPath.isEmpty() && path.startsWith(contextPath)) { + return path.substring(contextPath.length()); + } + return path; + } +}