// Simulate the sky and calculate its colors
// All of the remotely interesting code comes from here:
// https://www.scratchapixel.com/lessons/procedural-generation-virtual-worlds/simulating-sky/simulating-colors-of-the-sky
// Which is a _great_ article and I highly recommend it

import p5 from 'p5'
import SunCalc from "suncalc"

export class SkyCalculator {
  constructor(p5, config) {
    this.p5 = p5
    this.location = config.location

    this.resolution = config.resolution
    this.skyRange = config.skyRange

    // Make our empty grid
    this.data = []
    for (let x = 0; x < this.resolution.x; x++) {
      this.data[x] = []
      for (let y = 0; y < this.resolution.y; y++) {
        this.data[x][y] = null
      }
    }
  }

  async calculate(datetime, resumable) {
    let sunPos = SunCalc.getPosition(datetime, this.location.lat, this.location.long)
    let azimuth = sunPos.azimuth
    let elevation = sunPos.altitude

    let x = this.p5.cos(azimuth) * this.p5.cos(elevation)
    let z = this.p5.sin(azimuth) * this.p5.cos(elevation)
    let y = this.p5.sin(elevation)

    let sunDir = new p5.Vector(x, y, z)

    return this.calculateSkyDome(sunDir, resumable)
  }

  async calculateSkyDome(sunDir, resumable) {
    let atmosphere = new Atmosphere(sunDir)
    let camera = new p5.Vector(0, atmosphere.earthRadius + this.location.alt, 0)

    let [xMin, xMax] = this.skyRange.x
    let [yMin, yMax] = this.skyRange.y

    let iterCount = 0

    for (let j = 0; j < this.resolution.y; ++j) {
      let y = (j + 0.5) / (this.resolution.y - 1) * (yMax - yMin) - (yMax - yMin) / 2

      for (let i = 0; i < this.resolution.x; ++i) {
        let x = (i + 0.5) / (this.resolution.x - 1) * (xMax - xMin) - (xMax - xMin) / 2
        let z2 = x * x + y * y

        if (z2 <= 1) {
          let phi = this.p5.atan2(y, x)
          let theta = this.p5.acos(1 - z2)
          let dir = new p5.Vector(this.p5.sin(theta) * this.p5.cos(phi), this.p5.cos(theta), this.p5.sin(theta) * this.p5.sin(phi))
          let p = atmosphere.computeIncidentLight(camera, dir, 0, Infinity);

          let r = p.x < 1.413 ? this.p5.pow(p.x * 0.38317, 1.0 / 2.2) : 1.0 - this.p5.exp(-p.x)
          let g = p.y < 1.413 ? this.p5.pow(p.y * 0.38317, 1.0 / 2.2) : 1.0 - this.p5.exp(-p.y)
          let b = p.z < 1.413 ? this.p5.pow(p.z * 0.38317, 1.0 / 2.2) : 1.0 - this.p5.exp(-p.z)

          this.data[i][j] = [r, g, b]

          iterCount += 1
          if (resumable && iterCount % 50 == 0) {
            await resumable.wait()
          }
        }
      }
    }
  }

  // Gets the color at a specific x/y coordinate. Will interpolate for colors inbetween
  getColor(x, y) {
    let x1 = this.p5.floor(x)
    let y1 = this.p5.floor(y)

    let colors = [
      this.getDataAt(x1, y1),
      this.getDataAt(x1 + 1, y1),
      this.getDataAt(x1, y1 + 1),
      this.getDataAt(x1 + 1, y1 + 1),
    ]

    let color = [0, 0, 0]
    for (let c = 0; c < 3; c++) {
      color[c] = bilinearInterpolate(colors[0][c], colors[1][c], colors[2][c], colors[3][c], x - x1, y - y1) * 255
    }

    return color
  }

  // Gets the color at a specific x/y coordinate. Will not interpolate. Will clamp to grid if coord is out of bounds
  // If null, returns [null,null,null]
  getDataAt(x, y) {
    x = this.p5.constrain(x, 0, this.resolution.x - 1)
    y = this.p5.constrain(y, 0, this.resolution.y - 1)

    let c = this.data[x][y]
    if (c == null)
      return [null, null, null]

    return c
  }

  simpleRender(x, y, w, h, pointSize, pointAlpha) {
    for (let gridX = 0; gridX < this.resolution.x; gridX++) {
      for (let gridY = 0; gridY < this.resolution.y; gridY++) {
        let dataPoint = this.data[gridX][gridY]
        if (dataPoint == null)
          continue

        let [r, g, b] = dataPoint

        let screenX = gridX / this.resolution.x * w + x
        let screenY = gridY / this.resolution.y * h + y

        // console.log(screenX, screenY, [r*255, g*255, b*255, pointAlpha])
        this.p5.noStroke()
        this.p5.fill(r * 255, g * 255, b * 255, pointAlpha)
        this.p5.circle(screenX, screenY, pointSize)
      }
    }
  }
}

// Do bilinear interpolation on grid:
// a b
// c d
// Assumes x & y are between 0 and 1
// Null values are valid (although if all 4 points are null it will return null)
// However, null values don't return anything super meaningful (just smooth)
function bilinearInterpolate(a, b, c, d, x, y) {
  let ix1
  let ix2

  if (a == null) {
    ix1 = b
  } else if (b == null) {
    ix1 = a
  } else {
    ix1 = lerp(a, b, x)
  }


  if (c == null) {
    ix2 = d
  } else if (d == null) {
    ix2 = c
  } else {
    ix2 = lerp(c, d, x)
  }

  if (ix1 == null) {
    return ix2
  } else if (ix2 == null) {
    return ix1
  } else {
    return lerp(ix1, ix2, y)
  }
}

function lerp(start, stop, amt) {
  return amt * (stop - start) + start;
}

class Atmosphere {
  constructor(sunDirection) {
    this.sunDirection = sunDirection
    this.earthRadius = 6360e3
    this.atmosphereRadius = 6420e3
    this.hr = 7994 // Thickness of the atmosphere if density was uniform (Hr)
    this.hm = 1200 // Same as above but for Mie scattering (Hm)
    this.betaR = new p5.Vector(3.8e-6, 13.5e-6, 33.1e-6);
    this.betaM = new p5.Vector(21e-6, 21e-6, 21e-6);
  }

  computeIncidentLight(orig, dir, tmin, tmax) {
    let [intersectSuccess, t0, t1] = raySphereIntersect(
      orig,
      dir,
      this.atmosphereRadius
    );

    if (!intersectSuccess || t1 < 0) return 0;

    if (t0 > tmin && t0 > 0) tmin = t0;
    if (t1 < tmax) tmax = t1;
    let numSamples = 16; // TODO: Make configurable?
    let numSamplesLight = 8;
    let segmentLength = (tmax - tmin) / numSamples
    let tCurrent = tmin;

    let sumR = new p5.Vector(); // Sum rayleigh
    let sumM = new p5.Vector(); // Sum mie
    let opticalDepthR = 0;
    let opticalDepthM = 0;
    let mu = dir.dot(this.sunDirection); // mu in the paper which is the cosine of the angle between the sun direction and the ray direction
    let phaseR = (3 / (16 * Math.PI)) * (1 + mu * mu);
    let g = 0.76;
    let phaseM =
      ((3 / (8 * Math.PI)) * ((1 - g * g) * (1 + mu * mu))) /
      ((2 + g * g) * Math.pow(1 + g * g - 2 * g * mu, 1.5));

    for (let i = 0; i < numSamples; ++i) {
      let samplePosition = p5.Vector.add(
        orig,
        p5.Vector.mult(dir, (tCurrent + segmentLength * 0.5))
      );
      let height = samplePosition.mag() - this.earthRadius;
      // compute optical depth for light
      let hr = Math.exp(-height / this.hr) * segmentLength;
      let hm = Math.exp(-height / this.hm) * segmentLength;
      opticalDepthR += hr;
      opticalDepthM += hm;

      // light optical depth
      let [, , t1Light] = raySphereIntersect(
        samplePosition,
        this.sunDirection,
        this.atmosphereRadius
      );
      let segmentLengthLight = t1Light / numSamplesLight,
        tCurrentLight = 0;
      let opticalDepthLightR = 0;
      let opticalDepthLightM = 0;

      let j;
      for (j = 0; j < numSamplesLight; ++j) {
        let samplePositionLight = p5.Vector.add(samplePosition, p5.Vector.mult(this.sunDirection, (tCurrentLight + segmentLengthLight * 0.5)))
        let heightLight = samplePositionLight.mag() - this.earthRadius;
        if (heightLight < 0) break;
        opticalDepthLightR += Math.exp(-heightLight / this.hr) * segmentLengthLight;
        opticalDepthLightM += Math.exp(-heightLight / this.hm) * segmentLengthLight;
        tCurrentLight += segmentLengthLight;
      }
      if (j == numSamplesLight) {
        let rComp = p5.Vector.mult(this.betaR, opticalDepthR + opticalDepthLightR)
        let mComp = p5.Vector.mult(this.betaM, 1.1 * (opticalDepthM + opticalDepthLightM))
        let tau = rComp.add(mComp);

        let attenuation = new p5.Vector(Math.exp(-tau.x), Math.exp(-tau.y), Math.exp(-tau.z))
        sumR.add(p5.Vector.mult(attenuation, hr))
        sumM.add(p5.Vector.mult(attenuation, hm))
      }
      tCurrent += segmentLength
    }
    let rComp = p5.Vector.mult(sumR, this.betaR).mult(phaseR)
    let mComp = p5.Vector.mult(sumM, this.betaM).mult(phaseM)

    return rComp.add(mComp).mult(20)
  }
}

function solveQuadratic(a, b, c) {
  let x1 = 0;
  let x2 = 0;

  if (b == 0) {
    if (a == 0) return [false, x1, x2];

    x1 = 0;
    x2 = Math.sqrt(-c / a);
    return [true, x1, x2];
  }

  let discr = b * b - 4 * a * c;

  if (discr < 0) return [false, x1, x2];

  let q = b < 0 ? -0.5 * (b - Math.sqrt(discr)) : -0.5 * (b + Math.sqrt(discr));

  x1 = q / a;
  x2 = c / q;

  return [true, x1, x2];
}

function raySphereIntersect(orig, dir, radius) {
  let A = dir.x * dir.x + dir.y * dir.y + dir.z * dir.z;
  let B = 2 * (dir.x * orig.x + dir.y * orig.y + dir.z * orig.z);
  let C = orig.x * orig.x + orig.y * orig.y + orig.z * orig.z - radius * radius;

  let [solved, x1, x2] = solveQuadratic(A, B, C);
  if (!solved) return [false, x1, x2];

  if (x1 > x2) {
    let tmp = x1;
    x1 = x2;
    x2 = tmp;
  }

  return [true, x1, x2];
}
