import { useEffect, useState } from 'react';
import {
  unpackCascade,
  DetectionBox,
  ClassifyRegionFunction,
  clusterDetections,
  runCascade,
  UpdateMemoryFunction,
  instantiateDetectionMemory,
  useRefValue,
} from '../index';

const MINSIZE_MINIMUM = 20;

let regionClassifierCache: ClassifyRegionFunction | null = null;
export async function getRegionClassifier() {
  if (regionClassifierCache) return regionClassifierCache;

  const cascadeurl = `/models/facefinder.cascade`;
  const response = await fetch(cascadeurl);
  const buffer = await response.arrayBuffer();
  const bytes = new Int8Array(buffer);
  regionClassifierCache = unpackCascade(bytes);
  return regionClassifierCache;
}

function rgbaToGrayscale(
  rgba: Uint8ClampedArray,
  nrows: number,
  ncols: number
) {
  const gray = new Uint8Array(nrows * ncols);
  for (let r = 0; r < nrows; r += 1) {
    for (let c = 0; c < ncols; c += 1) {
      // gray = 0.2*red + 0.7*green + 0.1*blue
      gray[r * ncols + c] =
        (2 * rgba[r * 4 * ncols + 4 * c + 0] +
          7 * rgba[r * 4 * ncols + 4 * c + 1] +
          1 * rgba[r * 4 * ncols + 4 * c + 2]) /
        10;
    }
  }
  return gray;
}

interface IGetDetectionsConfig {
  /** move the detection window by 10% of its size (0.1) */
  shiftfactor: number;

  /** minimum size of a face */
  minsize: number; // minimum size of a face (100)

  /** maximum size of a face (1000) */
  maxsize: number; // maximum size of a face (1000)

  /** for multiscale processing: resize the detection window by 10% when moving to the higher scale */
  scalefactor: number; // for multiscale processing: resize the detection window by 10% when moving to the higher scale (1.1)
}

export async function getDetections(
  imageData: ImageData,
  updateMemory: UpdateMemoryFunction,
  {
    shiftfactor = 0.1,
    minsize = 100,
    maxsize = 1000,
    scalefactor = 1.1,
  }: Partial<IGetDetectionsConfig> = {}
) {
  const classifyRegion = await getRegionClassifier();

  const { width, height } = imageData;
  const image = {
    pixels: rgbaToGrayscale(imageData.data, height, width),
    nrows: height,
    ncols: width,
    ldim: width,
  };

  // run the cascade over the frame and cluster the obtained detections
  // dets is an array that contains (row, column, scale and detection score)
  let dets = runCascade(image, classifyRegion, {
    shiftfactor,
    minsize: Math.max(minsize, MINSIZE_MINIMUM),
    maxsize,
    scalefactor,
  });
  dets = updateMemory(dets);
  const detections = clusterDetections(dets, 0.2); // set IoU threshold to 0.2

  return detections;
}

function drawDetectionBox(
  canvasContext: CanvasRenderingContext2D,
  detectionBox: DetectionBox,
  { lineWidth = 2, boxColor }: { lineWidth?: number; boxColor: string },
  resizeBy: number
) {
  if (!detectionBox) throw new Error('No detection box found to draw!');

  const { centerY, centerX, scale } = detectionBox;

  const radius = scale / 2;
  const x = centerX - radius;
  const y = centerY - radius;

  const resizedX = x * resizeBy;
  const resizedY = y * resizeBy;
  const resizedScale = scale * resizeBy;

  canvasContext.beginPath();
  canvasContext.rect(resizedX, resizedY, resizedScale, resizedScale);
  /* eslint-disable no-param-reassign */
  canvasContext.lineWidth = lineWidth;
  canvasContext.strokeStyle = boxColor;
  /* eslint-enable no-param-reassign */
  canvasContext.stroke();
}

export type FaceStatus = 'NO' | 'FAR' | 'YES';

export interface UseFaceDetectionParams {
  inputVideo: HTMLVideoElement | null;
  outputCanvas: HTMLCanvasElement | null;
  validRangeMinFaceSize?: number;
  detectionMinFaceSize?: number;
  /** between 10 and 300, default 40 */
  detectionScoreThreshold?: number;
  shouldDrawDetection?: boolean;
  onDetection?: (faceStatus: FaceStatus, detection: DetectionBox) => void;
  /** we will use the detecions of the last n frames, where n is framesMemorySize, this cannot change */
  framesMemorySize?: number;
  /** detection will be done on 1/n resolution of input source for better performance, n being downscaleBy */
  downscaleBy?: number;
}

export function useFaceDetection({
  inputVideo,
  outputCanvas,
  validRangeMinFaceSize = 200,
  detectionMinFaceSize = 120,
  detectionScoreThreshold = 40,
  shouldDrawDetection = true,
  onDetection,
  framesMemorySize = 5,
  downscaleBy = 3,
}: UseFaceDetectionParams) {
  const [updateMemoryCb] = useState(() =>
    instantiateDetectionMemory(framesMemorySize)
  );
  const onDetectionRef = useRefValue(onDetection);
  const downscaledDetectionMinFaceSizeRef = useRefValue(
    detectionMinFaceSize / downscaleBy
  );
  const downscaledValidMinFaceSizeRef = useRefValue(
    validRangeMinFaceSize / downscaleBy
  );
  const [faceStatus, setFaceStatus] = useState<FaceStatus>('NO');

  // setup
  useEffect(() => {
    if (!inputVideo || !outputCanvas) {
      setFaceStatus('NO');
    }
  }, [inputVideo, outputCanvas]);

  // run
  useEffect(() => {
    if (!inputVideo || !outputCanvas) return () => null;

    const inputCanvas = document.createElement('canvas');

    const { videoWidth, videoHeight } = inputVideo;
    /* eslint-disable no-param-reassign */
    outputCanvas.width = videoWidth;
    outputCanvas.height = videoHeight;
    /* eslint-enable no-param-reassign */

    // run face detection on smaller canvas for performance
    const inputCanvasWidth = videoWidth / downscaleBy;
    const inputCanvasHeight = videoHeight / downscaleBy;
    inputCanvas.width = inputCanvasWidth;
    inputCanvas.height = inputCanvasHeight;

    const outputCanvasCtx = outputCanvas.getContext('2d');
    const inputCanvasCtx = inputCanvas.getContext('2d');
    if (!inputCanvasCtx || !outputCanvasCtx) return () => null;

    let isMounted = true;

    let rafId: number | null = null;
    const loop = async () => {
      try {
        if (!isMounted) return;

        inputCanvasCtx.drawImage(
          inputVideo,
          0,
          0,
          videoWidth,
          videoHeight,
          0,
          0,
          inputCanvasWidth,
          inputCanvasHeight
        );

        const imageData = inputCanvasCtx.getImageData(
          0,
          0,
          inputCanvasWidth,
          inputCanvasHeight
        );
        const detections = await getDetections(imageData, updateMemoryCb, {
          maxsize: Math.max(inputCanvasWidth, inputCanvasHeight),
          minsize: downscaledDetectionMinFaceSizeRef.current,
        });

        const validDetection =
          detections[0]?.score >= detectionScoreThreshold
            ? detections[0]
            : null;
        if (!isMounted) return;

        outputCanvasCtx.clearRect(0, 0, videoWidth, videoHeight);
        if (!validDetection) {
          setFaceStatus('NO');
        } else if (
          validDetection.scale < downscaledValidMinFaceSizeRef.current
        ) {
          if (shouldDrawDetection) {
            drawDetectionBox(
              outputCanvasCtx,
              validDetection,
              { boxColor: '#f003' },
              downscaleBy
            );
          }
          setFaceStatus('FAR');
          if (onDetectionRef.current)
            onDetectionRef.current('FAR', validDetection);
        } else {
          if (shouldDrawDetection) {
            drawDetectionBox(
              outputCanvasCtx,
              validDetection,
              { boxColor: '#fff' },
              downscaleBy
            );
          }
          setFaceStatus('YES');
          if (onDetectionRef.current)
            onDetectionRef.current('YES', validDetection);
        }
        rafId = window.requestAnimationFrame(loop);
      } catch (err) {
        console.error(err);
      }
    };

    loop();

    return () => {
      isMounted = false;
      if (rafId !== null) window.cancelAnimationFrame(rafId);
    };
  }, [
    inputVideo,
    outputCanvas,
    shouldDrawDetection,
    detectionScoreThreshold,
    downscaledValidMinFaceSizeRef,
    downscaledDetectionMinFaceSizeRef,
    onDetectionRef,
    updateMemoryCb,
    downscaleBy,
    detectionMinFaceSize,
  ]);

  return { faceStatus };
}
