if not isinstance(self.sharded_param, DTensor):
raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}")
if updated_local_tensor:
# Only change the local tensor object if needed
with _no_grad():
local_view = local_tensor.narrow(dim=shard_dim, start=0, length=length)
set_requires_grad_if_needed(self.sharded_param, local_view)
self.sharded_param._local_tensor = local_view
if not self.sharded_param._local_tensor.is_contiguous():
raise AssertionError(
"Expected sharded_param._local_tensor to be contiguous"
)