Update src/display/utils.py
Browse files- src/display/utils.py +7 -6
src/display/utils.py
CHANGED
|
@@ -240,16 +240,15 @@ class WeightDtype(Enum):
|
|
| 240 |
int2 = ModelDetails("int2")
|
| 241 |
int3 = ModelDetails("int3")
|
| 242 |
int4 = ModelDetails("int4")
|
|
|
|
| 243 |
nf4 = ModelDetails("nf4")
|
| 244 |
fp4 = ModelDetails("fp4")
|
| 245 |
-
|
| 246 |
bf16 = ModelDetails("bfloat16")
|
| 247 |
-
|
| 248 |
|
| 249 |
Unknown = ModelDetails("?")
|
| 250 |
|
| 251 |
-
|
| 252 |
-
|
| 253 |
def from_str(weight_dtype):
|
| 254 |
if weight_dtype in ["int2"]:
|
| 255 |
return WeightDtype.int2
|
|
@@ -257,6 +256,8 @@ class WeightDtype(Enum):
|
|
| 257 |
return WeightDtype.int3
|
| 258 |
if weight_dtype in ["int4"]:
|
| 259 |
return WeightDtype.int4
|
|
|
|
|
|
|
| 260 |
if weight_dtype in ["nf4"]:
|
| 261 |
return WeightDtype.nf4
|
| 262 |
if weight_dtype in ["fp4"]:
|
|
@@ -264,11 +265,11 @@ class WeightDtype(Enum):
|
|
| 264 |
if weight_dtype in ["All"]:
|
| 265 |
return WeightDtype.all
|
| 266 |
if weight_dtype in ["float16"]:
|
| 267 |
-
return WeightDtype.
|
| 268 |
if weight_dtype in ["bfloat16"]:
|
| 269 |
return WeightDtype.bf16
|
| 270 |
if weight_dtype in ["float32"]:
|
| 271 |
-
return WeightDtype.
|
| 272 |
return WeightDtype.Unknown
|
| 273 |
|
| 274 |
class ComputeDtype(Enum):
|
|
|
|
| 240 |
int2 = ModelDetails("int2")
|
| 241 |
int3 = ModelDetails("int3")
|
| 242 |
int4 = ModelDetails("int4")
|
| 243 |
+
int8 = ModelDetails("int8")
|
| 244 |
nf4 = ModelDetails("nf4")
|
| 245 |
fp4 = ModelDetails("fp4")
|
| 246 |
+
f16 = ModelDetails("float16")
|
| 247 |
bf16 = ModelDetails("bfloat16")
|
| 248 |
+
f32 = ModelDetails("float32")
|
| 249 |
|
| 250 |
Unknown = ModelDetails("?")
|
| 251 |
|
|
|
|
|
|
|
| 252 |
def from_str(weight_dtype):
|
| 253 |
if weight_dtype in ["int2"]:
|
| 254 |
return WeightDtype.int2
|
|
|
|
| 256 |
return WeightDtype.int3
|
| 257 |
if weight_dtype in ["int4"]:
|
| 258 |
return WeightDtype.int4
|
| 259 |
+
if weight_dtype in ["int8"]:
|
| 260 |
+
return WeightDtype.int8
|
| 261 |
if weight_dtype in ["nf4"]:
|
| 262 |
return WeightDtype.nf4
|
| 263 |
if weight_dtype in ["fp4"]:
|
|
|
|
| 265 |
if weight_dtype in ["All"]:
|
| 266 |
return WeightDtype.all
|
| 267 |
if weight_dtype in ["float16"]:
|
| 268 |
+
return WeightDtype.f16
|
| 269 |
if weight_dtype in ["bfloat16"]:
|
| 270 |
return WeightDtype.bf16
|
| 271 |
if weight_dtype in ["float32"]:
|
| 272 |
+
return WeightDtype.f32
|
| 273 |
return WeightDtype.Unknown
|
| 274 |
|
| 275 |
class ComputeDtype(Enum):
|