ionosphere commited on
Commit
e05e6ee
·
1 Parent(s): 7814748

First commit for Gaia Space

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. Dockerfile +1 -0
  3. README.md +24 -1
  4. requirements.txt +4 -1
  5. src/streamlit_app.py +313 -29
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
Dockerfile CHANGED
@@ -9,6 +9,7 @@ RUN apt-get update && apt-get install -y \
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
  COPY requirements.txt ./
 
12
  COPY src/ ./src/
13
 
14
  RUN pip3 install -r requirements.txt
 
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
  COPY requirements.txt ./
12
+ COPY .env ./src/
13
  COPY src/ ./src/
14
 
15
  RUN pip3 install -r requirements.txt
README.md CHANGED
@@ -11,10 +11,33 @@ pinned: false
11
  short_description: Example de reconnaissance de caractères avec MistralOCR
12
  license: mit
13
  ---
 
14
 
15
- # Welcome to Streamlit!
16
 
17
  Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
 
19
  If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
  forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  short_description: Example de reconnaissance de caractères avec MistralOCR
12
  license: mit
13
  ---
14
+ ![image](https://www.osfarm.org/assets/img/logo_white.png)
15
 
16
+ # Welcome to Gaia template OCR build by OSFarm !
17
 
18
  Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
19
 
20
  If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
21
  forums](https://discuss.streamlit.io).
22
+
23
+ ## Déploiement local (dev mode)
24
+
25
+ 1. Set a correct .env
26
+
27
+ ```env
28
+ MISTRAL_API_KEY=XXXXXXXXXXXXXXXXXXXXXXXXX
29
+ ```
30
+
31
+ 2. Build image
32
+
33
+ ```sh
34
+ docker build --pull --rm -f 'Dockerfile' -t 'mistralocr:latest' '.'
35
+ ```
36
+
37
+ 3. Launch app
38
+
39
+ ```sh
40
+ docker run --rm -d -p 8501:8501/tcp mistralocr:latest
41
+ ```
42
+
43
+ Open a Firefox or Chrome at localhost:8501
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
  altair
2
  pandas
3
- streamlit
 
 
 
 
1
  altair
2
  pandas
3
+ streamlit
4
+ mistralai
5
+ streamlit_chat
6
+ python-dotenv
src/streamlit_app.py CHANGED
@@ -1,10 +1,20 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  """
7
- # Welcome to Streamlit!
 
 
8
 
9
  Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
  If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
@@ -13,28 +23,302 @@ forums](https://discuss.streamlit.io).
13
  In the meantime, below is an example of what you can do with just a few lines of code:
14
  """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import base64
3
+ import tempfile
4
+ import os
5
+ from mistralai import Mistral
6
+ from PIL import Image
7
+ import io
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
+
12
+ MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY")
13
 
14
  """
15
+ ![image](https://www.osfarm.org/assets/img/logo_white.png)
16
+
17
+ # Welcome to Gaia OCR Template by OSFarm!
18
 
19
  Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
20
  If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
 
23
  In the meantime, below is an example of what you can do with just a few lines of code:
24
  """
25
 
26
+ SYSTEM_PROMPT = """From the user prompt coming from purchase invoice below, extract informations strictly as instructed.
27
+ Most of the time, the pattern of a purchase invoice is composed of supplier informations, invoice informations and one or many invoice lines.
28
+ Information come from France in french language.
29
+ Return the purchase informations in JSON format like an API according to the schema.
30
+ Do not return 'description', 'type' or 'format' attributes in the response.
31
+ Use it only to detect correct value of each attributes.
32
+ example of a response : { supplier: { name: "AXA", address: "10 rue du Bouil bleu", postal_code: "17250", ... }, invoice: {number: "FA25632", ... }, items: [{number: '1', ... }, {number: '2', ... }, ...]}.
33
+ for the items, try to detect the role of the item in 'merchandise' or 'service' in role attribute.
34
+ for all the date, try to convert it in the following format : 'DD/MM/YYYY'.
35
+ for the items, try to classify it like an accountant in nature attribute.
36
+ """
37
+
38
+ JSON_SCHEMA = {
39
+ "name": "PurchaseInvoice",
40
+ "schema_definition": {
41
+ "$defs": {
42
+ "Explanation": {
43
+ "properties": {
44
+ "explanation": {
45
+ "title": "Explanation",
46
+ "type": "string",
47
+ },
48
+ "output": {"title": "Output", "type": "string"},
49
+ },
50
+ "required": ["explanation", "output"],
51
+ "title": "Explanation",
52
+ "type": "object",
53
+ "additionalProperties": False,
54
+ }
55
+ },
56
+ "properties": {
57
+ "steps": {
58
+ "items": {"$ref": "#/$defs/Explanation"},
59
+ "title": "Steps",
60
+ "type": "array",
61
+ },
62
+ "final_answer": {"title": "Final Answer", "type": "string"},
63
+ },
64
+ "required": ["steps", "final_answer"],
65
+ "title": "MathDemonstration",
66
+ "type": "object",
67
+ "additionalProperties": False,
68
+ },
69
+ "description": None,
70
+ "strict": True
71
+ }
72
+
73
+ def upload_pdf(client, content, filename):
74
+ """
75
+ Uploads a PDF to Mistral's API and retrieves a signed URL for processing.
76
+
77
+ Args:
78
+ client (Mistral): Mistral API client instance.
79
+ content (bytes): The content of the PDF file.
80
+ filename (str): The name of the PDF file.
81
+
82
+ Returns:
83
+ str: Signed URL for the uploaded PDF.
84
+ """
85
+ with tempfile.TemporaryDirectory() as temp_dir:
86
+ temp_path = os.path.join(temp_dir, filename)
87
+
88
+ with open(temp_path, "wb") as tmp:
89
+ tmp.write(content)
90
+
91
+ try:
92
+ with open(temp_path, "rb") as file_obj:
93
+ file_upload = client.files.upload(
94
+ file={"file_name": filename, "content": file_obj},
95
+ purpose="ocr"
96
+ )
97
+
98
+ signed_url = client.files.get_signed_url(file_id=file_upload.id)
99
+ return signed_url.url
100
+ finally:
101
+ if os.path.exists(temp_path):
102
+ os.remove(temp_path)
103
+
104
+ def extract_json_from_doc(client, document_source):
105
+ """
106
+ Extracts JSON data from a document using Mistral's OCR API.
107
+
108
+ Args:
109
+ client (Mistral): Mistral API client instance.
110
+ document_source (dict): The source of the document (URL or image).
111
+
112
+ Returns:
113
+ dict: The extracted JSON data.
114
+ """
115
+ # Specify model
116
+ model = "mistral-small-latest"
117
+
118
+ messages = [
119
+ {
120
+ "role": "system",
121
+ "content": SYSTEM_PROMPT,
122
+ },
123
+ {
124
+ "role": "user",
125
+ "content": [
126
+ {
127
+ "type": "text",
128
+ "text": "what is the last sentence in the document"
129
+ },
130
+ document_source
131
+ ]
132
+ }
133
+ ]
134
+
135
+ print(messages)
136
+
137
+ chat_response = client.chat.complete(
138
+ model=model,
139
+ messages=messages,
140
+ response_format = {
141
+ "type": "json_object" #, "json_schema": JSON_SCHEMA
142
+ }
143
+ )
144
+
145
+ print(chat_response.choices[0].message.content)
146
+
147
+ return chat_response.choices[0].message.content
148
+
149
+ def process_ocr(client, document_source):
150
+ """
151
+ Processes a document using Mistral's OCR API.
152
+
153
+ Args:
154
+ client (Mistral): Mistral API client instance.
155
+ document_source (dict): The source of the document (URL or image).
156
+
157
+ Returns:
158
+ OCRResponse: The response from Mistral's OCR API.
159
+ """
160
+ return client.ocr.process(
161
+ model="mistral-ocr-latest",
162
+ document=document_source,
163
+ include_image_base64=True
164
+ )
165
+
166
+ def display_pdf(file):
167
+ """
168
+ Displays a PDF in Streamlit using an iframe.
169
+
170
+ Args:
171
+ file (str): Path to the PDF file.
172
+ """
173
+ with open(file, "rb") as f:
174
+ base64_pdf = base64.b64encode(f.read()).decode("utf-8")
175
+ pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="700" height="1000" type="application/pdf"></iframe>'
176
+ st.markdown(pdf_display, unsafe_allow_html=True)
177
+
178
+ def main():
179
+ """
180
+ Main function to run the Streamlit app.
181
+ """
182
+ st.set_page_config(page_title="OCR Facture avec Mistral", layout="wide")
183
+
184
+ # Sidebar: Authentication for Mistral API
185
+ if not MISTRAL_API_KEY:
186
+ api_key = st.sidebar.text_input("Mistral API Key", type="password")
187
+ else:
188
+ api_key = MISTRAL_API_KEY
189
+
190
+ if not api_key:
191
+ st.warning("Enter API key to continue")
192
+ return
193
+
194
+ # Initialize Mistral API client
195
+ client = Mistral(api_key=api_key)
196
+
197
+ # Main app interface
198
+ st.header("OCR Facture avec Mistral")
199
+
200
+ # Input method selection: URL, PDF Upload, or Image Upload
201
+ input_method = st.radio("Format de la facture:", ["URL", "PDF", "Image"])
202
+
203
+ document_source = None
204
+ preview_content = None
205
+ content_type = None
206
+
207
+ if input_method == "URL":
208
+ # Handle document URL input
209
+ url = st.text_input("Document URL:")
210
+ if url:
211
+ document_source = {
212
+ "type": "document_url",
213
+ "document_url": url
214
+ }
215
+ preview_content = url
216
+ content_type = "url"
217
+
218
+ elif input_method == "PDF":
219
+ # Handle PDF file upload
220
+ uploaded_file = st.file_uploader("Choisissez un PDF", type=["pdf"])
221
+ if uploaded_file:
222
+ content = uploaded_file.read()
223
+ preview_content = uploaded_file
224
+
225
+ # Save the uploaded PDF temporarily for display purposes
226
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
227
+ tmp.write(content)
228
+ pdf_path = tmp.name
229
+
230
+ display_pdf(pdf_path) # Display the uploaded PDF
231
+
232
+ # Prepare document source for OCR processing
233
+ document_source = {
234
+ "type": "document_url",
235
+ "document_url": upload_pdf(client, content, uploaded_file.name)
236
+ }
237
+ content_type = "pdf"
238
+
239
+ elif input_method == "Image":
240
+ # Handle image file upload
241
+ uploaded_image = st.file_uploader("Choisissez une image", type=["png", "jpg", "jpeg"])
242
+ if uploaded_image:
243
+ # Display the uploaded image
244
+ image = Image.open(uploaded_image)
245
+ st.image(image, caption="Uploaded Image", use_container_width=True)
246
+
247
+ # Convert image to base64
248
+ buffered = io.BytesIO()
249
+ image.save(buffered, format="PNG")
250
+ img_str = base64.b64encode(buffered.getvalue()).decode()
251
+
252
+ # Prepare document source for OCR processing
253
+ document_source = {
254
+ "type": "image_url",
255
+ "image_url": f"data:image/png;base64,{img_str}"
256
+ }
257
+ content_type = "image"
258
+
259
+ if document_source and st.button("Générer les données au format JSON"):
260
+ # Process the document when the user clicks the button
261
+ with st.spinner("Extracting JSON content..."):
262
+ try:
263
+ ocr_response = extract_json_from_doc(client, document_source)
264
+
265
+ with st.expander("Response"):
266
+ st.json(ocr_response)
267
+
268
+ except Exception as e:
269
+ # Display an error message if processing fails
270
+ st.error(f"Processing error: {str(e)}")
271
+
272
+ if document_source and st.button("Générer un Document"):
273
+ # Process the document when the user clicks the button
274
+ with st.spinner("Extracting content..."):
275
+ try:
276
+ ocr_response = process_ocr(client, document_source)
277
+
278
+ if ocr_response and ocr_response.pages:
279
+ # Combine extracted text from all pages into one string
280
+ extracted_content = "\n\n".join(
281
+ [f"**Page {i+1}**\n{page.markdown}"
282
+ for i, page in enumerate(ocr_response.pages)]
283
+ )
284
+
285
+ # Display extracted content in Markdown format
286
+ st.subheader("Extracted Content")
287
+ st.markdown(extracted_content)
288
+
289
+ # Prepare plain text version
290
+ plain_text_content = "\n\n".join(
291
+ [f"Page {i+1}\n{page.markdown}"
292
+ for i, page in enumerate(ocr_response.pages)]
293
+ )
294
+
295
+ # Add download buttons for both text and Markdown formats
296
+ col1, col2 = st.columns(2)
297
+ with col1:
298
+ st.download_button(
299
+ label="Download Text",
300
+ data=plain_text_content,
301
+ file_name="extracted_content.txt",
302
+ mime="text/plain"
303
+ )
304
+ with col2:
305
+ st.download_button(
306
+ label="Download Markdown",
307
+ data=extracted_content,
308
+ file_name="extracted_content.md",
309
+ mime="text/markdown"
310
+ )
311
+
312
+ # Optional: Show raw response for debugging purposes
313
+ with st.expander("Réponse API"):
314
+ st.json(ocr_response.model_dump())
315
+
316
+ else:
317
+ st.warning("No content extracted.")
318
+
319
+ except Exception as e:
320
+ # Display an error message if processing fails
321
+ st.error(f"Processing error: {str(e)}")
322
+
323
+ if __name__ == "__main__":
324
+ main()