/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * license agreements; and to You under the Apache License, version 2.0:
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * This file is part of the Apache Pekko project, which was derived from Akka.
 */

/*
 * Copyright (C) since 2016 Lightbend Inc. <https://www.lightbend.com>
 */

package org.apache.pekko.stream.connectors.s3.impl

import java.util.UUID
import org.apache.pekko
import pekko.actor.ActorSystem
import pekko.http.scaladsl.Http
import pekko.http.scaladsl.model.Uri.Query
import pekko.http.scaladsl.model._
import pekko.http.scaladsl.model.headers.{ `Raw-Request-URI`, ByteRange, RawHeader }
import pekko.stream.connectors.s3.headers.{ CannedAcl, ServerSideEncryption, StorageClass }
import pekko.stream.connectors.s3._
import pekko.stream.connectors.testkit.scaladsl.LogCapturing
import pekko.testkit.{ SocketUtil, TestKit, TestProbe }
import pekko.util.ByteString
import org.scalatest.concurrent.{ IntegrationPatience, ScalaFutures }
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import software.amazon.awssdk.auth.credentials._
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.regions.providers._

import scala.concurrent.ExecutionContext

class HttpRequestsSpec extends AnyFlatSpec with Matchers with ScalaFutures with IntegrationPatience with LogCapturing {

  // test fixtures
  def getSettings(
      bufferType: BufferType = MemoryBufferType,
      awsCredentials: AwsCredentialsProvider = AnonymousCredentialsProvider.create(),
      s3Region: Region = Region.US_EAST_1,
      listBucketApiVersion: ApiVersion = ApiVersion.ListBucketVersion2) = {
    val regionProvider = new AwsRegionProvider {
      def getRegion = s3Region
    }

    S3Settings(bufferType, awsCredentials, regionProvider, listBucketApiVersion)
  }

  val location = S3Location("bucket", "image-1024@2x")
  val contentType = MediaTypes.`image/jpeg`
  val acl = CannedAcl.PublicRead
  val metaHeaders: Map[String, String] = Map("location" -> "San Francisco", "orientation" -> "portrait")
  val multipartUpload = MultipartUpload("test-bucket", "testKey", "uploadId")

  it should "initiate multipart upload when the region is us-east-1" in {
    implicit val settings: S3Settings = getSettings()

    val req =
      HttpRequests.initiateMultipartUploadRequest(
        location,
        contentType,
        S3Headers().withCannedAcl(acl).withMetaHeaders(MetaHeaders(metaHeaders)).headers)

    req.entity shouldEqual HttpEntity.empty(contentType)
    req.headers should contain(RawHeader("x-amz-acl", acl.value))
    req.uri.authority.host.toString shouldEqual "bucket.s3.us-east-1.amazonaws.com"
    req.uri.path.toString shouldEqual "/image-1024@2x"

    metaHeaders.map { m =>
      req.headers should contain(RawHeader(s"x-amz-meta-${m._1}", m._2))
    }
  }

  it should "initiate multipart upload with other regions" in {
    implicit val settings: S3Settings = getSettings(s3Region = Region.US_EAST_2)

    val req =
      HttpRequests.initiateMultipartUploadRequest(
        location,
        contentType,
        S3Headers().withCannedAcl(acl).withMetaHeaders(MetaHeaders(metaHeaders)).headers)

    req.entity shouldEqual HttpEntity.empty(contentType)
    req.headers should contain(RawHeader("x-amz-acl", acl.value))
    req.uri.authority.host.toString shouldEqual "bucket.s3.us-east-2.amazonaws.com"
    req.uri.path.toString shouldEqual "/image-1024@2x"

    metaHeaders.map { m =>
      req.headers should contain(RawHeader(s"x-amz-meta-${m._1}", m._2))
    }
  }

  it should "throw an error if path-style access is false and the bucket name contains non-LDH characters" in {
    implicit val settings: S3Settings = getSettings(s3Region = Region.EU_WEST_1)

    assertThrows[IllegalUriException](
      HttpRequests.getDownloadRequest(S3Location("invalid_bucket_name", "image-1024@2x")))
  }

  it should "throw an error if the key uses `..`" in {
    implicit val settings: S3Settings = getSettings(s3Region = Region.EU_WEST_1)

    assertThrows[IllegalUriException](
      HttpRequests.getDownloadRequest(S3Location("validbucket", "../other-bucket/image-1024@2x")))
  }

  it should "throw an error when using `..` with path-style access" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.EU_WEST_1).withAccessStyle(AccessStyle.PathAccessStyle)

    assertThrows[IllegalUriException](
      HttpRequests.getDownloadRequest(S3Location("invalid/../bucket_name", "image-1024@2x")))
    assertThrows[IllegalUriException](
      HttpRequests.getDownloadRequest(S3Location("../bucket_name", "image-1024@2x")))
    assertThrows[IllegalUriException](
      HttpRequests.getDownloadRequest(S3Location("bucket_name/..", "image-1024@2x")))
    assertThrows[IllegalUriException](
      HttpRequests.getDownloadRequest(S3Location("..", "image-1024@2x")))
  }

  it should "initiate multipart upload with path-style access in region us-east-1" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.US_EAST_1).withAccessStyle(AccessStyle.PathAccessStyle)

    val req =
      HttpRequests.initiateMultipartUploadRequest(
        location,
        contentType,
        S3Headers().withCannedAcl(acl).withMetaHeaders(MetaHeaders(metaHeaders)).headers)

    req.uri.authority.host.toString shouldEqual "s3.us-east-1.amazonaws.com"
    req.uri.path.toString shouldEqual "/bucket/image-1024@2x"
  }

  it should "support download requests with path-style access in region us-east-1" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.US_EAST_1).withAccessStyle(AccessStyle.PathAccessStyle)

    val req = HttpRequests.getDownloadRequest(location)

    req.uri.authority.host.toString shouldEqual "s3.us-east-1.amazonaws.com"
    req.uri.path.toString shouldEqual "/bucket/image-1024@2x"
    req.uri.rawQueryString shouldBe empty
  }

  it should "initiate multipart upload with path-style access in other regions" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.US_WEST_2).withAccessStyle(AccessStyle.PathAccessStyle)

    val req =
      HttpRequests.initiateMultipartUploadRequest(
        location,
        contentType,
        S3Headers().withCannedAcl(acl).withMetaHeaders(MetaHeaders(metaHeaders)).headers)

    req.uri.authority.host.toString shouldEqual "s3.us-west-2.amazonaws.com"
    req.uri.path.toString shouldEqual "/bucket/image-1024@2x"
  }

  it should "support download requests with path-style access in other regions" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.EU_WEST_1).withAccessStyle(AccessStyle.PathAccessStyle)

    val req = HttpRequests.getDownloadRequest(location)

    req.uri.authority.host.toString shouldEqual "s3.eu-west-1.amazonaws.com"
    req.uri.path.toString shouldEqual "/bucket/image-1024@2x"
    req.uri.rawQueryString shouldBe empty
  }

  it should "support download requests via configured `endpointUrl`" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.EU_WEST_1).withEndpointUrl("http://localhost:8080")

    val req = HttpRequests.getDownloadRequest(location)

    req.uri.scheme shouldEqual "http"
    req.uri.authority.host.address shouldEqual "localhost"
    req.uri.authority.port shouldEqual 8080
  }

  it should "support download requests with keys starting with /" in {
    // the official client supports this and this translates
    // into an object at path /[empty string]/...
    // added this test because of a tricky uri building issue
    // in case of pathStyleAccess = false
    implicit val settings: S3Settings = getSettings()

    val location = S3Location("bucket", "/test/foo.txt")

    val req = HttpRequests.getDownloadRequest(location)

    req.uri.authority.host.toString shouldEqual "bucket.s3.us-east-1.amazonaws.com"
    req.uri.path.toString shouldEqual "//test/foo.txt"
    req.uri.rawQueryString shouldBe empty
  }

  it should "support download requests with keys ending with /" in {
    // object with a slash at the end of the filename should be accessible
    implicit val settings: S3Settings = getSettings()

    val location = S3Location("bucket", "/test//")

    val req = HttpRequests.getDownloadRequest(location)

    req.uri.authority.host.toString shouldEqual "bucket.s3.us-east-1.amazonaws.com"
    req.uri.path.toString shouldEqual "//test//"
    req.uri.rawQueryString shouldBe empty
  }

  it should "support download requests with keys containing spaces" in {
    implicit val settings: S3Settings = getSettings()

    val location = S3Location("bucket", "test folder/test file.txt")

    val req = HttpRequests.getDownloadRequest(location)

    req.uri.authority.host.toString shouldEqual "bucket.s3.us-east-1.amazonaws.com"
    req.uri.path.toString shouldEqual "/test%20folder/test%20file.txt"
    req.uri.rawQueryString shouldBe empty
  }

  it should "support download requests with keys containing plus" in {
    implicit val settings: S3Settings = getSettings()

    val location = S3Location("bucket", "test folder/1 + 2 = 3")
    val req = HttpRequests.getDownloadRequest(location)
    req.uri.authority.host.toString shouldEqual "bucket.s3.us-east-1.amazonaws.com"
    req.uri.path.toString shouldEqual "/test%20folder/1%20+%202%20=%203"
    req.headers should contain(`Raw-Request-URI`("/test%20folder/1%20%2B%202%20=%203"))
  }

  it should "support download requests with keys containing spaces with path-style access in other regions" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.EU_WEST_1).withAccessStyle(AccessStyle.PathAccessStyle)

    val location = S3Location("bucket", "test folder/test file.txt")

    val req = HttpRequests.getDownloadRequest(location)

    req.uri.authority.host.toString shouldEqual "s3.eu-west-1.amazonaws.com"
    req.uri.path.toString shouldEqual "/bucket/test%20folder/test%20file.txt"
    req.uri.rawQueryString shouldBe empty
  }

  it should "add versionId query parameter when provided" in {
    implicit val settings: S3Settings = getSettings().withAccessStyle(AccessStyle.PathAccessStyle)

    val location = S3Location("bucket", "test/foo.txt")
    val versionId = "123456"
    val req = HttpRequests.getDownloadRequest(location, versionId = Some(versionId))

    req.uri.authority.host.toString shouldEqual "s3.us-east-1.amazonaws.com"
    req.uri.path.toString shouldEqual "/bucket/test/foo.txt"
    req.uri.rawQueryString.fold(fail("query string is empty while it was supposed to be populated")) { rawQueryString =>
      rawQueryString shouldEqual s"versionId=$versionId"
    }
  }

  it should "support multipart init upload requests via configured `endpointUrl`" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.EU_WEST_1).withEndpointUrl("http://localhost:8080")

    val req =
      HttpRequests.initiateMultipartUploadRequest(
        location,
        contentType,
        S3Headers().withCannedAcl(acl).withMetaHeaders(MetaHeaders(metaHeaders)).headers)

    req.uri.scheme shouldEqual "http"
    req.uri.authority.host.address shouldEqual "localhost"
    req.uri.authority.port shouldEqual 8080
  }

  it should "support multipart upload part requests via configured `endpointUrl`" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.EU_WEST_1).withEndpointUrl("http://localhost:8080")

    val req =
      HttpRequests.uploadPartRequest(multipartUpload, 1, MemoryChunk(ByteString.empty))

    req.uri.scheme shouldEqual "http"
    req.uri.authority.host.address shouldEqual "localhost"
    req.uri.authority.port shouldEqual 8080
  }

  it should "properly multipart upload part request with customer keys server side encryption" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.EU_WEST_1).withAccessStyle(AccessStyle.PathAccessStyle)
    val myKey = "my-key"
    val md5Key = "md5-key"
    val s3Headers = ServerSideEncryption.customerKeys(myKey).withMd5(md5Key).headersFor(UploadPart)
    val req = HttpRequests.uploadPartRequest(multipartUpload, 1, MemoryChunk(ByteString.empty), s3Headers)

    req.headers should contain(RawHeader("x-amz-server-side-encryption-customer-algorithm", "AES256"))
    req.headers should contain(RawHeader("x-amz-server-side-encryption-customer-key", myKey))
    req.headers should contain(RawHeader("x-amz-server-side-encryption-customer-key-MD5", md5Key))
  }

  it should "support multipart upload complete requests via configured `endpointUrl`" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.EU_WEST_1).withEndpointUrl("http://localhost:8080")
    implicit val executionContext: ExecutionContext = ExecutionContext.global

    val req =
      HttpRequests.completeMultipartUploadRequest(multipartUpload, (1, "part") :: Nil, Nil).futureValue

    req.uri.scheme shouldEqual "http"
    req.uri.authority.host.address shouldEqual "localhost"
    req.uri.authority.port shouldEqual 8080
  }

  it should "initiate multipart upload with AES-256 server side encryption" in {
    implicit val settings: S3Settings = getSettings(s3Region = Region.US_EAST_2)
    val s3Headers = ServerSideEncryption.aes256().headersFor(InitiateMultipartUpload)
    val req = HttpRequests.initiateMultipartUploadRequest(location, contentType, s3Headers)

    req.headers should contain(RawHeader("x-amz-server-side-encryption", "AES256"))
  }

  it should "initiate multipart upload with aws:kms server side encryption" in {
    implicit val settings: S3Settings = getSettings(s3Region = Region.US_EAST_2)
    val testArn = "arn:aws:kms:my-region:my-account-id:key/my-key-id"
    val s3Headers = ServerSideEncryption.kms(testArn).headersFor(InitiateMultipartUpload)
    val req = HttpRequests.initiateMultipartUploadRequest(location, contentType, s3Headers)

    req.headers should contain(RawHeader("x-amz-server-side-encryption", "aws:kms"))
    req.headers should contain(RawHeader("x-amz-server-side-encryption-aws-kms-key-id", testArn))
  }

  it should "initiate multipart upload with customer keys encryption" in {
    implicit val settings: S3Settings = getSettings(s3Region = Region.US_EAST_2)
    val myKey = "my-key"
    val md5Key = "md5-key"
    val s3Headers = ServerSideEncryption.customerKeys(myKey).withMd5(md5Key).headersFor(InitiateMultipartUpload)
    val req = HttpRequests.initiateMultipartUploadRequest(location, contentType, s3Headers)

    req.headers should contain(RawHeader("x-amz-server-side-encryption-customer-algorithm", "AES256"))
    req.headers should contain(RawHeader("x-amz-server-side-encryption-customer-key", myKey))
    req.headers should contain(RawHeader("x-amz-server-side-encryption-customer-key-MD5", md5Key))
  }

  it should "initiate multipart upload with custom s3 storage class" in {
    implicit val settings: S3Settings = getSettings(s3Region = Region.US_EAST_2)
    val s3Headers = S3Headers().withStorageClass(StorageClass.ReducedRedundancy).headers
    val req = HttpRequests.initiateMultipartUploadRequest(location, contentType, s3Headers)

    req.headers should contain(RawHeader("x-amz-storage-class", "REDUCED_REDUNDANCY"))
  }

  it should "initiate multipart upload with custom s3 headers" in {
    implicit val settings: S3Settings = getSettings(s3Region = Region.US_EAST_2)
    val s3Headers = S3Headers().withCustomHeaders(Map("Cache-Control" -> "no-cache")).headers
    val req = HttpRequests.initiateMultipartUploadRequest(location, contentType, s3Headers)

    req.headers should contain(RawHeader("Cache-Control", "no-cache"))
  }

  it should "properly construct the list bucket request with no prefix, continuation token or delimiter passed" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.US_EAST_2).withAccessStyle(AccessStyle.PathAccessStyle)

    val req =
      HttpRequests.listBucket(location.bucket)

    req.uri.query() shouldEqual Query("list-type" -> "2")
  }

  it should "properly construct the list bucket request with a prefix and token passed" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.US_EAST_2).withAccessStyle(AccessStyle.PathAccessStyle)

    val req =
      HttpRequests.listBucket(location.bucket, Some("random/prefix"), Some("randomToken"))

    req.uri.query() shouldEqual Query("list-type" -> "2",
      "prefix" -> "random/prefix",
      "continuation-token" -> "randomToken")
  }

  it should "properly construct the list bucket request with a delimiter and token passed" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.US_EAST_2).withAccessStyle(AccessStyle.PathAccessStyle)

    val req =
      HttpRequests.listBucket(location.bucket, delimiter = Some("/"), continuationToken = Some("randomToken"))

    req.uri.query() shouldEqual Query("list-type" -> "2", "delimiter" -> "/", "continuation-token" -> "randomToken")
  }

  it should "properly construct the list bucket request when using api version 1" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.US_EAST_2, listBucketApiVersion = ApiVersion.ListBucketVersion1)
        .withAccessStyle(AccessStyle.PathAccessStyle)

    val req =
      HttpRequests.listBucket(location.bucket)

    req.uri.query() shouldEqual Query()
  }

  it should "properly construct the list bucket request when using api version set to 1 and a continuation token" in {
    implicit val settings: S3Settings =
      getSettings(s3Region = Region.US_EAST_2, listBucketApiVersion = ApiVersion.ListBucketVersion1)
        .withAccessStyle(AccessStyle.PathAccessStyle)

    val req =
      HttpRequests.listBucket(location.bucket, continuationToken = Some("randomToken"))

    req.uri.query() shouldEqual Query("marker" -> "randomToken")
  }

  it should "support custom endpoint configured by `endpointUrl`" in {
    implicit val system: ActorSystem = ActorSystem("HttpRequestsSpec")

    try {
      val probe = TestProbe()
      val address = SocketUtil.temporaryServerAddress()

      import pekko.http.scaladsl.server.Directives._

      Http()
        .newServerAt(address.getHostName, address.getPort)
        .bind(extractRequestContext { ctx =>
          probe.ref ! ctx.request
          complete("MOCK")
        })

      implicit val setting: S3Settings =
        getSettings().withEndpointUrl(s"http://${address.getHostName}:${address.getPort}/")

      val req =
        HttpRequests.listBucket(location.bucket, Some("random/prefix"), Some("randomToken"))

      Http().singleRequest(req).futureValue shouldBe a[HttpResponse]

      probe.expectMsgType[HttpRequest]
    } finally {
      TestKit.shutdownActorSystem(system)
    }
  }

  it should "add two (source, range) headers to multipart upload (copy) request when byte range populated" in {
    implicit val settings: S3Settings = getSettings()

    val multipartUpload = MultipartUpload("target-bucket", "target-key", UUID.randomUUID().toString)
    val copyPartition = CopyPartition(1, S3Location("source-bucket", "some/source-key"), Some(ByteRange(0, 5242880L)))
    val multipartCopy = MultipartCopy(multipartUpload, copyPartition)

    val request = HttpRequests.uploadCopyPartRequest(multipartCopy)
    request.headers should contain(RawHeader("x-amz-copy-source", "/source-bucket/some%2Fsource-key"))
    request.headers should contain(RawHeader("x-amz-copy-source-range", "bytes=0-5242879"))
  }

  it should "add only source header to multipart upload (copy) request when byte range missing" in {
    implicit val settings: S3Settings = getSettings()

    val multipartUpload = MultipartUpload("target-bucket", "target-key", UUID.randomUUID().toString)
    val copyPartition = CopyPartition(1, S3Location("source-bucket", "some/source-key"))
    val multipartCopy = MultipartCopy(multipartUpload, copyPartition)

    val request = HttpRequests.uploadCopyPartRequest(multipartCopy)
    request.headers should contain(RawHeader("x-amz-copy-source", "/source-bucket/some%2Fsource-key"))
    request.headers.map(_.lowercaseName()) should not contain "x-amz-copy-source-range"
  }

  it should "add versionId parameter to source header if provided" in {
    implicit val settings: S3Settings = getSettings()

    val multipartUpload = MultipartUpload("target-bucket", "target-key", UUID.randomUUID().toString)
    val copyPartition = CopyPartition(1, S3Location("source-bucket", "some/source-key"), Some(ByteRange(0, 5242880L)))
    val multipartCopy = MultipartCopy(multipartUpload, copyPartition)

    val request = HttpRequests.uploadCopyPartRequest(multipartCopy, Some("abcdwxyz"))
    request.headers should contain(
      RawHeader("x-amz-copy-source", "/source-bucket/some%2Fsource-key?versionId=abcdwxyz"))
    request.headers should contain(RawHeader("x-amz-copy-source-range", "bytes=0-5242879"))
  }

  it should "create make bucket request" in {
    implicit val settings: S3Settings = getSettings()

    val request = HttpRequests.bucketManagementRequest(location, method = HttpMethods.PUT)

    // Date is added by pekko by default
    request.uri.authority.host.toString should equal("bucket.s3.us-east-1.amazonaws.com")
    request.entity.contentLengthOption should equal(Some(0))
    request.uri.queryString() should equal(None)
    request.method should equal(HttpMethods.PUT)
  }

  it should "create delete bucket request" in {
    implicit val settings: S3Settings = getSettings()

    val request = HttpRequests.bucketManagementRequest(location, method = HttpMethods.DELETE)

    // Date is added by pekko by default
    request.uri.authority.host.toString should equal("bucket.s3.us-east-1.amazonaws.com")
    request.entity.contentLengthOption should equal(Some(0))
    request.uri.queryString() should equal(None)
    request.method should equal(HttpMethods.DELETE)
  }

  it should "create checkIfExits bucket request" in {
    implicit val settings: S3Settings = getSettings()

    val request: HttpRequest = HttpRequests.bucketManagementRequest(location, method = HttpMethods.HEAD)

    // Date is added by pekko by default
    request.uri.authority.host.toString should equal("bucket.s3.us-east-1.amazonaws.com")
    request.entity.contentLengthOption should equal(Some(0))
    request.uri.queryString() should equal(None)
    request.method should equal(HttpMethods.HEAD)
  }

  it should "add x-amz-mfa headers for a putBucketVersioning request" in {
    implicit val settings: S3Settings = getSettings()

    val serialNumber = "serial-number"
    val tokenCode = "token-code"

    val request: HttpRequest = HttpRequests.bucketVersioningRequest("target-bucket",
      Some(MFAStatus.Enabled(MFA(serialNumber, tokenCode))),
      HttpMethods.PUT)

    request.headers.collectFirst {
      case httpHeader if httpHeader.is("x-amz-mfa") => httpHeader.value()
    } should equal(Some(s"$serialNumber $tokenCode"))
  }
}
