import Grid2 from "@mui/material/Unstable_Grid2"
import * as dfd from 'danfojs'
import { useEffect, useState } from "react"
import { AgeGroupChart } from "../../components/charts/LineCharts"
import { AGE_GROUPS, PopulationGroupSizes } from "../municipality/ChangeBarChart"

export const calculateAgeGroups = (data?: dfd.DataFrame): any => {
    if (!data || !data.size) return

    // Get the year value, before dropping the column
    const year = (data.loc({ columns: ['year'] }).values[0] as any[])[0]; // TODO: bit fragile

    // Sum cohort sizes over all areas
    const cohortSizesSeries = data.copy().drop({ columns: ['year', 'area'] }).sum({ axis: 0 });
    const cohortSizes = new dfd.DataFrame([cohortSizesSeries.values], { columns: cohortSizesSeries.index.map(index => index.toString()) }); // back to DataFrame

    // Determine min and max ages
    let minAge = 99;
    let maxAge = 0;
    data.columns.forEach(column => {
        const age = parseInt(column.replace(/[fm]/, ''));
        if (!isNaN(age)) {
            minAge = Math.min(minAge, age);
            maxAge = Math.max(maxAge, age);
        }
    });
    let ageRangeSize = maxAge - minAge;

    // Generate age group labels and cohorts. TODO: this should encapsulated in a function
    const ageGroupLabels = [];
    const ageGroupBounds = [];
    const ageGroupCohorts = [];

    if (ageRangeSize > 20) { // If there are enough age groups, use 5-year intervals
        const REVERSED_AGE_GROUPS = AGE_GROUPS.slice().reverse();
        ageGroupLabels.push(...REVERSED_AGE_GROUPS.map(({ from, to }) => `${from}-${to}`));
        ageGroupBounds.push(...REVERSED_AGE_GROUPS.map(({ from, to }) => [from, to]));
        ageGroupCohorts.push(...REVERSED_AGE_GROUPS.map(({ from, to }) => {
            const femaleCohorts = [];
            const maleCohorts = [];
            for (let i = from; i <= to; i++) {
                if (i <= 99) { // 99 includes 100 and above
                    if (data.columns.includes(`f${i}`)) { // Check if the cohort exists, assume that f and m cohorts are the same
                        femaleCohorts.push(`f${i}`);
                        maleCohorts.push(`m${i}`);
                    }
                }
            }
            return [femaleCohorts, maleCohorts];
        }));
    } else { // Otherwise, use 1-year intervals
        const ageRange = Array.from({ length: maxAge - minAge + 1 }, (_, i) => i + minAge).slice().reverse();
        ageGroupLabels.push(...ageRange.map(age => `${age}`));
        ageGroupBounds.push(...ageRange.map(age => [age, age]));
        ageGroupCohorts.push(...ageRange.map(age => {
            const femaleCohorts = [`f${age}`];
            const maleCohorts = [`m${age}`];
            return [femaleCohorts, maleCohorts];
        }));
    }

    // Create the result data structure
    const result: { [year: string]: { male: any[], female: any[] } } = {};
    result[year] = { male: [], female: [] };

    // Create a new column for each age group, and set the value to the sum of the age groups
    for (let i = 0; i < ageGroupLabels.length; i++) {
        const label = ageGroupLabels[i];
        const bounds = ageGroupBounds[i];
        const cohorts = ageGroupCohorts[i];
        const ageFrom = bounds[0];
        const ageTo = bounds[1];

        if (ageFrom < minAge || (ageTo > maxAge && maxAge !== 99)) {
            continue;
        }

        // Sum up the cohort sizes
        const femaleCohorts = cohorts[0];
        const maleCohorts = cohorts[1];
        const numberOfFemales = cohortSizes.loc({ columns: femaleCohorts }).sum({ axis: 1 }).values[0] as number;
        const numberOfMales = cohortSizes.loc({ columns: maleCohorts }).sum({ axis: 1 }).values[0] as number; // Correct cohort for males

        result[year].female.push({
            sex: 'female',
            label: label,
            size: -Number(numberOfFemales.toFixed(2))
        });

        result[year].male.push({
            sex: 'male',
            label: label,
            size: Number(numberOfMales.toFixed(2))
        });
    }

    return [result];
};

export const findMaxAgeGroupSize = (populationData?: PopulationGroupSizes): number => {
        
    if (!populationData) {
        return 0;
    }

    let maxSize = 0;

    for (const entries of populationData.values()) {
        const males = Object.values(entries)[0].male;
        const females = Object.values(entries)[0].female;
        const maleGroupSizes = males.map((group: any) => Math.abs(group.size));
        const femaleGroupSizes = females.map((group: any) => Math.abs(group.size));
        const maxGroupSize = Math.max(...maleGroupSizes, ...femaleGroupSizes);
        if (maxGroupSize > maxSize) {
            maxSize = maxGroupSize;
        }
    }

    return maxSize;
}

export const DataframeAgeGroupChart: React.FC<{ currentYearData?: dfd.DataFrame, firstYearData?: dfd.DataFrame, maxSize: number }> = ({
    currentYearData,
    firstYearData,
    maxSize,
}) => {

    const [ageGroups, setAgeGroups] = useState<PopulationGroupSizes>()

    function maxAxisValue(maxSize: number): number {
        if (maxSize < 100) {
            return Math.ceil(maxSize * 1.10 / 10) * 10;
        }
        else {
            return Math.ceil(maxSize * 1.10 / 100) * 100;
        }
    }

    useEffect(() => {
        if (!currentYearData || !currentYearData.size) return
        const ageGroups = calculateAgeGroups(currentYearData).concat(calculateAgeGroups(firstYearData))
        setAgeGroups(ageGroups as any)
    }, [ currentYearData, firstYearData ])

    return (
        ageGroups ?
            <Grid2 container>
                <Grid2 xs={12}>
                    <AgeGroupChart
                        ageGroups={ageGroups}
                        maxSize={maxAxisValue(maxSize)}
                        minYear={Number(Object.keys(ageGroups[0]))}
                        maxYear={Number(Object.keys(ageGroups[1]))}
                        aspect={0.73}
                    />
                </Grid2>
            </Grid2>
            : null
    )
}