Skip to content

Conditioning with Turing Chains name_map #478

@ya0

Description

@ya0

My goal is to use the Turing framework to calculate individual predictions IPRED, for a spesific set of parameters. For example naively the mean of the posterior samples.
The idea was:

  1. Sample posterior and extract means for parameters.
  2. Instantiate a model with missing values for measurements, set error term in statistical model to very small number.
  3. Condition this model with the mean parameters (here the problem occurs). If you use a Turing chains name_map to build the dict for a parameter build with filldist then the conditioning does not work. You need to parse the values into a single name.
  4. Use rand() to calculate the predictions.

The problem: Turing parses the names individually for the multivariate distribution but conditioning expects only one name and a list of values. This is somehow counterintuitive. Calling the rand() function on the model conditioned with the Turing name_map will return (i.e. sample) the multidimensional parameters with the EXACT name I used to condition.

Maybe I am on the wrong path or conditioning could be extended to allow conditioning individual indexes of multidimensional priors.

Example of a longitudinal mixed effects model

$$ \begin{align*} f(t,a,b) &= a \cdot t + b \\ y_{ij} &\sim N(f(t_{ij},a_{i},b), f(t_{ij},a_{i},b)\cdot \sigma ) \end{align*} $$

After sampling the posterior for $a=(a_1, \dots, a_n),b$ we want to use the pointwise means $\hat a=(\hat a_1, \dots, \hat a_n),\hat b$ to calculate the predictions $f(t_{ij},\hat a_{i},\hat b)$ by calling
$$\hat y_{ij} \sim N(f(t_{ij},\hat a_{i},\hat b), f(t_{ij},\hat a_{i},\hat b)\cdot \hat\sigma )$$
With the trick that we set $\hat\sigma = 1e-308$ to the lowest floating point number $y_{ij}$ will be extremly close to $f(t_{ij},\hat a_{i},\hat b)$

in code:

## conditioning example
## Longitudinal mixed effects model
using Turing, Distributions, StatsPlots

# linear model
f(t,a,b) = a*t + b

# simple mixed effects model
# a individual parameters
# b population parameter
b = 10
n = 100
a = rand(LogNormal(1,1),n)
timepoints = range(0,20,9)

# simulate data with multiplicative error
# for every individual at every timepoint
σ = 0.01
y_longitudinal = []
for (id, aᵢ) in enumerate(a)
    for tⱼ in timepoints
        fᵢⱼ = f(tⱼ,aᵢ,b)
        measurement = rand(Normal(fᵢⱼ, fᵢⱼ * σ))
        push!(y_longitudinal, id, tⱼ, measurement)
    end
end
y_longitudinal = reshape(y_longitudinal, 3, 9*100)'


# corresponding model
@model function linear_longitudinal(y_ids, y_times, y_measurements)
    n = length(unique(y_ids))
    m = length(y_measurements)

    a ~ filldist(LogNormal(),n)
    b ~ truncated(Normal(10,1),0.1,100)
    σ ~ Beta(1.1,5)    

    for row in 1:m
        pred = f(y_times[row], a[y_ids[row]], b)
        y_measurements[row] ~ Normal(pred, pred * σ)
    end
end

# ----------------------------------------------------------
# 1. Sample posterior and extract means for parameters, set error term in statistical model to very small number
# ----------------------------------------------------------
y_ids, y_times, y_measurements = eachcol(y_longitudinal)
model = linear_longitudinal(y_ids, y_times, y_measurements)
chn = sample(model, NUTS(), 500)
pointwise_means = summarize(chn)[:,2]
pointwise_means[end] = floatmin()

# ----------------------------------------------------------
# 2. Instantiate a model with missing values for measurements
# ----------------------------------------------------------
pred_model = linear_longitudinal(y_ids, y_times, fill(missing,length(y_measurements)))

# ----------------------------------------------------------
# 3. Condition this model with the mean parameters
# first using the name_map.parameters from Turing
# ----------------------------------------------------------
names = chn.name_map.parameters

conditioning_1 = NamedTuple{Tuple(names)}(pointwise_means)
pred_model_conditioned_1 = pred_model | conditioning_1
# ----------------------------------------------------------
4. Use `rand()` to calculate the predictions
# ----------------------------------------------------------
pred_1 = rand(pred_model_conditioned_1)
pred_1_array = [v for v in pred_1]
# here we exclude the sample for a
pred_1_array = pred_1_array[2:end]
scatter(pred_1_array, y_measurements)
# notice how a is predicted even though we gave each a[i] a value


# ----------------------------------------------------------
# 3. Condition this model with the mean paramters
# parsing the parameters before conditioning
# ----------------------------------------------------------

# preprocessing and condition a on an array
pointwise_means_processed = [[pointwise_means[1:100]];pointwise_means[101:102]]
conditioning_2 = NamedTuple{Tuple([:a;names[101:102]])}(pointwise_means_processed)
pred_model_conditioned_2 = pred_model | conditioning_2
# ----------------------------------------------------------
4. Use `rand()` to calculate the predictions
# ----------------------------------------------------------
pred_2 = rand(pred_model_conditioned_2)
pred_2_array = [v for v in pred_2]
scatter(pred_2_array,y_measurements)
# here only the prediction of f conditioned on the means are sampled

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions