From a7edccaeb9e277242c927aba44d4f8776abc4c6a Mon Sep 17 00:00:00 2001 From: yisding <yi.s.ding@gmail.com> Date: Wed, 30 Aug 2023 14:48:22 -0700 Subject: [PATCH] update newline replacement logic on embeddings (#7484) --- CHANGELOG.md | 5 +++++ llama_index/embeddings/openai.py | 35 +++++++++++++++++++++++++------- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bfbbca8bad..e787511ead 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # ChangeLog +## Unreleased + +### Bug Fixes / Nits +- Only convert newlines to spaces for text 001 embedding models in OpenAI + ## [0.8.14] - 2023-08-30 ### New Features diff --git a/llama_index/embeddings/openai.py b/llama_index/embeddings/openai.py index 35fb1f9f37..216eaccc4c 100644 --- a/llama_index/embeddings/openai.py +++ b/llama_index/embeddings/openai.py @@ -118,7 +118,13 @@ def get_embedding( like matplotlib, plotly, scipy, sklearn. """ - text = text.replace("\n", " ") + if ( + engine is not None + and engine.endswith("001") + and not engine.endswith("code-001") + ): + # replace newlines, which can negatively affect performance on text-001 models. + text = text.replace("\n", " ") return openai.Embedding.create(input=[text], model=engine, **kwargs)["data"][0][ "embedding" ] @@ -140,8 +146,13 @@ async def aget_embedding( like matplotlib, plotly, scipy, sklearn. """ - # replace newlines, which can negatively affect performance. - text = text.replace("\n", " ") + if ( + engine is not None + and engine.endswith("001") + and not engine.endswith("code-001") + ): + # replace newlines, which can negatively affect performance on text-001 models. + text = text.replace("\n", " ") return (await openai.Embedding.acreate(input=[text], model=engine, **kwargs))[ "data" @@ -166,8 +177,13 @@ def get_embeddings( """ assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." - # replace newlines, which can negatively affect performance. - list_of_text = [text.replace("\n", " ") for text in list_of_text] + if ( + engine is not None + and engine.endswith("001") + and not engine.endswith("code-001") + ): + # replace newlines, which can negatively affect performance on text-001 models. + list_of_text = [text.replace("\n", " ") for text in list_of_text] data = openai.Embedding.create(input=list_of_text, model=engine, **kwargs).data return [d["embedding"] for d in data] @@ -191,8 +207,13 @@ async def aget_embeddings( """ assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." - # replace newlines, which can negatively affect performance. - list_of_text = [text.replace("\n", " ") for text in list_of_text] + if ( + engine is not None + and engine.endswith("001") + and not engine.endswith("code-001") + ): + # replace newlines, which can negatively affect performance on text-001 models. + list_of_text = [text.replace("\n", " ") for text in list_of_text] data = ( await openai.Embedding.acreate(input=list_of_text, model=engine, **kwargs) -- GitLab