Untitled
unknown
plain_text
a year ago
12 kB
7
Indexable
Never
ipdb> where /home/weirdab/goalrelabel_locobot_fullgcsl/launch_main.py(286)<module>() 284 print("params before run", params) 285 --> 286 run(**params) 287 # dd_utils.launch(run, params, mode='local', instance_type='c4.xlarge') 288 /home/weirdab/goalrelabel_locobot_fullgcsl/launch_main.py(210)run() 208 **huge_kwargs 209 ) --> 210 algo.train() 211 212 /home/weirdab/goalrelabel_locobot_fullgcsl/huge/algo/huge.py(1511)train() 1509 1510 self.policy.eval() -> 1511 self.evaluate_policy(self.eval_episodes, greedy=True, prefix='Eval') 1512 last_time = time.time() 1513 # End Evaluation Code /home/weirdab/goalrelabel_locobot_fullgcsl/huge/algo/huge.py(1864)evaluate_policy() 1862 goal = self.env.extract_goal(self.env.sample_goal()) 1863 -> 1864 states, actions, goal_state, _, _ , img_states= self.sample_trajectory(goal=goal, greedy=greedy, save_video_trajectory=index==0, video_filename=video_filename) 1865 all_actions.extend(actions) 1866 all_states.append(states) /home/weirdab/goalrelabel_locobot_fullgcsl/huge/algo/huge.py(1109)sample_trajectory() 1107 action = self.policy.act_vectorized(observation_image[None], img_goal[None], horizon=horizon[None], greedy=True, noise=0)[0] 1108 else: -> 1109 action = self.policy.act_vectorized(observation[None], goal[None], horizon=horizon[None], greedy=True, noise=0)[0] 1110 else: 1111 if self.use_images_in_policy: /home/weirdab/goalrelabel_locobot_fullgcsl/huge/algo/networks.py(563)act_vectorized() 561 # if horizon is not None: 562 # horizon = torch.tensor(horizon, dtype=torch.float32) --> 563 dist = self.forward(obs, goal, horizon) 564 565 samples = dist.sample() /home/weirdab/goalrelabel_locobot_fullgcsl/huge/algo/networks.py(545)forward() 543 (obs, goal), 544 temperature=1.0, --> 545 name="actor", 546 ) 547 return dist [... skipping 1 hidden frame(s)] > /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(1060)apply() 1058 method, self, 1059 mutable=mutable, capture_intermediates=capture_intermediates -> 1060 )(variables, *args, **kwargs, rngs=rngs) 1061 1062 @traceback_util.api_boundary /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/core/scope.py(691)wrapper() 689 **kwargs) -> Union[Any, Tuple[Any, VariableDict]]: 690 with bind(variables, rngs=rngs, mutable=mutable).temporary() as root: --> 691 y = fn(root, *args, **kwargs) 692 if mutable is not False: 693 return y, root.mutable_variables() /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(1312)scope_fn() 1310 _context.capture_stack.append(capture_intermediates) 1311 try: -> 1312 return fn(module.clone(parent=scope), *args, **kwargs) 1313 finally: 1314 _context.capture_stack.pop() /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/transforms.py(602)wrapped_fn() 600 def wrapped_fn(self, *args, **kwargs): 601 if not force and not linen_module._use_named_call: --> 602 return prewrapped_fn(self, *args, **kwargs) 603 fn_name = class_fn.__name__ 604 method_suffix = f'.{fn_name}' if fn_name != '__call__' else '' /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(318)wrapped_module_method() 316 if args and isinstance(args[0], Module): 317 self, args = args[0], args[1:] --> 318 return self._call_wrapped_method(fun, args, kwargs) 319 else: 320 return fun(*args, **kwargs) /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(602)_call_wrapped_method() 600 _context.module_stack.append(self) 601 try: --> 602 y = fun(self, *args, **kwargs) 603 if _context.capture_stack: 604 filter_fn = _context.capture_stack[-1] /home/weirdab/robonetv3/code/jaxrl_minimal/jaxrl_m/common/common.py(77)__call__() 75 return out 76 ---> 77 return self.modules[name](*args, **kwargs) 78 79 /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/transforms.py(602)wrapped_fn() 600 def wrapped_fn(self, *args, **kwargs): 601 if not force and not linen_module._use_named_call: --> 602 return prewrapped_fn(self, *args, **kwargs) 603 fn_name = class_fn.__name__ 604 method_suffix = f'.{fn_name}' if fn_name != '__call__' else '' /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(318)wrapped_module_method() 316 if args and isinstance(args[0], Module): 317 self, args = args[0], args[1:] --> 318 return self._call_wrapped_method(fun, args, kwargs) 319 else: 320 return fun(*args, **kwargs) /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(602)_call_wrapped_method() 600 _context.module_stack.append(self) 601 try: --> 602 y = fun(self, *args, **kwargs) 603 if _context.capture_stack: 604 filter_fn = _context.capture_stack[-1] /home/weirdab/robonetv3/code/jaxrl_minimal/jaxrl_m/networks/actor_critic_nets.py(124)__call__() 122 self, observations: jnp.ndarray, temperature: float = 1.0, train: bool = False 123 ) -> distrax.Distribution: --> 124 outputs = self.network(self.encoder(observations), train=train) 125 126 means = nn.Dense(self.action_dim, kernel_init=default_init())(outputs) /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/transforms.py(602)wrapped_fn() 600 def wrapped_fn(self, *args, **kwargs): 601 if not force and not linen_module._use_named_call: --> 602 return prewrapped_fn(self, *args, **kwargs) 603 fn_name = class_fn.__name__ 604 method_suffix = f'.{fn_name}' if fn_name != '__call__' else '' /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(318)wrapped_module_method() 316 if args and isinstance(args[0], Module): 317 self, args = args[0], args[1:] --> 318 return self._call_wrapped_method(fun, args, kwargs) 319 else: 320 return fun(*args, **kwargs) /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(602)_call_wrapped_method() 600 _context.module_stack.append(self) 601 try: --> 602 y = fun(self, *args, **kwargs) 603 if _context.capture_stack: 604 filter_fn = _context.capture_stack[-1] /home/weirdab/robonetv3/code/jaxrl_minimal/jaxrl_m/common/encoding.py(75)__call__() 73 # early goal concat 74 encoder_inputs = jnp.concatenate([obs_image, goal_image], axis=-1) ---> 75 encoding = self.encoder(encoder_inputs) 76 else: 77 # late fusion /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/transforms.py(602)wrapped_fn() 600 def wrapped_fn(self, *args, **kwargs): 601 if not force and not linen_module._use_named_call: --> 602 return prewrapped_fn(self, *args, **kwargs) 603 fn_name = class_fn.__name__ 604 method_suffix = f'.{fn_name}' if fn_name != '__call__' else '' /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(318)wrapped_module_method() 316 if args and isinstance(args[0], Module): 317 self, args = args[0], args[1:] --> 318 return self._call_wrapped_method(fun, args, kwargs) 319 else: 320 return fun(*args, **kwargs) /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(602)_call_wrapped_method() 600 _context.module_stack.append(self) 601 try: --> 602 y = fun(self, *args, **kwargs) 603 if _context.capture_stack: 604 filter_fn = _context.capture_stack[-1] /home/weirdab/robonetv3/code/jaxrl_minimal/jaxrl_m/vision/resnet_v1.py(244)__call__() 242 )(x) 243 --> 244 x = norm(name="norm_init")(x) 245 x = act(x) 246 x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/transforms.py(602)wrapped_fn() 600 def wrapped_fn(self, *args, **kwargs): 601 if not force and not linen_module._use_named_call: --> 602 return prewrapped_fn(self, *args, **kwargs) 603 fn_name = class_fn.__name__ 604 method_suffix = f'.{fn_name}' if fn_name != '__call__' else '' /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(318)wrapped_module_method() 316 if args and isinstance(args[0], Module): 317 self, args = args[0], args[1:] --> 318 return self._call_wrapped_method(fun, args, kwargs) 319 else: 320 return fun(*args, **kwargs) /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(602)_call_wrapped_method() 600 _context.module_stack.append(self) 601 try: --> 602 y = fun(self, *args, **kwargs) 603 if _context.capture_stack: 604 filter_fn = _context.capture_stack[-1] /home/weirdab/robonetv3/code/jaxrl_minimal/jaxrl_m/vision/resnet_v1.py(124)__call__() 122 if x.ndim == 3: 123 x = x[jnp.newaxis] --> 124 x = super().__call__(x) 125 return x[0] 126 else: /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/transforms.py(602)wrapped_fn() 600 def wrapped_fn(self, *args, **kwargs): 601 if not force and not linen_module._use_named_call: --> 602 return prewrapped_fn(self, *args, **kwargs) 603 fn_name = class_fn.__name__ 604 method_suffix = f'.{fn_name}' if fn_name != '__call__' else '' /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(318)wrapped_module_method() 316 if args and isinstance(args[0], Module): 317 self, args = args[0], args[1:] --> 318 return self._call_wrapped_method(fun, args, kwargs) 319 else: 320 return fun(*args, **kwargs) /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(602)_call_wrapped_method() 600 _context.module_stack.append(self) 601 try: --> 602 y = fun(self, *args, **kwargs) 603 if _context.capture_stack: 604 filter_fn = _context.capture_stack[-1] /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/normalization.py(305)__call__() 303 feature_shape = tuple([1 for d in input_shape[:-1]] + [input_shape[-1]]) 304 if self.use_scale: --> 305 x = x * self.param('scale', self.scale_init, feature_shape) 306 if self.use_bias: 307 x = x + self.param('bias', self.bias_init, feature_shape) /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(895)param() 893 if self._name_taken(name): 894 raise errors.NameInUseError('param', name, self.__class__.__name__) --> 895 v = self.scope.param(name, init_fn, *init_args) 896 self._state.children[name] = 'params' 897 return v /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/core/scope.py(624)param() 622 if jnp.shape(val) != jnp.shape(abs_val): 623 raise errors.ScopeParamShapeError(name, self.path_text, --> 624 jnp.shape(val), jnp.shape(abs_val)) 625 else: 626 if not self.is_mutable_collection('params'):