#
# Copyright (c) 2008-2015 Thierry Florac <tflorac AT ulthar.net>
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#

from datetime import date

from onf_website.reference.insee.model import Commune, CodePostal
from pyams_alchemy import Base
from pyams_alchemy.engine import get_user_session
from pyams_alchemy.mixin import DynamicSchemaMixin
from sqlalchemy import Column, Date, ForeignKey, Integer, Unicode
from sqlalchemy.orm import relation
from sqlalchemy.sql import and_, or_, text
from zope.interface import implementer, provider
from zope.schema.interfaces import ITitledTokenizedTerm

from onf_website.reference.orga.model.interfaces import IStructure, IStructureModel


PARENT_SCHEMA = 'refstruct'
PARENT_SESSION = 'REFSTRUCT'


@provider(IStructureModel)
@implementer(IStructure, ITitledTokenizedTerm)
class Structure(DynamicSchemaMixin, Base):
    """Base ONF structure"""

    __tablename__ = 'vm_bu_structures'
    __schema__ = PARENT_SCHEMA

    id = Column(Integer, primary_key=True)
    code = Column(Unicode(6))
    code_sign = Column(Unicode(6))
    libelle = Column(Unicode(30))
    libelle_long = Column(Unicode(40))
    niveau = Column(Integer)
    libelle_type = Column(Unicode(20))
    libelle_type_long = Column('libelle_long_type', Unicode(40))
    date_debut = Column('d_debut', Date)
    date_fin = Column('d_fin', Date)
    date_etat = Column('etat', Date)
    complement_adresse = Column('adresse_comp', Unicode(40))
    num_voie = Column(Integer)
    ext_voie = Column(Unicode(5))
    nature_voie = Column('nat_voie', Unicode(32))
    nom_commune = Column(Unicode(32))
    code_commune = Column(Unicode(5))
    code_postal = Column(Unicode(10))
    bureau_distrib = Column('bur_distrib', Unicode(32))
    pays = Column(Unicode(6))
    code_iso_pays = Column('pays_code_iso', Unicode(2))
    categorie = Column('code_categorie', Unicode(10))

    @property
    def value(self):
        return self.code

    @property
    def token(self):
        return str(self.code)

    @property
    def title(self):
        if self.code_sign:
            return u'%s - %s' % (self.code_sign, self.libelle_long)
        else:
            return self.libelle_long

    @classmethod
    def get_session(cls, session=PARENT_SESSION):
        if isinstance(session, str):
            session = get_user_session(session)
        return session

    @classmethod
    def get_default_params(cls, reference_date=None):
        if reference_date is None:
            reference_date = date.today()
        return [Structure.date_debut <= reference_date,
                or_(Structure.date_fin == None,
                    Structure.date_fin >= reference_date)]

    @classmethod
    def get(cls, code, reference_date=None, session=PARENT_SESSION):
        if reference_date is None:
            reference_date = date.today()
        if isinstance(code, (list, tuple, set)):
            params = [Structure.code.in_(code), ]
        else:
            params = [Structure.code == code, ]
        params.extend(cls.get_default_params(reference_date))
        session = Structure.get_session(session)
        return session.query(Structure).filter(and_(*params))

    @classmethod
    def find(cls, query, reference_date=None, exact=False, session=PARENT_SESSION):
        if isinstance(query, dict):
            if 'query' in query:
                key = query.get('query', '').strip()
                code = query.get('query', '').strip()
                label = query.get('query', '').strip().upper()
            else:
                key = query.get('key', '')
                if isinstance(key, str):
                    key = key.strip()
                code = query.get('code', '')
                if isinstance(code, str):
                    code = code.strip()
                label = query.get('label', '').strip().upper()
            type = query.get('type', '').strip()
            category = query.get('category', None)
            if isinstance(category, str):
                category = category.strip().split(';')
        else:
            key = query.strip().upper()
            code = query.strip().upper()
            label = query.strip().upper()
            type = None
            category = None
        if reference_date is None:
            reference_date = date.today()
        params = []
        code_params = []
        if key:
            if isinstance(key, (list, tuple, set)):
                code_params.append(Structure.code.in_(key))
            else:
                code_params.append(Structure.code == key)
        if code:
            if exact:
                if isinstance(code, (list, tuple, set)):
                    code_params.append(Structure.code_sign.in_(code))
                else:
                    code_params.append(Structure.code_sign == code)
            else:
                if isinstance(code, (list, tuple, set)):
                    code_params.append(or_(*[Structure.code_sign.like(c + '%') for c in code]))
                else:
                    code_params.append(Structure.code_sign.like(code + '%'))
        if label:
            if exact:
                code_params.append(or_(Structure.libelle == label,
                                       Structure.libelle_long == label))
            else:
                code_params.append(or_(Structure.libelle.like('%' + label + '%'),
                                       Structure.libelle_long.like('%' + label + '%')))
        if code_params:
            params.append(or_(*code_params))
        if type:
            params.append(Structure.libelle_type == type)
        if category:
            params.append(Structure.categorie.in_(category))
        params.extend(cls.get_default_params(reference_date))
        params = [and_(*params), ]
        session = Structure.get_session(session)
        return session.query(Structure).filter(and_(*params))

    @classmethod
    def find_by_insee_code(cls, insee_code, fields=None, session=PARENT_SESSION):
        """Find structures (territorial units) affected to a given INSEE code

        :param str insee_code: INSEE code to look for
        :param tuple fields: if not None, a tuple containing attributes to return; by default,
            complete :ref:`Structure` records are returned
        :param str session: a registered SQLAlchemy engine name, or an already created session
            object
        :return iterator: an iterator of structures (territorial units) records matching given
            INSEE code
        """
        if not fields:
            fields = (Structure,)
        session = Structure.get_session(session)
        return session.query(*fields) \
                      .join(Structure.affectations) \
                      .filter(*Structure.get_default_params()) \
                      .filter(AffectationCommune.code_commune == insee_code) \
                      .distinct()

    @classmethod
    def find_by_postal_code(cls, postal_code, fields=None, session=PARENT_SESSION):
        """Find structures (territorial units) affected to a given postal code

        :param str postal_code: postal code to look for
        :param tuple fields: if not None, a tuple containing attributes to return; by default,
            complete :ref:`Structure` records are returned
        :param str session: a registered SQLAlchemy engine name, or an already created session
            object
        :return iterator: an iterator of structures (territorial units) records matching given
            postal code
        """
        if not fields:
            fields = (Structure,)
        session = Structure.get_session(session)
        return session.query(*fields) \
                      .join(Structure.affectations) \
                      .join(AffectationCommune.commune) \
                      .join(Commune.ref_codespostaux) \
                      .filter(*Structure.get_default_params()) \
                      .filter(CodePostal.code_postal == postal_code) \
                      .distinct()

    def get_parents(self, fields=None, session=PARENT_SESSION):
        """Get a structure and all it's parents"""
        session = Structure.get_session(session)
        connection = session.connection()
        if connection.engine.url.get_backend_name() == 'oracle':
            statement = "select struc.*, level " \
                        "from {schema}.{tablename} struc " \
                        "start with (code='{code}') and (d_fin is null or d_fin > sysdate) " \
                        "connect by nocycle (prior code_parent = code) and " \
                        "(d_fin is null or d_fin > sysdate) " \
                        "order by level"
        else:  # PostgreSQL
            statement = "with recursive q as (" \
                        "    select * " \
                        "       from {schema}.{tablename} struc1 " \
                        "       where (struc1.code='{code}') and " \
                        "             (struc1.d_fin is null or struc1.d_fin > now()) " \
                        "    union all " \
                        "    select struc2.* " \
                        "       from {schema}.{tablename} struc2 " \
                        "       join q on q.code_parent = struc2.code " \
                        "       where (struc2.d_fin is null) or (struc2.d_fin > now()) " \
                        ") select * from q order by niveau"
        if not fields:
            fields = (Structure,)
        return session.query(*fields) \
                      .from_statement(text(statement.format(schema=self.__table__.schema,
                                                            tablename=self.__tablename__,
                                                            code=self.code)))

    def get_first_parent(self, category, session=PARENT_SESSION):
        """Get first parent matching given categories"""
        if isinstance(category, str):
            category = category.split(';')
        for parent in reversed(self.getParents(session)[0:-1]):
            if parent.categorie in category:
                return parent

    def get_children(self, fields=None, session=PARENT_SESSION, category=None, direct_only=False, reference_date=None):
        """Get a structure and all it's children"""
        if not fields:
            fields = (Structure,)
        if category:
            if isinstance(category, str):
                category = category.strip().split(';')
        else:
            category = ''
        if reference_date is None:
            reference_date = date.today()
        session = Structure.get_session(session)
        structures = session.query(*fields)
        if direct_only:
            structures = structures.filter(and_(or_(Structure.code == self.code,
                                                    Structure.code_parent == self.code),
                                                and_(Structure.date_debut <= reference_date,
                                                     or_(Structure.date_fin == None,
                                                         Structure.date_fin >= reference_date))))
            if category:
                structures = structures.filter(Structure.categorie.in_(category))
            return structures.order_by(Structure.code_sign)
        else:
            connection = session.connection()
            if connection.engine.url.get_backend_name() == 'oracle':
                statement = "select struc.* " \
                            "from {schema}.{tablename} struc " \
                            "start with (code='{code}') and (d_fin is null or d_fin > sysdate) " \
                            "connect by nocycle (code_parent = prior code) and " \
                            "(d_fin is null or d_fin > sysdate) " \
                            "order siblings by code_sign, libelle_long"
            else:  # PostgreSQL
                statement = "with recursive q as (" \
                            "   select * " \
                            "       from {schema}.{tablename} struc1 " \
                            "       where (struc1.code='{code}') and " \
                            "             (struc1.d_fin is null or struc1.d_fin > now()) " \
                            "   union all " \
                            "   select struc2.* " \
                            "       from {schema}.{tablename} struc2 " \
                            "       join q on q.code = struc2.code_parent " \
                            "       where (struc2.d_fin is null) or (struc2.d_fin > now()) " \
                            ") select * from q order by code_sign, libelle_long"
            result = structures.from_statement(statement.format(schema=self.__table__.schema,
                                                                tablename=self.__tablename__,
                                                                code=self.code))
            if category:
                result = filter(lambda x: x.categorie in category, result)
            return result


Structure.code_parent = Column('code_parent', Unicode(6), ForeignKey(Structure.code))


class AffectationCommune(DynamicSchemaMixin, Base):
    """Link between commune and ONF territorial unit"""

    __tablename__ = 'communes_ut'
    __schema__ = PARENT_SCHEMA

    code_commune = Column(Unicode(5), primary_key=True)
    code_ut = Column(Unicode(6))


AffectationCommune.commune = relation(Commune,
                                      primaryjoin=AffectationCommune.code_commune==Commune.code,
                                      foreign_keys=Commune.code,
                                      backref='affectation',
                                      uselist=False)
AffectationCommune.structure = relation(Structure,
                                        primaryjoin=AffectationCommune.code_ut==Structure.code_sign,
                                        foreign_keys=Structure.code_sign,
                                        backref='affectations',
                                        uselist=True)
