136 lines
4.1 KiB
Python
136 lines
4.1 KiB
Python
"""
|
||
Algorithm for testing d-separation in DAGs.
|
||
|
||
*d-separation* is a test for conditional independence in probability
|
||
distributions that can be factorized using DAGs. It is a purely
|
||
graphical test that uses the underlying graph and makes no reference
|
||
to the actual distribution parameters. See [1]_ for a formal
|
||
definition.
|
||
|
||
The implementation is based on the conceptually simple linear time
|
||
algorithm presented in [2]_. Refer to [3]_, [4]_ for a couple of
|
||
alternative algorithms.
|
||
|
||
|
||
Examples
|
||
--------
|
||
|
||
>>>
|
||
>>> # HMM graph with five states and observation nodes
|
||
... g = nx.DiGraph()
|
||
>>> g.add_edges_from(
|
||
... [
|
||
... ("S1", "S2"),
|
||
... ("S2", "S3"),
|
||
... ("S3", "S4"),
|
||
... ("S4", "S5"),
|
||
... ("S1", "O1"),
|
||
... ("S2", "O2"),
|
||
... ("S3", "O3"),
|
||
... ("S4", "O4"),
|
||
... ("S5", "O5"),
|
||
... ]
|
||
... )
|
||
>>>
|
||
>>> # states/obs before 'S3' are d-separated from states/obs after 'S3'
|
||
... nx.d_separated(g, {"S1", "S2", "O1", "O2"}, {"S4", "S5", "O4", "O5"}, {"S3"})
|
||
True
|
||
|
||
|
||
References
|
||
----------
|
||
|
||
.. [1] Pearl, J. (2009). Causality. Cambridge: Cambridge University Press.
|
||
|
||
.. [2] Darwiche, A. (2009). Modeling and reasoning with Bayesian networks. Cambridge: Cambridge University Press.
|
||
|
||
.. [3] Shachter, R. D. (1998). Bayes-ball: rational pastime (for determining irrelevance and requisite information in belief networks and influence diagrams). In , Proceedings of the Fourteenth Conference on Uncertainty in Artificial Intelligence (pp. 480–487). San Francisco, CA, USA: Morgan Kaufmann Publishers Inc.
|
||
|
||
.. [4] Koller, D., & Friedman, N. (2009). Probabilistic graphical models: principles and techniques. The MIT Press.
|
||
|
||
"""
|
||
|
||
from collections import deque
|
||
from typing import AbstractSet
|
||
|
||
import networkx as nx
|
||
from networkx.utils import not_implemented_for, UnionFind
|
||
|
||
__all__ = ["d_separated"]
|
||
|
||
|
||
@not_implemented_for("undirected")
|
||
def d_separated(G: nx.DiGraph, x: AbstractSet, y: AbstractSet, z: AbstractSet) -> bool:
|
||
"""
|
||
Return whether node sets ``x`` and ``y`` are d-separated by ``z``.
|
||
|
||
Parameters
|
||
----------
|
||
G : graph
|
||
A NetworkX DAG.
|
||
|
||
x : set
|
||
First set of nodes in ``G``.
|
||
|
||
y : set
|
||
Second set of nodes in ``G``.
|
||
|
||
z : set
|
||
Set of conditioning nodes in ``G``. Can be empty set.
|
||
|
||
Returns
|
||
-------
|
||
b : bool
|
||
A boolean that is true if ``x`` is d-separated from ``y`` given ``z`` in ``G``.
|
||
|
||
Raises
|
||
------
|
||
NetworkXError
|
||
The *d-separation* test is commonly used with directed
|
||
graphical models which are acyclic. Accordingly, the algorithm
|
||
raises a :exc:`NetworkXError` if the input graph is not a DAG.
|
||
|
||
NodeNotFound
|
||
If any of the input nodes are not found in the graph,
|
||
a :exc:`NodeNotFound` exception is raised.
|
||
|
||
"""
|
||
|
||
if not nx.is_directed_acyclic_graph(G):
|
||
raise nx.NetworkXError("graph should be directed acyclic")
|
||
|
||
union_xyz = x.union(y).union(z)
|
||
|
||
if any(n not in G.nodes for n in union_xyz):
|
||
raise nx.NodeNotFound("one or more specified nodes not found in the graph")
|
||
|
||
G_copy = G.copy()
|
||
|
||
# transform the graph by removing leaves that are not in x | y | z
|
||
# until no more leaves can be removed.
|
||
leaves = deque([n for n in G_copy.nodes if G_copy.out_degree[n] == 0])
|
||
while len(leaves) > 0:
|
||
leaf = leaves.popleft()
|
||
if leaf not in union_xyz:
|
||
for p in G_copy.predecessors(leaf):
|
||
if G_copy.out_degree[p] == 1:
|
||
leaves.append(p)
|
||
G_copy.remove_node(leaf)
|
||
|
||
# transform the graph by removing outgoing edges from the
|
||
# conditioning set.
|
||
edges_to_remove = list(G_copy.out_edges(z))
|
||
G_copy.remove_edges_from(edges_to_remove)
|
||
|
||
# use disjoint-set data structure to check if any node in `x`
|
||
# occurs in the same weakly connected component as a node in `y`.
|
||
disjoint_set = UnionFind(G_copy.nodes())
|
||
for component in nx.weakly_connected_components(G_copy):
|
||
disjoint_set.union(*component)
|
||
disjoint_set.union(*x)
|
||
disjoint_set.union(*y)
|
||
|
||
if x and y and disjoint_set[next(iter(x))] == disjoint_set[next(iter(y))]:
|
||
return False
|
||
else:
|
||
return True
|