Spaces:
Runtime error
Runtime error
style
Browse files- dalle_mini/model/modeling.py +6 -7
- tools/train/train.py +1 -1
dalle_mini/model/modeling.py
CHANGED
|
@@ -30,21 +30,20 @@ from transformers.modeling_flax_outputs import (
|
|
| 30 |
FlaxSeq2SeqLMOutput,
|
| 31 |
)
|
| 32 |
from transformers.modeling_flax_utils import ACT2FN
|
| 33 |
-
from transformers.utils import logging
|
| 34 |
-
|
| 35 |
from transformers.models.bart.modeling_flax_bart import (
|
| 36 |
FlaxBartAttention,
|
| 37 |
-
|
| 38 |
FlaxBartDecoderLayer,
|
| 39 |
-
FlaxBartEncoderLayerCollection,
|
| 40 |
FlaxBartDecoderLayerCollection,
|
| 41 |
FlaxBartEncoder,
|
| 42 |
-
|
| 43 |
-
|
|
|
|
| 44 |
FlaxBartForConditionalGenerationModule,
|
|
|
|
| 45 |
FlaxBartPreTrainedModel,
|
| 46 |
-
FlaxBartForConditionalGeneration,
|
| 47 |
)
|
|
|
|
| 48 |
|
| 49 |
from .configuration import DalleBartConfig
|
| 50 |
|
|
|
|
| 30 |
FlaxSeq2SeqLMOutput,
|
| 31 |
)
|
| 32 |
from transformers.modeling_flax_utils import ACT2FN
|
|
|
|
|
|
|
| 33 |
from transformers.models.bart.modeling_flax_bart import (
|
| 34 |
FlaxBartAttention,
|
| 35 |
+
FlaxBartDecoder,
|
| 36 |
FlaxBartDecoderLayer,
|
|
|
|
| 37 |
FlaxBartDecoderLayerCollection,
|
| 38 |
FlaxBartEncoder,
|
| 39 |
+
FlaxBartEncoderLayer,
|
| 40 |
+
FlaxBartEncoderLayerCollection,
|
| 41 |
+
FlaxBartForConditionalGeneration,
|
| 42 |
FlaxBartForConditionalGenerationModule,
|
| 43 |
+
FlaxBartModule,
|
| 44 |
FlaxBartPreTrainedModel,
|
|
|
|
| 45 |
)
|
| 46 |
+
from transformers.utils import logging
|
| 47 |
|
| 48 |
from .configuration import DalleBartConfig
|
| 49 |
|
tools/train/train.py
CHANGED
|
@@ -43,7 +43,7 @@ from tqdm import tqdm
|
|
| 43 |
from transformers import AutoTokenizer, HfArgumentParser
|
| 44 |
|
| 45 |
from dalle_mini.data import Dataset
|
| 46 |
-
from dalle_mini.model import
|
| 47 |
|
| 48 |
logger = logging.getLogger(__name__)
|
| 49 |
|
|
|
|
| 43 |
from transformers import AutoTokenizer, HfArgumentParser
|
| 44 |
|
| 45 |
from dalle_mini.data import Dataset
|
| 46 |
+
from dalle_mini.model import DalleBart, DalleBartConfig
|
| 47 |
|
| 48 |
logger = logging.getLogger(__name__)
|
| 49 |
|