mirror of https://github.com/sigoden/aichat
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
644 lines
20 KiB
Rust
644 lines
20 KiB
Rust
use super::{openai::OpenAIConfig, ClientConfig, Message, Model, ReplyHandler};
|
|
|
|
use crate::{
|
|
config::{GlobalConfig, Input},
|
|
render::{render_error, render_stream},
|
|
utils::{prompt_input_integer, prompt_input_string, tokenize, AbortSignal, PromptKind},
|
|
};
|
|
|
|
use anyhow::{bail, Context, Result};
|
|
use async_trait::async_trait;
|
|
use futures_util::{Stream, StreamExt};
|
|
use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy, RequestBuilder};
|
|
use serde::Deserialize;
|
|
use serde_json::{json, Value};
|
|
use std::{env, future::Future, time::Duration};
|
|
use tokio::{sync::mpsc::unbounded_channel, time::sleep};
|
|
|
|
#[macro_export]
|
|
macro_rules! register_client {
|
|
(
|
|
$(($module:ident, $name:literal, $config:ident, $client:ident),)+
|
|
) => {
|
|
$(
|
|
mod $module;
|
|
)+
|
|
$(
|
|
use self::$module::$config;
|
|
)+
|
|
|
|
#[derive(Debug, Clone, serde::Deserialize)]
|
|
#[serde(tag = "type")]
|
|
pub enum ClientConfig {
|
|
$(
|
|
#[serde(rename = $name)]
|
|
$config($config),
|
|
)+
|
|
#[serde(other)]
|
|
Unknown,
|
|
}
|
|
|
|
|
|
$(
|
|
#[derive(Debug)]
|
|
pub struct $client {
|
|
global_config: $crate::config::GlobalConfig,
|
|
config: $config,
|
|
model: $crate::client::Model,
|
|
}
|
|
|
|
impl $client {
|
|
pub const NAME: &'static str = $name;
|
|
|
|
pub fn init(global_config: &$crate::config::GlobalConfig) -> Option<Box<dyn Client>> {
|
|
let model = global_config.read().model.clone();
|
|
let config = global_config.read().clients.iter().find_map(|client_config| {
|
|
if let ClientConfig::$config(c) = client_config {
|
|
if Self::name(c) == &model.client_name {
|
|
return Some(c.clone())
|
|
}
|
|
}
|
|
None
|
|
})?;
|
|
|
|
Some(Box::new(Self {
|
|
global_config: global_config.clone(),
|
|
config,
|
|
model,
|
|
}))
|
|
}
|
|
|
|
pub fn name(config: &$config) -> &str {
|
|
config.name.as_deref().unwrap_or(Self::NAME)
|
|
}
|
|
}
|
|
|
|
)+
|
|
|
|
pub fn init_client(config: &$crate::config::GlobalConfig) -> anyhow::Result<Box<dyn Client>> {
|
|
None
|
|
$(.or_else(|| $client::init(config)))+
|
|
.ok_or_else(|| {
|
|
let model = config.read().model.clone();
|
|
anyhow::anyhow!("Unknown client '{}'", &model.client_name)
|
|
})
|
|
}
|
|
|
|
pub fn ensure_model_capabilities(client: &mut dyn Client, capabilities: $crate::client::ModelCapabilities) -> anyhow::Result<()> {
|
|
if !client.model().capabilities.contains(capabilities) {
|
|
let models = client.list_models();
|
|
if let Some(model) = models.into_iter().find(|v| v.capabilities.contains(capabilities)) {
|
|
client.set_model(model);
|
|
} else {
|
|
anyhow::bail!(
|
|
"The current model lacks the corresponding capability."
|
|
);
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub fn list_client_types() -> Vec<&'static str> {
|
|
vec![$($client::NAME,)+]
|
|
}
|
|
|
|
pub fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> {
|
|
$(
|
|
if client == $client::NAME {
|
|
return create_config(&$client::PROMPTS, $client::NAME)
|
|
}
|
|
)+
|
|
anyhow::bail!("Unknown client '{}'", client)
|
|
}
|
|
|
|
pub fn list_models(config: &$crate::config::Config) -> Vec<$crate::client::Model> {
|
|
config
|
|
.clients
|
|
.iter()
|
|
.flat_map(|v| match v {
|
|
$(ClientConfig::$config(c) => $client::list_models(c),)+
|
|
ClientConfig::Unknown => vec![],
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
};
|
|
}
|
|
|
|
#[macro_export]
|
|
macro_rules! openai_compatible_client {
|
|
(
|
|
$config:ident,
|
|
$client:ident,
|
|
$api_base:literal,
|
|
[$(($name:literal, $capabilities:literal, $max_input_tokens:literal $(, $max_output_tokens:literal)? )),+$(,)?]
|
|
) => {
|
|
use $crate::client::openai::openai_build_body;
|
|
use $crate::client::{ExtraConfig, $client, Model, ModelConfig, PromptType, SendData};
|
|
|
|
use $crate::utils::PromptKind;
|
|
|
|
use anyhow::Result;
|
|
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
|
use serde::Deserialize;
|
|
|
|
const API_BASE: &str = $api_base;
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
pub struct $config {
|
|
pub name: Option<String>,
|
|
pub api_key: Option<String>,
|
|
#[serde(default)]
|
|
pub models: Vec<ModelConfig>,
|
|
pub extra: Option<ExtraConfig>,
|
|
}
|
|
|
|
impl_client_trait!(
|
|
$client,
|
|
$crate::client::openai::openai_send_message,
|
|
$crate::client::openai::openai_send_message_streaming
|
|
);
|
|
|
|
|
|
impl $client {
|
|
list_models_fn!(
|
|
$config,
|
|
[
|
|
$(
|
|
($name, $capabilities, $max_input_tokens $(, $max_output_tokens)?),
|
|
)+
|
|
]
|
|
);
|
|
config_get_fn!(api_key, get_api_key);
|
|
|
|
pub const PROMPTS: [PromptType<'static>; 1] =
|
|
[("api_key", "API Key:", false, PromptKind::String)];
|
|
|
|
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
|
let api_key = self.get_api_key().ok();
|
|
|
|
let body = openai_build_body(data, &self.model);
|
|
|
|
let url = format!("{API_BASE}/chat/completions");
|
|
|
|
debug!("Request: {url} {body}");
|
|
|
|
let mut builder = client.post(url).json(&body);
|
|
if let Some(api_key) = api_key {
|
|
builder = builder.bearer_auth(api_key);
|
|
}
|
|
|
|
Ok(builder)
|
|
}
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
#[macro_export]
|
|
macro_rules! client_common_fns {
|
|
() => {
|
|
fn config(
|
|
&self,
|
|
) -> (
|
|
&$crate::config::GlobalConfig,
|
|
&Option<$crate::client::ExtraConfig>,
|
|
) {
|
|
(&self.global_config, &self.config.extra)
|
|
}
|
|
|
|
fn list_models(&self) -> Vec<Model> {
|
|
Self::list_models(&self.config)
|
|
}
|
|
|
|
fn model(&self) -> &Model {
|
|
&self.model
|
|
}
|
|
|
|
fn set_model(&mut self, model: Model) {
|
|
self.model = model;
|
|
}
|
|
};
|
|
}
|
|
|
|
#[macro_export]
|
|
macro_rules! impl_client_trait {
|
|
($client:ident, $send_message:path, $send_message_streaming:path) => {
|
|
#[async_trait::async_trait]
|
|
impl $crate::client::Client for $crate::client::$client {
|
|
client_common_fns!();
|
|
|
|
async fn send_message_inner(
|
|
&self,
|
|
client: &reqwest::Client,
|
|
data: $crate::client::SendData,
|
|
) -> anyhow::Result<String> {
|
|
let builder = self.request_builder(client, data)?;
|
|
$send_message(builder).await
|
|
}
|
|
|
|
async fn send_message_streaming_inner(
|
|
&self,
|
|
client: &reqwest::Client,
|
|
handler: &mut $crate::client::ReplyHandler,
|
|
data: $crate::client::SendData,
|
|
) -> Result<()> {
|
|
let builder = self.request_builder(client, data)?;
|
|
$send_message_streaming(builder, handler).await
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
#[macro_export]
|
|
macro_rules! config_get_fn {
|
|
($field_name:ident, $fn_name:ident) => {
|
|
fn $fn_name(&self) -> anyhow::Result<String> {
|
|
let api_key = self.config.$field_name.clone();
|
|
api_key
|
|
.or_else(|| {
|
|
let env_prefix = Self::name(&self.config);
|
|
let env_name =
|
|
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
|
|
std::env::var(&env_name).ok()
|
|
})
|
|
.ok_or_else(|| {
|
|
anyhow::anyhow!("Miss '{}' in client configuration", stringify!($field_name))
|
|
})
|
|
}
|
|
};
|
|
}
|
|
|
|
#[macro_export]
|
|
macro_rules! list_models_fn {
|
|
($config:ident) => {
|
|
pub fn list_models(local_config: &$config) -> Vec<Model> {
|
|
let client_name = Self::name(local_config);
|
|
Model::from_config(client_name, &local_config.models)
|
|
}
|
|
};
|
|
($config:ident, [$(($name:literal, $capabilities:literal, $max_input_tokens:literal $(, $max_output_tokens:literal)? )),+$(,)?]) => {
|
|
pub fn list_models(local_config: &$config) -> Vec<Model> {
|
|
let client_name = Self::name(local_config);
|
|
if local_config.models.is_empty() {
|
|
vec![
|
|
$(
|
|
Model::new(client_name, $name)
|
|
.set_capabilities($capabilities.into())
|
|
.set_max_input_tokens(Some($max_input_tokens))
|
|
$(.set_max_output_tokens(Some($max_output_tokens)))?
|
|
),+
|
|
]
|
|
} else {
|
|
Model::from_config(client_name, &local_config.models)
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
#[macro_export]
|
|
macro_rules! unsupported_model {
|
|
($name:expr) => {
|
|
anyhow::bail!("Unsupported model '{}'", $name)
|
|
};
|
|
}
|
|
|
|
#[async_trait]
|
|
pub trait Client: Sync + Send {
|
|
fn config(&self) -> (&GlobalConfig, &Option<ExtraConfig>);
|
|
|
|
fn list_models(&self) -> Vec<Model>;
|
|
|
|
fn model(&self) -> &Model;
|
|
|
|
fn set_model(&mut self, model: Model);
|
|
|
|
fn build_client(&self) -> Result<ReqwestClient> {
|
|
let mut builder = ReqwestClient::builder();
|
|
let options = self.config().1;
|
|
let timeout = options
|
|
.as_ref()
|
|
.and_then(|v| v.connect_timeout)
|
|
.unwrap_or(10);
|
|
let proxy = options.as_ref().and_then(|v| v.proxy.clone());
|
|
builder = set_proxy(builder, &proxy)?;
|
|
let client = builder
|
|
.connect_timeout(Duration::from_secs(timeout))
|
|
.build()
|
|
.with_context(|| "Failed to build client")?;
|
|
Ok(client)
|
|
}
|
|
|
|
async fn send_message(&self, input: Input) -> Result<String> {
|
|
let global_config = self.config().0;
|
|
if global_config.read().dry_run {
|
|
let content = global_config.read().echo_messages(&input);
|
|
return Ok(content);
|
|
}
|
|
let client = self.build_client()?;
|
|
let data = global_config.read().prepare_send_data(&input, false)?;
|
|
self.send_message_inner(&client, data)
|
|
.await
|
|
.with_context(|| "Failed to get answer")
|
|
}
|
|
|
|
async fn send_message_streaming(
|
|
&self,
|
|
input: &Input,
|
|
handler: &mut ReplyHandler,
|
|
) -> Result<()> {
|
|
async fn watch_abort(abort: AbortSignal) {
|
|
loop {
|
|
if abort.aborted() {
|
|
break;
|
|
}
|
|
sleep(Duration::from_millis(100)).await;
|
|
}
|
|
}
|
|
let abort = handler.get_abort();
|
|
let input = input.clone();
|
|
tokio::select! {
|
|
ret = async {
|
|
let global_config = self.config().0;
|
|
if global_config.read().dry_run {
|
|
let content = global_config.read().echo_messages(&input);
|
|
let tokens = tokenize(&content);
|
|
for token in tokens {
|
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
|
handler.text(&token)?;
|
|
}
|
|
return Ok(());
|
|
}
|
|
let client = self.build_client()?;
|
|
let data = global_config.read().prepare_send_data(&input, true)?;
|
|
self.send_message_streaming_inner(&client, handler, data).await
|
|
} => {
|
|
handler.done()?;
|
|
ret.with_context(|| "Failed to get answer")
|
|
}
|
|
_ = watch_abort(abort.clone()) => {
|
|
handler.done()?;
|
|
Ok(())
|
|
},
|
|
}
|
|
}
|
|
|
|
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String>;
|
|
|
|
async fn send_message_streaming_inner(
|
|
&self,
|
|
client: &ReqwestClient,
|
|
handler: &mut ReplyHandler,
|
|
data: SendData,
|
|
) -> Result<()>;
|
|
}
|
|
|
|
impl Default for ClientConfig {
|
|
fn default() -> Self {
|
|
Self::OpenAIConfig(OpenAIConfig::default())
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Default)]
|
|
pub struct ExtraConfig {
|
|
pub proxy: Option<String>,
|
|
pub connect_timeout: Option<u64>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct SendData {
|
|
pub messages: Vec<Message>,
|
|
pub temperature: Option<f64>,
|
|
pub top_p: Option<f64>,
|
|
pub stream: bool,
|
|
}
|
|
|
|
pub type PromptType<'a> = (&'a str, &'a str, bool, PromptKind);
|
|
|
|
pub fn create_config(list: &[PromptType], client: &str) -> Result<(String, Value)> {
|
|
let mut config = json!({
|
|
"type": client,
|
|
});
|
|
let mut model = client.to_string();
|
|
for (path, desc, required, kind) in list {
|
|
match kind {
|
|
PromptKind::String => {
|
|
let value = prompt_input_string(desc, *required)?;
|
|
set_config_value(&mut config, path, kind, &value);
|
|
if *path == "name" {
|
|
model = value;
|
|
}
|
|
}
|
|
PromptKind::Integer => {
|
|
let value = prompt_input_integer(desc, *required)?;
|
|
set_config_value(&mut config, path, kind, &value);
|
|
}
|
|
}
|
|
}
|
|
|
|
let clients = json!(vec![config]);
|
|
Ok((model, clients))
|
|
}
|
|
|
|
pub async fn send_stream(
|
|
input: &Input,
|
|
client: &dyn Client,
|
|
config: &GlobalConfig,
|
|
abort: AbortSignal,
|
|
) -> Result<String> {
|
|
let (tx, rx) = unbounded_channel();
|
|
let mut stream_handler = ReplyHandler::new(tx, abort.clone());
|
|
|
|
let (send_ret, rend_ret) = tokio::join!(
|
|
client.send_message_streaming(input, &mut stream_handler),
|
|
render_stream(rx, config, abort.clone()),
|
|
);
|
|
if let Err(err) = rend_ret {
|
|
render_error(err, config.read().highlight);
|
|
}
|
|
let output = stream_handler.get_buffer().to_string();
|
|
match send_ret {
|
|
Ok(_) => {
|
|
println!();
|
|
Ok(output)
|
|
}
|
|
Err(err) => {
|
|
if !output.is_empty() {
|
|
println!();
|
|
}
|
|
Err(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[allow(unused)]
|
|
pub async fn send_message_as_streaming<F, Fut>(
|
|
builder: RequestBuilder,
|
|
handler: &mut ReplyHandler,
|
|
f: F,
|
|
) -> Result<()>
|
|
where
|
|
F: FnOnce(RequestBuilder) -> Fut,
|
|
Fut: Future<Output = Result<String>>,
|
|
{
|
|
let text = f(builder).await?;
|
|
handler.text(&text)?;
|
|
handler.done()?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn catch_error(data: &Value, status: u16) -> Result<()> {
|
|
if (200..300).contains(&status) {
|
|
return Ok(());
|
|
}
|
|
debug!("Invalid response, status: {status}, data: {data}");
|
|
if let Some(error) = data["error"].as_object() {
|
|
if let (Some(typ), Some(message)) = (error["type"].as_str(), error["message"].as_str()) {
|
|
bail!("{message} (type: {typ})");
|
|
}
|
|
} else if let Some(error) = data[0]["error"].as_object() {
|
|
if let (Some(status), Some(message)) = (error["status"].as_str(), error["message"].as_str())
|
|
{
|
|
bail!("{message} (status: {status})")
|
|
}
|
|
} else if let Some(error) = data["error"].as_str() {
|
|
bail!("{error}");
|
|
} else if let Some(message) = data["message"].as_str() {
|
|
bail!("{message}");
|
|
}
|
|
bail!("Invalid response data: {data} (status: {status})");
|
|
}
|
|
|
|
pub fn maybe_catch_error(data: &Value) -> Result<()> {
|
|
if let (Some(code), Some(message)) = (data["code"].as_str(), data["message"].as_str()) {
|
|
debug!("Invalid response: {}", data);
|
|
bail!("{message} (code: {code})");
|
|
} else if let (Some(error_code), Some(error_msg)) =
|
|
(data["error_code"].as_number(), data["error_msg"].as_str())
|
|
{
|
|
debug!("Invalid response: {}", data);
|
|
bail!("{error_msg} (error_code: {error_code})");
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn json_stream<S, F>(mut stream: S, mut handle: F) -> Result<()>
|
|
where
|
|
S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,
|
|
F: FnMut(&str) -> Result<()>,
|
|
{
|
|
let mut buffer = vec![];
|
|
let mut cursor = 0;
|
|
let mut start = 0;
|
|
let mut balances = vec![];
|
|
let mut quoting = false;
|
|
let mut escape = false;
|
|
while let Some(chunk) = stream.next().await {
|
|
let chunk = chunk?;
|
|
let chunk = std::str::from_utf8(&chunk)?;
|
|
buffer.extend(chunk.chars());
|
|
for i in cursor..buffer.len() {
|
|
let ch = buffer[i];
|
|
if quoting {
|
|
if ch == '\\' {
|
|
escape = !escape;
|
|
} else {
|
|
if !escape && ch == '"' {
|
|
quoting = false;
|
|
}
|
|
escape = false;
|
|
}
|
|
continue;
|
|
}
|
|
match ch {
|
|
'"' => {
|
|
quoting = true;
|
|
escape = false;
|
|
}
|
|
'{' => {
|
|
if balances.is_empty() {
|
|
start = i;
|
|
}
|
|
balances.push(ch);
|
|
}
|
|
'[' => {
|
|
if start != 0 {
|
|
balances.push(ch);
|
|
}
|
|
}
|
|
'}' => {
|
|
balances.pop();
|
|
if balances.is_empty() {
|
|
let value: String = buffer[start..=i].iter().collect();
|
|
handle(&value)?;
|
|
}
|
|
}
|
|
']' => {
|
|
balances.pop();
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
cursor = buffer.len();
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn set_config_value(json: &mut Value, path: &str, kind: &PromptKind, value: &str) {
|
|
let segs: Vec<&str> = path.split('.').collect();
|
|
match segs.as_slice() {
|
|
[name] => json[name] = to_json(kind, value),
|
|
[scope, name] => match scope.split_once('[') {
|
|
None => {
|
|
if json.get(scope).is_none() {
|
|
let mut obj = json!({});
|
|
obj[name] = to_json(kind, value);
|
|
json[scope] = obj;
|
|
} else {
|
|
json[scope][name] = to_json(kind, value);
|
|
}
|
|
}
|
|
Some((scope, _)) => {
|
|
if json.get(scope).is_none() {
|
|
let mut obj = json!({});
|
|
obj[name] = to_json(kind, value);
|
|
json[scope] = json!([obj]);
|
|
} else {
|
|
json[scope][0][name] = to_json(kind, value);
|
|
}
|
|
}
|
|
},
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
fn to_json(kind: &PromptKind, value: &str) -> Value {
|
|
if value.is_empty() {
|
|
return Value::Null;
|
|
}
|
|
match kind {
|
|
PromptKind::String => value.into(),
|
|
PromptKind::Integer => match value.parse::<i32>() {
|
|
Ok(value) => value.into(),
|
|
Err(_) => value.into(),
|
|
},
|
|
}
|
|
}
|
|
|
|
fn set_proxy(builder: ClientBuilder, proxy: &Option<String>) -> Result<ClientBuilder> {
|
|
let proxy = if let Some(proxy) = proxy {
|
|
if proxy.is_empty() || proxy == "false" || proxy == "-" {
|
|
return Ok(builder);
|
|
}
|
|
proxy.clone()
|
|
} else if let Ok(proxy) = env::var("HTTPS_PROXY").or_else(|_| env::var("ALL_PROXY")) {
|
|
proxy
|
|
} else {
|
|
return Ok(builder);
|
|
};
|
|
let builder =
|
|
builder.proxy(Proxy::all(&proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
|
|
Ok(builder)
|
|
}
|