Cautious Next Token Prediction
Yizhou Wang, Lingzhi Zhang, Yue Bai, Mang Tik Chiu, Zhengmian Hu, Mingyuan Zhang, Qihua Dong, Yu Yin, Sohrab Amirghodsi, Yun Fu
Findings of ACL 2025
Biological Intelligence perform better when cautious, do Digital Intelligence also benefit from cautiousness?
Ilya told us that next token prediction can mimic human brain functioning in intelligence, and can lead to AGI:
If that is really the case (I believe so personally), this next token prediction paradigm should exhibit similar phenomenon as human brains.
When human beings think about a problem, when they are uncertain in the process of thinking, they will be more cautious and explore multiple paths of possible solutions before choosing the most confident one. This is very common in real lifes of human beings. Specifically, in human exam settings, individuals often re-check key steps and sample alternative reasoning paths when unsure, only finalizing the path that best resonates with previously established facts. Such progressive practice usually results in better human performance.
By analogy, we hypothesize that LLMs might also benefit from conditionally branching into multiple future continuations whenever they sense high uncertainty. Our core insight is to compute an entropy measure that indicates how “unsure” the model is, and trigger more thorough exploration exactly at those points. Once possible continuations are sampled, the model’s own likelihood function can judge the best candidate to proceed with, mirroring how humans choose the strongest proof line. In this work, we are thrilled to confirm the empirical effects of such practice on LLMs as well. Therefore, the answer to the question is yes: Digital Intelligence can also benefit from cautiousness!
We propose Cautious Next Token Prediction (CNTP), a novel training-free decoding strategy that adaptively samples multiple continuations based on model confidence. When the model exhibits high prediction entropy at a given step, CNTP samples multiple trials independently, stops at punctuation, and selects the trial with the lowest perplexity. This approach mimics human behavior: when uncertain, we explore multiple thinking paths before choosing the most confident one.
Algorithm 1 Cautious Next Token Prediction (CNTP)
- 🎯 Adaptive Sampling: Dynamically adjusts trial numbers based on model confidence
- 🚀 Training-Free: No additional training or fine-tuning required
- ⚡ Efficient: Focuses computational resources only where the model is uncertain
- 🎨 Balanced: Maintains stochasticity, coherence, and creativity
- 🔧 Compatible: Can be integrated with existing methods like self-consistency
Decoding Method | Stochasticity | Coherence | Creativity | Efficiency |
---|---|---|---|---|
Stochastic Decoding | ✅ | ❌ | ✅ | ✅ |
Greedy Decoding | ❌ | ✅ | ❌ | ✅ |
Beam Search | ❌ | ✅ | ❌ | ❌ |
CNTP (Ours) | ✅ | ✅ | ✅ | ✅ |
- Add all the codes for the baseline methods
git clone https://github.com/wyzjack/CNTP.git
cd CNTP
For LLM experiments on GSM8K and StrategyQA, please install the dependencies in the envs/gsm8k_strategyqa.yml
file.
conda env create -f envs/gsm8k_strategyqa.yml
conda activate gsm8k_strategyqa
cd custom_transformers_packages/gsm8k_strategyqa
pip install -e .
For LLM experiments on MATH and TruthfulQA, please install the dependencies in the envs/math_truthfulqa.yml
file.
conda env create -f envs/math_truthfulqa.yml
conda activate math_truthfulqa
cd custom_transformers_packages/math_truthfulqa
pip install -e .
For VLM experiments on MMVet and MathVista, please install the dependencies in the envs/mmvet_mathvista.yml
file.
conda env create -f envs/mmvet_mathvista.yml
conda activate mmvet_mathvista
cd custom_transformers_packages/mmvet_mathvista
pip install -e .
Our experiments were conducted on 8 A100 GPUs with 80GB memory. The following instructions are for reproducing all the main results of our paper.
cd experiments/SelfEval-Guided-Decoding
Follow the instructions in SelfEval-Guided-Decoding/README.md to run the experiments.
cd experiments/open-instruct
Follow the instructions in open-instruct/README.md to run the experiments.
cd experiments/VLMEvalKit
Follow the instructions in VLMEvalKit/README.md to run the experiments.
The core implementation of CNTP is to modify the _sample
function of class GenerationMixin
in transformers.generation.utils.py
in the transformers
library. Below is the code snippet of the modified function support of CNTP for transformers==4.47.1
.
def _sample_reflect_perplexity(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
entropy_threshold_low: Optional[float] = None,
entropy_threshold_high: Optional[float] = None,
max_trials: Optional[int] = None,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
max_length = generation_config.max_length
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
batch_size, cur_len = input_ids.shape
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
# Define punctuation tokens as a list
punctuation_tokens = [tokenizer.encode(".", add_special_tokens=False)[0],
tokenizer.encode(",", add_special_tokens=False)[0],
tokenizer.encode("\n", add_special_tokens=False)[0],
tokenizer.encode("?", add_special_tokens=False)[0],
tokenizer.encode("!", add_special_tokens=False)[0],
tokenizer.encode(":", add_special_tokens=False)[0],
tokenizer.encode(";", add_special_tokens=False)[0],
tokenizer.encode(")", add_special_tokens=False)[0],
tokenizer.encode("]", add_special_tokens=False)[0],
tokenizer.encode("}", add_special_tokens=False)[0]]
max_trials = max_trials if max_trials is not None else 10
min_trials = 1
entropy_threshold_low = entropy_threshold_low if entropy_threshold_low is not None else 0.01
entropy_threshold_high = entropy_threshold_high if entropy_threshold_high is not None else 1.5
def calculate_entropy(probs):
eps = 1e-10
entropy = -torch.sum(probs * torch.log(probs + eps), dim=-1)
return entropy
def determine_trial_number(entropy):
normalized_entropy = torch.clamp((entropy - entropy_threshold_low) / (entropy_threshold_high - entropy_threshold_low), 0, 1)
num_trials = min_trials + (max_trials - min_trials) * normalized_entropy
return int(num_trials.item())
def clone_cache(past_key_values):
if isinstance(past_key_values, tuple):
return tuple(tuple(x.clone() for x in layer) for layer in past_key_values)
elif isinstance(past_key_values, DynamicCache):
new_cache = DynamicCache()
for layer_idx in range(len(past_key_values.key_cache)):
new_cache.key_cache.append(past_key_values.key_cache[layer_idx].clone())
new_cache.value_cache.append(past_key_values.value_cache[layer_idx].clone())
return new_cache
elif isinstance(past_key_values, EncoderDecoderCache):
return EncoderDecoderCache(
clone_cache(past_key_values.self_attention_cache),
clone_cache(past_key_values.cross_attention_cache)
)
elif isinstance(past_key_values, StaticCache):
return past_key_values.clone()
return past_key_values
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):
original_cache = None
if "past_key_values" in model_kwargs:
original_cache = clone_cache(model_kwargs["past_key_values"])
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
outputs = self(**model_inputs, return_dict=True)
if synced_gpus and this_peer_finished:
continue
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_logits = next_token_logits.to(input_ids.device)
next_token_scores = logits_processor(input_ids, next_token_logits)
probs = nn.functional.softmax(next_token_scores, dim=-1)
# Calculate entropy and determine number of trials
entropy = calculate_entropy(probs)
num_trials = determine_trial_number(entropy)
if num_trials > 1:
print("entropy: {}, num_trials: {}".format(entropy, num_trials))
# Generate multiple trials
trial_sequences = []
trial_scores = []
trial_outputs = []
trial_model_kwargs = []
for trial_idx in range(num_trials):
# print(trial_idx)
trial_tokens = []
trial_probs = [] # Store individual token probabilities instead of joint prob
trial_step_scores = [] # Store scores for each step
trial_step_logits = [] # Store logits for each step
curr_input_ids = input_ids.clone()
# Clone model kwargs and cache for this trial
curr_model_kwargs = copy.deepcopy(model_kwargs)
if original_cache is not None:
curr_model_kwargs["past_key_values"] = clone_cache(original_cache)
while True:
curr_inputs = self.prepare_inputs_for_generation(
curr_input_ids,
**curr_model_kwargs
)
curr_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
curr_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
curr_outputs = self(**curr_inputs, return_dict=True)
curr_logits = curr_outputs.logits[:, -1, :].clone().float()
curr_logits = curr_logits.to(curr_input_ids.device)
curr_scores = logits_processor(curr_input_ids, curr_logits)
curr_probs = nn.functional.softmax(curr_scores, dim=-1)
# Store scores and logits for each step
if output_scores:
trial_step_scores.append(curr_scores)
if output_logits:
trial_step_logits.append(curr_logits)
# Sample next token
curr_token = torch.multinomial(curr_probs, num_samples=1).squeeze(1)
trial_tokens.append(curr_token)
trial_probs.append(curr_probs[0, curr_token.item()].item())
curr_input_ids = torch.cat([curr_input_ids, curr_token.unsqueeze(1)], dim=-1)
curr_model_kwargs = self._update_model_kwargs_for_generation(
curr_outputs, curr_model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# Check stopping criteria for the current trial's input_ids
# Convert to tuple for stopping criteria compatibility
accumulated_scores = tuple(trial_step_scores) if output_scores else None
trial_should_stop = stopping_criteria(curr_input_ids, accumulated_scores).any()
if trial_should_stop or (curr_token.item() in punctuation_tokens):
break
trial_sequences.append(torch.stack(trial_tokens))
trial_scores.append(trial_probs) # Store list of probabilities
trial_outputs.append((curr_outputs, trial_step_scores, trial_step_logits)) # Store outputs along with step scores/logits
trial_model_kwargs.append(curr_model_kwargs)
# delete
del curr_outputs
del curr_logits
del curr_scores
del curr_probs
del curr_input_ids
del curr_model_kwargs
del curr_inputs
torch.cuda.empty_cache()
gc.collect()
# Calculate normalized joint probabilities using only the first min_length tokens
perplexities = []
for probs in trial_scores:
# Calculate negative log likelihood
nll = -sum(np.log(p) for p in probs)
# Calculate perplexity: exp(average negative log likelihood)
sequence_length = len(probs)
perplexity = np.exp(nll / sequence_length)
perplexities.append(perplexity)
# Select best trial based on lowest perplexity
best_trial_idx = min(range(len(perplexities)), key=lambda i: perplexities[i])
# Get the full sequence for the best trial (no truncation needed)
next_tokens = trial_sequences[best_trial_idx].squeeze(1).unsqueeze(0)
best_outputs, best_step_scores, best_step_logits = trial_outputs[best_trial_idx]
# Store scores, attentions and hidden_states for the best trial
if return_dict_in_generate:
if output_scores:
for token_score in best_step_scores:
scores += (token_score,)
if output_logits:
for token_logit in best_step_logits:
raw_logits += (token_logit,)
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = trial_model_kwargs[best_trial_idx]
# clear trial data
trial_sequences.clear()
trial_scores.clear()
trial_outputs.clear()
trial_model_kwargs.clear()
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += next_tokens.shape[-1]
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
del outputs
del best_outputs
del best_step_scores
del best_step_logits
else:
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# token selection
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# print(input_ids.shape)
if streamer is not None:
# print("streamer is not None")
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1
del outputs
del original_cache
if streamer is not None:
streamer.end()
# Before returning, truncate everything to max_length if needed
max_length = stopping_criteria.max_length
if input_ids.shape[1] > max_length:
input_ids = input_ids[:, :max_length]
# clear unused variables
if 'trial_sequences' in locals():
del trial_sequences
if 'trial_scores' in locals():
del trial_scores
if 'trial_outputs' in locals():
del trial_outputs
if 'trial_model_kwargs' in locals():
del trial_model_kwargs
if 'normalized_scores' in locals():
del normalized_scores
if 'best_trial_idx' in locals():
del best_trial_idx
if 'min_length' in locals():
del min_length
if 'trial_tokens' in locals():
del trial_tokens
if 'trial_probs' in locals():
del trial_probs
if 'trial_step_scores' in locals():
del trial_step_scores
if 'trial_step_logits' in locals():
del trial_step_logits
torch.cuda.empty_cache()
# force garbage collection
gc.collect()
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
After this paper is released on arXiv, we are notified of the concurrent work Entropix. We were not aware of the work when writing the paper. Entropix also takes advantage of model output logit entropy to decide the LLM sampling strategy, dividing into four types: 1) Insert CoT or Pause Token 2) Resample 3) Argmax 4) Branch. This is a general and elegant approach for LLMs to simulate the o1-alike effects. The idea of using model confidence to change the sampling strategy is similar to ours. However, our CNTP differs in that: firstly, we innovatively propose to stop at punctuations, enabling multiple local optimal branching and sampling in each answer generation. Secondly, we design a specific negative correlation relationship between the answer trial sampling number and the confidence, achieving superiority over the baseline decoding approaches. We leave the comparison of CNTP and Entropix for future work.
If you find CNTP useful, please cite our paper:
@inproceedings{wang-etal-2025-cautious,
title = "Cautious Next Token Prediction",
author = "Wang, Yizhou and
Zhang, Lingzhi and
Bai, Yue and
Chiu, Mang Tik and
Hu, Zhengmian and
Zhang, Mingyuan and
Dong, Qihua and
Yin, Yu and
Amirghodsi, Sohrab and
Fu, Yun",
editor = "Che, Wanxiang and
Nabende, Joyce and
Shutova, Ekaterina and
Pilehvar, Mohammad Taher",
booktitle = "Findings of the Association for Computational Linguistics: ACL 2025",
month = jul,
year = "2025",
address = "Vienna, Austria",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2025.findings-acl.1318/",
pages = "25685--25697",
ISBN = "979-8-89176-256-5",
abstract = "Next token prediction paradigm has been prevailing for autoregressive models in the era of LLMs. The current default sampling choice for popular LLMs is temperature scaling together with nucleus sampling to balance diversity and coherence. Nevertheless, such approach leads to inferior performance in various NLP tasks when the model is not certain about testing questions. To this end, we propose a brand new training-free decoding strategy, dubbed as Cautious Next Token Prediction (CNTP). In the decoding process, if the model has comparatively high prediction entropy at a certain step, we sample multiple trials starting from the step independently and stop when encountering any punctuation. Then we select the trial with the lowest perplexity score viewed as the most probable and reliable trial path given the model{'}s capacity. The trial number is negatively correlated with the prediction confidence, i.e., the less confident the model is, the more trials it should sample. This is consistent with human beings' behaviour: when feeling uncertain or unconfident, one tends to think more creatively, exploring multiple thinking paths, to cautiously select the path one feels most confident about. Extensive experiments on both LLMs and MLLMs show that our proposed CNTP approach outperforms existing standard decoding strategies consistently by a clear margin. Moreover, the integration of CNTP with self consistency can further improve over vanilla self consistency. We believe our proposed CNTP has the potential to become one of the default choices for LLM decoding. Code is available at https://github.com/wyzjack/CNTP."
}
This project is licensed under the MIT License - see the LICENSE file for details.
In this code we heavily rely on the public Github repos open-instruct, SelfEval-Guided-Decoding and VLMEvalKit. Great thanks to them! We also greatly thank the anounymous ACL'25 reviewers for the constructive comments to help us improve the paper.
- Yizhou Wang: wyzjack990122@gmail.com