""" lib/words.py """

import itertools
from collections import Counter

import nltk
from nltk.corpus import brown
from nltk.corpus import words  # @UnusedImport
from nltk.probability import FreqDist


class Words:
    def __init__(self, min_limit=3, max_limit=8):
        # Initialize limits
        self.min_limit = min_limit
        self.max_limit = max_limit

        # Load and cache necessary data
        self._initialize_data()

        # Initialize trie
        self.trie_root = TrieNode()
        self._populate_trie()

    def _initialize_data(self):
        """Check and download necessary NLTK corpora, and load frequency distribution"""

        nltk.data.path.append("/var/www/piapp/nltk_data")
        nltk.data.find("corpora/words")
        nltk.data.find("corpora/brown")

        self.words = words
        self.brown = brown
        self.freq_dist = FreqDist(brown.words())

    def _populate_trie(self):
        """Populate trie with English words from NLTK"""
        for word in set(self.words.words()):
            self._insert_to_trie(word.lower())

    def _insert_to_trie(self, word):
        """Insert a word into the trie"""
        current = self.trie_root
        for letter in word:
            if letter not in current.children:
                current.children[letter] = TrieNode()
            current = current.children[letter]
        current.is_end_of_word = True

    def find_permutations(self, source_string):
        """Find and sort valid words from permutations of the source string"""
        valid_words = self._find_valid_words(source_string)
        sorted_word_freq_pairs = self.sort_words_by_frequency(valid_words)
        return sorted_word_freq_pairs

    def _find_valid_words(self, input_string):
        """Generate valid words from permutations of input string within limits"""
        found_words = set()
        char_counter = Counter(input_string.lower())
        for length in range(self.min_limit, min(self.max_limit, len(input_string)) + 1):
            for perm in itertools.permutations(char_counter.elements(), length):
                word = "".join(perm).lower()
                if self._find_in_trie(word):
                    found_words.add(word)
        return found_words

    def _find_in_trie(self, word):
        """Check if a word is in the trie"""
        current = self.trie_root
        for letter in word:
            if letter not in current.children:
                return False
            current = current.children[letter]
        return current.is_end_of_word

    def sort_words_by_frequency(self, unsorted_words):
        """Sort words by frequency in descending order"""
        pairs = [(word, self.freq_dist[word]) for word in unsorted_words]
        return sorted(pairs, key=lambda pair: pair[1], reverse=True)


class TrieNode:
    """Node for Trie data structure"""

    def __init__(self):
        self.children = {}
        self.is_end_of_word = False


def main():
    # Example usage
    words_instance = Words(min_limit=3, max_limit=7)
    source_string = "example"
    print(words_instance.find_permutations(source_string))


if __name__ == "__main__":
    main()
