@@ -646,36 +646,16 @@ class CreateCompletionRequest(BaseModel):
646646 }
647647
648648
649- def make_logit_bias_processor (
649+ def _logit_bias_tokens_to_input_ids (
650650 llama : llama_cpp .Llama ,
651651 logit_bias : Dict [str , float ],
652- logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]],
653- ):
654- if logit_bias_type is None :
655- logit_bias_type = "input_ids"
656-
657- to_bias : Dict [int , float ] = {}
658- if logit_bias_type == "input_ids" :
659- for input_id , score in logit_bias .items ():
660- input_id = int (input_id )
661- to_bias [input_id ] = score
662-
663- elif logit_bias_type == "tokens" :
664- for token , score in logit_bias .items ():
665- token = token .encode ("utf-8" )
666- for input_id in llama .tokenize (token , add_bos = False , special = True ):
667- to_bias [input_id ] = score
668-
669- def logit_bias_processor (
670- input_ids : npt .NDArray [np .intc ],
671- scores : npt .NDArray [np .single ],
672- ) -> npt .NDArray [np .single ]:
673- new_scores = np .copy (scores ) # Does it make sense to copy the whole array or can we just overwrite the original one?
674- for input_id , score in to_bias .items ():
675- new_scores [input_id ] = score + scores [input_id ]
676- return new_scores
677-
678- return logit_bias_processor
652+ ) -> Dict [str , float ]:
653+ to_bias : Dict [str , float ] = {}
654+ for token , score in logit_bias .items ():
655+ token = token .encode ("utf-8" )
656+ for input_id in llama .tokenize (token , add_bos = False , special = True ):
657+ to_bias [str (input_id )] = score
658+ return to_bias
679659
680660
681661@router .post (
@@ -694,17 +674,16 @@ async def create_completion(
694674 exclude = {
695675 "n" ,
696676 "best_of" ,
697- "logit_bias" ,
698677 "logit_bias_type" ,
699678 "user" ,
700679 }
701680 kwargs = body .model_dump (exclude = exclude )
702681
703682 if body .logit_bias is not None :
704- kwargs ["logits_processor " ] = llama_cpp . LogitsProcessorList (
705- [
706- make_logit_bias_processor ( llama , body .logit_bias , body . logit_bias_type ),
707- ]
683+ kwargs ["logit_bias " ] = (
684+ _logit_bias_tokens_to_input_ids ( llama , body . logit_bias )
685+ if body .logit_bias_type == "tokens"
686+ else body . logit_bias
708687 )
709688
710689 if body .grammar is not None :
@@ -851,17 +830,16 @@ async def create_chat_completion(
851830) -> llama_cpp .ChatCompletion :
852831 exclude = {
853832 "n" ,
854- "logit_bias" ,
855833 "logit_bias_type" ,
856834 "user" ,
857835 }
858836 kwargs = body .model_dump (exclude = exclude )
859837
860838 if body .logit_bias is not None :
861- kwargs ["logits_processor " ] = llama_cpp . LogitsProcessorList (
862- [
863- make_logit_bias_processor ( llama , body .logit_bias , body . logit_bias_type ),
864- ]
839+ kwargs ["logit_bias " ] = (
840+ _logit_bias_tokens_to_input_ids ( llama , body . logit_bias )
841+ if body .logit_bias_type == "tokens"
842+ else body . logit_bias
865843 )
866844
867845 if body .grammar is not None :
0 commit comments