Spaces:
Runtime error
Runtime error
| # copy from megatron | |
| def get_a_and_b_segments(sample, np_rng): | |
| """Divide sample into a and b segments.""" | |
| # Number of sentences in the sample. | |
| n_sentences = len(sample) | |
| # Make sure we always have two sentences. | |
| assert n_sentences > 1, 'make sure each sample has at least two sentences.' | |
| # First part: | |
| # `a_end` is how many sentences go into the `A`. | |
| a_end = 1 | |
| if n_sentences >= 3: | |
| # Note that randin in numpy is exclusive. | |
| a_end = np_rng.randint(1, n_sentences) | |
| tokens_a = [] | |
| for j in range(a_end): | |
| tokens_a.extend(sample[j]) | |
| # Second part: | |
| tokens_b = [] | |
| for j in range(a_end, n_sentences): | |
| tokens_b.extend(sample[j]) | |
| # Random next: | |
| is_next_random = False | |
| if np_rng.random() < 0.5: | |
| is_next_random = True | |
| tokens_a, tokens_b = tokens_b, tokens_a | |
| return tokens_a, tokens_b, is_next_random | |