dev #51

Merged
JoeKung merged 5 commits from dev into main 2026-03-22 23:06:02 +01:00
2 changed files with 148 additions and 0 deletions
Showing only changes of commit cc343ee27c - Show all commits

View File

@@ -0,0 +1,88 @@
package com.printcalculator.config;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.net.URI;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
@Service
public class AllowedOriginService {
private final List<String> allowedOrigins;
public AllowedOriginService(
@Value("${app.frontend.base-url:http://localhost:4200}") String frontendBaseUrl,
@Value("${app.cors.additional-allowed-origins:}") String additionalAllowedOrigins
) {
LinkedHashSet<String> configuredOrigins = new LinkedHashSet<>();
addConfiguredOrigin(configuredOrigins, frontendBaseUrl, "app.frontend.base-url");
for (String rawOrigin : additionalAllowedOrigins.split(",")) {
addConfiguredOrigin(configuredOrigins, rawOrigin, "app.cors.additional-allowed-origins");
}
if (configuredOrigins.isEmpty()) {
throw new IllegalStateException("At least one allowed origin must be configured.");
}
this.allowedOrigins = List.copyOf(configuredOrigins);
}
public List<String> getAllowedOrigins() {
return allowedOrigins;
}
public boolean isAllowed(String rawOriginOrUrl) {
String normalizedOrigin = normalizeRequestOrigin(rawOriginOrUrl);
return normalizedOrigin != null && allowedOrigins.contains(normalizedOrigin);
}
private void addConfiguredOrigin(Set<String> configuredOrigins, String rawOriginOrUrl, String propertyName) {
if (rawOriginOrUrl == null || rawOriginOrUrl.isBlank()) {
return;
}
String normalizedOrigin = normalizeRequestOrigin(rawOriginOrUrl);
if (normalizedOrigin == null) {
throw new IllegalStateException(propertyName + " must contain absolute http(s) URLs.");
}
configuredOrigins.add(normalizedOrigin);
}
private String normalizeRequestOrigin(String rawOriginOrUrl) {
if (rawOriginOrUrl == null || rawOriginOrUrl.isBlank()) {
return null;
}
try {
URI uri = URI.create(rawOriginOrUrl.trim());
String scheme = uri.getScheme();
String host = uri.getHost();
if (scheme == null || host == null) {
return null;
}
String normalizedScheme = scheme.toLowerCase(Locale.ROOT);
if (!"http".equals(normalizedScheme) && !"https".equals(normalizedScheme)) {
return null;
}
String normalizedHost = host.toLowerCase(Locale.ROOT);
int port = uri.getPort();
if (isDefaultPort(normalizedScheme, port) || port < 0) {
return normalizedScheme + "://" + normalizedHost;
}
return normalizedScheme + "://" + normalizedHost + ":" + port;
} catch (IllegalArgumentException ignored) {
return null;
}
}
private boolean isDefaultPort(String scheme, int port) {
return ("http".equals(scheme) && port == 80)
|| ("https".equals(scheme) && port == 443);
}
}

View File

@@ -0,0 +1,60 @@
package com.printcalculator.security;
import com.printcalculator.config.AllowedOriginService;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
import java.io.IOException;
import java.util.Locale;
import java.util.Set;
@Component
public class AdminCsrfProtectionFilter extends OncePerRequestFilter {
private static final Set<String> SAFE_METHODS = Set.of("GET", "HEAD", "OPTIONS", "TRACE");
private final AllowedOriginService allowedOriginService;
public AdminCsrfProtectionFilter(AllowedOriginService allowedOriginService) {
this.allowedOriginService = allowedOriginService;
}
@Override
protected boolean shouldNotFilter(HttpServletRequest request) {
String path = resolvePath(request);
String method = request.getMethod() == null ? "" : request.getMethod().toUpperCase(Locale.ROOT);
return !path.startsWith("/api/admin/") || SAFE_METHODS.contains(method);
}
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
String origin = request.getHeader(HttpHeaders.ORIGIN);
String referer = request.getHeader(HttpHeaders.REFERER);
if (allowedOriginService.isAllowed(origin) || allowedOriginService.isAllowed(referer)) {
filterChain.doFilter(request, response);
return;
}
response.setStatus(HttpServletResponse.SC_FORBIDDEN);
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
response.getWriter().write("{\"error\":\"CSRF_INVALID\"}");
}
private String resolvePath(HttpServletRequest request) {
String path = request.getRequestURI();
String contextPath = request.getContextPath();
if (contextPath != null && !contextPath.isEmpty() && path.startsWith(contextPath)) {
return path.substring(contextPath.length());
}
return path;
}
}