1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
/// A module for managing a Google API access token
use goauth::{
    auth::{JwtClaims, Token},
    credentials::Credentials,
};
use log::*;
use smpl_jwt::Jwt;
use std::{
    sync::{
        atomic::{AtomicBool, Ordering},
        {Arc, RwLock},
    },
    time::Instant,
};

pub use goauth::scopes::Scope;

fn load_credentials() -> Result<Credentials, String> {
    // Use standard GOOGLE_APPLICATION_CREDENTIALS environment variable
    let credentials_file = std::env::var("GOOGLE_APPLICATION_CREDENTIALS")
        .map_err(|_| "GOOGLE_APPLICATION_CREDENTIALS environment variable not found".to_string())?;

    Credentials::from_file(&credentials_file).map_err(|err| {
        format!(
            "Failed to read GCP credentials from {}: {}",
            credentials_file, err
        )
    })
}

#[derive(Clone)]
pub struct AccessToken {
    credentials: Credentials,
    scope: Scope,
    refresh_active: Arc<AtomicBool>,
    token: Arc<RwLock<(Token, Instant)>>,
}

impl AccessToken {
    pub async fn new(scope: Scope) -> Result<Self, String> {
        let credentials = load_credentials()?;
        if let Err(err) = credentials.rsa_key() {
            Err(format!("Invalid rsa key: {}", err))
        } else {
            let token = Arc::new(RwLock::new(Self::get_token(&credentials, &scope).await?));
            let access_token = Self {
                credentials,
                scope,
                token,
                refresh_active: Arc::new(AtomicBool::new(false)),
            };
            Ok(access_token)
        }
    }

    /// The project that this token grants access to
    pub fn project(&self) -> String {
        self.credentials.project()
    }

    async fn get_token(
        credentials: &Credentials,
        scope: &Scope,
    ) -> Result<(Token, Instant), String> {
        info!("Requesting token for {:?} scope", scope);
        let claims = JwtClaims::new(
            credentials.iss(),
            scope,
            credentials.token_uri(),
            None,
            None,
        );
        let jwt = Jwt::new(claims, credentials.rsa_key().unwrap(), None);

        let token = goauth::get_token(&jwt, credentials)
            .await
            .map_err(|err| format!("Failed to refresh access token: {}", err))?;

        info!("Token expires in {} seconds", token.expires_in());
        Ok((token, Instant::now()))
    }

    /// Call this function regularly to ensure the access token does not expire
    pub async fn refresh(&self) {
        // Check if it's time to try a token refresh
        {
            let token_r = self.token.read().unwrap();
            if token_r.1.elapsed().as_secs() < token_r.0.expires_in() as u64 / 2 {
                return;
            }

            #[allow(deprecated)]
            if self
                .refresh_active
                .compare_and_swap(false, true, Ordering::Relaxed)
            {
                // Refresh already pending
                return;
            }
        }

        info!("Refreshing token");
        let new_token = Self::get_token(&self.credentials, &self.scope).await;
        {
            let mut token_w = self.token.write().unwrap();
            match new_token {
                Ok(new_token) => *token_w = new_token,
                Err(err) => warn!("{}", err),
            }
            self.refresh_active.store(false, Ordering::Relaxed);
        }
    }

    /// Return an access token suitable for use in an HTTP authorization header
    pub fn get(&self) -> String {
        let token_r = self.token.read().unwrap();
        format!("{} {}", token_r.0.token_type(), token_r.0.access_token())
    }
}