Donghao Huang
commited on
Commit
·
6b469d2
1
Parent(s):
571afe2
fixed bug on llama-2
Browse files- app_modules/llm_inference.py +9 -8
- app_modules/llm_loader.py +1 -1
- test.py +5 -1
app_modules/llm_inference.py
CHANGED
|
@@ -35,7 +35,12 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
| 35 |
return self.chain
|
| 36 |
|
| 37 |
def call_chain(
|
| 38 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
):
|
| 40 |
print(inputs)
|
| 41 |
if self.llm_loader.streamer.for_huggingface:
|
|
@@ -46,11 +51,7 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
| 46 |
|
| 47 |
chain = self.get_chain(tracing)
|
| 48 |
result = (
|
| 49 |
-
self._run_chain(
|
| 50 |
-
chain,
|
| 51 |
-
inputs,
|
| 52 |
-
streaming_handler,
|
| 53 |
-
)
|
| 54 |
if streaming_handler is not None
|
| 55 |
else chain(inputs)
|
| 56 |
)
|
|
@@ -74,7 +75,7 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
| 74 |
def _execute_chain(self, chain, inputs, q, sh):
|
| 75 |
q.put(chain(inputs, callbacks=[sh]))
|
| 76 |
|
| 77 |
-
def _run_chain(self, chain, inputs, streaming_handler):
|
| 78 |
que = Queue()
|
| 79 |
|
| 80 |
t = Thread(
|
|
@@ -83,7 +84,7 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
| 83 |
)
|
| 84 |
t.start()
|
| 85 |
|
| 86 |
-
if self.llm_loader.streamer.for_huggingface:
|
| 87 |
count = (
|
| 88 |
2
|
| 89 |
if "chat_history" in inputs and len(inputs.get("chat_history")) > 0
|
|
|
|
| 35 |
return self.chain
|
| 36 |
|
| 37 |
def call_chain(
|
| 38 |
+
self,
|
| 39 |
+
inputs,
|
| 40 |
+
streaming_handler,
|
| 41 |
+
q: Queue = None,
|
| 42 |
+
tracing: bool = False,
|
| 43 |
+
testing: bool = False,
|
| 44 |
):
|
| 45 |
print(inputs)
|
| 46 |
if self.llm_loader.streamer.for_huggingface:
|
|
|
|
| 51 |
|
| 52 |
chain = self.get_chain(tracing)
|
| 53 |
result = (
|
| 54 |
+
self._run_chain(chain, inputs, streaming_handler, testing)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
if streaming_handler is not None
|
| 56 |
else chain(inputs)
|
| 57 |
)
|
|
|
|
| 75 |
def _execute_chain(self, chain, inputs, q, sh):
|
| 76 |
q.put(chain(inputs, callbacks=[sh]))
|
| 77 |
|
| 78 |
+
def _run_chain(self, chain, inputs, streaming_handler, testing):
|
| 79 |
que = Queue()
|
| 80 |
|
| 81 |
t = Thread(
|
|
|
|
| 84 |
)
|
| 85 |
t.start()
|
| 86 |
|
| 87 |
+
if self.llm_loader.streamer.for_huggingface and not testing:
|
| 88 |
count = (
|
| 89 |
2
|
| 90 |
if "chat_history" in inputs and len(inputs.get("chat_history")) > 0
|
app_modules/llm_loader.py
CHANGED
|
@@ -227,6 +227,7 @@ class LLMLoader:
|
|
| 227 |
if "gpt4all-j" in MODEL_NAME_OR_PATH
|
| 228 |
or "dolly" in MODEL_NAME_OR_PATH
|
| 229 |
or "Qwen" in MODEL_NAME_OR_PATH
|
|
|
|
| 230 |
else 0
|
| 231 |
)
|
| 232 |
use_fast = (
|
|
@@ -452,7 +453,6 @@ class LLMLoader:
|
|
| 452 |
top_p=0.95,
|
| 453 |
top_k=0, # select from top 0 tokens (because zero, relies on top_p)
|
| 454 |
repetition_penalty=1.115,
|
| 455 |
-
use_auth_token=token,
|
| 456 |
token=token,
|
| 457 |
)
|
| 458 |
)
|
|
|
|
| 227 |
if "gpt4all-j" in MODEL_NAME_OR_PATH
|
| 228 |
or "dolly" in MODEL_NAME_OR_PATH
|
| 229 |
or "Qwen" in MODEL_NAME_OR_PATH
|
| 230 |
+
or "Llama-2" in MODEL_NAME_OR_PATH
|
| 231 |
else 0
|
| 232 |
)
|
| 233 |
use_fast = (
|
|
|
|
| 453 |
top_p=0.95,
|
| 454 |
top_k=0, # select from top 0 tokens (because zero, relies on top_p)
|
| 455 |
repetition_penalty=1.115,
|
|
|
|
| 456 |
token=token,
|
| 457 |
)
|
| 458 |
)
|
test.py
CHANGED
|
@@ -69,7 +69,11 @@ while True:
|
|
| 69 |
|
| 70 |
start = timer()
|
| 71 |
result = qa_chain.call_chain(
|
| 72 |
-
{"question": query, "chat_history": chat_history},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
)
|
| 74 |
end = timer()
|
| 75 |
print(f"Completed in {end - start:.3f}s")
|
|
|
|
| 69 |
|
| 70 |
start = timer()
|
| 71 |
result = qa_chain.call_chain(
|
| 72 |
+
{"question": query, "chat_history": chat_history},
|
| 73 |
+
custom_handler,
|
| 74 |
+
None,
|
| 75 |
+
False,
|
| 76 |
+
True,
|
| 77 |
)
|
| 78 |
end = timer()
|
| 79 |
print(f"Completed in {end - start:.3f}s")
|