diff --git a/packages/core/src/response-synthesizers/factory.ts b/packages/core/src/response-synthesizers/factory.ts index 5d14b025e991b46e9d2734df2e3e1b05f6f83c8a..85d4d3e5fd0c685450ace59dcb593438cadc9a37 100644 --- a/packages/core/src/response-synthesizers/factory.ts +++ b/packages/core/src/response-synthesizers/factory.ts @@ -403,27 +403,27 @@ class MultiModal extends BaseSynthesizer { } } -export function getResponseSynthesizer( - mode: ResponseMode, +const modeToSynthesizer = { + compact: CompactAndRefine, + refine: Refine, + tree_summarize: TreeSummarize, + multi_modal: MultiModal, +} as const; + +export function getResponseSynthesizer<Mode extends ResponseMode>( + mode: Mode, options: BaseSynthesizerOptions & { textQATemplate?: TextQAPrompt; refineTemplate?: RefinePrompt; summaryTemplate?: TreeSummarizePrompt; metadataMode?: MetadataMode; } = {}, -) { - switch (mode) { - case "compact": { - return new CompactAndRefine(options); - } - case "refine": { - return new Refine(options); - } - case "tree_summarize": { - return new TreeSummarize(options); - } - case "multi_modal": { - return new MultiModal(options); - } +): InstanceType<(typeof modeToSynthesizer)[Mode]> { + const Synthesizer: (typeof modeToSynthesizer)[Mode] = modeToSynthesizer[mode]; + if (!Synthesizer) { + throw new Error(`Invalid response mode: ${mode}`); } + return new Synthesizer(options) as InstanceType< + (typeof modeToSynthesizer)[Mode] + >; }