Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,7 +5,6 @@ from gradio import routes
|
|
| 5 |
from typing import List, Type
|
| 6 |
from petals import AutoDistributedModelForCausalLM
|
| 7 |
from transformers import AutoTokenizer
|
| 8 |
-
import npc_data
|
| 9 |
import requests, os, re, asyncio, json
|
| 10 |
|
| 11 |
loop = asyncio.get_event_loop()
|
|
@@ -31,7 +30,7 @@ routes.get_types = get_types
|
|
| 31 |
|
| 32 |
# App code
|
| 33 |
|
| 34 |
-
model_name = "
|
| 35 |
|
| 36 |
#daekeun-ml/Llama-2-ko-instruct-13B
|
| 37 |
#quantumaikr/llama-2-70b-fb16-korean
|
|
@@ -39,6 +38,12 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
| 39 |
|
| 40 |
model = None
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def check(model_name):
|
| 43 |
data = requests.get("https://health.petals.dev/api/v1/state").json()
|
| 44 |
out = []
|
|
@@ -72,23 +77,51 @@ def chat(id, npc, text):
|
|
| 72 |
|
| 73 |
if check(model_name):
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
|
|
|
|
| 81 |
|
| 82 |
-
|
| 83 |
-
User:
|
| 84 |
-
{user_message}
|
| 85 |
-
[/INST]
|
| 86 |
"""
|
| 87 |
|
| 88 |
inputs = tokenizer(prom, return_tensors="pt")["input_ids"]
|
| 89 |
-
outputs = model.generate(inputs, max_new_tokens=
|
| 90 |
-
|
| 91 |
-
|
| 92 |
print(output)
|
| 93 |
else:
|
| 94 |
output = "no model"
|
|
|
|
| 5 |
from typing import List, Type
|
| 6 |
from petals import AutoDistributedModelForCausalLM
|
| 7 |
from transformers import AutoTokenizer
|
|
|
|
| 8 |
import requests, os, re, asyncio, json
|
| 9 |
|
| 10 |
loop = asyncio.get_event_loop()
|
|
|
|
| 30 |
|
| 31 |
# App code
|
| 32 |
|
| 33 |
+
model_name = "petals-team/StableBeluga2"
|
| 34 |
|
| 35 |
#daekeun-ml/Llama-2-ko-instruct-13B
|
| 36 |
#quantumaikr/llama-2-70b-fb16-korean
|
|
|
|
| 38 |
|
| 39 |
model = None
|
| 40 |
|
| 41 |
+
history = {
|
| 42 |
+
"":{
|
| 43 |
+
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
def check(model_name):
|
| 48 |
data = requests.get("https://health.petals.dev/api/v1/state").json()
|
| 49 |
out = []
|
|
|
|
| 77 |
|
| 78 |
if check(model_name):
|
| 79 |
|
| 80 |
+
global history
|
| 81 |
+
if not npc in npc_story:
|
| 82 |
+
return "no npc"
|
| 83 |
+
|
| 84 |
+
if not npc in history:
|
| 85 |
+
history[npc] = {}
|
| 86 |
+
if not id in history[npc]:
|
| 87 |
+
history[npc][id] = ""
|
| 88 |
+
if len(history[npc][id].split("###")) > 10:
|
| 89 |
+
history[npc][id] = "###" + history[npc][id].split("###", 3)[3]
|
| 90 |
+
npc_list = str([k for k in npc_story.keys()]).replace('\'', '')
|
| 91 |
+
town_story = f"""[{id}์ ๋ง์]
|
| 92 |
+
์ธ๋ด ๊ณณ์ ์กฐ๊ทธ๋ง ์ฌ์ ์ฌ๋ฌ ์ฃผ๋ฏผ๋ค์ด ๋ชจ์ฌ ์ด๊ณ ์์ต๋๋ค.
|
| 93 |
+
|
| 94 |
+
ํ์ฌ {npc_list}์ด ์ด๊ณ ์์ต๋๋ค."""
|
| 95 |
+
|
| 96 |
+
system_message = f"""1. ๋น์ ์ ํ๊ตญ์ด์ ๋ฅ์ํฉ๋๋ค.
|
| 97 |
+
2. ๋น์ ์ ์ง๊ธ ์ญํ ๊ทน์ ํ๊ณ ์์ต๋๋ค. {npc}์ ๋ฐ์์ ์์ํ๊ณ ๋งค๋ ฅ์ ์ด๊ฒ ํํํฉ๋๋ค.
|
| 98 |
+
3. ๋น์ ์ {npc}์
๋๋ค. {npc}์ ์
์ฅ์์ ์๊ฐํ๊ณ ๋งํฉ๋๋ค.
|
| 99 |
+
4. ์ฃผ์ด์ง๋ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก ๊ฐ์ฐ์ฑ์๊ณ ์ค๊ฐ๋๋ {npc}์ ๋์ฌ๋ฅผ ์์ฑํ์ธ์.
|
| 100 |
+
5. ์ฃผ์ด์ง๋ {npc}์ ์ ๋ณด๋ฅผ ์ ์คํ๊ฒ ์ฝ๊ณ , ๊ณผํ์ง ์๊ณ ๋ด๋ฐฑํ๊ฒ ์บ๋ฆญํฐ๋ฅผ ์ฐ๊ธฐํ์ธ์.
|
| 101 |
+
6. User์ ์ญํ ์ ์ ๋๋ก ์นจ๋ฒํ์ง ๋ง์ธ์. ๊ฐ์ ๋ง์ ๋ฐ๋ณตํ์ง ๋ง์ธ์.
|
| 102 |
+
7. {npc}์ ๋งํฌ๋ฅผ ์ง์ผ์ ์์ฑํ์ธ์."""
|
| 103 |
+
|
| 104 |
+
prom = f"""<<SYS>>
|
| 105 |
+
{system_message}<</SYS>>
|
| 106 |
+
|
| 107 |
+
{town_story}
|
| 108 |
+
|
| 109 |
+
### ์บ๋ฆญํฐ ์ ๋ณด: {npc_story[npc]}
|
| 110 |
|
| 111 |
+
### ๋ช
๋ น์ด:
|
| 112 |
+
{npc}์ ์ ๋ณด๋ฅผ ์ฐธ๊ณ ํ์ฌ {npc}์ด ํ ๋ง์ ์ํฉ์ ๋ง์ถฐ ์์ฐ์ค๋ฝ๊ฒ ์์ฑํด์ฃผ์ธ์.
|
| 113 |
+
{history[npc][id]}
|
| 114 |
|
| 115 |
+
### User:
|
| 116 |
+
{text}
|
| 117 |
|
| 118 |
+
### {npc}:
|
|
|
|
|
|
|
|
|
|
| 119 |
"""
|
| 120 |
|
| 121 |
inputs = tokenizer(prom, return_tensors="pt")["input_ids"]
|
| 122 |
+
outputs = model.generate(inputs, do_sample=True, temperature=0.6, top_p=0.75, max_new_tokens=100)
|
| 123 |
+
output = tokenizer.decode(outputs[0])[len(prom)+3:-1].split("<")[0].split("###")[0].replace(". ", ".\n")
|
| 124 |
+
print(outputs)
|
| 125 |
print(output)
|
| 126 |
else:
|
| 127 |
output = "no model"
|