Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
index.js 6.89 KiB
const path = require("path");
const fs = require("fs");
const { toChunks } = require("../../helpers");
const { v4 } = require("uuid");

class NativeEmbedder {
  // This is a folder that Mintplex Labs hosts for those who cannot capture the HF model download
  // endpoint for various reasons. This endpoint is not guaranteed to be active or maintained
  // and may go offline at any time at Mintplex Labs's discretion.
  #fallbackHost =
    "https://s3.us-west-1.amazonaws.com/public.useanything.com/support/models/";

  constructor() {
    // Model Card: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
    this.model = "Xenova/all-MiniLM-L6-v2";
    this.cacheDir = path.resolve(
      process.env.STORAGE_DIR
        ? path.resolve(process.env.STORAGE_DIR, `models`)
        : path.resolve(__dirname, `../../../storage/models`)
    );
    this.modelPath = path.resolve(this.cacheDir, "Xenova", "all-MiniLM-L6-v2");
    this.modelDownloaded = fs.existsSync(this.modelPath);

    // Limit of how many strings we can process in a single pass to stay with resource or network limits
    this.maxConcurrentChunks = 25;
    this.embeddingMaxChunkLength = 1_000;

    // Make directory when it does not exist in existing installations
    if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir);
    this.log("Initialized");
  }

  log(text, ...args) {
    console.log(`\x1b[36m[NativeEmbedder]\x1b[0m ${text}`, ...args);
  }

  #tempfilePath() {
    const filename = `${v4()}.tmp`;
    const tmpPath = process.env.STORAGE_DIR
      ? path.resolve(process.env.STORAGE_DIR, "tmp")
      : path.resolve(__dirname, `../../../storage/tmp`);
    if (!fs.existsSync(tmpPath)) fs.mkdirSync(tmpPath, { recursive: true });
    return path.resolve(tmpPath, filename);
  }

  async #writeToTempfile(filePath, data) {
    try {
      await fs.promises.appendFile(filePath, data, { encoding: "utf8" });
    } catch (e) {
      console.error(`Error writing to tempfile: ${e}`);
    }
  }

  async #fetchWithHost(hostOverride = null) {
    try {
      // Convert ESM to CommonJS via import so we can load this library.
      const pipeline = (...args) =>
        import("@xenova/transformers").then(({ pipeline, env }) => {
          if (!this.modelDownloaded) {
            // if model is not downloaded, we will log where we are fetching from.
            if (hostOverride) {
              env.remoteHost = hostOverride;
              env.remotePathTemplate = "{model}/"; // Our S3 fallback url does not support revision File structure.
            }
            this.log(`Downloading ${this.model} from ${env.remoteHost}`);
          }
          return pipeline(...args);
        });
      return {
        pipeline: await pipeline("feature-extraction", this.model, {
          cache_dir: this.cacheDir,
          ...(!this.modelDownloaded
            ? {
                // Show download progress if we need to download any files
                progress_callback: (data) => {
                  if (!data.hasOwnProperty("progress")) return;
                  console.log(
                    `\x1b[36m[NativeEmbedder - Downloading model]\x1b[0m ${
                      data.file
                    } ${~~data?.progress}%`
                  );
                },
              }
            : {}),
        }),
        retry: false,
        error: null,
      };
    } catch (error) {
      return {
        pipeline: null,
        retry: hostOverride === null ? this.#fallbackHost : false,
        error,
      };
    }
  }

  // This function will do a single fallback attempt (not recursive on purpose) to try to grab the embedder model on first embed
  // since at time, some clients cannot properly download the model from HF servers due to a number of reasons (IP, VPN, etc).
  // Given this model is critical and nobody reads the GitHub issues before submitting the bug, we get the same bug
  // report 20 times a day: https://github.com/Mintplex-Labs/anything-llm/issues/821
  // So to attempt to monkey-patch this we have a single fallback URL to help alleviate duplicate bug reports.
  async embedderClient() {
    if (!this.modelDownloaded)
      this.log(
        "The native embedding model has never been run and will be downloaded right now. Subsequent runs will be faster. (~23MB)"
      );

    let fetchResponse = await this.#fetchWithHost();
    if (fetchResponse.pipeline !== null) return fetchResponse.pipeline;

    this.log(
      `Failed to download model from primary URL. Using fallback ${fetchResponse.retry}`
    );
    if (!!fetchResponse.retry)
      fetchResponse = await this.#fetchWithHost(fetchResponse.retry);
    if (fetchResponse.pipeline !== null) return fetchResponse.pipeline;
    throw fetchResponse.error;
  }

  async embedTextInput(textInput) {
    const result = await this.embedChunks(textInput);
    return result?.[0] || [];
  }

  // If you are thinking you want to edit this function - you probably don't.
  // This process was benchmarked heavily on a t3.small (2GB RAM 1vCPU)
  // and without careful memory management for the V8 garbage collector
  // this function will likely result in an OOM on any resource-constrained deployment.
  // To help manage very large documents we run a concurrent write-log each iteration
  // to keep the embedding result out of memory. The `maxConcurrentChunk` is set to 25,
  // as 50 seems to overflow no matter what. Given the above, memory use hovers around ~30%
  // during a very large document (>100K words) but can spike up to 70% before gc.
  // This seems repeatable for all document sizes.
  // While this does take a while, it is zero set up and is 100% free and on-instance.
  // It still may crash depending on other elements at play - so no promises it works under all conditions.
  async embedChunks(textChunks = []) {
    const tmpFilePath = this.#tempfilePath();
    const chunks = toChunks(textChunks, this.maxConcurrentChunks);
    const chunkLen = chunks.length;

    for (let [idx, chunk] of chunks.entries()) {
      if (idx === 0) await this.#writeToTempfile(tmpFilePath, "[");
      let data;
      let pipeline = await this.embedderClient();
      let output = await pipeline(chunk, {
        pooling: "mean",
        normalize: true,
      });

      if (output.length === 0) {
        pipeline = null;
        output = null;
        data = null;
        continue;
      }

      data = JSON.stringify(output.tolist());
      await this.#writeToTempfile(tmpFilePath, data);
      this.log(`Embedded Chunk ${idx + 1} of ${chunkLen}`);
      if (chunkLen - 1 !== idx) await this.#writeToTempfile(tmpFilePath, ",");
      if (chunkLen - 1 === idx) await this.#writeToTempfile(tmpFilePath, "]");
      pipeline = null;
      output = null;
      data = null;
    }

    const embeddingResults = JSON.parse(
      fs.readFileSync(tmpFilePath, { encoding: "utf-8" })
    );
    fs.rmSync(tmpFilePath, { force: true });
    return embeddingResults.length > 0 ? embeddingResults.flat() : null;
  }
}

module.exports = {
  NativeEmbedder,
};