@@ -17,60 +17,118 @@ module internal Utils =
1717
1818
1919[<AbstractClass>]
20+ /// <summary>TBD</summary>
2021type Distribution < 'T >() =
22+
23+ /// <summary>TBD</summary>
2124 abstract member sample: unit -> 'T
25+
26+ /// <summary>TBD</summary>
2227 abstract member logprob: 'T -> Tensor
2328
2429
2530[<AbstractClass>]
31+ /// <summary>TBD</summary>
2632type TensorDistribution () =
2733 inherit Distribution< Tensor>()
34+
35+ /// <summary>TBD</summary>
2836 member d.sample ( numSamples : int ) = Array.init numSamples ( fun _ -> d.sample()) |> dsharp.stack
37+
38+ /// <summary>TBD</summary>
2939 abstract member batchShape: int []
40+
41+ /// <summary>TBD</summary>
3042 abstract member eventShape: int []
43+
44+ /// <summary>TBD</summary>
3145 abstract member mean: Tensor
46+
47+ /// <summary>TBD</summary>
3248 abstract member stddev: Tensor
49+
50+ /// <summary>TBD</summary>
3351 abstract member variance: Tensor
52+
3453 default d.stddev = d.variance.sqrt()
3554 default d.variance = d.stddev * d.stddev
55+
56+ /// <summary>TBD</summary>
3657 member d.prob ( value ) = d.logprob( value) .exp()
3758
3859
60+ /// <summary>TBD</summary>
3961type Normal ( mean : Tensor , stddev : Tensor ) =
4062 inherit TensorDistribution()
4163 do if mean.shape <> stddev.shape then failwithf " Expecting mean and standard deviation with same shape, received %A , %A " mean.shape stddev.shape
4264 do if mean.dim > 1 then failwithf " Expecting scalar parameters (0-dimensional mean and stddev) or a batch of scalar parameters (1-dimensional mean and stddev)"
65+
66+ /// <summary>TBD</summary>
4367 override d.batchShape = d.mean.shape
68+
69+ /// <summary>TBD</summary>
4470 override d.eventShape = [||]
71+
72+ /// <summary>TBD</summary>
4573 override d.mean = mean
74+
75+ /// <summary>TBD</summary>
4676 override d.stddev = stddev
77+
78+ /// <summary>TBD</summary>
4779 override d.sample () = d.mean + dsharp.randnLike( d.mean) * d.stddev
80+
81+ /// <summary>TBD</summary>
4882 override d.logprob ( value ) =
4983 if value.shape <> d.batchShape then failwithf " Expecting a value with shape %A , received %A " d.batchShape value.shape
5084 let v = value - d.mean in -( v * v) / ( 2. * d.variance) - ( log d.stddev) - logSqrt2Pi
85+
86+ /// <summary>TBD</summary>
5187 override d.ToString () = sprintf " Normal(mean:%A , stddev:%A )" d.mean d.stddev
5288
5389
90+ /// <summary>TBD</summary>
5491type Uniform ( low : Tensor , high : Tensor ) =
5592 inherit TensorDistribution()
5693 do if low.shape <> high.shape then failwithf " Expecting low and high with same shape, received %A , %A " low.shape high.shape
5794 do if low.dim > 1 then failwithf " Expecting scalar parameters (0-dimensional low and high) or a batch of scalar parameters (1-dimensional low and high)"
95+
96+ /// <summary>TBD</summary>
5897 member d.low = low
98+
99+ /// <summary>TBD</summary>
59100 member d.high = high
101+
102+ /// <summary>TBD</summary>
60103 member d.range = high - low
104+
105+ /// <summary>TBD</summary>
61106 override d.batchShape = low.shape
107+
108+ /// <summary>TBD</summary>
62109 override d.eventShape = [||]
110+
111+ /// <summary>TBD</summary>
63112 override d.mean = ( low + high) / 2.
113+
114+ /// <summary>TBD</summary>
64115 override d.variance = d.range * d.range / 12.
116+
117+ /// <summary>TBD</summary>
65118 override d.sample () = d.low + dsharp.randLike( d.low) * d.range
119+
120+ /// <summary>TBD</summary>
66121 override d.logprob ( value ) =
67122 if value.shape <> d.batchShape then failwithf " Expecting a value with shape %A , received %A " d.batchShape value.shape
68123 let lb = low.le( value) .cast( low.dtype)
69124 let ub = high.gt( value) .cast( high.dtype)
70125 log ( lb * ub) - log d.range
126+
127+ /// <summary>TBD</summary>
71128 override d.ToString () = sprintf " Uniform(low:%A , high:%A )" d.low d.high
72129
73130
131+ /// <summary>TBD</summary>
74132type Bernoulli (? probs : Tensor , ? logits : Tensor ) =
75133 inherit TensorDistribution()
76134 let _probs , _logits , _dtype =
@@ -79,20 +137,39 @@ type Bernoulli(?probs:Tensor, ?logits:Tensor) =
79137 | Some p, None -> let pp = p.float() in clampProbs pp, probsToLogits pp true , p.dtype // Do not normalize probs
80138 | None, Some lp -> let lpp = lp.float() in logitsToProbs lpp true , lpp, lp.dtype // Do not normalize logits
81139 | None, None -> failwithf " Expecting either probs or logits"
140+
141+ /// <summary>TBD</summary>
82142 member d.probs = _ probs.cast(_ dtype)
143+
144+ /// <summary>TBD</summary>
83145 member d.logits = _ logits.cast(_ dtype)
146+
147+ /// <summary>TBD</summary>
84148 override d.batchShape = d.probs.shape
149+
150+ /// <summary>TBD</summary>
85151 override d.eventShape = [||]
152+
153+ /// <summary>TBD</summary>
86154 override d.mean = d.probs
155+
156+ /// <summary>TBD</summary>
87157 override d.variance = (_ probs * ( 1. - _ probs)) .cast(_ dtype)
158+
159+ /// <summary>TBD</summary>
88160 override d.sample () = dsharp.bernoulli(_ probs) .cast(_ dtype)
161+
162+ /// <summary>TBD</summary>
89163 override d.logprob ( value ) =
90164 if value.shape <> d.batchShape then failwithf " Expecting a value with shape %A , received %A " d.batchShape value.shape
91165 let lp = (_ probs ** value) * (( 1. - _ probs) ** ( 1. - value)) // Correct but numerical stability can be improved
92166 lp.log() .cast(_ dtype)
167+
168+ /// <summary>TBD</summary>
93169 override d.ToString () = sprintf " Bernoulli(probs:%A )" d.probs
94170
95171
172+ /// <summary>TBD</summary>
96173type Categorical (? probs : Tensor , ? logits : Tensor ) =
97174 inherit TensorDistribution()
98175 let _probs , _logits , _dtype =
@@ -102,13 +179,29 @@ type Categorical(?probs:Tensor, ?logits:Tensor) =
102179 | None, Some lp -> let lpp = ( lp - lp.logsumexp(- 1 , keepDim= true )) .float() in logitsToProbs lpp false , lpp, lp.dtype // Normalize logits
103180 | None, None -> failwithf " Expecting either probs or logits"
104181 do if _ probs.dim < 1 || _ probs.dim > 2 then failwithf " Expecting vector parameters (1-dimensional probs or logits) or batch of vector parameters (2-dimensional probs or logits), received shape %A " _ probs.shape
182+
183+ /// <summary>TBD</summary>
105184 member d.probs = _ probs.cast(_ dtype)
185+
186+ /// <summary>TBD</summary>
106187 member d.logits = _ logits.cast(_ dtype)
188+
189+ /// <summary>TBD</summary>
107190 override d.batchShape = if d.probs.dim = 1 then [||] else [| d.probs.shape.[ 0 ]|]
191+
192+ /// <summary>TBD</summary>
108193 override d.eventShape = [||]
194+
195+ /// <summary>TBD</summary>
109196 override d.mean = dsharp.onesLike( d.probs) * System.Double.NaN
197+
198+ /// <summary>TBD</summary>
110199 override d.stddev = dsharp.onesLike( d.probs) * System.Double.NaN
200+
201+ /// <summary>TBD</summary>
111202 override d.sample () = dsharp.multinomial(_ probs, 1 ) .cast(_ dtype) .squeeze()
203+
204+ /// <summary>TBD</summary>
112205 override d.logprob ( value ) =
113206 if value.shape <> d.batchShape then failwithf " Expecting a value with shape %A , received %A " d.batchShape value.shape
114207 if d.batchShape.Length = 0 then
@@ -118,9 +211,12 @@ type Categorical(?probs:Tensor, ?logits:Tensor) =
118211 let is = value.int() .toArray() :?> int[]
119212 let lp = Array.init d.batchShape.[ 0 ] ( fun i -> _ logits.[ i, is.[ i]]) |> dsharp.stack
120213 lp.cast(_ dtype)
214+
215+ /// <summary>TBD</summary>
121216 override d.ToString () = sprintf " Categorical(probs:%A )" d.probs
122217
123218
219+ /// <summary>TBD</summary>
124220type Empirical < 'T when 'T:equality >( values :seq < 'T >, ? weights : Tensor , ? logWeights : Tensor , ? combineDuplicates : bool ) =
125221 inherit Distribution< 'T>()
126222 let _categorical , _weighted =
@@ -156,25 +252,49 @@ type Empirical<'T when 'T:equality>(values:seq<'T>, ?weights:Tensor, ?logWeights
156252 _ values <- newValues
157253 _ categorical <- Categorical( logits= newLogWeights)
158254 _ weighted <- not ( Seq.allEqual (_ categorical.probs.unstack()))
255+
256+ /// <summary>TBD</summary>
159257 member d.values = _ values
258+
259+ /// <summary>TBD</summary>
160260 member d.valuesTensor = _ valuesTensor.Force()
261+
262+ /// <summary>TBD</summary>
161263 member d.length = d.values.Length
264+
265+ /// <summary>TBD</summary>
162266 member d.weights = _ categorical.probs
267+
268+ /// <summary>TBD</summary>
163269 member d.logWeights = _ categorical.logits
270+
271+ /// <summary>TBD</summary>
164272 member d.isWeighted = _ weighted
273+
274+ /// <summary>TBD</summary>
165275 member d.Item
166276 with get( i ) = d.values.[ i], d.weights.[ i]
277+
278+ /// <summary>TBD</summary>
167279 member d.GetSlice ( start , finish ) =
168280 let start = defaultArg start 0
169281 let finish = defaultArg finish d.length - 1
170282 Empirical( d.values.[ start.. finish], logWeights= d.logWeights.[ start.. finish])
283+
284+ /// <summary>TBD</summary>
171285 member d.unweighted () = Empirical( d.values)
286+
287+ /// <summary>TBD</summary>
172288 member d.map ( f : 'T -> 'a ) = Empirical( Array.map f d.values, logWeights= d.logWeights)
289+
290+ /// <summary>TBD</summary>
173291 member d.filter ( predicate : 'T -> bool ) =
174292 let results = ResizeArray< 'T* Tensor>()
175293 Array.iteri ( fun i v -> if predicate v then results.Add( v, d.logWeights.[ i])) d.values
176294 let v , lw = results.ToArray() |> Array.unzip
177295 Empirical( v, logWeights= dsharp.stack( lw))
296+
297+ /// <summary>TBD</summary>
178298 member d.sample (? minIndex : int , ? maxIndex : int ) = // minIndex is inclusive, maxIndex is exclusive
179299 let minIndex = defaultArg minIndex 0
180300 let maxIndex = defaultArg maxIndex d.length
@@ -185,7 +305,11 @@ type Empirical<'T when 'T:equality>(values:seq<'T>, ?weights:Tensor, ?logWeights
185305 d.[ minIndex.. maxIndex]. sample()
186306 else
187307 let i = _ categorical.sample() |> int in d.values.[ i]
308+
309+ /// <summary>TBD</summary>
188310 member d.resample ( numSamples , ? minIndex : int , ? maxIndex : int ) = Array.init numSamples ( fun _ -> d.sample( ?minIndex= minIndex, ?maxIndex= maxIndex)) |> Empirical
311+
312+ /// <summary>TBD</summary>
189313 member d.thin ( numSamples , ? minIndex : int , ? maxIndex : int ) =
190314 if d.isWeighted then failwithf " Cannot thin weighted distribution. Consider transforming Empirical to an unweigted Empirical by resampling."
191315 let minIndex = defaultArg minIndex 0
@@ -196,23 +320,50 @@ type Empirical<'T when 'T:equality>(values:seq<'T>, ?weights:Tensor, ?logWeights
196320 let v = d.values.[ i]
197321 results.Add( v)
198322 Empirical( results.ToArray())
323+
324+ /// <summary>TBD</summary>
199325 member d.combineDuplicates () = Empirical( d.values, logWeights= d.logWeights, combineDuplicates= true )
326+
327+ /// <summary>TBD</summary>
200328 member d.expectation ( f : Tensor -> Tensor ) =
201329 if d.isWeighted then d.valuesTensor.unstack() |> Seq.mapi ( fun i v -> d.weights.[ i]*( f v)) |> dsharp.stack |> dsharp.sum( 0 )
202330 else d.valuesTensor.unstack() |> Seq.map f |> dsharp.stack |> dsharp.mean( 0 )
331+
332+ /// <summary>TBD</summary>
203333 member d.mean = d.expectation( id)
334+
335+ /// <summary>TBD</summary>
204336 member d.variance = let mean = d.mean in d.expectation( fun x -> ( x- mean)** 2 )
337+
338+ /// <summary>TBD</summary>
205339 member d.stddev = dsharp.sqrt( d.variance)
340+
341+ /// <summary>TBD</summary>
206342 member d.mode =
207343 if d.isWeighted then
208344 let dCombined = d.combineDuplicates()
209345 let i = dCombined.logWeights.argmax() in dCombined.values.[ i.[ 0 ]]
210346 else
211347 let vals , _ = d.values |> Array.getUniqueCounts true
212348 vals.[ 0 ]
349+
350+ /// <summary>TBD</summary>
213351 member d.min = d.valuesTensor.min()
352+
353+ /// <summary>TBD</summary>
214354 member d.max = d.valuesTensor.max()
355+
356+ /// <summary>TBD</summary>
215357 member d.effectiveSampleSize = 1. / d.weights.pow( 2 ) .sum()
358+
359+ /// <summary>TBD</summary>
216360 override d.sample () = d.sample( minIndex= 0 , maxIndex= d.length)
361+
362+ /// <summary>TBD</summary>
217363 override d.logprob ( _ ) = failwith " Not supported" // TODO: can be implemented using density estimation
218- override d.ToString () = sprintf " Empirical(length:%A )" d.length
364+
365+ /// <summary>TBD</summary>
366+
367+ /// <summary>TBD</summary>
368+ override d.ToString () = sprintf " Empirical(length:%A )" d.length
369+
0 commit comments