forked from phoenix/litellm-mirror
Merge pull request #2720 from onukura/ollama-batch-embedding
Batch embedding for Ollama
This commit is contained in:
commit
28905c85b6
2 changed files with 52 additions and 62 deletions
|
@ -344,9 +344,9 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
||||||
|
|
||||||
|
|
||||||
async def ollama_aembeddings(
|
async def ollama_aembeddings(
|
||||||
api_base="http://localhost:11434",
|
api_base: str,
|
||||||
model="llama2",
|
model: str,
|
||||||
prompt="Why is the sky blue?",
|
prompts: list[str],
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
model_response=None,
|
model_response=None,
|
||||||
|
@ -365,6 +365,11 @@ async def ollama_aembeddings(
|
||||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
|
total_input_tokens = 0
|
||||||
|
output_data = []
|
||||||
|
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
for idx, prompt in enumerate(prompts):
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
@ -375,10 +380,8 @@ async def ollama_aembeddings(
|
||||||
api_key=None,
|
api_key=None,
|
||||||
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}},
|
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}},
|
||||||
)
|
)
|
||||||
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
||||||
response = await session.post(url, json=data)
|
|
||||||
|
|
||||||
|
response = await session.post(url, json=data)
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
text = await response.text()
|
text = await response.text()
|
||||||
raise OllamaError(status_code=response.status, message=text)
|
raise OllamaError(status_code=response.status, message=text)
|
||||||
|
@ -395,22 +398,19 @@ async def ollama_aembeddings(
|
||||||
)
|
)
|
||||||
|
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
embeddings = response_json["embedding"]
|
embeddings: list[float] = response_json["embedding"]
|
||||||
embeddings = [embeddings] # Ollama currently does not support batch embedding
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
output_data = []
|
|
||||||
for idx, embedding in enumerate(embeddings):
|
|
||||||
output_data.append(
|
output_data.append(
|
||||||
{"object": "embedding", "index": idx, "embedding": embedding}
|
{"object": "embedding", "index": idx, "embedding": embeddings}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
input_tokens = len(encoding.encode(prompt))
|
||||||
|
total_input_tokens += input_tokens
|
||||||
|
|
||||||
model_response["object"] = "list"
|
model_response["object"] = "list"
|
||||||
model_response["data"] = output_data
|
model_response["data"] = output_data
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
|
|
||||||
input_tokens = len(encoding.encode(prompt))
|
|
||||||
|
|
||||||
model_response["usage"] = {
|
model_response["usage"] = {
|
||||||
"prompt_tokens": input_tokens,
|
"prompt_tokens": total_input_tokens,
|
||||||
"total_tokens": input_tokens,
|
"total_tokens": total_input_tokens,
|
||||||
}
|
}
|
||||||
return model_response
|
return model_response
|
||||||
|
|
|
@ -2796,29 +2796,19 @@ def embedding(
|
||||||
or get_secret("OLLAMA_API_BASE")
|
or get_secret("OLLAMA_API_BASE")
|
||||||
or "http://localhost:11434"
|
or "http://localhost:11434"
|
||||||
)
|
)
|
||||||
ollama_input = None
|
if isinstance(input ,str):
|
||||||
if isinstance(input, list) and len(input) > 1:
|
input = [input]
|
||||||
raise litellm.BadRequestError(
|
if not all(isinstance(item, str) for item in input):
|
||||||
message=f"Ollama Embeddings don't support batch embeddings",
|
|
||||||
model=model, # type: ignore
|
|
||||||
llm_provider="ollama", # type: ignore
|
|
||||||
)
|
|
||||||
if isinstance(input, list) and len(input) == 1:
|
|
||||||
ollama_input = "".join(input[0])
|
|
||||||
elif isinstance(input, str):
|
|
||||||
ollama_input = input
|
|
||||||
else:
|
|
||||||
raise litellm.BadRequestError(
|
raise litellm.BadRequestError(
|
||||||
message=f"Invalid input for ollama embeddings. input={input}",
|
message=f"Invalid input for ollama embeddings. input={input}",
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
llm_provider="ollama", # type: ignore
|
llm_provider="ollama", # type: ignore
|
||||||
)
|
)
|
||||||
|
if aembedding:
|
||||||
if aembedding == True:
|
|
||||||
response = ollama.ollama_aembeddings(
|
response = ollama.ollama_aembeddings(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
model=model,
|
model=model,
|
||||||
prompt=ollama_input,
|
prompts=input,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue