Hard Sudoku Solver Algorithm – Part 2

Hard Sudoku Solver Algorithm – Part 2

Introduction

In part 1, we have solved some hard sudoku puzzles using the backtracking approach. While the algorithm did the job, it couln’t beat the time limit in the more advanced challenge: Hard Sudoku Solver 1.

This time, we are going to find a better solution so that we can pass the tests in the new challenge within the 10-second time limit.

It is recommended that you try to practice on your own first before going on to read the solution.

Problem description

From Codewars’ Hard Sudoku Solver 1:

There are several difficulty of sudoku games, we can estimate the difficulty of a sudoku game based on how many cells are given of the 81 cells of the game.

  • Easy sudoku generally have over 32 givens
  • Medium sudoku have around 30–32 givens
  • Hard sudoku have around 28–30 givens
  • Very Hard sudoku have less than 28 givens

Note: The minimum of givens required to create a unique (with no multiple solutions) sudoku game is 17.

A hard sudoku game means that at start no cell will have a single candidates and thus require guessing and trial and error. A very hard will have several layers of multiple candidates for any empty cell.

Your Task

Write a function that solves sudoku puzzles of any difficulty. The function will take a sudoku grid and it should return a 9×9 array with the proper answer for the puzzle.

Or it should raise an error in cases of: invalid grid (not 9×9, cell with values not in the range 1~9); multiple solutions for the same puzzle or the puzzle is unsolvable.

Analysis

Compared to the problem in part 1, this problem requires us to raise an error if there are more than one solution to the sudoku. Therefore, instead of return the first solution we found, we have to keep testing other possibilities to make sure that there are no other solution to the sudoku. The good news is, we can return an error early immediately when we found the second solution.

The hardest test that we are going to solve is the puzzle below, which has only 17 given numbers:

#### Should solve very hard puzzle
puzzle = \
    [[8, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 3, 6, 0, 0, 0, 0, 0],
     [0, 7, 0, 0, 9, 0, 2, 0, 0],
     [0, 5, 0, 0, 0, 7, 0, 0, 0],
     [0, 0, 0, 0, 4, 5, 7, 0, 0],
     [0, 0, 0, 1, 0, 0, 0, 3, 0],
     [0, 0, 1, 0, 0, 0, 0, 6, 8],
     [0, 0, 8, 5, 0, 0, 0, 1, 0],
     [0, 9, 0, 0, 0, 0, 4, 0, 0]]
solution = \
    [[8, 1, 2, 7, 5, 3, 6, 4, 9],
     [9, 4, 3, 6, 8, 2, 1, 7, 5],
     [6, 7, 5, 4, 9, 1, 2, 8, 3],
     [1, 5, 4, 2, 3, 7, 8, 9, 6],
     [3, 6, 9, 8, 4, 5, 7, 2, 1],
     [2, 8, 7, 1, 6, 9, 5, 3, 4],
     [5, 2, 1, 9, 7, 4, 3, 6, 8],
     [4, 3, 8, 5, 2, 6, 9, 1, 7],
     [7, 9, 6, 3, 1, 8, 4, 5, 2]]
assert(sudoku_solver(puzzle) == solution)

The backtracking approach

The first thing that we want to try is of course reusing the solution from the last challenge, with some modifications to fit the new additional requirements.

The sample code goes as below:

import copy

def sudoku_solver(puzzle):
    sudoku = Sudoku(puzzle)
    if not sudoku.valid:
        raise ValueError

    sudoku.solve()

    if len(sudoku.solutions) != 1:
        raise ValueError

    return sudoku.solutions[0]

class Sudoku:

    def __init__(self, board):
        # the 9x9 sudoku board
        self.board = []
        # list of solutions to this sudoku
        self.solutions = []
        # is the board valid
        self.valid = True


        # init a blank board with all cells filled with zero
        for i in range(9):
            r = [0] * 9
            self.board.append(r)

        # check the input board dimensions
        if not self.validate(board):
            self.valid = False

        # copy the input board to self.board
        for r in range(len(board)):
            for c in range(len(board[r])):
                if board[r][c]:
                    if not self.set(r, c, board[r][c]):
                        self.valid = False


    # validate board dimensions
    def validate(self, board):
        if len(board) != 9:
            return False
        for r in range(len(board)):
            if len(board[r]) != 9:
                return False
        return True

    # the main function to solve the sudoku
    def solve(self):
        # recursively solve the sudoku, starting from cell 0
        self.solve_helper(0)
        return self.solutions

    # given the sudoku has been filled up to cell k-1, try to solve the sudoku from cell k
    # k is the cell index, counting from left to right and top to bottom.
    #     k is 0, 1, 2, ..., 8     for cell (0,0), (0,1), (0,2), ..., (0,8)
    #     k is 9, 10, 11, ...,     for cell (1,0), (1,1), (1,2), ...
    #     k is 80                  for cell (8,8) (last cell)
    def solve_helper(self, k):
        # if we get pass the last cell, it means we have filled every cells with valid values.
        # return True to notify that we have found a solution
        if (k > 80):
            self.solutions.append(copy.deepcopy(self.board))
            return

        r = int(k / 9)
        c = k % 9

        # if this cell has been filled, go on to solve the next cell
        if self.board[r][c] != 0:
            return self.solve_helper(k+1)

        # try to fill each value from 1 to 9 in this cell
        for x in range(1, 10):
            # fill the cell with value x only if x has not appeared on the same row or col or 3x3 box
            if self.check(r, c, x):
                # start backtracking:
                # try x in this cell,
                self.board[r][c] = x
                # then try to solve from the next cell k+1,
                self.solve_helper(k+1)
                if len(self.solutions) >= 2:
                    # the problem requires us raise error if there are more than one solution,
                    # so we can skip further processing if 2 solutions have been found
                    return
                # then clear cell to return the board to the status before x was filled
                self.board[r][c] = 0

        # if we are here, it means we have tried all values in this cell without finding a solution
        # return False to notify upper recursive solve_helper that
        # we didn't find any solution given the current board status
        return False

    # check if value x can be put at cell[row][col]
    # return False   if value x has already been used in other cells on current row, or column, or 3x3 box
    # return True    otherwise
    def check(self, row, col, x):
        if not isinstance(x, int):
            return False
        if x < 1 or x > 9:
            return False
        for i in range(9):
            if self.board[row][i] == x:
                return False
            if self.board[i][col] == x:
                return False
        box_start_row = row - row % 3
        box_start_col = col - col % 3
        for r in range(box_start_row, box_start_row + 3):
            for c in range(box_start_col, box_start_col + 3):
                if self.board[r][c] == x:
                    return False

        return True

    # check if x can be put in cell [r][c]
    # if yes, put x in cell [r][c] and return True
    # if no, return False
    def set(self, r, c, x):
        if not self.check(r, c, x):
            return False

        self.board[r][c] = x
        return True

Test results

This code takes 45 seconds on my computer to solve the very hard puzzle above. Obviously, it didn’t make it through the Codewars’ hidden tests within the time limit.

Algorithm Optimizations

In the backtracking approach, we pick the first vacant cell and try every value at it. For each value, we then recursively pick the next vacant cell and try every value at the new cell. Although we do skip the values that will violates the existing values within the same row, column, and box, the number of branches that the program must try are still overwhelming.

The whole idea to optimize this algorithm is about reducing the branches that we have to test as much as possible.

Algorithm Optimization #1: Try to infer the values for the vacant cells from the existing values

The values can be inferred in two ways:

  • Rule 1: If cell [r][c] can only be filled with one value d (because other values already exists in the same row, column, or box), we can fill d in [r][c].
  • Rule 2: Because a value d must be filled exactly once in each unit (a unit means a row, a column, or a box), if all cells in a unit cannot be filled with d except for one cell, we can fill d in that cell.

Algorithm Optimization #2: Try to fail as early as possible

If we can detect that the current board status cannot be completed in any way, stop further processing on that case and return immediately.

We can detect this condition with the following checks:

  • Rule 3: If cell [r][c] cannot be filled with any value (because all values already exists in the same row, column, or box), the board is unsolvable.
  • Rule 4: If value d cannot be filled in any cell within a unit, the board is unsolvable.

Algorithm Optimization #3: Try to branch starting from the cell with fewest possible values

When no more cells can be inferred from the board, we have to go back to testing values at a vacant cell and recursively solve the board with the newly filled test value, similar to what we have been doing in the backtracking approach.

However, we can significantly reduce the number of cases to test by picking the cell with the minimum number of possible values, instead of picking the first vacant cell as usual. It’s like reducing the cases from 9*9*9*9*... to 2*3*2*2*....

Not only that this reduces the number of cases on each recursion, it also significantly reduces the depth of recursion, thanks to the inference process in Algorithm Optimization #1. In practice, this approach can reduces the number of recursive calls from hundreds of thousands calls to only several hundred calls.

Effectiveness

Believe it or not, using the above optimizations, most puzzles on Codewars’ tests, including hard ones, can be solved using only one or two recursive calls. By one recursive call, I mean we can inferred the whole board and don’t even need to test any value at any cell. In fact, except for the “very hard puzzle” mention at the beginning of this topic, all other puzzles can be solved with 6 recursive calls at most.

Implementation Optimizations

Although the algorithm optmizations seem promising, a bad implementation can ruin everything. Next, we are going to look at some implementation optimizations.

Implementation Optimization #1: Use a bit mask to mark allowed values at a cell

To efficiently check the possibilities at a cell, we can use a bit mask for each cell. The bit mask will be initialized as 0x1ff in hex or 111111111 in binary.

  • If value d is still allowed to be filled in a cell without conflicting with any other cells in the same row, column, or box, the dth bit of the cell’s mask is turned on.
  • If value d cannot be filled in a cell, the dth bit of the cell’s mask is turned off

More facts can be inferred from the above mask rule:

  • The initial value of the cell’s mask contains 9 set bits, indicating that all 9 values are still allowed in that cell.
  • If a mask has only 1 set bit at dth position, we can consider that d is filled in that cell (Rule 1).
  • If a mask is 0, we can consider the puzzle status is unsolvable (Rule 3).
  • Everytime a value d is set at cell [r][c], we also clear the dth bit in every cells in the same row, column, or box.

Some useful functions:
Check if a mask has only one bit (which means that cell has been set with a value):

def is_single_bit(m):
    return (m & (m - 1)) == 0

where m is the mask.

Check if a value d is allowed in cell [r][c]:

def is_allowed(m, d):
    return m & (1<<(d-1))

where m is the mask, d is the value we want to check.

Get the corresponding value from the mask, given we already know that the mask has only one set bit:

def get_value(m):
    return int(math.log2(m)) + 1

where m is the mask.

Count number of set bits in a mask:

def count_bits(m):
    count = 0
    while m:
        m &= (m - 1)
        count += 1
    return count

where m is the mask.

I have also tried using a precalculated bits_set table for fast lookup when counting number of set bits in a mask, however I cannot notice a performance increase in term of running time. The function goes as follow:

# precalculate the bits_set index table
bits_set = [0] * 256
for i in range(256):
    for d in range(1,10):
        if i & (1 << d-1):
            bits_set[i] += 1
def count_bits(m):
    return bits_set[m & 255] + bits_set[(m >> 8) & 255] \
        + bits_set[(m >> 16) & 255] + bits_set[(m >> 24) & 255]

Implementation Optimization #2: Use single dimension array instead of two-dimension array

At least, in Python. In my case, use single dimension array to store the mask instead of a two-dimension array speed the code up by two times.

Implementation Optimization #3: Set a value d to a cell by clearing all other bits instead of directly setting the dth bit

By doing this, everytime we clear an eth bit, we have a chance to count the number of remaining cells that still allow e in the related units (row, column, or box). If that number is 1, we can infer the place to fill e (Rule 2). If the number is 0, we can fail early (Rule 4).

Solution

Sample code

The following sample code is written in Python.

import copy
import math

def sudoku_solver(puzzle):
    # init a blank sudoku
    sudoku = Sudoku()
    
    # set the input board to our sudoku
    sudoku.setboard(puzzle)
    
    # if the input is invalid, raise an error
    if not sudoku.valid:
        raise ValueError

    # solve the sudoku, the results (if any) will be stored in sudoku.result_boards
    sudoku.solve()

    # if there are no solution, or there are more than 1 solution
    if len(sudoku.result_boards) != 1:
        raise ValueError

    # return the only solution
    return sudoku.result_boards[0]


class Sudoku:
    ### Notes:
    # cells will be indexed from 0 to 80


    ### init static constants:

    # list of the digits that will be filled in the sudoku
    digits = [1, 2, 3, 4, 5, 6, 7, 8, 9]

    # List of block/units that each cell stays in.
    # For example, cell[9] (row 1, col 0) has 3 units:
    # units[9][0] == [9, 10, 11, 12, 13, 14, 15, 16, 17] (all cells in the same row)
    # units[9][1] == [0,  9, 18, 27, 36, 45, 54, 63, 72] (all cells in the same column)
    # units[9][2] == [0,  1,  2,  9, 10, 11, 18, 19, 20] (all cells in the same box)
    units = []

    # list of all peers of each cell, which are all the unique cells in the same row, column and box.
    # For example, cell[9] (row 1, col 0) has the following peers:
    # peers[9] == [9, 10, 11, 12, 13, 14, 15, 16, 17, 0, 18, 27, 36, 45, 54, 63, 72, 0, 1, 2, 19, 20]
    peers = []

    # init units and peers table
    for i in range(81):
        units.append([])
        # add cells in same row
        units[i].append([])
        r = int(i / 9)
        for c in range(9):
            units[i][0].append(r * 9 + c)
        # add cells in same col
        units[i].append([])
        c = int(i % 9)
        for r in range(9):
            units[i][1].append(r * 9 + c)
        # add cells in same box
        units[i].append([])
        br = int(int(i / 9) / 3)
        cr = int(int(i % 9) / 3)
        for r in range(br * 3, br * 3 + 3):
            for c in range(cr * 3, cr * 3 + 3):
                units[i][2].append(r * 9 + c)
        # collect all neighbor cells of each cell
        peers.append([])
        for unit in units[i]:
            for cell in unit:
                if cell not in peers[i]:
                    peers[i].append(cell)
        peers[i].remove(i)

    # init a blank sudoku
    def __init__(self):
        self.mask = []
        self.valid = True
        self.solutions = []
        self.result_boards = []

    # set the input board to our sudoku
    def setboard(self, board):
        # the mask array of the 80 cells
        self.mask = []

        # whether the sudoku is valid
        self.valid = True

        # the list of solutions (if any) in mask format
        self.solutions = []

        # the list of solutions (if any) in human readable array format
        self.result_boards = []

        # check the input board dimensions
        if not self.validate(board):
            self.valid = False

        # init mask matrix with all cells set to 0x1ff, indicating that all 9 digits are still allowed in that cell
        self.mask = [0x1ff] * 81

        # set the input board to this sudoku matrix, and also update the peers' masks for each cell along the way
        for r in range(len(board)):
            for c in range(len(board[r])):
                if board[r][c]:
                    # if the digit cannot be set at a cell, we mark that the input board is invalid (unsolvable)
                    if not self.set(r*9+c, board[r][c]):
                        self.valid = False
                        return


    # validate board dimensions
    def validate(self, board):
        if len(board) != 9:
            return False
        for r in range(len(board)):
            if len(board[r]) != 9:
                return False
        return True

    # convert mask to human readable two-dimensional array
    def mask_to_board(self, mask):
        board = []
        for r in range(9):
            board.append([0] * 9)
        for r in range(9):
            for c in range(9):
                if self.is_single_bit(mask[r*9+c]):
                    for d in self.digits:
                        if mask[r*9+c] & (1 << (d-1)):
                            board[r][c] = d
        return board

    # clone the current status of the sudoku
    def clone(self):
        sudoku = Sudoku()
        sudoku.mask = copy.copy(self.mask)
        sudoku.valid = self.valid
        return sudoku

    # the main function to solve the sudoku
    def solve(self):
        # call the recursive function solve_helper to solve
        self.solve_helper()

        # convert the solution masks into human readable two-dimensional array and stored in result_boards
        for result in self.solutions:
            self.result_boards.append(self.mask_to_board(result))

    # recursive function to solve the board
    def solve_helper(self):
        # choose the vacant cell with the fewest possibilities
        cell = self.find_vacant_with_min_possibilities()

        # if all cells have been filled (no vacant cell), we have found a solution!
        if cell is None:
            self.add_solution(self.mask)
            return

        # try the remaining possible value in this cell
        for d in self.digits:
            # skip if d is not allowed in this cell
            if not (self.mask[cell] & (1<<(d-1))):
                continue

            # clone the sudoku status...
            sudoku = self.clone()

            # ... and try digit d in the cloned one to start searching for a solution
            # if setting d in this cell leads to an unsolvable sudoku: stop further processing
            if not sudoku.set(cell, d):
                continue

            # solve the cloned sudoku with the newly filled value
            sudoku.solve_helper()

            # if we found any solutions for the cloned sudoku:
            if len(sudoku.solutions) > 0:
                # collect those solutions for our current sudoku
                for solution in sudoku.solutions:
                    self.add_solution(solution)

            # the problem requires us raise error if there are more than one solution,
            # so we can skip further processing if 2 solutions have been found
            if len(self.solutions) >= 2:
                return


    # a mask is considered as set if there's only one bit turned on.
    # m has exactly one bit turned on if (m & (m - 1)) == 0
    def is_single_bit(self, m):
        return (m & (m - 1)) == 0

    # count number of turned on bits in a mask
    def count_bits(self, m):
        count = 0
        while m:
            m &= (m - 1)
            count += 1
        return count

    # add a solution to our collection, skip if that solution already exists
    def add_solution(self, mask):
        for result in self.solutions:
            if result == mask:
                return
        self.solutions.append(copy.deepcopy(mask))


    # find the vacant cell with fewest allowed value
    def find_vacant_with_min_possibilities(self):
        vacant_cnt = 0
        best_vacant_possibilities = 10
        best_vacant_i = 0
        for i in range(81):
            if best_vacant_possibilities == 2:
                break;
            if not self.is_single_bit(self.mask[i]):
                vacant_cnt += 1
                choices = self.count_bits(self.mask[i])

                if choices < best_vacant_possibilities:
                    best_vacant_possibilities = choices
                    best_vacant_i = i

        if (vacant_cnt == 0):
            # no more vacant cell:
            return None

        return best_vacant_i

    # set digit d at cell by clearing all the other bits (except for dth bit) in mask[cell]
    # return False if a contradiction is detected.
    # return True otherwise
    def set(self, cell, d):
        other_values = [ d2 for d2 in self.digits if d2 != d and self.mask[cell] & (1<<(d2-1)) ]
        for d2 in other_values:
            if not self.clear(cell, d2):
                return False
        return True

    # removing a digit d from being allowed at cell by clearing the dth bit
    def clear(self, cell, d):
        # if already cleared
        if not (self.mask[cell] & (1<<(d-1))):
            return True

        # turn off bit at d to mark d is no longer allowed at this cell
        self.mask[cell] &= ~(1<<(d-1))

        # Rule 1: If this cell has only one allowed value d2,
        # then make d2 the value at cell and eliminate d2 from the peers.
        if self.mask[cell] == 0:
            return False  # error: no value is allowed at this cell (Rule 3)
        elif self.is_single_bit(self.mask[cell]):
            val = int(math.log2(self.mask[cell])) + 1
            for cell2 in self.peers[cell]:
                if not self.clear(cell2, val):
                    return False

        ## Rule 2: If all cells in the unit cannot be filled with d except for one cell2,
        # we can fill d in that cell2.
        for u in self.units[cell]:
            dplaces = [cell2 for cell2 in u if self.mask[cell2] & (1<<(d-1))]
            if len(dplaces) == 0:
                return False  # error: no place for this value (Rule 4)
            elif len(dplaces) == 1:
                # d can only be in one place in unit; assign it there
                if not self.set(dplaces[0], d):
                    return False
        return True

Test the code

puzzle = \
    [[8, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 3, 6, 0, 0, 0, 0, 0],
     [0, 7, 0, 0, 9, 0, 2, 0, 0],
     [0, 5, 0, 0, 0, 7, 0, 0, 0],
     [0, 0, 0, 0, 4, 5, 7, 0, 0],
     [0, 0, 0, 1, 0, 0, 0, 3, 0],
     [0, 0, 1, 0, 0, 0, 0, 6, 8],
     [0, 0, 8, 5, 0, 0, 0, 1, 0],
     [0, 9, 0, 0, 0, 0, 4, 0, 0]]
solution = \
    [[8, 1, 2, 7, 5, 3, 6, 4, 9],
     [9, 4, 3, 6, 8, 2, 1, 7, 5],
     [6, 7, 5, 4, 9, 1, 2, 8, 3],
     [1, 5, 4, 2, 3, 7, 8, 9, 6],
     [3, 6, 9, 8, 4, 5, 7, 2, 1],
     [2, 8, 7, 1, 6, 9, 5, 3, 4],
     [5, 2, 1, 9, 7, 4, 3, 6, 8],
     [4, 3, 8, 5, 2, 6, 9, 1, 7],
     [7, 9, 6, 3, 1, 8, 4, 5, 2]]
assert(sudoku_solver(puzzle) == solution)

Result

The new code solve the "very hard puzzle" above in about 0.1s. Compared to 40s in the backtracking approach, this is a 400 times increase in performance.

Try to submit the code on Codewars, and we pass all the tests in less then 300ms. Yay!

Conclusion

By applying some optimizations on both the algorithm and the implementation, we can see a significant increase in performance of our sudoku solver program.

Notice that the rules added to the algorithm are copied from the way that "human" uses to solve sudokus. Does that mean human brains have always been implemented with the most advanced algorithms?

There are still places for improvements, to solve even harder puzzles like the one below:

# 17 cells with thousands of solution
puzzle = \
    [[0, 0, 0, 0, 0, 6, 0, 0, 0],
     [0, 5, 9, 0, 0, 0, 0, 0, 8],
     [2, 0, 0, 0, 0, 8, 0, 0, 0],
     [0, 4, 5, 0, 0, 0, 0, 0, 0],
     [0, 0, 3, 0, 0, 0, 0, 0, 0],
     [0, 0, 6, 0, 0, 3, 0, 5, 4],
     [0, 0, 0, 3, 2, 5, 0, 0, 6],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]]

or this one:

# 16 cells with no solution
puzzle = \
    [[0, 0, 0, 0, 0, 5, 0, 8, 0],
     [0, 0, 0, 6, 0, 1, 0, 4, 3],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 1, 0, 5, 0, 0, 0, 0, 0],
     [0, 0, 0, 1, 0, 6, 0, 0, 0],
     [3, 0, 0, 0, 0, 0, 0, 0, 5],
     [5, 3, 0, 0, 0, 0, 0, 6, 1],
     [0, 0, 0, 0, 0, 0, 0, 0, 4],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]]

For these two puzzles, the algorithm in this topic still takes too long to complete. Hope we can find better algorithms to solve these puzzle faster in the future. On the other hand, solving sudoku in parallel (e.g. multithread or MapReduce) is also an interesting topic to discuss.

Leave a Reply

Be the First to Comment!

Notify of
wpDiscuz