Merge pull request #1689 from BerriAI/litellm_set_organization_on_config.yaml

[Feat] Set OpenAI organization for litellm.completion, Proxy Config
This commit is contained in:
Ishaan Jaff 2024-01-30 11:47:42 -08:00 committed by GitHub
commit dd9c78819a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 107 additions and 3 deletions

View file

@ -188,7 +188,7 @@ print(response)
</Tabs> </Tabs>
## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Headers etc.) ## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Organization, Headers etc.)
You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc. You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc.
[**All input params**](https://docs.litellm.ai/docs/completion/input#input-params-1) [**All input params**](https://docs.litellm.ai/docs/completion/input#input-params-1)
@ -210,6 +210,12 @@ model_list:
api_key: sk-123 api_key: sk-123
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/ api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
temperature: 0.2 temperature: 0.2
- model_name: openai-gpt-3.5
litellm_params:
model: openai/gpt-3.5-turbo
api_key: sk-123
organization: org-ikDc4ex8NB
temperature: 0.2
- model_name: mistral-7b - model_name: mistral-7b
litellm_params: litellm_params:
model: ollama/mistral model: ollama/mistral

View file

@ -221,6 +221,7 @@ class OpenAIChatCompletion(BaseLLM):
headers: Optional[dict] = None, headers: Optional[dict] = None,
custom_prompt_dict: dict = {}, custom_prompt_dict: dict = {},
client=None, client=None,
organization: Optional[str] = None,
): ):
super().completion() super().completion()
exception_mapping_worked = False exception_mapping_worked = False
@ -254,6 +255,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout=timeout, timeout=timeout,
client=client, client=client,
max_retries=max_retries, max_retries=max_retries,
organization=organization,
) )
else: else:
return self.acompletion( return self.acompletion(
@ -266,6 +268,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout=timeout, timeout=timeout,
client=client, client=client,
max_retries=max_retries, max_retries=max_retries,
organization=organization,
) )
elif optional_params.get("stream", False): elif optional_params.get("stream", False):
return self.streaming( return self.streaming(
@ -278,6 +281,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout=timeout, timeout=timeout,
client=client, client=client,
max_retries=max_retries, max_retries=max_retries,
organization=organization,
) )
else: else:
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
@ -291,6 +295,7 @@ class OpenAIChatCompletion(BaseLLM):
http_client=litellm.client_session, http_client=litellm.client_session,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization,
) )
else: else:
openai_client = client openai_client = client
@ -358,6 +363,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout: float, timeout: float,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None, client=None,
max_retries=None, max_retries=None,
logging_obj=None, logging_obj=None,
@ -372,6 +378,7 @@ class OpenAIChatCompletion(BaseLLM):
http_client=litellm.aclient_session, http_client=litellm.aclient_session,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization,
) )
else: else:
openai_aclient = client openai_aclient = client
@ -412,6 +419,7 @@ class OpenAIChatCompletion(BaseLLM):
model: str, model: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None, client=None,
max_retries=None, max_retries=None,
headers=None, headers=None,
@ -423,6 +431,7 @@ class OpenAIChatCompletion(BaseLLM):
http_client=litellm.client_session, http_client=litellm.client_session,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization,
) )
else: else:
openai_client = client openai_client = client
@ -454,6 +463,7 @@ class OpenAIChatCompletion(BaseLLM):
model: str, model: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None, client=None,
max_retries=None, max_retries=None,
headers=None, headers=None,
@ -467,6 +477,7 @@ class OpenAIChatCompletion(BaseLLM):
http_client=litellm.aclient_session, http_client=litellm.aclient_session,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization,
) )
else: else:
openai_aclient = client openai_aclient = client
@ -748,8 +759,11 @@ class OpenAIChatCompletion(BaseLLM):
messages: Optional[list] = None, messages: Optional[list] = None,
input: Optional[list] = None, input: Optional[list] = None,
prompt: Optional[str] = None, prompt: Optional[str] = None,
organization: Optional[str] = None,
): ):
client = AsyncOpenAI(api_key=api_key, timeout=timeout) client = AsyncOpenAI(
api_key=api_key, timeout=timeout, organization=organization
)
if model is None and mode != "image_generation": if model is None and mode != "image_generation":
raise Exception("model is not set") raise Exception("model is not set")

View file

@ -450,6 +450,7 @@ def completion(
num_retries = kwargs.get("num_retries", None) ## deprecated num_retries = kwargs.get("num_retries", None) ## deprecated
max_retries = kwargs.get("max_retries", None) max_retries = kwargs.get("max_retries", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
organization = kwargs.get("organization", None)
### CUSTOM MODEL COST ### ### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None) input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None) output_cost_per_token = kwargs.get("output_cost_per_token", None)
@ -787,7 +788,8 @@ def completion(
or "https://api.openai.com/v1" or "https://api.openai.com/v1"
) )
openai.organization = ( openai.organization = (
litellm.organization organization
or litellm.organization
or get_secret("OPENAI_ORGANIZATION") or get_secret("OPENAI_ORGANIZATION")
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
) )
@ -827,6 +829,7 @@ def completion(
timeout=timeout, timeout=timeout,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
) )
except Exception as e: except Exception as e:
## LOGGING - log the original exception returned ## LOGGING - log the original exception returned
@ -3224,6 +3227,7 @@ async def ahealth_check(
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"
): ):
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY") api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
organization = model_params.get("organization")
timeout = ( timeout = (
model_params.get("timeout") model_params.get("timeout")
@ -3241,6 +3245,7 @@ async def ahealth_check(
mode=mode, mode=mode,
prompt=prompt, prompt=prompt,
input=input, input=input,
organization=organization,
) )
else: else:
if mode == "embedding": if mode == "embedding":

View file

@ -1411,6 +1411,12 @@ class Router:
max_retries = litellm.get_secret(max_retries_env_name) max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries litellm_params["max_retries"] = max_retries
organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "")
organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization
if "azure" in model_name: if "azure" in model_name:
if api_base is None: if api_base is None:
raise ValueError( raise ValueError(
@ -1610,6 +1616,7 @@ class Router:
base_url=api_base, base_url=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), transport=AsyncCustomHTTPTransport(),
limits=httpx.Limits( limits=httpx.Limits(
@ -1630,6 +1637,7 @@ class Router:
base_url=api_base, base_url=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport(), transport=CustomHTTPTransport(),
limits=httpx.Limits( limits=httpx.Limits(
@ -1651,6 +1659,7 @@ class Router:
base_url=api_base, base_url=api_base,
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), transport=AsyncCustomHTTPTransport(),
limits=httpx.Limits( limits=httpx.Limits(
@ -1672,6 +1681,7 @@ class Router:
base_url=api_base, base_url=api_base,
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport(), transport=CustomHTTPTransport(),
limits=httpx.Limits( limits=httpx.Limits(

View file

@ -569,6 +569,22 @@ def test_completion_openai():
# test_completion_openai() # test_completion_openai()
def test_completion_openai_organization():
try:
litellm.set_verbose = True
try:
response = completion(
model="gpt-3.5-turbo", messages=messages, organization="org-ikDc4ex8NB"
)
pytest.fail("Request should have failed - This organization does not exist")
except Exception as e:
assert "No such organization: org-ikDc4ex8NB" in str(e)
except Exception as e:
print(e)
pytest.fail(f"Error occurred: {e}")
def test_completion_text_openai(): def test_completion_text_openai():
try: try:
# litellm.set_verbose = True # litellm.set_verbose = True

View file

@ -387,3 +387,56 @@ def test_router_init_gpt_4_vision_enhancements():
print("passed") print("passed")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_openai_with_organization():
try:
print("Testing OpenAI with organization")
model_list = [
{
"model_name": "openai-bad-org",
"litellm_params": {
"model": "gpt-3.5-turbo",
"organization": "org-ikDc4ex8NB",
},
},
{
"model_name": "openai-good-org",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
]
router = Router(model_list=model_list)
print(router.model_list)
print(router.model_list[0])
openai_client = router._get_client(
deployment=router.model_list[0],
kwargs={"input": ["hello"], "model": "openai-bad-org"},
)
print(vars(openai_client))
assert openai_client.organization == "org-ikDc4ex8NB"
# bad org raises error
try:
response = router.completion(
model="openai-bad-org",
messages=[{"role": "user", "content": "this is a test"}],
)
pytest.fail("Request should have failed - This organization does not exist")
except Exception as e:
print("Got exception: " + str(e))
assert "No such organization: org-ikDc4ex8NB" in str(e)
# good org works
response = router.completion(
model="openai-good-org",
messages=[{"role": "user", "content": "this is a test"}],
max_tokens=5,
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")