Exact Inference
Exact Inference as Conditional Distribution
As one of the tasks on Probabilistic Models.
Consider a probabilistic models, where we are given
- some observed evidence \(X_F\), and
- some unobserved random variables \(X_F\) that we are interested in the distribution
- other variables \(X_R\) that are not observed and not intersted.
Inference is intersted in finding the conditional distribution
Thus, we need to marginalize all \(X_R\), and consider the conditional probability.
Variable Elimination
Consider the conditional distributino encountered, note that we need to do a huge number of summations. For example, consider a simple chaining of variables \(A\rightarrow B\rightarrow C\rightarrow D\) and we are interested in
If we do the summation naively, it will be
Resulting \(O(k^n)\) time, where \(k\) \(k\) is the number of states in each variable and \(n\) is the number of variables.
On the other hand, we can use dynamic programming by decomposing the triple summations, to do varaible elimination. Obverse that
Thus, the runtime is reduced to \(O(nk^2)\)
Intermediate Factors
Consider the distribution given by
Suppose that we'd like to marginalize over \(X\), so that
However, \(\sum_X p(X)p(A|X)p(C|B,X)\) is not a valid conditional or marginal distribution, since it is unnormalized.
Note that the only purpose we write these intermediate distribution is to cache them in dynamic tables for the final computation results. Thus, we don't necessarily need them to be a distribution, until we finish the computations.
Additionally, for each conditional distributions \(P(A|B)\), it is a function of variables \(A,B\). Thus, we introduce factor \(\phi\) which are not necessarily normalized distributions, but describe the local relationship between random variables.
In addition, for the summation that we want to temporarily store. We introduce another intermediate factor \(\tau\), for example, we can let \(\tau(A,B,C) = \sum_X p(X)p(A|X)p(C|B,X)\) so that we have \(X\) eliminated. More formally,
where, for dag, \(\Phi\) is given by
VE Implementation
Note that the above VE algorithm is an abstraction. Where we are summing up probability functions for each state. Now, consider an implementation where each variable has finite number of states, and each state \(p(X=x)\) is associated with a fixed number so that the probability functions are well-defined.
Consider a set of conditional probabilities \(\phi\in\Phi\), a set of query variables \(X_f \in Q\), set of evidence variables \(X_e \in E\) with observed values \(X_e = x_e\) and a sequence of remaining variables \(X_r\in Z\).
for each observed variable Xe in E:
for each factor phi(..., Xe) that mentioned Xe:
replace factor with restricted factor phi(..., Xe=xe)
for each Xr in Z:
Phi_Xr = the set of factors in Phi that mentioned Xr
tau = sum(prod(Phi_Xr))
remove Phi_Xr from Phi
add tau to Phi
# all variables are eliminated now
return normalize(prod(Phi))
Factors
Each factor \(\phi\) or \(\tau\) is a function that takes a specific state set of scoped variables, and return a positive real number. Thus, they are implemented as a lookup table, where each line is the specific state config, and its associated value. For each table, there are \(\prod_{X_i\in scope(\phi)} |X_i|\) states (table rows).
For \(\phi\)'s, we directly obtain them from the conditional probability functions at initalization time. For example, we initialize \(\phi(A,B) = p(A|B)\). For \(\tau\), we obtain them from prod
and sum
.
import pandas as pd
f = pd.DataFrame({"A": [0, 0, 1, 1], "B": [0, 1, 0, 1], "value": [.9, .1, .4, .6]})
f
A | B | value | |
---|---|---|---|
0 | 0 | 0 | 0.9 |
1 | 0 | 1 | 0.1 |
2 | 1 | 0 | 0.4 |
3 | 1 | 1 | 0.6 |
B | C | value | |
---|---|---|---|
0 | 0 | 0 | 0.7 |
1 | 0 | 1 | 0.3 |
2 | 1 | 0 | 0.8 |
3 | 1 | 1 | 0.2 |
Product
prod(f,g)
takes two factors (tables) \(f,g\) with a scope variable in common, and returns a new factor \(h\).
We take the inner join of the two factors, and multiply the values for each row.
def prod(f, g):
f = f.rename(columns={"value": "value_x"})
g = g.rename(columns={"value": "value_y"})
h = f.merge(g)
h['value'] = h['value_x'] * h['value_y']
h = h.drop(['value_x', 'value_y'], axis=1)
return h
h_prod = prod(f, g)
h_prod
A | B | C | value | |
---|---|---|---|---|
0 | 0 | 0 | 0 | 0.63 |
1 | 0 | 0 | 1 | 0.27 |
2 | 1 | 0 | 0 | 0.28 |
3 | 1 | 0 | 1 | 0.12 |
4 | 0 | 1 | 0 | 0.08 |
5 | 0 | 1 | 1 | 0.02 |
6 | 1 | 1 | 0 | 0.48 |
7 | 1 | 1 | 1 | 0.12 |
Sum
sum(f, X)
takes a factor \(f\) and a variable \(X\), and returns a new factor by summing up \(X\) from \(f\).
def sum(f, X):
f_group = f.groupby(list(set(f.columns) - {X, "value"}))[['value']].sum()
new_f = f_group.reset_index()
return new_f
h_sum = sum(h_prod, "C")
h_sum
B | A | value | |
---|---|---|---|
0 | 0 | 0 | 0.9 |
1 | 0 | 1 | 0.4 |
2 | 1 | 0 | 0.1 |
3 | 1 | 1 | 0.6 |
Restriction
restrict(f, X, x)
takes factor \(f\), an evidence variable \(X\) and the evidence value \(x\), and returns a new factor that only contains rows that \(X=x\).
A | B | value | |
---|---|---|---|
1 | 0 | 0 | 0.27 |
3 | 1 | 0 | 0.12 |
5 | 0 | 1 | 0.02 |
7 | 1 | 1 | 0.12 |
Implementation
"""Variable Elimination
Args:
Phi: A list of factors as pd.DataFrame
Q: A list of str, representing the query variable
E: A list of (str, state), representing the evidence var and evidence
R: A list of str, given the elimination ordering
"""
for evar, evidence in E:
for i, f in enumerate(Phi):
if evar in f.columns:
Phi[i] = restrict(f, evar, evidence)
for var in R:
tau = None
to_remove = []
for i, f in enumerate(Phi):
if var in f.columns:
tau = prod(f, tau) if tau is not None else f
to_remove.append(i)
while len(to_remove) > 0:
del Phi[to_remove.pop()]
if tau is not None:
tau = sum(tau, var)
Phi.append(tau)
p = Phi[0]
for tau in Phi[1:]:
p = prod(p, tau)
p['value'] /= p['value'].sum()
return p
VE Ordering and Message Passing
Consider a model \(T=(V,E)\) be a tree. Let \(N(i)\) be the neighbors of vertex \(i\). Then, the joint distribution is
where the factors are initialized from given conditional probabilities and \(Z\) is the normalizer.
Now, define the message passing as
If \(x_j\) is observed with value \(\bar x_j\), since we will restrict \(x_j = \bar x_j\), the message passing becomes
Once the message passing is complete, we can compute beliefs
In the case of a tree, the leaf will only have its parent being the neighbor. Therefore, if we start message passing from each leaf, and then propagate till the root, we can cache the numerical values of the message passing on each edge, without recomputing any edge.
Thus, by the tree property, we have the message passing algorithm
- choose any vertex be the root \(r\).
- message passing from all leafs to \(r\), and then message passing from \(r\) to leafs
- For each query variable, compute belifs and normalize it
Message Passing for VE
Note that the time complexity of VE is
where \(m\) is the number of initial factors, \(k\) is the number of states for each r.v. , \(N_{\max}\) is the max number of random variables inside some summation. Thus, the ordering for VE is important for the running time.
Determining the optimal ordering on a arbitrary graph is NP-hard. However, we have optimal orderings on trees, where any elimination ordering that goes from the leaves inwards towards any root will be optimal.
If we have a DAGM that is a tree, we can directly eliminate variables from the leaf till the query variables. In this case, we will have optimal runtime and the computation of message passing is actually the same as VE.