VLA-Adapter commited on
Commit
e915ad2
·
verified ·
1 Parent(s): c7a120c

Update modeling_prismatic.py

Browse files
Files changed (1) hide show
  1. modeling_prismatic.py +31 -31
modeling_prismatic.py CHANGED
@@ -338,7 +338,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
338
  f"there might be inference-time regressions due to dependency changes. If in doubt, please"
339
  f"use the above versions."
340
  )
341
- # import pdb; pdb.set_trace()
342
  # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
343
  self.vision_backbone = PrismaticVisionBackbone(
344
  config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
@@ -432,7 +432,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
432
 
433
  # Move the noisy action features into their correct positions
434
  # print(noisy_action_features.size())
435
- # import pdb; pdb.set_trace()
436
  repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
437
 
438
  # Combine original input embeddings and noisy action embeddings using the mask
@@ -475,7 +475,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
475
  def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
476
  """Build multimodal embeddings and attention mask"""
477
  # Update attention mask
478
- # import pdb; pdb.set_trace()
479
  projected_patch_attention_mask = None
480
  if attention_mask is not None:
481
  projected_patch_attention_mask = torch.full(
@@ -589,12 +589,12 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
589
  # Get input embeddings (from language model embeddings)
590
  input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
591
 
592
- # import pdb; pdb.set_trace()
593
  # Extract action masks
594
  all_actions_mask = self._process_action_masks(labels)
595
 
596
  # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
597
- # import pdb; pdb.set_trace()
598
  # print(input_embeddings[~all_actions_mask].size())
599
  language_embeddings = input_embeddings[~all_actions_mask].reshape(
600
  input_embeddings.shape[0], -1, input_embeddings.shape[2]
@@ -624,7 +624,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
624
 
625
  # Process action embeddings
626
  if noisy_actions is not None:
627
- # import pdb; pdb.set_trace()
628
  if self.version == 'v1':
629
  # action_queries = self.action_queries.weight # (1, h)
630
  # action_queries = action_queries.view(1, 1, action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
@@ -642,7 +642,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
642
  all_actions_mask = self._process_action_masks(labels)
643
  input_embeddings = self._replace_input_embeddings(
644
  input_embeddings, all_actions_mask, action_queries)
645
- # import pdb; pdb.set_trace()
646
 
647
  else:
648
  # Get mask corresponding to all action tokens
@@ -665,7 +665,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
665
  all_actions_mask = self._process_action_masks(labels)
666
  input_embeddings = self._replace_input_embeddings(
667
  input_embeddings, all_actions_mask, action_queries)
668
- # import pdb; pdb.set_trace()
669
  else:
670
  # Replace the embeddings of the action tokens with zeros
671
  # (Later on, the positional embeddings will be added to them)
@@ -677,14 +677,14 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
677
  multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
678
  input_embeddings, projected_patch_embeddings, attention_mask
679
  )
680
- # import pdb; pdb.set_trace()
681
  # Build labels for multimodal sequence if needed
682
  multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
683
 
684
- # import pdb; pdb.set_trace()
685
  # Dispatch to language model
686
  if self.version == 'v1':
687
- # import pdb; pdb.set_trace()
688
  language_model_output = self.language_model(
689
  input_ids=None,
690
  attention_mask=multimodal_attention_mask,
@@ -697,7 +697,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
697
  output_hidden_states=output_hidden_states,
698
  return_dict=return_dict,
699
  )
700
- # import pdb; pdb.set_trace()
701
  else:
702
  language_model_output = self.language_model(
703
  input_ids=None,
@@ -802,7 +802,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
802
  def __init__(self, config: OpenVLAConfig) -> None:
803
  super().__init__(config)
804
  self.norm_stats = config.norm_stats
805
- # import pdb; pdb.set_trace()
806
 
807
  # Compute action bins
808
  self.bins = np.linspace(-1, 1, config.n_action_bins)
@@ -1048,7 +1048,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
1048
  # Clone embedding for reuse in each timestep
1049
  curr_noisy_actions = noise
1050
 
1051
- # import pdb; pdb.set_trace()
1052
 
1053
  action_queries = self.action_queries.weight # (1, h)
1054
  action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
@@ -1068,7 +1068,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
1068
  input_embeddings, projected_patch_embeddings, attention_mask
1069
  )
1070
 
1071
- # import pdb; pdb.set_trace()
1072
  # Forward pass through language model
1073
  language_model_output = self.language_model(
1074
  input_ids=None,
@@ -1083,21 +1083,21 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
1083
  return_dict=True,
1084
  )
1085
  multi_layer_hidden_states = []
1086
- # import pdb; pdb.set_trace()
1087
  for item in language_model_output.hidden_states[0:]:
1088
  # last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
1089
  # Get hidden states for text portion of prompt+response (after the vision patches)
1090
  text_hidden_states = item
1091
  # Get hidden states for action portion of response
1092
  actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
1093
- # import pdb; pdb.set_trace()
1094
  batch_size = item.shape[0]
1095
  task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
1096
  all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
1097
  multi_layer_hidden_states.append(all_hidden_states)
1098
- # import pdb; pdb.set_trace()
1099
  multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
1100
- # import pdb; pdb.set_trace()
1101
 
1102
 
1103
 
@@ -1176,21 +1176,21 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
1176
 
1177
  # Extract hidden states for action tokens
1178
  multi_layer_hidden_states = []
1179
- # import pdb; pdb.set_trace()
1180
  for item in language_model_output.hidden_states[0:]:
1181
  # last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
1182
  # Get hidden states for text portion of prompt+response (after the vision patches)
1183
  text_hidden_states = item
1184
  # Get hidden states for action portion of response
1185
  actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
1186
- # import pdb; pdb.set_trace()
1187
  batch_size = item.shape[0]
1188
  task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
1189
  all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
1190
  multi_layer_hidden_states.append(all_hidden_states)
1191
- # import pdb; pdb.set_trace()
1192
  multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
1193
- # import pdb; pdb.set_trace()
1194
 
1195
  # Handle different prediction methods
1196
  if action_head is not None:
@@ -1311,11 +1311,11 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
1311
  Returns:
1312
  Tuple of (unnormalized_actions, action_hidden_states)
1313
  """
1314
- # import pdb; pdb.set_trace()
1315
  # If the special empty token ('') does not already appear after the colon (':') token in the prompt
1316
  # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
1317
 
1318
- # 如果是 minivla, 不用加这个判断!!!!!
1319
  # if not torch.all(input_ids[:, -1] == 29871):
1320
  # input_ids = torch.cat(
1321
  # (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
@@ -1332,7 +1332,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
1332
  # Get number of tokens in prompt (excluding the start token)
1333
  NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1334
 
1335
- # import pdb; pdb.set_trace()
1336
 
1337
  # Prepare inputs by adding necessary tokens
1338
  input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
@@ -1362,7 +1362,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
1362
  projected_patch_embeddings = self._process_proprio_features(
1363
  projected_patch_embeddings, proprio, proprio_projector
1364
  )
1365
- # import pdb; pdb.set_trace()
1366
  # Use diffusion if provided, otherwise use regression or discrete prediction
1367
  use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1368
  use_flow_matching = noisy_action_projector is not None and hasattr(action_head, "sample_actions")
@@ -1380,7 +1380,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
1380
  if use_diffusion:
1381
  NUM_PATCHES += 1
1382
 
1383
- # import pdb; pdb.set_trace()
1384
  if use_flow_matching:
1385
  # Sample random noise with shape equal to output action, used as the starting state for flow matching
1386
  noise = action_head.sample_noise((1, NUM_ACTIONS_CHUNK, ACTION_DIM),device=input_embeddings.device, dtype=input_embeddings.dtype)
@@ -1403,10 +1403,10 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
1403
  noise = torch.randn(
1404
  size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
1405
  )
1406
- # import pdb; pdb.set_trace()
1407
  if self.version == 'v1':
1408
 
1409
- # import pdb; pdb.set_trace()
1410
  # Run diffusion-based prediction
1411
  normalized_actions, actions_hidden_states = self._run_diffusion_prediction_V1(
1412
  input_embeddings, # [1, 86, 4096]
@@ -1465,7 +1465,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
1465
  action_head,
1466
  )
1467
 
1468
- # import pdb; pdb.set_trace()
1469
  # Unnormalize predicted actions
1470
  actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1471
 
 
338
  f"there might be inference-time regressions due to dependency changes. If in doubt, please"
339
  f"use the above versions."
340
  )
341
+
342
  # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
343
  self.vision_backbone = PrismaticVisionBackbone(
344
  config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
 
432
 
433
  # Move the noisy action features into their correct positions
434
  # print(noisy_action_features.size())
435
+
436
  repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
437
 
438
  # Combine original input embeddings and noisy action embeddings using the mask
 
475
  def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
476
  """Build multimodal embeddings and attention mask"""
477
  # Update attention mask
478
+
479
  projected_patch_attention_mask = None
480
  if attention_mask is not None:
481
  projected_patch_attention_mask = torch.full(
 
589
  # Get input embeddings (from language model embeddings)
590
  input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
591
 
592
+
593
  # Extract action masks
594
  all_actions_mask = self._process_action_masks(labels)
595
 
596
  # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
597
+
598
  # print(input_embeddings[~all_actions_mask].size())
599
  language_embeddings = input_embeddings[~all_actions_mask].reshape(
600
  input_embeddings.shape[0], -1, input_embeddings.shape[2]
 
624
 
625
  # Process action embeddings
626
  if noisy_actions is not None:
627
+
628
  if self.version == 'v1':
629
  # action_queries = self.action_queries.weight # (1, h)
630
  # action_queries = action_queries.view(1, 1, action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
 
642
  all_actions_mask = self._process_action_masks(labels)
643
  input_embeddings = self._replace_input_embeddings(
644
  input_embeddings, all_actions_mask, action_queries)
645
+
646
 
647
  else:
648
  # Get mask corresponding to all action tokens
 
665
  all_actions_mask = self._process_action_masks(labels)
666
  input_embeddings = self._replace_input_embeddings(
667
  input_embeddings, all_actions_mask, action_queries)
668
+
669
  else:
670
  # Replace the embeddings of the action tokens with zeros
671
  # (Later on, the positional embeddings will be added to them)
 
677
  multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
678
  input_embeddings, projected_patch_embeddings, attention_mask
679
  )
680
+
681
  # Build labels for multimodal sequence if needed
682
  multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
683
 
684
+
685
  # Dispatch to language model
686
  if self.version == 'v1':
687
+
688
  language_model_output = self.language_model(
689
  input_ids=None,
690
  attention_mask=multimodal_attention_mask,
 
697
  output_hidden_states=output_hidden_states,
698
  return_dict=return_dict,
699
  )
700
+
701
  else:
702
  language_model_output = self.language_model(
703
  input_ids=None,
 
802
  def __init__(self, config: OpenVLAConfig) -> None:
803
  super().__init__(config)
804
  self.norm_stats = config.norm_stats
805
+
806
 
807
  # Compute action bins
808
  self.bins = np.linspace(-1, 1, config.n_action_bins)
 
1048
  # Clone embedding for reuse in each timestep
1049
  curr_noisy_actions = noise
1050
 
1051
+
1052
 
1053
  action_queries = self.action_queries.weight # (1, h)
1054
  action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
 
1068
  input_embeddings, projected_patch_embeddings, attention_mask
1069
  )
1070
 
1071
+
1072
  # Forward pass through language model
1073
  language_model_output = self.language_model(
1074
  input_ids=None,
 
1083
  return_dict=True,
1084
  )
1085
  multi_layer_hidden_states = []
1086
+
1087
  for item in language_model_output.hidden_states[0:]:
1088
  # last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
1089
  # Get hidden states for text portion of prompt+response (after the vision patches)
1090
  text_hidden_states = item
1091
  # Get hidden states for action portion of response
1092
  actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
1093
+
1094
  batch_size = item.shape[0]
1095
  task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
1096
  all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
1097
  multi_layer_hidden_states.append(all_hidden_states)
1098
+
1099
  multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
1100
+
1101
 
1102
 
1103
 
 
1176
 
1177
  # Extract hidden states for action tokens
1178
  multi_layer_hidden_states = []
1179
+
1180
  for item in language_model_output.hidden_states[0:]:
1181
  # last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
1182
  # Get hidden states for text portion of prompt+response (after the vision patches)
1183
  text_hidden_states = item
1184
  # Get hidden states for action portion of response
1185
  actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
1186
+
1187
  batch_size = item.shape[0]
1188
  task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
1189
  all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
1190
  multi_layer_hidden_states.append(all_hidden_states)
1191
+
1192
  multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
1193
+
1194
 
1195
  # Handle different prediction methods
1196
  if action_head is not None:
 
1311
  Returns:
1312
  Tuple of (unnormalized_actions, action_hidden_states)
1313
  """
1314
+
1315
  # If the special empty token ('') does not already appear after the colon (':') token in the prompt
1316
  # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
1317
 
1318
+
1319
  # if not torch.all(input_ids[:, -1] == 29871):
1320
  # input_ids = torch.cat(
1321
  # (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
 
1332
  # Get number of tokens in prompt (excluding the start token)
1333
  NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1334
 
1335
+
1336
 
1337
  # Prepare inputs by adding necessary tokens
1338
  input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
 
1362
  projected_patch_embeddings = self._process_proprio_features(
1363
  projected_patch_embeddings, proprio, proprio_projector
1364
  )
1365
+
1366
  # Use diffusion if provided, otherwise use regression or discrete prediction
1367
  use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1368
  use_flow_matching = noisy_action_projector is not None and hasattr(action_head, "sample_actions")
 
1380
  if use_diffusion:
1381
  NUM_PATCHES += 1
1382
 
1383
+
1384
  if use_flow_matching:
1385
  # Sample random noise with shape equal to output action, used as the starting state for flow matching
1386
  noise = action_head.sample_noise((1, NUM_ACTIONS_CHUNK, ACTION_DIM),device=input_embeddings.device, dtype=input_embeddings.dtype)
 
1403
  noise = torch.randn(
1404
  size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
1405
  )
1406
+
1407
  if self.version == 'v1':
1408
 
1409
+
1410
  # Run diffusion-based prediction
1411
  normalized_actions, actions_hidden_states = self._run_diffusion_prediction_V1(
1412
  input_embeddings, # [1, 86, 4096]
 
1465
  action_head,
1466
  )
1467
 
1468
+
1469
  # Unnormalize predicted actions
1470
  actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1471