justiceai / language.py
Princeaka's picture
Update language.py
55ce9fc verified
"""
language.py — robust loader + adapter for language.bin
This loader attempts multiple safe options to load a local language model file `language.bin`
and adapt it into a small, predictable translation API:
- translate(text, src, tgt)
- translate_to_en(text, src)
- translate_from_en(text, tgt)
- detect(text) / detect_language(text) (if provided by model)
- model_info() for debugging
Loading strategy (in order):
1. If a language.py module is present (importable) we prefer it (the app already tries this).
2. If language.bin exists:
- Try to detect if it's a safetensors file and (if safetensors is installed) attempt to load.
- Try torch.load with weights_only=True (safe for "weights-only" files).
- If that fails and you explicitly allow insecure loading, try torch.load(..., weights_only=False).
To allow this, set the environment variable LANGUAGE_LOAD_ALLOW_INSECURE=1.
NOTE: loading with weights_only=False may execute arbitrary code from the file. Only do this
when you trust the source of language.bin.
- Try pickle.load as a last attempt (may fail for many binary formats).
3. Fallback: no model loaded (the app will fall back to heuristics).
Security note:
- Re-running torch.load with weights_only=False can run arbitrary code embedded in the file.
Only enable LANGUAGE_LOAD_ALLOW_INSECURE if you trust the file origin.
"""
from pathlib import Path
import logging
import importlib
import io
import sys
logger = logging.getLogger("local_language")
logger.setLevel(logging.INFO)
_model = None
_load_errors = []
def _try_import_language_module():
# If a language.py exists, prefer importing it (app already tries this but we expose here)
try:
mod = importlib.import_module("language")
logger.info("Found importable language.py module; using it.")
return mod
except Exception as e:
_load_errors.append(("import_language_py", repr(e)))
return None
def _is_likely_safetensors(path: Path) -> bool:
# Heuristic: safetensors files are usually small header-less binary; if file ends with .safetensors we try it.
return path.suffix == ".safetensors" or path.name.endswith(".safetensors")
def _try_safetensors_load(path: Path):
try:
from safetensors.torch import load_file as st_load # type: ignore
except Exception as e:
_load_errors.append(("safetensors_not_installed", repr(e)))
return None
try:
tensors = st_load(str(path))
logger.info("Loaded safetensors file into tensor dict (language.bin treated as safetensors).")
# Return the dict; user wrapper may adapt it.
return tensors
except Exception as e:
_load_errors.append(("safetensors_load_failed", repr(e)))
return None
def _try_torch_load(path: Path, weights_only: bool):
try:
import torch
except Exception as e:
_load_errors.append(("torch_not_installed", repr(e)))
return None
try:
# In PyTorch 2.6+, torch.load defaults weights_only=True. Passing explicitly for clarity.
obj = torch.load(str(path), map_location="cpu", weights_only=weights_only)
logger.info(f"torch.load succeeded (weights_only={weights_only}).")
return obj
except TypeError as e:
# Older torch versions don't accept weights_only kwarg; try without it (older API)
try:
obj = torch.load(str(path), map_location="cpu")
logger.info("torch.load succeeded (no weights_only argument supported by local torch).")
return obj
except Exception as e2:
_load_errors.append(("torch_load_typeerror_then_failed", repr(e2)))
return None
except Exception as e:
_load_errors.append((f"torch_load_failed_weights_only={weights_only}", repr(e)))
return None
def _try_pickle_load(path: Path):
try:
import pickle
with open(path, "rb") as f:
obj = pickle.load(f)
logger.info("Loaded language.bin via pickle.")
return obj
except Exception as e:
_load_errors.append(("pickle_load_failed", repr(e)))
return None
def _attempt_load(path: Path):
# 1) Safetensors heuristics
if _is_likely_safetensors(path):
logger.info("language.bin looks like safetensors (by filename). Attempting safetensors load.")
obj = _try_safetensors_load(path)
if obj is not None:
return obj
# 2) Try torch.load in safe (weights-only) mode first (PyTorch 2.6+ default is weights_only=True)
obj = _try_torch_load(path, weights_only=True)
if obj is not None:
return obj
# 3) If env var allows insecure loading, try weights_only=False (dangerous)
allow_insecure = str(os.environ.get("LANGUAGE_LOAD_ALLOW_INSECURE", "")).lower() in ("1", "true", "yes")
if allow_insecure:
logger.warning("LANGUAGE_LOAD_ALLOW_INSECURE is set -> attempting torch.load with weights_only=False (INSECURE).")
obj = _try_torch_load(path, weights_only=False)
if obj is not None:
return obj
else:
logger.warning("torch.load(weights_only=False) failed or returned None.")
# 4) Try pickle as last resort
obj = _try_pickle_load(path)
if obj is not None:
return obj
return None
def _load_language_bin_if_present():
global _model
p = Path("language.bin")
if not p.exists():
return None
logger.info("language.bin found; attempting to load with safe fallbacks...")
# Try multiple strategies
obj = _attempt_load(p)
if obj is None:
logger.warning("All attempts to load language.bin failed. See _load_errors for details.")
else:
_model = obj
return obj
def load():
"""
Public loader. Returns the loaded model/object or None.
"""
global _model
# Prefer an explicit language.py module if present on sys.path.
mod = _try_import_language_module()
if mod is not None:
_model = mod
return _model
# Attempt to load language.bin if present
obj = _load_language_bin_if_present()
return obj
# Run load on import (app calls load_local_language_module separately too)
try:
load()
except Exception as e:
logger.warning(f"language.py loader encountered error during import: {e}")
# --- Adapter / API functions the app expects --- #
def model_info() -> dict:
"""
Return a small summary about the loaded model/object to help debugging.
"""
info = {"loaded": False, "type": None, "repr": None, "load_errors": list(_load_errors)[:20], "has_translate": False, "has_detect": False, "callable": False}
if _model is None:
return info
info["loaded"] = True
try:
info["type"] = type(_model).__name__
except Exception:
info["type"] = "<unknown>"
try:
info["repr"] = repr(_model)[:1000]
except Exception:
info["repr"] = "<unreprable>"
try:
info["has_translate"] = hasattr(_model, "translate")
info["has_translate_to_en"] = hasattr(_model, "translate_to_en")
info["has_translate_from_en"] = hasattr(_model, "translate_from_en")
info["has_detect"] = hasattr(_model, "detect") or hasattr(_model, "detect_language")
info["callable"] = callable(_model)
if hasattr(_model, "__dir__"):
try:
info["dir"] = [n for n in dir(_model) if not n.startswith("_")]
except Exception:
info["dir"] = []
except Exception:
pass
return info
def _safe_call_translate(text: str, src: str, tgt: str) -> str:
"""
Try multiple call patterns to invoke translation functions on the loaded object.
Fall back to returning original text if nothing works.
"""
if _model is None:
return text
# 1) Preferred explicit API
try:
if hasattr(_model, "translate"):
try:
return _model.translate(text, src, tgt)
except TypeError:
try:
# some translate implementations take (text, "src->tgt")
return _model.translate(text, f"{src}->{tgt}")
except Exception:
pass
except Exception as e:
logger.debug(f"_model.translate attempt failed: {e}")
# 2) Dedicated helpers
try:
if tgt.lower() in ("en", "eng") and hasattr(_model, "translate_to_en"):
return _model.translate_to_en(text, src)
except Exception as e:
logger.debug(f"_model.translate_to_en attempt failed: {e}")
try:
if src.lower() in ("en", "eng") and hasattr(_model, "translate_from_en"):
return _model.translate_from_en(text, tgt)
except Exception as e:
logger.debug(f"_model.translate_from_en attempt failed: {e}")
# 3) Callable model (call signature may vary)
try:
if callable(_model):
try:
return _model(text, src, tgt)
except TypeError:
try:
return _model(text, src) # maybe (text, src)
except TypeError:
try:
return _model(text) # maybe (text)
except Exception:
pass
except Exception as e:
logger.debug(f"_model callable attempts failed: {e}")
# 4) HF-style model object with attached tokenizer (best-effort)
try:
# model could be a dict of tensors (weights-only) - not directly usable for translation
tokenizer = getattr(_model, "tokenizer", None)
generate = getattr(_model, "generate", None)
if tokenizer and generate:
inputs = tokenizer([text], return_tensors="pt", truncation=True)
outputs = _model.generate(**inputs, max_length=1024)
decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return decoded
except Exception as e:
logger.debug(f"_model HF-style generate attempt failed: {e}")
# 5) dict-like mapping (('src','tgt') -> fn or str)
try:
if isinstance(_model, dict):
key = (src, tgt)
if key in _model:
val = _model[key]
if callable(val):
return val(text)
if isinstance(val, str):
return val
key2 = f"{src}->{tgt}"
if key2 in _model:
val = _model[key2]
if callable(val):
return val(text)
if isinstance(val, str):
return val
except Exception as e:
logger.debug(f"_model dict-like attempt failed: {e}")
# Nothing worked: return input (no hallucination)
return text
def translate(text: str, src: str, tgt: str) -> str:
if not text:
return text
return _safe_call_translate(text, src or "und", tgt or "und")
def translate_to_en(text: str, src: str) -> str:
if not text:
return text
# prefer dedicated helper if present
try:
if _model is not None and hasattr(_model, "translate_to_en"):
return _model.translate_to_en(text, src)
except Exception:
pass
return translate(text, src, "en")
def translate_from_en(text: str, tgt: str) -> str:
if not text:
return text
try:
if _model is not None and hasattr(_model, "translate_from_en"):
return _model.translate_from_en(text, tgt)
except Exception:
pass
return translate(text, "en", tgt)
def detect(text: str) -> str:
"""
Call detection if the model exposes it. Returns None if not available.
"""
if not text:
return None
if _model is None:
return None
try:
if hasattr(_model, "detect_language"):
return _model.detect_language(text)
if hasattr(_model, "detect"):
return _model.detect(text)
except Exception as e:
logger.debug(f"model detect attempt failed: {e}")
return None
# Small helper for CLI testing
if __name__ == "__main__":
import sys
print("model_info:", model_info())
if len(sys.argv) >= 4:
src = sys.argv[1]
tgt = sys.argv[2]
txt = " ".join(sys.argv[3:])
print("translate:", translate(txt, src, tgt))
else:
print("Usage: python language.py <src> <tgt> <text...>")
print("Example: python language.py es en 'hola mundo'")