diff --git a/README.md b/README.md index 2cafba1..2069f15 100644 --- a/README.md +++ b/README.md @@ -63,8 +63,9 @@ Feel free to adjust the configuration according to your needs. > Get `config.yaml` path with command `aichat --info` or repl command `.info`. ```yaml -model: openai:gpt-3.5-turbo # The Large Language Model (LLM) to use -temperature: 1.0 # Controls the randomness and creativity of the LLM's responses +model: openai:gpt-3.5-turbo # Specify the language model to use +temperature: null # Set default temperature parameter +top_p: null # Set default top-p parameter save: true # Indicates whether to persist the message save_session: null # Controls the persistence of the session, if null, asking the user highlight: true # Controls syntax highlighting diff --git a/config.example.yaml b/config.example.yaml index deed762..3b29c86 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,5 +1,6 @@ -model: openai:gpt-3.5-turbo # The Large Language Model (LLM) to use -temperature: 1.0 # Controls the randomness and creativity of the LLM's responses +model: openai:gpt-3.5-turbo # Specify the language model to use +temperature: null # Set default temperature parameter +top_p: null # Set default top-p parameter save: true # Indicates whether to persist the message save_session: null # Controls the persistence of the session, if null, asking the user highlight: true # Controls syntax highlighting diff --git a/src/client/claude.rs b/src/client/claude.rs index 283ae99..054731c 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -140,6 +140,7 @@ fn build_body(data: SendData, model: &Model) -> Result { let SendData { mut messages, temperature, + top_p, stream, } = data; @@ -205,6 +206,9 @@ fn build_body(data: SendData, model: &Model) -> Result { if let Some(v) = temperature { body["temperature"] = v.into(); } + if let Some(v) = top_p { + body["top_p"] = v.into(); + } if stream { body["stream"] = true.into(); } diff --git a/src/client/cohere.rs b/src/client/cohere.rs index 4d713b6..cfab0fa 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -110,6 +110,7 @@ fn build_body(data: SendData, model: &Model) -> Result { let SendData { mut messages, temperature, + top_p, stream, } = data; @@ -173,6 +174,9 @@ fn build_body(data: SendData, model: &Model) -> Result { if let Some(temperature) = temperature { body["temperature"] = temperature.into(); } + if let Some(top_p) = top_p { + body["p"] = top_p.into(); + } if stream { body["stream"] = true.into(); } diff --git a/src/client/common.rs b/src/client/common.rs index 62c5f92..593140d 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -323,6 +323,7 @@ pub struct ExtraConfig { pub struct SendData { pub messages: Vec, pub temperature: Option, + pub top_p: Option, pub stream: bool, } diff --git a/src/client/ernie.rs b/src/client/ernie.rs index 68d1ace..ffe10f0 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -230,6 +230,7 @@ fn build_body(data: SendData, model: &Model) -> Value { let SendData { mut messages, temperature, + top_p, stream, } = data; @@ -242,6 +243,9 @@ fn build_body(data: SendData, model: &Model) -> Value { if let Some(temperature) = temperature { body["temperature"] = temperature.into(); } + if let Some(top_p) = top_p { + body["top_p"] = top_p.into(); + } if let Some(max_output_tokens) = model.max_output_tokens { body["max_output_tokens"] = max_output_tokens.into(); diff --git a/src/client/ollama.rs b/src/client/ollama.rs index 055730a..434e487 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -122,6 +122,7 @@ fn build_body(data: SendData, model: &Model) -> Result { let SendData { messages, temperature, + top_p, stream, } = data; @@ -185,6 +186,9 @@ fn build_body(data: SendData, model: &Model) -> Result { if let Some(temperature) = temperature { body["options"]["temperature"] = temperature.into(); } + if let Some(top_p) = top_p { + body["options"]["top_p"] = top_p.into(); + } Ok(body) } diff --git a/src/client/openai.rs b/src/client/openai.rs index f35c6ab..2c5d99b 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -130,6 +130,7 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value { let SendData { messages, temperature, + top_p, stream, } = data; @@ -139,13 +140,16 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value { }); if let Some(max_tokens) = model.max_output_tokens { - body["max_tokens"] = json!(max_tokens); + body["max_tokens"] = max_tokens.into(); } else if model.name == "gpt-4-vision-preview" { // The default max_tokens of gpt-4-vision-preview is only 16, we need to make it larger - body["max_tokens"] = json!(4096); + body["max_tokens"] = 4096.into(); } - if let Some(v) = temperature { - body["temperature"] = v.into(); + if let Some(temperature) = temperature { + body["temperature"] = temperature.into(); + } + if let Some(top_p) = top_p { + body["top_p"] = top_p.into(); } if stream { body["stream"] = true.into(); diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index fb97964..c5a72e0 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -130,7 +130,6 @@ async fn send_message_streaming( is_vl: bool, ) -> Result<()> { let mut es = builder.eventsource()?; - let mut offset = 0; while let Some(event) = es.next().await { match event { @@ -139,12 +138,10 @@ async fn send_message_streaming( let data: Value = serde_json::from_str(&message.data)?; catch_error(&data)?; if is_vl { - let text = - data["output"]["choices"][0]["message"]["content"][0]["text"].as_str(); - if let Some(text) = text { - let text = &text[offset..]; + if let Some(text) = + data["output"]["choices"][0]["message"]["content"][0]["text"].as_str() + { handler.text(text)?; - offset += text.len(); } } else if let Some(text) = data["output"]["text"].as_str() { handler.text(text)?; @@ -169,11 +166,12 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool let SendData { messages, temperature, + top_p, stream, } = data; let mut has_upload = false; - let (input, parameters) = if is_vl { + let input = if is_vl { let messages: Vec = messages .into_iter() .map(|message| { @@ -199,40 +197,37 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool }) .collect(); - let input = json!({ + json!({ "messages": messages, - }); - - let mut parameters = json!({}); - if let Some(v) = temperature { - parameters["temperature"] = v.into(); - } - (input, parameters) + }) } else { - let input = json!({ + json!({ "messages": messages, - }); + }) + }; - let mut parameters = json!({}); - if stream { - parameters["incremental_output"] = true.into(); - } + let mut parameters = json!({}); + if stream { + parameters["incremental_output"] = true.into(); + } - if let Some(max_tokens) = model.max_output_tokens { - parameters["max_tokens"] = max_tokens.into(); - } + if let Some(max_tokens) = model.max_output_tokens { + parameters["max_tokens"] = max_tokens.into(); + } - if let Some(v) = temperature { - parameters["temperature"] = v.into(); - } - (input, parameters) - }; + if let Some(temperature) = temperature { + parameters["temperature"] = temperature.into(); + } + if let Some(top_p) = top_p { + parameters["top_p"] = top_p.into(); + } let body = json!({ "model": &model.name, "input": input, "parameters": parameters }); + Ok((body, has_upload)) } diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index d9c1019..e0ae567 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -158,7 +158,8 @@ pub(crate) fn build_body( let SendData { mut messages, temperature, - .. + top_p, + stream: _, } = data; patch_system_message(&mut messages); @@ -223,6 +224,10 @@ pub(crate) fn build_body( body["generationConfig"]["temperature"] = temperature.into(); } + if let Some(top_p) = top_p { + body["generationConfig"]["topP"] = top_p.into(); + } + Ok(body) } diff --git a/src/config/mod.rs b/src/config/mod.rs index e1cec4d..87be519 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -53,6 +53,7 @@ pub struct Config { #[serde(rename(serialize = "model", deserialize = "model"))] pub model_id: Option, pub temperature: Option, + pub top_p: Option, pub dry_run: bool, pub save: bool, pub save_session: Option, @@ -89,6 +90,7 @@ impl Default for Config { Self { model_id: None, temperature: None, + top_p: None, save: true, save_session: None, highlight: true, @@ -297,6 +299,7 @@ impl Config { if let Some(session) = self.session.as_mut() { session.guard_empty()?; session.set_temperature(role.temperature); + session.set_top_p(role.top_p); } self.role = Some(role); Ok(()) @@ -335,6 +338,16 @@ impl Config { } } + pub fn set_top_p(&mut self, value: Option) { + if let Some(session) = self.session.as_mut() { + session.set_top_p(value); + } else if let Some(role) = self.role.as_mut() { + role.set_top_p(value); + } else { + self.top_p = value; + } + } + pub fn set_save_session(&mut self, value: Option) { if let Some(session) = self.session.as_mut() { session.set_save_session(value); @@ -411,6 +424,7 @@ impl Config { let items = vec![ ("model", self.model.id()), ("temperature", format_option(&self.temperature)), + ("top_p", format_option(&self.top_p)), ("dry_run", self.dry_run.to_string()), ("save", self.save.to_string()), ("save_session", format_option(&self.save_session)), @@ -478,6 +492,7 @@ impl Config { ".session" => self.list_sessions(), ".set" => vec![ "temperature ", + "top_p ", "compress_threshold", "save ", "save_session ", @@ -529,6 +544,10 @@ impl Config { let value = parse_value(value)?; self.set_temperature(value); } + "top_p" => { + let value = parse_value(value)?; + self.set_top_p(value); + } "compress_threshold" => { let value = parse_value(value)?; self.set_compress_threshold(value); @@ -756,10 +775,18 @@ impl Config { } else { self.temperature }; + let top_p = if let Some(session) = input.session(&self.session) { + session.top_p() + } else if let Some(role) = input.role() { + role.top_p + } else { + self.top_p + }; self.model.max_input_tokens_limit(&messages)?; Ok(SendData { messages, temperature, + top_p, stream, }) } @@ -791,6 +818,11 @@ impl Config { output.insert("temperature", temperature.to_string()); } } + if let Some(top_p) = self.top_p { + if top_p != 0.0 { + output.insert("top_p", top_p.to_string()); + } + } if self.dry_run { output.insert("dry_run", "true".to_string()); } diff --git a/src/config/role.rs b/src/config/role.rs index 50d5b5e..b226622 100644 --- a/src/config/role.rs +++ b/src/config/role.rs @@ -16,12 +16,10 @@ pub const INPUT_PLACEHOLDER: &str = "__INPUT__"; #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Role { - /// Role name pub name: String, - /// Prompt text pub prompt: String, - /// Temperature value pub temperature: Option, + pub top_p: Option, } impl Role { @@ -30,6 +28,7 @@ impl Role { name: TEMP_ROLE.into(), prompt: prompt.into(), temperature: None, + top_p: None, } } @@ -67,6 +66,7 @@ If there is a lack of details, provide most logical solution. Output plain text only, without any markdown formatting."# ), temperature: None, + top_p: None, } } @@ -79,6 +79,7 @@ Provide short responses in about 80 words. APPLY MARKDOWN formatting when possible."# .into(), temperature: None, + top_p: None, } } @@ -89,6 +90,7 @@ APPLY MARKDOWN formatting when possible."# If there is a lack of details, provide most logical solution, without requesting further clarification."# .into(), temperature: None, + top_p: None, } } @@ -106,6 +108,10 @@ If there is a lack of details, provide most logical solution, without requesting self.temperature = value; } + pub fn set_top_p(&mut self, value: Option) { + self.top_p = value; + } + pub fn complete_prompt_args(&mut self, name: &str) { self.name = name.to_string(); self.prompt = complete_prompt_args(&self.prompt, &self.name); diff --git a/src/config/session.rs b/src/config/session.rs index 801615f..4ebf0d3 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -18,6 +18,7 @@ pub struct Session { #[serde(rename(serialize = "model", deserialize = "model"))] model_id: String, temperature: Option, + top_p: Option, #[serde(default)] save_session: Option, messages: Vec, @@ -43,6 +44,7 @@ impl Session { Self { model_id: config.model.id(), temperature: config.temperature, + top_p: config.top_p, save_session: config.save_session, messages: vec![], compressed_messages: vec![], @@ -80,6 +82,10 @@ impl Session { self.temperature } + pub fn top_p(&self) -> Option { + self.top_p + } + pub fn save_session(&self) -> Option { self.save_session } @@ -111,6 +117,9 @@ impl Session { if let Some(temperature) = self.temperature() { data["temperature"] = temperature.into(); } + if let Some(top_p) = self.top_p() { + data["top_p"] = top_p.into(); + } if let Some(save_session) = self.save_session() { data["save_session"] = save_session.into(); } @@ -140,6 +149,9 @@ impl Session { if let Some(temperature) = self.temperature() { items.push(("temperature", temperature.to_string())); } + if let Some(top_p) = self.top_p() { + items.push(("top_p", top_p.to_string())); + } if let Some(save_session) = self.save_session() { items.push(("save_session", save_session.to_string())); @@ -207,6 +219,13 @@ impl Session { } } + pub fn set_top_p(&mut self, value: Option) { + if self.top_p != value { + self.top_p = value; + self.dirty = true; + } + } + pub fn set_save_session(&mut self, value: Option) { if self.save_session != value { self.save_session = value; diff --git a/src/serve.rs b/src/serve.rs index 2286b61..e443bd2 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -134,6 +134,7 @@ impl Server { model, messages, temperature, + top_p, max_tokens, stream, } = req_body; @@ -161,6 +162,7 @@ impl Server { let send_data: SendData = SendData { messages, temperature, + top_p, stream, }; @@ -242,6 +244,7 @@ struct ChatCompletionReqBody { model: String, messages: Vec, temperature: Option, + top_p: Option, max_tokens: Option, #[serde(default)] stream: bool,