diff --git a/legged_gym/envs/base/legged_robot.py b/legged_gym/envs/base/legged_robot.py index 98ae04d..00587b3 100644 --- a/legged_gym/envs/base/legged_robot.py +++ b/legged_gym/envs/base/legged_robot.py @@ -105,6 +105,10 @@ class LeggedRobot(BaseTask): self.compute_reward() env_ids = self.reset_buf.nonzero(as_tuple=False).flatten() self.reset_idx(env_ids) + + if self.cfg.domain_rand.push_robots: + self._push_robots() + self.compute_observations() # in some cases a simulation step might be required to refresh some obs (for example body positions) self.last_actions[:] = self.actions[:] @@ -139,6 +143,7 @@ class LeggedRobot(BaseTask): self._resample_commands(env_ids) # reset buffers + self.actions[env_ids] = 0. self.last_actions[env_ids] = 0. self.last_dof_vel[env_ids] = 0. self.feet_air_time[env_ids] = 0. @@ -284,9 +289,6 @@ class LeggedRobot(BaseTask): heading = torch.atan2(forward[:, 1], forward[:, 0]) self.commands[:, 2] = torch.clip(0.5*wrap_to_pi(self.commands[:, 3] - heading), -1., 1.) - if self.cfg.domain_rand.push_robots and (self.common_step_counter % self.cfg.domain_rand.push_interval == 0): - self._push_robots() - def _resample_commands(self, env_ids): """ Randommly select commands of some environments @@ -367,9 +369,17 @@ class LeggedRobot(BaseTask): def _push_robots(self): """ Random pushes the robots. Emulates an impulse by setting a randomized base velocity. """ + env_ids = torch.arange(self.num_envs, device=self.device) + push_env_ids = env_ids[self.episode_length_buf[env_ids] % int(self.cfg.domain_rand.push_interval) == 0] + if len(push_env_ids) == 0: + return max_vel = self.cfg.domain_rand.max_push_vel_xy self.root_states[:, 7:9] = torch_rand_float(-max_vel, max_vel, (self.num_envs, 2), device=self.device) # lin vel x/y - self.gym.set_actor_root_state_tensor(self.sim, gymtorch.unwrap_tensor(self.root_states)) + + env_ids_int32 = push_env_ids.to(dtype=torch.int32) + self.gym.set_actor_root_state_tensor_indexed(self.sim, + gymtorch.unwrap_tensor(self.root_states), + gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32))