Given an m x n matrix, return all elements of the matrix in spiral order.

class Solution:
    def spiralOrder(self, matrix: List[List[int]]) -> List[int]:
        ans = []
        N = len(matrix)
        M = len(matrix[0])
        total_size = M * N
        direction = 0
        i = 0
        j = 0
        cycle = 0
        
 
        while len(ans) < total_size:
            ans.append(matrix[i][j])
            if direction == 0:
                if j < (M - 1 - cycle):
                    j += 1
                else:
                    i += 1
                    direction += 1
            elif direction == 1:
                if i < (N - 1 - cycle):
                    i += 1
                else:
                    j -= 1
                    direction += 1
            elif direction == 2:
                if j > cycle:
                    j -= 1
                else:
                    i -= 1
                    direction += 1
            else:
                if i > cycle + 1:
                    i -= 1
                else:
                    direction = 0
                    cycle += 1
                    j += 1
            
        return ans