@@ -554,13 +554,62 @@ float bf16_to_f32(uint16_t bfloat16) {
554554 return *reinterpret_cast <float *>(&val_bits);
555555}
556556
557+ uint16_t f8_e4m3_to_f16 (uint8_t f8 ) {
558+ // do we need to support uz?
559+
560+ const uint32_t exponent_bias = 7 ;
561+ if (f8 == 0xff ) {
562+ return ggml_fp32_to_fp16 (-NAN);
563+ } else if (f8 == 0x7f ) {
564+ return ggml_fp32_to_fp16 (NAN);
565+ }
566+
567+ uint32_t sign = f8 & 0x80 ;
568+ uint32_t exponent = (f8 & 0x78 ) >> 3 ;
569+ uint32_t mantissa = f8 & 0x07 ;
570+ uint32_t result = sign << 24 ;
571+ if (exponent == 0 ) {
572+ if (mantissa > 0 ) {
573+ exponent = 0x7f - exponent_bias;
574+
575+ // yes, 2 times
576+ if ((mantissa & 0x04 ) == 0 ) {
577+ mantissa &= 0x03 ;
578+ mantissa <<= 1 ;
579+ exponent -= 1 ;
580+ }
581+ if ((mantissa & 0x04 ) == 0 ) {
582+ mantissa &= 0x03 ;
583+ mantissa <<= 1 ;
584+ exponent -= 1 ;
585+ }
586+
587+ result |= (mantissa & 0x03 ) << 21 ;
588+ result |= exponent << 23 ;
589+ }
590+ } else {
591+ result |= mantissa << 20 ;
592+ exponent += 0x7f - exponent_bias;
593+ result |= exponent << 23 ;
594+ }
595+
596+ return ggml_fp32_to_fp16 (*reinterpret_cast <const float *>(&result));
597+ }
598+
557599void bf16_to_f32_vec (uint16_t * src, float * dst, int64_t n) {
558600 // support inplace op
559601 for (int64_t i = n - 1 ; i >= 0 ; i--) {
560602 dst[i] = bf16_to_f32 (src[i]);
561603 }
562604}
563605
606+ void f8_e4m3_to_f16_vec (uint8_t * src, uint16_t * dst, int64_t n) {
607+ // support inplace op
608+ for (int64_t i = n - 1 ; i >= 0 ; i--) {
609+ dst[i] = f8_e4m3_to_f16 (src[i]);
610+ }
611+ }
612+
564613void convert_tensor (void * src,
565614 ggml_type src_type,
566615 void * dst,
@@ -794,6 +843,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
794843 ttype = GGML_TYPE_F32;
795844 } else if (dtype == " F32" ) {
796845 ttype = GGML_TYPE_F32;
846+ } else if (dtype == " F8_E4M3" ) {
847+ ttype = GGML_TYPE_F16;
797848 }
798849 return ttype;
799850}
@@ -866,7 +917,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
866917
867918 ggml_type type = str_to_ggml_type (dtype);
868919 if (type == GGML_TYPE_COUNT) {
869- LOG_ERROR (" unsupported dtype '%s'" , dtype.c_str ());
920+ LOG_ERROR (" unsupported dtype '%s' (tensor '%s') " , dtype. c_str (), name .c_str ());
870921 return false ;
871922 }
872923
@@ -903,6 +954,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
903954 if (dtype == " BF16" ) {
904955 tensor_storage.is_bf16 = true ;
905956 GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
957+ } else if (dtype == " F8_E4M3" ) {
958+ tensor_storage.is_f8_e4m3 = true ;
959+ // f8 -> f16
960+ GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
906961 } else {
907962 GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size);
908963 }
@@ -1537,6 +1592,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
15371592 if (tensor_storage.is_bf16 ) {
15381593 // inplace op
15391594 bf16_to_f32_vec ((uint16_t *)dst_tensor->data , (float *)dst_tensor->data , tensor_storage.nelements ());
1595+ } else if (tensor_storage.is_f8_e4m3 ) {
1596+ // inplace op
1597+ f8_e4m3_to_f16_vec ((uint8_t *)dst_tensor->data , (uint16_t *)dst_tensor->data , tensor_storage.nelements ());
15401598 }
15411599 } else {
15421600 read_buffer.resize (tensor_storage.nbytes ());
@@ -1545,6 +1603,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
15451603 if (tensor_storage.is_bf16 ) {
15461604 // inplace op
15471605 bf16_to_f32_vec ((uint16_t *)read_buffer.data (), (float *)read_buffer.data (), tensor_storage.nelements ());
1606+ } else if (tensor_storage.is_f8_e4m3 ) {
1607+ // inplace op
1608+ f8_e4m3_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
15481609 }
15491610
15501611 convert_tensor ((void *)read_buffer.data (), tensor_storage.type , dst_tensor->data ,
@@ -1557,6 +1618,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
15571618 if (tensor_storage.is_bf16 ) {
15581619 // inplace op
15591620 bf16_to_f32_vec ((uint16_t *)read_buffer.data (), (float *)read_buffer.data (), tensor_storage.nelements ());
1621+ } else if (tensor_storage.is_f8_e4m3 ) {
1622+ // inplace op
1623+ f8_e4m3_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
15601624 }
15611625
15621626 if (tensor_storage.type == dst_tensor->type ) {
0 commit comments