""" Algorithm for testing d-separation in DAGs. *d-separation* is a test for conditional independence in probability distributions that can be factorized using DAGs. It is a purely graphical test that uses the underlying graph and makes no reference to the actual distribution parameters. See [1]_ for a formal definition. The implementation is based on the conceptually simple linear time algorithm presented in [2]_. Refer to [3]_, [4]_ for a couple of alternative algorithms. Examples -------- >>> >>> # HMM graph with five states and observation nodes ... g = nx.DiGraph() >>> g.add_edges_from( ... [ ... ("S1", "S2"), ... ("S2", "S3"), ... ("S3", "S4"), ... ("S4", "S5"), ... ("S1", "O1"), ... ("S2", "O2"), ... ("S3", "O3"), ... ("S4", "O4"), ... ("S5", "O5"), ... ] ... ) >>> >>> # states/obs before 'S3' are d-separated from states/obs after 'S3' ... nx.d_separated(g, {"S1", "S2", "O1", "O2"}, {"S4", "S5", "O4", "O5"}, {"S3"}) True References ---------- .. [1] Pearl, J. (2009). Causality. Cambridge: Cambridge University Press. .. [2] Darwiche, A. (2009). Modeling and reasoning with Bayesian networks. Cambridge: Cambridge University Press. .. [3] Shachter, R. D. (1998). Bayes-ball: rational pastime (for determining irrelevance and requisite information in belief networks and influence diagrams). In , Proceedings of the Fourteenth Conference on Uncertainty in Artificial Intelligence (pp. 480–487). San Francisco, CA, USA: Morgan Kaufmann Publishers Inc. .. [4] Koller, D., & Friedman, N. (2009). Probabilistic graphical models: principles and techniques. The MIT Press. """ from collections import deque from typing import AbstractSet import networkx as nx from networkx.utils import not_implemented_for, UnionFind __all__ = ["d_separated"] @not_implemented_for("undirected") def d_separated(G: nx.DiGraph, x: AbstractSet, y: AbstractSet, z: AbstractSet) -> bool: """ Return whether node sets ``x`` and ``y`` are d-separated by ``z``. Parameters ---------- G : graph A NetworkX DAG. x : set First set of nodes in ``G``. y : set Second set of nodes in ``G``. z : set Set of conditioning nodes in ``G``. Can be empty set. Returns ------- b : bool A boolean that is true if ``x`` is d-separated from ``y`` given ``z`` in ``G``. Raises ------ NetworkXError The *d-separation* test is commonly used with directed graphical models which are acyclic. Accordingly, the algorithm raises a :exc:`NetworkXError` if the input graph is not a DAG. NodeNotFound If any of the input nodes are not found in the graph, a :exc:`NodeNotFound` exception is raised. """ if not nx.is_directed_acyclic_graph(G): raise nx.NetworkXError("graph should be directed acyclic") union_xyz = x.union(y).union(z) if any(n not in G.nodes for n in union_xyz): raise nx.NodeNotFound("one or more specified nodes not found in the graph") G_copy = G.copy() # transform the graph by removing leaves that are not in x | y | z # until no more leaves can be removed. leaves = deque([n for n in G_copy.nodes if G_copy.out_degree[n] == 0]) while len(leaves) > 0: leaf = leaves.popleft() if leaf not in union_xyz: for p in G_copy.predecessors(leaf): if G_copy.out_degree[p] == 1: leaves.append(p) G_copy.remove_node(leaf) # transform the graph by removing outgoing edges from the # conditioning set. edges_to_remove = list(G_copy.out_edges(z)) G_copy.remove_edges_from(edges_to_remove) # use disjoint-set data structure to check if any node in `x` # occurs in the same weakly connected component as a node in `y`. disjoint_set = UnionFind(G_copy.nodes()) for component in nx.weakly_connected_components(G_copy): disjoint_set.union(*component) disjoint_set.union(*x) disjoint_set.union(*y) if x and y and disjoint_set[next(iter(x))] == disjoint_set[next(iter(y))]: return False else: return True