Returns:
tuple: (local_args, local_kwargs, cache_values) where local_args contains
local tensors for x and w; cache_values contains [x_layout, w_layout, transpose_a, transpose_b].
"""
args, kwargs = _normalize_matmul_args(*args, **kwargs)
x_tensor, w_tensor, transpose_a, transpose_b = args
local_args = (x_tensor.to_local(), w_tensor.to_local(), transpose_a, transpose_b)
local_kwargs = {}
cache_values = [x_tensor.layout, w_tensor.layout, transpose_a, transpose_b]
return local_args, local_kwargs, cache_values
def infer_layout(self, cache_values: list) -> Tuple[tuple, None]:
"""
Infer output layout for MatMul operator (output = x @ w, with possible transpose).