Leetcode 957. Prison Cells After N Days

957. Prison Cells After N Days

Implication : After day 1, cells[0] and cells[7] must be 0. So the total number of states is reduced from 2^8 to 2^6, which will save time and space.

Reference Solution


class Solution:
    def prisonAfterNDays(self, cells: List[int], N: int) -> List[int]:
        # Get the transit matrix and loop group.
        self.init(bins=8)
        day1 = self.GetNextDay(cells)
        num = self.ConvertListToInt(day1)
        # Find current position and end position in loop.
        groupIndex, loopIndex = self.posDict[num]
        endDay = (loopIndex + N - 1) % len(self.loopGroup[groupIndex])
        
        return self.ConvertIntToList(self.loopGroup[groupIndex][endDay])
                
    def ConvertListToInt(self, cells):
        # Only convert useful digits(1~6).
        power = 1
        result = 0
        for digit in cells[1:self.bins-1]:
            result += digit * power
            power *= 2
        
        return result
    
    def ConvertIntToList(self, num):
        # Only convert useful digits(1~6).
        k = num
        result = [0] * self.bins
        for i in range(self.bins-2):
            result[i+1] = k % 2
            k = k >> 1
        
        return result
    
    def GetNextDay(self, origin):
        r = [0] * self.bins
        for j in range(1,self.bins-1):
            if origin[j-1] == origin[j+1] :
                r[j] = 1
            else :
                r[j] = 0
        r[0] = r[7] = 0
        return r

    def init(self, bins):     
        ## We can use hardcode to save time.
        
        # self.posDict={0: (0, 0), 63: (0, 1), 30: (0, 2), 12: (0, 3), 33: (0, 4), 45: (0, 5), 51: (0, 6), 1: (1, 0), 61: (1, 1), 27: (1, 2), 4: (1, 3), 53: (1, 4), 15: (1, 5), 38: (1, 6), 32: (1, 7), 47: (1, 8), 54: (1, 9), 8: (1, 10), 43: (1, 11), 60: (1, 12), 25: (1, 13), 2: (2, 0), 58: (2, 1), 22: (2, 2), 24: (2, 3), 3: (2, 4), 56: (2, 5), 19: (2, 6), 16: (2, 7), 23: (2, 8), 26: (2, 9), 6: (2, 10), 48: (2, 11), 7: (2, 12), 50: (2, 13), 5: (5, 0), 55: (5, 1), 10: (5, 2), 46: (5, 3), 52: (5, 4), 13: (5, 5), 35: (5, 6), 40: (5, 7), 59: (5, 8), 20: (5, 9), 29: (5, 10), 11: (5, 11), 44: (5, 12), 49: (5, 13), 9: (9, 0), 41: (9, 1), 57: (9, 2), 17: (9, 3), 21: (9, 4), 31: (9, 5), 14: (9, 6), 36: (9, 7), 37: (9, 8), 39: (9, 9), 34: (9, 10), 42: (9, 11), 62: (9, 12), 28: (9, 13), 18: (18, 0)}
        
        # self.loopGroup={0: [0, 63, 30, 12, 33, 45, 51], 1: [1, 61, 27, 4, 53, 15, 38, 32, 47, 54, 8, 43, 60, 25], 2: [2, 58, 22, 24, 3, 56, 19, 16, 23, 26, 6, 48, 7, 50], 5: [5, 55, 10, 46, 52, 13, 35, 40, 59, 20, 29, 11, 44, 49], 9: [9, 41, 57, 17, 21, 31, 14, 36, 37, 39, 34, 42, 62, 28], 18: [18]}

        self.bins = bins
        transitMatrix = {}
        # Simulate the transit of 64 states.
        for startNum in range(2**(self.bins-2)):
            startState = self.ConvertIntToList(startNum)
            endState = self.GetNextDay(startState)
            endNum = self.ConvertListToInt(endState)
            transitMatrix[startNum] = endNum
            
        self.posDict = {}            
        self.loopGroup = {}
        for startNum in range(2**(self.bins-2)):
            currentNum = startNum
            
            loopIndex = 0
            while currentNum not in self.posDict :
                self.posDict[currentNum] = (startNum, loopIndex)
                loopIndex += 1
                
                if startNum not in self.loopGroup :
                    self.loopGroup[startNum] = [startNum]
                else :
                    self.loopGroup[startNum].append(currentNum)
                # print("%d: %d -> %d" % (self.posDict[currentNum], currentNum, transitMatrix[currentNum]))
                currentNum = transitMatrix[currentNum]
                # print("%d: %d" % (currentNum, self.loopGroup[currentNum]))