# cython: infer_types=True, binding=True
import importlib
import sys
from collections import defaultdict
from typing import Callable, Optional

from thinc.api import Config, Model

from ._parser_internals.transition_system import TransitionSystem

from ._parser_internals.ner cimport BiluoPushDown
from .transition_parser cimport Parser

from ..language import Language
from ..scorer import get_ner_prf
from ..training import remove_bilu_prefix
from ..util import registry

default_model_config = """
[model]
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "ner"
extra_state_tokens = false
hidden_width = 64
maxout_pieces = 2
use_upper = true

[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v2"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true
"""
DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"]


def ner_score(examples, **kwargs):
    return get_ner_prf(examples, **kwargs)


def make_ner_scorer():
    return ner_score


cdef class EntityRecognizer(Parser):
    """Pipeline component for named entity recognition.

    DOCS: https://spacy.io/api/entityrecognizer
    """
    TransitionSystem = BiluoPushDown

    def __init__(
        self,
        vocab,
        model,
        name="ner",
        moves=None,
        *,
        update_with_oracle_cut_size=100,
        beam_width=1,
        beam_density=0.0,
        beam_update_prob=0.0,
        multitasks=tuple(),
        incorrect_spans_key=None,
        scorer=ner_score,
    ):
        """Create an EntityRecognizer.
        """
        super().__init__(
            vocab,
            model,
            name,
            moves,
            update_with_oracle_cut_size=update_with_oracle_cut_size,
            min_action_freq=1,   # not relevant for NER
            learn_tokens=False,  # not relevant for NER
            beam_width=beam_width,
            beam_density=beam_density,
            beam_update_prob=beam_update_prob,
            multitasks=multitasks,
            incorrect_spans_key=incorrect_spans_key,
            scorer=scorer,
        )

    def add_multitask_objective(self, mt_component):
        """Register another component as a multi-task objective. Experimental."""
        self._multitasks.append(mt_component)

    def init_multitask_objectives(self, get_examples, nlp=None, **cfg):
        """Setup multi-task objective components. Experimental and internal."""
        # TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
        for labeller in self._multitasks:
            labeller.model.set_dim("nO", len(self.labels))
            if labeller.model.has_ref("output_layer"):
                labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
            labeller.initialize(get_examples, nlp=nlp)

    @property
    def labels(self):
        # Get the labels from the model by looking at the available moves, e.g.
        # B-PERSON, I-PERSON, L-PERSON, U-PERSON
        labels = set(remove_bilu_prefix(move) for move in self.move_names
                     if move[0] in ("B", "I", "L", "U"))
        return tuple(sorted(labels))

    def scored_ents(self, beams):
        """Return a dictionary of (start, end, label) tuples with corresponding scores
        for each beam/doc that was processed.
        """
        entity_scores = []
        for beam in beams:
            score_dict = defaultdict(float)
            for score, ents in self.moves.get_beam_parses(beam):
                for start, end, label in ents:
                    score_dict[(start, end, label)] += score
            entity_scores.append(score_dict)
        return entity_scores


# Setup backwards compatibility hook for factories
def __getattr__(name):
    if name == "make_ner":
        module = importlib.import_module("spacy.pipeline.factories")
        return module.make_ner
    elif name == "make_beam_ner":
        module = importlib.import_module("spacy.pipeline.factories")
        return module.make_beam_ner
    raise AttributeError(f"module {__name__} has no attribute {name}")
