Spaces:
Sleeping
Sleeping
Update utils/sdg_classifier.py
Browse files- utils/sdg_classifier.py +31 -31
utils/sdg_classifier.py
CHANGED
|
@@ -14,27 +14,27 @@ except ImportError:
|
|
| 14 |
logging.info("Streamlit not installed")
|
| 15 |
|
| 16 |
## Labels dictionary ###
|
| 17 |
-
_lab_dict = {0: '
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
|
| 36 |
@st.cache(allow_output_mutation=True)
|
| 37 |
-
def
|
| 38 |
"""
|
| 39 |
loads the document classifier using haystack, where the name/path of model
|
| 40 |
in HF-hub as string is used to fetch the model object.Either configfile or
|
|
@@ -57,7 +57,7 @@ def load_sdgClassifier(config_file:str = None, classifier_name:str = None):
|
|
| 57 |
return
|
| 58 |
else:
|
| 59 |
config = getconfig(config_file)
|
| 60 |
-
classifier_name = config.get('
|
| 61 |
|
| 62 |
logging.info("Loading classifier")
|
| 63 |
doc_classifier = TransformersDocumentClassifier(
|
|
@@ -68,7 +68,7 @@ def load_sdgClassifier(config_file:str = None, classifier_name:str = None):
|
|
| 68 |
|
| 69 |
|
| 70 |
@st.cache(allow_output_mutation=True)
|
| 71 |
-
def
|
| 72 |
threshold:float = 0.8,
|
| 73 |
classifier_model:TransformersDocumentClassifier= None
|
| 74 |
)->Tuple[DataFrame,Series]:
|
|
@@ -95,10 +95,10 @@ def sdg_classification(haystack_doc:List[Document],
|
|
| 95 |
the number of times it is covered/discussed/count_of_paragraphs.
|
| 96 |
|
| 97 |
"""
|
| 98 |
-
logging.info("Working on
|
| 99 |
if not classifier_model:
|
| 100 |
if check_streamlit():
|
| 101 |
-
classifier_model = st.session_state['
|
| 102 |
else:
|
| 103 |
logging.warning("No streamlit envinornment found, Pass the classifier")
|
| 104 |
return
|
|
@@ -109,23 +109,23 @@ def sdg_classification(haystack_doc:List[Document],
|
|
| 109 |
labels_= [(l.meta['classification']['label'],
|
| 110 |
l.meta['classification']['score'],l.content,) for l in results]
|
| 111 |
|
| 112 |
-
df = DataFrame(labels_, columns=["
|
| 113 |
|
| 114 |
df = df.sort_values(by="Relevancy", ascending=False).reset_index(drop=True)
|
| 115 |
df.index += 1
|
| 116 |
df =df[df['Relevancy']>threshold]
|
| 117 |
|
| 118 |
# creating the dataframe for value counts of SDG, along with 'title' of SDGs
|
| 119 |
-
x = df['
|
| 120 |
x = x.rename('count')
|
| 121 |
-
x = x.rename_axis('
|
| 122 |
-
x["
|
| 123 |
x = x.sort_values(by=['count'], ascending=False)
|
| 124 |
-
x['SDG_name'] = x['
|
| 125 |
-
x['SDG_Num'] = x['
|
| 126 |
|
| 127 |
-
df['
|
| 128 |
-
df = df.sort_values('
|
| 129 |
|
| 130 |
return df, x
|
| 131 |
|
|
|
|
| 14 |
logging.info("Streamlit not installed")
|
| 15 |
|
| 16 |
## Labels dictionary ###
|
| 17 |
+
_lab_dict = {0: 'Agricultural communities',
|
| 18 |
+
1: 'Children',
|
| 19 |
+
2: 'Coastal communities',
|
| 20 |
+
3: 'Ethnic, racial or other minorities',
|
| 21 |
+
4: 'Fishery communities',
|
| 22 |
+
5: 'Informal sector workers',
|
| 23 |
+
6: 'Members of indigenous and local communities',
|
| 24 |
+
7: 'Migrants and displaced persons',
|
| 25 |
+
8: 'Older persons',
|
| 26 |
+
9: 'Other',
|
| 27 |
+
10: 'Persons living in poverty',
|
| 28 |
+
11: 'Persons with disabilities',
|
| 29 |
+
12: 'Persons with pre-existing health conditions',
|
| 30 |
+
13: 'Residents of drought-prone regions',
|
| 31 |
+
14: 'Rural populations',
|
| 32 |
+
15: 'Sexual minorities (LGBTQI+)',
|
| 33 |
+
16: 'Urban populations',
|
| 34 |
+
17: 'Women and other genders'}
|
| 35 |
|
| 36 |
@st.cache(allow_output_mutation=True)
|
| 37 |
+
def load_Classifier(config_file:str = None, classifier_name:str = None):
|
| 38 |
"""
|
| 39 |
loads the document classifier using haystack, where the name/path of model
|
| 40 |
in HF-hub as string is used to fetch the model object.Either configfile or
|
|
|
|
| 57 |
return
|
| 58 |
else:
|
| 59 |
config = getconfig(config_file)
|
| 60 |
+
classifier_name = config.get('vulnerability','MODEL')
|
| 61 |
|
| 62 |
logging.info("Loading classifier")
|
| 63 |
doc_classifier = TransformersDocumentClassifier(
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
@st.cache(allow_output_mutation=True)
|
| 71 |
+
def classification(haystack_doc:List[Document],
|
| 72 |
threshold:float = 0.8,
|
| 73 |
classifier_model:TransformersDocumentClassifier= None
|
| 74 |
)->Tuple[DataFrame,Series]:
|
|
|
|
| 95 |
the number of times it is covered/discussed/count_of_paragraphs.
|
| 96 |
|
| 97 |
"""
|
| 98 |
+
logging.info("Working on Vulnerability Classification")
|
| 99 |
if not classifier_model:
|
| 100 |
if check_streamlit():
|
| 101 |
+
classifier_model = st.session_state['vulnerability_classifier']
|
| 102 |
else:
|
| 103 |
logging.warning("No streamlit envinornment found, Pass the classifier")
|
| 104 |
return
|
|
|
|
| 109 |
labels_= [(l.meta['classification']['label'],
|
| 110 |
l.meta['classification']['score'],l.content,) for l in results]
|
| 111 |
|
| 112 |
+
df = DataFrame(labels_, columns=["Vulnerability","Relevancy","text"])
|
| 113 |
|
| 114 |
df = df.sort_values(by="Relevancy", ascending=False).reset_index(drop=True)
|
| 115 |
df.index += 1
|
| 116 |
df =df[df['Relevancy']>threshold]
|
| 117 |
|
| 118 |
# creating the dataframe for value counts of SDG, along with 'title' of SDGs
|
| 119 |
+
x = df['Vulnerability'].value_counts()
|
| 120 |
x = x.rename('count')
|
| 121 |
+
x = x.rename_axis('Vulnerability').reset_index()
|
| 122 |
+
x["Vulnerability"] = pd.to_numeric(x["Vulnerability"])
|
| 123 |
x = x.sort_values(by=['count'], ascending=False)
|
| 124 |
+
x['SDG_name'] = x['Vulnerability'].apply(lambda x: _lab_dict[x])
|
| 125 |
+
x['SDG_Num'] = x['Vulnerability'].apply(lambda x: "Vulnerability "+str(x))
|
| 126 |
|
| 127 |
+
df['Vulnerability'] = pd.to_numeric(df['Vulnerability'])
|
| 128 |
+
df = df.sort_values('Vulnerability')
|
| 129 |
|
| 130 |
return df, x
|
| 131 |
|