Update modeling_prismatic.py
Browse files- modeling_prismatic.py +31 -31
modeling_prismatic.py
CHANGED
|
@@ -338,7 +338,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
| 338 |
f"there might be inference-time regressions due to dependency changes. If in doubt, please"
|
| 339 |
f"use the above versions."
|
| 340 |
)
|
| 341 |
-
|
| 342 |
# Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
|
| 343 |
self.vision_backbone = PrismaticVisionBackbone(
|
| 344 |
config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
|
|
@@ -432,7 +432,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
| 432 |
|
| 433 |
# Move the noisy action features into their correct positions
|
| 434 |
# print(noisy_action_features.size())
|
| 435 |
-
|
| 436 |
repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
|
| 437 |
|
| 438 |
# Combine original input embeddings and noisy action embeddings using the mask
|
|
@@ -475,7 +475,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
| 475 |
def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
|
| 476 |
"""Build multimodal embeddings and attention mask"""
|
| 477 |
# Update attention mask
|
| 478 |
-
|
| 479 |
projected_patch_attention_mask = None
|
| 480 |
if attention_mask is not None:
|
| 481 |
projected_patch_attention_mask = torch.full(
|
|
@@ -589,12 +589,12 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
| 589 |
# Get input embeddings (from language model embeddings)
|
| 590 |
input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
|
| 591 |
|
| 592 |
-
|
| 593 |
# Extract action masks
|
| 594 |
all_actions_mask = self._process_action_masks(labels)
|
| 595 |
|
| 596 |
# Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
|
| 597 |
-
|
| 598 |
# print(input_embeddings[~all_actions_mask].size())
|
| 599 |
language_embeddings = input_embeddings[~all_actions_mask].reshape(
|
| 600 |
input_embeddings.shape[0], -1, input_embeddings.shape[2]
|
|
@@ -624,7 +624,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
| 624 |
|
| 625 |
# Process action embeddings
|
| 626 |
if noisy_actions is not None:
|
| 627 |
-
|
| 628 |
if self.version == 'v1':
|
| 629 |
# action_queries = self.action_queries.weight # (1, h)
|
| 630 |
# action_queries = action_queries.view(1, 1, action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
|
|
@@ -642,7 +642,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
| 642 |
all_actions_mask = self._process_action_masks(labels)
|
| 643 |
input_embeddings = self._replace_input_embeddings(
|
| 644 |
input_embeddings, all_actions_mask, action_queries)
|
| 645 |
-
|
| 646 |
|
| 647 |
else:
|
| 648 |
# Get mask corresponding to all action tokens
|
|
@@ -665,7 +665,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
| 665 |
all_actions_mask = self._process_action_masks(labels)
|
| 666 |
input_embeddings = self._replace_input_embeddings(
|
| 667 |
input_embeddings, all_actions_mask, action_queries)
|
| 668 |
-
|
| 669 |
else:
|
| 670 |
# Replace the embeddings of the action tokens with zeros
|
| 671 |
# (Later on, the positional embeddings will be added to them)
|
|
@@ -677,14 +677,14 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
| 677 |
multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
|
| 678 |
input_embeddings, projected_patch_embeddings, attention_mask
|
| 679 |
)
|
| 680 |
-
|
| 681 |
# Build labels for multimodal sequence if needed
|
| 682 |
multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
|
| 683 |
|
| 684 |
-
|
| 685 |
# Dispatch to language model
|
| 686 |
if self.version == 'v1':
|
| 687 |
-
|
| 688 |
language_model_output = self.language_model(
|
| 689 |
input_ids=None,
|
| 690 |
attention_mask=multimodal_attention_mask,
|
|
@@ -697,7 +697,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
| 697 |
output_hidden_states=output_hidden_states,
|
| 698 |
return_dict=return_dict,
|
| 699 |
)
|
| 700 |
-
|
| 701 |
else:
|
| 702 |
language_model_output = self.language_model(
|
| 703 |
input_ids=None,
|
|
@@ -802,7 +802,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 802 |
def __init__(self, config: OpenVLAConfig) -> None:
|
| 803 |
super().__init__(config)
|
| 804 |
self.norm_stats = config.norm_stats
|
| 805 |
-
|
| 806 |
|
| 807 |
# Compute action bins
|
| 808 |
self.bins = np.linspace(-1, 1, config.n_action_bins)
|
|
@@ -1048,7 +1048,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 1048 |
# Clone embedding for reuse in each timestep
|
| 1049 |
curr_noisy_actions = noise
|
| 1050 |
|
| 1051 |
-
|
| 1052 |
|
| 1053 |
action_queries = self.action_queries.weight # (1, h)
|
| 1054 |
action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
|
|
@@ -1068,7 +1068,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 1068 |
input_embeddings, projected_patch_embeddings, attention_mask
|
| 1069 |
)
|
| 1070 |
|
| 1071 |
-
|
| 1072 |
# Forward pass through language model
|
| 1073 |
language_model_output = self.language_model(
|
| 1074 |
input_ids=None,
|
|
@@ -1083,21 +1083,21 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 1083 |
return_dict=True,
|
| 1084 |
)
|
| 1085 |
multi_layer_hidden_states = []
|
| 1086 |
-
|
| 1087 |
for item in language_model_output.hidden_states[0:]:
|
| 1088 |
# last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
|
| 1089 |
# Get hidden states for text portion of prompt+response (after the vision patches)
|
| 1090 |
text_hidden_states = item
|
| 1091 |
# Get hidden states for action portion of response
|
| 1092 |
actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
|
| 1093 |
-
|
| 1094 |
batch_size = item.shape[0]
|
| 1095 |
task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
|
| 1096 |
all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
|
| 1097 |
multi_layer_hidden_states.append(all_hidden_states)
|
| 1098 |
-
|
| 1099 |
multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
|
| 1100 |
-
|
| 1101 |
|
| 1102 |
|
| 1103 |
|
|
@@ -1176,21 +1176,21 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 1176 |
|
| 1177 |
# Extract hidden states for action tokens
|
| 1178 |
multi_layer_hidden_states = []
|
| 1179 |
-
|
| 1180 |
for item in language_model_output.hidden_states[0:]:
|
| 1181 |
# last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
|
| 1182 |
# Get hidden states for text portion of prompt+response (after the vision patches)
|
| 1183 |
text_hidden_states = item
|
| 1184 |
# Get hidden states for action portion of response
|
| 1185 |
actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
|
| 1186 |
-
|
| 1187 |
batch_size = item.shape[0]
|
| 1188 |
task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
|
| 1189 |
all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
|
| 1190 |
multi_layer_hidden_states.append(all_hidden_states)
|
| 1191 |
-
|
| 1192 |
multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
|
| 1193 |
-
|
| 1194 |
|
| 1195 |
# Handle different prediction methods
|
| 1196 |
if action_head is not None:
|
|
@@ -1311,11 +1311,11 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 1311 |
Returns:
|
| 1312 |
Tuple of (unnormalized_actions, action_hidden_states)
|
| 1313 |
"""
|
| 1314 |
-
|
| 1315 |
# If the special empty token ('') does not already appear after the colon (':') token in the prompt
|
| 1316 |
# (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
|
| 1317 |
|
| 1318 |
-
|
| 1319 |
# if not torch.all(input_ids[:, -1] == 29871):
|
| 1320 |
# input_ids = torch.cat(
|
| 1321 |
# (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
|
|
@@ -1332,7 +1332,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 1332 |
# Get number of tokens in prompt (excluding the start token)
|
| 1333 |
NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
|
| 1334 |
|
| 1335 |
-
|
| 1336 |
|
| 1337 |
# Prepare inputs by adding necessary tokens
|
| 1338 |
input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
|
|
@@ -1362,7 +1362,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 1362 |
projected_patch_embeddings = self._process_proprio_features(
|
| 1363 |
projected_patch_embeddings, proprio, proprio_projector
|
| 1364 |
)
|
| 1365 |
-
|
| 1366 |
# Use diffusion if provided, otherwise use regression or discrete prediction
|
| 1367 |
use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
|
| 1368 |
use_flow_matching = noisy_action_projector is not None and hasattr(action_head, "sample_actions")
|
|
@@ -1380,7 +1380,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 1380 |
if use_diffusion:
|
| 1381 |
NUM_PATCHES += 1
|
| 1382 |
|
| 1383 |
-
|
| 1384 |
if use_flow_matching:
|
| 1385 |
# Sample random noise with shape equal to output action, used as the starting state for flow matching
|
| 1386 |
noise = action_head.sample_noise((1, NUM_ACTIONS_CHUNK, ACTION_DIM),device=input_embeddings.device, dtype=input_embeddings.dtype)
|
|
@@ -1403,10 +1403,10 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 1403 |
noise = torch.randn(
|
| 1404 |
size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
|
| 1405 |
)
|
| 1406 |
-
|
| 1407 |
if self.version == 'v1':
|
| 1408 |
|
| 1409 |
-
|
| 1410 |
# Run diffusion-based prediction
|
| 1411 |
normalized_actions, actions_hidden_states = self._run_diffusion_prediction_V1(
|
| 1412 |
input_embeddings, # [1, 86, 4096]
|
|
@@ -1465,7 +1465,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 1465 |
action_head,
|
| 1466 |
)
|
| 1467 |
|
| 1468 |
-
|
| 1469 |
# Unnormalize predicted actions
|
| 1470 |
actions = self._unnormalize_actions(normalized_actions, unnorm_key)
|
| 1471 |
|
|
|
|
| 338 |
f"there might be inference-time regressions due to dependency changes. If in doubt, please"
|
| 339 |
f"use the above versions."
|
| 340 |
)
|
| 341 |
+
|
| 342 |
# Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
|
| 343 |
self.vision_backbone = PrismaticVisionBackbone(
|
| 344 |
config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
|
|
|
|
| 432 |
|
| 433 |
# Move the noisy action features into their correct positions
|
| 434 |
# print(noisy_action_features.size())
|
| 435 |
+
|
| 436 |
repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
|
| 437 |
|
| 438 |
# Combine original input embeddings and noisy action embeddings using the mask
|
|
|
|
| 475 |
def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
|
| 476 |
"""Build multimodal embeddings and attention mask"""
|
| 477 |
# Update attention mask
|
| 478 |
+
|
| 479 |
projected_patch_attention_mask = None
|
| 480 |
if attention_mask is not None:
|
| 481 |
projected_patch_attention_mask = torch.full(
|
|
|
|
| 589 |
# Get input embeddings (from language model embeddings)
|
| 590 |
input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
|
| 591 |
|
| 592 |
+
|
| 593 |
# Extract action masks
|
| 594 |
all_actions_mask = self._process_action_masks(labels)
|
| 595 |
|
| 596 |
# Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
|
| 597 |
+
|
| 598 |
# print(input_embeddings[~all_actions_mask].size())
|
| 599 |
language_embeddings = input_embeddings[~all_actions_mask].reshape(
|
| 600 |
input_embeddings.shape[0], -1, input_embeddings.shape[2]
|
|
|
|
| 624 |
|
| 625 |
# Process action embeddings
|
| 626 |
if noisy_actions is not None:
|
| 627 |
+
|
| 628 |
if self.version == 'v1':
|
| 629 |
# action_queries = self.action_queries.weight # (1, h)
|
| 630 |
# action_queries = action_queries.view(1, 1, action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
|
|
|
|
| 642 |
all_actions_mask = self._process_action_masks(labels)
|
| 643 |
input_embeddings = self._replace_input_embeddings(
|
| 644 |
input_embeddings, all_actions_mask, action_queries)
|
| 645 |
+
|
| 646 |
|
| 647 |
else:
|
| 648 |
# Get mask corresponding to all action tokens
|
|
|
|
| 665 |
all_actions_mask = self._process_action_masks(labels)
|
| 666 |
input_embeddings = self._replace_input_embeddings(
|
| 667 |
input_embeddings, all_actions_mask, action_queries)
|
| 668 |
+
|
| 669 |
else:
|
| 670 |
# Replace the embeddings of the action tokens with zeros
|
| 671 |
# (Later on, the positional embeddings will be added to them)
|
|
|
|
| 677 |
multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
|
| 678 |
input_embeddings, projected_patch_embeddings, attention_mask
|
| 679 |
)
|
| 680 |
+
|
| 681 |
# Build labels for multimodal sequence if needed
|
| 682 |
multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
|
| 683 |
|
| 684 |
+
|
| 685 |
# Dispatch to language model
|
| 686 |
if self.version == 'v1':
|
| 687 |
+
|
| 688 |
language_model_output = self.language_model(
|
| 689 |
input_ids=None,
|
| 690 |
attention_mask=multimodal_attention_mask,
|
|
|
|
| 697 |
output_hidden_states=output_hidden_states,
|
| 698 |
return_dict=return_dict,
|
| 699 |
)
|
| 700 |
+
|
| 701 |
else:
|
| 702 |
language_model_output = self.language_model(
|
| 703 |
input_ids=None,
|
|
|
|
| 802 |
def __init__(self, config: OpenVLAConfig) -> None:
|
| 803 |
super().__init__(config)
|
| 804 |
self.norm_stats = config.norm_stats
|
| 805 |
+
|
| 806 |
|
| 807 |
# Compute action bins
|
| 808 |
self.bins = np.linspace(-1, 1, config.n_action_bins)
|
|
|
|
| 1048 |
# Clone embedding for reuse in each timestep
|
| 1049 |
curr_noisy_actions = noise
|
| 1050 |
|
| 1051 |
+
|
| 1052 |
|
| 1053 |
action_queries = self.action_queries.weight # (1, h)
|
| 1054 |
action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
|
|
|
|
| 1068 |
input_embeddings, projected_patch_embeddings, attention_mask
|
| 1069 |
)
|
| 1070 |
|
| 1071 |
+
|
| 1072 |
# Forward pass through language model
|
| 1073 |
language_model_output = self.language_model(
|
| 1074 |
input_ids=None,
|
|
|
|
| 1083 |
return_dict=True,
|
| 1084 |
)
|
| 1085 |
multi_layer_hidden_states = []
|
| 1086 |
+
|
| 1087 |
for item in language_model_output.hidden_states[0:]:
|
| 1088 |
# last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
|
| 1089 |
# Get hidden states for text portion of prompt+response (after the vision patches)
|
| 1090 |
text_hidden_states = item
|
| 1091 |
# Get hidden states for action portion of response
|
| 1092 |
actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
|
| 1093 |
+
|
| 1094 |
batch_size = item.shape[0]
|
| 1095 |
task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
|
| 1096 |
all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
|
| 1097 |
multi_layer_hidden_states.append(all_hidden_states)
|
| 1098 |
+
|
| 1099 |
multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
|
| 1100 |
+
|
| 1101 |
|
| 1102 |
|
| 1103 |
|
|
|
|
| 1176 |
|
| 1177 |
# Extract hidden states for action tokens
|
| 1178 |
multi_layer_hidden_states = []
|
| 1179 |
+
|
| 1180 |
for item in language_model_output.hidden_states[0:]:
|
| 1181 |
# last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
|
| 1182 |
# Get hidden states for text portion of prompt+response (after the vision patches)
|
| 1183 |
text_hidden_states = item
|
| 1184 |
# Get hidden states for action portion of response
|
| 1185 |
actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
|
| 1186 |
+
|
| 1187 |
batch_size = item.shape[0]
|
| 1188 |
task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
|
| 1189 |
all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
|
| 1190 |
multi_layer_hidden_states.append(all_hidden_states)
|
| 1191 |
+
|
| 1192 |
multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
|
| 1193 |
+
|
| 1194 |
|
| 1195 |
# Handle different prediction methods
|
| 1196 |
if action_head is not None:
|
|
|
|
| 1311 |
Returns:
|
| 1312 |
Tuple of (unnormalized_actions, action_hidden_states)
|
| 1313 |
"""
|
| 1314 |
+
|
| 1315 |
# If the special empty token ('') does not already appear after the colon (':') token in the prompt
|
| 1316 |
# (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
|
| 1317 |
|
| 1318 |
+
|
| 1319 |
# if not torch.all(input_ids[:, -1] == 29871):
|
| 1320 |
# input_ids = torch.cat(
|
| 1321 |
# (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
|
|
|
|
| 1332 |
# Get number of tokens in prompt (excluding the start token)
|
| 1333 |
NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
|
| 1334 |
|
| 1335 |
+
|
| 1336 |
|
| 1337 |
# Prepare inputs by adding necessary tokens
|
| 1338 |
input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
|
|
|
|
| 1362 |
projected_patch_embeddings = self._process_proprio_features(
|
| 1363 |
projected_patch_embeddings, proprio, proprio_projector
|
| 1364 |
)
|
| 1365 |
+
|
| 1366 |
# Use diffusion if provided, otherwise use regression or discrete prediction
|
| 1367 |
use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
|
| 1368 |
use_flow_matching = noisy_action_projector is not None and hasattr(action_head, "sample_actions")
|
|
|
|
| 1380 |
if use_diffusion:
|
| 1381 |
NUM_PATCHES += 1
|
| 1382 |
|
| 1383 |
+
|
| 1384 |
if use_flow_matching:
|
| 1385 |
# Sample random noise with shape equal to output action, used as the starting state for flow matching
|
| 1386 |
noise = action_head.sample_noise((1, NUM_ACTIONS_CHUNK, ACTION_DIM),device=input_embeddings.device, dtype=input_embeddings.dtype)
|
|
|
|
| 1403 |
noise = torch.randn(
|
| 1404 |
size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
|
| 1405 |
)
|
| 1406 |
+
|
| 1407 |
if self.version == 'v1':
|
| 1408 |
|
| 1409 |
+
|
| 1410 |
# Run diffusion-based prediction
|
| 1411 |
normalized_actions, actions_hidden_states = self._run_diffusion_prediction_V1(
|
| 1412 |
input_embeddings, # [1, 86, 4096]
|
|
|
|
| 1465 |
action_head,
|
| 1466 |
)
|
| 1467 |
|
| 1468 |
+
|
| 1469 |
# Unnormalize predicted actions
|
| 1470 |
actions = self._unnormalize_actions(normalized_actions, unnorm_key)
|
| 1471 |
|