Merge pull request #1689 from BerriAI/litellm_set_organization_on_config.yaml

[Feat] Set OpenAI organization for litellm.completion, Proxy Config
This commit is contained in:
Ishaan Jaff 2024-01-30 11:47:42 -08:00 committed by GitHub
commit dd9c78819a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 107 additions and 3 deletions

View file

@ -188,7 +188,7 @@ print(response)
</Tabs>
## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Headers etc.)
## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Organization, Headers etc.)
You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc.
[**All input params**](https://docs.litellm.ai/docs/completion/input#input-params-1)
@ -210,6 +210,12 @@ model_list:
api_key: sk-123
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
temperature: 0.2
- model_name: openai-gpt-3.5
litellm_params:
model: openai/gpt-3.5-turbo
api_key: sk-123
organization: org-ikDc4ex8NB
temperature: 0.2
- model_name: mistral-7b
litellm_params:
model: ollama/mistral

View file

@ -221,6 +221,7 @@ class OpenAIChatCompletion(BaseLLM):
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
):
super().completion()
exception_mapping_worked = False
@ -254,6 +255,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
else:
return self.acompletion(
@ -266,6 +268,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
elif optional_params.get("stream", False):
return self.streaming(
@ -278,6 +281,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
else:
if not isinstance(max_retries, int):
@ -291,6 +295,7 @@ class OpenAIChatCompletion(BaseLLM):
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
@ -358,6 +363,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
@ -372,6 +378,7 @@ class OpenAIChatCompletion(BaseLLM):
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
@ -412,6 +419,7 @@ class OpenAIChatCompletion(BaseLLM):
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
@ -423,6 +431,7 @@ class OpenAIChatCompletion(BaseLLM):
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
@ -454,6 +463,7 @@ class OpenAIChatCompletion(BaseLLM):
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
@ -467,6 +477,7 @@ class OpenAIChatCompletion(BaseLLM):
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
@ -748,8 +759,11 @@ class OpenAIChatCompletion(BaseLLM):
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
organization: Optional[str] = None,
):
client = AsyncOpenAI(api_key=api_key, timeout=timeout)
client = AsyncOpenAI(
api_key=api_key, timeout=timeout, organization=organization
)
if model is None and mode != "image_generation":
raise Exception("model is not set")

View file

@ -450,6 +450,7 @@ def completion(
num_retries = kwargs.get("num_retries", None) ## deprecated
max_retries = kwargs.get("max_retries", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
organization = kwargs.get("organization", None)
### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None)
@ -787,7 +788,8 @@ def completion(
or "https://api.openai.com/v1"
)
openai.organization = (
litellm.organization
organization
or litellm.organization
or get_secret("OPENAI_ORGANIZATION")
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
@ -827,6 +829,7 @@ def completion(
timeout=timeout,
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
)
except Exception as e:
## LOGGING - log the original exception returned
@ -3224,6 +3227,7 @@ async def ahealth_check(
or custom_llm_provider == "text-completion-openai"
):
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
organization = model_params.get("organization")
timeout = (
model_params.get("timeout")
@ -3241,6 +3245,7 @@ async def ahealth_check(
mode=mode,
prompt=prompt,
input=input,
organization=organization,
)
else:
if mode == "embedding":

View file

@ -1411,6 +1411,12 @@ class Router:
max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries
organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "")
organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization
if "azure" in model_name:
if api_base is None:
raise ValueError(
@ -1610,6 +1616,7 @@ class Router:
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(),
limits=httpx.Limits(
@ -1630,6 +1637,7 @@ class Router:
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(),
limits=httpx.Limits(
@ -1651,6 +1659,7 @@ class Router:
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(),
limits=httpx.Limits(
@ -1672,6 +1681,7 @@ class Router:
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(),
limits=httpx.Limits(

View file

@ -569,6 +569,22 @@ def test_completion_openai():
# test_completion_openai()
def test_completion_openai_organization():
try:
litellm.set_verbose = True
try:
response = completion(
model="gpt-3.5-turbo", messages=messages, organization="org-ikDc4ex8NB"
)
pytest.fail("Request should have failed - This organization does not exist")
except Exception as e:
assert "No such organization: org-ikDc4ex8NB" in str(e)
except Exception as e:
print(e)
pytest.fail(f"Error occurred: {e}")
def test_completion_text_openai():
try:
# litellm.set_verbose = True

View file

@ -387,3 +387,56 @@ def test_router_init_gpt_4_vision_enhancements():
print("passed")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_openai_with_organization():
try:
print("Testing OpenAI with organization")
model_list = [
{
"model_name": "openai-bad-org",
"litellm_params": {
"model": "gpt-3.5-turbo",
"organization": "org-ikDc4ex8NB",
},
},
{
"model_name": "openai-good-org",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
]
router = Router(model_list=model_list)
print(router.model_list)
print(router.model_list[0])
openai_client = router._get_client(
deployment=router.model_list[0],
kwargs={"input": ["hello"], "model": "openai-bad-org"},
)
print(vars(openai_client))
assert openai_client.organization == "org-ikDc4ex8NB"
# bad org raises error
try:
response = router.completion(
model="openai-bad-org",
messages=[{"role": "user", "content": "this is a test"}],
)
pytest.fail("Request should have failed - This organization does not exist")
except Exception as e:
print("Got exception: " + str(e))
assert "No such organization: org-ikDc4ex8NB" in str(e)
# good org works
response = router.completion(
model="openai-good-org",
messages=[{"role": "user", "content": "this is a test"}],
max_tokens=5,
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")