Skip to content
11 changes: 9 additions & 2 deletions octo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def apply_trajectory_transforms(
num_parallel_calls,
)

# chunks observations and actions
# chunks observations and actions, giving them a new axis at index 1 of size `window_size` and
# `window_size + future_action_window_size`, respectively

dataset = dataset.traj_map(
partial(
traj_transforms.chunk_act_obs,
Expand Down Expand Up @@ -391,7 +393,9 @@ def is_nonzero_length(traj):
full_dataset = full_dataset.filter(ModuleSpec.instantiate(filter_fcn_spec))
if ignore_errors:
full_dataset = full_dataset.ignore_errors()

full_dataset = full_dataset.traj_map(restructure).filter(is_nonzero_length)

# tries to load from cache, otherwise computes on the fly
dataset_statistics = get_dataset_statistics(
full_dataset,
Expand Down Expand Up @@ -454,13 +458,13 @@ def is_nonzero_length(traj):

return dataset, dataset_statistics


def make_single_dataset(
dataset_kwargs: dict,
*,
train: bool,
traj_transform_kwargs: dict = {},
frame_transform_kwargs: dict = {},
user_modify_traj
) -> dl.DLataset:
"""Creates a single dataset from kwargs. Returns a dataset of trajectories.

Expand All @@ -474,6 +478,9 @@ def make_single_dataset(
**dataset_kwargs,
train=train,
)

dataset = dataset.traj_map(user_modify_traj)

dataset = apply_trajectory_transforms(dataset, **traj_transform_kwargs, train=train)
dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train)

Expand Down
1 change: 1 addition & 0 deletions octo/data/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def normalize_action_and_proprio(
mask = metadata[key].get(
"mask", tf.ones_like(metadata[key]["mean"], dtype=tf.bool)
)

traj = dl.transforms.selective_tree_map(
traj,
match=lambda k, _: k == traj_key,
Expand Down
4 changes: 4 additions & 0 deletions octo/model/components/action_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ class DiscreteActionHead(nn.Module, ActionHead):
action_dim: int = 7
vocab_size: int = 256
normalization_type: str = "uniform"
low: Optional[float] = None
high: Optional[float] = None

def setup(self):
total_output = self.action_horizon * self.action_dim * self.vocab_size
Expand All @@ -267,6 +269,8 @@ def setup(self):
self.action_tokenizer = BinTokenizer(
n_bins=self.vocab_size,
bin_type=self.normalization_type,
low=self.low,
high=self.high
)

def __call__(
Expand Down
7 changes: 5 additions & 2 deletions octo/model/components/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,14 @@ class BinTokenizer(nn.Module):

n_bins: int = 256
bin_type: str = "uniform"
low: float = 0
high: float = 1
low: Optional[float] = None
high: Optional[float] = None

def setup(self):
if self.bin_type == "uniform":
if self.low is None or self.high is None:
raise ValueError("Low and high must be provided for uniform normalization")

self.thresholds = jnp.linspace(self.low, self.high, self.n_bins + 1)
elif self.bin_type == "normal":
self.thresholds = norm.ppf(jnp.linspace(EPS, 1 - EPS, self.n_bins + 1))
Expand Down
158 changes: 158 additions & 0 deletions octo/model/octo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,162 @@ def sample_actions(
else:
raise ValueError(f"Unknown normalization type: {normalization_type}")
return action

@partial(jax.jit, static_argnames=("train", "sample_shape", "argmax", "beam"))
def sample_future_actions(
self,
observations: Data,
tasks: Data,
unnormalization_statistics: Optional[Data] = None,
normalization_type: NormalizationType = NormalizationType.NORMAL,
beam:int = 1,
timestep_pad_mask: Optional[ArrayLike] = None,
train: bool = False,
argmax: bool = False,
sample_shape: Tuple[int, ...] = (),
rng: Optional[PRNGKey] = None,
temperature: float = 1.0,
):
"""Samples actions from the model. See `action_heads.py` for more info.

Args:
observations: dictionary of arrays of shape (batch_size, window_size, *)
tasks: dict of tasks of shape (batch_size, *)
unnormalization_statistics: dict of statistics for unnormalizing actions (must contain "mean",
"std", and optionally "mask")
normalization_type: type of normalization applied to the actions
timestep_pad_mask: (batch_size, window_size) Boolean mask that is False when the timestep corresponds to padding
train: whether to run in train mode
...see `action_heads.py` for the rest of the kwargs.
Returns:
actions: (*sample_shape, batch_size, action_horizon, action_dim)
"""
if timestep_pad_mask is None:
timestep_pad_mask = observations["pad_mask"]

transformer_outputs = self.run_transformer(
observations, tasks, timestep_pad_mask, train=train
)
action_head = self.module.bind({"params": self.params}).heads[
"action"
]

action_logits = action_head(transformer_outputs, train=train)[:, -1]

action_distribution = jax.nn.softmax(action_logits, axis=-1)

action_tokens = jnp.argsort(action_distribution, axis=-1)[..., -beam:].astype(jnp.int32)
confidence = jnp.take_along_axis(action_distribution, action_tokens, axis=-1)

action_tokens = jnp.broadcast_to(
action_tokens, sample_shape + action_tokens.shape
)

action = action_head.action_tokenizer.decode(action_tokens)

if unnormalization_statistics is not None:
if normalization_type == NormalizationType.NORMAL:
mask = unnormalization_statistics.get(
"mask",
jnp.ones_like(unnormalization_statistics["mean"], dtype=bool),
)
action = action[..., : len(mask)]
action = jnp.where(
mask,
(action * unnormalization_statistics["std"])
+ unnormalization_statistics["mean"],
action,
)
elif normalization_type == NormalizationType.BOUNDS:
mask = unnormalization_statistics.get(
"mask", jnp.ones_like(unnormalization_statistics["p01"], dtype=bool)
)
action = action[..., : len(mask)]
action = jnp.where(
mask,
(action + 1)
* (
unnormalization_statistics["p99"]
- unnormalization_statistics["p01"]
)
/ 2
+ unnormalization_statistics["p01"],
action,
)
else:
raise ValueError(f"Unknown normalization type: {normalization_type}")

return action, confidence

@partial(jax.jit, static_argnames=("train", "sample_shape", "beam"))
def sample_trajectory(
self,
observations: Data,
next_action,
tasks: Data,
unnormalization_statistics: Optional[Data] = None,
normalization_type: NormalizationType = NormalizationType.NORMAL,
beam: int = 1,
timestep_pad_mask: Optional[ArrayLike] = None,
train: bool = False,
argmax: bool = False,
sample_shape: Tuple[int, ...] = (),
rng: Optional[PRNGKey] = None,
temperature: float = 1.0,
):
if timestep_pad_mask is None:
pad_mask = observations["pad_mask"]

transformer_outputs = self.run_transformer(
observations, tasks, pad_mask, train=train
)

trajectory_head = self.module.bind({"params": self.params}).heads[
"trajectory"
]

action = trajectory_head.predict_action(
transformer_outputs,
train=train,
argmax=argmax,
sample_shape=sample_shape,
rng=rng,
temperature=temperature,
)

if unnormalization_statistics is not None:
if normalization_type == NormalizationType.NORMAL:
mask = unnormalization_statistics.get(
"mask",
jnp.ones_like(unnormalization_statistics["mean"], dtype=bool),
)
action = action[..., : len(mask)]
action = jnp.where(
mask,
(action * unnormalization_statistics["std"])
+ unnormalization_statistics["mean"],
action,
)
elif normalization_type == NormalizationType.BOUNDS:
mask = unnormalization_statistics.get(
"mask", jnp.ones_like(unnormalization_statistics["p01"], dtype=bool)
)
action = action[..., : len(mask)]
action = jnp.where(
mask,
(action + 1)
* (
unnormalization_statistics["p99"]
- unnormalization_statistics["p01"]
)
/ 2
+ unnormalization_statistics["p01"],
action,
)
else:
raise ValueError(f"Unknown normalization type: {normalization_type}")

return action

@classmethod
def load_pretrained(
Expand Down Expand Up @@ -277,6 +433,8 @@ def load_pretrained(
tf.io.gfile.join(checkpoint_path, "config.json"), "r"
) as f:
config = json.load(f)
if 'readouts' in config['model']:
config['model']['readout_tokenizers'] = config['model'].pop('readouts')

# shim to support old configs
if "pred_horizon" in config["model"]["heads"]["action"]["kwargs"]:
Expand Down
Loading