Upload 235 files
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +1 -0
 - mllm/__init__.py +0 -0
 - mllm/__pycache__/__init__.cpython-310.pyc +0 -0
 - mllm/config/__init__.py +1 -0
 - mllm/config/__pycache__/__init__.cpython-310.pyc +0 -0
 - mllm/config/__pycache__/config.cpython-310.pyc +0 -0
 - mllm/config/config.py +135 -0
 - mllm/conversation/__init__.py +1 -0
 - mllm/conversation/__pycache__/__init__.cpython-310.pyc +0 -0
 - mllm/conversation/__pycache__/base_conversation.cpython-310.pyc +0 -0
 - mllm/conversation/base_conversation.py +503 -0
 - mllm/dataset/__init__.py +7 -0
 - mllm/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
 - mllm/dataset/__pycache__/builder.cpython-310.pyc +0 -0
 - mllm/dataset/__pycache__/root.cpython-310.pyc +0 -0
 - mllm/dataset/__pycache__/single_image_convsation.cpython-310.pyc +0 -0
 - mllm/dataset/__pycache__/single_image_interactive.cpython-310.pyc +0 -0
 - mllm/dataset/builder.py +118 -0
 - mllm/dataset/process_function/__init__.py +13 -0
 - mllm/dataset/process_function/__pycache__/__init__.cpython-310.pyc +0 -0
 - mllm/dataset/process_function/__pycache__/box_process_function.cpython-310.pyc +0 -0
 - mllm/dataset/process_function/__pycache__/shikra_process_function.cpython-310.pyc +0 -0
 - mllm/dataset/process_function/box_process_function.py +326 -0
 - mllm/dataset/process_function/shikra_process_function.py +178 -0
 - mllm/dataset/root.py +67 -0
 - mllm/dataset/single_image_convsation.py +284 -0
 - mllm/dataset/single_image_dataset/__init__.py +13 -0
 - mllm/dataset/single_image_dataset/__pycache__/__init__.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/__pycache__/caption.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/__pycache__/clevr.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/__pycache__/flickr.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/__pycache__/gpt_gen.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/__pycache__/gqa.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/__pycache__/instr.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/__pycache__/point_qa.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/__pycache__/pope.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/__pycache__/rec.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/__pycache__/reg.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/__pycache__/vcr.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/__pycache__/vqaex.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/__pycache__/vqav2.cpython-310.pyc +0 -0
 - mllm/dataset/single_image_dataset/caption.py +31 -0
 - mllm/dataset/single_image_dataset/clevr.py +116 -0
 - mllm/dataset/single_image_dataset/flickr.py +68 -0
 - mllm/dataset/single_image_dataset/gpt_gen.py +58 -0
 - mllm/dataset/single_image_dataset/gqa.py +233 -0
 - mllm/dataset/single_image_dataset/instr.py +24 -0
 - mllm/dataset/single_image_dataset/point_qa.py +247 -0
 - mllm/dataset/single_image_dataset/pope.py +36 -0
 - mllm/dataset/single_image_dataset/rec.py +128 -0
 
    	
        .gitattributes
    CHANGED
    
    | 
         @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text 
     | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         | 
| 
         | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
            +
            mllm/demo/assets/baseball.png filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        mllm/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        mllm/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | 
         Binary file (137 Bytes). View file 
     | 
| 
         | 
    	
        mllm/config/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .config import prepare_args
         
     | 
    	
        mllm/config/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | 
         Binary file (187 Bytes). View file 
     | 
| 
         | 
    	
        mllm/config/__pycache__/config.cpython-310.pyc
    ADDED
    
    | 
         Binary file (4.24 kB). View file 
     | 
| 
         | 
    	
        mllm/config/config.py
    ADDED
    
    | 
         @@ -0,0 +1,135 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import sys
         
     | 
| 3 | 
         
            +
            import logging
         
     | 
| 4 | 
         
            +
            import argparse
         
     | 
| 5 | 
         
            +
            from dataclasses import dataclass, field
         
     | 
| 6 | 
         
            +
            from typing import List, Tuple
         
     | 
| 7 | 
         
            +
            from argparse import SUPPRESS
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import datasets
         
     | 
| 10 | 
         
            +
            import transformers
         
     | 
| 11 | 
         
            +
            from mmengine.config import Config, DictAction
         
     | 
| 12 | 
         
            +
            from transformers import HfArgumentParser, set_seed, add_start_docstrings
         
     | 
| 13 | 
         
            +
            from transformers import Seq2SeqTrainingArguments as HFSeq2SeqTrainingArguments
         
     | 
| 14 | 
         
            +
            from transformers.trainer_utils import get_last_checkpoint, is_main_process
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 17 | 
         
            +
            logger.setLevel(logging.INFO)
         
     | 
| 18 | 
         
            +
            logging.basicConfig(
         
     | 
| 19 | 
         
            +
                format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
         
     | 
| 20 | 
         
            +
                datefmt="%m/%d/%Y %H:%M:%S",
         
     | 
| 21 | 
         
            +
                handlers=[logging.StreamHandler(sys.stdout), ],
         
     | 
| 22 | 
         
            +
            )
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            @dataclass
         
     | 
| 26 | 
         
            +
            @add_start_docstrings(HFSeq2SeqTrainingArguments.__doc__)
         
     | 
| 27 | 
         
            +
            class Seq2SeqTrainingArguments(HFSeq2SeqTrainingArguments):
         
     | 
| 28 | 
         
            +
                do_multi_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the multi-test set."})
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            def prepare_args(args=None):
         
     | 
| 32 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 33 | 
         
            +
                parser.add_argument('config', help='train config file path')
         
     | 
| 34 | 
         
            +
                parser.add_argument(
         
     | 
| 35 | 
         
            +
                    '--cfg-options',
         
     | 
| 36 | 
         
            +
                    nargs='+',
         
     | 
| 37 | 
         
            +
                    action=DictAction,
         
     | 
| 38 | 
         
            +
                    help='override some settings in the used config, the key-value pair '
         
     | 
| 39 | 
         
            +
                         'in xxx=yyy format will be merged into config file. If the value to '
         
     | 
| 40 | 
         
            +
                         'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
         
     | 
| 41 | 
         
            +
                         'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
         
     | 
| 42 | 
         
            +
                         'Note that the quotation marks are necessary and that no white space '
         
     | 
| 43 | 
         
            +
                         'is allowed.')
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                hf_parser = HfArgumentParser((Seq2SeqTrainingArguments,))
         
     | 
| 46 | 
         
            +
                hf_parser, required = block_required_error(hf_parser)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                args, unknown_args = parser.parse_known_args(args)
         
     | 
| 49 | 
         
            +
                known_hf_args, unknown_args = hf_parser.parse_known_args(unknown_args)
         
     | 
| 50 | 
         
            +
                if unknown_args:
         
     | 
| 51 | 
         
            +
                    raise ValueError(f"Some specified arguments are not used "
         
     | 
| 52 | 
         
            +
                                     f"by the ArgumentParser or HfArgumentParser\n: {unknown_args}")
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                # load 'cfg' and 'training_args' from file and cli
         
     | 
| 55 | 
         
            +
                cfg = Config.fromfile(args.config)
         
     | 
| 56 | 
         
            +
                if args.cfg_options is not None:
         
     | 
| 57 | 
         
            +
                    cfg.merge_from_dict(args.cfg_options)
         
     | 
| 58 | 
         
            +
                training_args = cfg.training_args
         
     | 
| 59 | 
         
            +
                training_args.update(vars(known_hf_args))
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                # check training_args require
         
     | 
| 62 | 
         
            +
                req_but_not_assign = [item for item in required if item not in training_args]
         
     | 
| 63 | 
         
            +
                if req_but_not_assign:
         
     | 
| 64 | 
         
            +
                    raise ValueError(f"Requires {req_but_not_assign} but not assign.")
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                # update cfg.training_args
         
     | 
| 67 | 
         
            +
                cfg.training_args = training_args
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                # initialize and return
         
     | 
| 70 | 
         
            +
                training_args = Seq2SeqTrainingArguments(**training_args)
         
     | 
| 71 | 
         
            +
                training_args = check_output_dir(training_args)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                # logging
         
     | 
| 74 | 
         
            +
                if is_main_process(training_args.local_rank):
         
     | 
| 75 | 
         
            +
                    to_logging_cfg = Config()
         
     | 
| 76 | 
         
            +
                    to_logging_cfg.model_args = cfg.model_args
         
     | 
| 77 | 
         
            +
                    to_logging_cfg.data_args = cfg.data_args
         
     | 
| 78 | 
         
            +
                    to_logging_cfg.training_args = cfg.training_args
         
     | 
| 79 | 
         
            +
                    logger.info(to_logging_cfg.pretty_text)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                # setup logger
         
     | 
| 82 | 
         
            +
                if training_args.should_log:
         
     | 
| 83 | 
         
            +
                    # The default of training_args.log_level is passive, so we set log level at info here to have that default.
         
     | 
| 84 | 
         
            +
                    transformers.logging.set_verbosity_info()
         
     | 
| 85 | 
         
            +
                log_level = training_args.get_process_log_level()
         
     | 
| 86 | 
         
            +
                logger.setLevel(log_level)
         
     | 
| 87 | 
         
            +
                datasets.utils.logging.set_verbosity(log_level)
         
     | 
| 88 | 
         
            +
                transformers.logging.set_verbosity(log_level)
         
     | 
| 89 | 
         
            +
                transformers.logging.enable_default_handler()
         
     | 
| 90 | 
         
            +
                transformers.logging.enable_explicit_format()
         
     | 
| 91 | 
         
            +
                # setup_print_for_distributed(is_main_process(training_args))
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                # Log on each process the small summary:
         
     | 
| 94 | 
         
            +
                logger.info(f"Training/evaluation parameters {training_args}")
         
     | 
| 95 | 
         
            +
                logger.warning(
         
     | 
| 96 | 
         
            +
                    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
         
     | 
| 97 | 
         
            +
                    + f"  distributed training: {bool(training_args.local_rank != -1)}, fp16 training: {training_args.fp16}"
         
     | 
| 98 | 
         
            +
                )
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                # Set seed before initializing model.
         
     | 
| 101 | 
         
            +
                set_seed(training_args.seed)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                return cfg, training_args
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
            def block_required_error(hf_parser: HfArgumentParser) -> Tuple[HfArgumentParser, List]:
         
     | 
| 107 | 
         
            +
                required = []
         
     | 
| 108 | 
         
            +
                # noinspection PyProtectedMember
         
     | 
| 109 | 
         
            +
                for action in hf_parser._actions:
         
     | 
| 110 | 
         
            +
                    if action.required:
         
     | 
| 111 | 
         
            +
                        required.append(action.dest)
         
     | 
| 112 | 
         
            +
                    action.required = False
         
     | 
| 113 | 
         
            +
                    action.default = SUPPRESS
         
     | 
| 114 | 
         
            +
                return hf_parser, required
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            def check_output_dir(training_args):
         
     | 
| 118 | 
         
            +
                # Detecting last checkpoint.
         
     | 
| 119 | 
         
            +
                if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
         
     | 
| 120 | 
         
            +
                    last_checkpoint = get_last_checkpoint(training_args.output_dir)
         
     | 
| 121 | 
         
            +
                    if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
         
     | 
| 122 | 
         
            +
                        raise ValueError(
         
     | 
| 123 | 
         
            +
                            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
         
     | 
| 124 | 
         
            +
                            "Use --overwrite_output_dir to overcome."
         
     | 
| 125 | 
         
            +
                        )
         
     | 
| 126 | 
         
            +
                    elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
         
     | 
| 127 | 
         
            +
                        logger.info(
         
     | 
| 128 | 
         
            +
                            f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
         
     | 
| 129 | 
         
            +
                            "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
         
     | 
| 130 | 
         
            +
                        )
         
     | 
| 131 | 
         
            +
                return training_args
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 135 | 
         
            +
                _ = prepare_args()
         
     | 
    	
        mllm/conversation/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .base_conversation import SeparatorStyle, Conversation, register_conv_template, get_conv_template
         
     | 
    	
        mllm/conversation/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | 
         Binary file (290 Bytes). View file 
     | 
| 
         | 
    	
        mllm/conversation/__pycache__/base_conversation.cpython-310.pyc
    ADDED
    
    | 
         Binary file (11.4 kB). View file 
     | 
| 
         | 
    	
        mllm/conversation/base_conversation.py
    ADDED
    
    | 
         @@ -0,0 +1,503 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # copy from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
         
     | 
| 2 | 
         
            +
            """
         
     | 
| 3 | 
         
            +
            Conversation prompt templates.
         
     | 
| 4 | 
         
            +
            """
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import dataclasses
         
     | 
| 7 | 
         
            +
            from enum import auto, Enum
         
     | 
| 8 | 
         
            +
            from typing import List, Tuple, Any, Dict
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class SeparatorStyle(Enum):
         
     | 
| 12 | 
         
            +
                """Separator styles."""
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                ADD_COLON_SINGLE = auto()
         
     | 
| 15 | 
         
            +
                ADD_COLON_TWO = auto()
         
     | 
| 16 | 
         
            +
                ADD_SPACE_TWO = auto()
         
     | 
| 17 | 
         
            +
                NO_COLON_SINGLE = auto()
         
     | 
| 18 | 
         
            +
                BAIZE = auto()
         
     | 
| 19 | 
         
            +
                DOLLY = auto()
         
     | 
| 20 | 
         
            +
                RWKV = auto()
         
     | 
| 21 | 
         
            +
                PHOENIX = auto()
         
     | 
| 22 | 
         
            +
                NEW_LINE = auto()
         
     | 
| 23 | 
         
            +
                BILLA = auto()
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            @dataclasses.dataclass
         
     | 
| 27 | 
         
            +
            class Conversation:
         
     | 
| 28 | 
         
            +
                """A class that keeps all conversation history."""
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                # The name of this template
         
     | 
| 31 | 
         
            +
                name: str
         
     | 
| 32 | 
         
            +
                # System prompts
         
     | 
| 33 | 
         
            +
                system: str
         
     | 
| 34 | 
         
            +
                # Two roles
         
     | 
| 35 | 
         
            +
                roles: List[str]
         
     | 
| 36 | 
         
            +
                # All messages
         
     | 
| 37 | 
         
            +
                messages: List[List[str]]
         
     | 
| 38 | 
         
            +
                # Offset of few shot examples
         
     | 
| 39 | 
         
            +
                offset: int
         
     | 
| 40 | 
         
            +
                # Separators
         
     | 
| 41 | 
         
            +
                sep_style: SeparatorStyle
         
     | 
| 42 | 
         
            +
                sep: str
         
     | 
| 43 | 
         
            +
                sep2: str = None
         
     | 
| 44 | 
         
            +
                # Stop criteria (the default one is EOS token)
         
     | 
| 45 | 
         
            +
                stop_str: str = None
         
     | 
| 46 | 
         
            +
                # Stops generation if meeting any token in this list
         
     | 
| 47 | 
         
            +
                stop_token_ids: List[int] = None
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                # Used for the state in the gradio servers.
         
     | 
| 50 | 
         
            +
                # TODO(lmzheng): move this out of this class.
         
     | 
| 51 | 
         
            +
                conv_id: Any = None
         
     | 
| 52 | 
         
            +
                skip_next: bool = False
         
     | 
| 53 | 
         
            +
                model_name: str = None
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def get_prompt(self) -> str:
         
     | 
| 56 | 
         
            +
                    """Get the prompt for generation."""
         
     | 
| 57 | 
         
            +
                    if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
         
     | 
| 58 | 
         
            +
                        ret = self.system + self.sep
         
     | 
| 59 | 
         
            +
                        for role, message in self.messages:
         
     | 
| 60 | 
         
            +
                            if message:
         
     | 
| 61 | 
         
            +
                                ret += role + ": " + message + self.sep
         
     | 
| 62 | 
         
            +
                            else:
         
     | 
| 63 | 
         
            +
                                ret += role + ":"
         
     | 
| 64 | 
         
            +
                        return ret
         
     | 
| 65 | 
         
            +
                    elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
         
     | 
| 66 | 
         
            +
                        seps = [self.sep, self.sep2]
         
     | 
| 67 | 
         
            +
                        ret = self.system + seps[0]
         
     | 
| 68 | 
         
            +
                        for i, (role, message) in enumerate(self.messages):
         
     | 
| 69 | 
         
            +
                            if message:
         
     | 
| 70 | 
         
            +
                                ret += role + ": " + message + seps[i % 2]
         
     | 
| 71 | 
         
            +
                            else:
         
     | 
| 72 | 
         
            +
                                ret += role + ":"
         
     | 
| 73 | 
         
            +
                        return ret
         
     | 
| 74 | 
         
            +
                    elif self.sep_style == SeparatorStyle.ADD_SPACE_TWO:
         
     | 
| 75 | 
         
            +
                        seps = [self.sep, self.sep2]
         
     | 
| 76 | 
         
            +
                        ret = self.system + seps[0]
         
     | 
| 77 | 
         
            +
                        for i, (role, message) in enumerate(self.messages):
         
     | 
| 78 | 
         
            +
                            if message:
         
     | 
| 79 | 
         
            +
                                ret += role + " " + message + seps[i % 2]
         
     | 
| 80 | 
         
            +
                            else:
         
     | 
| 81 | 
         
            +
                                ret += role + ""
         
     | 
| 82 | 
         
            +
                        return ret
         
     | 
| 83 | 
         
            +
                    elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
         
     | 
| 84 | 
         
            +
                        ret = self.system
         
     | 
| 85 | 
         
            +
                        for role, message in self.messages:
         
     | 
| 86 | 
         
            +
                            if message:
         
     | 
| 87 | 
         
            +
                                ret += role + message + self.sep
         
     | 
| 88 | 
         
            +
                            else:
         
     | 
| 89 | 
         
            +
                                ret += role
         
     | 
| 90 | 
         
            +
                        return ret
         
     | 
| 91 | 
         
            +
                    elif self.sep_style == SeparatorStyle.BAIZE:
         
     | 
| 92 | 
         
            +
                        ret = self.system + "\n"
         
     | 
| 93 | 
         
            +
                        for role, message in self.messages:
         
     | 
| 94 | 
         
            +
                            if message:
         
     | 
| 95 | 
         
            +
                                ret += role + message + "\n"
         
     | 
| 96 | 
         
            +
                            else:
         
     | 
| 97 | 
         
            +
                                ret += role
         
     | 
| 98 | 
         
            +
                        return ret
         
     | 
| 99 | 
         
            +
                    elif self.sep_style == SeparatorStyle.DOLLY:
         
     | 
| 100 | 
         
            +
                        seps = [self.sep, self.sep2]
         
     | 
| 101 | 
         
            +
                        ret = self.system
         
     | 
| 102 | 
         
            +
                        for i, (role, message) in enumerate(self.messages):
         
     | 
| 103 | 
         
            +
                            if message:
         
     | 
| 104 | 
         
            +
                                ret += role + ":\n" + message + seps[i % 2]
         
     | 
| 105 | 
         
            +
                                if i % 2 == 1:
         
     | 
| 106 | 
         
            +
                                    ret += "\n\n"
         
     | 
| 107 | 
         
            +
                            else:
         
     | 
| 108 | 
         
            +
                                ret += role + ":\n"
         
     | 
| 109 | 
         
            +
                        return ret
         
     | 
| 110 | 
         
            +
                    elif self.sep_style == SeparatorStyle.RWKV:
         
     | 
| 111 | 
         
            +
                        ret = self.system
         
     | 
| 112 | 
         
            +
                        for i, (role, message) in enumerate(self.messages):
         
     | 
| 113 | 
         
            +
                            if message:
         
     | 
| 114 | 
         
            +
                                ret += (
         
     | 
| 115 | 
         
            +
                                        role
         
     | 
| 116 | 
         
            +
                                        + ": "
         
     | 
| 117 | 
         
            +
                                        + message.replace("\r\n", "\n").replace("\n\n", "\n")
         
     | 
| 118 | 
         
            +
                                )
         
     | 
| 119 | 
         
            +
                                ret += "\n\n"
         
     | 
| 120 | 
         
            +
                            else:
         
     | 
| 121 | 
         
            +
                                ret += role + ":"
         
     | 
| 122 | 
         
            +
                        return ret
         
     | 
| 123 | 
         
            +
                    elif self.sep_style == SeparatorStyle.PHOENIX:
         
     | 
| 124 | 
         
            +
                        ret = self.system
         
     | 
| 125 | 
         
            +
                        for role, message in self.messages:
         
     | 
| 126 | 
         
            +
                            if message:
         
     | 
| 127 | 
         
            +
                                ret += role + ": " + "<s>" + message + "</s>"
         
     | 
| 128 | 
         
            +
                            else:
         
     | 
| 129 | 
         
            +
                                ret += role + ": " + "<s>"
         
     | 
| 130 | 
         
            +
                        return ret
         
     | 
| 131 | 
         
            +
                    elif self.sep_style == SeparatorStyle.NEW_LINE:
         
     | 
| 132 | 
         
            +
                        ret = self.system + self.sep
         
     | 
| 133 | 
         
            +
                        for role, message in self.messages:
         
     | 
| 134 | 
         
            +
                            if message:
         
     | 
| 135 | 
         
            +
                                ret += role + "\n" + message + self.sep
         
     | 
| 136 | 
         
            +
                            else:
         
     | 
| 137 | 
         
            +
                                ret += role + "\n"
         
     | 
| 138 | 
         
            +
                        return ret
         
     | 
| 139 | 
         
            +
                    elif self.sep_style == SeparatorStyle.BILLA:
         
     | 
| 140 | 
         
            +
                        ret = self.system + self.sep
         
     | 
| 141 | 
         
            +
                        for role, message in self.messages:
         
     | 
| 142 | 
         
            +
                            if message:
         
     | 
| 143 | 
         
            +
                                ret += role + ": " + message + self.sep
         
     | 
| 144 | 
         
            +
                            else:
         
     | 
| 145 | 
         
            +
                                ret += role + ": "  # must be end with a space
         
     | 
| 146 | 
         
            +
                        return ret
         
     | 
| 147 | 
         
            +
                    else:
         
     | 
| 148 | 
         
            +
                        raise ValueError(f"Invalid style: {self.sep_style}")
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                def append_message(self, role: str, message: str):
         
     | 
| 151 | 
         
            +
                    """Append a new message."""
         
     | 
| 152 | 
         
            +
                    self.messages.append([role, message])
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                def to_gradio_chatbot(self):
         
     | 
| 155 | 
         
            +
                    """Convert the history to gradio chatbot format"""
         
     | 
| 156 | 
         
            +
                    ret = []
         
     | 
| 157 | 
         
            +
                    for i, (role, msg) in enumerate(self.messages[self.offset:]):
         
     | 
| 158 | 
         
            +
                        if i % 2 == 0:
         
     | 
| 159 | 
         
            +
                            ret.append([msg, None])
         
     | 
| 160 | 
         
            +
                        else:
         
     | 
| 161 | 
         
            +
                            ret[-1][-1] = msg
         
     | 
| 162 | 
         
            +
                    return ret
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                def to_openai_api_messages(self):
         
     | 
| 165 | 
         
            +
                    """Convert the conversation to OpenAI chat completion format."""
         
     | 
| 166 | 
         
            +
                    ret = [{"role": "system", "content": self.system}]
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    for i, (_, msg) in enumerate(self.messages[self.offset:]):
         
     | 
| 169 | 
         
            +
                        if i % 2 == 0:
         
     | 
| 170 | 
         
            +
                            ret.append({"role": "user", "content": msg})
         
     | 
| 171 | 
         
            +
                        else:
         
     | 
| 172 | 
         
            +
                            if msg is not None:
         
     | 
| 173 | 
         
            +
                                ret.append({"role": "assistant", "content": msg})
         
     | 
| 174 | 
         
            +
                    return ret
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                def copy(self):
         
     | 
| 177 | 
         
            +
                    return Conversation(
         
     | 
| 178 | 
         
            +
                        name=self.name,
         
     | 
| 179 | 
         
            +
                        system=self.system,
         
     | 
| 180 | 
         
            +
                        roles=self.roles,
         
     | 
| 181 | 
         
            +
                        messages=[[x, y] for x, y in self.messages],
         
     | 
| 182 | 
         
            +
                        offset=self.offset,
         
     | 
| 183 | 
         
            +
                        sep_style=self.sep_style,
         
     | 
| 184 | 
         
            +
                        sep=self.sep,
         
     | 
| 185 | 
         
            +
                        sep2=self.sep2,
         
     | 
| 186 | 
         
            +
                        stop_str=self.stop_str,
         
     | 
| 187 | 
         
            +
                        stop_token_ids=self.stop_token_ids,
         
     | 
| 188 | 
         
            +
                        conv_id=self.conv_id,
         
     | 
| 189 | 
         
            +
                        model_name=self.model_name,
         
     | 
| 190 | 
         
            +
                    )
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                def dict(self):
         
     | 
| 193 | 
         
            +
                    return {
         
     | 
| 194 | 
         
            +
                        "name": self.name,
         
     | 
| 195 | 
         
            +
                        "system": self.system,
         
     | 
| 196 | 
         
            +
                        "roles": self.roles,
         
     | 
| 197 | 
         
            +
                        "messages": self.messages,
         
     | 
| 198 | 
         
            +
                        "offset": self.offset,
         
     | 
| 199 | 
         
            +
                        "conv_id": self.conv_id,
         
     | 
| 200 | 
         
            +
                        "model_name": self.model_name,
         
     | 
| 201 | 
         
            +
                    }
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
            # A global registry for all conversation templates
         
     | 
| 205 | 
         
            +
            conv_templates: Dict[str, Conversation] = {}
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
            def register_conv_template(template: Conversation, override: bool = False):
         
     | 
| 209 | 
         
            +
                """Register a new conversation template."""
         
     | 
| 210 | 
         
            +
                if not override:
         
     | 
| 211 | 
         
            +
                    assert template.name not in conv_templates, f"{template.name} has been registered."
         
     | 
| 212 | 
         
            +
                conv_templates[template.name] = template
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
            def get_conv_template(name: str) -> Conversation:
         
     | 
| 216 | 
         
            +
                """Get a conversation template."""
         
     | 
| 217 | 
         
            +
                return conv_templates[name].copy()
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
            # A template with one conversation example
         
     | 
| 221 | 
         
            +
            register_conv_template(
         
     | 
| 222 | 
         
            +
                Conversation(
         
     | 
| 223 | 
         
            +
                    name="one_shot",
         
     | 
| 224 | 
         
            +
                    system="A chat between a curious human and an artificial intelligence assistant. "
         
     | 
| 225 | 
         
            +
                           "The assistant gives helpful, detailed, and polite answers to the human's questions.",
         
     | 
| 226 | 
         
            +
                    roles=("Human", "Assistant"),
         
     | 
| 227 | 
         
            +
                    messages=(
         
     | 
| 228 | 
         
            +
                        (
         
     | 
| 229 | 
         
            +
                            "Human",
         
     | 
| 230 | 
         
            +
                            "What are the key differences between renewable and non-renewable energy sources?",
         
     | 
| 231 | 
         
            +
                        ),
         
     | 
| 232 | 
         
            +
                        (
         
     | 
| 233 | 
         
            +
                            "Assistant",
         
     | 
| 234 | 
         
            +
                            "Renewable energy sources are those that can be replenished naturally in a relatively "
         
     | 
| 235 | 
         
            +
                            "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
         
     | 
| 236 | 
         
            +
                            "Non-renewable energy sources, on the other hand, are finite and will eventually be "
         
     | 
| 237 | 
         
            +
                            "depleted, such as coal, oil, and natural gas. Here are some key differences between "
         
     | 
| 238 | 
         
            +
                            "renewable and non-renewable energy sources:\n"
         
     | 
| 239 | 
         
            +
                            "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
         
     | 
| 240 | 
         
            +
                            "energy sources are finite and will eventually run out.\n"
         
     | 
| 241 | 
         
            +
                            "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
         
     | 
| 242 | 
         
            +
                            "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
         
     | 
| 243 | 
         
            +
                            "and other negative effects.\n"
         
     | 
| 244 | 
         
            +
                            "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
         
     | 
| 245 | 
         
            +
                            "have lower operational costs than non-renewable sources.\n"
         
     | 
| 246 | 
         
            +
                            "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
         
     | 
| 247 | 
         
            +
                            "locations than non-renewable sources.\n"
         
     | 
| 248 | 
         
            +
                            "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
         
     | 
| 249 | 
         
            +
                            "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
         
     | 
| 250 | 
         
            +
                            "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
         
     | 
| 251 | 
         
            +
                            "non-renewable sources are not, and their depletion can lead to economic and social instability.",
         
     | 
| 252 | 
         
            +
                        ),
         
     | 
| 253 | 
         
            +
                    ),
         
     | 
| 254 | 
         
            +
                    offset=2,
         
     | 
| 255 | 
         
            +
                    sep_style=SeparatorStyle.ADD_COLON_SINGLE,
         
     | 
| 256 | 
         
            +
                    sep="\n### ",
         
     | 
| 257 | 
         
            +
                    stop_str="###",
         
     | 
| 258 | 
         
            +
                )
         
     | 
| 259 | 
         
            +
            )
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
            # Vicuna v1.1 template
         
     | 
| 262 | 
         
            +
            register_conv_template(
         
     | 
| 263 | 
         
            +
                Conversation(
         
     | 
| 264 | 
         
            +
                    name="vicuna_v1.1",
         
     | 
| 265 | 
         
            +
                    system="A chat between a curious user and an artificial intelligence assistant. "
         
     | 
| 266 | 
         
            +
                           "The assistant gives helpful, detailed, and polite answers to the user's questions.",
         
     | 
| 267 | 
         
            +
                    roles=("USER", "ASSISTANT"),
         
     | 
| 268 | 
         
            +
                    messages=(),
         
     | 
| 269 | 
         
            +
                    offset=0,
         
     | 
| 270 | 
         
            +
                    sep_style=SeparatorStyle.ADD_COLON_TWO,
         
     | 
| 271 | 
         
            +
                    sep=" ",
         
     | 
| 272 | 
         
            +
                    sep2="</s>",
         
     | 
| 273 | 
         
            +
                )
         
     | 
| 274 | 
         
            +
            )
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
            # Koala default template
         
     | 
| 277 | 
         
            +
            register_conv_template(
         
     | 
| 278 | 
         
            +
                Conversation(
         
     | 
| 279 | 
         
            +
                    name="koala_v1",
         
     | 
| 280 | 
         
            +
                    system="BEGINNING OF CONVERSATION:",
         
     | 
| 281 | 
         
            +
                    roles=("USER", "GPT"),
         
     | 
| 282 | 
         
            +
                    messages=(),
         
     | 
| 283 | 
         
            +
                    offset=0,
         
     | 
| 284 | 
         
            +
                    sep_style=SeparatorStyle.ADD_COLON_TWO,
         
     | 
| 285 | 
         
            +
                    sep=" ",
         
     | 
| 286 | 
         
            +
                    sep2="</s>",
         
     | 
| 287 | 
         
            +
                )
         
     | 
| 288 | 
         
            +
            )
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
            # Dolly V2 default template
         
     | 
| 291 | 
         
            +
            register_conv_template(
         
     | 
| 292 | 
         
            +
                Conversation(
         
     | 
| 293 | 
         
            +
                    name="dolly_v2",
         
     | 
| 294 | 
         
            +
                    system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
         
     | 
| 295 | 
         
            +
                    roles=("### Instruction", "### Response"),
         
     | 
| 296 | 
         
            +
                    messages=(),
         
     | 
| 297 | 
         
            +
                    offset=0,
         
     | 
| 298 | 
         
            +
                    sep_style=SeparatorStyle.DOLLY,
         
     | 
| 299 | 
         
            +
                    sep="\n\n",
         
     | 
| 300 | 
         
            +
                    sep2="### End",
         
     | 
| 301 | 
         
            +
                )
         
     | 
| 302 | 
         
            +
            )
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
            # OpenAssistant Pythia default template
         
     | 
| 305 | 
         
            +
            register_conv_template(
         
     | 
| 306 | 
         
            +
                Conversation(
         
     | 
| 307 | 
         
            +
                    name="oasst_pythia",
         
     | 
| 308 | 
         
            +
                    system="",
         
     | 
| 309 | 
         
            +
                    roles=("<|prompter|>", "<|assistant|>"),
         
     | 
| 310 | 
         
            +
                    messages=(),
         
     | 
| 311 | 
         
            +
                    offset=0,
         
     | 
| 312 | 
         
            +
                    sep_style=SeparatorStyle.NO_COLON_SINGLE,
         
     | 
| 313 | 
         
            +
                    sep="<|endoftext|>",
         
     | 
| 314 | 
         
            +
                )
         
     | 
| 315 | 
         
            +
            )
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
            # StableLM Alpha default template
         
     | 
| 318 | 
         
            +
            register_conv_template(
         
     | 
| 319 | 
         
            +
                Conversation(
         
     | 
| 320 | 
         
            +
                    name="stablelm",
         
     | 
| 321 | 
         
            +
                    system="""<|SYSTEM|># StableLM Tuned (Alpha version)
         
     | 
| 322 | 
         
            +
            - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
         
     | 
| 323 | 
         
            +
            - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
         
     | 
| 324 | 
         
            +
            - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
         
     | 
| 325 | 
         
            +
            - StableLM will refuse to participate in anything that could harm a human.
         
     | 
| 326 | 
         
            +
            """,
         
     | 
| 327 | 
         
            +
                    roles=("<|USER|>", "<|ASSISTANT|>"),
         
     | 
| 328 | 
         
            +
                    messages=(),
         
     | 
| 329 | 
         
            +
                    offset=0,
         
     | 
| 330 | 
         
            +
                    sep_style=SeparatorStyle.NO_COLON_SINGLE,
         
     | 
| 331 | 
         
            +
                    sep="",
         
     | 
| 332 | 
         
            +
                    stop_token_ids=[50278, 50279, 50277, 1, 0],
         
     | 
| 333 | 
         
            +
                )
         
     | 
| 334 | 
         
            +
            )
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
            # Baize default template
         
     | 
| 337 | 
         
            +
            register_conv_template(
         
     | 
| 338 | 
         
            +
                Conversation(
         
     | 
| 339 | 
         
            +
                    name="baize",
         
     | 
| 340 | 
         
            +
                    system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.",
         
     | 
| 341 | 
         
            +
                    roles=("[|Human|]", "[|AI|]"),
         
     | 
| 342 | 
         
            +
                    messages=(
         
     | 
| 343 | 
         
            +
                        ("[|Human|]", "Hello!"),
         
     | 
| 344 | 
         
            +
                        ("[|AI|]", "Hi!"),
         
     | 
| 345 | 
         
            +
                    ),
         
     | 
| 346 | 
         
            +
                    offset=2,
         
     | 
| 347 | 
         
            +
                    sep_style=SeparatorStyle.BAIZE,
         
     | 
| 348 | 
         
            +
                    sep="[|Human|]",
         
     | 
| 349 | 
         
            +
                    stop_str="[|Human|]",
         
     | 
| 350 | 
         
            +
                )
         
     | 
| 351 | 
         
            +
            )
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
            # RWKV-4-Raven default template
         
     | 
| 354 | 
         
            +
            register_conv_template(
         
     | 
| 355 | 
         
            +
                Conversation(
         
     | 
| 356 | 
         
            +
                    name="rwkv",
         
     | 
| 357 | 
         
            +
                    system="The following is a coherent verbose detailed conversation between Bob and Alice.\n\n",
         
     | 
| 358 | 
         
            +
                    roles=("Bob", "Alice"),
         
     | 
| 359 | 
         
            +
                    messages=(
         
     | 
| 360 | 
         
            +
                        ("Bob", "Hi"),
         
     | 
| 361 | 
         
            +
                        (
         
     | 
| 362 | 
         
            +
                            "Alice",
         
     | 
| 363 | 
         
            +
                            "Hi. I am your assistant and I will answer all questions. Please feel free to ask any question and I will always answer it.",
         
     | 
| 364 | 
         
            +
                        ),
         
     | 
| 365 | 
         
            +
                    ),
         
     | 
| 366 | 
         
            +
                    offset=2,
         
     | 
| 367 | 
         
            +
                    sep_style=SeparatorStyle.RWKV,
         
     | 
| 368 | 
         
            +
                    sep="",
         
     | 
| 369 | 
         
            +
                    stop_str="\n\n",
         
     | 
| 370 | 
         
            +
                )
         
     | 
| 371 | 
         
            +
            )
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
            # Buddy default template
         
     | 
| 374 | 
         
            +
            register_conv_template(
         
     | 
| 375 | 
         
            +
                Conversation(
         
     | 
| 376 | 
         
            +
                    name="openbuddy",
         
     | 
| 377 | 
         
            +
                    system="""Consider a conversation between User (a human) and Assistant (named Buddy).
         
     | 
| 378 | 
         
            +
            Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy
         
     | 
| 379 | 
         
            +
            Buddy cannot access the Internet.
         
     | 
| 380 | 
         
            +
            Buddy can fluently speak the user's language (e.g. English, Chinese).
         
     | 
| 381 | 
         
            +
            Buddy can generate poems, stories, code, essays, songs, parodies, and more.
         
     | 
| 382 | 
         
            +
            Buddy possesses vast knowledge about the world, history, and culture.
         
     | 
| 383 | 
         
            +
            Buddy's responses are always safe, creative, high-quality, human-like, and interesting.
         
     | 
| 384 | 
         
            +
            Buddy strictly refuses to discuss political, NSFW, or other unsafe topics.
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
            User: Hi.
         
     | 
| 387 | 
         
            +
            Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""",
         
     | 
| 388 | 
         
            +
                    roles=("User", "Assistant"),
         
     | 
| 389 | 
         
            +
                    messages=(),
         
     | 
| 390 | 
         
            +
                    offset=0,
         
     | 
| 391 | 
         
            +
                    sep_style=SeparatorStyle.ADD_COLON_SINGLE,
         
     | 
| 392 | 
         
            +
                    sep="\n",
         
     | 
| 393 | 
         
            +
                )
         
     | 
| 394 | 
         
            +
            )
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
            # Phoenix default template
         
     | 
| 397 | 
         
            +
            register_conv_template(
         
     | 
| 398 | 
         
            +
                Conversation(
         
     | 
| 399 | 
         
            +
                    name="phoenix",
         
     | 
| 400 | 
         
            +
                    system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
         
     | 
| 401 | 
         
            +
                    roles=("Human", "Assistant"),
         
     | 
| 402 | 
         
            +
                    messages=(),
         
     | 
| 403 | 
         
            +
                    offset=0,
         
     | 
| 404 | 
         
            +
                    sep_style=SeparatorStyle.PHOENIX,
         
     | 
| 405 | 
         
            +
                    sep="</s>",
         
     | 
| 406 | 
         
            +
                )
         
     | 
| 407 | 
         
            +
            )
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
            # ChatGPT default template
         
     | 
| 410 | 
         
            +
            register_conv_template(
         
     | 
| 411 | 
         
            +
                Conversation(
         
     | 
| 412 | 
         
            +
                    name="chatgpt",
         
     | 
| 413 | 
         
            +
                    system="You are a helpful assistant.",
         
     | 
| 414 | 
         
            +
                    roles=("user", "assistant"),
         
     | 
| 415 | 
         
            +
                    messages=(),
         
     | 
| 416 | 
         
            +
                    offset=0,
         
     | 
| 417 | 
         
            +
                    sep_style=None,
         
     | 
| 418 | 
         
            +
                    sep=None,
         
     | 
| 419 | 
         
            +
                )
         
     | 
| 420 | 
         
            +
            )
         
     | 
| 421 | 
         
            +
             
     | 
| 422 | 
         
            +
            # Claude default template
         
     | 
| 423 | 
         
            +
            register_conv_template(
         
     | 
| 424 | 
         
            +
                Conversation(
         
     | 
| 425 | 
         
            +
                    name="claude",
         
     | 
| 426 | 
         
            +
                    system="",
         
     | 
| 427 | 
         
            +
                    roles=("Human", "Assistant"),
         
     | 
| 428 | 
         
            +
                    messages=(),
         
     | 
| 429 | 
         
            +
                    offset=0,
         
     | 
| 430 | 
         
            +
                    sep_style=SeparatorStyle.ADD_COLON_SINGLE,
         
     | 
| 431 | 
         
            +
                    sep="\n\n",
         
     | 
| 432 | 
         
            +
                )
         
     | 
| 433 | 
         
            +
            )
         
     | 
| 434 | 
         
            +
             
     | 
| 435 | 
         
            +
            # MPT default template
         
     | 
| 436 | 
         
            +
            register_conv_template(
         
     | 
| 437 | 
         
            +
                Conversation(
         
     | 
| 438 | 
         
            +
                    name="mpt",
         
     | 
| 439 | 
         
            +
                    system="""<|im_start|>system
         
     | 
| 440 | 
         
            +
            - You are a helpful assistant chatbot trained by MosaicML.
         
     | 
| 441 | 
         
            +
            - You answer questions.
         
     | 
| 442 | 
         
            +
            - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
         
     | 
| 443 | 
         
            +
            - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.
         
     | 
| 444 | 
         
            +
            """,
         
     | 
| 445 | 
         
            +
                    roles=("<|im_start|>user", "<|im_start|>assistant"),
         
     | 
| 446 | 
         
            +
                    messages=(),
         
     | 
| 447 | 
         
            +
                    offset=0,
         
     | 
| 448 | 
         
            +
                    sep_style=SeparatorStyle.NEW_LINE,
         
     | 
| 449 | 
         
            +
                    sep="<|im_end|>",
         
     | 
| 450 | 
         
            +
                    stop_token_ids=[50278, 0],
         
     | 
| 451 | 
         
            +
                )
         
     | 
| 452 | 
         
            +
            )
         
     | 
| 453 | 
         
            +
             
     | 
| 454 | 
         
            +
            # Bard default template
         
     | 
| 455 | 
         
            +
            # Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150
         
     | 
| 456 | 
         
            +
            #            https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40
         
     | 
| 457 | 
         
            +
            register_conv_template(
         
     | 
| 458 | 
         
            +
                Conversation(
         
     | 
| 459 | 
         
            +
                    name="bard",
         
     | 
| 460 | 
         
            +
                    system="",
         
     | 
| 461 | 
         
            +
                    roles=("0", "1"),
         
     | 
| 462 | 
         
            +
                    messages=(),
         
     | 
| 463 | 
         
            +
                    offset=0,
         
     | 
| 464 | 
         
            +
                    sep_style=None,
         
     | 
| 465 | 
         
            +
                    sep=None,
         
     | 
| 466 | 
         
            +
                )
         
     | 
| 467 | 
         
            +
            )
         
     | 
| 468 | 
         
            +
             
     | 
| 469 | 
         
            +
            # BiLLa default template
         
     | 
| 470 | 
         
            +
            register_conv_template(
         
     | 
| 471 | 
         
            +
                Conversation(
         
     | 
| 472 | 
         
            +
                    name="billa",
         
     | 
| 473 | 
         
            +
                    system="",
         
     | 
| 474 | 
         
            +
                    roles=("Human", "Assistant"),
         
     | 
| 475 | 
         
            +
                    messages=(),
         
     | 
| 476 | 
         
            +
                    offset=0,
         
     | 
| 477 | 
         
            +
                    sep_style=SeparatorStyle.BILLA,
         
     | 
| 478 | 
         
            +
                    sep="\n",
         
     | 
| 479 | 
         
            +
                    stop_str="Human:",
         
     | 
| 480 | 
         
            +
                )
         
     | 
| 481 | 
         
            +
            )
         
     | 
| 482 | 
         
            +
             
     | 
| 483 | 
         
            +
            # custom otter template
         
     | 
| 484 | 
         
            +
            register_conv_template(
         
     | 
| 485 | 
         
            +
                Conversation(
         
     | 
| 486 | 
         
            +
                    name='otter',
         
     | 
| 487 | 
         
            +
                    system='',
         
     | 
| 488 | 
         
            +
                    roles=('User:', 'GPT:<answer>'),
         
     | 
| 489 | 
         
            +
                    messages=(),
         
     | 
| 490 | 
         
            +
                    offset=0,
         
     | 
| 491 | 
         
            +
                    sep_style=SeparatorStyle.ADD_SPACE_TWO,
         
     | 
| 492 | 
         
            +
                    sep=' ',
         
     | 
| 493 | 
         
            +
                    sep2='<|endofchunk|>',
         
     | 
| 494 | 
         
            +
                )
         
     | 
| 495 | 
         
            +
            )
         
     | 
| 496 | 
         
            +
             
     | 
| 497 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 498 | 
         
            +
                conv = get_conv_template("vicuna_v1.1")
         
     | 
| 499 | 
         
            +
                conv.append_message(conv.roles[0], "Hello!")
         
     | 
| 500 | 
         
            +
                conv.append_message(conv.roles[1], "Hi!")
         
     | 
| 501 | 
         
            +
                conv.append_message(conv.roles[0], "How are you?")
         
     | 
| 502 | 
         
            +
                conv.append_message(conv.roles[1], None)
         
     | 
| 503 | 
         
            +
                print(conv.get_prompt())
         
     | 
    	
        mllm/dataset/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,7 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .root import *
         
     | 
| 2 | 
         
            +
            from .utils import *
         
     | 
| 3 | 
         
            +
            from .process_function import *
         
     | 
| 4 | 
         
            +
            from .single_image_convsation import *
         
     | 
| 5 | 
         
            +
            from .single_image_dataset import *
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from .builder import prepare_data
         
     | 
    	
        mllm/dataset/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | 
         Binary file (322 Bytes). View file 
     | 
| 
         | 
    	
        mllm/dataset/__pycache__/builder.cpython-310.pyc
    ADDED
    
    | 
         Binary file (2.96 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/__pycache__/root.cpython-310.pyc
    ADDED
    
    | 
         Binary file (2.42 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/__pycache__/single_image_convsation.cpython-310.pyc
    ADDED
    
    | 
         Binary file (11 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/__pycache__/single_image_interactive.cpython-310.pyc
    ADDED
    
    | 
         Binary file (4.15 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/builder.py
    ADDED
    
    | 
         @@ -0,0 +1,118 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from functools import partial
         
     | 
| 2 | 
         
            +
            from typing import Callable, Dict, Tuple, Any, Optional
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from torch.utils.data import Dataset
         
     | 
| 5 | 
         
            +
            from transformers import EvalPrediction, TrainingArguments
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from .root import DATASETS, METRICS, TRANSFORMS, FUNCTIONS
         
     | 
| 8 | 
         
            +
            from .single_image_convsation import SingleImageConvDataset
         
     | 
| 9 | 
         
            +
            from .single_image_interactive import SingleImageInteractive
         
     | 
| 10 | 
         
            +
            from ..conversation import get_conv_template
         
     | 
| 11 | 
         
            +
            from .utils import init_ceph_client_if_needed
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            DatasetDict = Dict[str, Dataset]
         
     | 
| 14 | 
         
            +
            ComputeMetrics = Callable[[EvalPrediction], Dict]
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def prepare_data(
         
     | 
| 18 | 
         
            +
                    data_args,
         
     | 
| 19 | 
         
            +
                    model_args,
         
     | 
| 20 | 
         
            +
                    training_args: TrainingArguments,
         
     | 
| 21 | 
         
            +
                    preprocessor: Dict[str, Any],
         
     | 
| 22 | 
         
            +
            ) -> Tuple[DatasetDict, Optional[ComputeMetrics]]:
         
     | 
| 23 | 
         
            +
                # raw dataset
         
     | 
| 24 | 
         
            +
                datasets = {
         
     | 
| 25 | 
         
            +
                    'train': partial(DATASETS.build, data_args.train) if training_args.do_train else None,
         
     | 
| 26 | 
         
            +
                    'validation': partial(DATASETS.build, data_args.validation) if training_args.do_eval else None,
         
     | 
| 27 | 
         
            +
                    'test': partial(DATASETS.build, data_args.test) if training_args.do_predict else None,
         
     | 
| 28 | 
         
            +
                }
         
     | 
| 29 | 
         
            +
                # compute metric
         
     | 
| 30 | 
         
            +
                compute_metric_cfg = data_args.get('compute_metric', None)
         
     | 
| 31 | 
         
            +
                compute_metrics = build_compute_metric(compute_metric_cfg, preprocessor)
         
     | 
| 32 | 
         
            +
                # conv dataset wrap
         
     | 
| 33 | 
         
            +
                conv_args = model_args.conv_args
         
     | 
| 34 | 
         
            +
                tokenize_kwargs = conv_args.get('tokenize_kwargs', {})
         
     | 
| 35 | 
         
            +
                conv_template = conv_args.get('conv_template', 'vicuna_v1.1')
         
     | 
| 36 | 
         
            +
                conv_template = partial(get_conv_template, name=conv_template)
         
     | 
| 37 | 
         
            +
                transforms = conv_args.get('transforms', None)
         
     | 
| 38 | 
         
            +
                if transforms is not None:
         
     | 
| 39 | 
         
            +
                    transforms = TRANSFORMS.build(transforms)
         
     | 
| 40 | 
         
            +
                # process func
         
     | 
| 41 | 
         
            +
                process_func = {}
         
     | 
| 42 | 
         
            +
                for k, v in model_args.process_func_args.items():
         
     | 
| 43 | 
         
            +
                    process_func[k] = FUNCTIONS.build(cfg=v)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                conv_dataset_cls = partial(
         
     | 
| 46 | 
         
            +
                    SingleImageConvDataset,
         
     | 
| 47 | 
         
            +
                    preprocessor=preprocessor,
         
     | 
| 48 | 
         
            +
                    process_func=process_func,
         
     | 
| 49 | 
         
            +
                    tokenize_kwargs=tokenize_kwargs,
         
     | 
| 50 | 
         
            +
                    conv_template=conv_template,
         
     | 
| 51 | 
         
            +
                    training_args=training_args,
         
     | 
| 52 | 
         
            +
                    transforms=transforms,
         
     | 
| 53 | 
         
            +
                )
         
     | 
| 54 | 
         
            +
                ds = {
         
     | 
| 55 | 
         
            +
                    'train': conv_dataset_cls(dataset_generator=datasets['train'], mode='train') if datasets['train'] is not None else None,
         
     | 
| 56 | 
         
            +
                    'validation': conv_dataset_cls(dataset_generator=datasets['validation'], mode='validation') if datasets['validation'] is not None else None,
         
     | 
| 57 | 
         
            +
                    'test': conv_dataset_cls(dataset_generator=datasets['test'], mode='test') if datasets['test'] is not None else None,
         
     | 
| 58 | 
         
            +
                }
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                # multi test set
         
     | 
| 61 | 
         
            +
                if hasattr(data_args, 'multitest') and bool(data_args.multitest) \
         
     | 
| 62 | 
         
            +
                        and hasattr(training_args, 'do_multi_predict') and training_args.do_multi_predict:
         
     | 
| 63 | 
         
            +
                    print(f"processing multitest set")
         
     | 
| 64 | 
         
            +
                    k2v = {}
         
     | 
| 65 | 
         
            +
                    for k, item in data_args.multitest.items():
         
     | 
| 66 | 
         
            +
                        _dataset_cls = partial(DATASETS.build, item['cfg'])
         
     | 
| 67 | 
         
            +
                        _compute_metric = build_compute_metric(item['compute_metric'], preprocessor)
         
     | 
| 68 | 
         
            +
                        k2v[k] = {
         
     | 
| 69 | 
         
            +
                            "dataset": conv_dataset_cls(dataset_generator=_dataset_cls, mode='test'),
         
     | 
| 70 | 
         
            +
                            "compute_metric": _compute_metric
         
     | 
| 71 | 
         
            +
                        }
         
     | 
| 72 | 
         
            +
                    ds['multitest'] = k2v
         
     | 
| 73 | 
         
            +
                    print(f"processing multitest set. done.")
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                # in default, ceph client do init at the beginning of program.
         
     | 
| 76 | 
         
            +
                #  importantly, before dataloader worker fork.
         
     | 
| 77 | 
         
            +
                lazy_init = data_args.get('lazy_init', True)
         
     | 
| 78 | 
         
            +
                if not lazy_init:
         
     | 
| 79 | 
         
            +
                    init_ceph_client_if_needed()
         
     | 
| 80 | 
         
            +
                return ds, compute_metrics
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            def build_compute_metric(compute_metric_cfg, preprocessor):
         
     | 
| 84 | 
         
            +
                if compute_metric_cfg is not None:
         
     | 
| 85 | 
         
            +
                    compute_metric_cfg = dict(compute_metric_cfg)  # copy cfg because we modify it
         
     | 
| 86 | 
         
            +
                    compute_metric_cfg.update(dict(preprocessor=preprocessor))
         
     | 
| 87 | 
         
            +
                    compute_metrics = METRICS.build(cfg=compute_metric_cfg)
         
     | 
| 88 | 
         
            +
                else:
         
     | 
| 89 | 
         
            +
                    compute_metrics = None
         
     | 
| 90 | 
         
            +
                return compute_metrics
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            def prepare_interactive(
         
     | 
| 94 | 
         
            +
                    model_args,
         
     | 
| 95 | 
         
            +
                    preprocessor: Dict[str, Any],
         
     | 
| 96 | 
         
            +
            ):
         
     | 
| 97 | 
         
            +
                conv_args = model_args.conv_args
         
     | 
| 98 | 
         
            +
                tokenize_kwargs = conv_args.get('tokenize_kwargs', {})
         
     | 
| 99 | 
         
            +
                conv_template = conv_args.get('conv_template', 'vicuna_v1.1')
         
     | 
| 100 | 
         
            +
                conv_template = partial(get_conv_template, name=conv_template)
         
     | 
| 101 | 
         
            +
                transforms = conv_args.get('transforms', None)
         
     | 
| 102 | 
         
            +
                if transforms is not None:
         
     | 
| 103 | 
         
            +
                    transforms = TRANSFORMS.build(transforms)
         
     | 
| 104 | 
         
            +
                # process func
         
     | 
| 105 | 
         
            +
                process_func = {}
         
     | 
| 106 | 
         
            +
                for k, v in model_args.process_func_args.items():
         
     | 
| 107 | 
         
            +
                    process_func[k] = FUNCTIONS.build(cfg=v)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                ds = SingleImageInteractive(
         
     | 
| 110 | 
         
            +
                    preprocessor=preprocessor,
         
     | 
| 111 | 
         
            +
                    process_func=process_func,
         
     | 
| 112 | 
         
            +
                    tokenize_kwargs=tokenize_kwargs,
         
     | 
| 113 | 
         
            +
                    conv_template=conv_template,
         
     | 
| 114 | 
         
            +
                    training_args=None,
         
     | 
| 115 | 
         
            +
                    transforms=transforms,
         
     | 
| 116 | 
         
            +
                    mode='test',
         
     | 
| 117 | 
         
            +
                )
         
     | 
| 118 | 
         
            +
                return ds
         
     | 
    	
        mllm/dataset/process_function/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,13 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .shikra_process_function import (
         
     | 
| 2 | 
         
            +
                ShikraConvProcess,
         
     | 
| 3 | 
         
            +
                ShikraImageProcessor,
         
     | 
| 4 | 
         
            +
                ShikraTextProcess,
         
     | 
| 5 | 
         
            +
            )
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from .box_process_function import (
         
     | 
| 8 | 
         
            +
                BoxFormatProcess,
         
     | 
| 9 | 
         
            +
                BoxFormatter,
         
     | 
| 10 | 
         
            +
                PlainBoxFormatter,
         
     | 
| 11 | 
         
            +
                TokenFormatter,
         
     | 
| 12 | 
         
            +
                prepare_target_processor,
         
     | 
| 13 | 
         
            +
            )
         
     | 
    	
        mllm/dataset/process_function/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | 
         Binary file (458 Bytes). View file 
     | 
| 
         | 
    	
        mllm/dataset/process_function/__pycache__/box_process_function.cpython-310.pyc
    ADDED
    
    | 
         Binary file (10.7 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/process_function/__pycache__/shikra_process_function.cpython-310.pyc
    ADDED
    
    | 
         Binary file (6.02 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/process_function/box_process_function.py
    ADDED
    
    | 
         @@ -0,0 +1,326 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import re
         
     | 
| 2 | 
         
            +
            import sys
         
     | 
| 3 | 
         
            +
            import logging
         
     | 
| 4 | 
         
            +
            import typing
         
     | 
| 5 | 
         
            +
            from typing import List, Dict, Any, Tuple, Union
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from ..utils.transform import norm_box_xyxy, norm_point_xyxy
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from ..root import (
         
     | 
| 10 | 
         
            +
                FUNCTIONS,
         
     | 
| 11 | 
         
            +
                BaseTargetProcessFunc,
         
     | 
| 12 | 
         
            +
                BOXES_PLACEHOLDER,
         
     | 
| 13 | 
         
            +
                BOXES_PROCESSOR,
         
     | 
| 14 | 
         
            +
                POINTS_PLACEHOLDER,
         
     | 
| 15 | 
         
            +
            )
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            from ...utils import smart_tokenizer_and_embedding_resize
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 20 | 
         
            +
            logger.setLevel(logging.INFO)
         
     | 
| 21 | 
         
            +
            logging.basicConfig(
         
     | 
| 22 | 
         
            +
                format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
         
     | 
| 23 | 
         
            +
                datefmt="%m/%d/%Y %H:%M:%S",
         
     | 
| 24 | 
         
            +
                handlers=[logging.StreamHandler(sys.stdout), ],
         
     | 
| 25 | 
         
            +
            )
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            Box = List[Union[float, int]]
         
     | 
| 28 | 
         
            +
            Boxes = List[Box]
         
     | 
| 29 | 
         
            +
            BoxesSeq = List[Boxes]
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            @FUNCTIONS.register_module()
         
     | 
| 33 | 
         
            +
            class BoxFormatProcess(BaseTargetProcessFunc):
         
     | 
| 34 | 
         
            +
                def __call__(self, raw_conv: List[Dict[str, Any]], target: Dict[str, Any], preprocessor: Dict[str, Any],
         
     | 
| 35 | 
         
            +
                             multimage_mode=False) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
         
     | 
| 36 | 
         
            +
                    box_formatter = preprocessor['target']['boxes']
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    if multimage_mode:
         
     | 
| 39 | 
         
            +
                        target = typing.cast(list, target)
         
     | 
| 40 | 
         
            +
                        outer_normalized_boxes = []
         
     | 
| 41 | 
         
            +
                        for tgt in target:
         
     | 
| 42 | 
         
            +
                            normalized_boxes = []
         
     | 
| 43 | 
         
            +
                            if tgt is not None and 'boxes' in tgt:
         
     | 
| 44 | 
         
            +
                                for box in tgt['boxes']:
         
     | 
| 45 | 
         
            +
                                    normalized_boxes.append(
         
     | 
| 46 | 
         
            +
                                        norm_box_xyxy(box, w=tgt['width'], h=tgt['height'])
         
     | 
| 47 | 
         
            +
                                    )
         
     | 
| 48 | 
         
            +
                            outer_normalized_boxes.append(normalized_boxes)
         
     | 
| 49 | 
         
            +
                        normalized_boxes = outer_normalized_boxes
         
     | 
| 50 | 
         
            +
                        outer_normalized_points = []
         
     | 
| 51 | 
         
            +
                        for tgt in target:
         
     | 
| 52 | 
         
            +
                            normalized_points = []
         
     | 
| 53 | 
         
            +
                            if tgt is not None and 'boxes' in tgt:
         
     | 
| 54 | 
         
            +
                                for box in tgt['boxes']:
         
     | 
| 55 | 
         
            +
                                    normalized_points.append(
         
     | 
| 56 | 
         
            +
                                        norm_box_xyxy(box, w=tgt['width'], h=tgt['height'])
         
     | 
| 57 | 
         
            +
                                    )
         
     | 
| 58 | 
         
            +
                            outer_normalized_points.append(normalized_points)
         
     | 
| 59 | 
         
            +
                        normalized_points = outer_normalized_points
         
     | 
| 60 | 
         
            +
                    else:
         
     | 
| 61 | 
         
            +
                        # normalize target
         
     | 
| 62 | 
         
            +
                        normalized_boxes = []
         
     | 
| 63 | 
         
            +
                        if target is not None and 'boxes' in target:
         
     | 
| 64 | 
         
            +
                            for box in target['boxes']:
         
     | 
| 65 | 
         
            +
                                normalized_boxes.append(
         
     | 
| 66 | 
         
            +
                                    norm_box_xyxy(box, w=target['width'], h=target['height'])
         
     | 
| 67 | 
         
            +
                                )
         
     | 
| 68 | 
         
            +
                        normalized_points = []
         
     | 
| 69 | 
         
            +
                        if target is not None and 'points' in target:
         
     | 
| 70 | 
         
            +
                            for point in target['points']:
         
     | 
| 71 | 
         
            +
                                normalized_points.append(
         
     | 
| 72 | 
         
            +
                                    norm_point_xyxy(point, w=target['width'], h=target['height'])
         
     | 
| 73 | 
         
            +
                                )
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    # convert bboxes_seq
         
     | 
| 76 | 
         
            +
                    for sentence in raw_conv:
         
     | 
| 77 | 
         
            +
                        words: str = sentence['value']
         
     | 
| 78 | 
         
            +
                        boxes_seq: List[List[int]] = sentence.get('boxes_seq', None)
         
     | 
| 79 | 
         
            +
                        if boxes_seq is not None:
         
     | 
| 80 | 
         
            +
                            # map box seq
         
     | 
| 81 | 
         
            +
                            boxes_seq: List[Boxes] = map_obj(normalized_boxes, boxes_seq)
         
     | 
| 82 | 
         
            +
                            # reformat; replace <boxes> placeholder
         
     | 
| 83 | 
         
            +
                            converted = box_formatter(words, boxes_seq)
         
     | 
| 84 | 
         
            +
                            words = converted
         
     | 
| 85 | 
         
            +
                        points_seq: List[List[int]] = sentence.get('points_seq', None)
         
     | 
| 86 | 
         
            +
                        if points_seq is not None:
         
     | 
| 87 | 
         
            +
                            # map point seq
         
     | 
| 88 | 
         
            +
                            points_seq: List[Boxes] = map_obj(normalized_points, points_seq)
         
     | 
| 89 | 
         
            +
                            # reformat; replace <points> placeholder
         
     | 
| 90 | 
         
            +
                            converted = box_formatter.call_on_point(words, points_seq)
         
     | 
| 91 | 
         
            +
                            words = converted
         
     | 
| 92 | 
         
            +
                        if boxes_seq is not None or points_seq is not None:
         
     | 
| 93 | 
         
            +
                            sentence['raw_value'] = sentence['value']
         
     | 
| 94 | 
         
            +
                            sentence['value'] = words
         
     | 
| 95 | 
         
            +
                    return raw_conv, target
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            def map_obj(boxes_value: List[List[float]], boxes_seq: List[List[int]]) -> List[List[List[float]]]:
         
     | 
| 99 | 
         
            +
                """
         
     | 
| 100 | 
         
            +
                >>> normalized_boxes = [[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2], [0.3, 0.3, 0.3, 0.3]]
         
     | 
| 101 | 
         
            +
                >>> boxes_seq_ = [[3, 1], [2]]
         
     | 
| 102 | 
         
            +
                >>> var = map_obj(normalized_boxes, boxes_seq_)
         
     | 
| 103 | 
         
            +
                >>> assert var == [[[0.3,0.3,0.3,0.3], [0.1,0.1,0.1,0.1]], [0.2,0.2,0.2,0.2]]
         
     | 
| 104 | 
         
            +
                """
         
     | 
| 105 | 
         
            +
                try:
         
     | 
| 106 | 
         
            +
                    ret = []
         
     | 
| 107 | 
         
            +
                    for boxes in boxes_seq:
         
     | 
| 108 | 
         
            +
                        boxes_ret = []
         
     | 
| 109 | 
         
            +
                        for box_index in boxes:
         
     | 
| 110 | 
         
            +
                            if isinstance(box_index, (list, tuple)):
         
     | 
| 111 | 
         
            +
                                boxes_ret.append(boxes_value[box_index[0]][box_index[1]])
         
     | 
| 112 | 
         
            +
                            else:
         
     | 
| 113 | 
         
            +
                                boxes_ret.append(boxes_value[box_index])
         
     | 
| 114 | 
         
            +
                        ret.append(boxes_ret)
         
     | 
| 115 | 
         
            +
                    return ret
         
     | 
| 116 | 
         
            +
                except:
         
     | 
| 117 | 
         
            +
                    raise SystemExit(f"error: map obj {boxes_value} {boxes_seq}")
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
            class BoxFormatter:
         
     | 
| 121 | 
         
            +
                def __init__(self, bboxes_token=BOXES_PLACEHOLDER, points_token=POINTS_PLACEHOLDER):
         
     | 
| 122 | 
         
            +
                    self.bboxes_token = bboxes_token
         
     | 
| 123 | 
         
            +
                    self.points_token = points_token
         
     | 
| 124 | 
         
            +
                    # normally the bboxes_token_pat is the same as bboxes_token if u not use some weird token
         
     | 
| 125 | 
         
            +
                    self.bboxes_token_pat = re.compile(bboxes_token)
         
     | 
| 126 | 
         
            +
                    self.points_token_pat = re.compile(points_token)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                def __call__(self, sentence: str, bboxes_seq: BoxesSeq) -> str:
         
     | 
| 129 | 
         
            +
                    all_box = self.bboxes_token_pat.findall(sentence)
         
     | 
| 130 | 
         
            +
                    assert len(all_box) == len(bboxes_seq), f"not match. sentence: {sentence}. boxes:{bboxes_seq}"
         
     | 
| 131 | 
         
            +
                    if len(all_box) == 0:
         
     | 
| 132 | 
         
            +
                        return sentence
         
     | 
| 133 | 
         
            +
                    bboxes_strs = [self.format_box(bboxes) for bboxes in bboxes_seq]
         
     | 
| 134 | 
         
            +
                    converted = sentence.replace(self.bboxes_token, '{}').format(*bboxes_strs)
         
     | 
| 135 | 
         
            +
                    return converted
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                def call_on_point(self, sentence: str, points_seq: BoxesSeq) -> str:
         
     | 
| 138 | 
         
            +
                    all_box = self.points_token_pat.findall(sentence)
         
     | 
| 139 | 
         
            +
                    assert len(all_box) == len(points_seq), f"not match. sentence: {sentence}. boxes:{points_seq}"
         
     | 
| 140 | 
         
            +
                    if len(all_box) == 0:
         
     | 
| 141 | 
         
            +
                        return sentence
         
     | 
| 142 | 
         
            +
                    bboxes_strs = [self.format_point(bboxes) for bboxes in points_seq]
         
     | 
| 143 | 
         
            +
                    converted = sentence.replace(self.points_token, '{}').format(*bboxes_strs)
         
     | 
| 144 | 
         
            +
                    return converted
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                def format_point(self, points) -> str:
         
     | 
| 147 | 
         
            +
                    raise NotImplementedError
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                def format_box(self, bboxes: Boxes) -> str:
         
     | 
| 150 | 
         
            +
                    raise NotImplementedError
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                def extract(self, string: str) -> List[Boxes]:
         
     | 
| 153 | 
         
            +
                    raise NotImplementedError
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                def extract_point(self, string: str) -> List[Boxes]:
         
     | 
| 156 | 
         
            +
                    raise NotImplementedError
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
            @BOXES_PROCESSOR.register_module()
         
     | 
| 160 | 
         
            +
            class PlainBoxFormatter(BoxFormatter):
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                def __init__(self, *args, precision=3, use_small_brackets=False, **kwargs):
         
     | 
| 163 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 164 | 
         
            +
                    self.precision = precision
         
     | 
| 165 | 
         
            +
                    self.use_small_brackets = use_small_brackets
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    small_brackets_pat = re.compile(r'\(\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\)')
         
     | 
| 168 | 
         
            +
                    small_brackets_point_pat = re.compile(r'\(\d(?:\.\d*)?(?:,\d(?:\.\d*)?)(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?))*\)')
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    middle_brackets_pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\]')
         
     | 
| 171 | 
         
            +
                    middle_brackets_point_pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?)(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?))*\]')
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                    self.pat = small_brackets_pat if use_small_brackets else middle_brackets_pat
         
     | 
| 174 | 
         
            +
                    self.point_pat = small_brackets_point_pat if use_small_brackets else middle_brackets_point_pat
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                def format_box(self, boxes: Boxes) -> str:
         
     | 
| 177 | 
         
            +
                    box_strs = []
         
     | 
| 178 | 
         
            +
                    for box in boxes:
         
     | 
| 179 | 
         
            +
                        box_strs.append(','.join([f"{elem:.{self.precision}f}" for elem in box]))
         
     | 
| 180 | 
         
            +
                    box_str = ';'.join(box_strs)
         
     | 
| 181 | 
         
            +
                    if self.use_small_brackets:
         
     | 
| 182 | 
         
            +
                        return "(" + box_str + ")"
         
     | 
| 183 | 
         
            +
                    return "[" + box_str + "]"
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                def format_point(self, points) -> str:
         
     | 
| 186 | 
         
            +
                    return self.format_box(points)
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                def extract(self, string: str) -> List[Boxes]:
         
     | 
| 189 | 
         
            +
                    """ balabala<boxes>balabala<boxes> -> [boxes, boxes] """
         
     | 
| 190 | 
         
            +
                    ret = []
         
     | 
| 191 | 
         
            +
                    for bboxes_str in self.pat.findall(string):
         
     | 
| 192 | 
         
            +
                        bboxes = []
         
     | 
| 193 | 
         
            +
                        bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";")
         
     | 
| 194 | 
         
            +
                        for bbox_str in bbox_strs:
         
     | 
| 195 | 
         
            +
                            bbox = list(map(float, bbox_str.split(',')))
         
     | 
| 196 | 
         
            +
                            bboxes.append(bbox)
         
     | 
| 197 | 
         
            +
                        ret.append(bboxes)
         
     | 
| 198 | 
         
            +
                    return ret
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                def extract_point(self, string: str) -> List[Boxes]:
         
     | 
| 201 | 
         
            +
                    """ balabala<boxes>balabala<boxes> -> [boxes, boxes] """
         
     | 
| 202 | 
         
            +
                    ret = []
         
     | 
| 203 | 
         
            +
                    for bboxes_str in self.point_pat.findall(string):
         
     | 
| 204 | 
         
            +
                        bboxes = []
         
     | 
| 205 | 
         
            +
                        bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";")
         
     | 
| 206 | 
         
            +
                        for bbox_str in bbox_strs:
         
     | 
| 207 | 
         
            +
                            bbox = list(map(float, bbox_str.split(',')))
         
     | 
| 208 | 
         
            +
                            bboxes.append(bbox)
         
     | 
| 209 | 
         
            +
                        ret.append(bboxes)
         
     | 
| 210 | 
         
            +
                    return ret
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
            @BOXES_PROCESSOR.register_module()
         
     | 
| 214 | 
         
            +
            class TokenFormatter(BoxFormatter):
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                def __init__(self, num_bins=1001):
         
     | 
| 217 | 
         
            +
                    super().__init__()
         
     | 
| 218 | 
         
            +
                    self.extract_box_pat = re.compile(r'<b_st><bin_\d*?>(?:<bin_\d*?>){3}(?:<b_sep><bin_\d*?>(?:<bin_\d*?>){3})*<b_ed>')
         
     | 
| 219 | 
         
            +
                    self.extract_point_pat = re.compile(r'<p_st><bin_\d*?>(?:<bin_\d*?>){1}(?:<p_sep><bin_\d*?>(?:<bin_\d*?>){1})*<p_ed>')
         
     | 
| 220 | 
         
            +
                    self.num_bins = num_bins
         
     | 
| 221 | 
         
            +
                    self.use_sep = True
         
     | 
| 222 | 
         
            +
                    self.use_begin_end = True
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    self.box_begin = '<b_st>'
         
     | 
| 225 | 
         
            +
                    self.box_sep = '<b_sep>'
         
     | 
| 226 | 
         
            +
                    self.box_end = '<b_ed>'
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                    self.point_begin = '<p_st>'
         
     | 
| 229 | 
         
            +
                    self.point_sep = '<p_sep>'
         
     | 
| 230 | 
         
            +
                    self.point_end = '<p_ed>'
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                def format_point(self, points) -> str:
         
     | 
| 233 | 
         
            +
                    final_str = []
         
     | 
| 234 | 
         
            +
                    for bbox in points:
         
     | 
| 235 | 
         
            +
                        quant_x0 = "<bin_{}>".format(round((bbox[0] * (self.num_bins - 1))))
         
     | 
| 236 | 
         
            +
                        quant_y0 = "<bin_{}>".format(round((bbox[1] * (self.num_bins - 1))))
         
     | 
| 237 | 
         
            +
                        region_coord = "{} {}".format(quant_x0, quant_y0)
         
     | 
| 238 | 
         
            +
                        final_str.append(region_coord)
         
     | 
| 239 | 
         
            +
                    if self.use_sep:
         
     | 
| 240 | 
         
            +
                        final_str = self.point_sep.join(final_str)
         
     | 
| 241 | 
         
            +
                    else:
         
     | 
| 242 | 
         
            +
                        final_str = ''.join(final_str)
         
     | 
| 243 | 
         
            +
                    if self.use_begin_end:
         
     | 
| 244 | 
         
            +
                        final_str = self.point_begin + final_str + self.point_end
         
     | 
| 245 | 
         
            +
                    return final_str
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                def format_box(self, bboxes: Boxes) -> str:
         
     | 
| 248 | 
         
            +
                    final_str = []
         
     | 
| 249 | 
         
            +
                    for bbox in bboxes:
         
     | 
| 250 | 
         
            +
                        quant_x0 = "<bin_{}>".format(round((bbox[0] * (self.num_bins - 1))))
         
     | 
| 251 | 
         
            +
                        quant_y0 = "<bin_{}>".format(round((bbox[1] * (self.num_bins - 1))))
         
     | 
| 252 | 
         
            +
                        quant_x1 = "<bin_{}>".format(round((bbox[2] * (self.num_bins - 1))))
         
     | 
| 253 | 
         
            +
                        quant_y1 = "<bin_{}>".format(round((bbox[3] * (self.num_bins - 1))))
         
     | 
| 254 | 
         
            +
                        region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
         
     | 
| 255 | 
         
            +
                        final_str.append(region_coord)
         
     | 
| 256 | 
         
            +
                    if self.use_sep:
         
     | 
| 257 | 
         
            +
                        final_str = self.box_sep.join(final_str)
         
     | 
| 258 | 
         
            +
                    else:
         
     | 
| 259 | 
         
            +
                        final_str = ''.join(final_str)
         
     | 
| 260 | 
         
            +
                    if self.use_begin_end:
         
     | 
| 261 | 
         
            +
                        final_str = self.box_begin + final_str + self.box_end
         
     | 
| 262 | 
         
            +
                    return final_str
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                def extract(self, string: str) -> List[Boxes]:
         
     | 
| 265 | 
         
            +
                    ret = []
         
     | 
| 266 | 
         
            +
                    for bboxes_str in self.extract_box_pat.findall(string.replace(" ", "")):
         
     | 
| 267 | 
         
            +
                        bboxes = []
         
     | 
| 268 | 
         
            +
                        bbox_strs = bboxes_str.replace(self.box_begin, "").replace(self.box_end, "").split(self.box_sep)
         
     | 
| 269 | 
         
            +
                        for bbox_str in bbox_strs:
         
     | 
| 270 | 
         
            +
                            elems = list(map(int, re.findall(r'<bin_(\d*?)>', bbox_str)))
         
     | 
| 271 | 
         
            +
                            bbox = [elem / (self.num_bins - 1) for elem in elems]
         
     | 
| 272 | 
         
            +
                            bboxes.append(bbox)
         
     | 
| 273 | 
         
            +
                        ret.append(bboxes)
         
     | 
| 274 | 
         
            +
                    return ret
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                def extract_point(self, string: str) -> List[Boxes]:
         
     | 
| 277 | 
         
            +
                    ret = []
         
     | 
| 278 | 
         
            +
                    for bboxes_str in self.extract_point_pat.findall(string):
         
     | 
| 279 | 
         
            +
                        bboxes = []
         
     | 
| 280 | 
         
            +
                        bbox_strs = bboxes_str.replace(self.point_begin, "").replace(self.point_end, "").split(self.point_sep)
         
     | 
| 281 | 
         
            +
                        for bbox_str in bbox_strs:
         
     | 
| 282 | 
         
            +
                            elems = list(map(int, re.findall(r'<bin_(\d*?)>', bbox_str)))
         
     | 
| 283 | 
         
            +
                            bbox = [elem / (self.num_bins - 1) for elem in elems]
         
     | 
| 284 | 
         
            +
                            bboxes.append(bbox)
         
     | 
| 285 | 
         
            +
                        ret.append(bboxes)
         
     | 
| 286 | 
         
            +
                    return ret
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                def post_process_model_tokenizer(self, model, preprocessor, model_args, training_args):
         
     | 
| 289 | 
         
            +
                    tokenizer = preprocessor['text']
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    additional_special_tokens = [
         
     | 
| 292 | 
         
            +
                        self.box_begin, self.box_sep, self.box_end,
         
     | 
| 293 | 
         
            +
                        self.point_begin, self.point_sep, self.point_end,
         
     | 
| 294 | 
         
            +
                    ]
         
     | 
| 295 | 
         
            +
                    for i in range(self.num_bins):
         
     | 
| 296 | 
         
            +
                        additional_special_tokens.append(f'<bin_{i}>')
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    smart_tokenizer_and_embedding_resize(
         
     | 
| 299 | 
         
            +
                        {'additional_special_tokens': additional_special_tokens},
         
     | 
| 300 | 
         
            +
                        tokenizer,
         
     | 
| 301 | 
         
            +
                        model,
         
     | 
| 302 | 
         
            +
                    )
         
     | 
| 303 | 
         
            +
                    return model, preprocessor
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
            # FIXME: merge into load_pretrained
         
     | 
| 307 | 
         
            +
            def prepare_target_processor(
         
     | 
| 308 | 
         
            +
                    model,  # multimodal llm
         
     | 
| 309 | 
         
            +
                    preprocessor: Dict[str, Any],
         
     | 
| 310 | 
         
            +
                    model_args,
         
     | 
| 311 | 
         
            +
                    training_args,
         
     | 
| 312 | 
         
            +
            ):
         
     | 
| 313 | 
         
            +
                if not hasattr(model_args, 'target_processor'):
         
     | 
| 314 | 
         
            +
                    return model, preprocessor
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                target_processor = {}
         
     | 
| 317 | 
         
            +
                if 'boxes' in model_args['target_processor']:
         
     | 
| 318 | 
         
            +
                    boxes_cfg = model_args['target_processor']['boxes']
         
     | 
| 319 | 
         
            +
                    boxes_processor = BOXES_PROCESSOR.build(boxes_cfg)
         
     | 
| 320 | 
         
            +
                    target_processor['boxes'] = boxes_processor
         
     | 
| 321 | 
         
            +
                    if hasattr(boxes_processor, "post_process_model_tokenizer"):
         
     | 
| 322 | 
         
            +
                        model, preprocessor = boxes_processor.post_process_model_tokenizer(
         
     | 
| 323 | 
         
            +
                            model, preprocessor, model_args, training_args,
         
     | 
| 324 | 
         
            +
                        )
         
     | 
| 325 | 
         
            +
                preprocessor['target'] = target_processor
         
     | 
| 326 | 
         
            +
                return model, preprocessor
         
     | 
    	
        mllm/dataset/process_function/shikra_process_function.py
    ADDED
    
    | 
         @@ -0,0 +1,178 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import sys
         
     | 
| 2 | 
         
            +
            import copy
         
     | 
| 3 | 
         
            +
            import warnings
         
     | 
| 4 | 
         
            +
            import logging
         
     | 
| 5 | 
         
            +
            from typing import Dict, Any, List
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import PIL.Image
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            from PIL import Image
         
     | 
| 10 | 
         
            +
            from transformers import LlamaTokenizer
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from ..root import (
         
     | 
| 13 | 
         
            +
                FUNCTIONS,
         
     | 
| 14 | 
         
            +
                IMAGE_PLACEHOLDER,
         
     | 
| 15 | 
         
            +
                BaseImageProcessFunc,
         
     | 
| 16 | 
         
            +
                BaseConvProcessFunc,
         
     | 
| 17 | 
         
            +
                BaseTextProcessFunc,
         
     | 
| 18 | 
         
            +
            )
         
     | 
| 19 | 
         
            +
            from ...conversation import SeparatorStyle, Conversation
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            IGNORE_INDEX = -100
         
     | 
| 22 | 
         
            +
            DEFAULT_IMAGE_TOKEN = IMAGE_PLACEHOLDER
         
     | 
| 23 | 
         
            +
            DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
         
     | 
| 24 | 
         
            +
            DEFAULT_IM_START_TOKEN = "<im_start>"
         
     | 
| 25 | 
         
            +
            DEFAULT_IM_END_TOKEN = "<im_end>"
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 28 | 
         
            +
            logger.setLevel(logging.INFO)
         
     | 
| 29 | 
         
            +
            logging.basicConfig(
         
     | 
| 30 | 
         
            +
                format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
         
     | 
| 31 | 
         
            +
                datefmt="%m/%d/%Y %H:%M:%S",
         
     | 
| 32 | 
         
            +
                handlers=[logging.StreamHandler(sys.stdout), ],
         
     | 
| 33 | 
         
            +
            )
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            @FUNCTIONS.register_module()
         
     | 
| 37 | 
         
            +
            class ShikraConvProcess(BaseConvProcessFunc):
         
     | 
| 38 | 
         
            +
                def __call__(self, raw_conv: List[Dict[str, Any]], preprocessor: Dict[str, Any], conv_template: Conversation) -> List[Dict[str, Any]]:
         
     | 
| 39 | 
         
            +
                    conv_processor_cfg = preprocessor['conv']
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    image_token_len = conv_processor_cfg['image_token_len']
         
     | 
| 42 | 
         
            +
                    sep_image_conv_front = conv_processor_cfg.get('sep_image_conv_front', False)
         
     | 
| 43 | 
         
            +
                    use_im_start_end = conv_processor_cfg.get('use_im_start_end', False)
         
     | 
| 44 | 
         
            +
                    # assert DEFAULT_IMAGE_PATCH_TOKEN in preprocessor['text'].get_vocab()
         
     | 
| 45 | 
         
            +
                    # if use_im_start_end:
         
     | 
| 46 | 
         
            +
                    #     assert DEFAULT_IM_START_TOKEN in preprocessor['text'].get_vocab()
         
     | 
| 47 | 
         
            +
                    #     assert DEFAULT_IM_END_TOKEN in preprocessor['text'].get_vocab()
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    if sep_image_conv_front:
         
     | 
| 50 | 
         
            +
                        raw_conv[0]['value'] = raw_conv[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
         
     | 
| 51 | 
         
            +
                        raw_conv[0]['value'] = DEFAULT_IMAGE_TOKEN + conv_template.sep + conv_template.roles[0] + ": " + raw_conv[0]['value']
         
     | 
| 52 | 
         
            +
                    for sentence in raw_conv:
         
     | 
| 53 | 
         
            +
                        replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
         
     | 
| 54 | 
         
            +
                        if use_im_start_end:
         
     | 
| 55 | 
         
            +
                            replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
         
     | 
| 56 | 
         
            +
                        sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    return raw_conv
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            @FUNCTIONS.register_module()
         
     | 
| 62 | 
         
            +
            class ShikraTextProcess(BaseTextProcessFunc):
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                def __call__(self, conv: Conversation, preprocessor: Dict[str, Any], mode: str, **tokenize_kwargs) -> Dict[str, Any]:
         
     | 
| 65 | 
         
            +
                    tokenizer = preprocessor['text']
         
     | 
| 66 | 
         
            +
                    assert isinstance(tokenizer, LlamaTokenizer), "only work for LlamaTokenizer"
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    _truncation_size = tokenize_kwargs.pop('truncation_size', None)
         
     | 
| 69 | 
         
            +
                    _kwargs = {'return_tensors': 'pt'}
         
     | 
| 70 | 
         
            +
                    _kwargs.update(tokenize_kwargs)
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    if conv.sep_style == SeparatorStyle.ADD_COLON_TWO:
         
     | 
| 73 | 
         
            +
                        if mode in ['train']:
         
     | 
| 74 | 
         
            +
                            ret = self.tk_conv_colon_two_train(conv, tokenizer, **_kwargs)
         
     | 
| 75 | 
         
            +
                        else:
         
     | 
| 76 | 
         
            +
                            ret = self.tk_conv_colon_two_eval(conv, tokenizer, **_kwargs)
         
     | 
| 77 | 
         
            +
                    else:
         
     | 
| 78 | 
         
            +
                        raise ValueError(f"unrecognized conv_style: {conv.sep_style}.\n the conv is {conv}")
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    if _truncation_size is None:
         
     | 
| 81 | 
         
            +
                        return ret
         
     | 
| 82 | 
         
            +
                    if len(ret['input_ids']) <= _truncation_size:
         
     | 
| 83 | 
         
            +
                        return ret
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    origin_len = len(ret['input_ids'])
         
     | 
| 86 | 
         
            +
                    ids_to_remove_num = origin_len - _truncation_size
         
     | 
| 87 | 
         
            +
                    # truncation. should carefully not truncate <img_token>
         
     | 
| 88 | 
         
            +
                    ids_should_not_remove = list(map(
         
     | 
| 89 | 
         
            +
                        tokenizer.convert_tokens_to_ids,
         
     | 
| 90 | 
         
            +
                        (DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN)
         
     | 
| 91 | 
         
            +
                    ))
         
     | 
| 92 | 
         
            +
                    back_no_image = all(ids not in ids_should_not_remove for ids in ret['input_ids'][_truncation_size:])
         
     | 
| 93 | 
         
            +
                    if back_no_image:
         
     | 
| 94 | 
         
            +
                        tgt_ids = list(range(_truncation_size))
         
     | 
| 95 | 
         
            +
                    else:
         
     | 
| 96 | 
         
            +
                        ids_to_remove = set()
         
     | 
| 97 | 
         
            +
                        for idx in range(origin_len - 1, -1, -1):
         
     | 
| 98 | 
         
            +
                            if ret['input_ids'][idx] not in ids_should_not_remove:
         
     | 
| 99 | 
         
            +
                                ids_to_remove.add(idx)
         
     | 
| 100 | 
         
            +
                                if len(ids_to_remove) >= ids_to_remove_num:
         
     | 
| 101 | 
         
            +
                                    break
         
     | 
| 102 | 
         
            +
                        tgt_ids = [_ for _ in range(origin_len) if _ not in ids_to_remove]
         
     | 
| 103 | 
         
            +
                    logger.warning(f"truncate sample size from {origin_len} to {len(tgt_ids)}.")
         
     | 
| 104 | 
         
            +
                    assert len(tgt_ids) == _truncation_size, f"{len(tgt_ids)}, {_truncation_size}, {ret['input_ids'].tolist()}"
         
     | 
| 105 | 
         
            +
                    truncated_ret = {k: v[tgt_ids] for k, v in ret.items()}
         
     | 
| 106 | 
         
            +
                    return truncated_ret
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                # noinspection PyMethodMayBeStatic
         
     | 
| 109 | 
         
            +
                def tk_conv_colon_two_train(self, conv, tokenizer, **kwargs):
         
     | 
| 110 | 
         
            +
                    conversation = conv.get_prompt()
         
     | 
| 111 | 
         
            +
                    input_ids = tokenizer([conversation, ], **kwargs).input_ids[0]
         
     | 
| 112 | 
         
            +
                    target = copy.deepcopy(input_ids)
         
     | 
| 113 | 
         
            +
                    assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
         
     | 
| 114 | 
         
            +
                    # Mask targets
         
     | 
| 115 | 
         
            +
                    sep = conv.sep + conv.roles[1] + ": "
         
     | 
| 116 | 
         
            +
                    total_len = int(target.ne(tokenizer.pad_token_id).sum())
         
     | 
| 117 | 
         
            +
                    rounds = conversation.split(conv.sep2)
         
     | 
| 118 | 
         
            +
                    cur_len = 1
         
     | 
| 119 | 
         
            +
                    target[:cur_len] = IGNORE_INDEX
         
     | 
| 120 | 
         
            +
                    for i, rou in enumerate(rounds):
         
     | 
| 121 | 
         
            +
                        if rou == "":
         
     | 
| 122 | 
         
            +
                            break
         
     | 
| 123 | 
         
            +
                        parts = rou.split(sep)
         
     | 
| 124 | 
         
            +
                        if len(parts) != 2:
         
     | 
| 125 | 
         
            +
                            break
         
     | 
| 126 | 
         
            +
                        parts[0] += sep
         
     | 
| 127 | 
         
            +
                        round_len = len(tokenizer(rou).input_ids)
         
     | 
| 128 | 
         
            +
                        instruction_len = len(tokenizer(parts[0]).input_ids) - 2  # <s> <space>
         
     | 
| 129 | 
         
            +
                        target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
         
     | 
| 130 | 
         
            +
                        cur_len += round_len
         
     | 
| 131 | 
         
            +
                    target[cur_len:] = IGNORE_INDEX
         
     | 
| 132 | 
         
            +
                    if cur_len < tokenizer.model_max_length:
         
     | 
| 133 | 
         
            +
                        if cur_len != total_len:
         
     | 
| 134 | 
         
            +
                            target[:] = IGNORE_INDEX
         
     | 
| 135 | 
         
            +
                            warnings.warn(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}. (ignored):\n{conversation}")
         
     | 
| 136 | 
         
            +
                    return dict(
         
     | 
| 137 | 
         
            +
                        input_ids=input_ids,
         
     | 
| 138 | 
         
            +
                        attention_mask=input_ids.ne(tokenizer.pad_token_id),
         
     | 
| 139 | 
         
            +
                        labels=target,
         
     | 
| 140 | 
         
            +
                    )
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                # noinspection PyMethodMayBeStatic
         
     | 
| 143 | 
         
            +
                def tk_conv_colon_two_eval(self, conv, tokenizer, **kwargs):
         
     | 
| 144 | 
         
            +
                    assert len(conv.messages) >= 2
         
     | 
| 145 | 
         
            +
                    # target = conv.messages[-1][-1]
         
     | 
| 146 | 
         
            +
                    target = conv.get_prompt()
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    conv.messages[-1][-1] = ""
         
     | 
| 149 | 
         
            +
                    conversation = conv.get_prompt()
         
     | 
| 150 | 
         
            +
                    input_ids = tokenizer([conversation, ], **kwargs).input_ids[0]
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    target = tokenizer([target, ], add_special_tokens=False, **kwargs).input_ids[0]
         
     | 
| 153 | 
         
            +
                    target[target == tokenizer.pad_token_id] = IGNORE_INDEX
         
     | 
| 154 | 
         
            +
                    return dict(
         
     | 
| 155 | 
         
            +
                        input_ids=input_ids,
         
     | 
| 156 | 
         
            +
                        attention_mask=input_ids.ne(tokenizer.pad_token_id),
         
     | 
| 157 | 
         
            +
                        labels=target,
         
     | 
| 158 | 
         
            +
                    )
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
            @FUNCTIONS.register_module()
         
     | 
| 162 | 
         
            +
            class ShikraImageProcessor(BaseImageProcessFunc):
         
     | 
| 163 | 
         
            +
                def __call__(self, image: Image.Image, preprocessor: Dict[str, Any]) -> Dict[str, Any]:
         
     | 
| 164 | 
         
            +
                    image_processor = preprocessor['image']
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    if isinstance(image, (list, tuple)):
         
     | 
| 167 | 
         
            +
                        image = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
         
     | 
| 168 | 
         
            +
                        assert False, 'Shikra not support MultiImage'
         
     | 
| 169 | 
         
            +
                    elif isinstance(image, PIL.Image.Image):
         
     | 
| 170 | 
         
            +
                        image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
         
     | 
| 171 | 
         
            +
                    else:
         
     | 
| 172 | 
         
            +
                        if hasattr(image_processor, 'crop_size'):
         
     | 
| 173 | 
         
            +
                            crop_size = image_processor.crop_size
         
     | 
| 174 | 
         
            +
                            height, width = crop_size['height'], crop_size['width']
         
     | 
| 175 | 
         
            +
                        else:
         
     | 
| 176 | 
         
            +
                            raise ValueError("got empty image. and don't know how to pad")
         
     | 
| 177 | 
         
            +
                        image = torch.zeros(3, height, width)
         
     | 
| 178 | 
         
            +
                    return {'image': image}
         
     | 
    	
        mllm/dataset/root.py
    ADDED
    
    | 
         @@ -0,0 +1,67 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Dict, Any, List, Tuple
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from PIL import Image
         
     | 
| 4 | 
         
            +
            from mmengine import DATASETS, TRANSFORMS, METRICS, FUNCTIONS, Registry
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from ..conversation import Conversation
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            IMAGE_PLACEHOLDER = '<image>'
         
     | 
| 9 | 
         
            +
            BOXES_PLACEHOLDER = '<boxes>'
         
     | 
| 10 | 
         
            +
            EXPR_PLACEHOLDER = '<expr>'
         
     | 
| 11 | 
         
            +
            OBJS_PLACEHOLDER = '<objs>'
         
     | 
| 12 | 
         
            +
            QUESTION_PLACEHOLDER = '<question>'
         
     | 
| 13 | 
         
            +
            POINTS_PLACEHOLDER = '<points>'
         
     | 
| 14 | 
         
            +
            # processor
         
     | 
| 15 | 
         
            +
            BOXES_PROCESSOR = Registry('Processor for Boxes')
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            # only for static type checking
         
     | 
| 19 | 
         
            +
            class BaseConvProcessFunc:
         
     | 
| 20 | 
         
            +
                def __call__(
         
     | 
| 21 | 
         
            +
                        self,
         
     | 
| 22 | 
         
            +
                        raw_conv: List[Dict[str, Any]],
         
     | 
| 23 | 
         
            +
                        preprocessor: Dict[str, Any],
         
     | 
| 24 | 
         
            +
                        conv_template: Conversation,
         
     | 
| 25 | 
         
            +
                ) -> List[Dict[str, Any]]:
         
     | 
| 26 | 
         
            +
                    raise NotImplementedError
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            class BaseTargetProcessFunc:
         
     | 
| 30 | 
         
            +
                def __call__(
         
     | 
| 31 | 
         
            +
                        self,
         
     | 
| 32 | 
         
            +
                        raw_conv: List[Dict[str, Any]],
         
     | 
| 33 | 
         
            +
                        target: Dict[str, Any],
         
     | 
| 34 | 
         
            +
                        preprocessor: Dict[str, Any],
         
     | 
| 35 | 
         
            +
                ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
         
     | 
| 36 | 
         
            +
                    raise NotImplementedError
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            class BaseTextProcessFunc:
         
     | 
| 40 | 
         
            +
                def __call__(
         
     | 
| 41 | 
         
            +
                        self,
         
     | 
| 42 | 
         
            +
                        conv: Conversation,
         
     | 
| 43 | 
         
            +
                        preprocessor: Dict[str, Any],
         
     | 
| 44 | 
         
            +
                        mode: str,
         
     | 
| 45 | 
         
            +
                        **tokenize_kwargs,
         
     | 
| 46 | 
         
            +
                ) -> Dict[str, Any]:
         
     | 
| 47 | 
         
            +
                    raise NotImplementedError
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            class BaseImageProcessFunc:
         
     | 
| 51 | 
         
            +
                def __call__(
         
     | 
| 52 | 
         
            +
                        self,
         
     | 
| 53 | 
         
            +
                        image: Image.Image,
         
     | 
| 54 | 
         
            +
                        preprocessor: Dict[str, Any],
         
     | 
| 55 | 
         
            +
                ) -> Dict[str, Any]:
         
     | 
| 56 | 
         
            +
                    raise NotImplementedError
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            __all__ = [
         
     | 
| 60 | 
         
            +
                'IMAGE_PLACEHOLDER', 'BOXES_PLACEHOLDER', 'EXPR_PLACEHOLDER', 'OBJS_PLACEHOLDER', 'QUESTION_PLACEHOLDER', 'POINTS_PLACEHOLDER',
         
     | 
| 61 | 
         
            +
                'FUNCTIONS',
         
     | 
| 62 | 
         
            +
                'DATASETS',
         
     | 
| 63 | 
         
            +
                'TRANSFORMS',
         
     | 
| 64 | 
         
            +
                'METRICS',
         
     | 
| 65 | 
         
            +
                'BOXES_PROCESSOR',
         
     | 
| 66 | 
         
            +
                'BaseConvProcessFunc', 'BaseTargetProcessFunc', 'BaseTextProcessFunc', 'BaseImageProcessFunc',
         
     | 
| 67 | 
         
            +
            ]
         
     | 
    	
        mllm/dataset/single_image_convsation.py
    ADDED
    
    | 
         @@ -0,0 +1,284 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import warnings
         
     | 
| 2 | 
         
            +
            from functools import partial
         
     | 
| 3 | 
         
            +
            from typing import Dict, Any, Callable, List, Optional, Tuple, Type
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            from PIL import Image
         
     | 
| 7 | 
         
            +
            from torch.utils.data import Dataset
         
     | 
| 8 | 
         
            +
            from transformers import TrainingArguments
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from .root import IMAGE_PLACEHOLDER, BOXES_PLACEHOLDER
         
     | 
| 11 | 
         
            +
            from ..conversation import Conversation, get_conv_template
         
     | 
| 12 | 
         
            +
            from ..utils import post_process_generate_ids
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class SingleImageConvDatasetMixin:
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                def __init__(
         
     | 
| 18 | 
         
            +
                        self,
         
     | 
| 19 | 
         
            +
                        *args,
         
     | 
| 20 | 
         
            +
                        preprocessor: Dict[str, Any],
         
     | 
| 21 | 
         
            +
                        process_func: Dict[str, Any],
         
     | 
| 22 | 
         
            +
                        conv_template: Callable[[], Conversation] = partial(get_conv_template, name='vicuna_v1.1'),
         
     | 
| 23 | 
         
            +
                        mode='train',
         
     | 
| 24 | 
         
            +
                        tokenize_kwargs: dict = None,
         
     | 
| 25 | 
         
            +
                        training_args: TrainingArguments = None,
         
     | 
| 26 | 
         
            +
                        transforms: Optional[Callable] = None,
         
     | 
| 27 | 
         
            +
                        **kwargs,
         
     | 
| 28 | 
         
            +
                ):
         
     | 
| 29 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 30 | 
         
            +
                    assert mode in ['train', 'validation', 'test']
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    self.preprocessor = preprocessor
         
     | 
| 33 | 
         
            +
                    self.process_func = process_func
         
     | 
| 34 | 
         
            +
                    self.conv_template = conv_template
         
     | 
| 35 | 
         
            +
                    self.mode = mode
         
     | 
| 36 | 
         
            +
                    self.tokenize_kwargs = tokenize_kwargs if tokenize_kwargs is not None else {}
         
     | 
| 37 | 
         
            +
                    self.training_args = training_args
         
     | 
| 38 | 
         
            +
                    self.transforms = transforms
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                def __getitem__(self, index, debug_mode=False, return_conv=False) -> Dict[str, Any]:
         
     | 
| 41 | 
         
            +
                    # getitem
         
     | 
| 42 | 
         
            +
                    item = self.get_raw_item(index)
         
     | 
| 43 | 
         
            +
                    image: Image.Image = item.get('image', None)
         
     | 
| 44 | 
         
            +
                    target: Dict[str, Any] = item.get('target', None)
         
     | 
| 45 | 
         
            +
                    raw_conv: List[Dict[str, Any]] = item['conversations']
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    # transform
         
     | 
| 48 | 
         
            +
                    assert isinstance(image, list) == isinstance(target, list)
         
     | 
| 49 | 
         
            +
                    multimage_mode = isinstance(image, list)
         
     | 
| 50 | 
         
            +
                    if isinstance(image, list):
         
     | 
| 51 | 
         
            +
                        # TODO: validate raw item
         
     | 
| 52 | 
         
            +
                        transformed_image, transformed_target = [], []
         
     | 
| 53 | 
         
            +
                        for img, tgt in zip(image, target):
         
     | 
| 54 | 
         
            +
                            if self.transforms is not None and image is not None:
         
     | 
| 55 | 
         
            +
                                img, tgt = self.transforms(img, tgt)
         
     | 
| 56 | 
         
            +
                            if tgt is not None:
         
     | 
| 57 | 
         
            +
                                tgt['width'], tgt['height'] = img.width, img.height
         
     | 
| 58 | 
         
            +
                            transformed_image.append(img)
         
     | 
| 59 | 
         
            +
                            transformed_target.append(tgt)
         
     | 
| 60 | 
         
            +
                        image, target = transformed_image, transformed_target
         
     | 
| 61 | 
         
            +
                    else:
         
     | 
| 62 | 
         
            +
                        self.validate_raw_item(item)  # only validate for single image.
         
     | 
| 63 | 
         
            +
                        if self.transforms is not None and image is not None:
         
     | 
| 64 | 
         
            +
                            image, target = self.transforms(image, target)
         
     | 
| 65 | 
         
            +
                        has_image = 'image' in item and bool(item['image'])
         
     | 
| 66 | 
         
            +
                        has_target = 'target' in item and bool(item['target']) and any(bool(elem) for elem in item['target'].values())
         
     | 
| 67 | 
         
            +
                        if has_target and has_image:
         
     | 
| 68 | 
         
            +
                            target['width'], target['height'] = image.width, image.height
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    # preprocess
         
     | 
| 71 | 
         
            +
                    raw_conv = self.process_conv(raw_conv)
         
     | 
| 72 | 
         
            +
                    raw_conv, image = self.process_conv_multimage(raw_conv, image)
         
     | 
| 73 | 
         
            +
                    raw_conv, _ = self.process_target(raw_conv, target, multimage_mode=multimage_mode)
         
     | 
| 74 | 
         
            +
                    conv = self.build_conv(raw_conv)
         
     | 
| 75 | 
         
            +
                    if return_conv:
         
     | 
| 76 | 
         
            +
                        # noinspection PyTypeChecker
         
     | 
| 77 | 
         
            +
                        return conv
         
     | 
| 78 | 
         
            +
                    text_dict = self.process_text(conv)
         
     | 
| 79 | 
         
            +
                    image_dict = self.process_image(image)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    # return
         
     | 
| 82 | 
         
            +
                    ret_dict = {}
         
     | 
| 83 | 
         
            +
                    ret_dict.update(text_dict)
         
     | 
| 84 | 
         
            +
                    ret_dict.update(image_dict)
         
     | 
| 85 | 
         
            +
                    self._print_sample(ret_dict, raw_conv, conv)
         
     | 
| 86 | 
         
            +
                    if debug_mode:
         
     | 
| 87 | 
         
            +
                        return {'ret': ret_dict, 'raw_conv': raw_conv, 'conv': conv, 'image': image}
         
     | 
| 88 | 
         
            +
                    return ret_dict
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def __len__(self):
         
     | 
| 91 | 
         
            +
                    raise NotImplementedError
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                # noinspection PyMethodMayBeStatic
         
     | 
| 94 | 
         
            +
                def process_conv_multimage(self, raw_conv, image):
         
     | 
| 95 | 
         
            +
                    # re-sort multi image
         
     | 
| 96 | 
         
            +
                    if image is None:
         
     | 
| 97 | 
         
            +
                        return raw_conv, image
         
     | 
| 98 | 
         
            +
                    if not isinstance(image, (list, tuple)):
         
     | 
| 99 | 
         
            +
                        return raw_conv, image
         
     | 
| 100 | 
         
            +
                    image_seqs = []
         
     | 
| 101 | 
         
            +
                    for conv in raw_conv:
         
     | 
| 102 | 
         
            +
                        image_seqs.extend(conv['image_seq'] if 'image_seq' in conv else [])
         
     | 
| 103 | 
         
            +
                    images = []
         
     | 
| 104 | 
         
            +
                    for idx in image_seqs:
         
     | 
| 105 | 
         
            +
                        images.append(image[idx])
         
     | 
| 106 | 
         
            +
                    return raw_conv, images
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                def get_raw_item(self, index) -> Dict[str, Any]:
         
     | 
| 109 | 
         
            +
                    """
         
     | 
| 110 | 
         
            +
                    return item format like this.
         
     | 
| 111 | 
         
            +
                    item = {
         
     | 
| 112 | 
         
            +
                        'image': # PIL.Image.Image,
         
     | 
| 113 | 
         
            +
                        'target': {
         
     | 
| 114 | 
         
            +
                            # xmin, ymin, xmax, ymax
         
     | 
| 115 | 
         
            +
                            'boxes': [
         
     | 
| 116 | 
         
            +
                                [10, 10, 256, 265],  # dog1
         
     | 
| 117 | 
         
            +
                                [24, 18, 378, 768],  # dog2
         
     | 
| 118 | 
         
            +
                                [100, 310, 670, 653],  # man
         
     | 
| 119 | 
         
            +
                                [278, 320, 809, 673],  # rope
         
     | 
| 120 | 
         
            +
                            ],
         
     | 
| 121 | 
         
            +
                        }
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                        "conversations": [
         
     | 
| 124 | 
         
            +
                            {
         
     | 
| 125 | 
         
            +
                                'from': 'human',
         
     | 
| 126 | 
         
            +
                                'value': 'What is the relation between the two dogs <boxes> and the man <boxes> in the image <image> ?',
         
     | 
| 127 | 
         
            +
                                'boxes_seq': [[0, 1], [2], ],
         
     | 
| 128 | 
         
            +
                            },
         
     | 
| 129 | 
         
            +
                            {
         
     | 
| 130 | 
         
            +
                                'from': 'gpt',
         
     | 
| 131 | 
         
            +
                                'value': 'a rope <boxes> is connecting the left dog <boxes> with the man <boxes>. '
         
     | 
| 132 | 
         
            +
                                         'So the man <boxes> is walking the dog <boxes>.'
         
     | 
| 133 | 
         
            +
                                        'And the man <boxes> has no relationship with the right dog <boxes>',
         
     | 
| 134 | 
         
            +
                                'boxes_seq': [[3], [0], [2], [2], [0], [2], [1]],
         
     | 
| 135 | 
         
            +
                            }
         
     | 
| 136 | 
         
            +
                        ]
         
     | 
| 137 | 
         
            +
                    }
         
     | 
| 138 | 
         
            +
                    # placeholder: <image> <boxes>
         
     | 
| 139 | 
         
            +
                    """
         
     | 
| 140 | 
         
            +
                    raise NotImplementedError
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                # noinspection PyMethodMayBeStatic
         
     | 
| 143 | 
         
            +
                def validate_raw_item(self, item):
         
     | 
| 144 | 
         
            +
                    has_image = 'image' in item and bool(item['image'])
         
     | 
| 145 | 
         
            +
                    has_target = 'target' in item and bool(item['target']) and any(bool(elem) for elem in item['target'].values())
         
     | 
| 146 | 
         
            +
                    has_target_boxes = 'boxes' in item['target'] if has_target else False
         
     | 
| 147 | 
         
            +
                    raw_conv: List[Dict[str, Any]] = item['conversations']
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    # check image
         
     | 
| 150 | 
         
            +
                    human_input_has_image_placeholder = any(
         
     | 
| 151 | 
         
            +
                        sentence['from'] == 'human' and IMAGE_PLACEHOLDER in sentence['value'] for sentence in raw_conv
         
     | 
| 152 | 
         
            +
                    )
         
     | 
| 153 | 
         
            +
                    if human_input_has_image_placeholder:
         
     | 
| 154 | 
         
            +
                        assert has_image
         
     | 
| 155 | 
         
            +
                    if has_image and (not human_input_has_image_placeholder):
         
     | 
| 156 | 
         
            +
                        warnings.warn(f'item has image but the question has no image placeholder.\n{item}')
         
     | 
| 157 | 
         
            +
                    gpt_input_has_image_placeholder = any(
         
     | 
| 158 | 
         
            +
                        sentence['from'] == 'gpt' and IMAGE_PLACEHOLDER in sentence['value'] for sentence in raw_conv
         
     | 
| 159 | 
         
            +
                    )
         
     | 
| 160 | 
         
            +
                    assert not gpt_input_has_image_placeholder
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    # check target
         
     | 
| 163 | 
         
            +
                    has_boxes_placeholder = any(
         
     | 
| 164 | 
         
            +
                        BOXES_PLACEHOLDER in sentence['value'] for sentence in raw_conv
         
     | 
| 165 | 
         
            +
                    )
         
     | 
| 166 | 
         
            +
                    if has_boxes_placeholder:
         
     | 
| 167 | 
         
            +
                        assert has_target_boxes
         
     | 
| 168 | 
         
            +
                    # not check box placeholder num this will be checked in format process
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                def build_conv(self, source: List[Dict[str, Any]]) -> Conversation:
         
     | 
| 171 | 
         
            +
                    conv = self.conv_template()
         
     | 
| 172 | 
         
            +
                    role_map = {"human": conv.roles[0], "gpt": conv.roles[1]}
         
     | 
| 173 | 
         
            +
                    assert len(source) > 0
         
     | 
| 174 | 
         
            +
                    assert source[0]['from'] == 'human'
         
     | 
| 175 | 
         
            +
                    for sentence in source:
         
     | 
| 176 | 
         
            +
                        role = role_map[sentence['from']]
         
     | 
| 177 | 
         
            +
                        conv.append_message(role, sentence['value'])
         
     | 
| 178 | 
         
            +
                    return conv
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                def process_conv(self, raw_conv: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
         
     | 
| 181 | 
         
            +
                    """
         
     | 
| 182 | 
         
            +
                    some utils preprocess for raw_conv.
         
     | 
| 183 | 
         
            +
                        e.g. replace <image> placeholder to sequence <im_start> <im_patch>*256 <im_end>
         
     | 
| 184 | 
         
            +
                    """
         
     | 
| 185 | 
         
            +
                    return self.process_func['conv'](raw_conv, self.preprocessor, self.conv_template)
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                def process_target(self, raw_conv: List[Dict[str, Any]], target: Dict[str, Any], multimage_mode=False) -> Tuple[
         
     | 
| 188 | 
         
            +
                    List[Dict[str, Any]], Dict[str, Any]]:
         
     | 
| 189 | 
         
            +
                    """
         
     | 
| 190 | 
         
            +
                    convert target placeholder to actual information in raw_conv.
         
     | 
| 191 | 
         
            +
                        e.g. normalize bounding boxes; convert bounding boxes format; replace <boxes> placeholder
         
     | 
| 192 | 
         
            +
                    """
         
     | 
| 193 | 
         
            +
                    return self.process_func['target'](raw_conv, target, self.preprocessor, multimage_mode=multimage_mode)
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                def process_text(self, conv: Conversation) -> Dict[str, Any]:
         
     | 
| 196 | 
         
            +
                    """
         
     | 
| 197 | 
         
            +
                    convert Conversation object to torch.Tensor, e.g. input_ids, labels, attention_mask, etc.
         
     | 
| 198 | 
         
            +
                        self.tokenize_kwargs control something like padding/truncation behavior.
         
     | 
| 199 | 
         
            +
                    """
         
     | 
| 200 | 
         
            +
                    return self.process_func['text'](conv, self.preprocessor, self.mode, **self.tokenize_kwargs)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                def process_image(self, image: Image.Image) -> Dict[str, Any]:
         
     | 
| 203 | 
         
            +
                    """
         
     | 
| 204 | 
         
            +
                    convert Image.Image object to torch.Tensor
         
     | 
| 205 | 
         
            +
                    """
         
     | 
| 206 | 
         
            +
                    return self.process_func['image'](image, self.preprocessor)
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                def _print_sample(self, ret_dict, raw_conv, conv):
         
     | 
| 209 | 
         
            +
                    if not hasattr(self, '_printed_sample'):
         
     | 
| 210 | 
         
            +
                        self._printed_sample = True
         
     | 
| 211 | 
         
            +
                        post_processed_labels = post_process_generate_ids(self.preprocessor['text'], ret_dict['labels'])
         
     | 
| 212 | 
         
            +
                        print(f"=================== {self.mode} sample ===================", flush=True)
         
     | 
| 213 | 
         
            +
                        print(f"        input_ids: {self.preprocessor['text'].convert_ids_to_tokens(ret_dict['input_ids'])}")
         
     | 
| 214 | 
         
            +
                        print(f"           labels: {self.preprocessor['text'].convert_ids_to_tokens(post_processed_labels)}")
         
     | 
| 215 | 
         
            +
                        print(f"decoded input_ids: {self.preprocessor['text'].decode(ret_dict['input_ids'])}")
         
     | 
| 216 | 
         
            +
                        print(f"decoded    labels: {self.preprocessor['text'].decode(post_processed_labels)}")
         
     | 
| 217 | 
         
            +
                        if 'image' in ret_dict and ret_dict['image'] is not None:
         
     | 
| 218 | 
         
            +
                            image = ret_dict['image']
         
     | 
| 219 | 
         
            +
                            if isinstance(image, torch.Tensor):
         
     | 
| 220 | 
         
            +
                                print(f"            image: {image.shape}")
         
     | 
| 221 | 
         
            +
                            elif isinstance(image, dict):
         
     | 
| 222 | 
         
            +
                                print(f"            image: {image.keys()}")
         
     | 
| 223 | 
         
            +
                            elif isinstance(image, list) and len(image) > 0:
         
     | 
| 224 | 
         
            +
                                print(f"            image: {len(image)}, {type(image[0])}")
         
     | 
| 225 | 
         
            +
                            else:
         
     | 
| 226 | 
         
            +
                                print(f"            image: {type(image)}")
         
     | 
| 227 | 
         
            +
                        print("====================================================", flush=True)
         
     | 
| 228 | 
         
            +
                        try:
         
     | 
| 229 | 
         
            +
                            if self.training_args is not None:
         
     | 
| 230 | 
         
            +
                                _save_obj = {
         
     | 
| 231 | 
         
            +
                                    'ret_dict': ret_dict,
         
     | 
| 232 | 
         
            +
                                    'raw_conv': raw_conv,
         
     | 
| 233 | 
         
            +
                                    'conv': conv.get_prompt(),
         
     | 
| 234 | 
         
            +
                                }
         
     | 
| 235 | 
         
            +
                                from pathlib import Path
         
     | 
| 236 | 
         
            +
                                output_dir = Path(self.training_args.output_dir)
         
     | 
| 237 | 
         
            +
                                output_dir.mkdir(exist_ok=True, parents=True)
         
     | 
| 238 | 
         
            +
                                _local_rank = self.training_args.local_rank
         
     | 
| 239 | 
         
            +
                                _word_size = self.training_args.world_size
         
     | 
| 240 | 
         
            +
                                _file_path = str(output_dir / f'sample_check_{self.mode}_{_local_rank}_{_word_size}.pt')
         
     | 
| 241 | 
         
            +
                                print(f'saving some sample to {_file_path} for check.')
         
     | 
| 242 | 
         
            +
                                torch.save(_save_obj, _file_path)
         
     | 
| 243 | 
         
            +
                        except Exception as e:
         
     | 
| 244 | 
         
            +
                            warnings.warn(f'try to save samples but get exception: {e.args}. ignored.')
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
            class SingleImageConvDataset(SingleImageConvDatasetMixin, Dataset):
         
     | 
| 248 | 
         
            +
                _repr_indent = 4
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                def __init__(self, *args, dataset_generator: Type[Dataset], **kwargs):
         
     | 
| 251 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 252 | 
         
            +
                    self.dataset_generator = dataset_generator
         
     | 
| 253 | 
         
            +
                    self.dataset = None
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                def initialize_if_needed(self):
         
     | 
| 256 | 
         
            +
                    """
         
     | 
| 257 | 
         
            +
                    lazy initialize for big in-memory python object due to python 'copy-on-read' behavior
         
     | 
| 258 | 
         
            +
                    when num_worker > 0. refer: https://github.com/pytorch/pytorch/issues/13246
         
     | 
| 259 | 
         
            +
                    """
         
     | 
| 260 | 
         
            +
                    if self.dataset is None:
         
     | 
| 261 | 
         
            +
                        # warnings.warn("it's highly recommended that set persistent_workers=True, "
         
     | 
| 262 | 
         
            +
                        #               "otherwise this initialize code will run in every epoch beginning."
         
     | 
| 263 | 
         
            +
                        #               "(ignore me if set)")
         
     | 
| 264 | 
         
            +
                        self.dataset = self.dataset_generator()
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                def __len__(self):
         
     | 
| 267 | 
         
            +
                    self.initialize_if_needed()
         
     | 
| 268 | 
         
            +
                    return len(self.dataset)
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                def get_raw_item(self, index) -> Dict[str, Any]:
         
     | 
| 271 | 
         
            +
                    self.initialize_if_needed()
         
     | 
| 272 | 
         
            +
                    return self.dataset[index]
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                def __repr__(self) -> str:
         
     | 
| 275 | 
         
            +
                    head = "Dataset " + self.__class__.__name__
         
     | 
| 276 | 
         
            +
                    body = [
         
     | 
| 277 | 
         
            +
                        f"Number of datapoints: {self.__len__()}",
         
     | 
| 278 | 
         
            +
                    ]
         
     | 
| 279 | 
         
            +
                    body += self.dataset.__repr__().splitlines()
         
     | 
| 280 | 
         
            +
                    lines = [head] + [" " * self._repr_indent + line for line in body]
         
     | 
| 281 | 
         
            +
                    return "\n".join(lines)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
            __all__ = ['SingleImageConvDatasetMixin', 'SingleImageConvDataset']
         
     | 
    	
        mllm/dataset/single_image_dataset/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,13 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .flickr import FlickrParser, FlickrDataset
         
     | 
| 2 | 
         
            +
            from .rec import RECDataset, RECComputeMetrics
         
     | 
| 3 | 
         
            +
            from .reg import REGDataset, GCDataset
         
     | 
| 4 | 
         
            +
            from .caption import CaptionDataset
         
     | 
| 5 | 
         
            +
            from .instr import InstructDataset
         
     | 
| 6 | 
         
            +
            from .gqa import GQADataset, GQAComputeMetrics
         
     | 
| 7 | 
         
            +
            from .clevr import ClevrDataset
         
     | 
| 8 | 
         
            +
            from .point_qa import Point_QA_local, Point_QA_twice, V7W_POINT, PointQAComputeMetrics
         
     | 
| 9 | 
         
            +
            from .gpt_gen import GPT4Gen
         
     | 
| 10 | 
         
            +
            from .vcr import VCRDataset, VCRPredDataset
         
     | 
| 11 | 
         
            +
            from .vqav2 import VQAv2Dataset
         
     | 
| 12 | 
         
            +
            from .vqaex import VQAEXDataset
         
     | 
| 13 | 
         
            +
            from .pope import POPEVQADataset
         
     | 
    	
        mllm/dataset/single_image_dataset/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | 
         Binary file (909 Bytes). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/__pycache__/caption.cpython-310.pyc
    ADDED
    
    | 
         Binary file (1.07 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/__pycache__/clevr.cpython-310.pyc
    ADDED
    
    | 
         Binary file (3.89 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/__pycache__/flickr.cpython-310.pyc
    ADDED
    
    | 
         Binary file (2.65 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/__pycache__/gpt_gen.cpython-310.pyc
    ADDED
    
    | 
         Binary file (1.64 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/__pycache__/gqa.cpython-310.pyc
    ADDED
    
    | 
         Binary file (6.73 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/__pycache__/instr.cpython-310.pyc
    ADDED
    
    | 
         Binary file (1.08 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/__pycache__/point_qa.cpython-310.pyc
    ADDED
    
    | 
         Binary file (5.11 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/__pycache__/pope.cpython-310.pyc
    ADDED
    
    | 
         Binary file (1.2 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/__pycache__/rec.cpython-310.pyc
    ADDED
    
    | 
         Binary file (3.73 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/__pycache__/reg.cpython-310.pyc
    ADDED
    
    | 
         Binary file (1.39 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/__pycache__/vcr.cpython-310.pyc
    ADDED
    
    | 
         Binary file (5.39 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/__pycache__/vqaex.cpython-310.pyc
    ADDED
    
    | 
         Binary file (1.48 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/__pycache__/vqav2.cpython-310.pyc
    ADDED
    
    | 
         Binary file (1.29 kB). View file 
     | 
| 
         | 
    	
        mllm/dataset/single_image_dataset/caption.py
    ADDED
    
    | 
         @@ -0,0 +1,31 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from ..root import DATASETS, IMAGE_PLACEHOLDER
         
     | 
| 2 | 
         
            +
            from ..utils import MInstrDataset
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 6 | 
         
            +
            class CaptionDataset(MInstrDataset):
         
     | 
| 7 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 8 | 
         
            +
                    super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER,))
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 11 | 
         
            +
                    item = self.get_raw_item(index)
         
     | 
| 12 | 
         
            +
                    img_path = item['img_path']
         
     | 
| 13 | 
         
            +
                    caption = item['caption']
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                    image = self.get_image(img_path)
         
     | 
| 16 | 
         
            +
                    question = self.get_template()
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                    ret = {
         
     | 
| 19 | 
         
            +
                        'image': image,
         
     | 
| 20 | 
         
            +
                        'conversations': [
         
     | 
| 21 | 
         
            +
                            {
         
     | 
| 22 | 
         
            +
                                'from': 'human',
         
     | 
| 23 | 
         
            +
                                'value': question,
         
     | 
| 24 | 
         
            +
                            },
         
     | 
| 25 | 
         
            +
                            {
         
     | 
| 26 | 
         
            +
                                'from': 'gpt',
         
     | 
| 27 | 
         
            +
                                'value': caption,
         
     | 
| 28 | 
         
            +
                            }
         
     | 
| 29 | 
         
            +
                        ]
         
     | 
| 30 | 
         
            +
                    }
         
     | 
| 31 | 
         
            +
                    return ret
         
     | 
    	
        mllm/dataset/single_image_dataset/clevr.py
    ADDED
    
    | 
         @@ -0,0 +1,116 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import json
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from ..root import DATASETS, IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER, POINTS_PLACEHOLDER
         
     | 
| 4 | 
         
            +
            from ..utils import MInstrDataset
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 8 | 
         
            +
            class ClevrDataset(MInstrDataset):
         
     | 
| 9 | 
         
            +
                def __init__(self, *args, scene_graph_file, version, **kwargs):
         
     | 
| 10 | 
         
            +
                    super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
         
     | 
| 11 | 
         
            +
                    self.scene_graph_file = scene_graph_file
         
     | 
| 12 | 
         
            +
                    self.version = version
         
     | 
| 13 | 
         
            +
                    qtype, atype = version.split('-')
         
     | 
| 14 | 
         
            +
                    assert qtype in ['q']
         
     | 
| 15 | 
         
            +
                    assert atype in ['a', 's', 'bs']
         
     | 
| 16 | 
         
            +
                    self.qtype = qtype
         
     | 
| 17 | 
         
            +
                    self.atype = atype
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                    if scene_graph_file is None:
         
     | 
| 20 | 
         
            +
                        self.scene_graph = None
         
     | 
| 21 | 
         
            +
                    else:
         
     | 
| 22 | 
         
            +
                        self.scene_graph = [line for line in open(scene_graph_file, 'r', encoding='utf8')]
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def get_raw_item(self, index):
         
     | 
| 25 | 
         
            +
                    question = json.loads(self.data[index])
         
     | 
| 26 | 
         
            +
                    if self.scene_graph is None:
         
     | 
| 27 | 
         
            +
                        scene = None
         
     | 
| 28 | 
         
            +
                    else:
         
     | 
| 29 | 
         
            +
                        scene = json.loads(self.scene_graph[question['image_index']])
         
     | 
| 30 | 
         
            +
                    return question, scene
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 33 | 
         
            +
                    question, scene = self.get_raw_item(index)
         
     | 
| 34 | 
         
            +
                    img_path = question['image_filename']
         
     | 
| 35 | 
         
            +
                    image = self.get_image(img_path)
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    if self.atype == 'a':
         
     | 
| 38 | 
         
            +
                        boxes = []
         
     | 
| 39 | 
         
            +
                        answer = f"The answer is {question['answer']}."
         
     | 
| 40 | 
         
            +
                        answer_boxes_seq = []
         
     | 
| 41 | 
         
            +
                    elif self.atype == 's':
         
     | 
| 42 | 
         
            +
                        answer, boxes, answer_boxes_seq = clevr_ss_cot(obj=question, scene=scene, add_ref=False)
         
     | 
| 43 | 
         
            +
                        answer += f" The answer is {question['answer']}."
         
     | 
| 44 | 
         
            +
                    elif self.atype == 'bs':
         
     | 
| 45 | 
         
            +
                        answer, boxes, answer_boxes_seq = clevr_ss_cot(obj=question, scene=scene, add_ref=True)
         
     | 
| 46 | 
         
            +
                        answer += f" The answer is {question['answer']}."
         
     | 
| 47 | 
         
            +
                    else:
         
     | 
| 48 | 
         
            +
                        assert False
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    if self.qtype == 'q':
         
     | 
| 51 | 
         
            +
                        query_boxes_seq = []
         
     | 
| 52 | 
         
            +
                        final_query = self.get_template().replace(QUESTION_PLACEHOLDER, question['question'])
         
     | 
| 53 | 
         
            +
                    else:
         
     | 
| 54 | 
         
            +
                        assert False
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    ret = {
         
     | 
| 57 | 
         
            +
                        'image': image,
         
     | 
| 58 | 
         
            +
                        'target': {'points': boxes},
         
     | 
| 59 | 
         
            +
                        'conversations': [
         
     | 
| 60 | 
         
            +
                            {
         
     | 
| 61 | 
         
            +
                                'from': 'human',
         
     | 
| 62 | 
         
            +
                                'value': final_query,
         
     | 
| 63 | 
         
            +
                                'points_seq': query_boxes_seq,
         
     | 
| 64 | 
         
            +
                            },
         
     | 
| 65 | 
         
            +
                            {
         
     | 
| 66 | 
         
            +
                                'from': 'gpt',
         
     | 
| 67 | 
         
            +
                                'value': answer,
         
     | 
| 68 | 
         
            +
                                'points_seq': answer_boxes_seq,
         
     | 
| 69 | 
         
            +
                            }
         
     | 
| 70 | 
         
            +
                        ]
         
     | 
| 71 | 
         
            +
                    }
         
     | 
| 72 | 
         
            +
                    return ret
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            def get_boxes_idx(boxes_list, refs):
         
     | 
| 76 | 
         
            +
                def get_idx(boxes_list, box):
         
     | 
| 77 | 
         
            +
                    if box in boxes_list:
         
     | 
| 78 | 
         
            +
                        return boxes_list.index(box)
         
     | 
| 79 | 
         
            +
                    else:
         
     | 
| 80 | 
         
            +
                        boxes_list.append(box)
         
     | 
| 81 | 
         
            +
                        return len(boxes_list) - 1
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                idx = [get_idx(boxes_list, box) for box in refs]
         
     | 
| 84 | 
         
            +
                return idx
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            def clevr_ss_cot(obj, scene, add_ref=False):
         
     | 
| 88 | 
         
            +
                cot = []
         
     | 
| 89 | 
         
            +
                boxes = []
         
     | 
| 90 | 
         
            +
                seq = []
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                def can_add_ref():
         
     | 
| 93 | 
         
            +
                    if p['function'] in ['unique', 'union', 'intersect', 'relate', 'same_size', 'same_shape', 'same_material', 'same_color']:
         
     | 
| 94 | 
         
            +
                        return True
         
     | 
| 95 | 
         
            +
                    if p['function'] in ['scene', 'filter_color', 'filter_material', 'filter_shape', 'filter_size']:
         
     | 
| 96 | 
         
            +
                        if idx + 1 < len(obj['program']) and obj['program'][idx + 1]['function'] in ['exist', 'count']:
         
     | 
| 97 | 
         
            +
                            return True
         
     | 
| 98 | 
         
            +
                    return False
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                for idx, p in enumerate(obj['program']):
         
     | 
| 101 | 
         
            +
                    func = f"{p['function']}:{p['value_inputs'][0]}" if 'value_inputs' in p and p['value_inputs'] else p['function']
         
     | 
| 102 | 
         
            +
                    inputs = f"[{','.join(map(str, p['inputs']))}]" if p['inputs'] else ""
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    if add_ref and can_add_ref():
         
     | 
| 105 | 
         
            +
                        if p['ans']:
         
     | 
| 106 | 
         
            +
                            objs = POINTS_PLACEHOLDER
         
     | 
| 107 | 
         
            +
                            idx = get_boxes_idx(boxes_list=boxes, refs=[scene['objects'][_]['pixel_coords'][:2] for _ in p['ans']])
         
     | 
| 108 | 
         
            +
                            seq.append(idx)
         
     | 
| 109 | 
         
            +
                        else:
         
     | 
| 110 | 
         
            +
                            objs = f" Found no object."
         
     | 
| 111 | 
         
            +
                    else:
         
     | 
| 112 | 
         
            +
                        objs = ""
         
     | 
| 113 | 
         
            +
                    cot.append(f"{func}{inputs}{objs}")
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                ret = " -> ".join(cot)
         
     | 
| 116 | 
         
            +
                return ret, boxes, seq
         
     | 
    	
        mllm/dataset/single_image_dataset/flickr.py
    ADDED
    
    | 
         @@ -0,0 +1,68 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from torch.utils.data import Dataset
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from ..root import DATASETS, BOXES_PLACEHOLDER, IMAGE_PLACEHOLDER
         
     | 
| 4 | 
         
            +
            from ..utils import MInstrDataset
         
     | 
| 5 | 
         
            +
            from ..utils.flickr30k_entities_utils import (
         
     | 
| 6 | 
         
            +
                flatten_annotation,
         
     | 
| 7 | 
         
            +
                PHRASE_ED_PLACEHOLDER,
         
     | 
| 8 | 
         
            +
                PHRASE_ST_PLACEHOLDER,
         
     | 
| 9 | 
         
            +
            )
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class FlickrParser(Dataset):
         
     | 
| 13 | 
         
            +
                def __init__(self, filename, annotation_dir):
         
     | 
| 14 | 
         
            +
                    self.filename = filename
         
     | 
| 15 | 
         
            +
                    self.annotation_dir = annotation_dir
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                    self.indexes = [line.strip() for line in open(filename, 'r', encoding='utf8')]
         
     | 
| 18 | 
         
            +
                    self.data = flatten_annotation(self.annotation_dir, self.indexes)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                def __len__(self):
         
     | 
| 21 | 
         
            +
                    return len(self.data)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 24 | 
         
            +
                    return self.data[index]
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                def dump(self, filename):
         
     | 
| 27 | 
         
            +
                    import json
         
     | 
| 28 | 
         
            +
                    with open(filename, 'w', encoding='utf8') as f:
         
     | 
| 29 | 
         
            +
                        for obj in self.data:
         
     | 
| 30 | 
         
            +
                            obj_str = json.dumps(obj)
         
     | 
| 31 | 
         
            +
                            f.write(obj_str)
         
     | 
| 32 | 
         
            +
                            f.write('\n')
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 36 | 
         
            +
            class FlickrDataset(MInstrDataset):
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 39 | 
         
            +
                    super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER,))
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def __len__(self):
         
     | 
| 42 | 
         
            +
                    return len(self.data)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 45 | 
         
            +
                    item = self.get_raw_item(index)
         
     | 
| 46 | 
         
            +
                    img_path = f"{item['image_id']}.jpg"
         
     | 
| 47 | 
         
            +
                    caption = item['sentence']
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    image = self.get_image(img_path)
         
     | 
| 50 | 
         
            +
                    caption = caption.replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER)
         
     | 
| 51 | 
         
            +
                    question = self.get_template()
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    ret = {
         
     | 
| 54 | 
         
            +
                        'image': image,
         
     | 
| 55 | 
         
            +
                        'target': {'boxes': item['boxes']},
         
     | 
| 56 | 
         
            +
                        'conversations': [
         
     | 
| 57 | 
         
            +
                            {
         
     | 
| 58 | 
         
            +
                                'from': 'human',
         
     | 
| 59 | 
         
            +
                                'value': question,
         
     | 
| 60 | 
         
            +
                            },
         
     | 
| 61 | 
         
            +
                            {
         
     | 
| 62 | 
         
            +
                                'from': 'gpt',
         
     | 
| 63 | 
         
            +
                                'value': caption,
         
     | 
| 64 | 
         
            +
                                'boxes_seq': item['boxes_seq'],
         
     | 
| 65 | 
         
            +
                            }
         
     | 
| 66 | 
         
            +
                        ]
         
     | 
| 67 | 
         
            +
                    }
         
     | 
| 68 | 
         
            +
                    return ret
         
     | 
    	
        mllm/dataset/single_image_dataset/gpt_gen.py
    ADDED
    
    | 
         @@ -0,0 +1,58 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from ..root import (
         
     | 
| 2 | 
         
            +
                DATASETS,
         
     | 
| 3 | 
         
            +
                QUESTION_PLACEHOLDER,
         
     | 
| 4 | 
         
            +
                IMAGE_PLACEHOLDER,
         
     | 
| 5 | 
         
            +
                BOXES_PLACEHOLDER,
         
     | 
| 6 | 
         
            +
            )
         
     | 
| 7 | 
         
            +
            from ..utils import MInstrDataset
         
     | 
| 8 | 
         
            +
            from ..utils.flickr30k_entities_utils import PHRASE_ST_PLACEHOLDER, PHRASE_ED_PLACEHOLDER
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 12 | 
         
            +
            class GPT4Gen(MInstrDataset):
         
     | 
| 13 | 
         
            +
                def __init__(self, *args, version, **kwargs):
         
     | 
| 14 | 
         
            +
                    super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
         
     | 
| 15 | 
         
            +
                    self.version = version
         
     | 
| 16 | 
         
            +
                    assert version in ['a', 'c', 'bc']
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                def __getitem__(self, item):
         
     | 
| 19 | 
         
            +
                    raw = self.get_raw_item(item)
         
     | 
| 20 | 
         
            +
                    #
         
     | 
| 21 | 
         
            +
                    image = self.get_image(raw['img_path'])
         
     | 
| 22 | 
         
            +
                    #
         
     | 
| 23 | 
         
            +
                    boxes = raw['boxes']
         
     | 
| 24 | 
         
            +
                    #
         
     | 
| 25 | 
         
            +
                    question = raw['question']
         
     | 
| 26 | 
         
            +
                    question = question.replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER)
         
     | 
| 27 | 
         
            +
                    final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
         
     | 
| 28 | 
         
            +
                    query_boxes_seq = raw['question_boxes_seq']
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    if self.version == 'a':
         
     | 
| 31 | 
         
            +
                        final_answer = raw['answer']
         
     | 
| 32 | 
         
            +
                        answer_boxes_seq = None
         
     | 
| 33 | 
         
            +
                    elif self.version == 'c':
         
     | 
| 34 | 
         
            +
                        final_answer = raw['cot_with_ans'].replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, '')
         
     | 
| 35 | 
         
            +
                        answer_boxes_seq = None
         
     | 
| 36 | 
         
            +
                    elif self.version == 'bc':
         
     | 
| 37 | 
         
            +
                        final_answer = raw['cot_with_ans'].replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER)
         
     | 
| 38 | 
         
            +
                        answer_boxes_seq = raw['answer_boxes_seq']
         
     | 
| 39 | 
         
            +
                    else:
         
     | 
| 40 | 
         
            +
                        assert False
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    ret = {
         
     | 
| 43 | 
         
            +
                        'image': image,
         
     | 
| 44 | 
         
            +
                        'target': {'boxes': boxes},
         
     | 
| 45 | 
         
            +
                        'conversations': [
         
     | 
| 46 | 
         
            +
                            {
         
     | 
| 47 | 
         
            +
                                'from': 'human',
         
     | 
| 48 | 
         
            +
                                'value': final_question,
         
     | 
| 49 | 
         
            +
                                'boxes_seq': query_boxes_seq,
         
     | 
| 50 | 
         
            +
                            },
         
     | 
| 51 | 
         
            +
                            {
         
     | 
| 52 | 
         
            +
                                'from': 'gpt',
         
     | 
| 53 | 
         
            +
                                'value': final_answer,
         
     | 
| 54 | 
         
            +
                                'boxes_seq': answer_boxes_seq,
         
     | 
| 55 | 
         
            +
                            }
         
     | 
| 56 | 
         
            +
                        ]
         
     | 
| 57 | 
         
            +
                    }
         
     | 
| 58 | 
         
            +
                    return ret
         
     | 
    	
        mllm/dataset/single_image_dataset/gqa.py
    ADDED
    
    | 
         @@ -0,0 +1,233 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import json
         
     | 
| 2 | 
         
            +
            import re
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ..root import DATASETS, IMAGE_PLACEHOLDER, BOXES_PLACEHOLDER, QUESTION_PLACEHOLDER, METRICS
         
     | 
| 5 | 
         
            +
            from ..utils.flickr30k_entities_utils import PHRASE_ST_PLACEHOLDER, PHRASE_ED_PLACEHOLDER
         
     | 
| 6 | 
         
            +
            from ..utils import MInstrDataset, BaseComputeMetrics
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            REFID_PAT = re.compile(r'(\s\((?:(?:\d+(?:,\d+)*)|-)\)\s?)')
         
     | 
| 9 | 
         
            +
            ANS_EXTRACT_PAT = re.compile(r'(?:(?:(?:(?:(?:So t)|(?:T)|(?:t))he answer is)|(?:Answer:)) (.+))')
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 13 | 
         
            +
            class GQADataset(MInstrDataset):
         
     | 
| 14 | 
         
            +
                def __init__(
         
     | 
| 15 | 
         
            +
                        self,
         
     | 
| 16 | 
         
            +
                        *args,
         
     | 
| 17 | 
         
            +
                        scene_graph_file,
         
     | 
| 18 | 
         
            +
                        scene_graph_index,
         
     | 
| 19 | 
         
            +
                        version,
         
     | 
| 20 | 
         
            +
                        question_box_prob=0.5,
         
     | 
| 21 | 
         
            +
                        **kwargs
         
     | 
| 22 | 
         
            +
                ):
         
     | 
| 23 | 
         
            +
                    super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
         
     | 
| 24 | 
         
            +
                    self.scene_graph_file = scene_graph_file
         
     | 
| 25 | 
         
            +
                    self.scene_graph_index = scene_graph_index
         
     | 
| 26 | 
         
            +
                    self.version = version
         
     | 
| 27 | 
         
            +
                    self.question_box_prob = question_box_prob
         
     | 
| 28 | 
         
            +
                    qtype, atype = version.split('-')
         
     | 
| 29 | 
         
            +
                    assert qtype in ['q', 'qb', 'qbp']
         
     | 
| 30 | 
         
            +
                    assert atype in ['a', 'c', 'bc', 's', 'bs', 'l', 'bl']
         
     | 
| 31 | 
         
            +
                    self.qtype = qtype
         
     | 
| 32 | 
         
            +
                    self.atype = atype
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    assert bool(scene_graph_file) == bool(scene_graph_index)
         
     | 
| 35 | 
         
            +
                    if scene_graph_file is not None and scene_graph_index is not None:
         
     | 
| 36 | 
         
            +
                        self.scene_graph = [line for line in open(scene_graph_file, 'r', encoding='utf8')]
         
     | 
| 37 | 
         
            +
                        self.scene_index = json.load(open(scene_graph_index, 'r', encoding='utf8'))
         
     | 
| 38 | 
         
            +
                    else:
         
     | 
| 39 | 
         
            +
                        self.scene_graph = None
         
     | 
| 40 | 
         
            +
                        self.scene_index = None
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                def get_raw_item(self, index):
         
     | 
| 43 | 
         
            +
                    question = json.loads(self.data[index])
         
     | 
| 44 | 
         
            +
                    if self.scene_graph is None:
         
     | 
| 45 | 
         
            +
                        return question, None
         
     | 
| 46 | 
         
            +
                    scene = json.loads(self.scene_graph[self.scene_index[question['imageId']]])
         
     | 
| 47 | 
         
            +
                    return question, scene
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 50 | 
         
            +
                    question, scene = self.get_raw_item(index)
         
     | 
| 51 | 
         
            +
                    img_path = f"{question['imageId']}.jpg"
         
     | 
| 52 | 
         
            +
                    image = self.get_image(img_path)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    # answer
         
     | 
| 55 | 
         
            +
                    if self.atype == 'bc':
         
     | 
| 56 | 
         
            +
                        boxes = question['cot']['boxes']
         
     | 
| 57 | 
         
            +
                        answer = question['cot']['value'].replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER)
         
     | 
| 58 | 
         
            +
                        answer_boxes_seq = question['cot']['seq']
         
     | 
| 59 | 
         
            +
                    elif self.atype == 'c':
         
     | 
| 60 | 
         
            +
                        boxes = []
         
     | 
| 61 | 
         
            +
                        answer = question['cot']['value'].replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, "")
         
     | 
| 62 | 
         
            +
                        answer_boxes_seq = []
         
     | 
| 63 | 
         
            +
                    elif self.atype == 'bs':
         
     | 
| 64 | 
         
            +
                        boxes, bss, answer_boxes_seq = get_bss_example(question, scene)
         
     | 
| 65 | 
         
            +
                        answer = f"{bss}. The answer is {question['answer']}."
         
     | 
| 66 | 
         
            +
                    elif self.atype == 's':
         
     | 
| 67 | 
         
            +
                        boxes = []
         
     | 
| 68 | 
         
            +
                        ss = REFID_PAT.sub('', question['semanticStr'])
         
     | 
| 69 | 
         
            +
                        answer = f"{ss}. The answer is {question['answer']}."
         
     | 
| 70 | 
         
            +
                        answer_boxes_seq = []
         
     | 
| 71 | 
         
            +
                    elif self.atype == 'bl':
         
     | 
| 72 | 
         
            +
                        boxes, answer, answer_boxes_seq = get_bl_example(question, scene)
         
     | 
| 73 | 
         
            +
                    elif self.atype == 'l':
         
     | 
| 74 | 
         
            +
                        boxes = []
         
     | 
| 75 | 
         
            +
                        _, answer, _ = get_bl_example(question, scene)
         
     | 
| 76 | 
         
            +
                        answer = answer.replace(BOXES_PLACEHOLDER, "")
         
     | 
| 77 | 
         
            +
                        answer_boxes_seq = []
         
     | 
| 78 | 
         
            +
                    elif self.atype == 'a':
         
     | 
| 79 | 
         
            +
                        boxes = []
         
     | 
| 80 | 
         
            +
                        answer = f"The answer is {question['answer']}."
         
     | 
| 81 | 
         
            +
                        answer_boxes_seq = []
         
     | 
| 82 | 
         
            +
                    else:
         
     | 
| 83 | 
         
            +
                        assert False
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    # question
         
     | 
| 86 | 
         
            +
                    if self.qtype == 'q':
         
     | 
| 87 | 
         
            +
                        boxes, query, query_boxes_seq = prepare_query_dummy(boxes, question, scene)
         
     | 
| 88 | 
         
            +
                    elif self.qtype == 'qb':
         
     | 
| 89 | 
         
            +
                        boxes, query, query_boxes_seq = prepare_query_box(boxes, question, scene)
         
     | 
| 90 | 
         
            +
                    elif self.qtype == 'qbp':
         
     | 
| 91 | 
         
            +
                        if self.rng.uniform() > self.question_box_prob:
         
     | 
| 92 | 
         
            +
                            boxes, query, query_boxes_seq = prepare_query_dummy(boxes, question, scene)
         
     | 
| 93 | 
         
            +
                        else:
         
     | 
| 94 | 
         
            +
                            boxes, query, query_boxes_seq = prepare_query_box(boxes, question, scene)
         
     | 
| 95 | 
         
            +
                    else:
         
     | 
| 96 | 
         
            +
                        assert False
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    final_query = self.get_template().replace(QUESTION_PLACEHOLDER, query)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    ret = {
         
     | 
| 101 | 
         
            +
                        'image': image,
         
     | 
| 102 | 
         
            +
                        'target': {'boxes': boxes},
         
     | 
| 103 | 
         
            +
                        'conversations': [
         
     | 
| 104 | 
         
            +
                            {
         
     | 
| 105 | 
         
            +
                                'from': 'human',
         
     | 
| 106 | 
         
            +
                                'value': final_query,
         
     | 
| 107 | 
         
            +
                                'boxes_seq': query_boxes_seq,
         
     | 
| 108 | 
         
            +
                            },
         
     | 
| 109 | 
         
            +
                            {
         
     | 
| 110 | 
         
            +
                                'from': 'gpt',
         
     | 
| 111 | 
         
            +
                                'value': answer,
         
     | 
| 112 | 
         
            +
                                'boxes_seq': answer_boxes_seq,
         
     | 
| 113 | 
         
            +
                            }
         
     | 
| 114 | 
         
            +
                        ]
         
     | 
| 115 | 
         
            +
                    }
         
     | 
| 116 | 
         
            +
                    return ret
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            def prepare_query_dummy(boxes_list, q, scene):
         
     | 
| 120 | 
         
            +
                return boxes_list, q['question'], []
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
            def prepare_query_box(boxes_list, q, scene):
         
     | 
| 124 | 
         
            +
                def get_boxes_idx(box):
         
     | 
| 125 | 
         
            +
                    if box in boxes_list:
         
     | 
| 126 | 
         
            +
                        return boxes_list.index(box)
         
     | 
| 127 | 
         
            +
                    else:
         
     | 
| 128 | 
         
            +
                        boxes_list.append(box)
         
     | 
| 129 | 
         
            +
                        return len(boxes_list) - 1
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                def add_boxes_by_rids(rids):
         
     | 
| 132 | 
         
            +
                    def get_box_xyxy(obj):
         
     | 
| 133 | 
         
            +
                        x, y, w, h = obj['x'], obj['y'], obj['w'], obj['h']
         
     | 
| 134 | 
         
            +
                        return x, y, x + w, y + h
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    boxes_idx = []
         
     | 
| 137 | 
         
            +
                    for rid in rids:
         
     | 
| 138 | 
         
            +
                        ref = scene['objects'][rid]
         
     | 
| 139 | 
         
            +
                        ref_box = list(get_box_xyxy(ref))
         
     | 
| 140 | 
         
            +
                        boxes_idx.append(get_boxes_idx(ref_box))
         
     | 
| 141 | 
         
            +
                    return boxes_idx
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                sent = list(q['question'].split())
         
     | 
| 144 | 
         
            +
                query_boxes_seq = []
         
     | 
| 145 | 
         
            +
                for span, rids_str in q['annotations']['question'].items():
         
     | 
| 146 | 
         
            +
                    span = tuple(map(int, span.split(':')))
         
     | 
| 147 | 
         
            +
                    if len(span) == 1:
         
     | 
| 148 | 
         
            +
                        span = [span[0], span[0] + 1]
         
     | 
| 149 | 
         
            +
                    sent[span[1] - 1] = f"{sent[span[1] - 1]}{BOXES_PLACEHOLDER}"
         
     | 
| 150 | 
         
            +
                    boxes_idx = add_boxes_by_rids(rids_str.split(','))
         
     | 
| 151 | 
         
            +
                    query_boxes_seq.append(boxes_idx)
         
     | 
| 152 | 
         
            +
                sent_converted = " ".join(sent).strip()
         
     | 
| 153 | 
         
            +
                return boxes_list, sent_converted, query_boxes_seq
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
            def add_boxes_by_rids(boxes_list, rids, scene):
         
     | 
| 157 | 
         
            +
                def get_boxes_idx(boxes_list, box):
         
     | 
| 158 | 
         
            +
                    if box in boxes_list:
         
     | 
| 159 | 
         
            +
                        return boxes_list.index(box)
         
     | 
| 160 | 
         
            +
                    else:
         
     | 
| 161 | 
         
            +
                        boxes_list.append(box)
         
     | 
| 162 | 
         
            +
                        return len(boxes_list) - 1
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                def get_box_xyxy(obj):
         
     | 
| 165 | 
         
            +
                    x, y, w, h = obj['x'], obj['y'], obj['w'], obj['h']
         
     | 
| 166 | 
         
            +
                    return x, y, x + w, y + h
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                boxes_idx = []
         
     | 
| 169 | 
         
            +
                for rid in rids:
         
     | 
| 170 | 
         
            +
                    ref = scene['objects'][rid]
         
     | 
| 171 | 
         
            +
                    ref_box = list(get_box_xyxy(ref))
         
     | 
| 172 | 
         
            +
                    boxes_idx.append(get_boxes_idx(boxes_list, ref_box))
         
     | 
| 173 | 
         
            +
                return boxes_idx
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            def get_bss_example(question, scene):
         
     | 
| 177 | 
         
            +
                def format_refids(item):
         
     | 
| 178 | 
         
            +
                    item = item.strip()[1:-1]
         
     | 
| 179 | 
         
            +
                    return item.split(',')
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                s = question['semanticStr']
         
     | 
| 182 | 
         
            +
                print(REFID_PAT.findall(s))
         
     | 
| 183 | 
         
            +
                formats = []
         
     | 
| 184 | 
         
            +
                boxes = []
         
     | 
| 185 | 
         
            +
                seqs = []
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                for item in REFID_PAT.findall(s):
         
     | 
| 188 | 
         
            +
                    if '-' in item:
         
     | 
| 189 | 
         
            +
                        formats.append('')
         
     | 
| 190 | 
         
            +
                    else:
         
     | 
| 191 | 
         
            +
                        formats.append('<boxes>')
         
     | 
| 192 | 
         
            +
                        refids = format_refids(item)
         
     | 
| 193 | 
         
            +
                        idx = add_boxes_by_rids(boxes, refids, scene)
         
     | 
| 194 | 
         
            +
                        seqs.append(idx)
         
     | 
| 195 | 
         
            +
                answer = REFID_PAT.sub('{}', s).format(*formats)
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                print(answer)
         
     | 
| 198 | 
         
            +
                print(boxes)
         
     | 
| 199 | 
         
            +
                print(seqs)
         
     | 
| 200 | 
         
            +
                return boxes, answer, seqs
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
            def get_bl_example(ann, scene):
         
     | 
| 204 | 
         
            +
                boxes = []
         
     | 
| 205 | 
         
            +
                boxes_seq = []
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                origin_sent = ann['fullAnswer']
         
     | 
| 208 | 
         
            +
                origin_sent = re.sub('(?:^Yes,)|(?:^No,)', '', origin_sent).strip()
         
     | 
| 209 | 
         
            +
                sent = list(origin_sent.split())
         
     | 
| 210 | 
         
            +
                for span, rids_str in ann['annotations']['fullAnswer'].items():
         
     | 
| 211 | 
         
            +
                    span = tuple(map(int, span.split(':')))
         
     | 
| 212 | 
         
            +
                    if len(span) == 1:
         
     | 
| 213 | 
         
            +
                        span = [span[0], span[0] + 1]
         
     | 
| 214 | 
         
            +
                    sent[span[1] - 1] = f"{sent[span[1] - 1]}{BOXES_PLACEHOLDER}"
         
     | 
| 215 | 
         
            +
                    rids = rids_str.split(',')
         
     | 
| 216 | 
         
            +
                    boxes_idx = add_boxes_by_rids(boxes, rids, scene)
         
     | 
| 217 | 
         
            +
                    boxes_seq.append(boxes_idx)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                answer = "".join(sent)
         
     | 
| 220 | 
         
            +
                answer += f"The answer is {ann['answer']}."
         
     | 
| 221 | 
         
            +
                return boxes, answer, boxes_seq
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
            @METRICS.register_module()
         
     | 
| 225 | 
         
            +
            class GQAComputeMetrics(BaseComputeMetrics):
         
     | 
| 226 | 
         
            +
                def extract_ans(self, string: str):
         
     | 
| 227 | 
         
            +
                    try:
         
     | 
| 228 | 
         
            +
                        found = ANS_EXTRACT_PAT.findall(string.strip())
         
     | 
| 229 | 
         
            +
                        if len(found) != 1:
         
     | 
| 230 | 
         
            +
                            return None
         
     | 
| 231 | 
         
            +
                        return found[0].strip().rstrip('.').strip()
         
     | 
| 232 | 
         
            +
                    except (IndexError, AttributeError):
         
     | 
| 233 | 
         
            +
                        return None
         
     | 
    	
        mllm/dataset/single_image_dataset/instr.py
    ADDED
    
    | 
         @@ -0,0 +1,24 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from ..root import DATASETS
         
     | 
| 2 | 
         
            +
            from ..utils import MInstrDataset
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 6 | 
         
            +
            class InstructDataset(MInstrDataset):
         
     | 
| 7 | 
         
            +
                def __init__(self, *args, add_coco_prefix=False, **kwargs):
         
     | 
| 8 | 
         
            +
                    super().__init__(*args, **kwargs, placeholders=(), template_string='', template_file=None)
         
     | 
| 9 | 
         
            +
                    self.add_coco_prefix = add_coco_prefix
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 12 | 
         
            +
                    item = self.get_raw_item(index)
         
     | 
| 13 | 
         
            +
                    if self.add_coco_prefix:
         
     | 
| 14 | 
         
            +
                        img_path = f"COCO_train2014_{item['image']}"
         
     | 
| 15 | 
         
            +
                    else:
         
     | 
| 16 | 
         
            +
                        img_path = item['image']
         
     | 
| 17 | 
         
            +
                    conversations = item['conversations']
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                    image = self.get_image(img_path)
         
     | 
| 20 | 
         
            +
                    ret = {
         
     | 
| 21 | 
         
            +
                        'image': image,
         
     | 
| 22 | 
         
            +
                        'conversations': conversations,
         
     | 
| 23 | 
         
            +
                    }
         
     | 
| 24 | 
         
            +
                    return ret
         
     | 
    	
        mllm/dataset/single_image_dataset/point_qa.py
    ADDED
    
    | 
         @@ -0,0 +1,247 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import re
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from .. import BaseComputeMetrics
         
     | 
| 4 | 
         
            +
            from ..root import (
         
     | 
| 5 | 
         
            +
                DATASETS,
         
     | 
| 6 | 
         
            +
                METRICS,
         
     | 
| 7 | 
         
            +
                QUESTION_PLACEHOLDER,
         
     | 
| 8 | 
         
            +
                IMAGE_PLACEHOLDER,
         
     | 
| 9 | 
         
            +
                BOXES_PLACEHOLDER,
         
     | 
| 10 | 
         
            +
                POINTS_PLACEHOLDER,
         
     | 
| 11 | 
         
            +
            )
         
     | 
| 12 | 
         
            +
            from ..utils import MInstrDataset
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            # noinspection PyPep8Naming
         
     | 
| 16 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 17 | 
         
            +
            class Point_QA_local(MInstrDataset):
         
     | 
| 18 | 
         
            +
                def __init__(self, *args, version='p', qbp_p_prob=0.5, **kwargs):
         
     | 
| 19 | 
         
            +
                    super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
         
     | 
| 20 | 
         
            +
                    assert version in ['b', 'p', 'bp']
         
     | 
| 21 | 
         
            +
                    self.version = version
         
     | 
| 22 | 
         
            +
                    self.qbp_p_prob = qbp_p_prob
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 25 | 
         
            +
                    item = self.get_raw_item(index)
         
     | 
| 26 | 
         
            +
                    # image
         
     | 
| 27 | 
         
            +
                    img_path = item['file_path']
         
     | 
| 28 | 
         
            +
                    image = self.get_image(img_path)
         
     | 
| 29 | 
         
            +
                    # answer
         
     | 
| 30 | 
         
            +
                    answer = item['answer']
         
     | 
| 31 | 
         
            +
                    # question
         
     | 
| 32 | 
         
            +
                    question = item['question']
         
     | 
| 33 | 
         
            +
                    bbox = item['bbox']
         
     | 
| 34 | 
         
            +
                    point = item['point']
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    version = self.version
         
     | 
| 37 | 
         
            +
                    if version == 'bp':
         
     | 
| 38 | 
         
            +
                        version = 'p' if self.rng.random() < self.qbp_p_prob else 'b'
         
     | 
| 39 | 
         
            +
                    if version == 'b':
         
     | 
| 40 | 
         
            +
                        question = question + BOXES_PLACEHOLDER
         
     | 
| 41 | 
         
            +
                        query_boxes_seq = [[0]]
         
     | 
| 42 | 
         
            +
                        query_points_seq = None
         
     | 
| 43 | 
         
            +
                    elif version == 'p':
         
     | 
| 44 | 
         
            +
                        question = question + POINTS_PLACEHOLDER
         
     | 
| 45 | 
         
            +
                        query_boxes_seq = None
         
     | 
| 46 | 
         
            +
                        query_points_seq = [[0]]
         
     | 
| 47 | 
         
            +
                    else:
         
     | 
| 48 | 
         
            +
                        assert False
         
     | 
| 49 | 
         
            +
                    final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    ret = {
         
     | 
| 52 | 
         
            +
                        'image': image,
         
     | 
| 53 | 
         
            +
                        'target': {
         
     | 
| 54 | 
         
            +
                            'boxes': [bbox],
         
     | 
| 55 | 
         
            +
                            'points': [point],
         
     | 
| 56 | 
         
            +
                        },
         
     | 
| 57 | 
         
            +
                        'conversations': [
         
     | 
| 58 | 
         
            +
                            {
         
     | 
| 59 | 
         
            +
                                'from': 'human',
         
     | 
| 60 | 
         
            +
                                'value': final_question,
         
     | 
| 61 | 
         
            +
                                'boxes_seq': query_boxes_seq,
         
     | 
| 62 | 
         
            +
                                'points_seq': query_points_seq,
         
     | 
| 63 | 
         
            +
                            },
         
     | 
| 64 | 
         
            +
                            {
         
     | 
| 65 | 
         
            +
                                'from': 'gpt',
         
     | 
| 66 | 
         
            +
                                'value': f'The answer is {answer} .',
         
     | 
| 67 | 
         
            +
                            }
         
     | 
| 68 | 
         
            +
                        ]
         
     | 
| 69 | 
         
            +
                    }
         
     | 
| 70 | 
         
            +
                    return ret
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            # noinspection PyPep8Naming
         
     | 
| 74 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 75 | 
         
            +
            class Point_QA_twice(MInstrDataset):
         
     | 
| 76 | 
         
            +
                def __init__(self, *args, version='gq-p', bp_p_prob=0.5, **kwargs):
         
     | 
| 77 | 
         
            +
                    super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
         
     | 
| 78 | 
         
            +
                    self.version = version
         
     | 
| 79 | 
         
            +
                    self.bp_p_prob = bp_p_prob
         
     | 
| 80 | 
         
            +
                    qtype, rtype = version.split('-')
         
     | 
| 81 | 
         
            +
                    assert qtype in ['oq', 'sq', 'gq']
         
     | 
| 82 | 
         
            +
                    assert rtype in ['b', 'p', 'bp']
         
     | 
| 83 | 
         
            +
                    self.qtype = qtype
         
     | 
| 84 | 
         
            +
                    self.rtype = rtype
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 87 | 
         
            +
                    item = self.get_raw_item(index)
         
     | 
| 88 | 
         
            +
                    # image
         
     | 
| 89 | 
         
            +
                    img_path = item['file_path']
         
     | 
| 90 | 
         
            +
                    image = self.get_image(img_path)
         
     | 
| 91 | 
         
            +
                    # answer
         
     | 
| 92 | 
         
            +
                    answer = item['answer']
         
     | 
| 93 | 
         
            +
                    # question
         
     | 
| 94 | 
         
            +
                    bbox = item['bbox']
         
     | 
| 95 | 
         
            +
                    point = item['point']
         
     | 
| 96 | 
         
            +
                    if self.qtype == 'oq':
         
     | 
| 97 | 
         
            +
                        question = item['obj_question']
         
     | 
| 98 | 
         
            +
                    elif self.qtype == 'sq':
         
     | 
| 99 | 
         
            +
                        question = item['super_question']
         
     | 
| 100 | 
         
            +
                    elif self.qtype == 'gq':
         
     | 
| 101 | 
         
            +
                        question = item['general_question']
         
     | 
| 102 | 
         
            +
                    else:
         
     | 
| 103 | 
         
            +
                        assert False
         
     | 
| 104 | 
         
            +
                    rtype = self.rtype
         
     | 
| 105 | 
         
            +
                    if rtype == 'bp':
         
     | 
| 106 | 
         
            +
                        rtype = 'p' if self.rng.random() < self.bp_p_prob else 'b'
         
     | 
| 107 | 
         
            +
                    if rtype == 'p':
         
     | 
| 108 | 
         
            +
                        question = question + POINTS_PLACEHOLDER
         
     | 
| 109 | 
         
            +
                        query_boxes_seq = None
         
     | 
| 110 | 
         
            +
                        query_points_seq = [[0]]
         
     | 
| 111 | 
         
            +
                    elif rtype == 'b':
         
     | 
| 112 | 
         
            +
                        question = question + BOXES_PLACEHOLDER
         
     | 
| 113 | 
         
            +
                        query_boxes_seq = [[0]]
         
     | 
| 114 | 
         
            +
                        query_points_seq = None
         
     | 
| 115 | 
         
            +
                    else:
         
     | 
| 116 | 
         
            +
                        assert False
         
     | 
| 117 | 
         
            +
                    final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    ret = {
         
     | 
| 120 | 
         
            +
                        'image': image,
         
     | 
| 121 | 
         
            +
                        'target': {
         
     | 
| 122 | 
         
            +
                            'boxes': [bbox],
         
     | 
| 123 | 
         
            +
                            'points': [point],
         
     | 
| 124 | 
         
            +
                        },
         
     | 
| 125 | 
         
            +
                        'conversations': [
         
     | 
| 126 | 
         
            +
                            {
         
     | 
| 127 | 
         
            +
                                'from': 'human',
         
     | 
| 128 | 
         
            +
                                'value': final_question,
         
     | 
| 129 | 
         
            +
                                'boxes_seq': query_boxes_seq,
         
     | 
| 130 | 
         
            +
                                'points_seq': query_points_seq,
         
     | 
| 131 | 
         
            +
                            },
         
     | 
| 132 | 
         
            +
                            {
         
     | 
| 133 | 
         
            +
                                'from': 'gpt',
         
     | 
| 134 | 
         
            +
                                'value': f'The answer is {answer} .',
         
     | 
| 135 | 
         
            +
                            }
         
     | 
| 136 | 
         
            +
                        ]
         
     | 
| 137 | 
         
            +
                    }
         
     | 
| 138 | 
         
            +
                    return ret
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
            # noinspection PyPep8Naming
         
     | 
| 142 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 143 | 
         
            +
            class V7W_POINT(MInstrDataset):
         
     | 
| 144 | 
         
            +
                def __init__(self, *args, version, do_shuffle_choice=True, **kwargs):
         
     | 
| 145 | 
         
            +
                    super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
         
     | 
| 146 | 
         
            +
                    self.version = version
         
     | 
| 147 | 
         
            +
                    self.do_shuffle_choice = do_shuffle_choice
         
     | 
| 148 | 
         
            +
                    assert version in ['p', 'b']
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                def __len__(self):
         
     | 
| 151 | 
         
            +
                    return len(self.data)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 154 | 
         
            +
                    item = self.get_raw_item(index)
         
     | 
| 155 | 
         
            +
                    # image
         
     | 
| 156 | 
         
            +
                    img_path = item['file_path']
         
     | 
| 157 | 
         
            +
                    image = self.get_image(img_path)
         
     | 
| 158 | 
         
            +
                    # question
         
     | 
| 159 | 
         
            +
                    bboxes = item['candidates']
         
     | 
| 160 | 
         
            +
                    points = []
         
     | 
| 161 | 
         
            +
                    final_question = item['question'] + ' Candidates: ' + " ".join([BOXES_PLACEHOLDER for _ in range(len(bboxes))])
         
     | 
| 162 | 
         
            +
                    query_boxes_seq = []
         
     | 
| 163 | 
         
            +
                    for _ in range(len(bboxes)):
         
     | 
| 164 | 
         
            +
                        query_boxes_seq.append([_])
         
     | 
| 165 | 
         
            +
                    # answer
         
     | 
| 166 | 
         
            +
                    if self.version == 'p':
         
     | 
| 167 | 
         
            +
                        final_question += f" answer in point format."
         
     | 
| 168 | 
         
            +
                        points.append(item['point'])
         
     | 
| 169 | 
         
            +
                        final_answer = f"The answer is {POINTS_PLACEHOLDER} ."
         
     | 
| 170 | 
         
            +
                        answer_boxes_seq = None
         
     | 
| 171 | 
         
            +
                        answer_points_seq = [[0]]
         
     | 
| 172 | 
         
            +
                    elif self.version == 'b':
         
     | 
| 173 | 
         
            +
                        final_question += f" answer in box format."
         
     | 
| 174 | 
         
            +
                        idx = bboxes.index(item['answer'])
         
     | 
| 175 | 
         
            +
                        final_answer = f"The answer is {BOXES_PLACEHOLDER} ."
         
     | 
| 176 | 
         
            +
                        answer_boxes_seq = [[idx]]
         
     | 
| 177 | 
         
            +
                        answer_points_seq = None
         
     | 
| 178 | 
         
            +
                    else:
         
     | 
| 179 | 
         
            +
                        assert False
         
     | 
| 180 | 
         
            +
                    final_question = self.get_template().replace(QUESTION_PLACEHOLDER, final_question)
         
     | 
| 181 | 
         
            +
                    if self.do_shuffle_choice:
         
     | 
| 182 | 
         
            +
                        self.rng.shuffle(query_boxes_seq)
         
     | 
| 183 | 
         
            +
                        # bboxes, query_boxes_seq, answer_boxes_seq = self.shuffle_boxes(bboxes, query_boxes_seq, answer_boxes_seq)
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    ret = {
         
     | 
| 186 | 
         
            +
                        'image': image,
         
     | 
| 187 | 
         
            +
                        'target': {
         
     | 
| 188 | 
         
            +
                            'boxes': bboxes,
         
     | 
| 189 | 
         
            +
                            'points': points,
         
     | 
| 190 | 
         
            +
                        },
         
     | 
| 191 | 
         
            +
                        'conversations': [
         
     | 
| 192 | 
         
            +
                            {
         
     | 
| 193 | 
         
            +
                                'from': 'human',
         
     | 
| 194 | 
         
            +
                                'value': final_question,
         
     | 
| 195 | 
         
            +
                                'boxes_seq': query_boxes_seq,
         
     | 
| 196 | 
         
            +
                            },
         
     | 
| 197 | 
         
            +
                            {
         
     | 
| 198 | 
         
            +
                                'from': 'gpt',
         
     | 
| 199 | 
         
            +
                                'value': final_answer,
         
     | 
| 200 | 
         
            +
                                'boxes_seq': answer_boxes_seq,
         
     | 
| 201 | 
         
            +
                                'points_seq': answer_points_seq,
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                            }
         
     | 
| 204 | 
         
            +
                        ]
         
     | 
| 205 | 
         
            +
                    }
         
     | 
| 206 | 
         
            +
                    return ret
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                # def shuffle_boxes(self, bboxes, query_boxes_seq, answer_boxes_seq):
         
     | 
| 209 | 
         
            +
                #     idx_mapping = list(range(len(bboxes)))
         
     | 
| 210 | 
         
            +
                #     self.rng.shuffle(idx_mapping)
         
     | 
| 211 | 
         
            +
                #
         
     | 
| 212 | 
         
            +
                #     new_bboxes = [None for _ in range(len(bboxes))]
         
     | 
| 213 | 
         
            +
                #     for idx_old, idx_new in enumerate(idx_mapping):
         
     | 
| 214 | 
         
            +
                #         new_bboxes[idx_new] = bboxes[idx_old]
         
     | 
| 215 | 
         
            +
                #
         
     | 
| 216 | 
         
            +
                #     if query_boxes_seq is None:
         
     | 
| 217 | 
         
            +
                #         new_query_boxes_seq = None
         
     | 
| 218 | 
         
            +
                #     else:
         
     | 
| 219 | 
         
            +
                #         new_query_boxes_seq = []
         
     | 
| 220 | 
         
            +
                #         for boxes in query_boxes_seq:
         
     | 
| 221 | 
         
            +
                #             new_boxes = [idx_mapping[box_idx] for box_idx in boxes]
         
     | 
| 222 | 
         
            +
                #             new_query_boxes_seq.append(new_boxes)
         
     | 
| 223 | 
         
            +
                #
         
     | 
| 224 | 
         
            +
                #     if answer_boxes_seq is None:
         
     | 
| 225 | 
         
            +
                #         new_answer_boxes_seq = None
         
     | 
| 226 | 
         
            +
                #     else:
         
     | 
| 227 | 
         
            +
                #         new_answer_boxes_seq = []
         
     | 
| 228 | 
         
            +
                #         for boxes in answer_boxes_seq:
         
     | 
| 229 | 
         
            +
                #             new_boxes = [idx_mapping[box_idx] for box_idx in boxes]
         
     | 
| 230 | 
         
            +
                #             new_answer_boxes_seq.append(new_boxes)
         
     | 
| 231 | 
         
            +
                #
         
     | 
| 232 | 
         
            +
                #     return new_bboxes, new_query_boxes_seq, new_answer_boxes_seq
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
            ANS_EXTRACT_PAT = re.compile(r'(?:The answer is (.+?)\.)')
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
            @METRICS.register_module()
         
     | 
| 239 | 
         
            +
            class PointQAComputeMetrics(BaseComputeMetrics):
         
     | 
| 240 | 
         
            +
                def extract_ans(self, string: str):
         
     | 
| 241 | 
         
            +
                    try:
         
     | 
| 242 | 
         
            +
                        found = ANS_EXTRACT_PAT.findall(string.strip())
         
     | 
| 243 | 
         
            +
                        if len(found) != 1:
         
     | 
| 244 | 
         
            +
                            return None
         
     | 
| 245 | 
         
            +
                        return found[0].strip()
         
     | 
| 246 | 
         
            +
                    except (IndexError, AttributeError):
         
     | 
| 247 | 
         
            +
                        return None
         
     | 
    	
        mllm/dataset/single_image_dataset/pope.py
    ADDED
    
    | 
         @@ -0,0 +1,36 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from ..root import (
         
     | 
| 2 | 
         
            +
                DATASETS,
         
     | 
| 3 | 
         
            +
                QUESTION_PLACEHOLDER,
         
     | 
| 4 | 
         
            +
                IMAGE_PLACEHOLDER,
         
     | 
| 5 | 
         
            +
            )
         
     | 
| 6 | 
         
            +
            from ..utils import MInstrDataset
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 10 | 
         
            +
            class POPEVQADataset(MInstrDataset):
         
     | 
| 11 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 12 | 
         
            +
                    super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER))
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 15 | 
         
            +
                    item = self.get_raw_item(index)
         
     | 
| 16 | 
         
            +
                    image = self.get_image(image_path=item['image'])
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                    question = item['text']
         
     | 
| 19 | 
         
            +
                    final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    label = str(item['label']).lower()
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                    ret = {
         
     | 
| 24 | 
         
            +
                        'image': image,
         
     | 
| 25 | 
         
            +
                        'conversations': [
         
     | 
| 26 | 
         
            +
                            {
         
     | 
| 27 | 
         
            +
                                'from': 'human',
         
     | 
| 28 | 
         
            +
                                'value': final_question,
         
     | 
| 29 | 
         
            +
                            },
         
     | 
| 30 | 
         
            +
                            {
         
     | 
| 31 | 
         
            +
                                'from': 'gpt',
         
     | 
| 32 | 
         
            +
                                'value': f"The answer is {label} .",
         
     | 
| 33 | 
         
            +
                            },
         
     | 
| 34 | 
         
            +
                        ]
         
     | 
| 35 | 
         
            +
                    }
         
     | 
| 36 | 
         
            +
                    return ret
         
     | 
    	
        mllm/dataset/single_image_dataset/rec.py
    ADDED
    
    | 
         @@ -0,0 +1,128 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import sys
         
     | 
| 2 | 
         
            +
            import logging
         
     | 
| 3 | 
         
            +
            import warnings
         
     | 
| 4 | 
         
            +
            from typing import Dict, Any, Sequence
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            from torchvision.ops import box_iou
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from ..utils import (
         
     | 
| 10 | 
         
            +
                MInstrDataset,
         
     | 
| 11 | 
         
            +
                BaseComputeMetrics,
         
     | 
| 12 | 
         
            +
            )
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            from ..process_function import (
         
     | 
| 15 | 
         
            +
                BoxFormatter,
         
     | 
| 16 | 
         
            +
            )
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from ..root import (
         
     | 
| 19 | 
         
            +
                DATASETS,
         
     | 
| 20 | 
         
            +
                METRICS,
         
     | 
| 21 | 
         
            +
                IMAGE_PLACEHOLDER,
         
     | 
| 22 | 
         
            +
                BOXES_PLACEHOLDER,
         
     | 
| 23 | 
         
            +
                EXPR_PLACEHOLDER,
         
     | 
| 24 | 
         
            +
            )
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 27 | 
         
            +
            logger.setLevel(logging.INFO)
         
     | 
| 28 | 
         
            +
            logging.basicConfig(
         
     | 
| 29 | 
         
            +
                format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
         
     | 
| 30 | 
         
            +
                datefmt="%m/%d/%Y %H:%M:%S",
         
     | 
| 31 | 
         
            +
                handlers=[logging.StreamHandler(sys.stdout), ],
         
     | 
| 32 | 
         
            +
            )
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 36 | 
         
            +
            class RECDataset(MInstrDataset):
         
     | 
| 37 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 38 | 
         
            +
                    super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, EXPR_PLACEHOLDER))
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 41 | 
         
            +
                    item = self.get_raw_item(index)
         
     | 
| 42 | 
         
            +
                    img_path = item['img_path']
         
     | 
| 43 | 
         
            +
                    expr = item['expression']
         
     | 
| 44 | 
         
            +
                    bbox = item['bbox']
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    image = self.get_image(img_path)
         
     | 
| 47 | 
         
            +
                    question = self.get_template().replace(EXPR_PLACEHOLDER, expr)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    ret = {
         
     | 
| 50 | 
         
            +
                        'image': image,
         
     | 
| 51 | 
         
            +
                        'target': {
         
     | 
| 52 | 
         
            +
                            'boxes': [bbox],
         
     | 
| 53 | 
         
            +
                        },
         
     | 
| 54 | 
         
            +
                        'conversations': [
         
     | 
| 55 | 
         
            +
                            {
         
     | 
| 56 | 
         
            +
                                'from': 'human',
         
     | 
| 57 | 
         
            +
                                'value': question,
         
     | 
| 58 | 
         
            +
                            },
         
     | 
| 59 | 
         
            +
                            {
         
     | 
| 60 | 
         
            +
                                'from': 'gpt',
         
     | 
| 61 | 
         
            +
                                'value': f'Answer: {BOXES_PLACEHOLDER} .',
         
     | 
| 62 | 
         
            +
                                'boxes_seq': [[0]],
         
     | 
| 63 | 
         
            +
                            }
         
     | 
| 64 | 
         
            +
                        ]
         
     | 
| 65 | 
         
            +
                    }
         
     | 
| 66 | 
         
            +
                    return ret
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            @METRICS.register_module()
         
     | 
| 70 | 
         
            +
            class RECComputeMetrics(BaseComputeMetrics):
         
     | 
| 71 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 72 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 73 | 
         
            +
                    self.box_formatter: BoxFormatter = self.preprocessor['target']['boxes']
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                def calculate_metric(self, preds: Sequence[str], targets: Sequence[str]) -> Dict[str, Any]:
         
     | 
| 76 | 
         
            +
                    failed = 0
         
     | 
| 77 | 
         
            +
                    target_failed = 0
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    pred_boxes, target_boxes = [], []
         
     | 
| 80 | 
         
            +
                    for pred, target in zip(preds, targets):
         
     | 
| 81 | 
         
            +
                        extract_pred = self.extract_ans(pred)
         
     | 
| 82 | 
         
            +
                        extract_target = self.extract_ans(target)
         
     | 
| 83 | 
         
            +
                        if extract_target is None:
         
     | 
| 84 | 
         
            +
                            target_failed += 1
         
     | 
| 85 | 
         
            +
                            logger.warning(f"failed to extract ans for target: {target}")
         
     | 
| 86 | 
         
            +
                            continue
         
     | 
| 87 | 
         
            +
                        if extract_pred is None:
         
     | 
| 88 | 
         
            +
                            failed += 1
         
     | 
| 89 | 
         
            +
                            logger.warning(f"failed to extract ans for pred: {pred}")
         
     | 
| 90 | 
         
            +
                            extract_pred = [0, 0, 0, 0]
         
     | 
| 91 | 
         
            +
                        target_boxes.append(extract_target)
         
     | 
| 92 | 
         
            +
                        pred_boxes.append(extract_pred)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    with torch.no_grad():
         
     | 
| 95 | 
         
            +
                        target_boxes = torch.tensor(target_boxes)
         
     | 
| 96 | 
         
            +
                        pred_boxes = torch.tensor(pred_boxes)
         
     | 
| 97 | 
         
            +
                        # normalized box value is too small, so that the area is 0.
         
     | 
| 98 | 
         
            +
                        ious = box_iou(pred_boxes * 1000, target_boxes * 1000)
         
     | 
| 99 | 
         
            +
                        ious = torch.einsum('i i -> i', ious)  # take diag elem
         
     | 
| 100 | 
         
            +
                        # NOTE: please note iou only calculate for success target
         
     | 
| 101 | 
         
            +
                        iou = ious.mean().item()
         
     | 
| 102 | 
         
            +
                        correct = (ious > 0.5).sum().item()
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    # HACK: currently we expand image to square. so this iou is the real iou.
         
     | 
| 105 | 
         
            +
                    warn_message = "this iou is calculate on normalized box. just for non-rigorous training progress checking." \
         
     | 
| 106 | 
         
            +
                                   "the value is consistent with real iou only if image.width == image.height."
         
     | 
| 107 | 
         
            +
                    warnings.warn(warn_message)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    return {
         
     | 
| 110 | 
         
            +
                        'accuracy': 1.0 * correct / len(targets),
         
     | 
| 111 | 
         
            +
                        'target_failed': target_failed,
         
     | 
| 112 | 
         
            +
                        'failed': failed,
         
     | 
| 113 | 
         
            +
                        'iou': iou,
         
     | 
| 114 | 
         
            +
                        'warning': warn_message,
         
     | 
| 115 | 
         
            +
                    }
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                def extract_ans(self, string: str):
         
     | 
| 118 | 
         
            +
                    try:
         
     | 
| 119 | 
         
            +
                        list_of_boxes = self.box_formatter.extract(string)
         
     | 
| 120 | 
         
            +
                        if len(list_of_boxes) != 1 or len(list_of_boxes[0]) != 1:
         
     | 
| 121 | 
         
            +
                            return None
         
     | 
| 122 | 
         
            +
                        box = list_of_boxes[0][0]
         
     | 
| 123 | 
         
            +
                        if len(box) != 4:
         
     | 
| 124 | 
         
            +
                            return None
         
     | 
| 125 | 
         
            +
                        return box
         
     | 
| 126 | 
         
            +
                    except Exception as e:
         
     | 
| 127 | 
         
            +
                        logger.warning(f"extract_ans for {string} but get exception: {e}")
         
     | 
| 128 | 
         
            +
                        return None
         
     |