Factory function to create a BestOfN wrapper around any module. Runs the module N times and returns the best result according to a reward function.
Arguments
- module
A Module object to wrap
- N
Maximum number of attempts (default 3)
- reward_fn
Reward function with signature
function(prediction, inputs), returning a score between 0 and 1. Useas_reward_fn()to convert a metric. If NULL, uses a default that returns 1.0 for all valid predictions.- threshold
Score threshold for early stopping (default 1.0)
- fail_count
Maximum consecutive failures before erroring (default N)
- ...
Additional arguments passed to the module constructor
Examples
# Create a basic QA module
qa <- module(signature("question -> answer"))
# Wrap it with best-of-3 selection
wrapper <- best_of_n(qa, N = 3)
# With a custom reward function
one_word_reward <- function(pred, inputs) {
words <- strsplit(as.character(pred$answer), "\\s+")[[1]]
if (length(words) == 1) 1.0 else 0.0
}
wrapper <- best_of_n(qa, N = 5, reward_fn = one_word_reward)
# Using a metric-based reward function
wrapper <- best_of_n(
qa,
N = 3,
reward_fn = as_reward_fn(metric_exact_match(field = "answer")),
threshold = 1.0
)
