from sqlparse import parse
from sqlparse.tokens import Keyword, CTE, DML
from sqlparse.sql import Identifier, IdentifierList, Parenthesis
from collections import namedtuple
from .meta import TableMetadata, ColumnMetadata


# TableExpression is a namedtuple representing a CTE, used internally
# name: cte alias assigned in the query
# columns: list of column names
# start: index into the original string of the left parens starting the CTE
# stop: index into the original string of the right parens ending the CTE
TableExpression = namedtuple("TableExpression", "name columns start stop")


def isolate_query_ctes(full_text, text_before_cursor):
    """Simplify a query by converting CTEs into table metadata objects"""

    if not full_text or not full_text.strip():
        return full_text, text_before_cursor, tuple()

    ctes, _ = extract_ctes(full_text)
    if not ctes:
        return full_text, text_before_cursor, ()

    current_position = len(text_before_cursor)
    meta = []

    for cte in ctes:
        if cte.start < current_position < cte.stop:
            # Currently editing a cte - treat its body as the current full_text
            text_before_cursor = full_text[cte.start: current_position]
            full_text = full_text[cte.start: cte.stop]
            return full_text, text_before_cursor, meta

        # Append this cte to the list of available table metadata
        cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
        meta.append(TableMetadata(cte.name, cols))

    # Editing past the last cte (ie the main body of the query)
    full_text = full_text[ctes[-1].stop:]
    text_before_cursor = text_before_cursor[ctes[-1].stop: current_position]

    return full_text, text_before_cursor, tuple(meta)


def extract_ctes(sql):
    """Extract constant table expresseions from a query

    Returns tuple (ctes, remainder_sql)

    ctes is a list of TableExpression namedtuples
    remainder_sql is the text from the original query after the CTEs have
    been stripped.
    """

    p = parse(sql)[0]

    # Make sure the first meaningful token is "WITH" which is necessary to
    # define CTEs
    idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
    if not (tok and tok.ttype == CTE):
        return [], sql

    # Get the next (meaningful) token, which should be the first CTE
    idx, tok = p.token_next(idx)
    if not tok:
        return ([], "")
    start_pos = token_start_pos(p.tokens, idx)
    ctes = []

    if isinstance(tok, IdentifierList):
        # Multiple ctes
        for t in tok.get_identifiers():
            cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
            cte = get_cte_from_token(t, start_pos + cte_start_offset)
            if not cte:
                continue
            ctes.append(cte)
    elif isinstance(tok, Identifier):
        # A single CTE
        cte = get_cte_from_token(tok, start_pos)
        if cte:
            ctes.append(cte)

    idx = p.token_index(tok) + 1

    # Collapse everything after the ctes into a remainder query
    remainder = "".join(str(tok) for tok in p.tokens[idx:])

    return ctes, remainder


def get_cte_from_token(tok, pos0):
    cte_name = tok.get_real_name()
    if not cte_name:
        return None

    # Find the start position of the opening parens enclosing the cte body
    idx, parens = tok.token_next_by(Parenthesis)
    if not parens:
        return None

    start_pos = pos0 + token_start_pos(tok.tokens, idx)
    cte_len = len(str(parens))  # includes parens
    stop_pos = start_pos + cte_len

    column_names = extract_column_names(parens)

    return TableExpression(cte_name, column_names, start_pos, stop_pos)


def extract_column_names(parsed):
    # Find the first DML token to check if it's a
    # SELECT or INSERT/UPDATE/DELETE
    idx, tok = parsed.token_next_by(t=DML)
    tok_val = tok and tok.value.lower()

    if tok_val in ("insert", "update", "delete"):
        # Jump ahead to the RETURNING clause where the list of column names is
        idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
    elif tok_val != "select":
        # Must be invalid CTE
        return ()

    # The next token should be either a column name, or a list of column names
    idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
    return tuple(t.get_name() for t in _identifiers(tok))


def token_start_pos(tokens, idx):
    return sum(len(str(t)) for t in tokens[:idx])


def _identifiers(tok):
    if isinstance(tok, IdentifierList):
        for t in tok.get_identifiers():
            # NB: IdentifierList.get_identifiers() can return non-identifiers!
            if isinstance(t, Identifier):
                yield t
    elif isinstance(tok, Identifier):
        yield tok
