/*
    wearechat - generic multiplayer game with voip
    Copyright (C) 2025 metamuffin
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published by
    the Free Software Foundation, version 3 of the License only.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.
    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see .
*/
use std::{
    collections::{HashMap, VecDeque},
    env::var,
    sync::mpsc::{Receiver, SyncSender, sync_channel},
    time::Instant,
};
use anyhow::{Result, anyhow};
use audiopus::{
    Application, Channels, SampleRate,
    coder::{Decoder, Encoder},
};
use cpal::{
    Stream,
    traits::{DeviceTrait, HostTrait},
};
use glam::Vec3;
use log::{debug, info, warn};
use nnnoiseless::{DenoiseState, RnnModel};
pub struct Audio {
    _instream: Option,
    _outstream: Option,
    rx: Receiver>,
    tx: SyncSender,
}
impl Audio {
    pub fn new() -> Result {
        let host = cpal::default_host();
        let indev = host
            .default_input_device()
            .ok_or(anyhow!("no input device"))?;
        let outdev = host
            .default_output_device()
            .ok_or(anyhow!("no output device"))?;
        let mut inconf = indev.default_input_config()?.config();
        inconf.channels = 1;
        inconf.sample_rate = cpal::SampleRate(48_000);
        inconf.buffer_size = cpal::BufferSize::Fixed(480);
        let mut outconf = outdev.default_input_config()?.config();
        outconf.channels = 2;
        outconf.sample_rate = cpal::SampleRate(48_000);
        outconf.buffer_size = cpal::BufferSize::Fixed(480);
        let (mut aenc, rx) = AEncoder::new()?;
        let (mut adec, tx) = ADecoder::new()?;
        let instream = if var("DISABLE_AUDIO_IN").is_err() {
            Some(indev.build_input_stream(
                &inconf,
                move |samples: &[f32], _| {
                    if let Err(e) = aenc.data(samples) {
                        warn!("encoder error: {e}");
                    }
                },
                |err| warn!("audio input error: {err}"),
                None,
            )?)
        } else {
            None
        };
        let outstream = if var("DISABLE_AUDIO_OUT").is_err() {
            Some(outdev.build_output_stream(
                &outconf,
                move |samples: &mut [f32], _| {
                    if let Err(e) = adec.data(samples) {
                        warn!("decoder error: {e}");
                    }
                },
                |err| warn!("audio output error: {err}"),
                None,
            )?)
        } else {
            None
        };
        Ok(Self {
            _instream: instream,
            _outstream: outstream,
            rx,
            tx,
        })
    }
    pub fn pop_output(&mut self) -> Option> {
        self.rx.try_recv().ok()
    }
    pub fn incoming_packet(&mut self, channel: u128, pos: Vec3, data: Vec) {
        if let Err(e) = self.tx.send(APlayPacket { pos, data, channel }) {
            warn!("audio output buffer overflow: {e:?}")
        }
    }
}
#[derive(Debug)]
pub struct APlayPacket {
    pos: Vec3,
    channel: u128,
    data: Vec,
}
const AE_FRAME_SIZE: usize = 480;
pub struct AEncoder {
    encoder: Encoder,
    sender: SyncSender>,
    buffer: VecDeque,
    noise_rnn: DenoiseState<'static>,
    trigger: VadTrigger,
}
struct VadTrigger {
    last_sig: Instant,
    transmitting: bool,
}
impl AEncoder {
    pub fn new() -> Result<(Self, Receiver>)> {
        let (sender, rx) = sync_channel(1024);
        Ok((
            Self {
                noise_rnn: *DenoiseState::from_model(RnnModel::default()),
                encoder: Encoder::new(SampleRate::Hz48000, Channels::Mono, Application::Voip)?,
                sender,
                buffer: VecDeque::new(),
                trigger: VadTrigger {
                    last_sig: Instant::now(),
                    transmitting: false,
                },
            },
            rx,
        ))
    }
    pub fn data(&mut self, samples: &[f32]) -> Result<()> {
        self.buffer.extend(samples);
        while self.buffer.len() >= AE_FRAME_SIZE {
            let mut out = [0u8; AE_FRAME_SIZE];
            let mut denoise = [0f32; AE_FRAME_SIZE];
            let mut raw = [0f32; AE_FRAME_SIZE];
            for i in 0..AE_FRAME_SIZE {
                raw[i] = self.buffer.pop_front().unwrap() * 32768.0;
            }
            self.noise_rnn.process_frame(&mut denoise, &raw);
            for e in &mut denoise {
                *e /= 32768.0;
            }
            let energy = measure_energy(&denoise);
            let (tx, end_tx) = self.trigger.update(energy);
            if tx {
                let size = self.encoder.encode_float(&denoise, &mut out)?;
                debug!("encoded size={size}");
                let _ = self.sender.try_send(out[..size].to_vec());
            }
            if end_tx {
                let _ = self.sender.try_send(vec![]);
            }
        }
        Ok(())
    }
}
impl VadTrigger {
    pub fn update(&mut self, energy: f32) -> (bool, bool) {
        let now = Instant::now();
        let thres = if self.transmitting { 0.1 } else { 1. };
        if energy > thres {
            self.last_sig = now;
        }
        let last_sig_elapsed = (now - self.last_sig).as_secs_f32();
        let prev_transmitting = self.transmitting;
        self.transmitting = last_sig_elapsed < 0.2;
        match (prev_transmitting, self.transmitting) {
            (false, true) => info!("start transmit"),
            (true, false) => info!("end transmit"),
            _ => (),
        }
        (self.transmitting, prev_transmitting && !self.transmitting)
    }
}
const BUFFER_SIZE: usize = 48_000;
const JITTER_COMP: usize = 48_00 * 2;
pub struct ADecoder {
    decoder: Decoder,
    receiver: Receiver,
    channels: HashMap,
    playback: usize,
    buffer: Box<[[f32; 2]; BUFFER_SIZE]>,
}
impl ADecoder {
    pub fn new() -> Result<(Self, SyncSender)> {
        let (tx, receiver) = sync_channel(1024);
        Ok((
            Self {
                decoder: Decoder::new(SampleRate::Hz48000, Channels::Mono)?,
                receiver,
                channels: HashMap::new(),
                playback: 0,
                buffer: unsafe { Box::new_zeroed().assume_init() },
            },
            tx,
        ))
    }
    pub fn data(&mut self, samples: &mut [f32]) -> Result<()> {
        for p in self.receiver.try_iter() {
            if p.data.is_empty() {
                if self.channels.remove(&p.channel).is_some() {
                    info!("channel {} ended", p.channel % 1_000_000)
                }
            } else {
                let mut output = [0f32; AE_FRAME_SIZE];
                let size = self.decoder.decode_float(
                    Some(p.data.as_slice()),
                    output.as_mut_slice(),
                    false,
                )?;
                let channel_cursor = self.channels.entry(p.channel).or_insert_with(|| {
                    info!("channel {} started", p.channel % 1_000_000);
                    (self.playback + JITTER_COMP) % BUFFER_SIZE
                });
                let free_space = *channel_cursor - self.playback;
                for i in 0..size.min(free_space) {
                    // TODO positional audio
                    let _ = p.pos;
                    self.buffer[*channel_cursor][0] += output[i];
                    self.buffer[*channel_cursor][1] += output[i];
                    *channel_cursor += 1;
                    *channel_cursor %= BUFFER_SIZE
                }
            }
        }
        for x in samples.array_chunks_mut::<2>() {
            *x = self.buffer[self.playback];
            self.buffer[self.playback] = [0.; 2];
            self.playback += 1;
            self.playback %= BUFFER_SIZE;
        }
        Ok(())
    }
}
fn measure_energy(samples: &[f32]) -> f32 {
    let mut e = 0.;
    for s in samples {
        e += *s * *s;
    }
    e
}