diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 0c7c9f7a1090d..dd1fd3dff5df2 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -549,7 +549,7 @@ def _aggregate_total_loss(self): agg_loss /= self.dp_world_size assert self.global_rank in self.grid.pp_group - losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device) + losses = torch.stack([self.dp_group_loss, agg_loss]) if self.is_pipe_parallel: dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group()) else: