Understanding the causal relationships between variables is crucial for making informed decisions and predictions.
The Python package causal-learn
offers a powerful toolkit for causal discovery,
enabling researchers and practitioners to uncover these relationships from observational data.
In this blog post, we’ll explore how to leverage causal-learn to identify and analyze causal structures,
providing a step-by-step guide to get you started on your causal discovery journey.
Data Generation
First, let’s generate some data (inspired by this notebook):
import numpy as np
from causallearn.search.FCMBased import lingam
from causallearn.search.FCMBased.lingam.utils import make_dot
import networkx as nx
import seaborn as sns
N = 1000
q = np.random.uniform(0, 2, N)
w = np.random.randn(N)
x = np.random.gumbel(0, 1, N) + w
y = 0.6 * q + 0.8 * w + np.random.uniform(0, 1, N)
z = 0.5 * x + np.random.randn(N)
data = np.stack([x, y, w, z, q]).T
This is the DAG of our data:
nodes = ['X', 'Y', 'W', 'Z', 'Q']
edges = [
('W', 'X'),
('W', 'Y'),
('Q', 'Y'),
('X', 'Z'),
]
fci_graph = nx.DiGraph()
fci_graph.add_nodes_from(nodes)
fci_graph.add_edges_from(edges)
nx.draw(
G=fci_graph,
node_color='#00B0F0',
nodelist=['X', 'Y', 'W', 'Z', 'Q'],
with_labels=True,
pos=nx.circular_layout(fci_graph)
)
Causal Discovery using LiNGAM
Now, let’s estimate the causal graph using the LiNGAM method:
model = lingam.DirectLiNGAM(random_state=42)
model.fit(data)
causal-learn
has a built-in function to visualize the estimated DAG:
make_dot(model.adjacency_matrix_, labels=nodes)
We can also use networkx
to visualize the result:
G = nx.DiGraph(model.adjacency_matrix_.T)
nx.draw(G,
with_labels=True,
pos=nx.circular_layout(G))
We see that LiNGAM estimates the true DAG quite accurately! Another way to visualize the adjacency matrix is a heatmap:
sns.heatmap(model.adjacency_matrix_.T, cmap="rocket_r", cbar=False)
Note that the edge from X to Z has a low probability, if we set the threshold to 0.5 it would actually disappear:
sns.heatmap(model.adjacency_matrix_.T>0.5, cmap="rocket_r", cbar=False)