/*
    wearechat - generic multiplayer game with voip
    Copyright (C) 2025 metamuffin
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published by
    the Free Software Foundation, version 3 of the License only.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.
    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see .
*/
use crate::download::Downloader;
use anyhow::Result;
use egui::{Grid, Widget};
use humansize::DECIMAL;
use image::ImageReader;
use log::debug;
use std::{
    collections::{HashMap, HashSet},
    hash::Hash,
    io::Cursor,
    marker::PhantomData,
    sync::{Arc, RwLock},
    time::Instant,
};
use weareshared::{
    Affine3A,
    packets::Resource,
    resources::{Image, MeshPart, Prefab},
};
use wgpu::{
    AddressMode, BindGroup, BindGroupDescriptor, BindGroupEntry, BindGroupLayout, BindingResource,
    Buffer, BufferUsages, Device, Extent3d, FilterMode, Queue, SamplerDescriptor, Texture,
    TextureDescriptor, TextureDimension, TextureFormat, TextureUsages, TextureViewDescriptor,
    util::{BufferInitDescriptor, DeviceExt, TextureDataOrder},
};
pub struct DemandMap {
    inner: RwLock>,
}
struct DemandMapState {
    values: HashMap,
    needed: HashSet,
    size_metric: usize,
}
impl DemandMap {
    pub fn new() -> Self {
        Self {
            inner: DemandMapState {
                needed: HashSet::new(),
                values: HashMap::new(),
                size_metric: 0,
            }
            .into(),
        }
    }
    pub fn needed(&self) -> Vec {
        self.inner.read().unwrap().needed.iter().cloned().collect()
    }
    pub fn insert(&self, key: K, value: V, size: usize) {
        let mut s = self.inner.write().unwrap();
        s.needed.remove(&key);
        s.values.insert(key, value);
        s.size_metric += size;
    }
    pub fn try_get(&self, key: K) -> Option {
        let mut s = self.inner.write().unwrap();
        if let Some(k) = s.values.get(&key) {
            Some(k.to_owned())
        } else {
            s.needed.insert(key);
            None
        }
    }
}
pub struct ScenePreparer {
    device: Arc,
    queue: Arc,
    texture_bgl: BindGroupLayout,
    textures: DemandMap>, (Arc, Arc)>,
    placeholder_textures: DemandMap, Arc)>,
    index_buffers: DemandMap>, (Arc, u32)>,
    vertex_buffers: DemandMap>, (Arc, u32)>,
    placeholder_vertex_buffers: DemandMap<(u32, bool), Arc>,
    mesh_parts: DemandMap, Arc>,
    pub prefabs: DemandMap, Arc>,
}
pub struct RPrefab(pub Vec<(Affine3A, Arc)>);
pub struct RMeshPart {
    pub index_count: u32,
    pub index: Arc,
    pub va_position: Arc,
    pub va_normal: Arc,
    pub va_texcoord: Arc,
    pub tex_albedo: Arc,
    pub tex_normal: Arc,
    pub double_sided: bool,
}
impl ScenePreparer {
    pub fn new(device: Arc, queue: Arc, texture_bgl: BindGroupLayout) -> Self {
        Self {
            texture_bgl,
            index_buffers: DemandMap::new(),
            vertex_buffers: DemandMap::new(),
            mesh_parts: DemandMap::new(),
            prefabs: DemandMap::new(),
            textures: DemandMap::new(),
            placeholder_vertex_buffers: DemandMap::new(),
            placeholder_textures: DemandMap::new(),
            device,
            queue,
        }
    }
    pub fn update(&self, dls: &Downloader) -> Result {
        let mut num_done = 0;
        for pres in self.prefabs.needed() {
            if let Some(prefab) = dls.try_get(pres.clone())? {
                let mut rprefab = RPrefab(Vec::new());
                for (aff, partres) in &prefab.mesh {
                    if let Some(part) = self.mesh_parts.try_get(partres.clone()) {
                        rprefab.0.push((*aff, part.clone()));
                    }
                }
                if rprefab.0.len() == prefab.mesh.len() {
                    self.prefabs.insert(pres.clone(), Arc::new(rprefab), 0);
                    debug!("prefab created ({pres})");
                    num_done += 1;
                }
            }
        }
        for pres in self.index_buffers.needed() {
            let start = Instant::now();
            if let Some(buf) = dls.try_get(pres.clone())? {
                let buf = buf
                    .into_iter()
                    .flatten()
                    .flat_map(u32::to_le_bytes)
                    .collect::>();
                let buffer = self.device.create_buffer_init(&BufferInitDescriptor {
                    contents: &buf,
                    label: None,
                    usage: BufferUsages::INDEX | BufferUsages::COPY_DST,
                });
                self.index_buffers.insert(
                    pres.clone(),
                    (Arc::new(buffer), (buf.len() / size_of::()) as u32),
                    buf.len(),
                );
                debug!(
                    "index buffer created (len={}, took {:?}) {pres}",
                    buf.len() / size_of::(),
                    start.elapsed(),
                );
                num_done += 1;
            }
        }
        for pres in self.vertex_buffers.needed() {
            let start = Instant::now();
            if let Some(buf) = dls.try_get(pres.clone())? {
                let buf = buf
                    .into_iter()
                    .flat_map(f32::to_le_bytes)
                    .collect::>();
                let buffer = self.device.create_buffer_init(&BufferInitDescriptor {
                    contents: &buf,
                    label: None,
                    usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
                });
                self.vertex_buffers.insert(
                    pres.clone(),
                    (Arc::new(buffer), (buf.len() / size_of::()) as u32),
                    buf.len(),
                );
                debug!(
                    "vertex attribute buffer created (len={}, took {:?}) {pres}",
                    buf.len() / size_of::(),
                    start.elapsed()
                );
                num_done += 1;
            }
        }
        for pres in self.textures.needed() {
            let start = Instant::now();
            if let Some(buf) = dls.try_get(pres.clone())? {
                let image = ImageReader::new(Cursor::new(buf.0)).with_guessed_format()?;
                let image = image.decode()?;
                let dims = (image.width(), image.height());
                let image = image.into_rgba8();
                let image = image.into_vec();
                let tex_bg = create_texture(
                    &self.device,
                    &self.queue,
                    &self.texture_bgl,
                    &image,
                    dims.0,
                    dims.1,
                );
                self.textures.insert(pres.clone(), tex_bg, image.len());
                debug!(
                    "texture created (res={}x{}, took {:?})",
                    dims.0,
                    dims.1,
                    start.elapsed()
                );
                num_done += 1;
            }
        }
        for variant in self.placeholder_textures.needed() {
            let v = if variant { 255 } else { 0 };
            let tex_bg = create_texture(
                &self.device,
                &self.queue,
                &self.texture_bgl,
                &[v, v, v, 255],
                1,
                1,
            );
            self.placeholder_textures.insert(variant, tex_bg, 4);
            num_done += 1;
        }
        for pres in self.mesh_parts.needed() {
            let start = Instant::now();
            if let Some(part) = dls.try_get(pres.clone())? {
                if let (Some(indexres), Some(positionres)) = (part.index, part.va_position) {
                    let index = self.index_buffers.try_get(indexres);
                    let position = self
                        .vertex_buffers
                        .try_get(Resource(positionres.0, PhantomData));
                    let vertex_count = position.as_ref().map(|(_, c)| *c / 3);
                    let normal = if let Some(res) = part.va_normal {
                        self.vertex_buffers
                            .try_get(Resource(res.0, PhantomData))
                            .map(|e| e.0)
                    } else {
                        vertex_count
                            .map(|vc| self.placeholder_vertex_buffers.try_get((vc * 4, false)))
                            .flatten()
                    };
                    let texcoord = if let Some(res) = part.va_texcoord {
                        self.vertex_buffers
                            .try_get(Resource(res.0, PhantomData))
                            .map(|e| e.0)
                    } else {
                        vertex_count
                            .map(|vc| self.placeholder_vertex_buffers.try_get((vc * 2, false)))
                            .flatten()
                    };
                    let mut tex_albedo = None;
                    if let Some(albedores) = part.tex_albedo {
                        if let Some((_tex, bg)) = self.textures.try_get(albedores) {
                            tex_albedo = Some(bg)
                        }
                    } else {
                        if let Some((_tex, bg)) = self.placeholder_textures.try_get(true) {
                            tex_albedo = Some(bg)
                        }
                    }
                    let mut tex_normal = None;
                    if let Some(albedores) = part.tex_normal {
                        if let Some((_tex, bg)) = self.textures.try_get(albedores) {
                            tex_normal = Some(bg)
                        }
                    } else {
                        if let Some((_tex, bg)) = self.placeholder_textures.try_get(false) {
                            tex_normal = Some(bg)
                        }
                    }
                    if let (
                        Some(va_normal),
                        Some((index, index_count)),
                        Some(va_texcoord),
                        Some((va_position, _)),
                        Some(tex_normal),
                        Some(tex_albedo),
                    ) = (normal, index, texcoord, position, tex_normal, tex_albedo)
                    {
                        debug!("part created (took {:?}) {pres}", start.elapsed());
                        self.mesh_parts.insert(
                            pres,
                            Arc::new(RMeshPart {
                                index_count,
                                index,
                                va_normal,
                                va_position,
                                va_texcoord,
                                tex_albedo,
                                tex_normal,
                                double_sided: part.g_double_sided.is_some(),
                            }),
                            0,
                        );
                        num_done += 1;
                    }
                }
            }
        }
        Ok(num_done)
    }
}
fn create_texture(
    device: &Device,
    queue: &Queue,
    bgl: &BindGroupLayout,
    data: &[u8],
    width: u32,
    height: u32,
) -> (Arc, Arc) {
    let texture = device.create_texture_with_data(
        &queue,
        &TextureDescriptor {
            label: None,
            size: Extent3d {
                depth_or_array_layers: 1,
                width,
                height,
            },
            mip_level_count: 1,
            sample_count: 1,
            dimension: TextureDimension::D2,
            format: TextureFormat::Rgba8UnormSrgb,
            usage: TextureUsages::TEXTURE_BINDING | TextureUsages::COPY_DST,
            view_formats: &[],
        },
        TextureDataOrder::LayerMajor,
        data,
    );
    let textureview = texture.create_view(&TextureViewDescriptor::default());
    let sampler = device.create_sampler(&SamplerDescriptor {
        address_mode_u: AddressMode::Repeat,
        address_mode_v: AddressMode::Repeat,
        mag_filter: FilterMode::Linear,
        min_filter: FilterMode::Linear,
        ..Default::default()
    });
    let bindgroup = device.create_bind_group(&BindGroupDescriptor {
        label: None,
        layout: &bgl,
        entries: &[
            BindGroupEntry {
                binding: 0,
                resource: BindingResource::TextureView(&textureview),
            },
            BindGroupEntry {
                binding: 1,
                resource: BindingResource::Sampler(&sampler),
            },
        ],
    });
    (Arc::new(texture), Arc::new(bindgroup))
}
impl Widget for &DemandMap {
    fn ui(self, ui: &mut egui::Ui) -> egui::Response {
        let state = self.inner.read().unwrap();
        ui.label(state.needed.len().to_string());
        ui.label(state.values.len().to_string());
        ui.label(humansize::format_size(state.size_metric, DECIMAL));
        ui.end_row();
        ui.response()
    }
}
impl Widget for &ScenePreparer {
    fn ui(self, ui: &mut egui::Ui) -> egui::Response {
        Grid::new("sp")
            .num_columns(4)
            .show(ui, |ui| {
                ui.label("prefabs");
                self.prefabs.ui(ui);
                ui.label("mesh_parts");
                self.mesh_parts.ui(ui);
                ui.label("vertex_buffers");
                self.vertex_buffers.ui(ui);
                ui.label("index_buffers");
                self.index_buffers.ui(ui);
                ui.label("placeholder_textures");
                self.placeholder_textures.ui(ui);
                ui.label("placeholder_vertex_buffers");
                self.placeholder_vertex_buffers.ui(ui);
                ui.label("textures");
                self.textures.ui(ui);
            })
            .response
    }
}