#!/usr/bin/env python3
# -*- coding: utf-8 -*-
###############################################################################
# $Id$
#
#  Project:  GDAL scripts
#  Purpose:  Perform a pansharpening operation
#  Author:   Even Rouault <even.rouault at spatialys.com>
#
###############################################################################
#  Copyright (c) 2015, Even Rouault <even.rouault at spatialys.com>
#  Copyright (c) 2021, Idan Miara <idan@miara.com>
#
#  Permission is hereby granted, free of charge, to any person obtaining a
#  copy of this software and associated documentation files (the "Software"),
#  to deal in the Software without restriction, including without limitation
#  the rights to use, copy, modify, merge, publish, distribute, sublicense,
#  and/or sell copies of the Software, and to permit persons to whom the
#  Software is furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included
#  in all copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
#  OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
#  THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
#  FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
#  DEALINGS IN THE SOFTWARE.
###############################################################################

import os
import os.path
import sys
from numbers import Real
from typing import List, Optional, Sequence, Union

from osgeo import gdal
from osgeo_utils.auxiliary.util import GetOutputDriverFor


def Usage(isError):
    f = sys.stderr if isError else sys.stdout
    print("Usage: gdal_pansharpen [--help] [--help-general]", file=f)
    print(
        "                       <pan_dataset> {<spectral_dataset>[,band=<num>]} {<spectral_dataset>[,band=<num>]}... <out_dataset>",
        file=f,
    )
    print(
        "                       [-of <format>] [-b <band>]... [-w <weight>]...", file=f
    )
    print(
        "                       [-r {nearest|bilinear|cubic|cubicspline|lanczos|average}]",
        file=f,
    )
    print(
        "                       [-threads {ALL_CPUS|<number>}] [-bitdepth <val>] [-nodata <val>]",
        file=f,
    )
    print(
        "                       [-spat_adjust {union|intersection|none|nonewithoutwarning}]",
        file=f,
    )
    print("                       [-verbose_vrt] [-co <NAME>=<VALUE>]... [-q]", file=f)
    print("", file=f)
    print("Create a dataset resulting from a pansharpening operation.", file=f)
    return 2 if isError else 0


def main(argv=sys.argv):
    argv = gdal.GeneralCmdLineProcessor(argv)
    if argv is None:
        return 0

    pan_name = None
    spectral_names = []
    spectral_ds = []
    spectral_bands = []
    band_nums = []
    weights = []
    driver_name = None
    creation_options = []
    progress_callback = gdal.TermProgress_nocb
    resampling = None
    spat_adjust = None
    verbose_vrt = False
    num_threads = None
    bitdepth = None
    nodata_value = None

    i = 1
    argc = len(argv)
    while i < argc:
        if (argv[i] == "-of" or argv[i] == "-f") and i < len(argv) - 1:
            driver_name = argv[i + 1]
            i = i + 1
        elif argv[i] == "-r" and i < len(argv) - 1:
            resampling = argv[i + 1]
            i = i + 1
        elif argv[i] == "-spat_adjust" and i < len(argv) - 1:
            spat_adjust = argv[i + 1]
            i = i + 1
        elif argv[i] == "-b" and i < len(argv) - 1:
            band_nums.append(int(argv[i + 1]))
            i = i + 1
        elif argv[i] == "-w" and i < len(argv) - 1:
            weights.append(float(argv[i + 1]))
            i = i + 1
        elif argv[i] == "-co" and i < len(argv) - 1:
            creation_options.append(argv[i + 1])
            i = i + 1
        elif argv[i] == "-threads" and i < len(argv) - 1:
            num_threads = argv[i + 1]
            i = i + 1
        elif argv[i] == "-bitdepth" and i < len(argv) - 1:
            bitdepth = argv[i + 1]
            i = i + 1
        elif argv[i] == "-nodata" and i < len(argv) - 1:
            nodata_value = argv[i + 1]
            i = i + 1
        elif argv[i] == "-q":
            progress_callback = None
        elif argv[i] == "-verbose_vrt":
            verbose_vrt = True
        elif argv[i] == "--help":
            return Usage(isError=False)
        elif argv[i][0] == "-":
            sys.stderr.write("Unrecognized option : %s\n" % argv[i])
            return Usage(isError=True)
        elif pan_name is None:
            pan_name = argv[i]
        else:
            spectral_names.append(argv[i])

        i = i + 1

    if pan_name is None or len(spectral_names) < 2:
        return Usage(isError=True)

    dst_filename = spectral_names.pop()
    return gdal_pansharpen(
        argv=None,
        pan_name=pan_name,
        spectral_names=spectral_names,
        spectral_ds=spectral_ds,
        spectral_bands=spectral_bands,
        band_nums=band_nums,
        weights=weights,
        dst_filename=dst_filename,
        driver_name=driver_name,
        creation_options=creation_options,
        resampling=resampling,
        spat_adjust=spat_adjust,
        num_threads=num_threads,
        bitdepth=bitdepth,
        nodata_value=nodata_value,
        verbose_vrt=verbose_vrt,
        progress_callback=progress_callback,
    )


def gdal_pansharpen(
    argv: Optional[Sequence[str]] = None,
    pan_name: Optional[str] = None,
    spectral_names: Optional[Sequence[str]] = None,
    spectral_ds: Optional[List[gdal.Dataset]] = None,
    spectral_bands: Optional[List[gdal.Band]] = None,
    band_nums: Optional[Sequence[int]] = None,
    weights: Optional[Sequence[float]] = None,
    dst_filename: Optional[str] = None,
    driver_name: Optional[str] = None,
    creation_options: Optional[Sequence[str]] = None,
    resampling: Optional[str] = None,
    spat_adjust: Optional[str] = None,
    num_threads: Optional[Union[int, str]] = None,
    bitdepth: Optional[Union[int, str]] = None,
    nodata_value: Optional[Union[Real, str]] = None,
    verbose_vrt: bool = False,
    progress_callback: Optional = gdal.TermProgress_nocb,
):
    if argv:
        # this is here for backwards compatibility
        return main(argv)

    spectral_names = spectral_names or []
    spectral_ds = spectral_ds or []
    spectral_bands = spectral_bands or []
    band_nums = band_nums or []
    weights = weights or []
    creation_options = creation_options or []

    if spectral_names:
        parse_spectral_names(
            spectral_names=spectral_names,
            spectral_ds=spectral_ds,
            spectral_bands=spectral_bands,
        )

    if pan_name is None or not spectral_bands:
        return 1

    pan_ds = gdal.Open(pan_name)
    if pan_ds is None:
        return 1

    if driver_name is None:
        driver_name = GetOutputDriverFor(dst_filename)

    if not band_nums:
        band_nums = [j + 1 for j in range(len(spectral_bands))]
    else:
        for band in band_nums:
            if band < 0 or band > len(spectral_bands):
                print("Invalid band number in -b: %d" % band)
                return 1

    if weights and len(weights) != len(spectral_bands):
        print("There must be as many -w values specified as input spectral bands")
        return 1

    vrt_xml = """<VRTDataset subClass="VRTPansharpenedDataset">\n"""
    if band_nums != [j + 1 for j in range(len(spectral_bands))]:
        for i, band in enumerate(band_nums):
            sband = spectral_bands[band - 1]
            datatype = gdal.GetDataTypeName(sband.DataType)
            colorname = gdal.GetColorInterpretationName(sband.GetColorInterpretation())
            vrt_xml += """  <VRTRasterBand dataType="%s" band="%d" subClass="VRTPansharpenedRasterBand">
      <ColorInterp>%s</ColorInterp>
  </VRTRasterBand>\n""" % (
                datatype,
                i + 1,
                colorname,
            )

    vrt_xml += """  <PansharpeningOptions>\n"""

    if weights:
        vrt_xml += """      <AlgorithmOptions>\n"""
        vrt_xml += """        <Weights>"""
        for i, weight in enumerate(weights):
            if i > 0:
                vrt_xml += ","
            vrt_xml += "%.16g" % weight
        vrt_xml += "</Weights>\n"
        vrt_xml += """      </AlgorithmOptions>\n"""

    if resampling is not None:
        vrt_xml += f"      <Resampling>{resampling}</Resampling>\n"

    if num_threads is not None:
        vrt_xml += f"      <NumThreads>{num_threads}</NumThreads>\n"

    if bitdepth is not None:
        vrt_xml += f"      <BitDepth>{bitdepth}</BitDepth>\n"

    if nodata_value is not None:
        vrt_xml += f"      <NoData>{nodata_value}</NoData>\n"

    if spat_adjust is not None:
        vrt_xml += (
            f"      <SpatialExtentAdjustment>{spat_adjust}</SpatialExtentAdjustment>\n"
        )

    pan_relative = "0"
    if driver_name.upper() == "VRT":
        if not os.path.isabs(pan_name):
            pan_relative = "1"
            pan_name = os.path.relpath(pan_name, os.path.dirname(dst_filename))

    vrt_xml += """    <PanchroBand>
      <SourceFilename relativeToVRT="%s">%s</SourceFilename>
      <SourceBand>1</SourceBand>
    </PanchroBand>\n""" % (
        pan_relative,
        pan_name,
    )

    for i, sband in enumerate(spectral_bands):
        dstband = ""
        for j, band in enumerate(band_nums):
            if i + 1 == band:
                dstband = ' dstBand="%d"' % (j + 1)
                break

        ms_relative = "0"
        ms_name = spectral_ds[i].GetDescription()
        if driver_name.upper() == "VRT":
            if not os.path.isabs(ms_name):
                try:
                    ms_name = os.path.relpath(ms_name, os.path.dirname(dst_filename))
                    ms_relative = "1"
                except ValueError:
                    # Thrown if generating a relative path is not possible, e.g. if
                    # ms_name is on a different Windows drive from dst_filename
                    pass

        vrt_xml += """    <SpectralBand%s>
      <SourceFilename relativeToVRT="%s">%s</SourceFilename>
      <SourceBand>%d</SourceBand>
    </SpectralBand>\n""" % (
            dstband,
            ms_relative,
            ms_name,
            sband.GetBand(),
        )

    vrt_xml += """  </PansharpeningOptions>\n"""
    vrt_xml += """</VRTDataset>\n"""

    if driver_name.upper() == "VRT":
        f = gdal.VSIFOpenL(dst_filename, "wb")
        if f is None:
            print("Cannot create %s" % dst_filename)
            return 1
        gdal.VSIFWriteL(vrt_xml, 1, len(vrt_xml), f)
        gdal.VSIFCloseL(f)
        if verbose_vrt:
            vrt_ds = gdal.Open(dst_filename, gdal.GA_Update)
            vrt_ds.SetMetadata(vrt_ds.GetMetadata())
        else:
            vrt_ds = gdal.Open(dst_filename)
        if vrt_ds is None:
            return 1

        return 0

    vrt_ds = gdal.Open(vrt_xml)
    out_ds = gdal.GetDriverByName(driver_name).CreateCopy(
        dst_filename, vrt_ds, 0, creation_options, callback=progress_callback
    )
    if out_ds is None:
        return 1
    return 0


def parse_spectral_names(
    spectral_names: Sequence[str],
    spectral_ds: List[gdal.Dataset],
    spectral_bands: List[gdal.Band],
):
    for spectral_arg in spectral_names:
        # add selected bands
        pos = spectral_arg.find(",band=")
        if pos > 0:
            spectral_name = spectral_arg[0:pos]
            ds = gdal.Open(spectral_name)
            if ds is None:
                return 1
            band_num = int(spectral_arg[pos + len(",band=") :])
            band = ds.GetRasterBand(band_num)
            spectral_ds.append(ds)
            spectral_bands.append(band)
        else:
            # add all bands
            spectral_name = spectral_arg
            ds = gdal.Open(spectral_name)
            if ds is None:
                return 1
            for j in range(ds.RasterCount):
                spectral_ds.append(ds)
                spectral_bands.append(ds.GetRasterBand(j + 1))


if __name__ == "__main__":
    sys.exit(main(sys.argv))
