feat: add MinTokensLogitProcessor and min_tokens argument to server (… · githubbadguy/llama-cpp-python@5212fb0 · GitHub
Skip to content

Commit 5212fb0

Browse files
authored
feat: add MinTokensLogitProcessor and min_tokens argument to server (abetlen#1333)
* implement min_tokens * set default to 0 * pass min_tokens * fix * remove copy * implement MinTokensLogitsProcessor * format * fix condition
1 parent 389e09c commit 5212fb0

3 files changed

Lines changed: 44 additions & 0 deletions

File tree

llama_cpp/llama.py

Lines changed: 16 additions & 0 deletions

llama_cpp/server/app.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ async def create_completion(
275275
"best_of",
276276
"logit_bias_type",
277277
"user",
278+
"min_tokens",
278279
}
279280
kwargs = body.model_dump(exclude=exclude)
280281

@@ -288,6 +289,15 @@ async def create_completion(
288289
if body.grammar is not None:
289290
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
290291

292+
if body.min_tokens > 0:
293+
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
294+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
295+
)
296+
if "logits_processor" not in kwargs:
297+
kwargs["logits_processor"] = _min_tokens_logits_processor
298+
else:
299+
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
300+
291301
iterator_or_completion: Union[
292302
llama_cpp.CreateCompletionResponse,
293303
Iterator[llama_cpp.CreateCompletionStreamResponse],
@@ -445,6 +455,7 @@ async def create_chat_completion(
445455
"n",
446456
"logit_bias_type",
447457
"user",
458+
"min_tokens",
448459
}
449460
kwargs = body.model_dump(exclude=exclude)
450461
llama = llama_proxy(body.model)
@@ -458,6 +469,15 @@ async def create_chat_completion(
458469
if body.grammar is not None:
459470
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
460471

472+
if body.min_tokens > 0:
473+
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
474+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
475+
)
476+
if "logits_processor" not in kwargs:
477+
kwargs["logits_processor"] = _min_tokens_logits_processor
478+
else:
479+
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
480+
461481
iterator_or_completion: Union[
462482
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
463483
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)

llama_cpp/server/types.py

Lines changed: 8 additions & 0 deletions

0 commit comments

Comments
 (0)