Hard Sudoku Solver Algorithm – Part 2
Introduction
In part 1, we have solved some hard sudoku puzzles using the backtracking approach. While the algorithm did the job, it couln’t beat the time limit in the more advanced challenge: Hard Sudoku Solver 1.
This time, we are going to find a better solution so that we can pass the tests in the new challenge within the 10-second time limit.
It is recommended that you try to practice on your own first before going on to read the solution.
Problem description
From Codewars’ Hard Sudoku Solver 1:
There are several difficulty of sudoku games, we can estimate the difficulty of a sudoku game based on how many cells are given of the 81 cells of the game.
- Easy sudoku generally have over 32 givens
- Medium sudoku have around 30–32 givens
- Hard sudoku have around 28–30 givens
- Very Hard sudoku have less than 28 givens
Note: The minimum of givens required to create a unique (with no multiple solutions) sudoku game is 17.
A hard sudoku game means that at start no cell will have a single candidates and thus require guessing and trial and error. A very hard will have several layers of multiple candidates for any empty cell.
Your Task
Write a function that solves sudoku puzzles of any difficulty. The function will take a sudoku grid and it should return a 9×9 array with the proper answer for the puzzle.
Or it should raise an error in cases of: invalid grid (not 9×9, cell with values not in the range 1~9); multiple solutions for the same puzzle or the puzzle is unsolvable.
Analysis
Compared to the problem in part 1, this problem requires us to raise an error if there are more than one solution to the sudoku. Therefore, instead of return the first solution we found, we have to keep testing other possibilities to make sure that there are no other solution to the sudoku. The good news is, we can return an error early immediately when we found the second solution.
The hardest test that we are going to solve is the puzzle below, which has only 17 given numbers:
#### Should solve very hard puzzle puzzle = \ [[8, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 3, 6, 0, 0, 0, 0, 0], [0, 7, 0, 0, 9, 0, 2, 0, 0], [0, 5, 0, 0, 0, 7, 0, 0, 0], [0, 0, 0, 0, 4, 5, 7, 0, 0], [0, 0, 0, 1, 0, 0, 0, 3, 0], [0, 0, 1, 0, 0, 0, 0, 6, 8], [0, 0, 8, 5, 0, 0, 0, 1, 0], [0, 9, 0, 0, 0, 0, 4, 0, 0]] solution = \ [[8, 1, 2, 7, 5, 3, 6, 4, 9], [9, 4, 3, 6, 8, 2, 1, 7, 5], [6, 7, 5, 4, 9, 1, 2, 8, 3], [1, 5, 4, 2, 3, 7, 8, 9, 6], [3, 6, 9, 8, 4, 5, 7, 2, 1], [2, 8, 7, 1, 6, 9, 5, 3, 4], [5, 2, 1, 9, 7, 4, 3, 6, 8], [4, 3, 8, 5, 2, 6, 9, 1, 7], [7, 9, 6, 3, 1, 8, 4, 5, 2]] assert(sudoku_solver(puzzle) == solution)
The backtracking approach
The first thing that we want to try is of course reusing the solution from the last challenge, with some modifications to fit the new additional requirements.
The sample code goes as below:
import copy def sudoku_solver(puzzle): sudoku = Sudoku(puzzle) if not sudoku.valid: raise ValueError sudoku.solve() if len(sudoku.solutions) != 1: raise ValueError return sudoku.solutions[0] class Sudoku: def __init__(self, board): # the 9x9 sudoku board self.board = [] # list of solutions to this sudoku self.solutions = [] # is the board valid self.valid = True # init a blank board with all cells filled with zero for i in range(9): r = [0] * 9 self.board.append(r) # check the input board dimensions if not self.validate(board): self.valid = False # copy the input board to self.board for r in range(len(board)): for c in range(len(board[r])): if board[r][c]: if not self.set(r, c, board[r][c]): self.valid = False # validate board dimensions def validate(self, board): if len(board) != 9: return False for r in range(len(board)): if len(board[r]) != 9: return False return True # the main function to solve the sudoku def solve(self): # recursively solve the sudoku, starting from cell 0 self.solve_helper(0) return self.solutions # given the sudoku has been filled up to cell k-1, try to solve the sudoku from cell k # k is the cell index, counting from left to right and top to bottom. # k is 0, 1, 2, ..., 8 for cell (0,0), (0,1), (0,2), ..., (0,8) # k is 9, 10, 11, ..., for cell (1,0), (1,1), (1,2), ... # k is 80 for cell (8,8) (last cell) def solve_helper(self, k): # if we get pass the last cell, it means we have filled every cells with valid values. # return True to notify that we have found a solution if (k > 80): self.solutions.append(copy.deepcopy(self.board)) return r = int(k / 9) c = k % 9 # if this cell has been filled, go on to solve the next cell if self.board[r][c] != 0: return self.solve_helper(k+1) # try to fill each value from 1 to 9 in this cell for x in range(1, 10): # fill the cell with value x only if x has not appeared on the same row or col or 3x3 box if self.check(r, c, x): # start backtracking: # try x in this cell, self.board[r][c] = x # then try to solve from the next cell k+1, self.solve_helper(k+1) if len(self.solutions) >= 2: # the problem requires us raise error if there are more than one solution, # so we can skip further processing if 2 solutions have been found return # then clear cell to return the board to the status before x was filled self.board[r][c] = 0 # if we are here, it means we have tried all values in this cell without finding a solution # return False to notify upper recursive solve_helper that # we didn't find any solution given the current board status return False # check if value x can be put at cell[row][col] # return False if value x has already been used in other cells on current row, or column, or 3x3 box # return True otherwise def check(self, row, col, x): if not isinstance(x, int): return False if x < 1 or x > 9: return False for i in range(9): if self.board[row][i] == x: return False if self.board[i][col] == x: return False box_start_row = row - row % 3 box_start_col = col - col % 3 for r in range(box_start_row, box_start_row + 3): for c in range(box_start_col, box_start_col + 3): if self.board[r][c] == x: return False return True # check if x can be put in cell [r][c] # if yes, put x in cell [r][c] and return True # if no, return False def set(self, r, c, x): if not self.check(r, c, x): return False self.board[r][c] = x return True
Test results
This code takes 45 seconds on my computer to solve the very hard puzzle above. Obviously, it didn’t make it through the Codewars’ hidden tests within the time limit.
Algorithm Optimizations
In the backtracking approach, we pick the first vacant cell and try every value at it. For each value, we then recursively pick the next vacant cell and try every value at the new cell. Although we do skip the values that will violates the existing values within the same row, column, and box, the number of branches that the program must try are still overwhelming.
The whole idea to optimize this algorithm is about reducing the branches that we have to test as much as possible.
Algorithm Optimization #1: Try to infer the values for the vacant cells from the existing values
The values can be inferred in two ways:
- Rule 1: If cell [r][c] can only be filled with one value d (because other values already exists in the same row, column, or box), we can fill d in [r][c].
- Rule 2: Because a value d must be filled exactly once in each unit (a unit means a row, a column, or a box), if all cells in a unit cannot be filled with d except for one cell, we can fill d in that cell.
Algorithm Optimization #2: Try to fail as early as possible
If we can detect that the current board status cannot be completed in any way, stop further processing on that case and return immediately.
We can detect this condition with the following checks:
- Rule 3: If cell [r][c] cannot be filled with any value (because all values already exists in the same row, column, or box), the board is unsolvable.
- Rule 4: If value d cannot be filled in any cell within a unit, the board is unsolvable.
Algorithm Optimization #3: Try to branch starting from the cell with fewest possible values
When no more cells can be inferred from the board, we have to go back to testing values at a vacant cell and recursively solve the board with the newly filled test value, similar to what we have been doing in the backtracking approach.
However, we can significantly reduce the number of cases to test by picking the cell with the minimum number of possible values, instead of picking the first vacant cell as usual. It’s like reducing the cases from 9*9*9*9*...
to 2*3*2*2*...
.
Not only that this reduces the number of cases on each recursion, it also significantly reduces the depth of recursion, thanks to the inference process in Algorithm Optimization #1. In practice, this approach can reduces the number of recursive calls from hundreds of thousands calls to only several hundred calls.
Effectiveness
Believe it or not, using the above optimizations, most puzzles on Codewars’ tests, including hard ones, can be solved using only one or two recursive calls. By one recursive call, I mean we can inferred the whole board and don’t even need to test any value at any cell. In fact, except for the “very hard puzzle” mention at the beginning of this topic, all other puzzles can be solved with 6 recursive calls at most.
Implementation Optimizations
Although the algorithm optmizations seem promising, a bad implementation can ruin everything. Next, we are going to look at some implementation optimizations.
Implementation Optimization #1: Use a bit mask to mark allowed values at a cell
To efficiently check the possibilities at a cell, we can use a bit mask for each cell. The bit mask will be initialized as 0x1ff
in hex or 111111111
in binary.
- If value d is still allowed to be filled in a cell without conflicting with any other cells in the same row, column, or box, the dth bit of the cell’s mask is turned on.
- If value d cannot be filled in a cell, the dth bit of the cell’s mask is turned off
More facts can be inferred from the above mask rule:
- The initial value of the cell’s mask contains 9 set bits, indicating that all 9 values are still allowed in that cell.
- If a mask has only 1 set bit at dth position, we can consider that d is filled in that cell (Rule 1).
- If a mask is 0, we can consider the puzzle status is unsolvable (Rule 3).
- Everytime a value d is set at cell [r][c], we also clear the dth bit in every cells in the same row, column, or box.
Some useful functions:
Check if a mask has only one bit (which means that cell has been set with a value):
def is_single_bit(m): return (m & (m - 1)) == 0
where m
is the mask.
Check if a value d is allowed in cell [r][c]:
def is_allowed(m, d): return m & (1<<(d-1))
where m
is the mask, d
is the value we want to check.
Get the corresponding value from the mask, given we already know that the mask has only one set bit:
def get_value(m): return int(math.log2(m)) + 1
where m
is the mask.
Count number of set bits in a mask:
def count_bits(m): count = 0 while m: m &= (m - 1) count += 1 return count
where m
is the mask.
I have also tried using a precalculated bits_set table for fast lookup when counting number of set bits in a mask, however I cannot notice a performance increase in term of running time. The function goes as follow:
# precalculate the bits_set index table bits_set = [0] * 256 for i in range(256): for d in range(1,10): if i & (1 << d-1): bits_set[i] += 1 def count_bits(m): return bits_set[m & 255] + bits_set[(m >> 8) & 255] \ + bits_set[(m >> 16) & 255] + bits_set[(m >> 24) & 255]
Implementation Optimization #2: Use single dimension array instead of two-dimension array
At least, in Python. In my case, use single dimension array to store the mask instead of a two-dimension array speed the code up by two times.
Implementation Optimization #3: Set a value d to a cell by clearing all other bits instead of directly setting the dth bit
By doing this, everytime we clear an eth bit, we have a chance to count the number of remaining cells that still allow e in the related units (row, column, or box). If that number is 1, we can infer the place to fill e (Rule 2). If the number is 0, we can fail early (Rule 4).
Solution
Sample code
The following sample code is written in Python.
import copy import math def sudoku_solver(puzzle): # init a blank sudoku sudoku = Sudoku() # set the input board to our sudoku sudoku.setboard(puzzle) # if the input is invalid, raise an error if not sudoku.valid: raise ValueError # solve the sudoku, the results (if any) will be stored in sudoku.result_boards sudoku.solve() # if there are no solution, or there are more than 1 solution if len(sudoku.result_boards) != 1: raise ValueError # return the only solution return sudoku.result_boards[0] class Sudoku: ### Notes: # cells will be indexed from 0 to 80 ### init static constants: # list of the digits that will be filled in the sudoku digits = [1, 2, 3, 4, 5, 6, 7, 8, 9] # List of block/units that each cell stays in. # For example, cell[9] (row 1, col 0) has 3 units: # units[9][0] == [9, 10, 11, 12, 13, 14, 15, 16, 17] (all cells in the same row) # units[9][1] == [0, 9, 18, 27, 36, 45, 54, 63, 72] (all cells in the same column) # units[9][2] == [0, 1, 2, 9, 10, 11, 18, 19, 20] (all cells in the same box) units = [] # list of all peers of each cell, which are all the unique cells in the same row, column and box. # For example, cell[9] (row 1, col 0) has the following peers: # peers[9] == [9, 10, 11, 12, 13, 14, 15, 16, 17, 0, 18, 27, 36, 45, 54, 63, 72, 0, 1, 2, 19, 20] peers = [] # init units and peers table for i in range(81): units.append([]) # add cells in same row units[i].append([]) r = int(i / 9) for c in range(9): units[i][0].append(r * 9 + c) # add cells in same col units[i].append([]) c = int(i % 9) for r in range(9): units[i][1].append(r * 9 + c) # add cells in same box units[i].append([]) br = int(int(i / 9) / 3) cr = int(int(i % 9) / 3) for r in range(br * 3, br * 3 + 3): for c in range(cr * 3, cr * 3 + 3): units[i][2].append(r * 9 + c) # collect all neighbor cells of each cell peers.append([]) for unit in units[i]: for cell in unit: if cell not in peers[i]: peers[i].append(cell) peers[i].remove(i) # init a blank sudoku def __init__(self): self.mask = [] self.valid = True self.solutions = [] self.result_boards = [] # set the input board to our sudoku def setboard(self, board): # the mask array of the 80 cells self.mask = [] # whether the sudoku is valid self.valid = True # the list of solutions (if any) in mask format self.solutions = [] # the list of solutions (if any) in human readable array format self.result_boards = [] # check the input board dimensions if not self.validate(board): self.valid = False # init mask matrix with all cells set to 0x1ff, indicating that all 9 digits are still allowed in that cell self.mask = [0x1ff] * 81 # set the input board to this sudoku matrix, and also update the peers' masks for each cell along the way for r in range(len(board)): for c in range(len(board[r])): if board[r][c]: # if the digit cannot be set at a cell, we mark that the input board is invalid (unsolvable) if not self.set(r*9+c, board[r][c]): self.valid = False return # validate board dimensions def validate(self, board): if len(board) != 9: return False for r in range(len(board)): if len(board[r]) != 9: return False return True # convert mask to human readable two-dimensional array def mask_to_board(self, mask): board = [] for r in range(9): board.append([0] * 9) for r in range(9): for c in range(9): if self.is_single_bit(mask[r*9+c]): for d in self.digits: if mask[r*9+c] & (1 << (d-1)): board[r][c] = d return board # clone the current status of the sudoku def clone(self): sudoku = Sudoku() sudoku.mask = copy.copy(self.mask) sudoku.valid = self.valid return sudoku # the main function to solve the sudoku def solve(self): # call the recursive function solve_helper to solve self.solve_helper() # convert the solution masks into human readable two-dimensional array and stored in result_boards for result in self.solutions: self.result_boards.append(self.mask_to_board(result)) # recursive function to solve the board def solve_helper(self): # choose the vacant cell with the fewest possibilities cell = self.find_vacant_with_min_possibilities() # if all cells have been filled (no vacant cell), we have found a solution! if cell is None: self.add_solution(self.mask) return # try the remaining possible value in this cell for d in self.digits: # skip if d is not allowed in this cell if not (self.mask[cell] & (1<<(d-1))): continue # clone the sudoku status... sudoku = self.clone() # ... and try digit d in the cloned one to start searching for a solution # if setting d in this cell leads to an unsolvable sudoku: stop further processing if not sudoku.set(cell, d): continue # solve the cloned sudoku with the newly filled value sudoku.solve_helper() # if we found any solutions for the cloned sudoku: if len(sudoku.solutions) > 0: # collect those solutions for our current sudoku for solution in sudoku.solutions: self.add_solution(solution) # the problem requires us raise error if there are more than one solution, # so we can skip further processing if 2 solutions have been found if len(self.solutions) >= 2: return # a mask is considered as set if there's only one bit turned on. # m has exactly one bit turned on if (m & (m - 1)) == 0 def is_single_bit(self, m): return (m & (m - 1)) == 0 # count number of turned on bits in a mask def count_bits(self, m): count = 0 while m: m &= (m - 1) count += 1 return count # add a solution to our collection, skip if that solution already exists def add_solution(self, mask): for result in self.solutions: if result == mask: return self.solutions.append(copy.deepcopy(mask)) # find the vacant cell with fewest allowed value def find_vacant_with_min_possibilities(self): vacant_cnt = 0 best_vacant_possibilities = 10 best_vacant_i = 0 for i in range(81): if best_vacant_possibilities == 2: break; if not self.is_single_bit(self.mask[i]): vacant_cnt += 1 choices = self.count_bits(self.mask[i]) if choices < best_vacant_possibilities: best_vacant_possibilities = choices best_vacant_i = i if (vacant_cnt == 0): # no more vacant cell: return None return best_vacant_i # set digit d at cell by clearing all the other bits (except for dth bit) in mask[cell] # return False if a contradiction is detected. # return True otherwise def set(self, cell, d): other_values = [ d2 for d2 in self.digits if d2 != d and self.mask[cell] & (1<<(d2-1)) ] for d2 in other_values: if not self.clear(cell, d2): return False return True # removing a digit d from being allowed at cell by clearing the dth bit def clear(self, cell, d): # if already cleared if not (self.mask[cell] & (1<<(d-1))): return True # turn off bit at d to mark d is no longer allowed at this cell self.mask[cell] &= ~(1<<(d-1)) # Rule 1: If this cell has only one allowed value d2, # then make d2 the value at cell and eliminate d2 from the peers. if self.mask[cell] == 0: return False # error: no value is allowed at this cell (Rule 3) elif self.is_single_bit(self.mask[cell]): val = int(math.log2(self.mask[cell])) + 1 for cell2 in self.peers[cell]: if not self.clear(cell2, val): return False ## Rule 2: If all cells in the unit cannot be filled with d except for one cell2, # we can fill d in that cell2. for u in self.units[cell]: dplaces = [cell2 for cell2 in u if self.mask[cell2] & (1<<(d-1))] if len(dplaces) == 0: return False # error: no place for this value (Rule 4) elif len(dplaces) == 1: # d can only be in one place in unit; assign it there if not self.set(dplaces[0], d): return False return True
Test the code
puzzle = \ [[8, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 3, 6, 0, 0, 0, 0, 0], [0, 7, 0, 0, 9, 0, 2, 0, 0], [0, 5, 0, 0, 0, 7, 0, 0, 0], [0, 0, 0, 0, 4, 5, 7, 0, 0], [0, 0, 0, 1, 0, 0, 0, 3, 0], [0, 0, 1, 0, 0, 0, 0, 6, 8], [0, 0, 8, 5, 0, 0, 0, 1, 0], [0, 9, 0, 0, 0, 0, 4, 0, 0]] solution = \ [[8, 1, 2, 7, 5, 3, 6, 4, 9], [9, 4, 3, 6, 8, 2, 1, 7, 5], [6, 7, 5, 4, 9, 1, 2, 8, 3], [1, 5, 4, 2, 3, 7, 8, 9, 6], [3, 6, 9, 8, 4, 5, 7, 2, 1], [2, 8, 7, 1, 6, 9, 5, 3, 4], [5, 2, 1, 9, 7, 4, 3, 6, 8], [4, 3, 8, 5, 2, 6, 9, 1, 7], [7, 9, 6, 3, 1, 8, 4, 5, 2]] assert(sudoku_solver(puzzle) == solution)
Result
The new code solve the "very hard puzzle" above in about 0.1s. Compared to 40s in the backtracking approach, this is a 400 times increase in performance.
Try to submit the code on Codewars, and we pass all the tests in less then 300ms. Yay!
Conclusion
By applying some optimizations on both the algorithm and the implementation, we can see a significant increase in performance of our sudoku solver program.
Notice that the rules added to the algorithm are copied from the way that "human" uses to solve sudokus. Does that mean human brains have always been implemented with the most advanced algorithms?
There are still places for improvements, to solve even harder puzzles like the one below:
# 17 cells with thousands of solution puzzle = \ [[0, 0, 0, 0, 0, 6, 0, 0, 0], [0, 5, 9, 0, 0, 0, 0, 0, 8], [2, 0, 0, 0, 0, 8, 0, 0, 0], [0, 4, 5, 0, 0, 0, 0, 0, 0], [0, 0, 3, 0, 0, 0, 0, 0, 0], [0, 0, 6, 0, 0, 3, 0, 5, 4], [0, 0, 0, 3, 2, 5, 0, 0, 6], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0]]
or this one:
# 16 cells with no solution puzzle = \ [[0, 0, 0, 0, 0, 5, 0, 8, 0], [0, 0, 0, 6, 0, 1, 0, 4, 3], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 5, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 6, 0, 0, 0], [3, 0, 0, 0, 0, 0, 0, 0, 5], [5, 3, 0, 0, 0, 0, 0, 6, 1], [0, 0, 0, 0, 0, 0, 0, 0, 4], [0, 0, 0, 0, 0, 0, 0, 0, 0]]
For these two puzzles, the algorithm in this topic still takes too long to complete. Hope we can find better algorithms to solve these puzzle faster in the future. On the other hand, solving sudoku in parallel (e.g. multithread or MapReduce) is also an interesting topic to discuss.