138 lines
4.2 KiB
Python
138 lines
4.2 KiB
Python
|
import numpy as np
|
||
|
import heapq
|
||
|
|
||
|
|
||
|
def _revalidate_node_edges(rag, node, heap_list):
|
||
|
"""Handles validation and invalidation of edges incident to a node.
|
||
|
|
||
|
This function invalidates all existing edges incident on `node` and inserts
|
||
|
new items in `heap_list` updated with the valid weights.
|
||
|
|
||
|
rag : RAG
|
||
|
The Region Adjacency Graph.
|
||
|
node : int
|
||
|
The id of the node whose incident edges are to be validated/invalidated
|
||
|
.
|
||
|
heap_list : list
|
||
|
The list containing the existing heap of edges.
|
||
|
"""
|
||
|
# networkx updates data dictionary if edge exists
|
||
|
# this would mean we have to reposition these edges in
|
||
|
# heap if their weight is updated.
|
||
|
# instead we invalidate them
|
||
|
|
||
|
for nbr in rag.neighbors(node):
|
||
|
data = rag[node][nbr]
|
||
|
try:
|
||
|
# invalidate edges incident on `dst`, they have new weights
|
||
|
data['heap item'][3] = False
|
||
|
_invalidate_edge(rag, node, nbr)
|
||
|
except KeyError:
|
||
|
# will handle the case where the edge did not exist in the existing
|
||
|
# graph
|
||
|
pass
|
||
|
|
||
|
wt = data['weight']
|
||
|
heap_item = [wt, node, nbr, True]
|
||
|
data['heap item'] = heap_item
|
||
|
heapq.heappush(heap_list, heap_item)
|
||
|
|
||
|
|
||
|
def _rename_node(graph, node_id, copy_id):
|
||
|
""" Rename `node_id` in `graph` to `copy_id`. """
|
||
|
|
||
|
graph._add_node_silent(copy_id)
|
||
|
graph.nodes[copy_id].update(graph.nodes[node_id])
|
||
|
|
||
|
for nbr in graph.neighbors(node_id):
|
||
|
wt = graph[node_id][nbr]['weight']
|
||
|
graph.add_edge(nbr, copy_id, {'weight': wt})
|
||
|
|
||
|
graph.remove_node(node_id)
|
||
|
|
||
|
|
||
|
def _invalidate_edge(graph, n1, n2):
|
||
|
""" Invalidates the edge (n1, n2) in the heap. """
|
||
|
graph[n1][n2]['heap item'][3] = False
|
||
|
|
||
|
|
||
|
def merge_hierarchical(labels, rag, thresh, rag_copy, in_place_merge,
|
||
|
merge_func, weight_func):
|
||
|
"""Perform hierarchical merging of a RAG.
|
||
|
|
||
|
Greedily merges the most similar pair of nodes until no edges lower than
|
||
|
`thresh` remain.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
labels : ndarray
|
||
|
The array of labels.
|
||
|
rag : RAG
|
||
|
The Region Adjacency Graph.
|
||
|
thresh : float
|
||
|
Regions connected by an edge with weight smaller than `thresh` are
|
||
|
merged.
|
||
|
rag_copy : bool
|
||
|
If set, the RAG copied before modifying.
|
||
|
in_place_merge : bool
|
||
|
If set, the nodes are merged in place. Otherwise, a new node is
|
||
|
created for each merge..
|
||
|
merge_func : callable
|
||
|
This function is called before merging two nodes. For the RAG `graph`
|
||
|
while merging `src` and `dst`, it is called as follows
|
||
|
``merge_func(graph, src, dst)``.
|
||
|
weight_func : callable
|
||
|
The function to compute the new weights of the nodes adjacent to the
|
||
|
merged node. This is directly supplied as the argument `weight_func`
|
||
|
to `merge_nodes`.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
out : ndarray
|
||
|
The new labeled array.
|
||
|
|
||
|
"""
|
||
|
if rag_copy:
|
||
|
rag = rag.copy()
|
||
|
|
||
|
edge_heap = []
|
||
|
for n1, n2, data in rag.edges(data=True):
|
||
|
# Push a valid edge in the heap
|
||
|
wt = data['weight']
|
||
|
heap_item = [wt, n1, n2, True]
|
||
|
heapq.heappush(edge_heap, heap_item)
|
||
|
|
||
|
# Reference to the heap item in the graph
|
||
|
data['heap item'] = heap_item
|
||
|
|
||
|
while len(edge_heap) > 0 and edge_heap[0][0] < thresh:
|
||
|
_, n1, n2, valid = heapq.heappop(edge_heap)
|
||
|
|
||
|
# Ensure popped edge is valid, if not, the edge is discarded
|
||
|
if valid:
|
||
|
# Invalidate all neigbors of `src` before its deleted
|
||
|
|
||
|
for nbr in rag.neighbors(n1):
|
||
|
_invalidate_edge(rag, n1, nbr)
|
||
|
|
||
|
for nbr in rag.neighbors(n2):
|
||
|
_invalidate_edge(rag, n2, nbr)
|
||
|
|
||
|
if not in_place_merge:
|
||
|
next_id = rag.next_id()
|
||
|
_rename_node(rag, n2, next_id)
|
||
|
src, dst = n1, next_id
|
||
|
else:
|
||
|
src, dst = n1, n2
|
||
|
|
||
|
merge_func(rag, src, dst)
|
||
|
new_id = rag.merge_nodes(src, dst, weight_func)
|
||
|
_revalidate_node_edges(rag, new_id, edge_heap)
|
||
|
|
||
|
label_map = np.arange(labels.max() + 1)
|
||
|
for ix, (n, d) in enumerate(rag.nodes(data=True)):
|
||
|
for label in d['labels']:
|
||
|
label_map[label] = ix
|
||
|
|
||
|
return label_map[labels]
|