Skip to content
Snippets Groups Projects
Unverified Commit ce824593 authored by Siraj Aizlewood's avatar Siraj Aizlewood Committed by GitHub
Browse files

Merge pull request #159 from maxyousif15/feat/m.openai-encoder-org-id-support

feat: openai encoder org id support
parents 94686502 4fe981a0
No related branches found
No related tags found
No related merge requests found
...@@ -20,6 +20,7 @@ class OpenAIEncoder(BaseEncoder): ...@@ -20,6 +20,7 @@ class OpenAIEncoder(BaseEncoder):
self, self,
name: Optional[str] = None, name: Optional[str] = None,
openai_api_key: Optional[str] = None, openai_api_key: Optional[str] = None,
openai_org_id: Optional[str] = None,
score_threshold: float = 0.82, score_threshold: float = 0.82,
dimensions: Union[int, NotGiven] = NotGiven(), dimensions: Union[int, NotGiven] = NotGiven(),
): ):
...@@ -27,10 +28,11 @@ class OpenAIEncoder(BaseEncoder): ...@@ -27,10 +28,11 @@ class OpenAIEncoder(BaseEncoder):
name = os.getenv("OPENAI_MODEL_NAME", "text-embedding-ada-002") name = os.getenv("OPENAI_MODEL_NAME", "text-embedding-ada-002")
super().__init__(name=name, score_threshold=score_threshold) super().__init__(name=name, score_threshold=score_threshold)
api_key = openai_api_key or os.getenv("OPENAI_API_KEY") api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
openai_org_id = openai_org_id or os.getenv("OPENAI_ORG_ID")
if api_key is None: if api_key is None:
raise ValueError("OpenAI API key cannot be 'None'.") raise ValueError("OpenAI API key cannot be 'None'.")
try: try:
self.client = openai.Client(api_key=api_key) self.client = openai.Client(api_key=api_key, organization=openai_org_id)
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
f"OpenAI API client failed to initialize. Error: {e}" f"OpenAI API client failed to initialize. Error: {e}"
......
...@@ -14,7 +14,9 @@ def openai_encoder(mocker): ...@@ -14,7 +14,9 @@ def openai_encoder(mocker):
class TestOpenAIEncoder: class TestOpenAIEncoder:
def test_openai_encoder_init_success(self, mocker): def test_openai_encoder_init_success(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key") # -- Mock the return value of os.getenv 3 times: model name, api key and org ID
side_effect = ["fake-model-name", "fake-api-key", "fake-org-id"]
mocker.patch("os.getenv", side_effect=side_effect)
encoder = OpenAIEncoder() encoder = OpenAIEncoder()
assert encoder.client is not None assert encoder.client is not None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment