# NPU, so reuse it directly; the device path keeps it for Torch parity.ifself.offload_to_cpu:local_view=local_tensorelse:with_no_grad():local_view=local_tensor.narrow(dim=shard_dim,start=0,length=length)set_requires_grad_if_needed(self.sharded_param,local_view)self.sharded_param._local_tensor=local_viewifnotself.sharded_param._local_tensor.is_contiguous():raiseAssertionError(
667668669670671672673674675676
# wait here or the collective reads a half-written shard (verified racy).ms.runtime.synchronize()updated_local_tensor=Trueifself.pin_memoryandnotlocal_tensor.is_pinned():local_tensor=local_tensor.pin_memory()updated_local_tensor=Trueifself.offload_to_cpu:# Offloaded shards stay on host (``view(-1)`` would relocate them to# NPU); all_gather_inputs stages the device copy per unshard. Rebind# unconditionally: the to("cpu")/pin above can produce a new host
676677678679680681682683684685
# unconditionally: the to("cpu")/pin above can produce a new host# tensor even when same_local_tensor was True, leaving the old NPU# storage referenced here.self._sharded_param_data=local_tensorelifnotsame_local_tensor:self._sharded_param_data=local_tensor.view(-1)returnlocal_tensor,updated_local_tensordef_get_unsharded_param_data(self,async_op:bool=False)->Tuple[ms.Tensor,Optional[CommHandle]]:"""