refactor: unified access token management (#486)

pull/489/head
sigoden 2 weeks ago committed by GitHub
parent 9b283024b4
commit 7c6f75a139
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,34 @@
use anyhow::{anyhow, Result};
use chrono::Utc;
use indexmap::IndexMap;
use lazy_static::lazy_static;
use parking_lot::RwLock;
lazy_static! {
static ref ACCESS_TOKENS: RwLock<IndexMap<String, (String, i64)>> =
RwLock::new(IndexMap::new());
}
pub fn get_access_token(client_name: &str) -> Result<String> {
ACCESS_TOKENS
.read()
.get(client_name)
.map(|(token, _)| token.clone())
.ok_or_else(|| anyhow!("Invalid access token"))
}
pub fn is_valid_access_token(client_name: &str) -> bool {
let access_tokens = ACCESS_TOKENS.read();
let (token, expires_at) = match access_tokens.get(client_name) {
Some(v) => v,
None => return false,
};
!token.is_empty() && Utc::now().timestamp() < *expires_at
}
pub fn set_access_token(client_name: &str, token: String, expires_at: i64) {
let mut access_tokens = ACCESS_TOKENS.write();
let entry = access_tokens.entry(client_name.to_string()).or_default();
entry.0 = token;
entry.1 = expires_at;
}

@ -102,8 +102,8 @@ macro_rules! register_client {
}
}
pub fn name(config: &$config) -> &str {
config.name.as_deref().unwrap_or(Self::NAME)
pub fn name(local_config: &$config) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
}
@ -184,6 +184,10 @@ macro_rules! client_common_fns {
Self::list_models(&self.config)
}
fn name(&self) -> &str {
Self::name(&self.config)
}
fn model(&self) -> &Model {
&self.model
}
@ -259,6 +263,8 @@ pub trait Client: Sync + Send {
fn list_models(&self) -> Vec<Model>;
fn name(&self) -> &str;
fn model(&self) -> &Model;
fn model_mut(&mut self) -> &mut Model;

@ -1,3 +1,4 @@
use super::access_token::*;
use super::{
maybe_catch_error, patch_system_message, sse_stream, Client, CompletionDetails, ErnieClient,
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
@ -5,7 +6,6 @@ use super::{
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use chrono::Utc;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
@ -14,8 +14,6 @@ use std::env;
const API_BASE: &str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1";
const ACCESS_TOKEN_URL: &str = "https://aip.baidubce.com/oauth/2.0/token";
static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0);
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ErnieConfig {
pub name: Option<String>,
@ -34,11 +32,11 @@ impl ErnieClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let body = build_body(data, &self.model);
let access_token = get_access_token(self.name())?;
let url = format!(
"{API_BASE}/wenxinworkshop/chat/{}?access_token={}",
"{API_BASE}/wenxinworkshop/chat/{}?access_token={access_token}",
&self.model.name,
unsafe { &ACCESS_TOKEN.0 }
);
debug!("Ernie Request: {url} {body}");
@ -49,7 +47,8 @@ impl ErnieClient {
}
async fn prepare_access_token(&self) -> Result<()> {
if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } {
let client_name = self.name();
if !is_valid_access_token(client_name) {
let env_prefix = Self::name(&self.config).to_uppercase();
let api_key = self.config.api_key.clone();
let api_key = api_key
@ -65,7 +64,7 @@ impl ErnieClient {
let token = fetch_access_token(&client, &api_key, &secret_key)
.await
.with_context(|| "Failed to fetch access token")?;
unsafe { ACCESS_TOKEN = (token, 86400) };
set_access_token(client_name, token, 86400);
}
Ok(())
}

@ -1,5 +1,6 @@
#[macro_use]
mod common;
mod access_token;
mod message;
mod model;
mod prompt_format;

@ -1,3 +1,4 @@
use super::access_token::*;
use super::{
catch_error, json_stream, message::*, patch_system_message, Client, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler,
@ -12,8 +13,6 @@ use serde::Deserialize;
use serde_json::{json, Value};
use std::path::PathBuf;
static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation
#[derive(Debug, Clone, Deserialize, Default)]
pub struct VertexAIConfig {
pub name: Option<String>,
@ -39,6 +38,7 @@ impl VertexAIClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let project_id = self.get_project_id()?;
let location = self.get_location()?;
let access_token = get_access_token(self.name())?;
let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers");
@ -52,10 +52,7 @@ impl VertexAIClient {
debug!("VertexAI Request: {url} {body}");
let builder = client
.post(url)
.bearer_auth(unsafe { &ACCESS_TOKEN.0 })
.json(&body);
let builder = client.post(url).bearer_auth(access_token).json(&body);
Ok(builder)
}
@ -70,7 +67,7 @@ impl Client for VertexAIClient {
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionDetails)> {
prepare_access_token(client, &self.config.adc_file).await?;
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
gemini_send_message(builder).await
}
@ -81,7 +78,7 @@ impl Client for VertexAIClient {
handler: &mut SseHandler,
data: SendData,
) -> Result<()> {
prepare_access_token(client, &self.config.adc_file).await?;
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
gemini_send_message_streaming(builder, handler).await
}
@ -217,20 +214,24 @@ pub(crate) fn gemini_build_body(
Ok(body)
}
async fn prepare_access_token(client: &reqwest::Client, adc_file: &Option<String>) -> Result<()> {
if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } {
let (token, expires_in) = fetch_gcloud_access_token(client, adc_file)
pub async fn prepare_gcloud_access_token(
client: &reqwest::Client,
client_name: &str,
adc_file: &Option<String>,
) -> Result<()> {
if !is_valid_access_token(client_name) {
let (token, expires_in) = fetch_access_token(client, adc_file)
.await
.with_context(|| "Failed to fetch access token")?;
let expires_at = Utc::now()
+ Duration::try_seconds(expires_in)
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) };
set_access_token(client_name, token, expires_at.timestamp())
}
Ok(())
}
pub async fn fetch_gcloud_access_token(
async fn fetch_access_token(
client: &reqwest::Client,
file: &Option<String>,
) -> Result<(String, i64)> {

@ -1,18 +1,16 @@
use super::access_token::*;
use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming};
use super::vertexai::fetch_gcloud_access_token;
use super::vertexai::prepare_gcloud_access_token;
use super::{
Client, CompletionDetails, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData,
SseHandler, VertexAIClaudeClient,
};
use anyhow::{anyhow, Context, Result};
use anyhow::Result;
use async_trait::async_trait;
use chrono::{Duration, Utc};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation
#[derive(Debug, Clone, Deserialize, Default)]
pub struct VertexAIClaudeConfig {
pub name: Option<String>,
@ -36,6 +34,7 @@ impl VertexAIClaudeClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let project_id = self.get_project_id()?;
let location = self.get_location()?;
let access_token = get_access_token(self.name())?;
let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers");
let url = format!(
@ -51,10 +50,7 @@ impl VertexAIClaudeClient {
debug!("VertexAIClaude Request: {url} {body}");
let builder = client
.post(url)
.bearer_auth(unsafe { &ACCESS_TOKEN.0 })
.json(&body);
let builder = client.post(url).bearer_auth(access_token).json(&body);
Ok(builder)
}
@ -69,7 +65,7 @@ impl Client for VertexAIClaudeClient {
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionDetails)> {
prepare_access_token(client, &self.config.adc_file).await?;
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
claude_send_message(builder).await
}
@ -80,21 +76,8 @@ impl Client for VertexAIClaudeClient {
handler: &mut SseHandler,
data: SendData,
) -> Result<()> {
prepare_access_token(client, &self.config.adc_file).await?;
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
claude_send_message_streaming(builder, handler).await
}
}
async fn prepare_access_token(client: &reqwest::Client, adc_file: &Option<String>) -> Result<()> {
if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } {
let (token, expires_in) = fetch_gcloud_access_token(client, adc_file)
.await
.with_context(|| "Failed to fetch access token")?;
let expires_at = Utc::now()
+ Duration::try_seconds(expires_in)
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) };
}
Ok(())
}

Loading…
Cancel
Save