forked from phoenix/litellm-mirror
(fix) router: azure/embedding support
This commit is contained in:
parent
e58b3d5df0
commit
3891462b29
2 changed files with 33 additions and 1 deletions
|
@ -327,7 +327,7 @@ class Router:
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
else:
|
else:
|
||||||
data["model"] = original_model_string
|
data["model"] = original_model_string
|
||||||
model_client = deployment.get("async_client", None)
|
model_client = deployment.get("client", None)
|
||||||
|
|
||||||
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
|
|
||||||
|
@ -830,6 +830,7 @@ class Router:
|
||||||
or custom_llm_provider == "openai"
|
or custom_llm_provider == "openai"
|
||||||
or custom_llm_provider == "azure"
|
or custom_llm_provider == "azure"
|
||||||
or "ft:gpt-3.5-turbo" in model_name
|
or "ft:gpt-3.5-turbo" in model_name
|
||||||
|
or model_name in litellm.open_ai_embedding_models
|
||||||
):
|
):
|
||||||
# glorified / complicated reading of configs
|
# glorified / complicated reading of configs
|
||||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||||
|
|
|
@ -299,3 +299,34 @@ def test_aembedding_on_router():
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_aembedding_on_router()
|
# test_aembedding_on_router()
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_aembedding_on_router():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
try:
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "text-embedding-ada-002",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/azure-embedding-model",
|
||||||
|
"api_key":os.environ['AZURE_API_KEY'],
|
||||||
|
"api_base": os.environ['AZURE_API_BASE']
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 10000,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
async def embedding_call():
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
response = await router.aembedding(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
input=["good morning from litellm"]
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
router.reset()
|
||||||
|
asyncio.run(embedding_call())
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
# test_azure_aembedding_on_router()
|
Loading…
Add table
Add a link
Reference in a new issue