Spaces:
Build error
Build error
| import pandas as pd | |
| import numpy as np | |
| import logging | |
| from datasets.utils import set_progress_bar_enabled | |
| from weakly_supervised_parser.utils.prepare_dataset import NGramify | |
| from weakly_supervised_parser.utils.create_inside_outside_strings import InsideOutside | |
| from weakly_supervised_parser.model.trainer import InsideOutsideStringPredictor | |
| from weakly_supervised_parser.utils.cky_algorithm import get_best_parse | |
| from weakly_supervised_parser.utils.distant_supervision import RuleBasedHeuristic | |
| from weakly_supervised_parser.utils.prepare_dataset import PTBDataset | |
| from weakly_supervised_parser.settings import PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH | |
| from weakly_supervised_parser.model.data_module_loader import DataModule | |
| from weakly_supervised_parser.model.span_classifier import LightningModel | |
| # Disable Dataset.map progress bar | |
| set_progress_bar_enabled(False) | |
| logging.getLogger("pytorch_lightning").setLevel(logging.WARNING) | |
| # ptb = PTBDataset(data_path=PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH) | |
| # ptb_top_100_common = [item.lower() for item in RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).get_top_tokens(top_most_common_ptb=100)] | |
| ptb_top_100_common = ['this', 'myself', 'shouldn', 'not', 'analysts', 'same', 'mightn', 'we', 'american', 'the', 'another', 'until', "aren't", 'when', 'if', 'am', 'over', 'ma', 'as', 'of', 'with', 'even', 'couldn', 'not', "needn't", 'where', 'there', 'isn', 'however', 'my', 'sales', 'here', 'at', 'yours', 'into', 'wouldn', 'officials', 'no', "hasn't", 'to', 'wasn', 'any', 'ours', 'out', 'each', "wasn't", 'is', 'and', 'me', 'off', 'once', "it's", 'they', 'most', 'also', 'through', 'hasn', 'our', 'or', 'after', "weren't", 'about', 'mr.', 'first', 'haven', 'needn', 'have', "isn't", 'now', "didn't", 'on', 'theirs', 'these', 'before', 'there', 'was', 'which', 'those', 'having', 'do', 'most', 'own', 'among', 'because', 'for', "should've", "shan't", 'so', 'being', 'few', 'too', 'to', 'at', 'people', 'her', 'meanwhile', 'both', 'down', 'doesn', 'below', 'mustn', 'an', 'two', 'more', 'japanese', 'ford', "you'd", 'about', 'but', 'doing', 'itself', 've', 'under', 'what', 'again', 'then', 'your', 'himself', 'now', 'against', 'just', 'does', 'net', "couldn't", 'that', 'he', 'revenue', 'because', 'yesterday', 'them', 'i', 'their', 'all', 'under', 'up', "haven't", 'while', "won't", 'it', 'more', 'it', 'ain', 'him', 'still', 'a', 'he', 'despite', 'should', 'during', 'nor', "shouldn't", 'such', "doesn't", 'are', "that'll", 'since', 'yourselves', 'such', 'those', 'after', 'weren', "you're", 'd', 'like', 'did', 'hadn', 'themselves', 'its', 'but', 'been', 's', "don't", 'these', 'they', 'this', 'his', "mightn't", 'moreover', 'how', 'new', 'above', 'ourselves', 'so', 'why', 'between', 'their', 'general', "wouldn't", 'who', 'i', 'in', 'don', 'shan', 'u.s.', 'ibm', 'separately', 'had', 'you', 'federal', 'if', 'our', 'and', 'only', 'y', 'many', 'one', 'no', 'though', 'won', 'last', 'from', 'each', 'traders', 'john', 'further', 'hers', 'both', "you've", "you'll", 'that', 'all', 'its', 'only', 'here', 'according', "mustn't", 'while', 'in', 'what', 'didn', 'when', 'some', 'on', 'can', 'yourself', 'herself', 'than', 'with', 'has', 'she', 'during', 'will', 'of', 'thus', 'you', 'very', 'o', 'investors', 'a', 'ms.', 'japan', 'were', 'the', 'we', 'm', 'as', 'll', 'be', 'by', 'other', 'yet', 'whom', 'some', 'indeed', 'other', "she's", "hadn't", 'by', 'earlier', 'for', 'instead', 'she', 'an', 't', 're', 'his', 'then', 'aren', 'although'] | |
| # ptb_most_common_first_token = RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0].lower() | |
| ptb_most_common_first_token = "the" | |
| from pytorch_lightning import Trainer | |
| trainer = Trainer(accelerator="auto", enable_progress_bar=False, max_epochs=-1) | |
| class PopulateCKYChart: | |
| def __init__(self, sentence): | |
| self.sentence = sentence | |
| self.sentence_list = sentence.split() | |
| self.sentence_length = len(sentence.split()) | |
| self.span_scores = np.zeros((self.sentence_length + 1, self.sentence_length + 1), dtype=float) | |
| self.all_spans = NGramify(self.sentence).generate_ngrams(single_span=True, whole_span=True) | |
| def compute_scores(self, model, predict_type, scale_axis, predict_batch_size, chunks=128): | |
| inside_strings = [] | |
| outside_strings = [] | |
| inside_scores = [] | |
| outside_scores = [] | |
| for span in self.all_spans: | |
| _, inside_string, outside_string = InsideOutside(sentence=self.sentence).create_inside_outside_matrix(span) | |
| inside_strings.append(inside_string) | |
| outside_strings.append(outside_string) | |
| data = pd.DataFrame({"inside_sentence": inside_strings, "outside_sentence": outside_strings, "span": [span[0] for span in self.all_spans]}) | |
| if predict_type == "inside": | |
| # if data.shape[0] > chunks: | |
| # data_chunks = np.array_split(data, data.shape[0] // chunks) | |
| # for data_chunk in data_chunks: | |
| # inside_scores.extend(model.predict_proba(spans=data_chunk.rename(columns={"inside_sentence": "sentence"})[["sentence"]], | |
| # scale_axis=scale_axis, | |
| # predict_batch_size=predict_batch_size)[:, 1]) | |
| # else: | |
| # inside_scores.extend(model.predict_proba(spans=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]], | |
| # scale_axis=scale_axis, | |
| # predict_batch_size=predict_batch_size)[:, 1]) | |
| test_dataloader = DataModule(model_name_or_path="roberta-base", train_df=None, eval_df=None, | |
| test_df=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]]) | |
| inside_scores.extend(trainer.predict(model, dataloaders=test_dataloader)[0]) | |
| data["inside_scores"] = inside_scores | |
| data.loc[ | |
| (data["inside_sentence"].str.lower().str.startswith(ptb_most_common_first_token)) | |
| & (data["inside_sentence"].str.lower().str.split().str.len() == 2) | |
| & (~data["inside_sentence"].str.lower().str.split().str[-1].isin(RuleBasedHeuristic().get_top_tokens())), | |
| "inside_scores", | |
| ] = 1 | |
| is_upper_or_title = all([item.istitle() or item.isupper() for item in self.sentence.split()]) | |
| is_stop = any([item for item in self.sentence.split() if item.lower() in ptb_top_100_common]) | |
| flags = is_upper_or_title and not is_stop | |
| data["scores"] = data["inside_scores"] | |
| elif predict_type == "outside": | |
| outside_scores.extend(model.predict_proba(spans=data.rename(columns={"outside_sentence": "sentence"})[["sentence"]], | |
| scale_axis=scale_axis, | |
| predict_batch_size=predict_batch_size)[:, 1]) | |
| data["outside_scores"] = outside_scores | |
| flags = False | |
| data["scores"] = data["outside_scores"] | |
| return flags, data | |
| def fill_chart(self, model, predict_type, scale_axis, predict_batch_size, data=None): | |
| if data is None: | |
| flags, data = self.compute_scores(model, predict_type, scale_axis, predict_batch_size) | |
| for span in self.all_spans: | |
| for i in range(0, self.sentence_length): | |
| for j in range(i + 1, self.sentence_length + 1): | |
| if span[0] == (i, j): | |
| self.span_scores[i, j] = data.loc[data["span"] == span[0], "scores"].item() | |
| return flags, self.span_scores, data | |
| def best_parse_tree(self, span_scores): | |
| span_scores_cky_format = span_scores[:-1, 1:] | |
| best_parse = get_best_parse(sentence=[self.sentence_list], spans=span_scores_cky_format) | |
| return best_parse | |