Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 15 additions & 29 deletions roofit/hs3/src/JSONFactories_RooFitCore.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -543,12 +543,6 @@ class RooAddPdfStreamer : public RooFit::JSONIO::Exporter {
{
const RooArg_t *pdf = static_cast<const RooArg_t *>(func);
elem["type"] << key();
std::string name = elem["name"].val();
/*elem["name"] << RooJSONFactoryWSTool::sanitizeName(name);
RooJSONFactoryWSTool::fillSeqSanitizedName(elem["summands"], pdf->pdfList());
RooJSONFactoryWSTool::fillSeqSanitizedName(elem["coefficients"], pdf->coefList());
*/
elem["name"] << name;
RooJSONFactoryWSTool::fillSeq(elem["summands"], pdf->pdfList());
RooJSONFactoryWSTool::fillSeq(elem["coefficients"], pdf->coefList());
elem["extended"] << (pdf->extendMode() != RooArg_t::CanNotBeExtended);
Expand All @@ -563,12 +557,6 @@ class RooRealSumPdfStreamer : public RooFit::JSONIO::Exporter {
{
const RooRealSumPdf *pdf = static_cast<const RooRealSumPdf *>(func);
elem["type"] << key();
std::string name = elem["name"].val();
/*elem["name"] << RooJSONFactoryWSTool::sanitizeName(name);
RooJSONFactoryWSTool::fillSeqSanitizedName(elem["samples"], pdf->funcList());
RooJSONFactoryWSTool::fillSeqSanitizedName(elem["coefficients"], pdf->coefList());
*/
elem["name"] << name;
RooJSONFactoryWSTool::fillSeq(elem["samples"], pdf->funcList());
RooJSONFactoryWSTool::fillSeq(elem["coefficients"], pdf->coefList());
elem["extended"] << (pdf->extendMode() != RooAbsPdf::CanNotBeExtended);
Expand All @@ -583,12 +571,6 @@ class RooRealSumFuncStreamer : public RooFit::JSONIO::Exporter {
{
const RooRealSumFunc *pdf = static_cast<const RooRealSumFunc *>(func);
elem["type"] << key();
std::string name = elem["name"].val();
/*elem["name"] << RooJSONFactoryWSTool::sanitizeName(name);
RooJSONFactoryWSTool::fillSeqSanitizedName(elem["samples"], pdf->funcList());
RooJSONFactoryWSTool::fillSeqSanitizedName(elem["coefficients"], pdf->coefList());
*/
elem["name"] << name;
RooJSONFactoryWSTool::fillSeq(elem["samples"], pdf->funcList());
RooJSONFactoryWSTool::fillSeq(elem["coefficients"], pdf->coefList());
return true;
Expand Down Expand Up @@ -687,6 +669,7 @@ class RooFormulaArgStreamer : public RooFit::JSONIO::Exporter {
const RooArg_t *pdf = static_cast<const RooArg_t *>(func);
elem["type"] << key();
TString expression(pdf->expression());
cleanExpression(expression);
// If the tokens follow the "x[#]" convention, the square braces enclosing each number
// ensures that there is a unique mapping between the token and parameter name
// If the tokens follow the "@#" convention, the numbers are not enclosed by braces.
Expand All @@ -701,6 +684,19 @@ class RooFormulaArgStreamer : public RooFit::JSONIO::Exporter {
elem["expression"] << expression.Data();
return true;
}

private:
void cleanExpression(TString &expr) const
{
expr.ReplaceAll("TMath::Exp", "exp");
expr.ReplaceAll("TMath::Min", "min");
expr.ReplaceAll("TMath::Max", "max");
expr.ReplaceAll("TMath::Log", "log");
expr.ReplaceAll("TMath::Cos", "cos");
expr.ReplaceAll("TMath::Sin", "sin");
expr.ReplaceAll("TMath::Sqrt", "sqrt");
expr.ReplaceAll("TMath::Power", "pow");
}
};
template <class RooArg_t>
class RooPolynomialStreamer : public RooFit::JSONIO::Exporter {
Expand Down Expand Up @@ -784,9 +780,6 @@ class RooTruthModelStreamer : public RooFit::JSONIO::Exporter {
{
auto *pdf = static_cast<const RooTruthModel *>(func);
elem["type"] << key();
std::string name = elem["name"].val();
// elem["name"] << RooJSONFactoryWSTool::sanitizeName(name);
elem["name"] << name;
elem["x"] << pdf->convVar().GetName();

return true;
Expand All @@ -800,9 +793,6 @@ class RooGaussModelStreamer : public RooFit::JSONIO::Exporter {
{
auto *pdf = static_cast<const RooGaussModel *>(func);
elem["type"] << key();
std::string name = elem["name"].val();
// elem["name"] << RooJSONFactoryWSTool::sanitizeName(name);
elem["name"] << name;
elem["x"] << pdf->convVar().GetName();
elem["mean"] << pdf->getMean().GetName();
elem["sigma"] << pdf->getSigma().GetName();
Expand Down Expand Up @@ -913,10 +903,6 @@ class RooRealIntegralStreamer : public RooFit::JSONIO::Exporter {
bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override
{
auto *integral = static_cast<const RooRealIntegral *>(func);
std::string name = elem["name"].val();
// elem["name"] << RooJSONFactoryWSTool::sanitizeName(name);
elem["name"] << name;

elem["type"] << key();
std::string integrand = integral->integrand().GetName();
// elem["integrand"] << RooJSONFactoryWSTool::sanitizeName(integrand);
Expand Down Expand Up @@ -1060,7 +1046,7 @@ STATIC_EXECUTE([]() {
registerExporter<RooRealIntegralStreamer>(RooRealIntegral::Class(), false);
registerExporter<RooDerivativeStreamer>(RooDerivative::Class(), false);
registerExporter<RooFFTConvPdfStreamer>(RooFFTConvPdf::Class(), false);
registerExporter<RooExtendPdfStreamer>(RooExtendPdf::Class(), false);
registerExporter<RooExtendPdfStreamer>(RooExtendPdf::Class(), false);
});

} // namespace
27 changes: 16 additions & 11 deletions roofit/hs3/src/RooJSONFactoryWSTool.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -483,18 +483,13 @@ void exportAttributes(const RooAbsArg *arg, JSONNode &rootnode)
*
* @param ws The RooWorkspace in which the observables will be created.
* @param node The JSONNode containing information about the observables to be created.
* @param out The RooArgSet to which the created observables will be added.
* @param out The RooAbsCollection to which the created observables will be added.
* @return void
*/
void getObservables(RooWorkspace const &ws, const JSONNode &node, RooArgSet &out)
void getObservables(RooWorkspace const &ws, const JSONNode &node, RooAbsCollection &out)
{
std::map<std::string, Var> vars;
for (const auto &p : node["axes"].children()) {
vars.emplace(RooJSONFactoryWSTool::name(p), Var(p));
}

for (auto v : vars) {
std::string name(v.first);
std::string name(RooJSONFactoryWSTool::name(p));
if (ws.var(name)) {
out.add(*ws.var(name));
} else {
Expand Down Expand Up @@ -528,9 +523,9 @@ std::unique_ptr<RooAbsData> loadData(const JSONNode &p, RooWorkspace &workspace)
return RooJSONFactoryWSTool::readBinnedData(p, name, RooJSONFactoryWSTool::readAxes(p));
} else if (type == "unbinned") {
// unbinned
RooArgSet vars;
getObservables(workspace, p, vars);
RooArgList varlist(vars);
RooArgList varlist;
getObservables(workspace, p, varlist);
RooArgSet vars(varlist);
auto data = std::make_unique<RooDataSet>(name, name, vars, RooFit::WeightVar());
auto &coords = p["entries"];
if (!coords.is_seq()) {
Expand Down Expand Up @@ -2508,6 +2503,10 @@ RooWorkspace RooJSONFactoryWSTool::cleanWS(const RooWorkspace &ws, bool onlyMode
tmpWS.import(*obj);
}

for (auto *obj : ws.allResolutionModels()) {
tmpWS.import(*obj);
}

/*
if (auto* mc = dynamic_cast<RooStats::ModelConfig*>(obj)) {
// Import the PDF
Expand Down Expand Up @@ -2583,6 +2582,12 @@ RooWorkspace RooJSONFactoryWSTool::sanitizeWS(const RooWorkspace &ws)
}
}

// Resolution Models
for (auto *obj : tmpWS.allResolutionModels()) {
if (!isValidName(obj->GetName())) {
obj->SetName(sanitizeName(obj->GetName()).c_str());
}
}
// Datasets
for (auto *data : tmpWS.allData()) {
// Sanitize dataset name
Expand Down