Fixed database typo and removed unnecessary class identifier.
This commit is contained in:
parent
00ad49a143
commit
45fb349a7d
5098 changed files with 952558 additions and 85 deletions
|
@ -0,0 +1,102 @@
|
|||
# Author: Brian M. Clapper, G. Varoquaux, Lars Buitinck
|
||||
# License: BSD
|
||||
|
||||
from numpy.testing import assert_array_equal
|
||||
from pytest import raises as assert_raises
|
||||
|
||||
import numpy as np
|
||||
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from scipy.sparse.sputils import matrix
|
||||
|
||||
|
||||
def test_linear_sum_assignment():
|
||||
for sign in [-1, 1]:
|
||||
for cost_matrix, expected_cost in [
|
||||
# Square
|
||||
([[400, 150, 400],
|
||||
[400, 450, 600],
|
||||
[300, 225, 300]],
|
||||
[150, 400, 300]
|
||||
),
|
||||
|
||||
# Rectangular variant
|
||||
([[400, 150, 400, 1],
|
||||
[400, 450, 600, 2],
|
||||
[300, 225, 300, 3]],
|
||||
[150, 2, 300]),
|
||||
|
||||
# Square
|
||||
([[10, 10, 8],
|
||||
[9, 8, 1],
|
||||
[9, 7, 4]],
|
||||
[10, 1, 7]),
|
||||
|
||||
# Rectangular variant
|
||||
([[10, 10, 8, 11],
|
||||
[9, 8, 1, 1],
|
||||
[9, 7, 4, 10]],
|
||||
[10, 1, 4]),
|
||||
|
||||
# n == 2, m == 0 matrix
|
||||
([[], []],
|
||||
[]),
|
||||
|
||||
# Square with positive infinities
|
||||
([[10, float("inf"), float("inf")],
|
||||
[float("inf"), float("inf"), 1],
|
||||
[float("inf"), 7, float("inf")]],
|
||||
[10, 1, 7]),
|
||||
]:
|
||||
|
||||
maximize = sign == -1
|
||||
cost_matrix = sign * np.array(cost_matrix)
|
||||
expected_cost = sign * np.array(expected_cost)
|
||||
|
||||
row_ind, col_ind = linear_sum_assignment(cost_matrix,
|
||||
maximize=maximize)
|
||||
assert_array_equal(row_ind, np.sort(row_ind))
|
||||
assert_array_equal(expected_cost, cost_matrix[row_ind, col_ind])
|
||||
|
||||
cost_matrix = cost_matrix.T
|
||||
row_ind, col_ind = linear_sum_assignment(cost_matrix,
|
||||
maximize=maximize)
|
||||
assert_array_equal(row_ind, np.sort(row_ind))
|
||||
assert_array_equal(np.sort(expected_cost),
|
||||
np.sort(cost_matrix[row_ind, col_ind]))
|
||||
|
||||
|
||||
def test_linear_sum_assignment_input_validation():
|
||||
|
||||
assert_raises(ValueError, linear_sum_assignment, [1, 2, 3])
|
||||
|
||||
C = [[1, 2, 3], [4, 5, 6]]
|
||||
assert_array_equal(linear_sum_assignment(C),
|
||||
linear_sum_assignment(np.asarray(C)))
|
||||
assert_array_equal(linear_sum_assignment(C),
|
||||
linear_sum_assignment(matrix(C)))
|
||||
|
||||
I = np.identity(3)
|
||||
assert_array_equal(linear_sum_assignment(I.astype(np.bool_)),
|
||||
linear_sum_assignment(I))
|
||||
assert_raises(ValueError, linear_sum_assignment, I.astype(str))
|
||||
|
||||
I[0][0] = np.nan
|
||||
assert_raises(ValueError, linear_sum_assignment, I)
|
||||
|
||||
I = np.identity(3)
|
||||
I[1][1] = -np.inf
|
||||
assert_raises(ValueError, linear_sum_assignment, I)
|
||||
|
||||
I = np.identity(3)
|
||||
I[:, 0] = np.inf
|
||||
assert_raises(ValueError, linear_sum_assignment, I)
|
||||
|
||||
|
||||
def test_constant_cost_matrix():
|
||||
# Fixes #11602
|
||||
n = 8
|
||||
C = np.ones((n, n))
|
||||
row_ind, col_ind = linear_sum_assignment(C)
|
||||
assert_array_equal(row_ind, np.arange(n))
|
||||
assert_array_equal(col_ind, np.arange(n))
|
Loading…
Add table
Add a link
Reference in a new issue