import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import numpy as np
import pandas as pd


def get_recon_error_df(orig, recon):
    """
    Given original data and reconstructed data, returns normalized
    Euclidean distances between original points and reconstructed points.

    Parameters
    ----
    orig : numpy array or a pd.DataFrame
        The original X

    recon : numpy array or a pd.DataFrame
        The reconstructed X

    Returns
    ----
    pd.DataFrame containing indices and reconstruction error for each example
    """
    loss = np.sqrt(np.sum((orig - recon) ** 2, axis=1))
    loss = (loss - np.min(loss)) / (np.max(loss) - np.min(loss))  # normalization
    loss_df = pd.DataFrame(data=loss, columns=["recon_error"])
    return loss_df

def pca_faces(X_faces, n_components=[10, 50, 100]):
    """
    Applies PCA to a dataset of faces to reduce dimensionality and reconstructs the images.

    Parameters
    ----------
    X_faces : np.ndarray
        The input dataset of face images, where each row represents an image.
    n_components : list of int, default=[10, 50, 100]
        The number of principal components to retain for the reconstruction.

    Returns
    -------
    list of np.ndarray
        A list of numpy arrays, each representing the dataset reconstructed
        from a different number of principal components.
    """
    reduced_images = []
    for n in n_components:
        pca = PCA(n_components=n)
        pca.fit(X_faces)
        X_hat = pca.inverse_transform(pca.transform(X_faces))
        reduced_images.append(X_hat)
    return reduced_images



def plot_strong_comp_images(X_anims, Z, W, compn=1, image_shape=(100,100), positive_direction=True):
    """
    Visualizes the images in `X_anims` where the specified component `compn` has the most 
    extreme values, based on the transformed data `Z` and components `W`.

    Parameters
    ----------
    X_anims : numpy.ndarray
        Original input images in array form.
        
    Z : numpy.ndarray
        Transformed data obtained after applying dimensionality reduction.
    
    W : numpy.ndarray
        Learned components (e.g., principal components or basis vectors from NMF/SVD).

    compn : int, optional (default=1)
        The index of the component to analyze.

    image_shape : tuple, optional (default=(100, 100))
        The shape of the images to be reshaped for visualization.

    positive_direction : bool, optional (default=True)
        If True, selects images with the highest positive values for the given component.
        If False, selects images with the lowest (most negative) values.

    Returns
    -------
    None
        Displays a figure showing:
        - The component itself as an image (first panel).
        - The six images from `X_anims` where the component `compn` has the strongest values.
    """    
    if positive_direction: 
        inds = np.argsort(Z[:, compn])[::-1]
    else: 
        inds = np.argsort(Z[:, compn])   
    fig, ax = plt.subplots(
        1, 7, figsize=(12,4), subplot_kw={"xticks": (), "yticks": ()}
    )    
    ax[0].set_title(f"Component {compn}")
    ax[0].imshow(W[compn].reshape(image_shape))
    i = 1
    for image in inds[:6]:
        ax[i].set_title(image)
        ax[i].imshow(X_anims[image].reshape(image_shape))
        i+=1
    fig.tight_layout()  
    plt.show()


