76 lines
2.4 KiB
Python
76 lines
2.4 KiB
Python
|
import numpy as np
|
||
|
import skimage.graph.mcp as mcp
|
||
|
# import stentseg.graph._mcp as mcp
|
||
|
|
||
|
from skimage._shared.testing import assert_array_equal
|
||
|
|
||
|
|
||
|
a = np.ones((8, 8), dtype=np.float32)
|
||
|
count = 0
|
||
|
|
||
|
|
||
|
class MCP(mcp.MCP_Connect):
|
||
|
def _reset(self):
|
||
|
""" Reset the id map.
|
||
|
"""
|
||
|
mcp.MCP_Connect._reset(self)
|
||
|
self._conn = {}
|
||
|
self._bestconn = {}
|
||
|
|
||
|
def create_connection(self, id1, id2, pos1, pos2, cost1, cost2):
|
||
|
# Process data
|
||
|
hash = min(id1, id2), max(id1, id2)
|
||
|
val = min(pos1, pos2), max(pos1, pos2)
|
||
|
cost = min(cost1, cost2)
|
||
|
# Add to total list
|
||
|
self._conn.setdefault(hash, []).append(val)
|
||
|
# Keep track of connection with lowest cost
|
||
|
curcost = self._bestconn.get(hash, (np.inf,))[0]
|
||
|
if cost < curcost:
|
||
|
self._bestconn[hash] = (cost,) + val
|
||
|
|
||
|
|
||
|
def test_connections():
|
||
|
# Create MCP object with three seed points
|
||
|
mcp = MCP(a)
|
||
|
costs, traceback = mcp.find_costs([(1, 1), (7, 7), (1, 7)])
|
||
|
|
||
|
# Test that all three seed points are connected
|
||
|
connections = set(mcp._conn.keys())
|
||
|
assert (0, 1) in connections
|
||
|
assert (1, 2) in connections
|
||
|
assert (0, 2) in connections
|
||
|
|
||
|
# Test that any two neighbors have only been connected once
|
||
|
for position_tuples in mcp._conn.values():
|
||
|
n1 = len(position_tuples)
|
||
|
n2 = len(set(position_tuples))
|
||
|
assert n1 == n2
|
||
|
|
||
|
# For seed 0 and 1
|
||
|
cost, pos1, pos2 = mcp._bestconn[(0, 1)]
|
||
|
# Test meeting points
|
||
|
assert (pos1, pos2) == ((3, 3), (4, 4))
|
||
|
# Test the whole path
|
||
|
path = mcp.traceback(pos1) + list(reversed(mcp.traceback(pos2)))
|
||
|
assert_array_equal(
|
||
|
path, [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7)])
|
||
|
|
||
|
# For seed 1 and 2
|
||
|
cost, pos1, pos2 = mcp._bestconn[(1, 2)]
|
||
|
# Test meeting points
|
||
|
assert (pos1, pos2) == ((3, 7), (4, 7))
|
||
|
# Test the whole path
|
||
|
path = mcp.traceback(pos1) + list(reversed(mcp.traceback(pos2)))
|
||
|
assert_array_equal(
|
||
|
path, [(1, 7), (2, 7), (3, 7), (4, 7), (5, 7), (6, 7), (7, 7)])
|
||
|
|
||
|
# For seed 0 and 2
|
||
|
cost, pos1, pos2 = mcp._bestconn[(0, 2)]
|
||
|
# Test meeting points
|
||
|
assert (pos1, pos2) == ((1, 3), (1, 4))
|
||
|
# Test the whole path
|
||
|
path = mcp.traceback(pos1) + list(reversed(mcp.traceback(pos2)))
|
||
|
assert_array_equal(
|
||
|
path, [(1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7)])
|