sayanAIAI commited on
Commit
fd8623d
·
verified ·
1 Parent(s): 2884a69

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +78 -22
main.py CHANGED
@@ -1,38 +1,94 @@
1
  import os
2
  os.environ['HF_HOME'] = '/tmp'
3
- from flask import Flask, render_template, request, jsonify
4
- from transformers import pipeline, AutoTokenizer
 
 
5
 
6
  app = Flask(__name__)
7
 
8
- model_name = "sshleifer/distilbart-cnn-12-6"
9
- tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- summarizer = pipeline("summarization", model=model_name)
 
11
 
12
- @app.route("/")
13
- def index():
14
- return render_template("index.html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  @app.route("/summarize", methods=["POST"])
17
- def summarize():
18
- try:
19
- data = request.get_json()
20
- text = data.get("text", "").strip()
21
- if not text:
22
- return jsonify({"error": "No text provided"}), 400
 
 
23
 
24
- input_tokens = tokenizer.encode(text, return_tensors="pt")
25
- input_len = input_tokens.shape[1]
 
26
 
27
- max_len = max(10, min(100, input_len // 2))
28
- min_len = max(5, max_len // 2)
 
 
 
 
 
 
29
 
30
- summary = summarizer(text, max_length=max_len, min_length=min_len, do_sample=False)[0]['summary_text']
31
- return jsonify({"summary": summary})
32
- except Exception as e:
33
- return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
34
 
 
 
 
 
35
 
 
36
 
37
  if __name__ == "__main__":
38
  app.run(debug=True,port=7860)
 
1
  import os
2
  os.environ['HF_HOME'] = '/tmp'
3
+ # main.py (excerpt)
4
+ from flask import Flask, request, jsonify
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ import math, textwrap
7
 
8
  app = Flask(__name__)
9
 
10
+ MODEL_NAME = "sshleifer/distilbart-cnn-12-6"
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
13
+ summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1) # set device appropriately
14
 
15
+ # Simple mapping of presets to generation lengths
16
+ LENGTH_PRESETS = {
17
+ "short": {"min_length": 20, "max_length": 60},
18
+ "medium": {"min_length": 60, "max_length": 130},
19
+ "long": {"min_length": 130, "max_length": 300},
20
+ }
21
+
22
+ def chunk_text_by_chars(text, max_chars=1500, overlap=200):
23
+ if len(text) <= max_chars:
24
+ return [text]
25
+ parts = []
26
+ start = 0
27
+ while start < len(text):
28
+ end = min(len(text), start + max_chars)
29
+ # try to break at newline or sentence boundary for nicer chunking
30
+ chunk = text[start:end]
31
+ # extend to nearest newline if possible (avoid cutting sentences)
32
+ nl = chunk.rfind('\n')
33
+ if nl > max_chars*0.6:
34
+ end = start + nl
35
+ chunk = text[start:end]
36
+ parts.append(chunk.strip())
37
+ start = end - overlap
38
+ return parts
39
+
40
+ def apply_tone_instruction(text, tone):
41
+ tone = (tone or "neutral").lower()
42
+ if tone == "formal":
43
+ instr = "Summarize in a formal, professional tone:"
44
+ elif tone == "casual":
45
+ instr = "Summarize in a casual, conversational tone:"
46
+ elif tone == "bullet":
47
+ instr = "Summarize into short bullet points:"
48
+ else:
49
+ instr = "Summarize:"
50
+ return f"{instr}\n\n{text}"
51
 
52
  @app.route("/summarize", methods=["POST"])
53
+ def summarize_route():
54
+ data = request.get_json(force=True)
55
+ text = data.get("text", "")[:20000] # safe cap
56
+ length = data.get("length", "medium")
57
+ tone = data.get("tone", "neutral")
58
+
59
+ if not text or len(text.split()) < 5:
60
+ return jsonify({"error": "Input too short."}), 400
61
 
62
+ preset = LENGTH_PRESETS.get(length, LENGTH_PRESETS["medium"])
63
+ chunks = chunk_text_by_chars(text, max_chars=1500, overlap=200)
64
+ summaries = []
65
 
66
+ for chunk in chunks:
67
+ prompted = apply_tone_instruction(chunk, tone)
68
+ # call summarizer with min/max lengths
69
+ out = summarizer(prompted,
70
+ min_length=preset["min_length"],
71
+ max_length=preset["max_length"],
72
+ truncation=True)[0]["summary_text"]
73
+ summaries.append(out.strip())
74
 
75
+ # If multiple chunk summaries, join and compress once more
76
+ if len(summaries) == 1:
77
+ final = summaries[0]
78
+ else:
79
+ combined = "\n\n".join(summaries)
80
+ prompted = apply_tone_instruction(combined, tone)
81
+ final = summarizer(prompted,
82
+ min_length=preset["min_length"],
83
+ max_length=preset["max_length"],
84
+ truncation=True)[0]["summary_text"]
85
 
86
+ # if bullet tone, post-process
87
+ if tone == "bullet":
88
+ lines = [l.strip() for s in final.splitlines() for l in s.split(". ") if l.strip()]
89
+ final = "\n".join(f"- {l.rstrip('.')}" for l in lines[:20])
90
 
91
+ return jsonify({"summary": final})
92
 
93
  if __name__ == "__main__":
94
  app.run(debug=True,port=7860)