import numpy as np
import numpy.typing as npt
from util import *

EMBEDDINGS_FILE_PATH = "./openai_embeddings.csv"
PUZZLES_PATH = "./puzzles.json"
GROUP_SIZE = 4
NUM_GROUPS = 4
NUM_WORDS = 16

def find_best_quartet(words: list[str], embeddings: dict[str, npt.NDArray]) -> list[list[str]]:
    # TODO: YOUR CODE HERE
    return []

def calculate_group_internal_score(group: list[str], embeddings: dict[str, npt.NDArray], alpha:float=0.75) -> np.floating:
    cosine_similarities = []

    # TODO: YOUR CODE HERE

    return (alpha * np.mean(cosine_similarities)) + ((1 - alpha) * min(cosine_similarities))

def calculate_group_external_score(group: list[str], other_groups: list[list[str]], embeddings: dict[str, npt.NDArray]) -> np.floating:
    cosine_similarities = []
    
    # TODO: YOUR CODE HERE

    return -np.mean(cosine_similarities)

def calculate_word_ambiguity(word: str, words: list[str], embeddings: dict[str, npt.NDArray]) -> float:
    # TODO: YOUR CODE HERE
    return 0

def calculate_total_score(quartet: list[list[str]], words: list[str], embeddings: dict[str, npt.NDArray],
                          alpha:float=0.8, beta:float=0.2, penalty_weight:float=0.05):
    internal_scores = []
    external_scores = []
    ambiguity_penalties = []

    # TODO: YOUR CODE HERE

    overall_score = alpha * np.mean(internal_scores) \
                    + beta * np.mean(external_scores) \
                    - penalty_weight * np.mean(ambiguity_penalties)
    
    return overall_score

def main():
    # EXAMPLE USAGE:

    # puzzle #751
    answers = retrieve_puzzles(PUZZLES_PATH)[751]
    todays_words = convert_to_words(answers)

    # hardcoded, this would be:
    # todays_words = ['nick', 'palm', 'pinch', 'pocket', 'brush', 'dress', 'shave', 'shower', 'neat', 'sharp', 'smart', 'tidy', 'birth', 'key', 'mile', 'touch']
    # answers = [['nick', 'palm', 'pinch', 'pocket'], ['brush', 'dress', 'shave', 'shower'], ['neat', 'sharp', 'smart', 'tidy'], ['birth', 'key', 'mile', 'touch']]

    todays_embeddings = get_todays_embeddings(todays_words)
    
    print(find_best_quartet(todays_words, todays_embeddings))

if __name__ == "__main__":
    main()