CustomRestClient.java

package se.jobtechdev.personaldatagateway.api.util;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.*;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.web.client.RestClient;
import se.jobtechdev.personaldatagateway.api.exception.ApiException;
import se.jobtechdev.personaldatagateway.api.generated.model.ProblemDetails;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static se.jobtechdev.personaldatagateway.api.util.ProblemDetailsFactory.createProblemDetails;

public final class CustomRestClient {
  private static final Logger logger = LoggerFactory.getLogger(CustomRestClient.class);

  private CustomRestClient() {
  }

  public static DataWrapper retrieveData(
      String path, String method, String headers, String acceptHeader) {
    final var datasourceHeaders = HeaderExtractor.extractHeader(headers);
    if (acceptHeader != null) {
      datasourceHeaders.put(HttpHeaders.ACCEPT, acceptHeader);
    }
    final var headerConsumer = HeaderConsumer.createFromHeaders(datasourceHeaders);
    final var customClient =
        RestClient.builder().baseUrl(path).defaultHeaders(headerConsumer).build();

    final var responceSpec = customClient.method(HttpMethod.valueOf(method)).retrieve();
    responceSpec.onStatus(HttpStatusCode::isError, CustomRestClient::errorHandler);
    final var responseEntity = responceSpec.toEntity(byte[].class);

    var responseContentType = responseEntity.getHeaders().getContentType();
    if (responseContentType != null && responseContentType.toString().contains("*")) {
      // Wildcards are not allowed in Content-Type response header
      responseContentType = null;
    }

    final var code = responseEntity.getStatusCode();
    final var value = code.value();
    final var httpStatus = HttpStatus.valueOf(value);

    return new DataWrapper(responseEntity.getBody(), responseContentType, httpStatus);
  }

  static void errorHandler(HttpRequest request, ClientHttpResponse response) {
    try {
      final var body = response.getBody();
      final var responseBody = convertInputStreamToString(body);
      final var optionalProblemDetails = extractProblemDetails(responseBody);
      optionalProblemDetails.ifPresent(details -> {
        throw new ApiException(details);
      });
      final ProblemDetails problemDetails = new ProblemDetails().status(response.getStatusCode().value()).title(response.getStatusText());
      throw new ApiException(problemDetails);
    } catch (IOException e) {
      final var message = "Failed to create ErrorResponse";
      logger.error(message, e);
      throw new ApiException(createProblemDetails(HttpStatus.INTERNAL_SERVER_ERROR, message));
    }
  }

  static String convertInputStreamToString(InputStream inputStream) {
    try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
      return reader.lines().collect(Collectors.joining("\n"));
    } catch (Exception e) {
      final var message = "Failed to convert InputStream to String";
      logger.error(message, e);
      throw new ApiException(createProblemDetails(HttpStatus.INTERNAL_SERVER_ERROR, message));
    }
  }

  static Optional<ProblemDetails> extractProblemDetails(String responseBody) {
    try {
      return Optional.of(new ObjectMapper().readValue(responseBody, ProblemDetails.class));
    } catch (Exception e) {
      return Optional.empty();
    }
  }

  public static class HeaderConsumer {
    private HeaderConsumer() {
    }

    public static Consumer<HttpHeaders> createFromHeaders(Map<String, String> headers) {
      return header -> headers.forEach(header::add);
    }
  }
}