Spaces:
Runtime error
Runtime error
| import keras | |
| import keras_hub | |
| model_presets = [ | |
| # 8B params models | |
| "hf://google/gemma-2-instruct-9b-keras", | |
| "hf://meta-llama/Llama-3.1-8B-Instruct", | |
| "hf://google/codegemma-7b-it-keras", | |
| "hf://keras/mistral_instruct_7b_en", | |
| "hf://keras/vicuna_1.5_7b_en", | |
| # "keras/gemma_1.1_instruct_7b_en", # won't fit? | |
| # 1-3B params models | |
| "hf://meta-llama/Llama-3.2-1B-Instruct", | |
| "hf://google/gemma-2b-it-keras", | |
| "hf://meta-llama/Llama-3.2-3B-Instruct", | |
| ] | |
| model_labels = map(lambda s: s.removeprefix("hf://"), model_presets) | |
| model_labels = map(lambda s: s.removeprefix("google/"), model_labels) | |
| model_labels = map(lambda s: s.removeprefix("keras/"), model_labels) | |
| model_labels = map(lambda s: s.removeprefix("meta-llama/"), model_labels) | |
| def preset_to_website_url(preset): | |
| preset = preset.removeprefix("hf://") | |
| url = "http://huggingface.co/" + preset | |
| return url | |
| def get_appropriate_chat_template(preset): | |
| return "Vicuna" if "vicuna" in preset else "auto" | |
| def get_default_layout_map(preset_name, device_mesh): | |
| # Llama's default layout map works for mistral and vicuna | |
| # because their transformer layers have the same names. | |
| if ( | |
| "Llama" in preset_name | |
| or "mistral" in preset_name | |
| or "vicuna" in preset_name | |
| ): | |
| layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh) | |
| # Default layout map patch: | |
| # This line is missing for some Llama models (TODO: fix this in keras_hub) | |
| layout_map["token_embedding/reverse_embeddings"] = ("batch", "model") | |
| return layout_map | |
| elif "gemma" in preset_name: | |
| layout_map = keras_hub.models.GemmaBackbone.get_layout_map(device_mesh) | |
| if "gemma-2b-" in preset_name: | |
| # Default layout map patch: | |
| # Gemma QKV weigts are shaped [NB_HEADS, EMBED_DIM, INNER_DIM] | |
| # Llama QKV weights are shaped [EMBED_DIM, NB_HEADS, INNER_DIM] | |
| # However: | |
| # The default layout map for KQV weights on Gemma is: (model_dim,data_dim,None) | |
| # Which means sharding NB_HEADS on the "model" dimension. | |
| # But gemma-2b-it-keras has only 1 head so this won't work: must patch it | |
| # TODO: fix this in the Gemma layout map in Keras hub. | |
| patch_key = "decoder_block.*attention.*(query|key|value).kernel" | |
| layout_map.pop(patch_key) | |
| layout_map[patch_key] = (None, "model", "batch") | |
| return layout_map | |
| def log_applied_layout_map(model): | |
| print("Model class:", type(model).__name__) | |
| if "Gemma" in type(model).__name__: | |
| transformer_decoder_block_name = "decoder_block_1" | |
| elif "Llama" in type(model).__name__: # works for Llama (Vicuna) and Llama3 | |
| transformer_decoder_block_name = "transformer_layer_1" | |
| elif "Mistral" in type(model).__name__: | |
| transformer_decoder_block_name = "transformer_layer_1" | |
| else: | |
| print("Unknown architecture. Cannot display the applied layout.") | |
| return | |
| # See how layer sharding was applied | |
| embedding_layer = model.backbone.get_layer("token_embedding") | |
| print(embedding_layer) | |
| decoder_block = model.backbone.get_layer(transformer_decoder_block_name) | |
| print(type(decoder_block)) | |
| for variable in embedding_layer.weights + decoder_block.weights: | |
| print( | |
| f"{variable.path:<58} \ | |
| {str(variable.shape):<16} \ | |
| {str(variable.value.sharding.spec):<35} \ | |
| {str(variable.dtype)}" | |
| ) | |
| def load_model(preset): | |
| devices = keras.distribution.list_devices() | |
| device_mesh = keras.distribution.DeviceMesh( | |
| shape=(1, len(devices)), axis_names=["batch", "model"], devices=devices | |
| ) | |
| model_parallel = keras.distribution.ModelParallel( | |
| layout_map=get_default_layout_map(preset, device_mesh), | |
| batch_dim_name="batch", | |
| ) | |
| with model_parallel.scope(): | |
| # These two buggy models need this workaround to be loaded in bfloat16 | |
| if "google/gemma-2-instruct-9b-keras" in preset: | |
| model = keras_hub.models.GemmaCausalLM( | |
| backbone=keras_hub.models.GemmaBackbone.from_preset( | |
| preset, dtype="bfloat16" | |
| ), | |
| preprocessor=keras_hub.models.GemmaCausalLMPreprocessor.from_preset( | |
| preset | |
| ), | |
| ) | |
| elif "meta-llama/Llama-3.1-8B-Instruct" in preset: | |
| model = keras_hub.models.Llama3CausalLM( | |
| backbone=keras_hub.models.Llama3Backbone.from_preset( | |
| preset, dtype="bfloat16" | |
| ), | |
| preprocessor=keras_hub.models.Llama3CausalLMPreprocessor.from_preset( | |
| preset | |
| ), | |
| ) | |
| else: | |
| model = keras_hub.models.CausalLM.from_preset( | |
| preset, dtype="bfloat16" | |
| ) | |
| log_applied_layout_map(model) | |
| return model | |