Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| import sys | |
| import os | |
| from datasets import load_from_disk | |
| # from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import numpy as np | |
| import time | |
| from annotated_text import annotated_text | |
| ABSOLUTE_PATH = os.path.dirname(__file__) | |
| ASSETS_PATH = os.path.join(ABSOLUTE_PATH, 'model_assets') | |
| from nltk.data import find | |
| import nltk | |
| import gensim | |
| def get_embed_model(): | |
| nltk.download("word2vec_sample") | |
| word2vec_sample = str(find('models/word2vec_sample/pruned.word2vec.txt')) | |
| model = gensim.models.KeyedVectors.load_word2vec_format(word2vec_sample, binary=False) | |
| return model | |
| def get_top_n_closest(query_word, candidate, n): | |
| model = get_embed_model() | |
| t = time.time() | |
| p_c = preprocess_text(candidate) | |
| similarity = [] | |
| t = time.time() | |
| for i in p_c: | |
| try: | |
| similarity.append(model.similarity(query_word, i)) | |
| except: | |
| similarity.append(0) | |
| top_n = min(len(p_c), n) | |
| t = time.time() | |
| sorted = (-1*np.array(similarity)).argsort()[:top_n] | |
| top = [p_c[i] for i in sorted] | |
| return top | |
| def annotate_text(text, words): | |
| annotated = [text] | |
| for word in words: | |
| for i in range(len(annotated)): | |
| if type(annotated[i]) != str: | |
| continue | |
| string = annotated[i] | |
| try: | |
| index = string.index(word) | |
| except: | |
| continue | |
| first = string[:index] | |
| second = (string[index:index+len(word)],'SIMILAR') | |
| third = string[index+len(word):] | |
| annotated = annotated[:i] + [first, second, third] + annotated[i+1:] | |
| return tuple(annotated) | |
| def preprocess_text(s): | |
| return list(filter(lambda x: x!= '', (''.join(c if c.isalnum() or c == ' ' else ' ' for c in s)).split(' '))) | |
| def get_pairwise_distances(model): | |
| df = pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv").set_index('index') | |
| return df | |
| def get_pairwise_distances_chunked(model, chunk): | |
| # for df in pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv", chunksize = 16): | |
| # print(df.iloc[0]['queries']) | |
| # if chunk == int(df.iloc[0]['queries']): | |
| # return df | |
| return get_pairwise_distances(model) | |
| def get_query_strings(): | |
| df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.jsonl", lines = True) | |
| df['index'] = df.reset_index().index | |
| return df | |
| # df['partition'] = df['index']%100 | |
| # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", index = 'index', partition_cols = 'partition') | |
| # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", columns=['fullText', 'index', 'authorIDs']) | |
| def get_candidate_strings(): | |
| df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.jsonl", lines = True) | |
| df['i'] = df['index'] | |
| df = df.set_index('i') | |
| # df['index'] = df.reset_index().index | |
| return df | |
| # df['partition'] = df['index']%100 | |
| # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", index = 'index', partition_cols = 'partition') | |
| # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", columns=['fullText', 'index', 'authorIDs']) | |
| def get_embedding_dataset(model): | |
| data = load_from_disk(f"{ASSETS_PATH}/{model}/embedding") | |
| return data | |
| def get_bad_queries(model): | |
| df = get_query_strings().iloc[list(get_pairwise_distances(model)['queries'].unique())][['fullText', 'index', 'authorIDs']] | |
| return df | |
| def get_gt_candidates(model, author): | |
| gt_candidates = get_candidate_strings() | |
| df = gt_candidates[gt_candidates['authorIDs'] == author] | |
| return df | |
| def get_candidate_text(l): | |
| return get_candidate_strings().at[l,'fullText'] | |
| def get_annotated_text(text, word, pos): | |
| print("here", word, pos) | |
| start= text.index(word, pos) | |
| end = start+len(word) | |
| return (text[:start], (text[start:end ], 'SELECTED'), text[end:]), end | |
| # class AgGridBuilder: | |
| # __static_key = 0 | |
| # def build_ag_grid(table, display_columns): | |
| # AgGridBuilder.__static_key += 1 | |
| # options_builder = GridOptionsBuilder.from_dataframe(table[display_columns]) | |
| # options_builder.configure_pagination(paginationAutoPageSize=False, paginationPageSize=10) | |
| # options_builder.configure_selection(selection_mode= 'single', pre_selected_rows = [0]) | |
| # options = options_builder.build() | |
| # return AgGrid(table, gridOptions = options, fit_columns_on_grid_load=True, key = AgGridBuilder.__static_key, reload_data = True, update_mode = GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED) | |
| if __name__ == '__main__': | |
| st.set_page_config(layout="wide") | |
| models = filter(lambda file_name: os.path.isdir(f"{ASSETS_PATH}/{file_name}") and not file_name.endswith(".parquet"), os.listdir(ASSETS_PATH)) | |
| with st.sidebar: | |
| current_model = st.selectbox( | |
| "Select Model to analyze", | |
| models | |
| ) | |
| pairwise_distances = get_pairwise_distances(current_model) | |
| embedding_dataset = get_embedding_dataset(current_model) | |
| candidate_string_grid = None | |
| gt_candidate_string_grid = None | |
| with st.container(): | |
| t1 = time.time() | |
| st.title("Full Text") | |
| col1, col2 = st.columns([14, 2]) | |
| t2 = time.time() | |
| query_table = get_bad_queries(current_model) | |
| t3 = time.time() | |
| print(query_table) | |
| with col2: | |
| index = st.number_input('Enter Query number to inspect', min_value = 0, max_value = query_table.shape[0], step = 1) | |
| query_text = query_table.loc[index]['fullText'] | |
| preprocessed_query_text = preprocess_text(query_text) | |
| text_highlight_index = st.number_input('Enter word #', min_value = 0, max_value = len(preprocessed_query_text), step = 1) | |
| query_index = int(query_table.iloc[index]['index']) | |
| with col1: | |
| if 'pos_highlight' not in st.session_state or text_highlight_index == 0: | |
| st.session_state['pos_highlight'] = text_highlight_index | |
| st.session_state['pos_history'] = [0] | |
| if st.session_state['pos_highlight'] > text_highlight_index: | |
| st.session_state['pos_history'] = st.session_state['pos_history'][:-2] | |
| if len(st.session_state['pos_history']) == 0: | |
| st.session_state['pos_history'] = [0] | |
| print("pos", st.session_state['pos_history'], st.session_state['pos_highlight'], text_highlight_index) | |
| anotated_text_, pos = get_annotated_text(query_text, preprocessed_query_text[text_highlight_index-1], st.session_state['pos_history'][-1]) if text_highlight_index >= 1 else ((query_text), 0) | |
| if st.session_state['pos_highlight'] < text_highlight_index: | |
| st.session_state['pos_history'].append(pos) | |
| st.session_state['pos_highlight'] = text_highlight_index | |
| annotated_text(*anotated_text_) | |
| # annotated_text("Lol, this" , ('guy', 'SELECTED') , "is such a PR chameleon. \n\n In the Chan Zuckerberg Initiative announcement, he made it sound like he was giving away all his money to charity <PERSON> or <PERSON>. http://www.businessinsider.in/Mark-Zuckerberg-says-hes-giving-99-of-his-Facebook-shares-45-billion-to-charity/articleshow/50005321.cms Apparently, its just a VC fund. And there are still people out there who believe Facebook.org was an initiative to bring Internet to the poor.") | |
| t4 = time.time() | |
| print(f"query time query text: {t3-t2}, total time: {t4-t1}") | |
| with st.container(): | |
| st.title("Top 16 Recommended Candidates") | |
| col1, col2, col3 = st.columns([10, 4, 2]) | |
| rec_candidates = pairwise_distances[pairwise_distances["queries"]==query_index]['candidates'] | |
| print(rec_candidates) | |
| l = list(rec_candidates) | |
| with col3: | |
| candidate_rec_index = st.number_input('Enter recommended candidate number to inspect', min_value = 0, max_value = len(l), step = 1) | |
| print("l:",l, query_index) | |
| pairwise_candidate_index = int(l[candidate_rec_index]) | |
| with col1: | |
| st.header("Text") | |
| t1 = time.time() | |
| candidate_text = get_candidate_text(pairwise_candidate_index) | |
| if st.session_state['pos_highlight'] == 0: | |
| annotated_text(candidate_text) | |
| else: | |
| top_n_words_to_highlight = get_top_n_closest(preprocessed_query_text[text_highlight_index-1], candidate_text, 4) | |
| print("TOPN", top_n_words_to_highlight) | |
| annotated_text(*annotate_text(candidate_text, top_n_words_to_highlight)) | |
| t2 = time.time() | |
| with col2: | |
| st.header("Cosine Distance") | |
| st.write(float(pairwise_distances[\ | |
| ( pairwise_distances['queries'] == query_index ) \ | |
| & | |
| ( pairwise_distances['candidates'] == pairwise_candidate_index)]['distances'])) | |
| print(f"candidate string retreival: {t2-t1}") | |
| with st.container(): | |
| t1 = time.time() | |
| st.title("Candidates With Same Authors As Query") | |
| col1, col2, col3 = st.columns([10, 4, 2]) | |
| t2 = time.time() | |
| gt_candidates = get_gt_candidates(current_model, query_table.iloc[query_index]['authorIDs'][0]) | |
| t3 = time.time() | |
| with col3: | |
| candidate_index = st.number_input('Enter ground truthnumber to inspect', min_value = 0, max_value = gt_candidates.shape[0], step = 1) | |
| print(gt_candidates.head()) | |
| gt_candidate_index = int(gt_candidates.iloc[candidate_index]['index']) | |
| with col1: | |
| st.header("Text") | |
| st.write(gt_candidates.iloc[candidate_index]['fullText']) | |
| with col2: | |
| t4 = time.time() | |
| st.header("Cosine Distance") | |
| indices = list(embedding_dataset['candidates']['index']) | |
| st.write(1-cosine_similarity(np.array([embedding_dataset['queries'][query_index]['embedding']]), np.array([embedding_dataset['candidates'][indices.index(gt_candidate_index)]['embedding']]))[0,0]) | |
| t5 = time.time() | |
| print(f"find gt candidates: {t3-t2}, find cosine: {t5-t4}, total: {t5-t1}") | |