forked from phoenix/litellm-mirror
Merge pull request #1689 from BerriAI/litellm_set_organization_on_config.yaml
[Feat] Set OpenAI organization for litellm.completion, Proxy Config
This commit is contained in:
commit
dd9c78819a
6 changed files with 107 additions and 3 deletions
|
@ -188,7 +188,7 @@ print(response)
|
|||
</Tabs>
|
||||
|
||||
|
||||
## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Headers etc.)
|
||||
## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Organization, Headers etc.)
|
||||
You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc.
|
||||
|
||||
[**All input params**](https://docs.litellm.ai/docs/completion/input#input-params-1)
|
||||
|
@ -210,6 +210,12 @@ model_list:
|
|||
api_key: sk-123
|
||||
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
|
||||
temperature: 0.2
|
||||
- model_name: openai-gpt-3.5
|
||||
litellm_params:
|
||||
model: openai/gpt-3.5-turbo
|
||||
api_key: sk-123
|
||||
organization: org-ikDc4ex8NB
|
||||
temperature: 0.2
|
||||
- model_name: mistral-7b
|
||||
litellm_params:
|
||||
model: ollama/mistral
|
||||
|
|
|
@ -221,6 +221,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
headers: Optional[dict] = None,
|
||||
custom_prompt_dict: dict = {},
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
):
|
||||
super().completion()
|
||||
exception_mapping_worked = False
|
||||
|
@ -254,6 +255,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(
|
||||
|
@ -266,6 +268,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
elif optional_params.get("stream", False):
|
||||
return self.streaming(
|
||||
|
@ -278,6 +281,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
if not isinstance(max_retries, int):
|
||||
|
@ -291,6 +295,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
http_client=litellm.client_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
@ -358,6 +363,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
timeout: float,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
logging_obj=None,
|
||||
|
@ -372,6 +378,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
http_client=litellm.aclient_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_aclient = client
|
||||
|
@ -412,6 +419,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
headers=None,
|
||||
|
@ -423,6 +431,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
http_client=litellm.client_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
@ -454,6 +463,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
headers=None,
|
||||
|
@ -467,6 +477,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
http_client=litellm.aclient_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_aclient = client
|
||||
|
@ -748,8 +759,11 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
messages: Optional[list] = None,
|
||||
input: Optional[list] = None,
|
||||
prompt: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
):
|
||||
client = AsyncOpenAI(api_key=api_key, timeout=timeout)
|
||||
client = AsyncOpenAI(
|
||||
api_key=api_key, timeout=timeout, organization=organization
|
||||
)
|
||||
if model is None and mode != "image_generation":
|
||||
raise Exception("model is not set")
|
||||
|
||||
|
|
|
@ -450,6 +450,7 @@ def completion(
|
|||
num_retries = kwargs.get("num_retries", None) ## deprecated
|
||||
max_retries = kwargs.get("max_retries", None)
|
||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||
organization = kwargs.get("organization", None)
|
||||
### CUSTOM MODEL COST ###
|
||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||
|
@ -787,7 +788,8 @@ def completion(
|
|||
or "https://api.openai.com/v1"
|
||||
)
|
||||
openai.organization = (
|
||||
litellm.organization
|
||||
organization
|
||||
or litellm.organization
|
||||
or get_secret("OPENAI_ORGANIZATION")
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
|
@ -827,6 +829,7 @@ def completion(
|
|||
timeout=timeout,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
organization=organization,
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING - log the original exception returned
|
||||
|
@ -3224,6 +3227,7 @@ async def ahealth_check(
|
|||
or custom_llm_provider == "text-completion-openai"
|
||||
):
|
||||
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
|
||||
organization = model_params.get("organization")
|
||||
|
||||
timeout = (
|
||||
model_params.get("timeout")
|
||||
|
@ -3241,6 +3245,7 @@ async def ahealth_check(
|
|||
mode=mode,
|
||||
prompt=prompt,
|
||||
input=input,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
if mode == "embedding":
|
||||
|
|
|
@ -1411,6 +1411,12 @@ class Router:
|
|||
max_retries = litellm.get_secret(max_retries_env_name)
|
||||
litellm_params["max_retries"] = max_retries
|
||||
|
||||
organization = litellm_params.get("organization", None)
|
||||
if isinstance(organization, str) and organization.startswith("os.environ/"):
|
||||
organization_env_name = organization.replace("os.environ/", "")
|
||||
organization = litellm.get_secret(organization_env_name)
|
||||
litellm_params["organization"] = organization
|
||||
|
||||
if "azure" in model_name:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
|
@ -1610,6 +1616,7 @@ class Router:
|
|||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
http_client=httpx.AsyncClient(
|
||||
transport=AsyncCustomHTTPTransport(),
|
||||
limits=httpx.Limits(
|
||||
|
@ -1630,6 +1637,7 @@ class Router:
|
|||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
http_client=httpx.Client(
|
||||
transport=CustomHTTPTransport(),
|
||||
limits=httpx.Limits(
|
||||
|
@ -1651,6 +1659,7 @@ class Router:
|
|||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
http_client=httpx.AsyncClient(
|
||||
transport=AsyncCustomHTTPTransport(),
|
||||
limits=httpx.Limits(
|
||||
|
@ -1672,6 +1681,7 @@ class Router:
|
|||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
http_client=httpx.Client(
|
||||
transport=CustomHTTPTransport(),
|
||||
limits=httpx.Limits(
|
||||
|
|
|
@ -569,6 +569,22 @@ def test_completion_openai():
|
|||
# test_completion_openai()
|
||||
|
||||
|
||||
def test_completion_openai_organization():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
response = completion(
|
||||
model="gpt-3.5-turbo", messages=messages, organization="org-ikDc4ex8NB"
|
||||
)
|
||||
pytest.fail("Request should have failed - This organization does not exist")
|
||||
except Exception as e:
|
||||
assert "No such organization: org-ikDc4ex8NB" in str(e)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_text_openai():
|
||||
try:
|
||||
# litellm.set_verbose = True
|
||||
|
|
|
@ -387,3 +387,56 @@ def test_router_init_gpt_4_vision_enhancements():
|
|||
print("passed")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_openai_with_organization():
|
||||
try:
|
||||
print("Testing OpenAI with organization")
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "openai-bad-org",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"organization": "org-ikDc4ex8NB",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "openai-good-org",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||
},
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
print(router.model_list)
|
||||
print(router.model_list[0])
|
||||
|
||||
openai_client = router._get_client(
|
||||
deployment=router.model_list[0],
|
||||
kwargs={"input": ["hello"], "model": "openai-bad-org"},
|
||||
)
|
||||
print(vars(openai_client))
|
||||
|
||||
assert openai_client.organization == "org-ikDc4ex8NB"
|
||||
|
||||
# bad org raises error
|
||||
|
||||
try:
|
||||
response = router.completion(
|
||||
model="openai-bad-org",
|
||||
messages=[{"role": "user", "content": "this is a test"}],
|
||||
)
|
||||
pytest.fail("Request should have failed - This organization does not exist")
|
||||
except Exception as e:
|
||||
print("Got exception: " + str(e))
|
||||
assert "No such organization: org-ikDc4ex8NB" in str(e)
|
||||
|
||||
# good org works
|
||||
response = router.completion(
|
||||
model="openai-good-org",
|
||||
messages=[{"role": "user", "content": "this is a test"}],
|
||||
max_tokens=5,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue