#!/usr/bin/env python3

import argparse
from collections import namedtuple
from enum import Enum, unique
from functools import partial
import gzip
import json
import logging
from math import ceil, floor
from multiprocessing import cpu_count
from multiprocessing.dummy import Pool as ThreadPool
import os
from pathlib import Path
import sys
from time import sleep
from typing import List, Iterable, Set
from urllib import request
from urllib.error import URLError, HTTPError

# hack so ArgumentParser can accept negative numbers
# see https://github.com/valhalla/valhalla/issues/3426
for i, arg in enumerate(sys.argv):
    if not len(arg) > 1:
        continue
    if (arg[0] == '-') and arg[1].isdigit():
        sys.argv[i] = ' ' + arg

description = """Downloads Tilezen's elevation tiles that either intersect with features in all .geojson files in
              --input-geojson-dir or that intersect with --bbox. NOTE: geojson method requires shapely.
              """

# set up the logger basics
LOGGER = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)5s: %(message)s"))
LOGGER.addHandler(handler)


# Argument parser
@unique
class TileCompression(Enum):
    UNCOMPRESSED = 1
    GZIP = 2
    LZ4 = 3

    @property
    def extension(self):
        if self is TileCompression.UNCOMPRESSED:
            return ""
        elif self is TileCompression.GZIP:
            return ".gz"
        elif self is TileCompression.LZ4:
            return ".lz4"


class DownloadStatus(Enum):
    OK = 1
    FAILED = 2
    CORRUPTED = 3


parser = argparse.ArgumentParser(description=description)
method = parser.add_mutually_exclusive_group()
method.add_argument(
    "-g",
    "--from-geojson",
    help="Absolute or relative path to directory with .geojson files used "
    "as input to tile intersection. Requires shapely.",
    type=Path,
)
method.add_argument(
    "-b",
    "--from-bbox",
    help="Bounding box coordinates in the format 'minX,minY,maxX,maxY'.",
    type=str,
)
method.add_argument(
    "-t",
    "--from-tiles",
    help="Only download tiles covered by the Valhalla graph (requires a Valhalla config JSON file).",
    action="store_true",
)
parser.add_argument(
    "-c",
    "--config",
    help="Absolute or relative path to the Valhalla config JSON. "
    "If present, can be used for download location and bbox.",
    type=Path,
)
parser.add_argument(
    "-i",
    "--inline-config",
    help="Inline JSON config, will override --config JSON if present",
    type=str,
    default='{}',
)
parser.add_argument(
    "-o",
    "--outdir",
    help="Absolute or relative path to the directory where the tiles will be saved. Overrides config JSON.",
    type=Path,
)
parser.add_argument(
    "-p", "--parallelism", help="Number of processing units to use.", type=int, default=cpu_count()
)
parser.add_argument(
    "-v",
    "--verbosity",
    help="Accumulative verbosity flags; -v: INFO, -vv: DEBUG",
    action="count",
    default=0,
)

group = parser.add_mutually_exclusive_group()
group.add_argument(
    "-d",
    "--decompress",
    help="If set, downloaded files will be decompressed (default format is gzip).",
    action="store_true",
)
group.add_argument(
    "-z",
    "--lz4",
    help="If set, downloaded files will be recompressed with LZ4. Requires lz4. "
    "Requires ~12%% more disk space (216GB total) vs gzip. While you pay upfront in extra CPU and space on the "
    "initial download, tiles will be several times faster to decompress vz gzip.",
    action="store_true",
)

Tile = namedtuple("Tile", ["name", "dir"])
LOCAL_SIZE = 0.25


def get_tile_info(
    x: int,
    y: int,
) -> Tile:
    hemisphere = "S" if y < 0 else "N"
    dir_name = "%s%02d" % (hemisphere, abs(y))
    tile_name = "%s%02d%s%03d.hgt" % (
        hemisphere,
        abs(y),
        "W" if x < 0 else "E",
        abs(x),
    )

    return Tile(tile_name, dir_name)


def grid_from_bounds(bounds: List[float]) -> List[List[int]]:
    """Creates a regular grid of size 1x1 within specified bounds"""
    # expand bbox by snapping to a 1 x 1 degree grid
    # loop through x and y range and create the grid
    min_x, min_y = [floor(x) for x in bounds[:2]]
    max_x, max_y = [ceil(x) for x in bounds[-2:]]
    for x in range(min_x, max_x):
        for y in range(min_y, max_y):
            yield [x, y, x + 1, y + 1]


def get_tiles_with_geojson(input_dir: Path) -> Set[Tile]:
    try:
        from shapely.geometry import Polygon, box
    except ImportError:
        LOGGER.critical(
            "Could not import shapely. Please install shapely or use another download method instead."
        )
        sys.exit(1)

    if not input_dir.is_dir() or not len(list(input_dir.glob('*.geojson'))) > 0:
        LOGGER.critical(
            f"Geojson directory does not exist or contains no GeoJSON files: {input_dir.resolve()}"
        )
        sys.exit(1)

    def get_outer_rings(input_dir: Path) -> Iterable[Polygon]:
        for file in input_dir.glob("*.geojson"):
            with open(file) as f:
                geojson = json.load(f)
            for feature in geojson["features"]:
                if feature["geometry"]["type"] == "Polygon":
                    yield Polygon(feature["geometry"]["coordinates"][0])
                if feature["geometry"]["type"] == "MultiPolygon":
                    for single_polygon in feature["geometry"]["coordinates"]:
                        yield Polygon(single_polygon[0])

    tile_infos = set()
    for poly in get_outer_rings(input_dir):
        for rect in grid_from_bounds(poly.bounds):
            if poly.intersects(box(*rect)):
                tile_x, tile_y, *_ = rect
                tile_infos.add(get_tile_info(tile_x, tile_y))

    return tile_infos


def get_tiles_with_bbox(bbox_str: str) -> Set[Tile]:
    try:
        bbox = min_x, min_y, max_x, max_y = [float(x) for x in bbox_str.split(",")]
    except ValueError:
        LOGGER.critical(f"BBOX {bbox_str} is not a comma-separated string of coordinates.")
        sys.exit(1)

    # validate bbox
    if min_x > max_x or min_x < -180 or max_x > 180 or min_y > max_y or min_y < -90 or max_y > 90:
        LOGGER.critical(f"Bbox invalid: {bbox}")
        sys.exit(1)

    LOGGER.debug(f"Received valid bbox: {bbox}")

    tile_infos = set()
    for grid in grid_from_bounds(bbox):
        tile_x, tile_y = grid[:2]
        tile_infos.add(get_tile_info(tile_x, tile_y))

    return tile_infos


def get_tiles_with_graph(graph_dir: Path) -> Set[Tile]:
    if not graph_dir.is_dir():
        LOGGER.critical(f"Graph directory {graph_dir.resolve()} does not exist.")
        sys.exit(1)

    tile_infos = set()
    local_dir = graph_dir.joinpath('2')
    tile_width = int(360 / LOCAL_SIZE)
    for tile_fp in local_dir.rglob('*.gph'):
        # turn the path into a tile ID
        tile_id = int(
            str(tile_fp.parent.relative_to(local_dir).joinpath(tile_fp.stem)).replace(os.sep, '')
        )
        tile_x, tile_y = floor(int(tile_id % tile_width) * LOCAL_SIZE - 180), floor(
            int(tile_id / tile_width) * LOCAL_SIZE - 90
        )

        tile_infos.add(get_tile_info(tile_x, tile_y))

    return tile_infos


def download(tile: Tile, output_dir, compression: TileCompression) -> bool:
    dest_directory = Path(output_dir, tile.dir)
    dest_directory.mkdir(parents=True, exist_ok=True)

    filepath = dest_directory.joinpath(tile.name + compression.extension)
    if filepath.is_file():
        # Skip if the file already exists
        return False

    url = f"https://elevation-tiles-prod.s3.us-east-1.amazonaws.com/skadi/{tile.dir}/{tile.name}.gz"

    LOGGER.info(f"Downloading tile {tile.name}")

    download_status = DownloadStatus.FAILED
    for i in range(5):
        try:
            LOGGER.debug(f"Downloading tile {tile.dir}/{tile.name} for the {i}th time.")

            # tries up to 24 mins with exponentially increasing sleeps in between, starting with 5 seconds
            sleep((i**2) / 2 * 30)
            with request.urlopen(url) as res, open(filepath, "wb") as f:

                if compression is TileCompression.GZIP:
                    f.write(res.read())
                else:
                    with gzip.GzipFile(fileobj=res, mode="rb") as gz:
                        try:
                            uncompressed = gz.read()
                        except Exception as e:
                            download_status = DownloadStatus.CORRUPTED
                            LOGGER.error(
                                f"Decompression error on tile {tile.dir}/{tile.name}: {e}. Likely a corrupted download."
                            )
                            continue
                        if compression is TileCompression.UNCOMPRESSED:
                            f.write(uncompressed)
                        elif compression is TileCompression.LZ4:
                            # Compression level 6 was chosen after some benchmarking as the approx efficient frontier
                            # between compression time and space savings (decompression time is roughly constant regardless
                            # of level). The end result is larger than the maximally gzipped tiles from AWS, but only
                            # by around 12%.
                            import lz4.frame

                            with lz4.frame.LZ4FrameCompressor(
                                block_size=lz4.frame.BLOCKSIZE_MAX4MB, compression_level=6
                            ) as compressor:
                                # Optimization: we know the exact size of every uncompressed hgt file
                                f.write(compressor.begin(25934402))
                                f.write(compressor.compress(uncompressed))
                                f.write(compressor.flush())

            download_status = DownloadStatus.OK
            LOGGER.debug(f"Successfully downloaded tile {tile.dir}/{tile.name}")
            break
        except HTTPError as e:
            LOGGER.error(f"Download failed with HTTP error {e.code}: {e.reason}.\nTrying again...")
            continue
        except URLError as e:
            LOGGER.error(
                f"Download failed of elevation tile {tile.dir}/{tile.name}: {e.reason}.\nTrying again.."
            )
            continue
        except ImportError:
            LOGGER.critical(
                "Could not import lz4. Please install lz4 or use another compression format."
            )
            sys.exit(1)

    if download_status == DownloadStatus.CORRUPTED:
        LOGGER.error(f"Tile {tile.dir}/{tile.name} was corrupted, removing it...")
        filepath.unlink()
        return False
    elif download_status == DownloadStatus.FAILED:
        LOGGER.error(f"Tile {tile.dir}/{tile.name} couldn't be downloaded...")
        return False

    return True


if __name__ == "__main__":
    args = parser.parse_args()

    # set the right logger level
    if args.verbosity == 0:
        LOGGER.setLevel(logging.CRITICAL)
    elif args.verbosity == 1:
        LOGGER.setLevel(logging.INFO)
    elif args.verbosity >= 2:
        LOGGER.setLevel(logging.DEBUG)

    config = dict()
    if args.config:
        with open(args.config) as f:
            config = json.load(f)

    if args.inline_config:
        config.update(**json.loads(args.inline_config))

    if args.outdir:
        elevation_fp = args.outdir
    elif config:
        elevation_fp = Path(config["additional_data"]["elevation"] or "elevation")
    else:
        LOGGER.critical("Either config or outdir is required.")
        sys.exit(1)

    if args.from_geojson:
        tiles = get_tiles_with_geojson(args.from_geojson)
    elif args.from_bbox:
        tiles = get_tiles_with_bbox(args.from_bbox)
    elif args.from_tiles:
        if not config:
            LOGGER.critical("--from-tiles requires a config to be specified.")
            sys.exit(1)
        tiles = get_tiles_with_graph(Path(config["mjolnir"]["tile_dir"]))
    else:
        LOGGER.critical("No download method specified.")
        sys.exit(1)

    tile_compression = TileCompression.GZIP
    if args.decompress:
        tile_compression = TileCompression.UNCOMPRESSED
    elif args.lz4:
        tile_compression = TileCompression.LZ4

    LOGGER.debug(sorted(tiles, key=lambda x: x.name))

    # create the threadpool and download
    pool = ThreadPool(args.parallelism)
    results = pool.imap_unordered(
        partial(download, output_dir=elevation_fp, compression=tile_compression), tiles
    )

    sum_downloaded = list(filter(lambda res: res is True, results))
    LOGGER.info(f"Downloaded {len(sum_downloaded)} tiles. Exiting.")
