Revistas / build_parquet_embeddings.py
Romanes's picture
Upload 6 files
40fe9ab verified
import argparse
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
def ensure_col(df: pd.DataFrame, name: str):
if name not in df.columns:
df[name] = ""
return df
def main():
p = argparse.ArgumentParser(description="Construye un Parquet con embeddings E5 para recomendación de revistas.")
p.add_argument("--csv", required=True, help="Ruta al CSV exportado (ej. uptc_afid60077378_scopus_export.csv)")
p.add_argument("--out", default="scopus_corpus.parquet", help="Ruta de salida Parquet")
p.add_argument("--model", default="intfloat/multilingual-e5-small", help="Modelo Sentence-Transformers")
p.add_argument("--batch-size", type=int, default=64, help="Tamaño de batch para el encode")
args = p.parse_args()
df = pd.read_csv(args.csv)
# Asegurar columnas mínimas del export "simple"
for c in ["Title","Source title","ISSN","eISSN","Year","Cited by","DOI","Link","EID","Document Type","Open Access"]:
ensure_col(df, c)
# Texto para similitud: funciona aunque no haya Abstract/Keywords
# Usamos título + (revista como contexto suave)
df["text_for_match"] = (
df["Title"].fillna("").astype(str).str.strip()
+ ". Revista: "
+ df["Source title"].fillna("").astype(str).str.strip()
).str.replace(r"\s+", " ", regex=True).str.strip()
# Cargar modelo
print(f"Cargando modelo: {args.model}")
model = SentenceTransformer(args.model, device="cpu")
# Prefijo E5: "passage: " para el corpus
texts = ["passage: " + t if t else "passage: " for t in df["text_for_match"].tolist()]
print(f"Codificando {len(texts)} textos…")
embs = model.encode(
texts,
batch_size=args.batch_size,
show_progress_bar=True,
normalize_embeddings=True, # importantes para coseno con producto punto
).astype(np.float32)
# Normalizaciones de tipos
year = pd.to_numeric(df["Year"], errors="coerce").astype("Int64")
cited = pd.to_numeric(df["Cited by"], errors="coerce").fillna(0).astype(np.int32)
# Construcción de la tabla Arrow
table = pa.table({
"eid": pa.array(df["EID"].astype(str).tolist()),
"title": pa.array(df["Title"].astype(str).tolist()),
"source_title": pa.array(df["Source title"].astype(str).tolist()),
"issn": pa.array(df["ISSN"].fillna("").astype(str).tolist()),
"eissn": pa.array(df["eISSN"].fillna("").astype(str).tolist()),
"year": pa.array(year.tolist(), type=pa.int64()),
"cited_by": pa.array(cited.tolist(), type=pa.int32()),
"doi": pa.array(df["DOI"].fillna("").astype(str).tolist()),
"link": pa.array(df["Link"].fillna("").astype(str).tolist()),
"Document Type": pa.array(df["Document Type"].astype(str).tolist()),
"Open Access": pa.array(df["Open Access"].astype(str).tolist()),
"text_for_match": pa.array(df["text_for_match"].tolist()),
"embedding": pa.array(embs.tolist(), type=pa.list_(pa.float32())),
})
pq.write_table(table, args.out, compression="zstd")
dim = len(embs[0]) if len(embs) else 0
print(f"OK -> {args.out} | filas: {table.num_rows} | dim: {dim}")
if __name__ == "__main__":
main()