import { useRef, useEffect, useMemo } from "react";
import { Canvas, useFrame, useThree } from "@react-three/fiber";
import {
  type ShaderMaterial,
  Vector2,
  DataTexture,
  RGBAFormat,
  FloatType,
} from "three";
import * as THREE from "three";
import { EffectComposer, Bloom, Vignette } from "@react-three/postprocessing";

const auroraShader = {
  vertexShader: `
    uniform float time;
    varying vec2 vUv;

    void main() {
      vUv = uv;
      vec3 pos = position;
      pos.y += sin(pos.x * 5.0 + time * 2.0) * 0.5;
      pos.z += cos(pos.x * 5.0 + time * 2.0) * 0.2;

      gl_Position = projectionMatrix * modelViewMatrix * vec4(pos, 1.0);
    }
  `,
  fragmentShader: `
    uniform float time;
    uniform vec2 resolution;
    uniform sampler2D colorPalette;
    uniform float colorCount;

    varying vec2 vUv;

    float random(vec2 p) {
      return fract(sin(dot(p.xy, vec2(12.9898, 78.233))) * 43758.5453);
    }

    float noise(vec2 p) {
      vec2 i = floor(p);
      vec2 f = fract(p);
      vec2 u = f * f * (3.0 - 2.0 * f);

      return mix(
        mix(random(i), random(i + vec2(1.0, 0.0)), u.x),
        mix(random(i + vec2(0.0, 1.0)), random(i + vec2(1.0, 1.0)), u.x),
        u.y
      );
    }

    float fbm(vec2 p) {
      float value = 0.0;
      float amplitude = 0.5;
      float frequency = 1.0;

      for (int i = 0; i < 3; i++) {
        value += amplitude * noise(p * frequency);
        frequency *= 2.0;
        amplitude *= 0.5;
      }
      return value;
    }

    vec3 getColorFromPalette(float index) {
      // Ensure smooth wrapping around the palette
      float wrappedIndex = mod(index, colorCount);
      return texture2D(colorPalette, vec2(wrappedIndex / colorCount, 0.5)).rgb;
    }

    vec3 getSmoothColor(float timeOffset) {
      float slowTime = time * 0.1; // Slow down the color transitions
      float baseIndex = slowTime;
      
      // Get fractional part for smooth interpolation
      float fraction = fract(baseIndex);
      
      // Get the two colors to interpolate between
      vec3 color1 = getColorFromPalette(floor(baseIndex) + timeOffset);
      vec3 color2 = getColorFromPalette(floor(baseIndex) + 1.0 + timeOffset);
      
      // Smooth interpolation between colors
      float smoothFraction = smoothstep(0.0, 1.0, fraction);
      return mix(color1, color2, smoothFraction);
    }

    void main() {
      vec2 st = vUv * resolution.xy / min(resolution.x, resolution.y);
      vec3 color = vec3(0.0);

      float timeShift = time * 0.05;
      float n = fbm(st + timeShift);

      // Get two smoothly interpolated colors
      vec3 colorA = getSmoothColor(0.0);
      vec3 colorB = getSmoothColor(colorCount * 0.5); // Offset by half the palette
      
      // Add some variation based on sine waves
      float sineFactor = sin(time * 0.2) * 0.5 + 0.5;
      colorA = mix(colorA, getSmoothColor(2.0), sineFactor);
      colorB = mix(colorB, getSmoothColor(colorCount * 0.5 + 2.0), sineFactor);

      // Smooth final color mixing
      color = mix(colorA, colorB, smoothstep(0.3, 0.8, n));
      gl_FragColor = vec4(color, 1.0);
    }
  `,
};

const createColorTexture = (colors: string[]) => {
  const data = new Float32Array(colors.length * 4);
  colors.forEach((color, i) => {
    const threeColor = new THREE.Color(color);
    data[i * 4] = threeColor.r;
    data[i * 4 + 1] = threeColor.g;
    data[i * 4 + 2] = threeColor.b;
    data[i * 4 + 3] = 1.0;
  });

  const texture = new DataTexture(
    data,
    colors.length,
    1,
    RGBAFormat,
    FloatType
  );
  texture.needsUpdate = true;
  return texture;
};

const AuroraMaterial = ({ colorPalette }: { colorPalette: string[] }) => {
  const materialRef = useRef<ShaderMaterial>(null);

  const uniforms = useMemo(
    () => ({
      time: { value: 0 },
      resolution: { value: new Vector2(window.innerWidth, window.innerHeight) },
      colorPalette: { value: createColorTexture(colorPalette) },
      colorCount: { value: colorPalette.length },
    }),
    []
  );

  useFrame(({ clock }) => {
    if (materialRef.current) {
      materialRef.current.uniforms.time.value = clock.getElapsedTime();
    }
  });

  useEffect(() => {
    if (materialRef.current) {
      materialRef.current.uniforms.colorPalette.value =
        createColorTexture(colorPalette);
      materialRef.current.uniforms.colorCount.value = colorPalette.length;
    }
  }, [colorPalette]);

  return (
    <shaderMaterial
      ref={materialRef}
      attach="material"
      args={[
        {
          vertexShader: auroraShader.vertexShader,
          fragmentShader: auroraShader.fragmentShader,
          uniforms,
        },
      ]}
    />
  );
};

const AuroraPlane = ({ colorPalette }: { colorPalette: string[] }) => {
  const { size } = useThree();
  const planeSize = useMemo(() => {
    const aspect = size.width / size.height;
    const height = 10;
    const width = height * aspect;
    return [width, height];
  }, [size]);

  return (
    <mesh>
      <planeGeometry args={[planeSize[0], planeSize[1], 16, 16]} />
      <AuroraMaterial colorPalette={colorPalette} />
    </mesh>
  );
};

const AuroraGPU = ({ colorPalette }: { colorPalette: string[] }) => {
  return (
    <Canvas
      onCreated={({ camera, size }) => {
        if (camera instanceof THREE.PerspectiveCamera) {
          camera.aspect = size.width / size.height;
          camera.updateProjectionMatrix();
        }
      }}
      className="size-full overflow-hidden"
    >
      <AuroraPlane colorPalette={colorPalette} />
      <EffectComposer>
        <Bloom intensity={0.1} luminanceThreshold={0.6} />
        <Vignette eskil={false} offset={0.1} darkness={1.1} />
      </EffectComposer>
    </Canvas>
  );
};

export default AuroraGPU;
