Spaces:
Sleeping
Sleeping
| import logging | |
| def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, | |
| exclude_list=[], topk=10): | |
| xq_s = [ | |
| f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq] | |
| exclude_list_str = ','.join([f'\'{i}\'' for i in exclude_list]) | |
| _cond = (f"WHERE obj_id NOT IN ({exclude_list_str})" if len( | |
| exclude_list) > 0 else "") | |
| _subq_str = [] | |
| _img_score_subq = [] | |
| for _l, _xq in enumerate(xq_s): | |
| _img_score_subq.append( | |
| f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))") | |
| _subq_str.append(f""" | |
| SELECT img_id, img_url, img_w, img_h, 1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))) AS pred_logit, | |
| obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l | |
| FROM {OBJ_DB_NAME} | |
| JOIN {IMG_DB_NAME} | |
| ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id | |
| PREWHERE obj_id IN ( | |
| SELECT obj_id FROM ( | |
| SELECT obj_id, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist FROM {OBJ_DB_NAME} | |
| ORDER BY dist DESC | |
| ) {_cond} LIMIT 10 | |
| ) | |
| """) | |
| _subq_str = ' UNION ALL '.join(_subq_str) | |
| _img_score_q = ','.join(_img_score_subq) | |
| _img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score" | |
| q_str = f""" | |
| SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id, | |
| groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h, | |
| groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb, | |
| {_img_score_q} | |
| FROM | |
| ({_subq_str}) | |
| GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC | |
| """ | |
| xc = client.fetch(q_str) | |
| return xc | |
| def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08): | |
| xq_s = [ | |
| f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq] | |
| image_list = ','.join([f'\'{i}\'' for i in img_ids]) | |
| _thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else "" | |
| _subq_str = [] | |
| _img_score_subq = [] | |
| for _l, _xq in enumerate(xq_s): | |
| _img_score_subq.append( | |
| f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))") | |
| _subq_str.append(f""" | |
| SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, | |
| (1 / (1 + exp(-(arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))))) AS pred_logit, | |
| obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l | |
| FROM {OBJ_DB_NAME} | |
| JOIN {IMG_DB_NAME} | |
| ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id | |
| PREWHERE img_id IN ({image_list}) | |
| {_thresh} | |
| """) | |
| _subq_str = ' UNION ALL '.join(_subq_str) | |
| _img_score_q = ','.join(_img_score_subq) | |
| _img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score" | |
| q_str = f""" | |
| SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h, | |
| groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h, | |
| groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb, | |
| {_img_score_q} | |
| FROM | |
| ({_subq_str}) | |
| GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC | |
| """ | |
| xc = client.fetch(q_str) | |
| return xc | |
| def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10): | |
| xq_s = [ | |
| f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq] | |
| res = [] | |
| subq_str = [] | |
| _thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else "" | |
| for _l, _xq in enumerate(xq_s): | |
| subq_str.append( | |
| f""" | |
| SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, prelogit, | |
| obj_id, box_cx, box_cy, box_w, box_h, {_l} AS l, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist | |
| FROM {OBJ_DB_NAME} | |
| JOIN {IMG_DB_NAME} | |
| ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id | |
| {_thresh} LIMIT 10 | |
| """) | |
| subq_str = " UNION ALL ".join(subq_str) | |
| q_str = f""" | |
| SELECT groupArray(img_url) AS img_url, groupArray(img_w) AS img_w, groupArray(img_h) AS img_h, | |
| groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h, | |
| l AS label, groupArray(dist) as d, | |
| groupArray(1 / (1 + exp(-dist))) AS logit FROM ( | |
| {subq_str} | |
| ) | |
| GROUP BY l | |
| """ | |
| res = client.fetch(q_str) | |
| return res | |