ipdb> where
/home/weirdab/goalrelabel_locobot_fullgcsl/launch_main.py(286)<module>()
284 print("params before run", params)
285
--> 286 run(**params)
287 # dd_utils.launch(run, params, mode='local', instance_type='c4.xlarge')
288
/home/weirdab/goalrelabel_locobot_fullgcsl/launch_main.py(210)run()
208 **huge_kwargs
209 )
--> 210 algo.train()
211
212
/home/weirdab/goalrelabel_locobot_fullgcsl/huge/algo/huge.py(1511)train()
1509
1510 self.policy.eval()
-> 1511 self.evaluate_policy(self.eval_episodes, greedy=True, prefix='Eval')
1512 last_time = time.time()
1513 # End Evaluation Code
/home/weirdab/goalrelabel_locobot_fullgcsl/huge/algo/huge.py(1864)evaluate_policy()
1862 goal = self.env.extract_goal(self.env.sample_goal())
1863
-> 1864 states, actions, goal_state, _, _ , img_states= self.sample_trajectory(goal=goal, greedy=greedy, save_video_trajectory=index==0, video_filename=video_filename)
1865 all_actions.extend(actions)
1866 all_states.append(states)
/home/weirdab/goalrelabel_locobot_fullgcsl/huge/algo/huge.py(1109)sample_trajectory()
1107 action = self.policy.act_vectorized(observation_image[None], img_goal[None], horizon=horizon[None], greedy=True, noise=0)[0]
1108 else:
-> 1109 action = self.policy.act_vectorized(observation[None], goal[None], horizon=horizon[None], greedy=True, noise=0)[0]
1110 else:
1111 if self.use_images_in_policy:
/home/weirdab/goalrelabel_locobot_fullgcsl/huge/algo/networks.py(563)act_vectorized()
561 # if horizon is not None:
562 # horizon = torch.tensor(horizon, dtype=torch.float32)
--> 563 dist = self.forward(obs, goal, horizon)
564
565 samples = dist.sample()
/home/weirdab/goalrelabel_locobot_fullgcsl/huge/algo/networks.py(545)forward()
543 (obs, goal),
544 temperature=1.0,
--> 545 name="actor",
546 )
547 return dist
[... skipping 1 hidden frame(s)]
> /home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(1060)apply()
1058 method, self,
1059 mutable=mutable, capture_intermediates=capture_intermediates
-> 1060 )(variables, *args, **kwargs, rngs=rngs)
1061
1062 @traceback_util.api_boundary
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/core/scope.py(691)wrapper()
689 **kwargs) -> Union[Any, Tuple[Any, VariableDict]]:
690 with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
--> 691 y = fn(root, *args, **kwargs)
692 if mutable is not False:
693 return y, root.mutable_variables()
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(1312)scope_fn()
1310 _context.capture_stack.append(capture_intermediates)
1311 try:
-> 1312 return fn(module.clone(parent=scope), *args, **kwargs)
1313 finally:
1314 _context.capture_stack.pop()
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/transforms.py(602)wrapped_fn()
600 def wrapped_fn(self, *args, **kwargs):
601 if not force and not linen_module._use_named_call:
--> 602 return prewrapped_fn(self, *args, **kwargs)
603 fn_name = class_fn.__name__
604 method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(318)wrapped_module_method()
316 if args and isinstance(args[0], Module):
317 self, args = args[0], args[1:]
--> 318 return self._call_wrapped_method(fun, args, kwargs)
319 else:
320 return fun(*args, **kwargs)
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(602)_call_wrapped_method()
600 _context.module_stack.append(self)
601 try:
--> 602 y = fun(self, *args, **kwargs)
603 if _context.capture_stack:
604 filter_fn = _context.capture_stack[-1]
/home/weirdab/robonetv3/code/jaxrl_minimal/jaxrl_m/common/common.py(77)__call__()
75 return out
76
---> 77 return self.modules[name](*args, **kwargs)
78
79
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/transforms.py(602)wrapped_fn()
600 def wrapped_fn(self, *args, **kwargs):
601 if not force and not linen_module._use_named_call:
--> 602 return prewrapped_fn(self, *args, **kwargs)
603 fn_name = class_fn.__name__
604 method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(318)wrapped_module_method()
316 if args and isinstance(args[0], Module):
317 self, args = args[0], args[1:]
--> 318 return self._call_wrapped_method(fun, args, kwargs)
319 else:
320 return fun(*args, **kwargs)
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(602)_call_wrapped_method()
600 _context.module_stack.append(self)
601 try:
--> 602 y = fun(self, *args, **kwargs)
603 if _context.capture_stack:
604 filter_fn = _context.capture_stack[-1]
/home/weirdab/robonetv3/code/jaxrl_minimal/jaxrl_m/networks/actor_critic_nets.py(124)__call__()
122 self, observations: jnp.ndarray, temperature: float = 1.0, train: bool = False
123 ) -> distrax.Distribution:
--> 124 outputs = self.network(self.encoder(observations), train=train)
125
126 means = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/transforms.py(602)wrapped_fn()
600 def wrapped_fn(self, *args, **kwargs):
601 if not force and not linen_module._use_named_call:
--> 602 return prewrapped_fn(self, *args, **kwargs)
603 fn_name = class_fn.__name__
604 method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(318)wrapped_module_method()
316 if args and isinstance(args[0], Module):
317 self, args = args[0], args[1:]
--> 318 return self._call_wrapped_method(fun, args, kwargs)
319 else:
320 return fun(*args, **kwargs)
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(602)_call_wrapped_method()
600 _context.module_stack.append(self)
601 try:
--> 602 y = fun(self, *args, **kwargs)
603 if _context.capture_stack:
604 filter_fn = _context.capture_stack[-1]
/home/weirdab/robonetv3/code/jaxrl_minimal/jaxrl_m/common/encoding.py(75)__call__()
73 # early goal concat
74 encoder_inputs = jnp.concatenate([obs_image, goal_image], axis=-1)
---> 75 encoding = self.encoder(encoder_inputs)
76 else:
77 # late fusion
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/transforms.py(602)wrapped_fn()
600 def wrapped_fn(self, *args, **kwargs):
601 if not force and not linen_module._use_named_call:
--> 602 return prewrapped_fn(self, *args, **kwargs)
603 fn_name = class_fn.__name__
604 method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(318)wrapped_module_method()
316 if args and isinstance(args[0], Module):
317 self, args = args[0], args[1:]
--> 318 return self._call_wrapped_method(fun, args, kwargs)
319 else:
320 return fun(*args, **kwargs)
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(602)_call_wrapped_method()
600 _context.module_stack.append(self)
601 try:
--> 602 y = fun(self, *args, **kwargs)
603 if _context.capture_stack:
604 filter_fn = _context.capture_stack[-1]
/home/weirdab/robonetv3/code/jaxrl_minimal/jaxrl_m/vision/resnet_v1.py(244)__call__()
242 )(x)
243
--> 244 x = norm(name="norm_init")(x)
245 x = act(x)
246 x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/transforms.py(602)wrapped_fn()
600 def wrapped_fn(self, *args, **kwargs):
601 if not force and not linen_module._use_named_call:
--> 602 return prewrapped_fn(self, *args, **kwargs)
603 fn_name = class_fn.__name__
604 method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(318)wrapped_module_method()
316 if args and isinstance(args[0], Module):
317 self, args = args[0], args[1:]
--> 318 return self._call_wrapped_method(fun, args, kwargs)
319 else:
320 return fun(*args, **kwargs)
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(602)_call_wrapped_method()
600 _context.module_stack.append(self)
601 try:
--> 602 y = fun(self, *args, **kwargs)
603 if _context.capture_stack:
604 filter_fn = _context.capture_stack[-1]
/home/weirdab/robonetv3/code/jaxrl_minimal/jaxrl_m/vision/resnet_v1.py(124)__call__()
122 if x.ndim == 3:
123 x = x[jnp.newaxis]
--> 124 x = super().__call__(x)
125 return x[0]
126 else:
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/transforms.py(602)wrapped_fn()
600 def wrapped_fn(self, *args, **kwargs):
601 if not force and not linen_module._use_named_call:
--> 602 return prewrapped_fn(self, *args, **kwargs)
603 fn_name = class_fn.__name__
604 method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(318)wrapped_module_method()
316 if args and isinstance(args[0], Module):
317 self, args = args[0], args[1:]
--> 318 return self._call_wrapped_method(fun, args, kwargs)
319 else:
320 return fun(*args, **kwargs)
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(602)_call_wrapped_method()
600 _context.module_stack.append(self)
601 try:
--> 602 y = fun(self, *args, **kwargs)
603 if _context.capture_stack:
604 filter_fn = _context.capture_stack[-1]
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/normalization.py(305)__call__()
303 feature_shape = tuple([1 for d in input_shape[:-1]] + [input_shape[-1]])
304 if self.use_scale:
--> 305 x = x * self.param('scale', self.scale_init, feature_shape)
306 if self.use_bias:
307 x = x + self.param('bias', self.bias_init, feature_shape)
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/linen/module.py(895)param()
893 if self._name_taken(name):
894 raise errors.NameInUseError('param', name, self.__class__.__name__)
--> 895 v = self.scope.param(name, init_fn, *init_args)
896 self._state.children[name] = 'params'
897 return v
/home/weirdab/miniconda3/envs/huge/lib/python3.6/site-packages/flax/core/scope.py(624)param()
622 if jnp.shape(val) != jnp.shape(abs_val):
623 raise errors.ScopeParamShapeError(name, self.path_text,
--> 624 jnp.shape(val), jnp.shape(abs_val))
625 else:
626 if not self.is_mutable_collection('params'):