|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
from camel.models.reward import BaseRewardModel |
|
|
from camel.types import ModelType |
|
|
|
|
|
|
|
|
class SkyworkRewardModel(BaseRewardModel): |
|
|
r"""Reward model based on the transformers, it will download the model |
|
|
from huggingface. |
|
|
|
|
|
Args: |
|
|
model_type (Union[ModelType, str]): Model for which a backend is |
|
|
created. |
|
|
api_key (Optional[str], optional): Not used. (default: :obj:`None`) |
|
|
url (Optional[str], optional): Not used. (default: :obj:`None`) |
|
|
device_map (Optional[str], optional): choose the device map. |
|
|
(default: :obj:`auto`) |
|
|
attn_implementation (Optional[str], optional): choose the attention |
|
|
implementation. (default: :obj:`flash_attention_2`) |
|
|
offload_folder (Optional[str], optional): choose the offload folder. |
|
|
(default: :obj:`offload`) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_type: Union[ModelType, str], |
|
|
api_key: Optional[str] = None, |
|
|
url: Optional[str] = None, |
|
|
device_map: Optional[str] = "auto", |
|
|
attn_implementation: Optional[str] = "flash_attention_2", |
|
|
offload_folder: Optional[str] = "offload", |
|
|
) -> None: |
|
|
from transformers import ( |
|
|
AutoModelForSequenceClassification, |
|
|
AutoTokenizer, |
|
|
) |
|
|
|
|
|
super().__init__(model_type, api_key, url) |
|
|
self._client = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_type, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map=device_map, |
|
|
attn_implementation=attn_implementation, |
|
|
offload_folder=offload_folder, |
|
|
num_labels=1, |
|
|
) |
|
|
self._tokenizer = AutoTokenizer.from_pretrained(model_type) |
|
|
|
|
|
def evaluate(self, messages: List[Dict[str, str]]) -> Dict[str, float]: |
|
|
r"""Evaluate the messages using the Skywork model. |
|
|
|
|
|
Args: |
|
|
messages (List[Dict[str, str]]): A list of messages. |
|
|
|
|
|
Returns: |
|
|
ChatCompletion: A ChatCompletion object with the scores. |
|
|
""" |
|
|
inputs = self._tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
with torch.no_grad(): |
|
|
score = self._client(inputs).logits[0][0].item() |
|
|
return {"Score": score} |
|
|
|
|
|
def get_scores_types(self) -> List[str]: |
|
|
r"""get the scores types |
|
|
|
|
|
Returns: |
|
|
List[str]: list of scores types |
|
|
""" |
|
|
return ["Score"] |
|
|
|