@ -1,4 +1,5 @@
""" Manifest class. """
import copy
import logging
from typing import Any , Dict , List , Optional , Tuple , Union , cast
@ -13,40 +14,32 @@ from manifest.clients.cohere import CohereClient
from manifest . clients . dummy import DummyClient
from manifest . clients . huggingface import HuggingFaceClient
from manifest . clients . openai import OpenAIClient
from manifest . clients . openaichat import OpenAIChatClient
from manifest . clients . toma import TOMAClient
from manifest . request import Request
from manifest . response import Response
from manifest . session import Session
logging . getLogger ( " openai " ) . setLevel ( logging . WARNING )
logger = logging . getLogger ( __name__ )
CLIENT_CONSTRUCTORS = {
" openai " : OpenAIClient ,
" cohere " : CohereClient ,
" ai21 " : AI21Client ,
" huggingface " : HuggingFaceClient ,
" dummy " : DummyClient ,
" toma " : TOMAClient ,
OpenAIClient . NAME : OpenAIClient ,
OpenAIChatClient . NAME : OpenAIChatClient ,
CohereClient . NAME : CohereClient ,
AI21Client . NAME : AI21Client ,
HuggingFaceClient . NAME : HuggingFaceClient ,
DummyClient . NAME : DummyClient ,
TOMAClient . NAME : TOMAClient ,
}
# ChatGPT
try :
from manifest . clients . chatgpt import ChatGPTClient
CLIENT_CONSTRUCTORS [ " chatgpt " ] = ChatGPTClient
except Exception :
logger . info ( " ChatGPT not installed. Skipping import. " )
pass
# Diffusion
DIFFUSION_CLIENTS = [ " diffuser " , " tomadiffuser " ]
try :
from manifest . clients . diffuser import DiffuserClient
from manifest . clients . toma_diffuser import TOMADiffuserClient
CLIENT_CONSTRUCTORS [ " diffuser " ] = DiffuserClient
CLIENT_CONSTRUCTORS [ " tomadiffuser " ] = TOMADiffuserClient
CLIENT_CONSTRUCTORS [ DiffuserClient . NAME ] = DiffuserClient
CLIENT_CONSTRUCTORS [ TOMADiffuserClient . NAME ] = TOMADiffuserClient
except Exception :
logger . info ( " Diffusion not supported. Skipping import. " )
pass
@ -69,7 +62,6 @@ class Manifest:
client_connection : Optional [ str ] = None ,
cache_name : str = " noop " ,
cache_connection : Optional [ str ] = None ,
session_id : Optional [ str ] = None ,
stop_token : str = " " ,
* * kwargs : Any ,
) :
@ -81,9 +73,6 @@ class Manifest:
client_connection : connection string for client .
cache_name : name of cache .
cache_connection : connection string for cache .
session_id : session id for user session cache .
None ( default ) means no session logging .
" _default " means generate new session id .
stop_token : stop token prompt generation .
Can be overridden in run
@ -114,17 +103,6 @@ class Manifest:
self . client = CLIENT_CONSTRUCTORS [ self . client_name ] ( # type: ignore
client_connection , client_args = kwargs
)
if session_id is not None :
if self . client_name == " diffuser " :
raise NotImplementedError (
" Session logging not implemented for Diffuser client. "
)
if session_id == " _default " :
# Set session_id to None for Session random id
session_id = None
self . session = Session ( session_id )
else :
self . session = None
if len ( kwargs ) > 0 :
raise ValueError ( f " { list ( kwargs . items ( ) ) } arguments are not recognized. " )
@ -195,11 +173,133 @@ class Manifest:
logger . warning ( f " { list ( request_unused_kwargs ) } arguments are unused. " )
return
def _split_cached_requests (
self ,
request : Request ,
overwrite_cache : bool ,
) - > Tuple [ Dict [ int , Response ] , Request ] :
""" Split a request into cached responses and Requests to run.
Args :
request : request object .
overwrite_cache : whether to overwrite cache .
Returns :
cached_idx_to_response : dict of cached responses .
new_request : request object with only prompts to run .
"""
cached_idx_to_response : Dict [ int , Response ] = { }
new_request = copy . deepcopy ( request )
if not overwrite_cache :
if isinstance ( new_request . prompt , list ) :
new_request . prompt = [ ]
for idx , prompt_str in enumerate ( request . prompt ) :
single_request = copy . deepcopy ( request )
single_request . prompt = prompt_str
possible_response = self . cache . get (
self . client . get_cache_key ( single_request )
)
if possible_response :
cached_idx_to_response [ idx ] = possible_response
else :
new_request . prompt . append ( prompt_str )
else :
possible_response = self . cache . get (
self . client . get_cache_key ( new_request )
)
if possible_response :
cached_idx_to_response [ 0 ] = possible_response
new_request . prompt = None
return cached_idx_to_response , new_request
def _stitch_responses_and_cache (
self ,
request : Request ,
response : Union [ Response , None ] ,
cached_idx_to_response : Dict [ int , Response ] ,
) - > Response :
""" Stich together the cached and uncached responses. """
# We stitch the responses (the choices) here from both the new request the
# cached entries.
all_model_choices = [ ]
all_input_prompts = [ ]
response_idx = 0
number_prompts = len ( cached_idx_to_response )
single_output = False
if response :
if isinstance ( response . get_request ( ) [ " prompt " ] , str ) :
single_output = True
number_prompts + = 1
else :
number_prompts + = len ( response . get_request ( ) [ " prompt " ] )
response_gen_key = None
response_logits_key = None
response_item_key = None
for idx in range ( number_prompts ) :
if idx in cached_idx_to_response :
cached_res = cached_idx_to_response [ idx ]
response_gen_key = cached_res . generation_key
response_logits_key = cached_res . logits_key
response_item_key = cached_res . item_key
all_input_prompts . append ( cached_res . get_request ( ) [ " prompt " ] )
if request . n == 1 :
assert (
len ( cached_res . get_json_response ( ) [ response_gen_key ] ) == 1
) , " cached response should have only one choice "
all_model_choices . append (
cached_res . get_json_response ( ) [ response_gen_key ] [ 0 ]
)
else :
all_model_choices . extend (
cached_res . get_json_response ( ) [ response_gen_key ]
)
else :
assert response is not None , " response should not be None "
response = cast ( Response , response )
response_gen_key = response . generation_key
response_logits_key = response . logits_key
response_item_key = response . item_key
# the choices list in the response is a flat one.
# length is request.n * num_prompts
current_choices = response . get_json_response ( ) [ response_gen_key ] [
response_idx * request . n : ( response_idx + 1 ) * request . n
]
all_model_choices . extend ( current_choices )
if isinstance ( response . get_request ( ) [ " prompt " ] , list ) :
prompt = response . get_request ( ) [ " prompt " ] [ response_idx ]
else :
prompt = str ( response . get_request ( ) [ " prompt " ] )
all_input_prompts . append ( prompt )
# set cache
new_request = copy . deepcopy ( request )
new_request . prompt = prompt
cache_key = self . client . get_cache_key ( new_request )
new_response_key = copy . deepcopy ( response . get_json_response ( ) )
new_response_key [ response_gen_key ] = current_choices
self . cache . set ( cache_key , new_response_key )
response_idx + = 1
new_request = copy . deepcopy ( request )
new_request . prompt = (
all_input_prompts
if len ( all_input_prompts ) > 1 or not single_output
else all_input_prompts [ 0 ]
)
response_obj = Response (
{ response_gen_key : all_model_choices } ,
cached = len ( cached_idx_to_response ) > 0 ,
request_params = self . client . get_cache_key ( new_request ) ,
generation_key = response_gen_key ,
logits_key = response_logits_key ,
item_key = response_item_key ,
)
return response_obj
def run (
self ,
prompt : Union [ str , List [ str ] ] ,
overwrite_cache : bool = False ,
run_id : Optional [ str ] = None ,
stop_token : Optional [ str ] = None ,
return_response : bool = False ,
* * kwargs : Any ,
@ -210,7 +310,6 @@ class Manifest:
Args :
prompt : prompt ( s ) to run .
overwrite_cache : whether to overwrite cache .
run_id : run id for cache to repeat same run .
stop_token : stop token for prompt generation .
Default is self . stop_token .
" " for no stop token .
@ -223,28 +322,88 @@ class Manifest:
stop_token = stop_token if stop_token is not None else self . stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self . client . get_request _params ( prompt , kwargs )
request_params = self . client . get_request ( prompt , kwargs )
# Avoid nested list of results - enforce n = 1 for batch
if is_batch and request_params . n > 1 :
raise ValueError ( " Batch mode does not support n > 1. " )
possible_request , full_kwargs = self . client . get_request ( request_params )
self . _validate_kwargs ( kwargs , request_params )
cached_idx_to_response , request_params = self . _split_cached_requests (
request_params , overwrite_cache
)
# If not None value or empty list - run new request
if request_params . prompt :
response = self . client . run_request ( request_params )
else :
# Nothing to run
response = None
final_response = self . _stitch_responses_and_cache (
request = request_params ,
response = response ,
cached_idx_to_response = cached_idx_to_response ,
)
# Extract text results
if return_response :
return final_response
else :
return final_response . get_response ( stop_token , is_batch )
async def arun_batch (
self ,
prompts : List [ str ] ,
overwrite_cache : bool = False ,
stop_token : Optional [ str ] = None ,
return_response : bool = False ,
* * kwargs : Any ,
) - > Union [ List [ str ] , List [ np . ndarray ] , Response ] :
"""
Run a batch of prompts with async .
Args :
prompts : prompts to run .
overwrite_cache : whether to overwrite cache .
stop_token : stop token for prompt generation .
Default is self . stop_token .
" " for no stop token .
return_response : whether to return Response object .
Returns :
response from prompt .
"""
stop_token = stop_token if stop_token is not None else self . stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self . client . get_request ( prompts , kwargs )
# Avoid nested list of results - enforce n = 1 for batch
if request_params . n > 1 :
raise ValueError ( " Batch mode does not support n > 1. " )
self . _validate_kwargs ( kwargs , request_params )
# Create cacke key
cache_key = full_kwargs . copy ( )
# Make query model dependent
cache_key . update ( self . client . get_model_params ( ) )
if run_id :
cache_key [ " run_id " ] = run_id
response_obj = self . cache . get ( cache_key , overwrite_cache , possible_request )
# Log session dictionary values
if self . session :
self . session . log_query ( cache_key , response_obj . to_dict ( ) )
cached_idx_to_response , request_params = self . _split_cached_requests (
request_params , overwrite_cache
)
# If not None value or empty list - run new request
if request_params . prompt :
response = await self . client . arun_batch_request ( request_params )
else :
# Nothing to run
response = None
final_response = self . _stitch_responses_and_cache (
request = request_params ,
response = response ,
cached_idx_to_response = cached_idx_to_response ,
)
# Extract text results
if return_response :
return response_obj
return final_ response
else :
return response_obj . get_response ( stop_token , is_batch )
return cast (
Union [ List [ str ] , List [ np . ndarray ] ] ,
final_response . get_response ( stop_token , True ) ,
)
def score_prompt (
self ,
@ -257,8 +416,6 @@ class Manifest:
Returns the response object with logits of the prompt .
Prompt scoring is not part of a session cache .
Args :
prompt : prompt ( s ) to run .
overwrite_cache : whether to overwrite cache .
@ -267,66 +424,31 @@ class Manifest:
response from prompt .
"""
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self . client . get_request_params ( prompt , kwargs )
request_params = self . client . get_request ( prompt , kwargs )
request_params . request_type = " score_prompt "
if request_params . n > 1 :
raise ValueError ( " Sequence scoring does not support n > 1. " )
self . _validate_kwargs ( kwargs , request_params )
cached_idx_to_response , request_params = self . _split_cached_requests (
request_params , overwrite_cache
)
# If not None value or empty list - run new request
if request_params . prompt :
try :
possible_request , full_kwargs = cast (
res pon se = cast (
HuggingFaceClient , self . client
) . get_score_prompt_request ( request_params )
except AttributeError :
raise ValueError ( " `score_prompt` only supported for HF models. " )
else :
# Nothing to run
response = None
self . _validate_kwargs ( kwargs , request_params )
# Create cacke key
cache_key = full_kwargs . copy ( )
# Make query model dependent
cache_key . update ( self . client . get_model_params ( ) )
response_obj = self . cache . get ( cache_key , overwrite_cache , possible_request )
return response_obj . to_dict ( )
def get_last_queries (
self ,
last_n : int = - 1 ,
return_raw_values : bool = False ,
stop_token : Optional [ str ] = None ,
) - > List [ Tuple [ Any , Any ] ] :
"""
Get last n queries from current session .
If last_n is - 1 , return all queries . By default will only return the
prompt text and result text unles return_raw_values is False .
Args :
last_n : last n queries .
return_raw_values : whether to return raw values as dicts .
stop_token : stop token for prompt results to be applied to all results .
Returns :
last n list of queries and outputs .
"""
if self . session is None :
raise ValueError (
" Session was not initialized. Set `session_id` when loading Manifest. "
final_response = self . _stitch_responses_and_cache (
request = request_params ,
response = response ,
cached_idx_to_response = cached_idx_to_response ,
)
stop_token = stop_token if stop_token is not None else self . stop_token
last_queries = self . session . get_last_queries ( last_n )
if not return_raw_values :
last_queries = [
(
query [ " prompt " ] ,
Response . from_dict ( response ) . get_response (
stop_token , is_batch = isinstance ( query [ " prompt " ] , list )
) ,
) # type: ignore
for query , response in last_queries
]
return last_queries
def open_explorer ( self ) - > None :
""" Open the explorer for jupyter widget. """
# Open explorer
# TODO: implement
pass
return final_response . to_dict ( )