import java.util.LinkedList; import java.util.List; public class MCTSNode { //... private List sons; public double getUCB(double C, int totalParents){ if (totalVisits == 0) return Double.MAX_VALUE; if (state.getCurrentPlayer() == 1) return totalScore / totalVisits + C * Math.sqrt(Math.log(totalParents) / totalVisits); return -totalScore / totalVisits + C * Math.sqrt(Math.log(totalParents) / totalVisits); } public boolean isLeaf(){ return sons == null; } public List getSons() { return sons; } public MCTSNode selectSon(double C){ double bestUCB = -Double.MAX_VALUE; MCTSNode bestState = null; for (MCTSNode node: sons){ double ucb = node.getUCB(C, totalVisits); if (ucb > bestUCB){ bestUCB = ucb; bestState = node; } } return bestState; } public MCTSNode selectLeaf(double C){ if (this.isLeaf()) return this; return selectSon(C).selectLeaf(C); } }