Skip to content
Snippets Groups Projects
Unverified Commit 2fe3a2b6 authored by Emanuel Ferreira's avatar Emanuel Ferreira Committed by GitHub
Browse files

chore: enhancement optional args extractors (#462)

parent eb3d4af2
No related branches found
No related tags found
No related merge requests found
......@@ -16,7 +16,10 @@ import {
console.log(nodes);
const keywordExtractor = new KeywordExtractor(openaiLLM, 5);
const keywordExtractor = new KeywordExtractor({
llm: openaiLLM,
keywords: 5,
});
const nodesWithKeywordMetadata = await keywordExtractor.processNodes(nodes);
......
......@@ -19,10 +19,10 @@ import {
}),
]);
const questionsAnsweredExtractor = new QuestionsAnsweredExtractor(
openaiLLM,
5,
);
const questionsAnsweredExtractor = new QuestionsAnsweredExtractor({
llm: openaiLLM,
questions: 5,
});
const nodesWithQuestionsMetadata =
await questionsAnsweredExtractor.processNodes(nodes);
......
......@@ -16,7 +16,9 @@ import {
}),
]);
const summaryExtractor = new SummaryExtractor(openaiLLM);
const summaryExtractor = new SummaryExtractor({
llm: openaiLLM,
});
const nodesWithSummaryMetadata = await summaryExtractor.processNodes(nodes);
......
......@@ -11,7 +11,10 @@ import { Document, OpenAI, SimpleNodeParser, TitleExtractor } from "llamaindex";
}),
]);
const titleExtractor = new TitleExtractor(openaiLLM, 1);
const titleExtractor = new TitleExtractor({
llm: openaiLLM,
nodes: 5,
});
const nodesWithTitledMetadata = await titleExtractor.processNodes(nodes);
......
import { BaseNode, MetadataMode, TextNode } from "../Node";
import { LLM } from "../llm";
import { LLM, OpenAI } from "../llm";
import {
defaultKeywordExtractorPromptTemplate,
defaultQuestionAnswerPromptTemplate,
......@@ -11,6 +11,11 @@ import { BaseExtractor } from "./types";
const STRIP_REGEX = /(\r\n|\n|\r)/gm;
type KeywordExtractArgs = {
llm?: LLM;
keywords?: number;
};
type ExtractKeyword = {
excerptKeywords: string;
};
......@@ -38,12 +43,14 @@ export class KeywordExtractor extends BaseExtractor {
* @param {number} keywords Number of keywords to extract.
* @throws {Error} If keywords is less than 1.
*/
constructor(llm: LLM, keywords: number = 5) {
if (keywords < 1) throw new Error("Keywords must be greater than 0");
constructor(options?: KeywordExtractArgs) {
if (options?.keywords && options.keywords < 1)
throw new Error("Keywords must be greater than 0");
super();
this.llm = llm;
this.keywords = keywords;
this.llm = options?.llm ?? new OpenAI();
this.keywords = options?.keywords ?? 5;
}
/**
......@@ -81,6 +88,13 @@ export class KeywordExtractor extends BaseExtractor {
}
}
type TitleExtractorsArgs = {
llm?: LLM;
nodes?: number;
nodeTemplate?: string;
combineTemplate?: string;
};
type ExtractTitle = {
documentTitle: string;
};
......@@ -128,20 +142,16 @@ export class TitleExtractor extends BaseExtractor {
* @param {string} node_template The prompt template to use for the title extractor.
* @param {string} combine_template The prompt template to merge title with..
*/
constructor(
llm: LLM,
nodes: number = 5,
node_template?: string,
combine_template?: string,
) {
constructor(options?: TitleExtractorsArgs) {
super();
this.llm = llm;
this.nodes = nodes;
this.llm = options?.llm ?? new OpenAI();
this.nodes = options?.nodes ?? 5;
this.nodeTemplate = node_template ?? defaultTitleExtractorPromptTemplate();
this.nodeTemplate =
options?.nodeTemplate ?? defaultTitleExtractorPromptTemplate();
this.combineTemplate =
combine_template ?? defaultTitleCombinePromptTemplate();
options?.combineTemplate ?? defaultTitleCombinePromptTemplate();
}
/**
......@@ -197,6 +207,13 @@ export class TitleExtractor extends BaseExtractor {
}
}
type QuestionAnswerExtractArgs = {
llm?: LLM;
questions?: number;
promptTemplate?: string;
embeddingOnly?: boolean;
};
type ExtractQuestion = {
questionsThisExcerptCanAnswer: string;
};
......@@ -238,25 +255,21 @@ export class QuestionsAnsweredExtractor extends BaseExtractor {
* @param {string} promptTemplate The prompt template to use for the question extractor.
* @param {boolean} embeddingOnly Wheter to use metadata for embeddings only.
*/
constructor(
llm: LLM,
questions: number = 5,
promptTemplate?: string,
embeddingOnly: boolean = false,
) {
if (questions < 1) throw new Error("Questions must be greater than 0");
constructor(options?: QuestionAnswerExtractArgs) {
if (options?.questions && options.questions < 1)
throw new Error("Questions must be greater than 0");
super();
this.llm = llm;
this.questions = questions;
this.llm = options?.llm ?? new OpenAI();
this.questions = options?.questions ?? 5;
this.promptTemplate =
promptTemplate ??
options?.promptTemplate ??
defaultQuestionAnswerPromptTemplate({
numQuestions: questions,
numQuestions: this.questions,
contextStr: "",
});
this.embeddingOnly = embeddingOnly;
this.embeddingOnly = options?.embeddingOnly ?? false;
}
/**
......@@ -303,6 +316,12 @@ export class QuestionsAnsweredExtractor extends BaseExtractor {
}
}
type SummaryExtractArgs = {
llm?: LLM;
summaries?: string[];
promptTemplate?: string;
};
type ExtractSummary = {
sectionSummary: string;
prevSectionSummary: string;
......@@ -335,24 +354,25 @@ export class SummaryExtractor extends BaseExtractor {
private _prevSummary: boolean;
private _nextSummary: boolean;
constructor(
llm: LLM,
summaries: string[] = ["self"],
promptTemplate?: string,
) {
if (!summaries.some((s) => ["self", "prev", "next"].includes(s)))
constructor(options?: SummaryExtractArgs) {
const summaries = options?.summaries ?? ["self"];
if (
summaries &&
!summaries.some((s) => ["self", "prev", "next"].includes(s))
)
throw new Error("Summaries must be one of 'self', 'prev', 'next'");
super();
this.llm = llm;
this.llm = options?.llm ?? new OpenAI();
this.summaries = summaries;
this.promptTemplate =
promptTemplate ?? defaultSummaryExtractorPromptTemplate();
options?.promptTemplate ?? defaultSummaryExtractorPromptTemplate();
this._selfSummary = summaries.includes("self");
this._prevSummary = summaries.includes("prev");
this._nextSummary = summaries.includes("next");
this._selfSummary = summaries?.includes("self") ?? false;
this._prevSummary = summaries?.includes("prev") ?? false;
this._nextSummary = summaries?.includes("next") ?? false;
}
/**
......
......@@ -75,7 +75,10 @@ describe("[MetadataExtractor]: Extractors should populate the metadata", () => {
new Document({ text: DEFAULT_LLM_TEXT_OUTPUT }),
]);
const keywordExtractor = new KeywordExtractor(serviceContext.llm, 5);
const keywordExtractor = new KeywordExtractor({
llm: serviceContext.llm,
keywords: 5,
});
const nodesWithKeywordMetadata = await keywordExtractor.processNodes(nodes);
......@@ -91,7 +94,10 @@ describe("[MetadataExtractor]: Extractors should populate the metadata", () => {
new Document({ text: DEFAULT_LLM_TEXT_OUTPUT }),
]);
const titleExtractor = new TitleExtractor(serviceContext.llm, 5);
const titleExtractor = new TitleExtractor({
llm: serviceContext.llm,
nodes: 5,
});
const nodesWithKeywordMetadata = await titleExtractor.processNodes(nodes);
......@@ -107,10 +113,10 @@ describe("[MetadataExtractor]: Extractors should populate the metadata", () => {
new Document({ text: DEFAULT_LLM_TEXT_OUTPUT }),
]);
const questionsAnsweredExtractor = new QuestionsAnsweredExtractor(
serviceContext.llm,
5,
);
const questionsAnsweredExtractor = new QuestionsAnsweredExtractor({
llm: serviceContext.llm,
questions: 5,
});
const nodesWithKeywordMetadata =
await questionsAnsweredExtractor.processNodes(nodes);
......@@ -127,7 +133,9 @@ describe("[MetadataExtractor]: Extractors should populate the metadata", () => {
new Document({ text: DEFAULT_LLM_TEXT_OUTPUT }),
]);
const summaryExtractor = new SummaryExtractor(serviceContext.llm);
const summaryExtractor = new SummaryExtractor({
llm: serviceContext.llm,
});
const nodesWithKeywordMetadata = await summaryExtractor.processNodes(nodes);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment