diff --git a/cpp/program/setup.cpp b/cpp/program/setup.cpp index 8aff21d9d..61df764f0 100644 --- a/cpp/program/setup.cpp +++ b/cpp/program/setup.cpp @@ -797,6 +797,11 @@ vector Setup::loadParams( else if(cfg.contains("humanSLOppExploreProbWeightful"+idxStr)) params.humanSLOppExploreProbWeightful = cfg.getDouble("humanSLOppExploreProbWeightful"+idxStr, 0.0, 1.0); else if(cfg.contains("humanSLOppExploreProbWeightful")) params.humanSLOppExploreProbWeightful = cfg.getDouble("humanSLOppExploreProbWeightful", 0.0, 1.0); else params.humanSLOppExploreProbWeightful = 0.0; + if(!hasHumanModel && cfg.contains("humanSLValueProportion"+idxStr)) throwHumanParsingError("humanSLValueProportion"+idxStr); + else if(!hasHumanModel && cfg.contains("humanSLValueProportion")) throwHumanParsingError("humanSLValueProportion"); + else if(cfg.contains("humanSLValueProportion"+idxStr)) params.humanSLValueProportion = cfg.getDouble("humanSLValueProportion"+idxStr, 0.0, 1.0); + else if(cfg.contains("humanSLValueProportion")) params.humanSLValueProportion = cfg.getDouble("humanSLValueProportion", 0.0, 1.0); + else params.humanSLValueProportion = 0.0; if(!hasHumanModel && cfg.contains("humanSLChosenMoveProp"+idxStr)) throwHumanParsingError("humanSLChosenMoveProp"+idxStr); else if(!hasHumanModel && cfg.contains("humanSLChosenMoveProp")) throwHumanParsingError("humanSLChosenMoveProp"); else if(cfg.contains("humanSLChosenMoveProp"+idxStr)) params.humanSLChosenMoveProp = cfg.getDouble("humanSLChosenMoveProp"+idxStr, 0.0, 1.0); diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index 58af0bec1..4c49bf76b 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -1068,9 +1068,14 @@ void Search::computeRootValues() { //Grab a neural net evaluation for the current position and use that as the center if(!foundExpectedScoreFromTree) { NNResultBuf nnResultBuf; + NNResultBuf humanResultBuf; bool includeOwnerMap = true; - computeRootNNEvaluation(nnResultBuf,includeOwnerMap); + bool includeHumanResult = humanEvaluator != NULL && searchParams.humanSLValueProportion > 0; + computeRootNNEvaluation(nnResultBuf,humanResultBuf,includeOwnerMap,includeHumanResult); expectedScore = nnResultBuf.result->whiteScoreMean; + if(includeHumanResult) { + expectedScore += searchParams.humanSLValueProportion * ((double)(humanResultBuf.result->whiteScoreMean) - expectedScore); + } } recentScoreCenter = expectedScore * (1.0 - searchParams.dynamicScoreCenterZeroWeight); diff --git a/cpp/search/search.h b/cpp/search/search.h index f327112ee..45de489d1 100644 --- a/cpp/search/search.h +++ b/cpp/search/search.h @@ -434,11 +434,9 @@ struct Search { // searchhelpers.cpp //---------------------------------------------------------------------------------------- double getResultUtility(double winlossValue, double noResultValue) const; - double getResultUtilityFromNN(const NNOutput& nnOutput) const; double getScoreUtility(double scoreMeanAvg, double scoreMeanSqAvg) const; double getScoreUtilityDiff(double scoreMeanAvg, double scoreMeanSqAvg, double delta) const; double getApproxScoreUtilityDerivative(double scoreMean) const; - double getUtilityFromNN(const NNOutput& nnOutput) const; //---------------------------------------------------------------------------------------- // Miscellaneous search biasing helpers, root move selection, etc. @@ -517,7 +515,7 @@ struct Search { // Neural net queries // searchnnhelpers.cpp //---------------------------------------------------------------------------------------- - void computeRootNNEvaluation(NNResultBuf& nnResultBuf, bool includeOwnerMap); + void computeRootNNEvaluation(NNResultBuf& nnResultBuf, NNResultBuf& humanResultBuf, bool includeOwnerMap, bool includeHumanResult); bool initNodeNNOutput( SearchThread& thread, SearchNode& node, bool isRoot, bool skipCache, bool isReInit @@ -610,6 +608,7 @@ struct Search { bool assumeNoExistingWeight ); void addCurrentNNOutputAsLeafValue(SearchNode& node, bool assumeNoExistingWeight); + double getThisNodeNNUtility(const SearchNode& node) const; double computeWeightFromNNOutput(const NNOutput* nnOutput) const; diff --git a/cpp/search/searchexplorehelpers.cpp b/cpp/search/searchexplorehelpers.cpp index 07a320886..a0149ad74 100644 --- a/cpp/search/searchexplorehelpers.cpp +++ b/cpp/search/searchexplorehelpers.cpp @@ -299,10 +299,10 @@ double Search::getFpuValueForChildrenAssumeVisited( double parentUtilityForFPU = parentUtility; if(searchParams.fpuParentWeightByVisitedPolicy) { double avgWeight = std::min(1.0, pow(policyProbMassVisited, searchParams.fpuParentWeightByVisitedPolicyPow)); - parentUtilityForFPU = avgWeight * parentUtility + (1.0 - avgWeight) * getUtilityFromNN(*(node.getNNOutput())); + parentUtilityForFPU = avgWeight * parentUtility + (1.0 - avgWeight) * getThisNodeNNUtility(node); } else if(searchParams.fpuParentWeight > 0.0) { - parentUtilityForFPU = searchParams.fpuParentWeight * getUtilityFromNN(*(node.getNNOutput())) + (1.0 - searchParams.fpuParentWeight) * parentUtility; + parentUtilityForFPU = searchParams.fpuParentWeight * getThisNodeNNUtility(node) + (1.0 - searchParams.fpuParentWeight) * parentUtility; } double fpuValue; diff --git a/cpp/search/searchhelpers.cpp b/cpp/search/searchhelpers.cpp index 060e9875a..401858ac7 100644 --- a/cpp/search/searchhelpers.cpp +++ b/cpp/search/searchhelpers.cpp @@ -261,13 +261,6 @@ double Search::getResultUtility(double winLossValue, double noResultValue) const ); } -double Search::getResultUtilityFromNN(const NNOutput& nnOutput) const { - return ( - (nnOutput.whiteWinProb - nnOutput.whiteLossProb) * searchParams.winLossUtilityFactor + - nnOutput.whiteNoResultProb * searchParams.noResultUtilityForWhite - ); -} - double Search::getScoreUtility(double scoreMeanAvg, double scoreMeanSqAvg) const { double scoreMean = scoreMeanAvg; double scoreMeanSq = scoreMeanSqAvg; @@ -301,12 +294,6 @@ double Search::getApproxScoreUtilityDerivative(double scoreMean) const { } -double Search::getUtilityFromNN(const NNOutput& nnOutput) const { - double resultUtility = getResultUtilityFromNN(nnOutput); - return resultUtility + getScoreUtility(nnOutput.whiteScoreMean, nnOutput.whiteScoreMeanSq); -} - - bool Search::isAllowedRootMove(Loc moveLoc) const { assert(moveLoc == Board::PASS_LOC || rootBoard.isOnBoard(moveLoc)); diff --git a/cpp/search/searchnnhelpers.cpp b/cpp/search/searchnnhelpers.cpp index 14033feea..723d8968c 100644 --- a/cpp/search/searchnnhelpers.cpp +++ b/cpp/search/searchnnhelpers.cpp @@ -6,7 +6,7 @@ #include "../core/using.h" //------------------------ -void Search::computeRootNNEvaluation(NNResultBuf& nnResultBuf, bool includeOwnerMap) { +void Search::computeRootNNEvaluation(NNResultBuf& nnResultBuf, NNResultBuf& humanResultBuf, bool includeOwnerMap, bool includeHumanResult) { Board board = rootBoard; const BoardHistory& hist = rootHistory; Player pla = rootPla; @@ -32,6 +32,15 @@ void Search::computeRootNNEvaluation(NNResultBuf& nnResultBuf, bool includeOwner nnInputParams, nnResultBuf, skipCache, includeOwnerMap ); + + if(includeHumanResult) { + assert(humanEvaluator != NULL); + humanEvaluator->evaluate( + board, hist, pla, &searchParams.humanSLProfile, + nnInputParams, + humanResultBuf, skipCache, includeOwnerMap + ); + } } bool Search::needsHumanOutputAtRoot() const { @@ -42,7 +51,8 @@ bool Search::needsHumanOutputInTree() const { searchParams.humanSLPlaExploreProbWeightless > 0 || searchParams.humanSLPlaExploreProbWeightful > 0 || searchParams.humanSLOppExploreProbWeightless > 0 || - searchParams.humanSLOppExploreProbWeightful > 0 + searchParams.humanSLOppExploreProbWeightful > 0 || + searchParams.humanSLValueProportion > 0 ); } diff --git a/cpp/search/searchparams.cpp b/cpp/search/searchparams.cpp index 18b532dfe..b24486f6c 100644 --- a/cpp/search/searchparams.cpp +++ b/cpp/search/searchparams.cpp @@ -112,6 +112,7 @@ SearchParams::SearchParams() humanSLPlaExploreProbWeightful(0.0), humanSLOppExploreProbWeightless(0.0), humanSLOppExploreProbWeightful(0.0), + humanSLValueProportion(0.0), humanSLChosenMoveProp(0.0), humanSLChosenMoveIgnorePass(false), humanSLChosenMovePiklLambda(1000000000.0) @@ -252,6 +253,7 @@ bool SearchParams::operator==(const SearchParams& other) const { humanSLOppExploreProbWeightless == other.humanSLOppExploreProbWeightless && humanSLOppExploreProbWeightful == other.humanSLOppExploreProbWeightful && + humanSLValueProportion == other.humanSLValueProportion && humanSLChosenMoveProp == other.humanSLChosenMoveProp && humanSLChosenMoveIgnorePass == other.humanSLChosenMoveIgnorePass && humanSLChosenMovePiklLambda == other.humanSLChosenMovePiklLambda @@ -499,6 +501,7 @@ json SearchParams::changeableParametersToJson() const { ret["humanSLOppExploreProbWeightless"] = humanSLOppExploreProbWeightless; ret["humanSLOppExploreProbWeightful"] = humanSLOppExploreProbWeightful; + ret["humanSLValueProportion"] = humanSLValueProportion; ret["humanSLChosenMoveProp"] = humanSLChosenMoveProp; ret["humanSLChosenMoveIgnorePass"] = humanSLChosenMoveIgnorePass; ret["humanSLChosenMovePiklLambda"] = humanSLChosenMovePiklLambda; @@ -650,6 +653,7 @@ void SearchParams::printParams(std::ostream& out) const { PRINTPARAM(humanSLPlaExploreProbWeightful); PRINTPARAM(humanSLOppExploreProbWeightless); PRINTPARAM(humanSLOppExploreProbWeightful); + PRINTPARAM(humanSLValueProportion); PRINTPARAM(humanSLChosenMoveProp); PRINTPARAM(humanSLChosenMoveIgnorePass); PRINTPARAM(humanSLChosenMovePiklLambda); diff --git a/cpp/search/searchparams.h b/cpp/search/searchparams.h index 7139f24ef..1f538203f 100644 --- a/cpp/search/searchparams.h +++ b/cpp/search/searchparams.h @@ -162,6 +162,9 @@ struct SearchParams { double humanSLOppExploreProbWeightless; double humanSLOppExploreProbWeightful; + //Mix in this amount of the humanSL value into the values at nodes + double humanSLValueProportion; + //These three are PRIOR to the normal chosenMoveTemperature. double humanSLChosenMoveProp; //Proportion of final move selection probability using human SL policy bool humanSLChosenMoveIgnorePass; //If true, ignore human SL pass probability and use KataGo's passing logic diff --git a/cpp/search/searchresults.cpp b/cpp/search/searchresults.cpp index 6bc368543..3649275f1 100644 --- a/cpp/search/searchresults.cpp +++ b/cpp/search/searchresults.cpp @@ -2285,6 +2285,16 @@ bool Search::getPrunedNodeValues(const SearchNode* nodePtr, ReportedSearchValues double scoreMean = (double)nnOutput->whiteScoreMean; double scoreMeanSq = (double)nnOutput->whiteScoreMeanSq; double lead = (double)nnOutput->whiteLead; + if(humanEvaluator != NULL && searchParams.humanSLValueProportion > 0) { + const NNOutput* humanOutput = node.getHumanOutput(); + assert(humanOutput != NULL); + winProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteWinProb) - winProb); + lossProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLossProb) - lossProb); + noResultProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteNoResultProb) - noResultProb); + scoreMean += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMean) - scoreMean); + scoreMeanSq += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMeanSq) - scoreMeanSq); + lead += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLead) - lead); + } double utility = getResultUtility(winProb-lossProb, noResultProb) + getScoreUtility(scoreMean, scoreMeanSq); diff --git a/cpp/search/searchupdatehelpers.cpp b/cpp/search/searchupdatehelpers.cpp index ff39b208c..9481de3da 100644 --- a/cpp/search/searchupdatehelpers.cpp +++ b/cpp/search/searchupdatehelpers.cpp @@ -81,6 +81,29 @@ void Search::addLeafValue( } } +double Search::getThisNodeNNUtility(const SearchNode& node) const { + const NNOutput* nnOutput = node.getNNOutput(); + assert(nnOutput != NULL); + double winProb = (double)nnOutput->whiteWinProb; + double lossProb = (double)nnOutput->whiteLossProb; + double noResultProb = (double)nnOutput->whiteNoResultProb; + double scoreMean = (double)nnOutput->whiteScoreMean; + double scoreMeanSq = (double)nnOutput->whiteScoreMeanSq; + if(humanEvaluator != NULL && searchParams.humanSLValueProportion > 0) { + const NNOutput* humanOutput = node.getHumanOutput(); + assert(humanOutput != NULL); + winProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteWinProb) - winProb); + lossProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLossProb) - lossProb); + noResultProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteNoResultProb) - noResultProb); + scoreMean += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMean) - scoreMean); + scoreMeanSq += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMeanSq) - scoreMeanSq); + } + double utility = + getResultUtility(winProb-lossProb, noResultProb) + + getScoreUtility(scoreMean, scoreMeanSq); + return utility; +} + void Search::addCurrentNNOutputAsLeafValue(SearchNode& node, bool assumeNoExistingWeight) { const NNOutput* nnOutput = node.getNNOutput(); assert(nnOutput != NULL); @@ -92,6 +115,16 @@ void Search::addCurrentNNOutputAsLeafValue(SearchNode& node, bool assumeNoExisti double scoreMeanSq = (double)nnOutput->whiteScoreMeanSq; double lead = (double)nnOutput->whiteLead; double weight = computeWeightFromNNOutput(nnOutput); + if(humanEvaluator != NULL && searchParams.humanSLValueProportion > 0) { + const NNOutput* humanOutput = node.getHumanOutput(); + assert(humanOutput != NULL); + winProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteWinProb) - winProb); + lossProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLossProb) - lossProb); + noResultProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteNoResultProb) - noResultProb); + scoreMean += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMean) - scoreMean); + scoreMeanSq += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMeanSq) - scoreMeanSq); + lead += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLead) - lead); + } addLeafValue(node,winProb-lossProb,noResultProb,scoreMean,scoreMeanSq,lead,weight,false,assumeNoExistingWeight); } @@ -248,6 +281,17 @@ void Search::recomputeNodeStats(SearchNode& node, SearchThread& thread, int numV double scoreMean = (double)nnOutput->whiteScoreMean; double scoreMeanSq = (double)nnOutput->whiteScoreMeanSq; double lead = (double)nnOutput->whiteLead; + if(humanEvaluator != NULL && searchParams.humanSLValueProportion > 0) { + const NNOutput* humanOutput = node.getHumanOutput(); + assert(humanOutput != NULL); + winProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteWinProb) - winProb); + lossProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLossProb) - lossProb); + noResultProb += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteNoResultProb) - noResultProb); + scoreMean += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMean) - scoreMean); + scoreMeanSq += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteScoreMeanSq) - scoreMeanSq); + lead += searchParams.humanSLValueProportion * ((double)(humanOutput->whiteLead) - lead); + } + double utility = getResultUtility(winProb-lossProb, noResultProb) + getScoreUtility(scoreMean, scoreMeanSq); diff --git a/cpp/tests/results/runOutputTests.txt b/cpp/tests/results/runOutputTests.txt index 0b89725b1..bce35cd1b 100644 --- a/cpp/tests/results/runOutputTests.txt +++ b/cpp/tests/results/runOutputTests.txt @@ -20666,6 +20666,7 @@ humanSLPlaExploreProbWeightless: 0 humanSLPlaExploreProbWeightful: 0 humanSLOppExploreProbWeightless: 0 humanSLOppExploreProbWeightful: 0 +humanSLValueProportion: 0 humanSLChosenMoveProp: 0 humanSLChosenMoveIgnorePass: 0 humanSLChosenMovePiklLambda: 1e+09 @@ -20773,6 +20774,7 @@ humanSLPlaExploreProbWeightless: 0 humanSLPlaExploreProbWeightful: 0 humanSLOppExploreProbWeightless: 0 humanSLOppExploreProbWeightful: 0 +humanSLValueProportion: 0 humanSLChosenMoveProp: 0 humanSLChosenMoveIgnorePass: 0 humanSLChosenMovePiklLambda: 1e+09 @@ -20880,6 +20882,7 @@ humanSLPlaExploreProbWeightless: 0 humanSLPlaExploreProbWeightful: 0 humanSLOppExploreProbWeightless: 0 humanSLOppExploreProbWeightful: 0 +humanSLValueProportion: 0 humanSLChosenMoveProp: 0 humanSLChosenMoveIgnorePass: 0 humanSLChosenMovePiklLambda: 1e+09 @@ -20987,6 +20990,7 @@ humanSLPlaExploreProbWeightless: 0 humanSLPlaExploreProbWeightful: 0 humanSLOppExploreProbWeightless: 0 humanSLOppExploreProbWeightful: 0 +humanSLValueProportion: 0 humanSLChosenMoveProp: 0 humanSLChosenMoveIgnorePass: 0 humanSLChosenMovePiklLambda: 1e+09 @@ -21094,6 +21098,7 @@ humanSLPlaExploreProbWeightless: 0 humanSLPlaExploreProbWeightful: 0 humanSLOppExploreProbWeightless: 0 humanSLOppExploreProbWeightful: 0 +humanSLValueProportion: 0 humanSLChosenMoveProp: 0 humanSLChosenMoveIgnorePass: 0 humanSLChosenMovePiklLambda: 1e+09 @@ -21201,6 +21206,7 @@ humanSLPlaExploreProbWeightless: 0 humanSLPlaExploreProbWeightful: 0 humanSLOppExploreProbWeightless: 0 humanSLOppExploreProbWeightful: 0 +humanSLValueProportion: 0 humanSLChosenMoveProp: 0 humanSLChosenMoveIgnorePass: 0 humanSLChosenMovePiklLambda: 1e+09 @@ -21308,6 +21314,7 @@ humanSLPlaExploreProbWeightless: 0 humanSLPlaExploreProbWeightful: 0 humanSLOppExploreProbWeightless: 0 humanSLOppExploreProbWeightful: 0 +humanSLValueProportion: 0 humanSLChosenMoveProp: 0 humanSLChosenMoveIgnorePass: 0 humanSLChosenMovePiklLambda: 1e+09 @@ -21415,6 +21422,7 @@ humanSLPlaExploreProbWeightless: 0 humanSLPlaExploreProbWeightful: 0 humanSLOppExploreProbWeightless: 0 humanSLOppExploreProbWeightful: 0 +humanSLValueProportion: 0 humanSLChosenMoveProp: 0 humanSLChosenMoveIgnorePass: 0 humanSLChosenMovePiklLambda: 1e+09