#! /usr/bin/python3

import sys
import os
from os.path import expanduser, dirname
import subprocess
import json
import argparse
from pathlib import Path
import sqlite3
import locale
import xml.etree.ElementTree
import logging
import traceback
from urllib.parse import urlparse, parse_qs
from datetime import datetime
import threading

import paho.mqtt.client as mqtt
import yaml
import pafy
import requests

from mpv import MPV, MpvEventID, MpvEventEndFile

version = "3.3.1"

class MpvEventEndFile:
    EOF      = 0
    STOP     = 2
    QUIT     = 3
    ERROR    = 4
    REDIRECT = 5

# https://github.com/jaseg/python-mpv/blob/master/mpv.py#L582
# but with timeout
def mpv_wait_for_property(self, name, cond=lambda val: val, level_sensitive=True, timeout=1):
    """Waits until ``cond`` evaluates to a truthy value on the named property. This can be used to wait for
    properties such as ``idle_active`` indicating the player is done with regular playback and just idling around
    """
    sema = threading.Semaphore(value=0)
    def observer(name, val):
        if cond(val):
            sema.release()
    self.observe_property(name, observer)
    ret = True
    if not level_sensitive or not cond(getattr(self, name.replace('-', '_'))):
        ret = sema.acquire(timeout=timeout)
    try:
        self.unobserve_property(name, observer)
    except ValueError:
        pass
    return ret

def xdg_config_dir():
    config_home = os.getenv("XDG_CONFIG_HOME",
                            os.path.expanduser("~/.config/"))
    return os.path.join(config_home, "zeromedia")

def xdg_data_dir():
    config_home = os.getenv("XDG_DATA_HOME",
                            os.path.expanduser("~/.data/"))
    return os.path.join(config_home, "zeromedia")

class Player:
    def __init__(self, options, volume_cb, pause_cb, seek_cb, stop_cb, error_cb):
        default_options = {
            "ytdl": True,
            "input_default_bindings": True,
            "input_vo_keyboard": True,
            "config": True
        }
        default_options.update(options)

        extra_mpv_flags = [ k   for k,v in default_options.items() if v is None]
        extra_mpv_opts  = { k:v for k,v in default_options.items() if v is not None}

        self._mpv = MPV(*extra_mpv_flags, **extra_mpv_opts)

        # self._mpv.unregister_key_binding('q')
        # self._mpv.unregister_key_binding('Q')

        self.time_pos = None

        self._volume_cb = volume_cb
        self._pause_cb  = pause_cb
        self._stop_cb   = stop_cb
        self._seek_cb   = seek_cb
        self._error_cb  = error_cb

        @self._mpv.property_observer('time-pos')
        def time_pos(name, pos):
            if pos is not None:
                self.time_pos = pos

        @self._mpv.property_observer('pause')
        def pause(name, pause):
            self._pause_cb(self.time_pos)

        @self._mpv.property_observer('volume')
        def volume(name, vol):
            self._volume_cb(vol)

        def event(ev):
            if ev["event_id"] == MpvEventID.END_FILE:
                if ev["event"]["reason"] in (MpvEventEndFile.STOP,
                                             MpvEventEndFile.QUIT,
                                             ):
                    self._stop_cb(self.time_pos)
                    logging.debug("Player: stopped at %s", str(self.time_pos))
                elif ev["event"]["reason"] in (MpvEventEndFile.EOF, ):
                    self._stop_cb(None)
                    logging.debug("Player: end of file")
                elif ev["event"]["reason"] in (MpvEventEndFile.ERROR, ):
                    self._error_cb("file error")
                    self._file_error = True
                    logging.critical("Player: file error")
            elif ev["event_id"] in (MpvEventID.PLAYBACK_RESTART, ):
                self._seek_cb(self.time_pos)

        self._mpv.register_event_callback(event)

    def play(self, file, title=None):
        self.stop()

        self._file_error = False

        self._mpv.loadfile(file)
        self._mpv.pause = False

        # wait for effective playback
        while True:
            if mpv_wait_for_property(self._mpv, "path") and mpv_wait_for_property(self._mpv, "time-pos"):
                break

            if self._file_error:
                break

        if not self._file_error:
            if title is None:
                self._mpv.command("show-text", self._mpv.media_title)
            else:
                self._mpv.command("show-text", title)

            logging.debug("Player: playing %s", self._mpv.path)

    def pause(self):
        if self._mpv.pause:
            self._mpv.pause = False
        else:
            self._mpv.pause = True

    def stop(self):
        self._mpv.command("stop")

    def position(self, pos=None):
        if pos is not None:
            self._mpv.time_pos = pos
        self._mpv.command("osd-msg-bar", "show-progress")

    def seek(self, offset):
        self._mpv.command("osd-msg-bar", "seek", offset)

    def set_volume(self, val=None, add=None):
        try:
            if add is not None:
                add = int(add)
                self._mpv.command("osd-msg-bar", "add", "volume", add)
            elif val is not None:
                val = int(val)
                self._mpv.command("osd-msg-bar", "set", "volume", val)
        except TypeError:
            pass

        logging.debug("Player: volume %s", str(self._mpv.volume))

    def mute(self):
        if self._mpv.mute:
            self._mpv.mute = False
        else:
            self._mpv.mute = True

    def get_status(self):
        status = {}
        status["current"]     = self._mpv.path
        status["time_pos"]    = self._mpv.time_pos
        status["percent_pos"] = self._mpv.percent_pos
        status["pause"]       = self._mpv.pause
        status["volume"]      = self._mpv.volume
        status["duration"]    = self._mpv.duration

        return status

class History:
    _create_table = """CREATE TABLE IF NOT EXISTS history (
    path TEXT,
    timestamp TEXT,
    position REAL,
    deleted INTEGER,
    PRIMARY KEY (path)
    );
    """

    _update = "UPDATE history SET timestamp = datetime('now'), position = ? WHERE path = ?;"
    _insert = "INSERT OR IGNORE INTO history VALUES (?, datetime('now'), ? , 0);"
    _last   = """SELECT path, position FROM history
                 WHERE path GLOB ?
                 ORDER BY datetime(timestamp) DESC LIMIT 1;
              """

    def __init__(self, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        self._conn = sqlite3.connect(path, check_same_thread=False)
        self._conn.row_factory = sqlite3.Row

        c = self._conn
        c.execute(self._create_table)
        c.commit()

    def record(self, path, position):
        logging.info("History: position recorded")
        c = self._conn
        c.execute(self._update, (position, path))
        c.execute(self._insert, (path, position))
        c.commit()

    def last(self, pattern):
        if "*" not in pattern:
            pattern = "*" + pattern + "*"
        c = self._conn
        cur = c.execute(self._last, (pattern,))
        path, position = cur.fetchone()
        return (path, position)

class Filesystem:
    def _relative(self, name, base, fullpath):
        path = Path(fullpath)
        return name + ":" + str(path.relative_to(base))

    def list(self, name, base, suffix=None):
        base = base[len("file://"):]
        path = Path(base)

        if suffix is not None:
            path = path / suffix

        files = []
        dirs  = []
        for f in path.iterdir():
            if f.is_dir():
                dirs.append(self._relative(name, base, f) + "/")
            else:
                files.append(self._relative(name, base, f))

        files.sort(key=locale.strxfrm)
        files = [ {"path": f} for f in files ]
        dirs.sort(key=locale.strxfrm)

        return {"files": files, "directories": dirs}

    def _iter_files(self, path):
        for root, dirs, files in os.walk(path):
            files.sort(key=locale.strxfrm)
            dirs.sort(key=locale.strxfrm)

            for f in files:
                yield os.path.join(root, f)

    def prevnext(self, name, backend, suffix):
        base   = backend[len("file://"):]
        target = os.path.join(base, suffix)

        prev_file = None
        iter = self._iter_files(base)
        for f in iter:
            if f == target:
                break
            prev_file = f

        try:
            next_file = next(iter)
        except StopIteration:
            next_file = None

        next_file = self._relative(name, base, next_file) if next_file is not None else None
        prev_file = self._relative(name, base, prev_file) if prev_file is not None else None

        return prev_file, next_file

class Youtube:
    def _url(self, playlist_id, video_id):
        return "https://www.youtube.com/watch?v={video_id}&list={playlist_id}".format(playlist_id=playlist_id, video_id=video_id)

    def match(self, path):
        print("Youtube spotted")
        for start in ["https://www.youtube.com/watch", "https://www.youtube.com/playlist", "https://www.youtube.com/channel", "https://www.youtube.com/user"]:
            if path.startswith(start):
                return True

    def _need_feed(self, path):
        for start in ["https://www.youtube.com/channel", "https://www.youtube.com/user"]:
            if path.startswith(start):
                return True

    def get_feed(self, url):
        if url.startswith("https://www.youtube.com/channel"):
            channel_id = url.split("/")[4]
            feed_url = "https://www.youtube.com/feeds/videos.xml?channel_id=" + channel_id
        if url.startswith("https://www.youtube.com/user"):
            username = url.split("/")[4]
            feed_url = "https://www.youtube.com/feeds/videos.xml?user=" + username

        feed = requests.get(feed_url)
        xmltree = xml.etree.ElementTree.fromstring(feed.content)
        videos = [{"path": x.find('{http://www.w3.org/2005/Atom}link').attrib["href"], "title": x.find('{http://www.w3.org/2005/Atom}title').text}
                  for x in reversed(xmltree.findall('{http://www.w3.org/2005/Atom}entry'))]

        return videos

    def get_playlist(self, url):
        pl = pafy.get_playlist(url)

        for item in pl["items"]:
            item["added"] = datetime.strptime(item["playlist_meta"]["added"], "%d/%m/%Y")

        pl["items"].sort(key=lambda x: x["added"])

        return [x["pafy"] for x in pl["items"]], pl["playlist_id"]

    def list(self, url, name=None):
        if self._need_feed(url):
            return {"files": self.get_feed(url), "directories": []}

        pl, playlist_id = self.get_playlist(url)

        files = []
        for idx, vid in enumerate(pl):
            data = {}
            if name is None:
                data["path"] = self._url(playlist_id, vid.videoid)
            else:
                data["path"] = name + ":" + str(idx)
            data["title"] = vid.title
            data["duration"] = vid.duration

            files.append(data)

        return {"files": files, "directories": []}

    def playlist(self, url, index):
        index = int(index)

        pl, playlist_id = self.get_playlist(url)
        if index < len(pl):
            vid = pl[index]

            return self._url(playlist_id, vid.videoid)
        else:
            raise ValueError("index too large")

    def play(self, url):
        if url.startswith("https://www.youtube.com/watch"):
            vid = pafy.new(url)
        elif self._need_feed(url):
            pl = self.get_feed(url)
            vid = pafy.new(pl[-1]["path"])
        else:
            pl, _ = self.get_playlist(url)
            vid = pl[-1]

        stream = vid.getbest()
        path   = stream.url

        return path, vid.title

    def prevnext(self, name, backend, suffix):
        if backend is None:
            path = name + ":" + suffix

            if "list=" not in path:
                return
            videoid = pafy.new(path).videoid
            playlist, playlist_id = self.get_playlist(path)

            prev_id = None
            iterator = iter(playlist)
            for v in iterator:
                vid = v.videoid
                if videoid == vid:
                    break
                prev_id = vid

            try:
                next_id = next(iterator).videoid
            except StopIteration:
                next_id = None

            prev_url = self._url(playlist_id, prev_id)
            next_url = self._url(playlist_id, next_id)

            return prev_url, next_url
        else:
            playlist, _ = self.get_playlist(backend)
            max_idx = len(playlist)

            index = int(suffix)
            prev_idx = index - 1
            next_idx = index + 1

            if prev_idx < 0:
                prev_path = None
            else:
                prev_path = name+":"+str(prev_idx)
            if next_idx > max_idx:
                next_path = None
            else:
                next_path = name+":"+str(next_idx)

        return prev_path, next_path

class DVD:
    def __init__(self, dvd_device=None):
        self._dvd_device = dvd_device

    def _lsdvd(self):
        if self._dvd_device is None:
            command = ["lsdvd", "-Ox"]
        else:
            command = ["lsdvd", "-Ox", self._dvd_device]

        try:
            output = subprocess.check_output(command, stderr=subprocess.DEVNULL)
        except subprocess.CalledProcessError as e:
            return None

        return xml.etree.ElementTree.fromstring(output)

    def list(self, name):
        xml = self._lsdvd()
        if xml is None:
            return {"files": [], "directories": []}

        tracks = [ { "path": name+":"+str(int(x.find('ix').text) - 1), "duration": float(x.find('length').text) }
                   for x in xml.findall('track') ]
        return {"directories": [], "files": tracks}

    def prevnext(self, name, backend, suffix):
        max_track = len(self.list(name)["files"])
        suffix = int(suffix)

        prev_track = suffix - 1
        next_track = suffix + 1
        if prev_track < 0:
            prev_track = None
        if next_track > max_track:
            next_track = None

        return name+":"+str(prev_track), name+":"+str(next_track)

    def eject(self):
        if self._dvd_device is None:
            command = ["eject"]
        else:
            command = ["eject", self._dvd_device]

        subprocess.call(command)

class Dispatcher:
    def __init__(self, mpv_options, channels, status_cb):
        self._channels   = channels
        self._status_cb  = status_cb

        self._asked_file = None
        self._error      = None

        self._player     = Player(mpv_options,
            volume_cb=lambda _: status_cb(self.status()),
            stop_cb=self._record_position,
            seek_cb=lambda _: status_cb(self.status()),
            pause_cb=lambda _: status_cb(self.status()),
            error_cb=self._set_error)
        self._filesystem = Filesystem()
        self._dvd        = DVD(mpv_options.get("dvd-device", None))
        self._youtube    = Youtube()
        self._history    = History(os.path.join(xdg_data_dir(), "history.db"))

    def _record_position(self, pos):
        self._history.record(self._asked_file, pos)
        self._status_cb(self.status())

    def _set_error(self, error):
        self._error = error
        self._status_cb(self.status())

    def _playing(self):
        status = self._player.get_status()
        return status["current"] is not None

    def _paused(self):
        status = self._player.get_status()
        return status["pause"]

    def status(self):
        status = self._player.get_status()
        status["current"] = None if status["current"] is None else self._asked_file
        status["error"]   = self._error

        return status

    def _play(self, zpath):
        logging.info("Zeromedia: play %s", zpath)
        self._error = None

        title = None
        path = zpath
        parts = path.split(":", 1)
        if len(parts) == 2:
            name    = parts[0]
            suffix  = parts[1]
            if name in self._channels:
                backend = self._channels[name]

                if type(backend) is str:
                    if len(suffix) == 0:
                        path = backend
                    elif backend.startswith("file://"):
                        path = os.path.join(backend, suffix)
                    elif backend.startswith("https://www.youtube.com/playlist"):
                        path = self._youtube.playlist(backend, suffix)
                    elif backend.startswith("dvd://") or backend.startswith("dvdread://"):
                        path = backend + suffix

                elif type(backend) is dict:
                    path = backend[suffix]

        if path.startswith("https://www.youtube.com"):
            path, title = self._youtube.play(path)

        self._asked_file = zpath
        self._player.play(path, title=title)

    def do(self, cmd, args):
        try:
            func = getattr(self, "do_" + cmd)
        except AttributeError:
            logging.warning("Unknown command: %s", cmd)
        else:
            try:
                ret = func(**args)
                if ret is None:
                    return self.status()
                else:
                    return ret
            except ValueError as e:
                logging.warning("Value error in %s: %s", cmd, " ".join(str(s) for s in e.args))
                return {"error": "value error", "details": e.args}
            except TypeError as e:
                logging.warning("Type error in %s: %s", cmd, " ".join(str(s) for s in e.args))
                return {"error": "type error", "details": e.args}
            except:
                e = sys.exc_info()[1]
                logging.error("Unknown exception in %s: %s", cmd, str(e))
                traceback.print_exc()
                return {"error": "unknown error"}

    def do_status(self):
        return self.status()

    def do_play(self, path):
        self._play(path)

    def do_pause(self):
        logging.info("Zeromedia: pause")
        self._player.pause()

    def do_stop(self):
        self._player.stop()

    def do_position(self, pos=None):
        self._player.position(pos=pos)

    def do_seek(self, offset):
        self._player.seek(offset)

    def do_volume(self, val=None, add=None):
        if val is not None and val >= 130:
            val = 130
        self._player.set_volume(val=val, add=add)

    def do_mute(self):
        self._player.mute()

    def do_last(self, pattern="*"):
        path, position = self._history.last(pattern)
        logging.info("Zeromedia: last %s %s", str(position), path)
        return {"path": path, "position": position}

    def do_continue(self, pattern="*"):
        if self._paused():
            self._player.pause()
        if self._playing():
            return

        path, position = self._history.last(pattern)
        logging.info("Zeromedia: continue %s %s", str(position), path)
        if position is None:
            self._nextprev(path, +1)
        else:
            self._play(path)
            self._player.position(pos=position)

    def _split_path(self, path):
        name, suffix = path.split(":", 1)
        backend      = self._channels.get(name, None)
        return name, backend, suffix

    def _nextprev(self, path, direction):
        name, backend, suffix = self._split_path(path)

        if backend is not None and backend.startswith("file://"):
            prev, next = self._filesystem.prevnext(name, backend, suffix)
        elif backend is not None and backend.startswith("dvd://"):
            prev, next = self._dvd.prevnext(name, backend, suffix)
        elif path.startswith("https://www.youtube.com") or backend.startswith("https://www.youtube.com"):
            prev, next = self._youtube.prevnext(name, backend, suffix)
        else:
            return

        if direction > 0:
            if next is None:
                raise ValueError("next not found")
            else:
                self._play(next)
        else:
            if prev is None:
                raise ValueError("previous not found")
            else:
                self._play(prev)

    def do_next(self):
        status = self.status()
        if status["current"] is None:
            self.do_continue()
        else:
            self._nextprev(status["current"], +1)

    def do_previous(self):
        status = self.status()
        if status["current"] is None:
            return None
        else:
            self._nextprev(status["current"], -1)

    def do_screen(self, state):
        logging.info("Player: screen %s", str(state))
        if state:
            subprocess.call(["xset", "dpms", "force", "on"])
        else:
            subprocess.call(["xset", "dpms", "force", "off"])

    def do_eject(self, state=None):
        self._dvd.eject()

    def do_list(self, path=None):
        answer = {}

        if not path:
            channels = [ ch + ":" for ch in self._channels ]
            channels.sort(key=locale.strxfrm)
            answer = {"directories": channels, "files": [], "path": ""}
        else:
            parts   = path.split(":", 1)
            name    = parts[0]
            if name in self._channels:
                backend = self._channels[name]
                if type(backend) is dict:
                    answer =  {"directories": [], "files": [{"path": name + ":" + key} for key in backend]}
                elif backend.startswith("file://"):
                    if len(parts) == 2:
                        answer = self._filesystem.list(name, backend, suffix=parts[1])
                    else:
                        answer = self._filesystem.list(name, backend)
                elif backend.startswith("dvd://") or backend.startswith("dvdread://"):
                    answer = self._dvd.list(name)
                elif backend.startswith("https://www.youtube.com"):
                    answer = self._youtube.list(backend, name)
            elif path.startswith("https://www.youtube.com"):
                answer = self._youtube.list(path)
            else:
                answer = {"error": "non-existent channel"}

        answer["path"] = path
        return answer

if __name__ == "__main__":
    locale.resetlocale()
    locale.setlocale(locale.LC_NUMERIC, "C")

    parser = argparse.ArgumentParser(description='Zeromedia Server')
    parser.add_argument('--version', action='version', version='%(prog)s ' + version)
    parser.add_argument("-c", "--config",
                        dest    = "config",
                        default = os.path.join(xdg_config_dir(), "configuration.yml"),
                        help    = "Configuration file")
    parser.add_argument('-v', '--verbose',
                        dest    = 'verbose',
                        choices = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
                        default = 'WARNING',
                        metavar = 'VERBOSITY',
                        help    = 'verbosity level')
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.safe_load(f)
        if "channels" not in config or config["channels"] is None:
            config["channels"] = {}

    logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=getattr(logging, args.verbose), filename="zeromedia.log")
    console = logging.StreamHandler()
    logging.getLogger('').addHandler(console)

    topic  = config["mqtt"]["topic"]
    client = mqtt.Client()
    client.will_set(topic + "/status", json.dumps({'error': "server disconnected",
                                        'volume':0,
                                        'percent_pos': 0,
                                        'pause': False,
                                        'current': None,
                                        'time_pos': 0}))

    def on_status(status):
        client.publish(topic + "/status", json.dumps(status))

    dispatcher = Dispatcher(config.get("mpv", {}), config["channels"], on_status)

    def on_connect(client, userdata, flags, rc):
        if int(rc) == 0:
            logging.info("MQTT: Connection successful")
        else:
            logging.critical("MQTT: Connection error %s", str(rc))
        client.subscribe(topic + "/command/#")
    client.on_connect = on_connect

    def on_message(client, userdata, msg):
        logging.debug("MQTT: %s %s", msg.topic, msg.payload)

        command = msg.topic[len(topic + "/command") + 1:]
        payload = msg.payload

        if len(payload) == 0:
            payload = b"{}"
        try:
            payload = json.loads(payload.decode("UTF-8"))
        except:
            logging.error("MQTT: Decoding failed for %s", msg.payload)
            ans = {"error": "json decoding failed"}
        else:
            ans = dispatcher.do(command, payload)

        if ans is None:
            ans = {}
        on_status(dispatcher.status())
        logging.debug("MQTT: Answer %s", ans)
        client.publish(topic + "/data/" + command, json.dumps(ans))
    client.on_message = on_message

    client.connect(config["mqtt"]["server"], config["mqtt"].get("port", 1883))
    client.loop_forever()
