@ -287,7 +287,7 @@ class Manifest:
def run (
self ,
prompt : Union [ str , List [ str ] ],
prompt : Union [ str , List [ str ] , List [ Dict [ str , str ] ] ],
overwrite_cache : bool = False ,
stop_token : Optional [ str ] = None ,
return_response : bool = False ,
@ -296,6 +296,8 @@ class Manifest:
"""
Run the prompt .
Orchestrates between the standard run and chat run and batch run .
Args :
prompt : prompt ( s ) to run .
overwrite_cache : whether to overwrite cache .
@ -307,9 +309,68 @@ class Manifest:
Returns :
response from prompt .
"""
is_batch = isinstance ( prompt , list )
if not isinstance ( prompt , list ) and not isinstance ( prompt , str ) :
raise ValueError (
f " Invalid prompt type: { type ( prompt ) } . "
" Prompt must be a string or list of strings "
" or list of dicts. "
)
if isinstance ( prompt , list ) and not prompt :
raise ValueError ( " Prompt cannot be empty list " )
# Get the client to run
client = self . client_pool . get_next_client ( )
if isinstance ( prompt , list ) and isinstance ( prompt [ 0 ] , dict ) :
if not client . IS_CHAT :
raise ValueError (
f " Client { client } does not support dict chat prompt. "
" Please use a chat model. "
)
if stop_token :
logger . warning (
" stop_token is not supported for chat prompt. "
" Ignoring stop_token. "
)
return self . _run_chat (
prompt = cast ( List [ Dict [ str , str ] ] , prompt ) ,
client = client ,
overwrite_cache = overwrite_cache ,
return_response = return_response ,
)
else :
return self . _run (
prompt = cast ( Union [ str , List [ str ] ] , prompt ) ,
client = client ,
overwrite_cache = overwrite_cache ,
stop_token = stop_token ,
return_response = return_response ,
* * kwargs ,
)
def _run (
self ,
prompt : Union [ str , List [ str ] ] ,
client : Client ,
overwrite_cache : bool = False ,
stop_token : Optional [ str ] = None ,
return_response : bool = False ,
* * kwargs : Any ,
) - > Union [ str , List [ str ] , np . ndarray , List [ np . ndarray ] , Response ] :
"""
Run the prompt .
Args :
prompt : prompt ( s ) to run .
client : client to run .
overwrite_cache : whether to overwrite cache .
stop_token : stop token for prompt generation .
Default is self . stop_token .
" " for no stop token .
return_response : whether to return Response object .
Returns :
response from prompt .
"""
is_batch = isinstance ( prompt , list )
stop_token = stop_token if stop_token is not None else self . stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = client . get_request ( prompt , kwargs )
@ -344,6 +405,67 @@ class Manifest:
else :
return final_response . get_response ( stop_token , is_batch )
def _run_chat (
self ,
prompt : List [ Dict [ str , str ] ] ,
client : Client ,
overwrite_cache : bool = False ,
return_response : bool = False ,
* * kwargs : Any ,
) - > Union [ str , Response ] :
"""
Run the prompt .
Args :
prompt : prompt dictionary to run .
client : client to run .
overwrite_cache : whether to overwrite cache .
stop_token : stop token for prompt generation .
Default is self . stop_token .
" " for no stop token .
return_response : whether to return Response object .
Returns :
response from prompt .
"""
is_batch = False
# Get a request for an empty prompt to handle all kwargs
request_params = client . get_request ( " " , kwargs )
# Add prompt and cast as chat request
request_params_dict = request_params . to_dict ( )
request_params_dict [ " prompt " ] = prompt
request_params_as_chat = LMChatRequest ( * * request_params_dict )
# Avoid nested list of results - enforce n = 1 for batch
if request_params_as_chat . n > 1 :
raise ValueError ( " Chat mode does not support n > 1. " )
self . _validate_kwargs ( kwargs , request_params_as_chat )
cached_idx_to_response , request_params_as_chat = self . _split_cached_requests ( # type: ignore # noqa: E501
request_params_as_chat , client , overwrite_cache
)
# If not None value or empty list - run new request
if request_params_as_chat . prompt :
# Start timing metrics
self . client_pool . start_timer ( )
response = client . run_chat_request ( request_params_as_chat )
self . client_pool . end_timer ( )
else :
# Nothing to run
response = None
final_response = self . _stitch_responses_and_cache (
request = request_params_as_chat ,
client = client ,
response = response ,
cached_idx_to_response = cached_idx_to_response ,
)
# Extract text results
if return_response :
return final_response
else :
return cast ( str , final_response . get_response ( " " , is_batch ) )
async def arun_batch (
self ,
prompts : List [ str ] ,
@ -381,6 +503,13 @@ class Manifest:
Returns :
response from prompt .
"""
if not isinstance ( prompts , list ) :
raise ValueError ( " Prompts must be a list of strings. " )
if not prompts :
raise ValueError ( " Prompts must not be empty. " )
if not isinstance ( prompts [ 0 ] , str ) :
raise ValueError ( " Prompts must be a list of strings. " )
# Split the prompts into chunks
prompt_chunks : List [ Tuple [ Client , List [ str ] ] ] = [ ]
if chunk_size > 0 :
@ -464,67 +593,6 @@ class Manifest:
)
return final_response
def run_chat (
self ,
prompt : List [ Dict [ str , str ] ] ,
overwrite_cache : bool = False ,
return_response : bool = False ,
* * kwargs : Any ,
) - > Union [ str , Response ] :
"""
Run the prompt .
Args :
prompt : prompt dictionary to run .
overwrite_cache : whether to overwrite cache .
stop_token : stop token for prompt generation .
Default is self . stop_token .
" " for no stop token .
return_response : whether to return Response object .
Returns :
response from prompt .
"""
is_batch = False
# Get the client to run
client = self . client_pool . get_next_client ( )
# Get a request for an empty prompt to handle all kwargs
request_params = client . get_request ( " " , kwargs )
# Add prompt and cast as chat request
request_params_dict = request_params . to_dict ( )
request_params_dict [ " prompt " ] = prompt
request_params_as_chat = LMChatRequest ( * * request_params_dict )
# Avoid nested list of results - enforce n = 1 for batch
if request_params_as_chat . n > 1 :
raise ValueError ( " Chat mode does not support n > 1. " )
self . _validate_kwargs ( kwargs , request_params_as_chat )
cached_idx_to_response , request_params_as_chat = self . _split_cached_requests ( # type: ignore # noqa: E501
request_params_as_chat , client , overwrite_cache
)
# If not None value or empty list - run new request
if request_params_as_chat . prompt :
# Start timing metrics
self . client_pool . start_timer ( )
response = client . run_chat_request ( request_params_as_chat )
self . client_pool . end_timer ( )
else :
# Nothing to run
response = None
final_response = self . _stitch_responses_and_cache (
request = request_params_as_chat ,
client = client ,
response = response ,
cached_idx_to_response = cached_idx_to_response ,
)
# Extract text results
if return_response :
return final_response
else :
return cast ( str , final_response . get_response ( " " , is_batch ) )
def score_prompt (
self ,
prompt : Union [ str , List [ str ] ] ,