Spaces:
Build error
Build error
meg-huggingface
commited on
Commit
·
a2ae370
1
Parent(s):
335424f
More modularizing; npmi and labels
Browse files- app.py +5 -12
- data_measurements/dataset_statistics.py +20 -20
- data_measurements/streamlit_utils.py +4 -5
app.py
CHANGED
|
@@ -118,9 +118,8 @@ def load_or_prepare(ds_args, show_embeddings, use_cache=False):
|
|
| 118 |
if show_embeddings:
|
| 119 |
logs.warning("Loading Embeddings")
|
| 120 |
dstats.load_or_prepare_embeddings()
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
dstats.load_or_prepare_npmi_terms()
|
| 124 |
logs.warning("Loading Zipf")
|
| 125 |
dstats.load_or_prepare_zipf()
|
| 126 |
return dstats
|
|
@@ -156,6 +155,8 @@ def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False):
|
|
| 156 |
# Embeddings widget
|
| 157 |
dstats.load_or_prepare_embeddings()
|
| 158 |
dstats.load_or_prepare_text_duplicates()
|
|
|
|
|
|
|
| 159 |
|
| 160 |
def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=True):
|
| 161 |
"""
|
|
@@ -179,17 +180,9 @@ def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=T
|
|
| 179 |
st_utils.expander_label_distribution(dstats.fig_labels, column_id)
|
| 180 |
st_utils.expander_text_lengths(dstats, column_id)
|
| 181 |
st_utils.expander_text_duplicates(dstats, column_id)
|
| 182 |
-
|
| 183 |
-
# We do the loading of these after the others in order to have some time
|
| 184 |
-
# to compute while the user works with the details above.
|
| 185 |
# Uses an interaction; handled a bit differently than other widgets.
|
| 186 |
logs.info("showing npmi widget")
|
| 187 |
-
npmi_stats
|
| 188 |
-
dstats, use_cache=use_cache
|
| 189 |
-
)
|
| 190 |
-
available_terms = npmi_stats.get_available_terms()
|
| 191 |
-
st_utils.npmi_widget(
|
| 192 |
-
column_id, available_terms, npmi_stats, _MIN_VOCAB_COUNT)
|
| 193 |
logs.info("showing zipf")
|
| 194 |
st_utils.expander_zipf(dstats.z, dstats.zipf_fig, column_id)
|
| 195 |
if show_embeddings:
|
|
|
|
| 118 |
if show_embeddings:
|
| 119 |
logs.warning("Loading Embeddings")
|
| 120 |
dstats.load_or_prepare_embeddings()
|
| 121 |
+
logs.warning("Loading nPMI")
|
| 122 |
+
dstats.load_or_prepare_npmi()
|
|
|
|
| 123 |
logs.warning("Loading Zipf")
|
| 124 |
dstats.load_or_prepare_zipf()
|
| 125 |
return dstats
|
|
|
|
| 155 |
# Embeddings widget
|
| 156 |
dstats.load_or_prepare_embeddings()
|
| 157 |
dstats.load_or_prepare_text_duplicates()
|
| 158 |
+
dstats.load_or_prepare_npmi()
|
| 159 |
+
dstats.load_or_prepare_zipf()
|
| 160 |
|
| 161 |
def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=True):
|
| 162 |
"""
|
|
|
|
| 180 |
st_utils.expander_label_distribution(dstats.fig_labels, column_id)
|
| 181 |
st_utils.expander_text_lengths(dstats, column_id)
|
| 182 |
st_utils.expander_text_duplicates(dstats, column_id)
|
|
|
|
|
|
|
|
|
|
| 183 |
# Uses an interaction; handled a bit differently than other widgets.
|
| 184 |
logs.info("showing npmi widget")
|
| 185 |
+
st_utils.npmi_widget(dstats.npmi_stats, _MIN_VOCAB_COUNT, column_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
logs.info("showing zipf")
|
| 187 |
st_utils.expander_zipf(dstats.z, dstats.zipf_fig, column_id)
|
| 188 |
if show_embeddings:
|
data_measurements/dataset_statistics.py
CHANGED
|
@@ -231,10 +231,6 @@ class DatasetStatisticsCacheClass:
|
|
| 231 |
# nPMI
|
| 232 |
# Holds a nPMIStatisticsCacheClass object
|
| 233 |
self.npmi_stats = None
|
| 234 |
-
# TODO: Users ideally can type in whatever words they want.
|
| 235 |
-
self.termlist = _IDENTITY_TERMS
|
| 236 |
-
# termlist terms that are available more than _MIN_VOCAB_COUNT times
|
| 237 |
-
self.available_terms = _IDENTITY_TERMS
|
| 238 |
# TODO: Have lowercase be an option for a user to set.
|
| 239 |
self.to_lowercase = True
|
| 240 |
# The minimum amount of times a word should occur to be included in
|
|
@@ -627,24 +623,27 @@ class DatasetStatisticsCacheClass:
|
|
| 627 |
if save:
|
| 628 |
write_plotly(self.fig_labels, self.fig_labels_fid)
|
| 629 |
else:
|
| 630 |
-
self.
|
| 631 |
-
self.label_dset = self.dset.map(
|
| 632 |
-
lambda examples: extract_field(
|
| 633 |
-
examples, self.label_field, OUR_LABEL_FIELD
|
| 634 |
-
),
|
| 635 |
-
batched=True,
|
| 636 |
-
remove_columns=list(self.dset.features),
|
| 637 |
-
)
|
| 638 |
-
self.label_df = self.label_dset.to_pandas()
|
| 639 |
-
self.fig_labels = make_fig_labels(
|
| 640 |
-
self.label_df, self.label_names, OUR_LABEL_FIELD
|
| 641 |
-
)
|
| 642 |
if save:
|
| 643 |
# save extracted label instances
|
| 644 |
self.label_dset.save_to_disk(self.label_dset_fid)
|
| 645 |
write_plotly(self.fig_labels, self.fig_labels_fid)
|
| 646 |
|
| 647 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 648 |
self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=self.use_cache)
|
| 649 |
self.npmi_stats.load_or_prepare_npmi_terms()
|
| 650 |
|
|
@@ -693,7 +692,10 @@ class nPMIStatisticsCacheClass:
|
|
| 693 |
# We need to preprocess everything.
|
| 694 |
mkdir(self.pmi_cache_path)
|
| 695 |
self.joint_npmi_df_dict = {}
|
| 696 |
-
|
|
|
|
|
|
|
|
|
|
| 697 |
logs.info(self.termlist)
|
| 698 |
self.use_cache = use_cache
|
| 699 |
# TODO: Let users specify
|
|
@@ -701,8 +703,6 @@ class nPMIStatisticsCacheClass:
|
|
| 701 |
self.min_vocab_count = self.dstats.min_vocab_count
|
| 702 |
self.subgroup_files = {}
|
| 703 |
self.npmi_terms_fid = pjoin(self.dstats.cache_path, "npmi_terms.json")
|
| 704 |
-
self.available_terms = self.dstats.available_terms
|
| 705 |
-
logs.info(self.available_terms)
|
| 706 |
|
| 707 |
def load_or_prepare_npmi_terms(self):
|
| 708 |
"""
|
|
|
|
| 231 |
# nPMI
|
| 232 |
# Holds a nPMIStatisticsCacheClass object
|
| 233 |
self.npmi_stats = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
# TODO: Have lowercase be an option for a user to set.
|
| 235 |
self.to_lowercase = True
|
| 236 |
# The minimum amount of times a word should occur to be included in
|
|
|
|
| 623 |
if save:
|
| 624 |
write_plotly(self.fig_labels, self.fig_labels_fid)
|
| 625 |
else:
|
| 626 |
+
self.prepare_labels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
if save:
|
| 628 |
# save extracted label instances
|
| 629 |
self.label_dset.save_to_disk(self.label_dset_fid)
|
| 630 |
write_plotly(self.fig_labels, self.fig_labels_fid)
|
| 631 |
|
| 632 |
+
def prepare_labels(self):
|
| 633 |
+
self.get_base_dataset()
|
| 634 |
+
self.label_dset = self.dset.map(
|
| 635 |
+
lambda examples: extract_field(
|
| 636 |
+
examples, self.label_field, OUR_LABEL_FIELD
|
| 637 |
+
),
|
| 638 |
+
batched=True,
|
| 639 |
+
remove_columns=list(self.dset.features),
|
| 640 |
+
)
|
| 641 |
+
self.label_df = self.label_dset.to_pandas()
|
| 642 |
+
self.fig_labels = make_fig_labels(
|
| 643 |
+
self.label_df, self.label_names, OUR_LABEL_FIELD
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
def load_or_prepare_npmi(self):
|
| 647 |
self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=self.use_cache)
|
| 648 |
self.npmi_stats.load_or_prepare_npmi_terms()
|
| 649 |
|
|
|
|
| 692 |
# We need to preprocess everything.
|
| 693 |
mkdir(self.pmi_cache_path)
|
| 694 |
self.joint_npmi_df_dict = {}
|
| 695 |
+
# TODO: Users ideally can type in whatever words they want.
|
| 696 |
+
self.termlist = _IDENTITY_TERMS
|
| 697 |
+
# termlist terms that are available more than _MIN_VOCAB_COUNT times
|
| 698 |
+
self.available_terms = _IDENTITY_TERMS
|
| 699 |
logs.info(self.termlist)
|
| 700 |
self.use_cache = use_cache
|
| 701 |
# TODO: Let users specify
|
|
|
|
| 703 |
self.min_vocab_count = self.dstats.min_vocab_count
|
| 704 |
self.subgroup_files = {}
|
| 705 |
self.npmi_terms_fid = pjoin(self.dstats.cache_path, "npmi_terms.json")
|
|
|
|
|
|
|
| 706 |
|
| 707 |
def load_or_prepare_npmi_terms(self):
|
| 708 |
"""
|
data_measurements/streamlit_utils.py
CHANGED
|
@@ -273,7 +273,6 @@ def expander_text_duplicates(dstats, column_id):
|
|
| 273 |
st.write(
|
| 274 |
"### Here is the list of all the duplicated items and their counts in your dataset:"
|
| 275 |
)
|
| 276 |
-
# Eh...adding 1 because otherwise it looks too weird for duplicate counts when the value is just 1.
|
| 277 |
if dstats.dup_counts_df is None:
|
| 278 |
st.write("There are no duplicates in this dataset! 🥳")
|
| 279 |
else:
|
|
@@ -393,7 +392,7 @@ with an ideal α value of 1."""
|
|
| 393 |
|
| 394 |
|
| 395 |
### Finally finally finally, show nPMI stuff.
|
| 396 |
-
def npmi_widget(
|
| 397 |
"""
|
| 398 |
Part of the main app, but uses a user interaction so pulled out as its own f'n.
|
| 399 |
:param use_cache:
|
|
@@ -403,16 +402,16 @@ def npmi_widget(column_id, available_terms, npmi_stats, min_vocab):
|
|
| 403 |
:return:
|
| 404 |
"""
|
| 405 |
with st.expander(f"Word Association{column_id}: nPMI", expanded=False):
|
| 406 |
-
if len(available_terms) > 0:
|
| 407 |
expander_npmi_description(min_vocab)
|
| 408 |
st.markdown("-----")
|
| 409 |
term1 = st.selectbox(
|
| 410 |
f"What is the first term you want to select?{column_id}",
|
| 411 |
-
available_terms,
|
| 412 |
)
|
| 413 |
term2 = st.selectbox(
|
| 414 |
f"What is the second term you want to select?{column_id}",
|
| 415 |
-
reversed(available_terms),
|
| 416 |
)
|
| 417 |
# We calculate/grab nPMI data based on a canonical (alphabetic)
|
| 418 |
# subgroup ordering.
|
|
|
|
| 273 |
st.write(
|
| 274 |
"### Here is the list of all the duplicated items and their counts in your dataset:"
|
| 275 |
)
|
|
|
|
| 276 |
if dstats.dup_counts_df is None:
|
| 277 |
st.write("There are no duplicates in this dataset! 🥳")
|
| 278 |
else:
|
|
|
|
| 392 |
|
| 393 |
|
| 394 |
### Finally finally finally, show nPMI stuff.
|
| 395 |
+
def npmi_widget(npmi_stats, min_vocab, column_id):
|
| 396 |
"""
|
| 397 |
Part of the main app, but uses a user interaction so pulled out as its own f'n.
|
| 398 |
:param use_cache:
|
|
|
|
| 402 |
:return:
|
| 403 |
"""
|
| 404 |
with st.expander(f"Word Association{column_id}: nPMI", expanded=False):
|
| 405 |
+
if len(npmi_stats.available_terms) > 0:
|
| 406 |
expander_npmi_description(min_vocab)
|
| 407 |
st.markdown("-----")
|
| 408 |
term1 = st.selectbox(
|
| 409 |
f"What is the first term you want to select?{column_id}",
|
| 410 |
+
npmi_stats.available_terms,
|
| 411 |
)
|
| 412 |
term2 = st.selectbox(
|
| 413 |
f"What is the second term you want to select?{column_id}",
|
| 414 |
+
reversed(npmi_stats.available_terms),
|
| 415 |
)
|
| 416 |
# We calculate/grab nPMI data based on a canonical (alphabetic)
|
| 417 |
# subgroup ordering.
|