First working version
Browse files- completions.py +4 -4
- frontend/src/components/TokenChip.tsx +16 -1
- frontend/src/components/app.tsx +21 -13
completions.py
CHANGED
|
@@ -91,13 +91,13 @@ def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples:
|
|
| 91 |
def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
|
| 92 |
all_new_words = []
|
| 93 |
for i in range(num_inputs):
|
| 94 |
-
replacements =
|
| 95 |
for j in range(num_samples):
|
| 96 |
generated_ids = outputs[i * num_samples + j][input_len:]
|
| 97 |
new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
|
| 98 |
if starts_with_space(new_word):
|
| 99 |
-
replacements.
|
| 100 |
-
all_new_words.append(replacements)
|
| 101 |
return all_new_words
|
| 102 |
|
| 103 |
#%%
|
|
@@ -128,7 +128,7 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
|
|
| 128 |
input_ids = inputs["input_ids"]
|
| 129 |
|
| 130 |
#%%
|
| 131 |
-
num_samples =
|
| 132 |
start_time = time.time()
|
| 133 |
outputs = generate_outputs(model, inputs, num_samples)
|
| 134 |
end_time = time.time()
|
|
|
|
| 91 |
def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
|
| 92 |
all_new_words = []
|
| 93 |
for i in range(num_inputs):
|
| 94 |
+
replacements = set()
|
| 95 |
for j in range(num_samples):
|
| 96 |
generated_ids = outputs[i * num_samples + j][input_len:]
|
| 97 |
new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
|
| 98 |
if starts_with_space(new_word):
|
| 99 |
+
replacements.add(" " +new_word[1:])
|
| 100 |
+
all_new_words.append(sorted(list(replacements)))
|
| 101 |
return all_new_words
|
| 102 |
|
| 103 |
#%%
|
|
|
|
| 128 |
input_ids = inputs["input_ids"]
|
| 129 |
|
| 130 |
#%%
|
| 131 |
+
num_samples = 10
|
| 132 |
start_time = time.time()
|
| 133 |
outputs = generate_outputs(model, inputs, num_samples)
|
| 134 |
end_time = time.time()
|
frontend/src/components/TokenChip.tsx
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import React, { useState } from "react"
|
| 2 |
|
| 3 |
export const TokenChip = ({
|
| 4 |
token,
|
|
@@ -14,6 +14,20 @@ export const TokenChip = ({
|
|
| 14 |
onReplace: (newToken: string) => void
|
| 15 |
}) => {
|
| 16 |
const [isExpanded, setIsExpanded] = useState(false);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
const handleClick = () => {
|
| 19 |
if (logprob < threshold && replacements.length > 0) {
|
|
@@ -39,6 +53,7 @@ export const TokenChip = ({
|
|
| 39 |
{token}
|
| 40 |
{isExpanded && (
|
| 41 |
<select
|
|
|
|
| 42 |
onChange={handleReplacement}
|
| 43 |
value={token}
|
| 44 |
style={{
|
|
|
|
| 1 |
+
import React, { useState, useEffect, useRef } from "react"
|
| 2 |
|
| 3 |
export const TokenChip = ({
|
| 4 |
token,
|
|
|
|
| 14 |
onReplace: (newToken: string) => void
|
| 15 |
}) => {
|
| 16 |
const [isExpanded, setIsExpanded] = useState(false);
|
| 17 |
+
const dropdownRef = useRef<HTMLSelectElement>(null);
|
| 18 |
+
|
| 19 |
+
useEffect(() => {
|
| 20 |
+
const handleClickOutside = (event: MouseEvent) => {
|
| 21 |
+
if (dropdownRef.current && !dropdownRef.current.contains(event.target as Node)) {
|
| 22 |
+
setIsExpanded(false);
|
| 23 |
+
}
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
document.addEventListener("mousedown", handleClickOutside);
|
| 27 |
+
return () => {
|
| 28 |
+
document.removeEventListener("mousedown", handleClickOutside);
|
| 29 |
+
};
|
| 30 |
+
}, []);
|
| 31 |
|
| 32 |
const handleClick = () => {
|
| 33 |
if (logprob < threshold && replacements.length > 0) {
|
|
|
|
| 53 |
{token}
|
| 54 |
{isExpanded && (
|
| 55 |
<select
|
| 56 |
+
ref={dropdownRef}
|
| 57 |
onChange={handleReplacement}
|
| 58 |
value={token}
|
| 59 |
style={{
|
frontend/src/components/app.tsx
CHANGED
|
@@ -37,32 +37,40 @@ export default function App() {
|
|
| 37 |
const [context, setContext] = useState("")
|
| 38 |
const [wordlist, setWordlist] = useState("")
|
| 39 |
const [showWholePrompt, setShowWholePrompt] = useState(false)
|
| 40 |
-
const [text, setText] = useState("
|
| 41 |
const [mode, setMode] = useState<"edit" | "check">("edit")
|
| 42 |
const [words, setWords] = useState<Word[]>([])
|
| 43 |
const [isLoading, setIsLoading] = useState(false)
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
const toggleMode = async () => {
|
| 46 |
if (mode === "edit") {
|
| 47 |
setIsLoading(true)
|
| 48 |
-
|
| 49 |
-
const checkedWords = await checkText(text)
|
| 50 |
-
setWords(checkedWords)
|
| 51 |
-
} finally {
|
| 52 |
-
setMode("check")
|
| 53 |
-
setIsLoading(false)
|
| 54 |
-
}
|
| 55 |
} else {
|
| 56 |
setMode("edit")
|
| 57 |
}
|
| 58 |
}
|
| 59 |
|
| 60 |
-
const handleReplace = (index: number, newToken: string) => {
|
| 61 |
const updatedWords = [...words]
|
| 62 |
updatedWords[index].text = newToken
|
|
|
|
|
|
|
| 63 |
setWords(updatedWords)
|
| 64 |
setText(updatedWords.map(w => w.text).join(""))
|
| 65 |
-
setMode("edit")
|
|
|
|
| 66 |
}
|
| 67 |
|
| 68 |
let result
|
|
@@ -101,7 +109,7 @@ export default function App() {
|
|
| 101 |
<details>
|
| 102 |
<summary>Advanced settings</summary>
|
| 103 |
<label>
|
| 104 |
-
<strong>Threshold:</strong> <input type="number" step="
|
| 105 |
<small>
|
| 106 |
The <a href="https://en.wikipedia.org/wiki/Log_probability" target="_blank" rel="noreferrer">logprob</a> threshold.
|
| 107 |
Tokens with logprobs smaller than this will be marked red.
|
|
@@ -132,8 +140,8 @@ export default function App() {
|
|
| 132 |
|
| 133 |
<p>
|
| 134 |
<small>
|
| 135 |
-
|
| 136 |
-
Made with
|
| 137 |
<br />
|
| 138 |
This software is provided with absolutely no warranty.
|
| 139 |
</small>
|
|
|
|
| 37 |
const [context, setContext] = useState("")
|
| 38 |
const [wordlist, setWordlist] = useState("")
|
| 39 |
const [showWholePrompt, setShowWholePrompt] = useState(false)
|
| 40 |
+
const [text, setText] = useState("I just drove to the store to but eggs, but they had some.")
|
| 41 |
const [mode, setMode] = useState<"edit" | "check">("edit")
|
| 42 |
const [words, setWords] = useState<Word[]>([])
|
| 43 |
const [isLoading, setIsLoading] = useState(false)
|
| 44 |
|
| 45 |
+
const check = async () => {
|
| 46 |
+
setIsLoading(true)
|
| 47 |
+
try {
|
| 48 |
+
const checkedWords = await checkText(text)
|
| 49 |
+
setWords(checkedWords)
|
| 50 |
+
} finally {
|
| 51 |
+
setIsLoading(false)
|
| 52 |
+
setMode("check")
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
const toggleMode = async () => {
|
| 57 |
if (mode === "edit") {
|
| 58 |
setIsLoading(true)
|
| 59 |
+
await check()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
} else {
|
| 61 |
setMode("edit")
|
| 62 |
}
|
| 63 |
}
|
| 64 |
|
| 65 |
+
const handleReplace = async (index: number, newToken: string) => {
|
| 66 |
const updatedWords = [...words]
|
| 67 |
updatedWords[index].text = newToken
|
| 68 |
+
updatedWords[index].logprob = 0
|
| 69 |
+
updatedWords[index].replacements = []
|
| 70 |
setWords(updatedWords)
|
| 71 |
setText(updatedWords.map(w => w.text).join(""))
|
| 72 |
+
// setMode("edit")
|
| 73 |
+
await check()
|
| 74 |
}
|
| 75 |
|
| 76 |
let result
|
|
|
|
| 109 |
<details>
|
| 110 |
<summary>Advanced settings</summary>
|
| 111 |
<label>
|
| 112 |
+
<strong>Threshold:</strong> <input type="number" step="1" value={threshold} onChange={e => setThreshold(Number(e.target.value))} />
|
| 113 |
<small>
|
| 114 |
The <a href="https://en.wikipedia.org/wiki/Log_probability" target="_blank" rel="noreferrer">logprob</a> threshold.
|
| 115 |
Tokens with logprobs smaller than this will be marked red.
|
|
|
|
| 140 |
|
| 141 |
<p>
|
| 142 |
<small>
|
| 143 |
+
Based on <a href="https://github.com/vgel/gpted">GPTed</a> by <a href="https://vgel.me">Theia Vogel</a>.
|
| 144 |
+
Made with React, Transformers, LLama 3.2, and transitively, most of the web.
|
| 145 |
<br />
|
| 146 |
This software is provided with absolutely no warranty.
|
| 147 |
</small>
|