Skip to content
Snippets Groups Projects
Unverified Commit a7edccae authored by yisding's avatar yisding Committed by GitHub
Browse files

update newline replacement logic on embeddings (#7484)

parent 69a88df3
No related branches found
No related tags found
No related merge requests found
# ChangeLog
## Unreleased
### Bug Fixes / Nits
- Only convert newlines to spaces for text 001 embedding models in OpenAI
## [0.8.14] - 2023-08-30
### New Features
......
......@@ -118,7 +118,13 @@ def get_embedding(
like matplotlib, plotly, scipy, sklearn.
"""
text = text.replace("\n", " ")
if (
engine is not None
and engine.endswith("001")
and not engine.endswith("code-001")
):
# replace newlines, which can negatively affect performance on text-001 models.
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], model=engine, **kwargs)["data"][0][
"embedding"
]
......@@ -140,8 +146,13 @@ async def aget_embedding(
like matplotlib, plotly, scipy, sklearn.
"""
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
if (
engine is not None
and engine.endswith("001")
and not engine.endswith("code-001")
):
# replace newlines, which can negatively affect performance on text-001 models.
text = text.replace("\n", " ")
return (await openai.Embedding.acreate(input=[text], model=engine, **kwargs))[
"data"
......@@ -166,8 +177,13 @@ def get_embeddings(
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
if (
engine is not None
and engine.endswith("001")
and not engine.endswith("code-001")
):
# replace newlines, which can negatively affect performance on text-001 models.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = openai.Embedding.create(input=list_of_text, model=engine, **kwargs).data
return [d["embedding"] for d in data]
......@@ -191,8 +207,13 @@ async def aget_embeddings(
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
if (
engine is not None
and engine.endswith("001")
and not engine.endswith("code-001")
):
# replace newlines, which can negatively affect performance on text-001 models.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = (
await openai.Embedding.acreate(input=list_of_text, model=engine, **kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment