137 lines
4.1 KiB
Python
137 lines
4.1 KiB
Python
|
"""
|
|||
|
Algorithm for testing d-separation in DAGs.
|
|||
|
|
|||
|
*d-separation* is a test for conditional independence in probability
|
|||
|
distributions that can be factorized using DAGs. It is a purely
|
|||
|
graphical test that uses the underlying graph and makes no reference
|
|||
|
to the actual distribution parameters. See [1]_ for a formal
|
|||
|
definition.
|
|||
|
|
|||
|
The implementation is based on the conceptually simple linear time
|
|||
|
algorithm presented in [2]_. Refer to [3]_, [4]_ for a couple of
|
|||
|
alternative algorithms.
|
|||
|
|
|||
|
|
|||
|
Examples
|
|||
|
--------
|
|||
|
|
|||
|
>>>
|
|||
|
>>> # HMM graph with five states and observation nodes
|
|||
|
... g = nx.DiGraph()
|
|||
|
>>> g.add_edges_from(
|
|||
|
... [
|
|||
|
... ("S1", "S2"),
|
|||
|
... ("S2", "S3"),
|
|||
|
... ("S3", "S4"),
|
|||
|
... ("S4", "S5"),
|
|||
|
... ("S1", "O1"),
|
|||
|
... ("S2", "O2"),
|
|||
|
... ("S3", "O3"),
|
|||
|
... ("S4", "O4"),
|
|||
|
... ("S5", "O5"),
|
|||
|
... ]
|
|||
|
... )
|
|||
|
>>>
|
|||
|
>>> # states/obs before 'S3' are d-separated from states/obs after 'S3'
|
|||
|
... nx.d_separated(g, {"S1", "S2", "O1", "O2"}, {"S4", "S5", "O4", "O5"}, {"S3"})
|
|||
|
True
|
|||
|
|
|||
|
|
|||
|
References
|
|||
|
----------
|
|||
|
|
|||
|
.. [1] Pearl, J. (2009). Causality. Cambridge: Cambridge University Press.
|
|||
|
|
|||
|
.. [2] Darwiche, A. (2009). Modeling and reasoning with Bayesian networks. Cambridge: Cambridge University Press.
|
|||
|
|
|||
|
.. [3] Shachter, R. D. (1998). Bayes-ball: rational pastime (for determining irrelevance and requisite information in belief networks and influence diagrams). In , Proceedings of the Fourteenth Conference on Uncertainty in Artificial Intelligence (pp. 480–487). San Francisco, CA, USA: Morgan Kaufmann Publishers Inc.
|
|||
|
|
|||
|
.. [4] Koller, D., & Friedman, N. (2009). Probabilistic graphical models: principles and techniques. The MIT Press.
|
|||
|
|
|||
|
"""
|
|||
|
|
|||
|
from collections import deque
|
|||
|
from typing import AbstractSet
|
|||
|
|
|||
|
import networkx as nx
|
|||
|
from networkx.utils import not_implemented_for, UnionFind
|
|||
|
|
|||
|
__all__ = ["d_separated"]
|
|||
|
|
|||
|
|
|||
|
@not_implemented_for("undirected")
|
|||
|
def d_separated(G: nx.DiGraph, x: AbstractSet, y: AbstractSet, z: AbstractSet) -> bool:
|
|||
|
"""
|
|||
|
Return whether node sets ``x`` and ``y`` are d-separated by ``z``.
|
|||
|
|
|||
|
Parameters
|
|||
|
----------
|
|||
|
G : graph
|
|||
|
A NetworkX DAG.
|
|||
|
|
|||
|
x : set
|
|||
|
First set of nodes in ``G``.
|
|||
|
|
|||
|
y : set
|
|||
|
Second set of nodes in ``G``.
|
|||
|
|
|||
|
z : set
|
|||
|
Set of conditioning nodes in ``G``. Can be empty set.
|
|||
|
|
|||
|
Returns
|
|||
|
-------
|
|||
|
b : bool
|
|||
|
A boolean that is true if ``x`` is d-separated from ``y`` given ``z`` in ``G``.
|
|||
|
|
|||
|
Raises
|
|||
|
------
|
|||
|
NetworkXError
|
|||
|
The *d-separation* test is commonly used with directed
|
|||
|
graphical models which are acyclic. Accordingly, the algorithm
|
|||
|
raises a :exc:`NetworkXError` if the input graph is not a DAG.
|
|||
|
|
|||
|
NodeNotFound
|
|||
|
If any of the input nodes are not found in the graph,
|
|||
|
a :exc:`NodeNotFound` exception is raised.
|
|||
|
|
|||
|
"""
|
|||
|
|
|||
|
if not nx.is_directed_acyclic_graph(G):
|
|||
|
raise nx.NetworkXError("graph should be directed acyclic")
|
|||
|
|
|||
|
union_xyz = x.union(y).union(z)
|
|||
|
|
|||
|
if any(n not in G.nodes for n in union_xyz):
|
|||
|
raise nx.NodeNotFound("one or more specified nodes not found in the graph")
|
|||
|
|
|||
|
G_copy = G.copy()
|
|||
|
|
|||
|
# transform the graph by removing leaves that are not in x | y | z
|
|||
|
# until no more leaves can be removed.
|
|||
|
leaves = deque([n for n in G_copy.nodes if G_copy.out_degree[n] == 0])
|
|||
|
while len(leaves) > 0:
|
|||
|
leaf = leaves.popleft()
|
|||
|
if leaf not in union_xyz:
|
|||
|
for p in G_copy.predecessors(leaf):
|
|||
|
if G_copy.out_degree[p] == 1:
|
|||
|
leaves.append(p)
|
|||
|
G_copy.remove_node(leaf)
|
|||
|
|
|||
|
# transform the graph by removing outgoing edges from the
|
|||
|
# conditioning set.
|
|||
|
edges_to_remove = list(G_copy.out_edges(z))
|
|||
|
G_copy.remove_edges_from(edges_to_remove)
|
|||
|
|
|||
|
# use disjoint-set data structure to check if any node in `x`
|
|||
|
# occurs in the same weakly connected component as a node in `y`.
|
|||
|
disjoint_set = UnionFind(G_copy.nodes())
|
|||
|
for component in nx.weakly_connected_components(G_copy):
|
|||
|
disjoint_set.union(*component)
|
|||
|
disjoint_set.union(*x)
|
|||
|
disjoint_set.union(*y)
|
|||
|
|
|||
|
if x and y and disjoint_set[next(iter(x))] == disjoint_set[next(iter(y))]:
|
|||
|
return False
|
|||
|
else:
|
|||
|
return True
|