package tsp.pacman; import java.util.LinkedList; import java.util.List; public class MCTSNode { private PacmanState state; private double totalScore; private int totalVisits; private List sons; private MCTSNode parent; public MCTSNode(PacmanState state){ this.state = state; this.totalScore = 0; this.totalVisits = 0; } public MCTSNode(PacmanState state, MCTSNode parent){ this(state); this.parent = parent; } public void deleteParent(){ parent = null; } public void addScore(double score){ totalScore += score; totalVisits += 1; if (this.parent != null){ this.parent.addScore(score); } } public double getUCB(double C, int totalParents){ if (totalVisits == 0) return Double.MAX_VALUE; if (state.getCurrentPlayer() == 1) return totalScore / totalVisits + C * Math.sqrt(Math.log(totalParents) / totalVisits); return -totalScore / totalVisits + C * Math.sqrt(Math.log(totalParents) / totalVisits); } public boolean isLeaf(){ return sons == null; } public List getSons() { return sons; } public PacmanState getState() { return state; } public void generateSons(){ int currentPlayer = state.getCurrentPlayer(); if (currentPlayer == 0){ sons = getSonsPacman(); } else { sons = getSonsGhosts(state.findGhosts()[currentPlayer - 1]); } } private List getSonsPacman(){ List sons = new LinkedList<>(); for (Position position: state.findPossibleMoves()){ sons.add(new MCTSNode(state.move(position), this)); } if (sons.isEmpty()){ sons.add(new MCTSNode(state.move(state.findPacman()), this)); } return sons; } private List getSonsGhosts(Position positionGhost){ List sons = new LinkedList<>(); for (Position position: state.findPossibleMoves(positionGhost)){ sons.add(new MCTSNode(state.moveGhost(positionGhost, position), this)); } if (sons.isEmpty()){ sons.add(new MCTSNode(state.moveGhost(positionGhost, positionGhost), this)); } return sons; } public MCTSNode selectSon(double C){ double bestUCB = -Double.MAX_VALUE; MCTSNode bestState = null; for (MCTSNode node: sons){ double ucb = node.getUCB(C, totalVisits); if (ucb > bestUCB){ bestUCB = ucb; bestState = node; } } return bestState; } public MCTSNode selectSonFinal(){ double bestUCB = -Double.MAX_VALUE; MCTSNode bestState = null; for (MCTSNode node: sons){ double finalScore = node.totalScore / node.totalVisits; System.out.println(finalScore); if (finalScore > bestUCB){ bestUCB = finalScore; bestState = node; } } return bestState; } public MCTSNode selectLeaf(double C){ if (this.isLeaf()) return this; return selectSon(C).selectLeaf(C); } public void runSimulation(int maxDepth){ PacmanState currentState = state; while (maxDepth > 0 && !currentState.isFinalState()){ int currentPlayer = currentState.getCurrentPlayer(); maxDepth -= 1; if (currentPlayer == 0){ List positions = currentState.findPossibleMoves(); Position chosenPosition = currentState.findPacman(); if (!positions.isEmpty()) { int index = (int)(Math.random() * positions.size()); chosenPosition = positions.get(index); } currentState = currentState.move(chosenPosition); } else { Position ghostPosition = currentState.findGhosts()[currentState.getCurrentPlayer() - 1]; List positions = currentState.findPossibleMoves(ghostPosition); Position choosenPosition = ghostPosition; if (!positions.isEmpty()) { int index = (int)(Math.random() * positions.size()); choosenPosition = positions.get(index); } currentState = currentState.moveGhost(ghostPosition, choosenPosition); } } int score = -currentState.findFoods().size(); if (currentState.isFinalState()){ if (currentState.isLost()) score = -currentState.findFoods().size() * 2; //state.getWidth() * state.getHeight(); else score = currentState.getScore(); } this.addScore(score); } public void printTree(int depth){ if (depth > 5) return; for (int i = 0; i < depth; i++) System.out.print(" "); if (this.parent != null) { System.out.println(totalVisits + " " + totalScore + " " + getUCB(1.0, this.parent.totalVisits)); } else { System.out.println(totalVisits + " " + totalScore + " " + getUCB(1.0, 0)); } if (!isLeaf()) { for (MCTSNode son : sons) son.printTree(depth + 1); } } public void printTree() { printTree(0); } }