Spaces:
Configuration error
Configuration error
David Day
commited on
debug
Browse files- model_worker.py +12 -33
- requirements.txt +1 -0
model_worker.py
CHANGED
|
@@ -52,12 +52,12 @@ class ModelWorker:
|
|
| 52 |
torch_device='cpu',
|
| 53 |
device_map="cpu",
|
| 54 |
)
|
| 55 |
-
self.model.to(
|
| 56 |
|
| 57 |
@spaces.GPU
|
| 58 |
def generate_stream(self, params):
|
| 59 |
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
| 60 |
-
logger.info(f'Model devices: {
|
| 61 |
|
| 62 |
prompt = params["prompt"]
|
| 63 |
ori_prompt = prompt
|
|
@@ -70,17 +70,18 @@ class ModelWorker:
|
|
| 70 |
|
| 71 |
images = [load_image_from_base64(image) for image in images]
|
| 72 |
images = process_images(images, image_processor, model.config)
|
|
|
|
| 73 |
|
| 74 |
if type(images) is list:
|
| 75 |
-
images = [image.to(
|
| 76 |
else:
|
| 77 |
-
images = images.to(
|
| 78 |
|
| 79 |
if self.load_bf16:
|
| 80 |
images = images.to(dtype=torch.bfloat16)
|
| 81 |
|
| 82 |
replace_token = DEFAULT_IMAGE_TOKEN
|
| 83 |
-
if getattr(
|
| 84 |
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
| 85 |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 86 |
|
|
@@ -99,15 +100,15 @@ class ModelWorker:
|
|
| 99 |
stop_str = params.get("stop", None)
|
| 100 |
do_sample = True if temperature > 0.001 else False
|
| 101 |
|
| 102 |
-
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).
|
| 103 |
keywords = [stop_str]
|
| 104 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 105 |
-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=
|
| 106 |
|
| 107 |
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
| 108 |
|
| 109 |
if max_new_tokens < 1:
|
| 110 |
-
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode()
|
| 111 |
return
|
| 112 |
|
| 113 |
thread = Thread(target=model.generate, kwargs=dict(
|
|
@@ -128,33 +129,11 @@ class ModelWorker:
|
|
| 128 |
generated_text += new_text
|
| 129 |
if generated_text.endswith(stop_str):
|
| 130 |
generated_text = generated_text[:-len(stop_str)]
|
| 131 |
-
yield json.dumps({"text": generated_text, "error_code": 0}).encode()
|
| 132 |
|
| 133 |
def generate_stream_gate(self, params):
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
yield x
|
| 137 |
-
except ValueError as e:
|
| 138 |
-
print("Caught ValueError:", e)
|
| 139 |
-
ret = {
|
| 140 |
-
"text": server_error_msg,
|
| 141 |
-
"error_code": 1,
|
| 142 |
-
}
|
| 143 |
-
yield json.dumps(ret).encode() + b"\0"
|
| 144 |
-
except torch.cuda.CudaError as e:
|
| 145 |
-
print("Caught torch.cuda.CudaError:", e)
|
| 146 |
-
ret = {
|
| 147 |
-
"text": server_error_msg,
|
| 148 |
-
"error_code": 1,
|
| 149 |
-
}
|
| 150 |
-
yield json.dumps(ret).encode() + b"\0"
|
| 151 |
-
except Exception as e:
|
| 152 |
-
print("Caught Unknown Error", e)
|
| 153 |
-
ret = {
|
| 154 |
-
"text": server_error_msg,
|
| 155 |
-
"error_code": 1,
|
| 156 |
-
}
|
| 157 |
-
yield json.dumps(ret).encode() + b"\0"
|
| 158 |
|
| 159 |
def release_model_semaphore(fn=None):
|
| 160 |
model_semaphore.release()
|
|
|
|
| 52 |
torch_device='cpu',
|
| 53 |
device_map="cpu",
|
| 54 |
)
|
| 55 |
+
self.model.to('cuda')
|
| 56 |
|
| 57 |
@spaces.GPU
|
| 58 |
def generate_stream(self, params):
|
| 59 |
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
| 60 |
+
logger.info(f'Model devices: {model.device}')
|
| 61 |
|
| 62 |
prompt = params["prompt"]
|
| 63 |
ori_prompt = prompt
|
|
|
|
| 70 |
|
| 71 |
images = [load_image_from_base64(image) for image in images]
|
| 72 |
images = process_images(images, image_processor, model.config)
|
| 73 |
+
logger.info(f'Images: {images.shape}')
|
| 74 |
|
| 75 |
if type(images) is list:
|
| 76 |
+
images = [image.to(model.device, dtype=torch.float16) for image in images]
|
| 77 |
else:
|
| 78 |
+
images = images.to(model.device, dtype=torch.float16)
|
| 79 |
|
| 80 |
if self.load_bf16:
|
| 81 |
images = images.to(dtype=torch.bfloat16)
|
| 82 |
|
| 83 |
replace_token = DEFAULT_IMAGE_TOKEN
|
| 84 |
+
if getattr(model.config, 'mm_use_im_start_end', False):
|
| 85 |
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
| 86 |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 87 |
|
|
|
|
| 100 |
stop_str = params.get("stop", None)
|
| 101 |
do_sample = True if temperature > 0.001 else False
|
| 102 |
|
| 103 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
| 104 |
keywords = [stop_str]
|
| 105 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 106 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=None)
|
| 107 |
|
| 108 |
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
| 109 |
|
| 110 |
if max_new_tokens < 1:
|
| 111 |
+
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode()
|
| 112 |
return
|
| 113 |
|
| 114 |
thread = Thread(target=model.generate, kwargs=dict(
|
|
|
|
| 129 |
generated_text += new_text
|
| 130 |
if generated_text.endswith(stop_str):
|
| 131 |
generated_text = generated_text[:-len(stop_str)]
|
| 132 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode()
|
| 133 |
|
| 134 |
def generate_stream_gate(self, params):
|
| 135 |
+
for x in self.generate_stream(params):
|
| 136 |
+
yield x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
def release_model_semaphore(fn=None):
|
| 139 |
model_semaphore.release()
|
requirements.txt
CHANGED
|
@@ -11,4 +11,5 @@ einops==0.6.1
|
|
| 11 |
einops-exts==0.0.4
|
| 12 |
timm==0.6.13
|
| 13 |
httpx==0.24.0
|
|
|
|
| 14 |
scipy
|
|
|
|
| 11 |
einops-exts==0.0.4
|
| 12 |
timm==0.6.13
|
| 13 |
httpx==0.24.0
|
| 14 |
+
numpy==1.26.4
|
| 15 |
scipy
|