127 lines
3.4 KiB
Lua
127 lines
3.4 KiB
Lua
local serde = require("@lune/serde")
|
|
local ollama = require("@ollama")
|
|
|
|
export type AnswerRecord = {
|
|
model: string,
|
|
category: string,
|
|
index: number,
|
|
question: string,
|
|
candidate: string,
|
|
reference: string,
|
|
pointReward: number,
|
|
}
|
|
|
|
export type EvalJSON = {
|
|
score: number,
|
|
rationale: string,
|
|
correct: boolean?,
|
|
}
|
|
|
|
export type ScoredRecord = AnswerRecord & {
|
|
score: number,
|
|
rationale: string,
|
|
}
|
|
|
|
local function clamp(n: number, lo: number, hi: number): number
|
|
if n < lo then return lo end
|
|
if n > hi then return hi end
|
|
return n
|
|
end
|
|
|
|
local function evalAnswer(evaluatorModel: string, question: string, reference: string, candidate: string, maxPoints: number): EvalJSON
|
|
local client = ollama.serve()
|
|
|
|
local system = table.concat({
|
|
"You are a strict grader.",
|
|
"Return ONLY valid JSON complying with the given schema.",
|
|
"No prose, no markdown, no code fences.",
|
|
}, " ")
|
|
|
|
local schema = [[{"score": number, "rationale": string, "correct": boolean}]]
|
|
|
|
local instructions = string.format([[You will grade a candidate answer.
|
|
Constraints:
|
|
- Award an integer score from 0 to %d.
|
|
- Keep rationale 1-2 short sentences.
|
|
- Set correct=true if the candidate meaningfully matches the reference answer, else false.
|
|
Output:
|
|
- Return ONLY a single JSON object matching this schema: %s
|
|
|
|
Question:
|
|
"""
|
|
%s
|
|
"""
|
|
|
|
Reference Answer:
|
|
"""
|
|
%s
|
|
"""
|
|
|
|
Candidate Answer:
|
|
"""
|
|
%s
|
|
"""
|
|
]], maxPoints, schema, question, reference, candidate)
|
|
|
|
local r = client:generateCompletion({
|
|
model = evaluatorModel,
|
|
prompt = instructions,
|
|
system = system,
|
|
format = "json",
|
|
keep_alive = "5m",
|
|
options = {
|
|
num_ctx = 32000
|
|
}
|
|
})
|
|
|
|
if r.statusCode then
|
|
return { score = 0, rationale = "evaluator request failed", correct = false }
|
|
end
|
|
|
|
local ok, obj = pcall(function()
|
|
return serde.decode("json", r.response)
|
|
end)
|
|
|
|
if not ok or type(obj) ~= "table" or type(obj.score) ~= "number" then
|
|
return { score = 0, rationale = "invalid or non-JSON evaluator output", correct = false }
|
|
end
|
|
|
|
local bounded = clamp(math.floor(obj.score), 0, maxPoints)
|
|
local rationale = tostring(obj.rationale or "")
|
|
return { score = bounded, rationale = rationale, correct = not not obj.correct }
|
|
end
|
|
|
|
local function evaluateOne(evaluatorModel: string, ar: AnswerRecord): ScoredRecord
|
|
local res = evalAnswer(evaluatorModel, ar.question, ar.reference, ar.candidate, ar.pointReward)
|
|
return {
|
|
model = ar.model,
|
|
category = ar.category,
|
|
index = ar.index,
|
|
question = ar.question,
|
|
candidate = ar.candidate,
|
|
reference = ar.reference,
|
|
pointReward = ar.pointReward,
|
|
score = res.score,
|
|
rationale = res.rationale,
|
|
}
|
|
end
|
|
|
|
local function evaluateAll(evaluatorModel: string, answers: { AnswerRecord }, onProgress: ((current: number, total: number, ctx: { model: string, category: string, index: number }) -> ())?): { ScoredRecord }
|
|
local out: { ScoredRecord } = {}
|
|
local total = #answers
|
|
local current = 0
|
|
for _, ar in answers do
|
|
table.insert(out, evaluateOne(evaluatorModel, ar))
|
|
current += 1
|
|
if onProgress then
|
|
onProgress(current, total, { model = ar.model, category = ar.category, index = ar.index })
|
|
end
|
|
end
|
|
return out
|
|
end
|
|
|
|
return {
|
|
evaluateOne = evaluateOne,
|
|
evaluateAll = evaluateAll,
|
|
}
|