Back to Repositories

Testing Weighted Maximum Cardinality Matching Implementation in Algorithms Repository

This test suite validates the Weighted Maximum Cardinality Matching (WMCM) algorithm implementation in Java, focusing on graph matching optimization and cost calculations. The tests cover various graph configurations and edge cases to ensure correct matching pairs and minimum weight calculations.

Test Coverage Overview

The test suite provides comprehensive coverage of WMCM algorithm functionality including:
  • Small to medium-sized graph validations
  • Perfect and non-perfect matching scenarios
  • Negative edge weight handling
  • Disjoint graph testing
  • Random graph generation and validation
  • Brute force comparison tests

Implementation Analysis

The testing approach uses JUnit 5 framework with a combination of specific test cases and randomized testing. It implements both recursive and iterative WMCM solutions, comparing results against a brute force implementation for verification. The test structure employs helper methods for matrix creation and edge weight management.

Key patterns include matrix-based graph representation, symmetric cost validation, and comprehensive assertion checking.

Technical Details

Testing tools and configuration:
  • JUnit Jupiter for test execution
  • Google Truth for assertions
  • Custom MwpmInterface implementation
  • Matrix-based graph representation
  • Utility methods for graph construction
  • Randomized test data generation

Best Practices Demonstrated

The test suite exemplifies several testing best practices including thorough edge case coverage, randomized testing for robustness, and comprehensive validation of algorithm correctness. It demonstrates proper test isolation, clear test naming conventions, and effective use of helper methods for test setup and validation.

Notable practices include symmetric matrix validation, cost consistency checks, and comparison against brute force solutions for accuracy verification.

williamfiset/algorithms

src/test/java/com/williamfiset/algorithms/dp/WeightedMaximumCardinalityMatchingTest.java

            
package com.williamfiset.algorithms.dp;

import static com.google.common.truth.Truth.assertThat;

import java.util.*;
import org.junit.jupiter.api.*;

public class WeightedMaximumCardinalityMatchingTest {

  static final int LOOPS = 50;
  static final double INF = 987654321.0;

  static class BruteForceMwpm {
    private int n;
    private Double[][] matrix;
    private double minWeightMatching = Double.POSITIVE_INFINITY;

    public BruteForceMwpm(Double[][] matrix) {
      this.matrix = matrix;
      this.n = matrix.length;
    }

    public double getMinWeightCost() {
      solve();
      return minWeightMatching;
    }

    public double computeMatchingCost(int[] p) {
      double t = 0;
      for (int i = 0; i < n / 2; i++) {
        int ii = p[2 * i];
        int jj = p[2 * i + 1];
        t += matrix[ii][jj];
      }
      return t;
    }

    public void solve() {
      int[] permutation = new int[n];
      for (int i = 0; i < n; i++) permutation[i] = i;

      // Try all matchings
      do {
        double matchingCost = computeMatchingCost(permutation);
        if (matchingCost < minWeightMatching) {
          minWeightMatching = matchingCost;
        }
      } while (nextPermutation(permutation));
    }

    // Generates the next ordered permutation in-place (skips repeated permutations).
    // Calling this when the array is already at the highest permutation returns false.
    // Recommended usage is to start with the smallest permutations and use a do while
    // loop to generate each successive permutations (see main for example).
    public static boolean nextPermutation(int[] sequence) {
      int first = getFirst(sequence);
      if (first == -1) return false;
      int toSwap = sequence.length - 1;
      while (sequence[first] >= sequence[toSwap]) --toSwap;
      swap(sequence, first++, toSwap);
      toSwap = sequence.length - 1;
      while (first < toSwap) swap(sequence, first++, toSwap--);
      return true;
    }

    private static int getFirst(int[] sequence) {
      for (int i = sequence.length - 2; i >= 0; --i) if (sequence[i] < sequence[i + 1]) return i;
      return -1;
    }

    private static void swap(int[] sequence, int i, int j) {
      int tmp = sequence[i];
      sequence[i] = sequence[j];
      sequence[j] = tmp;
    }
  }

  private static MwpmInterface[] getImplementations(Double[][] costMatrix) {
    return new MwpmInterface[] {new WeightedMaximumCardinalityMatchingRecursive(costMatrix)
      // new WeightedMaximumCardinalityMatchingIterative(costMatrix)
    };
  }

  private static Double[][] createEmptyMatrix(int n) {
    Double[][] costMatrix = new Double[n][n];
    for (int i = 0; i < n; ++i) {
      for (int j = 0; j < n; ++j) {
        if (i == j) continue;
        costMatrix[i][j] = null;
      }
    }
    return costMatrix;
  }

  private static void addUndirectedWeightedEdge(Double[][] g, int from, int to, double weight) {
    g[from][to] = weight;
    g[to][from] = weight;
  }

  @Test
  public void testSmallGraph_oddSize() {
    int n = 5;
    Double[][] g = createEmptyMatrix(n);
    // 0, 2; 3, 4
    addUndirectedWeightedEdge(g, 0, 1, 8);
    addUndirectedWeightedEdge(g, 0, 2, 1);
    addUndirectedWeightedEdge(g, 1, 3, 8);
    addUndirectedWeightedEdge(g, 2, 3, 8);
    addUndirectedWeightedEdge(g, 2, 4, 8);
    addUndirectedWeightedEdge(g, 3, 4, 2);

    MwpmInterface mwpm = new WeightedMaximumCardinalityMatchingRecursive(g);
    double cost = mwpm.getMinWeightCost();
    assertThat(cost).isEqualTo(3.0);

    int[] matching = mwpm.getMatching();
    int[] expectedMatching = {0, 2, 3, 4};
    assertThat(matching).isEqualTo(expectedMatching);
  }

  @Test
  public void testSmallestMatrix1() {
    // nodes 0 & 1 make the mwpm
    Double[][] costMatrix = {
      {0.0, 1.0},
      {1.0, 0.0},
    };
    MwpmInterface[] impls = getImplementations(costMatrix);
    for (MwpmInterface mwpm : impls) {
      double cost = mwpm.getMinWeightCost();
      assertThat(cost).isEqualTo(1.0);

      int[] matching = mwpm.getMatching();
      int[] expectedMatching = {0, 1};
      assertThat(matching).isEqualTo(expectedMatching);
    }
  }

  @Test
  public void testSmallMatrix1() {
    // nodes 0 & 2 and 1 & 3 make the mwpm
    Double[][] costMatrix = {
      {0.0, 2.0, 1.0, 2.0},
      {2.0, 0.0, 2.0, 1.0},
      {1.0, 2.0, 0.0, 2.0},
      {2.0, 1.0, 2.0, 0.0},
    };

    MwpmInterface[] impls = getImplementations(costMatrix);
    for (MwpmInterface mwpm : impls) {
      double cost = mwpm.getMinWeightCost();
      assertThat(cost).isEqualTo(2.0);

      int[] matching = mwpm.getMatching();
      int[] expectedMatching = {0, 2, 1, 3};
      assertThat(matching).isEqualTo(expectedMatching);
    }
  }

  @Test
  public void testSmallMatrix2() {
    // nodes 0 & 1 and 2 & 3 make the mwpm
    Double[][] costMatrix = {
      {0.0, 1.0, 2.0, 2.0},
      {1.0, 0.0, 2.0, 2.0},
      {2.0, 2.0, 0.0, 1.0},
      {2.0, 2.0, 1.0, 0.0},
    };

    MwpmInterface[] impls = getImplementations(costMatrix);
    for (MwpmInterface mwpm : impls) {
      double cost = mwpm.getMinWeightCost();
      assertThat(cost).isEqualTo(2.0);

      int[] matching = mwpm.getMatching();
      int[] expectedMatching = {0, 1, 2, 3};
      assertThat(matching).isEqualTo(expectedMatching);
    }
  }

  @Test
  public void testMediumMatrix1() {
    // mwpm between 0 & 5, 1 & 2, 3 & 4
    Double[][] costMatrix = {
      {0.0, 9.0, 9.0, 9.0, 9.0, 1.0},
      {9.0, 0.0, 1.0, 9.0, 9.0, 9.0},
      {9.0, 1.0, 0.0, 9.0, 9.0, 9.0},
      {9.0, 9.0, 9.0, 0.0, 1.0, 9.0},
      {9.0, 9.0, 9.0, 1.0, 0.0, 9.0},
      {1.0, 9.0, 9.0, 9.0, 9.0, 0.0},
    };

    MwpmInterface[] impls = getImplementations(costMatrix);
    for (MwpmInterface mwpm : impls) {
      double cost = mwpm.getMinWeightCost();
      assertThat(cost).isEqualTo(3.0);

      int[] matching = mwpm.getMatching();
      int[] expectedMatching = {0, 5, 1, 2, 3, 4};
      assertThat(matching).isEqualTo(expectedMatching);
    }
  }

  @Test
  public void testMediumMatrix2() {
    // mwpm between 0 & 1, 2 & 4, 3 & 5
    Double[][] costMatrix = {
      {0.0, 1.0, 9.0, 9.0, 9.0, 9.0},
      {1.0, 0.0, 9.0, 9.0, 9.0, 9.0},
      {9.0, 9.0, 0.0, 9.0, 1.0, 9.0},
      {9.0, 9.0, 9.0, 0.0, 9.0, 1.0},
      {9.0, 9.0, 1.0, 9.0, 0.0, 9.0},
      {9.0, 9.0, 9.0, 1.0, 9.0, 0.0},
    };

    MwpmInterface[] impls = getImplementations(costMatrix);
    for (MwpmInterface mwpm : impls) {
      double cost = mwpm.getMinWeightCost();
      assertThat(cost).isEqualTo(3.0);

      int[] matching = mwpm.getMatching();
      int[] expectedMatching = {0, 1, 2, 4, 3, 5};
      assertThat(matching).isEqualTo(expectedMatching);
    }
  }

  @Test
  public void testMediumGraph_evenSize_fromSlides() {
    int n = 6;
    Double[][] g = createEmptyMatrix(n);

    addUndirectedWeightedEdge(g, 0, 1, 7);
    addUndirectedWeightedEdge(g, 0, 2, 6);
    addUndirectedWeightedEdge(g, 0, 4, -1);
    addUndirectedWeightedEdge(g, 1, 3, 1);
    addUndirectedWeightedEdge(g, 1, 4, 3);
    addUndirectedWeightedEdge(g, 1, 5, 5);
    addUndirectedWeightedEdge(g, 2, 4, 5);
    addUndirectedWeightedEdge(g, 3, 5, 3);
    addUndirectedWeightedEdge(g, 4, 5, 8);

    MwpmInterface mwpm = new WeightedMaximumCardinalityMatchingRecursive(g);
    double cost = mwpm.getMinWeightCost();
    assertThat(cost).isEqualTo(12);

    int[] matching = mwpm.getMatching();

    int[] expectedMatching = {0, 2, 1, 4, 3, 5};
    assertThat(matching).isEqualTo(expectedMatching);
  }

  @Test
  public void testMediumGraph_evenSize_nonPerfectMatchingFromSlides() {
    int n = 6;
    Double[][] g = createEmptyMatrix(n);

    addUndirectedWeightedEdge(g, 0, 1, 6);
    addUndirectedWeightedEdge(g, 1, 2, 7);
    addUndirectedWeightedEdge(g, 1, 5, 8);
    addUndirectedWeightedEdge(g, 1, 4, 9);
    addUndirectedWeightedEdge(g, 1, 3, 10);
    addUndirectedWeightedEdge(g, 3, 4, 11);

    MwpmInterface mwpm = new WeightedMaximumCardinalityMatchingRecursive(g);
    double cost = mwpm.getMinWeightCost();
    assertThat(cost).isEqualTo(17);

    int[] matching = mwpm.getMatching();

    int[] expectedMatching = {0, 1, 3, 4};
    assertThat(matching).isEqualTo(expectedMatching);
  }

  @Test
  public void testNegativeEdgeWeights() {
    int n = 6;
    Double[][] g = createEmptyMatrix(n);
    addUndirectedWeightedEdge(g, 0, 1, -1); // selected
    addUndirectedWeightedEdge(g, 1, 2, -2);
    addUndirectedWeightedEdge(g, 2, 3, -3); // selected
    addUndirectedWeightedEdge(g, 3, 4, -4);
    addUndirectedWeightedEdge(g, 4, 5, -5); // selected

    MwpmInterface mwpm = new WeightedMaximumCardinalityMatchingRecursive(g);
    double cost = mwpm.getMinWeightCost();
    assertThat(cost).isEqualTo(-1 + -3 + -5);

    int[] matching = mwpm.getMatching();
    int[] expectedMatching = {0, 1, 2, 3, 4, 5};
    assertThat(matching).isEqualTo(expectedMatching);
  }

  @Test
  public void testNegativeEdge_smallerThanINFWeights() {
    int n = 6;
    Double[][] g = createEmptyMatrix(n);
    addUndirectedWeightedEdge(g, 0, 1, -1 - 50 * INF); // selected
    addUndirectedWeightedEdge(g, 1, 2, -2 - 50 * INF);
    addUndirectedWeightedEdge(g, 2, 3, -3 - 50 * INF); // selected
    addUndirectedWeightedEdge(g, 3, 4, -4 - 50 * INF);
    addUndirectedWeightedEdge(g, 4, 5, -5 - 50 * INF); // selected

    MwpmInterface mwpm = new WeightedMaximumCardinalityMatchingRecursive(g);
    double cost = mwpm.getMinWeightCost();
    double expectedCost = -1.0 + -3.0 + -5.0 + (-3 * 50 * INF);
    assertThat(cost).isEqualTo(expectedCost);

    int[] matching = mwpm.getMatching();
    int[] expectedMatching = {0, 1, 2, 3, 4, 5};
    assertThat(matching).isEqualTo(expectedMatching);
    assertOptimalMatching(matching, g, expectedCost);
  }

  @Test
  public void testDisjointGraph() {
    int n = 8;
    Double[][] g = createEmptyMatrix(n);
    addUndirectedWeightedEdge(g, 0, 1, 3);
    addUndirectedWeightedEdge(g, 0, 2, 5);
    addUndirectedWeightedEdge(g, 1, 2, 1);
    addUndirectedWeightedEdge(g, 1, 3, 4);
    addUndirectedWeightedEdge(g, 2, 3, 2);

    addUndirectedWeightedEdge(g, 4, 5, 3);
    addUndirectedWeightedEdge(g, 4, 6, 5);
    addUndirectedWeightedEdge(g, 5, 6, 1);
    addUndirectedWeightedEdge(g, 5, 7, 4);
    addUndirectedWeightedEdge(g, 6, 7, 2);

    MwpmInterface mwpm = new WeightedMaximumCardinalityMatchingRecursive(g);
    double cost = mwpm.getMinWeightCost();
    assertThat(cost).isEqualTo(10);

    int[] matching = mwpm.getMatching();
    assertOptimalMatching(matching, g, 10);
  }

  @Test
  public void testHarderWmcm_fromSlides() {
    int n = 11;
    Double[][] g = createEmptyMatrix(n);
    addUndirectedWeightedEdge(g, 0, 1, 1);
    addUndirectedWeightedEdge(g, 0, 3, 8);
    addUndirectedWeightedEdge(g, 0, 4, 9);
    addUndirectedWeightedEdge(g, 0, 5, 7); // selected
    addUndirectedWeightedEdge(g, 1, 2, 1);
    addUndirectedWeightedEdge(g, 1, 6, 7); // selected
    addUndirectedWeightedEdge(g, 1, 7, 8);
    addUndirectedWeightedEdge(g, 1, 8, 9);
    addUndirectedWeightedEdge(g, 2, 9, 9);
    addUndirectedWeightedEdge(g, 2, 10, 7); // selected

    MwpmInterface mwpm = new WeightedMaximumCardinalityMatchingRecursive(g);
    double cost = mwpm.getMinWeightCost();
    assertThat(cost).isEqualTo(7 + 7 + 7);

    int[] matching = mwpm.getMatching();
    int[] expectedMatching = {0, 5, 1, 6, 2, 10};
    assertThat(matching).isEqualTo(expectedMatching);
  }

  @Test
  public void testMatchingOutputsUniqueNodes() {
    for (int loop = 0; loop < LOOPS; loop++) {
      int n = Math.max(1, (int) (Math.random() * 11)) * 2; // n is either 2,4,6,8,10,12,14,16,18,20
      Double[][] costMatrix = new Double[n][n];
      randomFillSymmetricMatrix(costMatrix, 100);

      MwpmInterface[] impls = getImplementations(costMatrix);
      for (MwpmInterface mwpm : impls) {
        int[] matching = mwpm.getMatching();
        Set<Integer> set = new HashSet<>();
        for (int i = 0; i < matching.length; i++) {
          set.add(matching[i]);
        }

        assertThat(set.size()).isEqualTo(matching.length);
      }
    }
  }

  @Test
  public void testMatchingAndCostAreConsistent() {
    for (int loop = 0; loop < LOOPS; loop++) {
      int n = Math.max(1, (int) (Math.random() * 11)) * 2; // n is either 2,4,6,8,10,12,14,16,18,20
      Double[][] costMatrix = new Double[n][n];
      randomFillSymmetricMatrix(costMatrix, 100);

      MwpmInterface[] impls = getImplementations(costMatrix);
      for (MwpmInterface mwpm : impls) {
        assertOptimalMatching(mwpm.getMatching(), costMatrix, mwpm.getMinWeightCost());
      }
    }
  }

  @Test
  public void testAgainstBruteForce_largeValues() {
    for (int loop = 0; loop < LOOPS; loop++) {
      int n = Math.max(1, (int) (Math.random() * 6)) * 2; // n is either 2,4,6,8, or 10
      Double[][] costMatrix = new Double[n][n];
      randomFillSymmetricMatrix(costMatrix, /* maxValue= */ 10000);

      MwpmInterface[] impls = getImplementations(costMatrix);
      for (MwpmInterface mwpm : impls) {
        int[] matching = mwpm.getMatching();
        BruteForceMwpm bfMwpm = new BruteForceMwpm(costMatrix);
        double dpSoln = mwpm.getMinWeightCost();
        double bfSoln = bfMwpm.getMinWeightCost();
        assertThat(dpSoln).isEqualTo(bfSoln);
      }
    }
  }

  @Test
  public void testAgainstBruteForce_smallValues() {
    for (int loop = 0; loop < LOOPS; loop++) {
      int n = Math.max(1, (int) (Math.random() * 6)) * 2; // n is either 2,4,6,8, or 10
      Double[][] costMatrix = new Double[n][n];
      randomFillSymmetricMatrix(costMatrix, /* maxValue= */ 3);

      MwpmInterface[] impls = getImplementations(costMatrix);
      for (MwpmInterface mwpm : impls) {

        BruteForceMwpm bfMwpm = new BruteForceMwpm(costMatrix);
        double dpSoln = mwpm.getMinWeightCost();
        double bfSoln = bfMwpm.getMinWeightCost();

        assertThat(dpSoln).isEqualTo(bfSoln);
      }
    }
  }

  public void randomFillSymmetricMatrix(Double[][] dist, int maxValue) {
    for (int i = 0; i < dist.length; i++) {
      for (int j = i + 1; j < dist.length; j++) {
        double val = (int) (Math.random() * maxValue);
        dist[i][j] = dist[j][i] = val;
      }
    }
  }

  private void assertOptimalMatching(
      int[] matching, Double[][] costMatrix, double expectedMatchingCost) {
    double total = 0;
    for (int i = 0; i < matching.length; i += 2) {
      total += costMatrix[matching[i]][matching[i + 1]];
    }
    assertThat(total).isEqualTo(expectedMatchingCost);
  }
}