import { TraitWeights } from "context/GenerateConfigContext";
import logIfNotProd from "utils/logIfNotProd";

export default function isTraitWeightsValid(
  weights: TraitWeights,
  numImages: number,
  numTraits: number
): boolean {
  if (weights.size !== numTraits) {
    logIfNotProd("trait weights invalid 1", weights, numImages, numTraits);
    return false;
  }

  for (const traitWeights of [...weights.values()]) {
    const sum = [...traitWeights.values()].reduce(
      (acc, currVal) => acc + currVal,
      0
    );

    if (sum !== numImages) {
      logIfNotProd(
        "trait weights invalid 2",
        weights,
        numImages,
        numTraits,
        sum
      );
      return false;
    }
  }

  return true;
}
