Update helpers.py

This commit is contained in:
Ziluo Ding 2024-07-01 16:26:40 +08:00
parent e9107dfb7c
commit 6f298b4435

View File

@ -157,8 +157,7 @@ def export_policy_as_jit(actor_critic, path):
path = os.path.join(path, 'policy_1.pt') path = os.path.join(path, 'policy_1.pt')
model = copy.deepcopy(actor_critic.actor).to('cpu') model = copy.deepcopy(actor_critic.actor).to('cpu')
traced_script_module = torch.jit.script(model) traced_script_module = torch.jit.script(model)
traced_script_module.save(path) traced_script_module.save(path)
class PolicyExporterLSTM(torch.nn.Module): class PolicyExporterLSTM(torch.nn.Module):
def __init__(self, actor_critic): def __init__(self, actor_critic):