diff --git a/Cargo.lock b/Cargo.lock index 07172cc..618d2b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -884,9 +884,9 @@ checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" [[package]] name = "codee" -version = "0.3.0" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f18d705321923b1a9358e3fc3c57c3b50171196827fc7f5f10b053242aca627" +checksum = "fd8bbfdadf2f8999c6e404697bc08016dce4a3d77dec465b36c9a0652fdb3327" dependencies = [ "serde", "serde_json", @@ -1229,9 +1229,9 @@ dependencies = [ [[package]] name = "derive-where" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e73f2692d4bd3cac41dca28934a39894200c9fabf49586d77d0e5954af1d7902" +checksum = "510c292c8cf384b1a340b816a9a6cf2599eb8f566a44949024af88418000c50b" dependencies = [ "proc-macro2", "quote", @@ -1446,9 +1446,9 @@ checksum = "a1731451909bde27714eacba19c2566362a7f35224f52b153d3f42cf60f72472" [[package]] name = "errno" -version = "0.3.12" +version = "0.3.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18" +checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" dependencies = [ "libc", "windows-sys 0.59.0", @@ -3410,9 +3410,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.12" +version = "0.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee4e529991f949c5e25755532370b8af5d114acae52326361d68d47af64aa842" +checksum = "fcebb1209ee276352ef14ff8732e24cc2b02bbac986cd74a4c81bcb2f9881970" dependencies = [ "cfg_aliases", "libc", @@ -3455,9 +3455,9 @@ dependencies = [ [[package]] name = "r-efi" -version = "5.2.0" +version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "radium" @@ -3585,6 +3585,26 @@ dependencies = [ "bitflags", ] +[[package]] +name = "ref-cast" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.102", +] + [[package]] name = "regex" version = "1.11.1" @@ -3809,9 +3829,9 @@ dependencies = [ [[package]] name = "rust_decimal" -version = "1.37.1" +version = "1.37.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faa7de2ba56ac291bd90c6b9bece784a52ae1411f9506544b3eae36dd2356d50" +checksum = "b203a6425500a03e0919c42d3c47caca51e79f1132046626d2c8871c5092035d" dependencies = [ "arrayvec", "borsh", @@ -3859,9 +3879,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.27" +version = "0.23.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" +checksum = "7160e3e10bf4535308537f3c4e1641468cd0e485175d6163087c0393c7d46643" dependencies = [ "once_cell", "ring", @@ -3922,6 +3942,18 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "schemars" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd191f9397d57d581cddd31014772520aa448f65ef991055d7f61582c65165f" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -4110,6 +4142,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "secrecy" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a" +dependencies = [ + "zeroize", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -4243,15 +4284,16 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.12.0" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6b6f7f2fcb69f747921f79f3926bd1e203fce4fef62c268dd3abfb6d86029aa" +checksum = "bf65a400f8f66fb7b0552869ad70157166676db75ed8181f8104ea91cf9d0b42" dependencies = [ "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", "indexmap 2.9.0", + "schemars", "serde", "serde_derive", "serde_json", @@ -4261,9 +4303,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.12.0" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d00caa5193a3c8362ac2b73be6b9e768aa5a4b2f721d8f4b339600c3cb51f8e" +checksum = "81679d9ed988d5e9a5e6531dc3f2c28efbd639cbd1dfb628df08edea6004da77" dependencies = [ "darling", "proc-macro2", @@ -4424,6 +4466,24 @@ version = "0.0.4" name = "shield-email" version = "0.0.4" +[[package]] +name = "shield-examples-axum" +version = "0.0.4" +dependencies = [ + "axum", + "shield", + "shield-axum", + "shield-memory", + "shield-oidc", + "time", + "tokio", + "tower-sessions", + "tracing", + "tracing-subscriber", + "utoipa", + "utoipa-swagger-ui", +] + [[package]] name = "shield-examples-leptos-actix" version = "0.0.4" @@ -4524,6 +4584,7 @@ dependencies = [ "async-trait", "serde", "shield", + "shield-oauth", "shield-oidc", "uuid", ] @@ -4533,8 +4594,10 @@ name = "shield-oauth" version = "0.0.4" dependencies = [ "async-trait", + "bon", "chrono", "oauth2", + "secrecy", "serde", "shield", ] @@ -4548,6 +4611,7 @@ dependencies = [ "chrono", "oauth2", "openidconnect", + "secrecy", "serde", "shield", "tracing", @@ -4561,6 +4625,7 @@ dependencies = [ "chrono", "sea-orm", "sea-orm-migration", + "secrecy", "serde", "serde_json", "shield", @@ -5081,12 +5146,11 @@ dependencies = [ [[package]] name = "thread_local" -version = "1.1.8" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" dependencies = [ "cfg-if", - "once_cell", ] [[package]] @@ -5410,9 +5474,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.29" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1ffbcf9c6f6b99d386e7444eb608ba646ae452a36b39737deb9663b610f662" +checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" dependencies = [ "proc-macro2", "quote", @@ -5847,9 +5911,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2853738d1cc4f2da3a225c18ec6c3721abb31961096e9dbf5ab35fa88b19cfdb" +checksum = "8782dd5a41a24eed3a4f40b606249b3e236ca61adf1f25ea4d45c73de122b502" dependencies = [ "rustls-pki-types", ] @@ -6173,18 +6237,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.25" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" +checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.25" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" +checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 892647f..3311f55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ leptos_meta = "0.8.0-beta" leptos_router = "0.8.0-beta" sea-orm = "1.1.2" sea-orm-migration = "1.1.2" +secrecy = "0.10.3" serde = "1.0.215" serde_json = "1.0.133" shield = { path = "./packages/core/shield", version = "0.0.4" } diff --git a/examples/axum/Cargo.toml b/examples/axum/Cargo.toml new file mode 100644 index 0000000..ca292b9 --- /dev/null +++ b/examples/axum/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "shield-examples-axum" +description = "Example with Axum." +publish = false + +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] +axum = { workspace = true } +shield.workspace = true +shield-axum = { workspace = true, features = ["utoipa"] } +shield-memory = { workspace = true, features = ["method-oidc"] } +shield-oidc = { workspace = true, features = ["native-tls"] } +time = "0.3.37" +tokio = { workspace = true, features = ["rt-multi-thread"] } +tower-sessions = { workspace = true } +tracing.workspace = true +tracing-subscriber.workspace = true +utoipa.workspace = true +utoipa-swagger-ui = { version = "9.0.0", features = ["axum", "vendored"] } diff --git a/examples/axum/src/main.rs b/examples/axum/src/main.rs new file mode 100644 index 0000000..57611ca --- /dev/null +++ b/examples/axum/src/main.rs @@ -0,0 +1,77 @@ +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + +#[tokio::main] +async fn main() { + use std::sync::Arc; + + use axum::{Router, middleware::from_fn, routing::get}; + + use shield::{Shield, ShieldOptions}; + use shield_axum::{AuthRoutes, ShieldLayer, auth_required}; + use shield_memory::{MemoryStorage, User}; + use shield_oidc::{Keycloak, OidcMethod}; + use time::Duration; + use tokio::net::TcpListener; + use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; + use tracing::{info, level_filters::LevelFilter}; + use utoipa::OpenApi; + use utoipa_swagger_ui::SwaggerUi; + + // Initialize tracing + tracing_subscriber::fmt() + .with_max_level(LevelFilter::DEBUG) + .init(); + + // Configuration + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 3000); + + // Initialize sessions + let session_store = MemoryStore::default(); + let session_layer = SessionManagerLayer::new(session_store) + .with_secure(false) + .with_expiry(Expiry::OnInactivity(Duration::minutes(10))); + + // Initialize Shield + let storage = MemoryStorage::new(); + let shield = Shield::new( + storage.clone(), + vec![Arc::new( + OidcMethod::new(storage).with_providers([Keycloak::builder( + "keycloak", + "http://localhost:18080/realms/Shield", + "client1", + ) + .client_secret("xcpQsaGbRILTljPtX4npjmYMBjKrariJ") + .redirect_url(format!( + "http://localhost:{}/api/auth/sign-in/callback/oidc/keycloak", + addr.port() + )) + .build()]), + )], + ShieldOptions::default(), + ); + let shield_layer = ShieldLayer::new(shield.clone()); + + // Initialize OpenAPI specification (optional) + #[derive(OpenApi)] + #[openapi(nest( + (path = "/api/auth", api = AuthRoutes, tags = ["auth"]), + ))] + struct Docs; + + // Initialize router + let router = Router::new() + .route("/api/protected", get(async || "Protected")) + .route_layer(from_fn(auth_required::)) + .nest("/api/auth", AuthRoutes::router::()) + .merge(SwaggerUi::new("/api-docs").url("/api/openapi.json", Docs::openapi())) + .layer(shield_layer) + .layer(session_layer); + + // Start app + info!("listening on http://{}", &addr); + let listener = TcpListener::bind(&addr).await.unwrap(); + axum::serve(listener, router.into_make_service()) + .await + .unwrap(); +} diff --git a/examples/leptos-actix/src/app.rs b/examples/leptos-actix/src/app.rs index 4083b87..68b2daa 100644 --- a/examples/leptos-actix/src/app.rs +++ b/examples/leptos-actix/src/app.rs @@ -4,7 +4,6 @@ use leptos_router::{ components::{Route, Router, Routes}, path, }; -use shield_leptos::routes::{SignIn, SignOut}; use crate::home::HomePage; @@ -37,9 +36,6 @@ pub fn App() -> impl IntoView {
- - -
diff --git a/examples/leptos-axum/src/app.rs b/examples/leptos-axum/src/app.rs index 18ba341..9841806 100644 --- a/examples/leptos-axum/src/app.rs +++ b/examples/leptos-axum/src/app.rs @@ -4,7 +4,6 @@ use leptos_router::{ components::{Route, Router, Routes}, path, }; -use shield_leptos::routes::{SignIn, SignOut}; use crate::home::HomePage; @@ -37,9 +36,6 @@ pub fn App() -> impl IntoView {
- - -
diff --git a/packages/core/shield/src/action.rs b/packages/core/shield/src/action.rs new file mode 100644 index 0000000..d65f1a3 --- /dev/null +++ b/packages/core/shield/src/action.rs @@ -0,0 +1,108 @@ +use std::any::Any; + +use async_trait::async_trait; + +use crate::{ + error::ShieldError, form::Form, provider::Provider, request::Request, response::Response, + session::Session, +}; + +pub const SIGN_IN_ACTION_ID: &str = "sign-in"; +pub const SIGN_IN_CALLBACK_ACTION_ID: &str = "sign-in-callback"; +pub const SIGN_OUT_ACTION_ID: &str = "sign-out"; + +#[async_trait] +pub trait Action: ErasedAction + Send + Sync { + fn id(&self) -> String; + + fn render(&self, provider: P) -> Form; + + async fn call( + &self, + provider: P, + session: Session, + request: Request, + ) -> Result; +} + +#[async_trait] +pub trait ErasedAction: Send + Sync { + fn erased_id(&self) -> String; + + fn erased_render(&self, provider: Box) -> Form; + + async fn erased_call( + &self, + provider: Box, + session: Session, + request: Request, + ) -> Result; +} + +#[macro_export] +macro_rules! erased_action { + ($action:ident $(, < $( $generic_name:ident : $generic_type:ident ),+ > )*) => { + #[async_trait] + impl $( < $( $generic_name: $generic_type + 'static ),+ > )* $crate::ErasedAction for $action $( < $( $generic_name ),+ > )* { + fn erased_id(&self) -> String { + self.id() + } + + fn erased_render(&self, provider: Box) -> $crate::Form { + self.render(*provider.downcast().expect("TODO")) + } + + async fn erased_call( + &self, + provider: Box, + session: $crate::Session, + request: $crate::Request, + ) -> Result<$crate::Response, ShieldError> { + self.call(*provider.downcast().expect("TODO"), session, request) + .await + } + } + }; +} + +#[cfg(test)] +pub(crate) mod tests { + use async_trait::async_trait; + + use crate::{ + error::ShieldError, form::Form, provider::tests::TestProvider, request::Request, + response::Response, session::Session, + }; + + use super::Action; + + pub const TEST_ACTION_ID: &str = "action"; + + #[derive(Default)] + pub struct TestAction {} + + #[async_trait] + impl Action for TestAction { + fn id(&self) -> String { + TEST_ACTION_ID.to_owned() + } + + fn render(&self, _provider: TestProvider) -> Form { + Form { + inputs: vec![], + attributes: None, + } + } + + async fn call( + &self, + _provider: TestProvider, + _session: Session, + _request: Request, + ) -> Result { + Ok(Response::Default) + } + } + + erased_action!(TestAction); +} diff --git a/packages/core/shield/src/error.rs b/packages/core/shield/src/error.rs index ea25668..760956f 100644 --- a/packages/core/shield/src/error.rs +++ b/packages/core/shield/src/error.rs @@ -3,15 +3,28 @@ use thiserror::Error; #[derive(Debug, Error)] pub enum MethodError { #[error("method `{0}` not found")] - MethodNotFound(String), + NotFound(String), +} + +#[derive(Debug, Error)] +pub enum ActionError { + #[error("action `{0}` not found")] + NotFound(String), } #[derive(Debug, Error)] pub enum ProviderError { #[error("provider is missing")] - ProviderMissing, - #[error("provider `{0}` not found")] - ProviderNotFound(String), + Missing, + #[error("{}", provider_not_found_message(.0))] + NotFound(Option), +} + +fn provider_not_found_message(provider_id: &Option) -> String { + match provider_id { + Some(id) => format!("provider `{id}` not found"), + None => "provider not found".to_owned(), + } } #[derive(Debug, Error)] @@ -52,6 +65,8 @@ pub enum ShieldError { #[error(transparent)] Method(#[from] MethodError), #[error(transparent)] + Action(#[from] ActionError), + #[error(transparent)] Provider(#[from] ProviderError), #[error(transparent)] Configuration(#[from] ConfigurationError), diff --git a/packages/core/shield/src/lib.rs b/packages/core/shield/src/lib.rs index a97d459..f8718e7 100644 --- a/packages/core/shield/src/lib.rs +++ b/packages/core/shield/src/lib.rs @@ -1,3 +1,4 @@ +mod action; mod error; mod form; mod method; @@ -11,6 +12,7 @@ mod shield_dyn; mod storage; mod user; +pub use action::*; pub use error::*; pub use form::*; pub use method::*; diff --git a/packages/core/shield/src/method.rs b/packages/core/shield/src/method.rs index b3a3378..7d44238 100644 --- a/packages/core/shield/src/method.rs +++ b/packages/core/shield/src/method.rs @@ -1,45 +1,94 @@ +use std::any::Any; + use async_trait::async_trait; -use crate::{ - error::ShieldError, - options::ShieldOptions, - provider::Provider, - request::{SignInCallbackRequest, SignInRequest, SignOutRequest}, - response::Response, - session::Session, -}; +use crate::{ErasedAction, action::Action, error::ShieldError, provider::Provider}; #[async_trait] -pub trait Method: Send + Sync { +pub trait Method: Send + Sync { fn id(&self) -> String; - async fn providers(&self) -> Result>, ShieldError>; + fn actions(&self) -> Vec>>; - async fn provider_by_id( - &self, - provider_id: &str, - ) -> Result>, ShieldError>; + fn action_by_id(&self, action_id: &str) -> Option>> { + self.actions() + .into_iter() + .find(|action| action.id() == action_id) + } - async fn sign_in( - &self, - request: SignInRequest, - session: Session, - options: &ShieldOptions, - ) -> Result; + async fn providers(&self) -> Result, ShieldError>; - async fn sign_in_callback( - &self, - request: SignInCallbackRequest, - session: Session, - options: &ShieldOptions, - ) -> Result; + async fn provider_by_id(&self, provider_id: Option<&str>) -> Result, ShieldError> { + Ok(self + .providers() + .await? + .into_iter() + .find(|provider| provider.id().as_deref() == provider_id)) + } +} + +#[async_trait] +pub trait ErasedMethod: Send + Sync { + fn erased_id(&self) -> String; + + fn erased_actions(&self) -> Vec>; + + fn erased_action_by_id(&self, action_id: &str) -> Option>; + + async fn erased_providers(&self) -> Result>, ShieldError>; - async fn sign_out( + async fn erased_provider_by_id( &self, - request: SignOutRequest, - session: Session, - options: &ShieldOptions, - ) -> Result, ShieldError>; + provider_id: Option<&str>, + ) -> Result>, ShieldError>; +} + +#[macro_export] +macro_rules! erased_method { + ($method:ident $(, < $( $generic_name:ident : $generic_type:ident ),+ > )*) => { + #[async_trait] + impl $( < $( $generic_name: $generic_type + 'static ),+ > )* $crate::ErasedMethod for $method $( < $( $generic_name ),+ > )* { + fn erased_id(&self) -> String { + self.id() + } + + fn erased_actions(&self) -> Vec> { + self.actions() + .into_iter() + .map(|action| action as Box) + .collect() + } + + fn erased_action_by_id( + &self, + action_id: &str, + ) -> Option> { + self.action_by_id(action_id) + .map(|action| action as Box) + } + + async fn erased_providers( + &self, + ) -> Result>, ShieldError> { + self.providers().await.map(|providers| { + providers + .into_iter() + .map(|provider| Box::new(provider) as Box) + .collect() + }) + } + + async fn erased_provider_by_id( + &self, + provider_id: Option<&str>, + ) -> Result>, ShieldError> { + self.provider_by_id(provider_id).await.map(|provider| { + provider + .map(|provider| Box::new(provider) as Box) + }) + } + } + }; } #[cfg(test)] @@ -47,12 +96,9 @@ pub(crate) mod tests { use async_trait::async_trait; use crate::{ - ShieldOptions, + action::{Action, tests::TestAction}, error::ShieldError, - provider::Provider, - request::{SignInCallbackRequest, SignInRequest, SignOutRequest}, - response::Response, - session::Session, + provider::tests::TestProvider, }; use super::Method; @@ -65,54 +111,27 @@ pub(crate) mod tests { } impl TestMethod { - pub fn with_id(mut self, id: &'static str) -> Self { + // TODO + pub fn _with_id(mut self, id: &'static str) -> Self { self.id = Some(id); self } } #[async_trait] - impl Method for TestMethod { + impl Method for TestMethod { fn id(&self) -> String { self.id.unwrap_or(TEST_METHOD_ID).to_owned() } - async fn providers(&self) -> Result>, ShieldError> { - Ok(vec![]) - } - - async fn provider_by_id( - &self, - _provider_id: &str, - ) -> Result>, ShieldError> { - Ok(None) + fn actions(&self) -> Vec>> { + vec![Box::new(TestAction::default())] } - async fn sign_in( - &self, - _request: SignInRequest, - _session: Session, - _options: &ShieldOptions, - ) -> Result { - todo!("redirect back?") - } - - async fn sign_in_callback( - &self, - _request: SignInCallbackRequest, - _session: Session, - _options: &ShieldOptions, - ) -> Result { - todo!("redirect back?") - } - - async fn sign_out( - &self, - _request: SignOutRequest, - _session: Session, - _options: &ShieldOptions, - ) -> Result, ShieldError> { - Ok(None) + async fn providers(&self) -> Result, ShieldError> { + Ok(vec![TestProvider::default()]) } } + + erased_method!(TestMethod); } diff --git a/packages/core/shield/src/provider.rs b/packages/core/shield/src/provider.rs index b33e84f..ad1f51f 100644 --- a/packages/core/shield/src/provider.rs +++ b/packages/core/shield/src/provider.rs @@ -1,26 +1,36 @@ -use serde::{Deserialize, Serialize}; - -use crate::form::Form; - pub trait Provider: Send + Sync { fn method_id(&self) -> String; fn id(&self) -> Option; fn name(&self) -> String; +} - fn icon_url(&self) -> Option; +#[cfg(test)] +pub(crate) mod tests { + use async_trait::async_trait; - fn form(&self) -> Option
; -} + use crate::method::tests::TEST_METHOD_ID; + + use super::Provider; + + pub const TEST_PROVIDER_NAME: &str = "Test"; + + #[derive(Default)] + pub struct TestProvider {} + + #[async_trait] + impl Provider for TestProvider { + fn method_id(&self) -> String { + TEST_METHOD_ID.to_owned() + } + + fn id(&self) -> Option { + None + } -#[derive(Clone, Debug, Deserialize, Serialize)] -#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] -#[serde(rename_all = "camelCase")] -pub struct ProviderVisualisation { - pub key: String, - pub method_id: String, - pub provider_id: Option, - pub name: String, - pub icon_url: Option, + fn name(&self) -> String { + TEST_PROVIDER_NAME.to_owned() + } + } } diff --git a/packages/core/shield/src/request.rs b/packages/core/shield/src/request.rs index d7a2d4e..f046eda 100644 --- a/packages/core/shield/src/request.rs +++ b/packages/core/shield/src/request.rs @@ -1,29 +1,7 @@ -use serde::{Deserialize, Serialize}; use serde_json::Value; -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct SignInRequest { - pub method_id: String, - pub provider_id: Option, - pub redirect_url: Option, - pub data: Option, - pub form_data: Option, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct SignInCallbackRequest { - pub method_id: String, - pub provider_id: Option, - pub redirect_url: Option, - pub query: Option, - pub data: Option, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct SignOutRequest { - pub method_id: String, - pub provider_id: Option, +#[derive(Clone, Debug)] +pub struct Request { + pub query: Value, + pub form_data: Value, } diff --git a/packages/core/shield/src/response.rs b/packages/core/shield/src/response.rs index bd07377..4b3b4b6 100644 --- a/packages/core/shield/src/response.rs +++ b/packages/core/shield/src/response.rs @@ -1,6 +1,6 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug)] pub enum Response { + // TODO: Remove temporary default variant. + Default, Redirect(String), } diff --git a/packages/core/shield/src/shield.rs b/packages/core/shield/src/shield.rs index 8175a44..361fce6 100644 --- a/packages/core/shield/src/shield.rs +++ b/packages/core/shield/src/shield.rs @@ -1,38 +1,29 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{any::Any, collections::HashMap, sync::Arc}; use futures::future::try_join_all; use crate::{ - MethodError, - error::{SessionError, ShieldError}, - method::Method, - options::ShieldOptions, - provider::{Provider, ProviderVisualisation}, - request::{SignInCallbackRequest, SignInRequest, SignOutRequest}, - response::Response, - session::Session, - storage::Storage, - user::User, + error::ShieldError, method::ErasedMethod, options::ShieldOptions, storage::Storage, user::User, }; #[derive(Clone)] pub struct Shield { storage: Arc>, - methods: Arc>>, + methods: Arc>>, options: ShieldOptions, } impl Shield { - pub fn new(storage: S, providers: Vec>, options: ShieldOptions) -> Self + pub fn new(storage: S, methods: Vec>, options: ShieldOptions) -> Self where S: Storage + 'static, { Self { storage: Arc::new(storage), methods: Arc::new( - providers + methods .into_iter() - .map(|provider| (provider.id(), provider)) + .map(|method| (method.erased_id(), method)) .collect(), ), options, @@ -47,187 +38,27 @@ impl Shield { &self.options } - pub fn method_by_id(&self, provider_id: &str) -> Option<&dyn Method> { - self.methods.get(provider_id).map(|v| &**v) + pub fn method_by_id(&self, method_id: &str) -> Option<&dyn ErasedMethod> { + self.methods.get(method_id).map(|v| &**v) } - pub async fn providers(&self) -> Result>, ShieldError> { - try_join_all(self.methods.values().map(|provider| provider.providers())) - .await - .map(|providers| providers.into_iter().flatten().collect::>()) - } - - pub async fn provider_visualisations(&self) -> Result, ShieldError> { - self.providers().await.map(|providers| { - providers - .into_iter() - .map(|provider| { - let method_id = provider.method_id(); - let provider_id = provider.id(); - - ProviderVisualisation { - key: match &provider_id { - Some(provider_id) => format!("{method_id}-{provider_id}"), - None => method_id.clone(), - }, - method_id, - provider_id, - name: provider.name(), - icon_url: provider.icon_url(), - } - }) - .collect() - }) + pub async fn providers(&self) -> Result>, ShieldError> { + try_join_all( + self.methods + .values() + .map(|provider| provider.erased_providers()), + ) + .await + .map(|providers| providers.into_iter().flatten().collect::>()) } pub async fn provider_by_id( &self, method_id: &str, provider_id: Option<&str>, - ) -> Result>, ShieldError> { + ) -> Result>, ShieldError> { match self.method_by_id(method_id) { - Some(provider) => provider.provider_by_id(provider_id.expect("TODO")).await, - None => Ok(None), - } - } - - pub async fn sign_in( - &self, - request: SignInRequest, - session: Session, - ) -> Result { - let provider = match self.methods.get(&request.method_id) { - Some(provider) => provider, - None => return Err(MethodError::MethodNotFound(request.method_id).into()), - }; - - // TODO: validate redirect URL - - { - let session_data = session.data(); - let mut session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.redirect_url = request.redirect_url.clone(); - }; - - let response = provider - .sign_in(request, session.clone(), &self.options) - .await; - - session.update().await?; - - response - } - - pub async fn sign_in_callback( - &self, - request: SignInCallbackRequest, - session: Session, - ) -> Result { - let provider = match self.methods.get(&request.method_id) { - Some(provider) => provider, - None => return Err(MethodError::MethodNotFound(request.method_id).into()), - }; - - let redirect_url = { - let session_data = session.data(); - let session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.redirect_url.clone() - }; - - let response = provider - .sign_in_callback( - SignInCallbackRequest { - redirect_url: request.redirect_url.or(redirect_url), - ..request - }, - session.clone(), - &self.options, - ) - .await; - - session.update().await?; - - response - } - - pub async fn sign_out(&self, session: Session) -> Result { - let authenticated = { - let session_data = session.data(); - let session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.authentication.clone() - }; - - let response = if let Some(authenticated) = authenticated { - let provider = match self.methods.get(&authenticated.method_id) { - Some(provider) => provider, - None => { - return Err(MethodError::MethodNotFound(authenticated.method_id).into()); - } - }; - - provider - .sign_out( - SignOutRequest { - method_id: authenticated.method_id, - provider_id: authenticated.provider_id, - }, - session.clone(), - &self.options, - ) - .await? - } else { - None - }; - - let response = - response.unwrap_or_else(|| Response::Redirect(self.options.sign_out_redirect.clone())); - - session.purge().await?; - - Ok(response) - } - - pub async fn user(&self, session: &Session) -> Result, ShieldError> { - let authentication = { - let session_data = session.data(); - let session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.authentication.clone() - }; - - match authentication { - Some(authentication) => { - if self - .provider_by_id( - &authentication.method_id, - authentication.provider_id.as_deref(), - ) - .await? - .is_none() - { - session.purge().await?; - return Ok(None); - } - - let user = self.storage().user_by_id(&authentication.user_id).await?; - - if user.is_none() { - session.purge().await?; - } - - Ok(user) - } + Some(provider) => provider.erased_provider_by_id(provider_id).await, None => Ok(None), } } @@ -235,11 +66,8 @@ impl Shield { #[cfg(test)] mod tests { - use std::sync::Arc; - use crate::{ ShieldOptions, - method::tests::{TEST_METHOD_ID, TestMethod}, storage::tests::{TEST_STORAGE_ID, TestStorage}, }; @@ -251,31 +79,4 @@ mod tests { assert_eq!(TEST_STORAGE_ID, shield.storage().id()); } - - #[test] - fn test_providers() { - let shield = Shield::new( - TestStorage::default(), - vec![ - Arc::new(TestMethod::default().with_id("test1")), - Arc::new(TestMethod::default().with_id("test2")), - ], - ShieldOptions::default(), - ); - - assert_eq!( - None, - shield - .method_by_id(TEST_METHOD_ID) - .map(|provider| provider.id()) - ); - assert_eq!( - Some("test1".to_owned()), - shield.method_by_id("test1").map(|provider| provider.id()) - ); - assert_eq!( - Some("test2".to_owned()), - shield.method_by_id("test2").map(|provider| provider.id()) - ); - } } diff --git a/packages/core/shield/src/shield_dyn.rs b/packages/core/shield/src/shield_dyn.rs index 2a15d86..cccd81a 100644 --- a/packages/core/shield/src/shield_dyn.rs +++ b/packages/core/shield/src/shield_dyn.rs @@ -1,67 +1,19 @@ -use std::sync::Arc; +use std::{any::Any, sync::Arc}; use async_trait::async_trait; -use crate::{ - error::ShieldError, - provider::{Provider, ProviderVisualisation}, - request::{SignInCallbackRequest, SignInRequest}, - response::Response, - session::Session, - shield::Shield, - user::User, -}; +use crate::{error::ShieldError, shield::Shield, user::User}; #[async_trait] pub trait DynShield: Send + Sync { - async fn providers(&self) -> Result>, ShieldError>; - - async fn provider_visualisations(&self) -> Result, ShieldError>; - - async fn sign_in( - &self, - request: SignInRequest, - session: Session, - ) -> Result; - - async fn sign_in_callback( - &self, - request: SignInCallbackRequest, - session: Session, - ) -> Result; - - async fn sign_out(&self, session: Session) -> Result; + async fn providers(&self) -> Result>, ShieldError>; } #[async_trait] impl DynShield for Shield { - async fn providers(&self) -> Result>, ShieldError> { + async fn providers(&self) -> Result>, ShieldError> { self.providers().await } - - async fn provider_visualisations(&self) -> Result, ShieldError> { - self.provider_visualisations().await - } - - async fn sign_in( - &self, - request: SignInRequest, - session: Session, - ) -> Result { - self.sign_in(request, session).await - } - - async fn sign_in_callback( - &self, - request: SignInCallbackRequest, - session: Session, - ) -> Result { - self.sign_in_callback(request, session).await - } - - async fn sign_out(&self, session: Session) -> Result { - self.sign_out(session).await - } } pub struct ShieldDyn(Arc); @@ -71,31 +23,7 @@ impl ShieldDyn { Self(Arc::new(shield)) } - pub async fn providers(&self) -> Result>, ShieldError> { + pub async fn providers(&self) -> Result>, ShieldError> { self.0.providers().await } - - pub async fn provider_visualisations(&self) -> Result, ShieldError> { - self.0.provider_visualisations().await - } - - pub async fn sign_in( - &self, - request: SignInRequest, - session: Session, - ) -> Result { - self.0.sign_in(request, session).await - } - - pub async fn sign_in_callback( - &self, - request: SignInCallbackRequest, - session: Session, - ) -> Result { - self.0.sign_in_callback(request, session).await - } - - pub async fn sign_out(&self, session: Session) -> Result { - self.0.sign_out(session).await - } } diff --git a/packages/integrations/shield-axum/src/error.rs b/packages/integrations/shield-axum/src/error.rs index 349876c..928adf1 100644 --- a/packages/integrations/shield-axum/src/error.rs +++ b/packages/integrations/shield-axum/src/error.rs @@ -4,7 +4,7 @@ use axum::{ response::{IntoResponse, Response}, }; use serde::Serialize; -use shield::{MethodError, ProviderError, ShieldError, StorageError}; +use shield::{ActionError, MethodError, ProviderError, ShieldError, StorageError}; #[derive(Serialize)] #[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] @@ -45,11 +45,14 @@ impl IntoResponse for RouteError { fn into_response(self) -> Response { let status_code = match &self.0 { ShieldError::Method(method_error) => match method_error { - MethodError::MethodNotFound(_) => StatusCode::NOT_FOUND, + MethodError::NotFound(_) => StatusCode::NOT_FOUND, + }, + ShieldError::Action(action_error) => match action_error { + ActionError::NotFound(_) => StatusCode::NOT_FOUND, }, ShieldError::Provider(provider_error) => match provider_error { - ProviderError::ProviderMissing => StatusCode::BAD_REQUEST, - ProviderError::ProviderNotFound(_) => StatusCode::NOT_FOUND, + ProviderError::Missing => StatusCode::BAD_REQUEST, + ProviderError::NotFound(_) => StatusCode::NOT_FOUND, }, ShieldError::Configuration(_) => StatusCode::INTERNAL_SERVER_ERROR, ShieldError::Session(_) => StatusCode::INTERNAL_SERVER_ERROR, diff --git a/packages/integrations/shield-axum/src/lib.rs b/packages/integrations/shield-axum/src/lib.rs index b1c8976..41537f3 100644 --- a/packages/integrations/shield-axum/src/lib.rs +++ b/packages/integrations/shield-axum/src/lib.rs @@ -2,7 +2,6 @@ mod error; mod extract; mod middleware; mod path; -mod response; mod router; mod routes; diff --git a/packages/integrations/shield-axum/src/path.rs b/packages/integrations/shield-axum/src/path.rs index 494a830..f8a0c5c 100644 --- a/packages/integrations/shield-axum/src/path.rs +++ b/packages/integrations/shield-axum/src/path.rs @@ -3,9 +3,11 @@ use serde::Deserialize; #[derive(Deserialize)] #[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))] #[serde(rename_all = "camelCase")] -pub struct AuthPathParams { - /// ID of authentication method. +pub struct ActionPathParams { + /// ID of the method. pub method_id: String, - /// ID of authentication provider (optional). + /// ID of the action. + pub action_id: String, + /// ID of provider (optional). pub provider_id: Option, } diff --git a/packages/integrations/shield-axum/src/response.rs b/packages/integrations/shield-axum/src/response.rs deleted file mode 100644 index 2396b88..0000000 --- a/packages/integrations/shield-axum/src/response.rs +++ /dev/null @@ -1,17 +0,0 @@ -use axum::response::{IntoResponse, Redirect, Response}; - -pub struct RouteResponse(shield::Response); - -impl IntoResponse for RouteResponse { - fn into_response(self) -> Response { - match self.0 { - shield::Response::Redirect(url) => Redirect::to(&url).into_response(), - } - } -} - -impl From for RouteResponse { - fn from(value: shield::Response) -> Self { - Self(value) - } -} diff --git a/packages/integrations/shield-axum/src/router.rs b/packages/integrations/shield-axum/src/router.rs index 57a4d3a..09e4c2e 100644 --- a/packages/integrations/shield-axum/src/router.rs +++ b/packages/integrations/shield-axum/src/router.rs @@ -7,24 +7,15 @@ use shield::User; use crate::routes::*; #[cfg_attr(feature = "utoipa", derive(utoipa::OpenApi))] -#[cfg_attr( - feature = "utoipa", - openapi(paths(providers, sign_in, sign_in_callback, sign_out, user)) -)] +#[cfg_attr(feature = "utoipa", openapi(paths()))] pub struct AuthRoutes; impl AuthRoutes { pub fn router() -> Router { Router::new() - .route("/providers", get(providers::)) - .route("/sign-in/{methodId}", post(sign_in::)) - .route("/sign-in/{methodId}/{providerId}", post(sign_in::)) - .route("/sign-in/callback/{methodId}", get(sign_in_callback::)) - .route( - "/sign-in/callback/{methodId}/{providerId}", - get(sign_in_callback::), - ) - .route("/sign-out", post(sign_out::)) - .route("/user", get(user::)) + .route("/{methodId}/{actionId}", get(action::)) + .route("/{methodId}/{actionId}", post(action::)) + .route("/{methodId}/{actionId}/{providerId}", get(action::)) + .route("/{methodId}/{actionId}/{providerId}", post(action::)) } } diff --git a/packages/integrations/shield-axum/src/routes.rs b/packages/integrations/shield-axum/src/routes.rs index accae39..8bd911f 100644 --- a/packages/integrations/shield-axum/src/routes.rs +++ b/packages/integrations/shield-axum/src/routes.rs @@ -1,11 +1,3 @@ -mod providers; -mod sign_in; -mod sign_in_callback; -mod sign_out; -mod user; +mod action; -pub use providers::*; -pub use sign_in::*; -pub use sign_in_callback::*; -pub use sign_out::*; -pub use user::*; +pub use action::*; diff --git a/packages/integrations/shield-axum/src/routes/action.rs b/packages/integrations/shield-axum/src/routes/action.rs new file mode 100644 index 0000000..fe2cd68 --- /dev/null +++ b/packages/integrations/shield-axum/src/routes/action.rs @@ -0,0 +1,42 @@ +use axum::{ + Form, + extract::{Path, Query}, +}; +use serde_json::Value; +use shield::{ActionError, MethodError, ProviderError, Request, ShieldError, User}; + +use crate::{ExtractSession, ExtractShield, RouteError, path::ActionPathParams}; + +pub async fn action( + Path(ActionPathParams { + method_id, + action_id, + provider_id, + .. + }): Path, + ExtractShield(shield): ExtractShield, + ExtractSession(session): ExtractSession, + Query(query): Query, + Form(form_data): Form, +) -> Result<(), RouteError> { + let method = shield + .method_by_id(&method_id) + .ok_or(ShieldError::Method(MethodError::NotFound(method_id)))?; + + let action = method + .erased_action_by_id(&action_id) + .ok_or(ShieldError::Action(ActionError::NotFound(action_id)))?; + + // TODO: Check if this action supports the HTTP method (GET/POST). + + let provider = method + .erased_provider_by_id(provider_id.as_deref()) + .await? + .ok_or(ShieldError::Provider(ProviderError::NotFound(provider_id)))?; + + action + .erased_call(provider, session, Request { query, form_data }) + .await?; + + Ok(()) +} diff --git a/packages/integrations/shield-axum/src/routes/providers.rs b/packages/integrations/shield-axum/src/routes/providers.rs deleted file mode 100644 index a5d5583..0000000 --- a/packages/integrations/shield-axum/src/routes/providers.rs +++ /dev/null @@ -1,26 +0,0 @@ -use axum::Json; -use shield::{ProviderVisualisation, User}; - -use crate::{ - error::{ErrorBody, RouteError}, - extract::ExtractShield, -}; - -#[cfg_attr( - feature = "utoipa", - utoipa::path( - get, - path = "/providers", - operation_id = "getProviders", - description = "Get a list of authentication providers.", - responses( - (status = 200, description = "List of authentication providers.", body = Vec), - (status = 500, description = "Internal server error.", body = ErrorBody), - ) - ) -)] -pub async fn providers( - ExtractShield(shield): ExtractShield, -) -> Result>, RouteError> { - Ok(Json(shield.provider_visualisations().await?)) -} diff --git a/packages/integrations/shield-axum/src/routes/sign_in.rs b/packages/integrations/shield-axum/src/routes/sign_in.rs deleted file mode 100644 index 92e5189..0000000 --- a/packages/integrations/shield-axum/src/routes/sign_in.rs +++ /dev/null @@ -1,62 +0,0 @@ -use axum::{Form, extract::Path}; -use serde::{Deserialize, Serialize}; -use shield::{SignInRequest, User}; - -use crate::{ - error::{ErrorBody, RouteError}, - extract::{ExtractSession, ExtractShield}, - path::AuthPathParams, - response::RouteResponse, -}; - -#[derive(Clone, Debug, Default, Deserialize, Serialize)] -#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] -#[serde(rename_all = "camelCase")] -pub struct SignInData { - redirect_url: Option, -} - -#[cfg_attr( - feature = "utoipa", - utoipa::path( - post, - path = "/sign-in/{methodId}/{providerId}", - operation_id = "signIn", - description = "Sign in to an account with the specified authentication provider.", - params( - AuthPathParams, - ), - request_body = SignInData, - responses( - (status = 200, description = "Successfully signed in."), - (status = 303, description = "Redirect to authentication provider for sign in."), - (status = 400, description = "Bad request.", body = ErrorBody), - (status = 404, description = "Not found.", body = ErrorBody), - (status = 500, description = "Internal server error.", body = ErrorBody), - ) - ) -)] -pub async fn sign_in( - Path(AuthPathParams { - method_id, - provider_id, - }): Path, - ExtractShield(shield): ExtractShield, - ExtractSession(session): ExtractSession, - Form(data): Form, -) -> Result { - let response = shield - .sign_in( - SignInRequest { - method_id, - provider_id, - redirect_url: data.redirect_url, - data: None, - form_data: None, - }, - session, - ) - .await?; - - Ok(response.into()) -} diff --git a/packages/integrations/shield-axum/src/routes/sign_in_callback.rs b/packages/integrations/shield-axum/src/routes/sign_in_callback.rs deleted file mode 100644 index 54f81bf..0000000 --- a/packages/integrations/shield-axum/src/routes/sign_in_callback.rs +++ /dev/null @@ -1,53 +0,0 @@ -use axum::extract::{Path, Query}; -use serde_json::Value; -use shield::{SignInCallbackRequest, User}; - -use crate::{ - error::{ErrorBody, RouteError}, - extract::{ExtractSession, ExtractShield}, - path::AuthPathParams, - response::RouteResponse, -}; - -#[cfg_attr( - feature = "utoipa", - utoipa::path( - post, - path = "/sign-in/callback/{methodId}/{providerId}", - operation_id = "signInCallback", - description = "Callback after signing in with authentication provider.", - params( - AuthPathParams, - ), - responses( - (status = 200, description = "Successfully signed in."), - (status = 400, description = "Bad request.", body = ErrorBody), - (status = 404, description = "Not found.", body = ErrorBody), - (status = 500, description = "Internal server error.", body = ErrorBody), - ) - ) -)] -pub async fn sign_in_callback( - Path(AuthPathParams { - method_id, - provider_id, - }): Path, - Query(query): Query, - ExtractShield(shield): ExtractShield, - ExtractSession(session): ExtractSession, -) -> Result { - let response = shield - .sign_in_callback( - SignInCallbackRequest { - method_id, - provider_id, - redirect_url: None, - query: Some(query), - data: None, - }, - session, - ) - .await?; - - Ok(response.into()) -} diff --git a/packages/integrations/shield-axum/src/routes/sign_out.rs b/packages/integrations/shield-axum/src/routes/sign_out.rs deleted file mode 100644 index 8659118..0000000 --- a/packages/integrations/shield-axum/src/routes/sign_out.rs +++ /dev/null @@ -1,30 +0,0 @@ -use shield::User; - -use crate::{ - error::{ErrorBody, RouteError}, - extract::{ExtractSession, ExtractShield}, - response::RouteResponse, -}; - -#[cfg_attr( - feature = "utoipa", - utoipa::path( - post, - path = "/sign-out", - operation_id = "signOut", - description = "Sign out of the current account.", - responses( - (status = 201, description = "Successfully signed out."), - (status = 400, description = "Bad request.", body = ErrorBody), - (status = 500, description = "Internal server error.", body = ErrorBody), - ) - ) -)] -pub async fn sign_out( - ExtractShield(shield): ExtractShield, - ExtractSession(session): ExtractSession, -) -> Result { - let response = shield.sign_out(session).await?; - - Ok(response.into()) -} diff --git a/packages/integrations/shield-axum/src/routes/user.rs b/packages/integrations/shield-axum/src/routes/user.rs deleted file mode 100644 index cfbbf5d..0000000 --- a/packages/integrations/shield-axum/src/routes/user.rs +++ /dev/null @@ -1,59 +0,0 @@ -use axum::Json; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use shield::{ConfigurationError, EmailAddress, ShieldError, User}; - -use crate::{ - error::{ErrorBody, RouteError}, - extract::ExtractUser, -}; - -#[derive(Deserialize, Serialize)] -#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] -#[cfg_attr(feature = "utoipa", schema(as = User))] -#[serde(rename_all = "camelCase")] -pub struct UserBody { - id: String, - name: Option, - email_addresses: Vec, - additional: Value, -} - -impl UserBody { - async fn new(user: U) -> Result { - let email_addresses = user.email_addresses().await?; - - Ok(Self { - id: user.id(), - name: user.name(), - email_addresses, - additional: serde_json::to_value(user.additional()).map_err(|err| { - ConfigurationError::Invalid(format!( - "additional user data is not serializable: {err}" - )) - })?, - }) - } -} - -#[cfg_attr( - feature = "utoipa", - utoipa::path( - get, - path = "/user", - operation_id = "getUser", - description = "Get the current user account.", - responses( - (status = 200, description = "Current user account.", body = UserBody), - (status = 401, description = "No account signed in.", body = ErrorBody), - (status = 500, description = "Internal server error.", body = ErrorBody), - ) - ) -)] -pub async fn user( - ExtractUser(user): ExtractUser, -) -> Result, RouteError> { - let user = user.ok_or(ShieldError::Unauthorized)?; - - Ok(Json(UserBody::new(user).await?)) -} diff --git a/packages/integrations/shield-leptos/src/lib.rs b/packages/integrations/shield-leptos/src/lib.rs index 21b4ecc..398c7f0 100644 --- a/packages/integrations/shield-leptos/src/lib.rs +++ b/packages/integrations/shield-leptos/src/lib.rs @@ -1,3 +1,2 @@ pub mod context; pub mod integration; -pub mod routes; diff --git a/packages/integrations/shield-leptos/src/routes/sign_in.rs b/packages/integrations/shield-leptos/src/routes/sign_in.rs deleted file mode 100644 index 9b1c3e8..0000000 --- a/packages/integrations/shield-leptos/src/routes/sign_in.rs +++ /dev/null @@ -1,79 +0,0 @@ -use leptos::{either::Either, prelude::*}; -use shield::ProviderVisualisation; - -#[server] -pub async fn providers() -> Result, ServerFnError> { - use crate::context::extract_shield; - - let shield = extract_shield().await; - - shield - .provider_visualisations() - .await - .map_err(|err| err.into()) -} - -#[server] -pub async fn sign_in(method_id: String, provider_id: Option) -> Result<(), ServerFnError> { - use shield::{Response, ShieldError, SignInRequest}; - - use crate::context::expect_server_integration; - - let server_integration = expect_server_integration(); - let shield = server_integration.extract_shield().await; - let session = server_integration.extract_session().await; - - let response = shield - .sign_in( - SignInRequest { - method_id, - provider_id, - redirect_url: None, - data: None, - form_data: None, - }, - session, - ) - .await - .map_err(ServerFnError::::from)?; - - match response { - Response::Redirect(url) => { - server_integration.redirect(&url); - - Ok(()) - } - } -} - -#[component] -pub fn SignIn() -> impl IntoView { - let providers = OnceResource::new(providers()); - let sign_in = ServerAction::::new(); - - view! { -

"Sign in"

- - - {move || Suspend::new(async move { match providers.await { - Ok(providers) => Either::Left(view! { - - - - - - - - - }), - Err(err) => Either::Right(view! { - {err.to_string()} - }) - }})} - - } -} diff --git a/packages/integrations/shield-leptos/src/routes/sign_out.rs b/packages/integrations/shield-leptos/src/routes/sign_out.rs deleted file mode 100644 index 11ab8e3..0000000 --- a/packages/integrations/shield-leptos/src/routes/sign_out.rs +++ /dev/null @@ -1,38 +0,0 @@ -use leptos::prelude::*; - -#[server] -pub async fn sign_out() -> Result<(), ServerFnError> { - use shield::{Response, ShieldError}; - - use crate::context::expect_server_integration; - - let server_integration = expect_server_integration(); - let shield = server_integration.extract_shield().await; - let session = server_integration.extract_session().await; - - let response = shield - .sign_out(session) - .await - .map_err(ServerFnError::::from)?; - - match response { - Response::Redirect(url) => { - server_integration.redirect(&url); - - Ok(()) - } - } -} - -#[component] -pub fn SignOut() -> impl IntoView { - let sign_out = ServerAction::::new(); - - view! { -

"Sign out"

- - - - - } -} diff --git a/packages/integrations/shield-tower/src/service.rs b/packages/integrations/shield-tower/src/service.rs index a0b40f4..ee7345d 100644 --- a/packages/integrations/shield-tower/src/service.rs +++ b/packages/integrations/shield-tower/src/service.rs @@ -74,14 +74,14 @@ where }; let shield_session = Session::new(session_storage); - let user = match shield.user(&shield_session).await { - Ok(user) => user, - Err(_err) => return Ok(Self::internal_server_error()), - }; + // let user = match shield.user(&shield_session).await { + // Ok(user) => user, + // Err(_err) => return Ok(Self::internal_server_error()), + // }; req.extensions_mut().insert(shield); req.extensions_mut().insert(shield_session); - req.extensions_mut().insert(user); + // req.extensions_mut().insert(user); inner.call(req).await }) diff --git a/packages/methods/shield-credentials/src/actions.rs b/packages/methods/shield-credentials/src/actions.rs new file mode 100644 index 0000000..82b597f --- /dev/null +++ b/packages/methods/shield-credentials/src/actions.rs @@ -0,0 +1,3 @@ +mod sign_in; + +pub use sign_in::*; diff --git a/packages/methods/shield-credentials/src/actions/sign_in.rs b/packages/methods/shield-credentials/src/actions/sign_in.rs new file mode 100644 index 0000000..cf6a9b2 --- /dev/null +++ b/packages/methods/shield-credentials/src/actions/sign_in.rs @@ -0,0 +1,64 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use serde::de::DeserializeOwned; +use shield::{ + Action, Authentication, Form, Request, Response, SIGN_IN_ACTION_ID, Session, SessionError, + ShieldError, User, erased_action, +}; + +use crate::{credentials::Credentials, provider::CredentialsProvider}; + +pub struct CredentialsSignInAction { + credentials: Arc>, +} + +impl CredentialsSignInAction { + pub fn new(credentials: Arc>) -> Self { + Self { credentials } + } +} + +#[async_trait] +impl Action + for CredentialsSignInAction +{ + fn id(&self) -> String { + SIGN_IN_ACTION_ID.to_owned() + } + + fn render(&self, _provider: CredentialsProvider) -> Form { + self.credentials.form() + } + + async fn call( + &self, + _provider: CredentialsProvider, + session: Session, + request: Request, + ) -> Result { + let data = serde_json::from_value(request.form_data) + .map_err(|err| ShieldError::Validation(err.to_string()))?; + + let user = self.credentials.sign_in(data).await?; + + session.renew().await?; + + { + let session_data = session.data(); + let mut session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.authentication = Some(Authentication { + method_id: self.id(), + provider_id: None, + user_id: user.id(), + }); + } + + Ok(Response::Default) + } +} + +erased_action!(CredentialsSignInAction, ); diff --git a/packages/methods/shield-credentials/src/lib.rs b/packages/methods/shield-credentials/src/lib.rs index c9b7a3d..2c16649 100644 --- a/packages/methods/shield-credentials/src/lib.rs +++ b/packages/methods/shield-credentials/src/lib.rs @@ -1,3 +1,4 @@ +mod actions; mod credentials; mod email_password; mod method; diff --git a/packages/methods/shield-credentials/src/method.rs b/packages/methods/shield-credentials/src/method.rs index 834ff52..f5c2dde 100644 --- a/packages/methods/shield-credentials/src/method.rs +++ b/packages/methods/shield-credentials/src/method.rs @@ -2,12 +2,9 @@ use std::sync::Arc; use async_trait::async_trait; use serde::de::DeserializeOwned; -use shield::{ - Authentication, Method, Provider, Response, Session, SessionError, ShieldError, ShieldOptions, - SignInCallbackRequest, SignInRequest, SignOutRequest, User, -}; +use shield::{Action, Method, ShieldError, User, erased_method}; -use crate::{Credentials, provider::CredentialsProvider}; +use crate::{Credentials, actions::CredentialsSignInAction, provider::CredentialsProvider}; pub const CREDENTIALS_METHOD_ID: &str = "credentials"; @@ -24,84 +21,22 @@ impl CredentialsMethod { } #[async_trait] -impl Method for CredentialsMethod { +impl Method + for CredentialsMethod +{ fn id(&self) -> String { CREDENTIALS_METHOD_ID.to_owned() } - async fn providers(&self) -> Result>, ShieldError> { - Ok(vec![Box::new(CredentialsProvider::new( + fn actions(&self) -> Vec>> { + vec![Box::new(CredentialsSignInAction::new( self.credentials.clone(), - ))]) + ))] } - async fn provider_by_id( - &self, - _provider_id: &str, - ) -> Result>, ShieldError> { - Ok(None) - } - - async fn sign_in( - &self, - request: SignInRequest, - session: Session, - options: &ShieldOptions, - ) -> Result { - if request.provider_id.is_some() { - return Err(ShieldError::Validation( - "Provider should be none.".to_owned(), - )); - } - - let Some(form_data) = request.form_data else { - return Err(ShieldError::Validation("Missing form data.".to_owned())); - }; - - let data = serde_json::from_value(form_data) - .map_err(|err| ShieldError::Validation(err.to_string()))?; - - let user = self.credentials.sign_in(data).await?; - - session.renew().await?; - - { - let session_data = session.data(); - let mut session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.authentication = Some(Authentication { - method_id: self.id(), - provider_id: None, - user_id: user.id(), - }); - } - - Ok(Response::Redirect( - request - .redirect_url - .unwrap_or(options.sign_in_redirect.clone()), - )) - } - - async fn sign_in_callback( - &self, - _request: SignInCallbackRequest, - _session: Session, - _options: &ShieldOptions, - ) -> Result { - Err(ShieldError::Validation( - "Credentials method does not have a sign in callback.".to_owned(), - )) - } - - async fn sign_out( - &self, - _request: SignOutRequest, - _session: Session, - _options: &ShieldOptions, - ) -> Result, ShieldError> { - Ok(None) + async fn providers(&self) -> Result, ShieldError> { + Ok(vec![CredentialsProvider]) } } + +erased_method!(CredentialsMethod, ); diff --git a/packages/methods/shield-credentials/src/provider.rs b/packages/methods/shield-credentials/src/provider.rs index 99f7fa9..f764e6f 100644 --- a/packages/methods/shield-credentials/src/provider.rs +++ b/packages/methods/shield-credentials/src/provider.rs @@ -1,21 +1,10 @@ -use std::sync::Arc; +use shield::Provider; -use serde::de::DeserializeOwned; -use shield::{Form, Provider, User}; +use crate::CREDENTIALS_METHOD_ID; -use crate::{CREDENTIALS_METHOD_ID, Credentials}; +pub struct CredentialsProvider; -pub struct CredentialsProvider { - credentials: Arc>, -} - -impl CredentialsProvider { - pub(crate) fn new(credentials: Arc>) -> Self { - Self { credentials } - } -} - -impl Provider for CredentialsProvider { +impl Provider for CredentialsProvider { fn method_id(&self) -> String { CREDENTIALS_METHOD_ID.to_owned() } @@ -27,12 +16,4 @@ impl Provider for CredentialsProvider { fn name(&self) -> String { "Credentials".to_owned() } - - fn icon_url(&self) -> Option { - None - } - - fn form(&self) -> Option { - Some(self.credentials.form()) - } } diff --git a/packages/methods/shield-oauth/Cargo.toml b/packages/methods/shield-oauth/Cargo.toml index d6d2f20..d28bbd2 100644 --- a/packages/methods/shield-oauth/Cargo.toml +++ b/packages/methods/shield-oauth/Cargo.toml @@ -15,7 +15,9 @@ rustls-tls = ["oauth2/rustls-tls"] [dependencies] async-trait.workspace = true +bon.workspace = true chrono.workspace = true oauth2 = { version = "5.0.0", default-features = false, features = ["reqwest"] } +secrecy.workspace = true serde.workspace = true shield.workspace = true diff --git a/packages/integrations/shield-leptos/src/routes.rs b/packages/methods/shield-oauth/src/actions.rs similarity index 57% rename from packages/integrations/shield-leptos/src/routes.rs rename to packages/methods/shield-oauth/src/actions.rs index 757615c..47d587d 100644 --- a/packages/integrations/shield-leptos/src/routes.rs +++ b/packages/methods/shield-oauth/src/actions.rs @@ -1,5 +1,7 @@ mod sign_in; +mod sign_in_callback; mod sign_out; pub use sign_in::*; +pub use sign_in_callback::*; pub use sign_out::*; diff --git a/packages/methods/shield-oauth/src/actions/sign_in.rs b/packages/methods/shield-oauth/src/actions/sign_in.rs new file mode 100644 index 0000000..d1291b6 --- /dev/null +++ b/packages/methods/shield-oauth/src/actions/sign_in.rs @@ -0,0 +1,91 @@ +use async_trait::async_trait; +use oauth2::{CsrfToken, PkceCodeChallenge, Scope, url::form_urlencoded::parse}; +use shield::{ + Action, ConfigurationError, Form, Request, Response, SIGN_IN_ACTION_ID, Session, SessionError, + ShieldError, erased_action, +}; + +use crate::{ + method::OAUTH_METHOD_ID, + provider::{OauthProvider, OauthProviderPkceCodeChallenge}, + session::OauthSession, +}; + +pub struct OauthSignInAction; + +#[async_trait] +impl Action for OauthSignInAction { + fn id(&self) -> String { + SIGN_IN_ACTION_ID.to_owned() + } + + fn render(&self, _provider: OauthProvider) -> Form { + Form { + inputs: vec![], + attributes: None, + } + } + + async fn call( + &self, + provider: OauthProvider, + session: Session, + _request: Request, + ) -> Result { + let client = provider.oauth_client().await?; + + let mut authorization_request = client + .authorize_url(CsrfToken::new_random) + .map_err(|err| ConfigurationError::Invalid(err.to_string()))?; + + let pkce_code_challenge = match provider.pkce_code_challenge { + OauthProviderPkceCodeChallenge::None => None, + OauthProviderPkceCodeChallenge::Plain => Some(PkceCodeChallenge::new_random_plain()), + OauthProviderPkceCodeChallenge::S256 => Some(PkceCodeChallenge::new_random_sha256()), + }; + + if let Some((pkce_code_challenge, _)) = &pkce_code_challenge { + authorization_request = + authorization_request.set_pkce_challenge(pkce_code_challenge.clone()); + } + + if let Some(scopes) = provider.scopes { + authorization_request = + authorization_request.add_scopes(scopes.into_iter().map(Scope::new)); + } + + if let Some(authorization_url_params) = provider.authorization_url_params { + let params = parse(authorization_url_params.trim_start_matches('?').as_bytes()); + + for (name, value) in params { + authorization_request = + authorization_request.add_extra_param(name.into_owned(), value.into_owned()); + } + } + + let (auth_url, csrf_token) = authorization_request.url(); + + { + let session_data = session.data(); + let mut session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.authentication = None; + + session_data.set_method( + OAUTH_METHOD_ID, + OauthSession { + csrf: Some(csrf_token.secret().clone()), + pkce_verifier: pkce_code_challenge + .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()), + oauth_connection_id: None, + }, + )?; + } + + Ok(Response::Redirect(auth_url.to_string())) + } +} + +erased_action!(OauthSignInAction); diff --git a/packages/methods/shield-oauth/src/actions/sign_in_callback.rs b/packages/methods/shield-oauth/src/actions/sign_in_callback.rs new file mode 100644 index 0000000..0f9ff18 --- /dev/null +++ b/packages/methods/shield-oauth/src/actions/sign_in_callback.rs @@ -0,0 +1,305 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use chrono::{DateTime, Duration, FixedOffset, Utc}; +use oauth2::{ + AuthorizationCode, PkceCodeVerifier, TokenResponse, basic::BasicTokenResponse, + url::form_urlencoded::parse, +}; +use secrecy::SecretString; +use shield::{ + Action, Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Form, Request, + Response, SIGN_IN_CALLBACK_ACTION_ID, Session, SessionError, ShieldError, UpdateUser, User, + erased_action, +}; + +use crate::{ + client::async_http_client, + connection::{CreateOauthConnection, OauthConnection, UpdateOauthConnection}, + method::OAUTH_METHOD_ID, + options::OauthOptions, + provider::{OauthProvider, OauthProviderPkceCodeChallenge}, + session::OauthSession, + storage::OauthStorage, +}; + +pub struct OauthSignInCallbackAction { + options: OauthOptions, + storage: Arc>, +} + +impl OauthSignInCallbackAction { + pub fn new(options: OauthOptions, storage: Arc>) -> Self { + Self { options, storage } + } + + // TODO: Consider if there is a better location for the functions below. + + async fn create_user(&self, email: Option<&str>, name: Option<&str>) -> Result { + if let Some(email) = email { + match self.storage.user_by_email(email).await? { + Some(_) => Err(ShieldError::Validation( + "\ + Email address `{email}` is already used by another account. \ + To link a new provider, sign in to with your exising account first. \ + If this is not your account, please contact support for assistence.\ + " + .to_owned(), + )), + None => Ok(self + .storage + .create_user( + CreateUser { + name: name.map(ToOwned::to_owned), + }, + CreateEmailAddress { + email: email.to_string(), + is_primary: true, + // TODO: from claim? + is_verified: false, + // TODO: generate if not verified + verification_token: None, + verification_token_expired_at: None, + verified_at: None, + }, + ) + .await?), + } + } else { + Err(ShieldError::Validation( + "Missing email address in OpenID Connect claims.".to_owned(), + )) + } + } + + async fn update_user(&self, user_id: &str, name: Option<&str>) -> Result { + self.storage + .update_user(UpdateUser { + id: user_id.to_owned(), + name: name.map(ToOwned::to_owned).map(Some), + }) + .await + .map_err(ShieldError::Storage) + } + + async fn create_oauth_connection( + &self, + provider_id: String, + user_id: String, + identifier: String, + token_response: BasicTokenResponse, + ) -> Result { + let (token_type, access_token, refresh_token, expired_at, scopes) = + parse_token_response(token_response)?; + + self.storage + .create_oauth_connection(CreateOauthConnection { + identifier, + token_type, + access_token, + refresh_token, + expired_at, + scopes, + provider_id, + user_id, + }) + .await + .map_err(ShieldError::Storage) + } + + async fn update_oauth_connection( + &self, + connection_id: String, + token_response: BasicTokenResponse, + ) -> Result { + let (token_type, access_token, refresh_token, expired_at, scopes) = + parse_token_response(token_response)?; + + self.storage + .update_oauth_connection(UpdateOauthConnection { + id: connection_id, + token_type: Some(token_type), + access_token: Some(access_token), + refresh_token: refresh_token.map(Some), + expired_at: expired_at.map(Some), + scopes: scopes.map(Some), + }) + .await + .map_err(ShieldError::Storage) + } +} + +#[async_trait] +impl Action for OauthSignInCallbackAction { + fn id(&self) -> String { + SIGN_IN_CALLBACK_ACTION_ID.to_owned() + } + + fn render(&self, _provider: OauthProvider) -> Form { + Form { + inputs: vec![], + attributes: None, + } + } + + async fn call( + &self, + provider: OauthProvider, + session: Session, + request: Request, + ) -> Result { + let OauthSession { + csrf, + pkce_verifier, + .. + } = { + let session_data = session.data(); + let session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.method(OAUTH_METHOD_ID)? + }; + + let state = request + .query + .get("state") + .and_then(|code| code.as_str()) + .ok_or_else(|| ShieldError::Validation("Missing state.".to_owned()))?; + + if csrf.is_none_or(|csrf| csrf != state) { + return Err(ShieldError::Validation("Invalid state.".to_owned())); + } + + let authorization_code = request + .query + .get("code") + .and_then(|code| code.as_str()) + .ok_or_else(|| ShieldError::Validation("Missing authorization code.".to_owned()))?; + + let client = provider.oauth_client().await?; + + let mut token_request = client + .exchange_code(AuthorizationCode::new(authorization_code.to_owned())) + .map_err(|err| { + ShieldError::Configuration(ConfigurationError::Missing(err.to_string())) + })?; + + if let Some(pkce_verifier) = pkce_verifier { + token_request = token_request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier)); + } else if provider.pkce_code_challenge != OauthProviderPkceCodeChallenge::None { + return Err(ShieldError::Validation("Missing PKCE verifier.".to_owned())); + } + + if let Some(token_url_params) = provider.token_url_params { + let params = parse(token_url_params.trim_start_matches('?').as_bytes()); + + for (name, value) in params { + token_request = + token_request.add_extra_param(name.into_owned(), value.into_owned()); + } + } + + let async_http_client = async_http_client()?; + + let token_response = token_request + .request_async(&async_http_client) + .await + .map_err(|err| ShieldError::Request(err.to_string()))?; + + // TODO: user info + let identifier = ""; + let email = Some(""); + let name = Some(""); + + let (connection, user) = match self + .storage + .oauth_connection_by_identifier(&provider.id, identifier) + .await? + { + Some(connection) => { + let connection = self + .update_oauth_connection(connection.id, token_response) + .await?; + + let user = self.update_user(&connection.user_id, name).await?; + + (connection, user) + } + None => { + let user = self.create_user(email, name).await?; + + let connection = self + .create_oauth_connection( + provider.id.clone(), + user.id(), + identifier.to_owned(), + token_response, + ) + .await?; + + (connection, user) + } + }; + + session.renew().await?; + + { + let session_data = session.data(); + let mut session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.authentication = Some(Authentication { + method_id: self.id(), + provider_id: Some(provider.id), + user_id: user.id(), + }); + + session_data.set_method( + OAUTH_METHOD_ID, + OauthSession { + csrf: None, + pkce_verifier: None, + oauth_connection_id: Some(connection.id), + }, + )?; + } + + Ok(Response::Redirect(self.options.sign_in_redirect.clone())) + } +} + +erased_action!(OauthSignInCallbackAction, ); + +type ParsedTokenResponse = ( + String, + SecretString, + Option, + Option>, + Option>, +); + +fn parse_token_response( + token_response: BasicTokenResponse, +) -> Result { + Ok(( + token_response.token_type().as_ref().to_string(), + token_response.access_token().secret().as_str().into(), + token_response + .refresh_token() + .map(|refresh_token| refresh_token.secret().as_str().into()), + match token_response.expires_in() { + Some(expires_in) => Some( + (Utc::now() + + Duration::from_std(expires_in) + .map_err(|err| ShieldError::Validation(err.to_string()))?) + .into(), + ), + None => None, + }, + token_response + .scopes() + .map(|scopes| scopes.iter().map(|scope| scope.to_string()).collect()), + )) +} diff --git a/packages/methods/shield-oauth/src/actions/sign_out.rs b/packages/methods/shield-oauth/src/actions/sign_out.rs new file mode 100644 index 0000000..40b4d4c --- /dev/null +++ b/packages/methods/shield-oauth/src/actions/sign_out.rs @@ -0,0 +1,35 @@ +use async_trait::async_trait; +use shield::{ + Action, Form, Request, Response, SIGN_OUT_ACTION_ID, Session, ShieldError, erased_action, +}; + +use crate::provider::OauthProvider; + +pub struct OauthSignOutAction; + +#[async_trait] +impl Action for OauthSignOutAction { + fn id(&self) -> String { + SIGN_OUT_ACTION_ID.to_owned() + } + + fn render(&self, _provider: OauthProvider) -> Form { + Form { + inputs: vec![], + attributes: None, + } + } + + async fn call( + &self, + _provider: OauthProvider, + _session: Session, + _request: Request, + ) -> Result { + // TODO: OAuth token revocation. + + Ok(Response::Default) + } +} + +erased_action!(OauthSignOutAction); diff --git a/packages/methods/shield-oauth/src/connection.rs b/packages/methods/shield-oauth/src/connection.rs index e34d89c..f217faf 100644 --- a/packages/methods/shield-oauth/src/connection.rs +++ b/packages/methods/shield-oauth/src/connection.rs @@ -1,12 +1,13 @@ use chrono::{DateTime, FixedOffset}; +use secrecy::SecretString; #[derive(Clone, Debug)] pub struct OauthConnection { pub id: String, pub identifier: String, pub token_type: String, - pub access_token: String, - pub refresh_token: Option, + pub access_token: SecretString, + pub refresh_token: Option, pub expired_at: Option>, pub scopes: Option>, pub provider_id: String, @@ -17,8 +18,8 @@ pub struct OauthConnection { pub struct CreateOauthConnection { pub identifier: String, pub token_type: String, - pub access_token: String, - pub refresh_token: Option, + pub access_token: SecretString, + pub refresh_token: Option, pub expired_at: Option>, pub scopes: Option>, pub provider_id: String, @@ -29,8 +30,8 @@ pub struct CreateOauthConnection { pub struct UpdateOauthConnection { pub id: String, pub token_type: Option, - pub access_token: Option, - pub refresh_token: Option>, + pub access_token: Option, + pub refresh_token: Option>, pub expired_at: Option>>, pub scopes: Option>>, } diff --git a/packages/methods/shield-oauth/src/lib.rs b/packages/methods/shield-oauth/src/lib.rs index 4928900..5f97baf 100644 --- a/packages/methods/shield-oauth/src/lib.rs +++ b/packages/methods/shield-oauth/src/lib.rs @@ -1,11 +1,14 @@ +mod actions; mod client; mod connection; mod method; +mod options; mod provider; mod session; mod storage; pub use connection::*; pub use method::*; +pub use options::*; pub use provider::*; pub use storage::*; diff --git a/packages/methods/shield-oauth/src/method.rs b/packages/methods/shield-oauth/src/method.rs index c81d1cc..d84927c 100644 --- a/packages/methods/shield-oauth/src/method.rs +++ b/packages/methods/shield-oauth/src/method.rs @@ -1,38 +1,37 @@ +use std::sync::Arc; + use async_trait::async_trait; -use chrono::{DateTime, Duration, FixedOffset, Utc}; -use oauth2::{ - AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, - basic::BasicTokenResponse, url::form_urlencoded::parse, -}; -use shield::{ - Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Method, Provider, - ProviderError, Response, Session, SessionError, ShieldError, ShieldOptions, - SignInCallbackRequest, SignInRequest, SignOutRequest, UpdateUser, User, -}; +use shield::{Action, Method, ShieldError, User, erased_method}; use crate::{ - CreateOauthConnection, OauthConnection, UpdateOauthConnection, - client::async_http_client, - provider::{OauthProvider, OauthProviderPkceCodeChallenge}, - session::OauthSession, + actions::{OauthSignInAction, OauthSignInCallbackAction, OauthSignOutAction}, + options::OauthOptions, + provider::OauthProvider, storage::OauthStorage, }; pub const OAUTH_METHOD_ID: &str = "oauth"; pub struct OauthMethod { + options: OauthOptions, providers: Vec, - storage: Box>, + storage: Arc>, } impl OauthMethod { pub fn new + 'static>(storage: S) -> Self { Self { + options: OauthOptions::default(), providers: vec![], - storage: Box::new(storage), + storage: Arc::new(storage), } } + pub fn with_options(mut self, options: OauthOptions) -> Self { + self.options = options; + self + } + pub fn with_providers>(mut self, providers: I) -> Self { self.providers = providers.into_iter().collect(); self @@ -41,13 +40,13 @@ impl OauthMethod { async fn oauth_provider_by_id_or_slug( &self, provider_id: &str, - ) -> Result { + ) -> Result, ShieldError> { if let Some(provider) = self .providers .iter() .find(|provider| provider.id == provider_id) { - return Ok(provider.clone()); + return Ok(Some(provider.clone())); } if let Some(provider) = self @@ -55,382 +54,49 @@ impl OauthMethod { .oauth_provider_by_id_or_slug(provider_id) .await? { - return Ok(provider); - } - - Err(ProviderError::ProviderNotFound(provider_id.to_owned()).into()) - } - - async fn create_user(&self, email: Option<&str>, name: Option<&str>) -> Result { - if let Some(email) = email { - match self.storage.user_by_email(email).await? { - Some(_) => Err(ShieldError::Validation( - "\ - Email address `{email}` is already used by another account. \ - To link a new provider, sign in to with your exising account first. \ - If this is not your account, please contact support for assistence.\ - " - .to_owned(), - )), - None => Ok(self - .storage - .create_user( - CreateUser { - name: name.map(ToOwned::to_owned), - }, - CreateEmailAddress { - email: email.to_string(), - is_primary: true, - // TODO: from claim? - is_verified: false, - // TODO: generate if not verified - verification_token: None, - verification_token_expired_at: None, - verified_at: None, - }, - ) - .await?), - } - } else { - Err(ShieldError::Validation( - "Missing email address in OpenID Connect claims.".to_owned(), - )) + return Ok(Some(provider)); } - } - - async fn update_user(&self, user_id: &str, name: Option<&str>) -> Result { - self.storage - .update_user(UpdateUser { - id: user_id.to_owned(), - name: name.map(ToOwned::to_owned).map(Some), - }) - .await - .map_err(ShieldError::Storage) - } - async fn create_oauth_connection( - &self, - provider_id: String, - user_id: String, - identifier: String, - token_response: BasicTokenResponse, - ) -> Result { - let (token_type, access_token, refresh_token, expired_at, scopes) = - parse_token_response(token_response)?; - - self.storage - .create_oauth_connection(CreateOauthConnection { - identifier, - token_type, - access_token, - refresh_token, - expired_at, - scopes, - provider_id, - user_id, - }) - .await - .map_err(ShieldError::Storage) - } - - async fn update_oauth_connection( - &self, - connection_id: String, - token_response: BasicTokenResponse, - ) -> Result { - let (token_type, access_token, refresh_token, expired_at, scopes) = - parse_token_response(token_response)?; - - self.storage - .update_oauth_connection(UpdateOauthConnection { - id: connection_id, - token_type: Some(token_type), - access_token: Some(access_token), - refresh_token: refresh_token.map(Some), - expired_at: expired_at.map(Some), - scopes: scopes.map(Some), - }) - .await - .map_err(ShieldError::Storage) + Ok(None) } } #[async_trait] -impl Method for OauthMethod { +impl Method for OauthMethod { fn id(&self) -> String { OAUTH_METHOD_ID.to_owned() } - async fn providers(&self) -> Result>, ShieldError> { - let providers = self + fn actions(&self) -> Vec>> { + vec![ + Box::new(OauthSignInAction), + Box::new(OauthSignInCallbackAction::new( + self.options.clone(), + self.storage.clone(), + )), + Box::new(OauthSignOutAction), + ] + } + + async fn providers(&self) -> Result, ShieldError> { + Ok(self .providers .iter() .cloned() - .chain(self.storage.oauth_providers().await?); - - Ok(providers - .map(|provider| Box::new(provider) as Box) + .chain(self.storage.oauth_providers().await?) .collect()) } async fn provider_by_id( &self, - provider_id: &str, - ) -> Result>, ShieldError> { - self.oauth_provider_by_id_or_slug(provider_id) - .await - .map(|provider| Some(Box::new(provider) as Box)) - } - - async fn sign_in( - &self, - request: SignInRequest, - session: Session, - _options: &ShieldOptions, - ) -> Result { - let provider = match request.provider_id { - Some(provider_id) => self.oauth_provider_by_id_or_slug(&provider_id).await?, - None => return Err(ProviderError::ProviderMissing.into()), - }; - - let client = provider.oauth_client().await?; - - let mut authorization_request = client - .authorize_url(CsrfToken::new_random) - .map_err(|err| ConfigurationError::Invalid(err.to_string()))?; - - let pkce_code_challenge = match provider.pkce_code_challenge { - OauthProviderPkceCodeChallenge::None => None, - OauthProviderPkceCodeChallenge::Plain => Some(PkceCodeChallenge::new_random_plain()), - OauthProviderPkceCodeChallenge::S256 => Some(PkceCodeChallenge::new_random_sha256()), - }; - - if let Some((pkce_code_challenge, _)) = &pkce_code_challenge { - authorization_request = - authorization_request.set_pkce_challenge(pkce_code_challenge.clone()); - } - - if let Some(scopes) = provider.scopes { - authorization_request = - authorization_request.add_scopes(scopes.into_iter().map(Scope::new)); - } - - if let Some(authorization_url_params) = provider.authorization_url_params { - let params = parse(authorization_url_params.trim_start_matches('?').as_bytes()); - - for (name, value) in params { - authorization_request = - authorization_request.add_extra_param(name.into_owned(), value.into_owned()); - } - } - - let (auth_url, csrf_token) = authorization_request.url(); - - { - let session_data = session.data(); - let mut session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.authentication = None; - - session_data.set_method( - OAUTH_METHOD_ID, - OauthSession { - csrf: Some(csrf_token.secret().clone()), - pkce_verifier: pkce_code_challenge - .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()), - oauth_connection_id: None, - }, - )?; - } - - Ok(Response::Redirect(auth_url.to_string())) - } - - async fn sign_in_callback( - &self, - request: SignInCallbackRequest, - session: Session, - options: &ShieldOptions, - ) -> Result { - let OauthSession { - csrf, - pkce_verifier, - .. - } = { - let session_data = session.data(); - let session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.method(OAUTH_METHOD_ID)? - }; - - let state = request - .query - .as_ref() - .and_then(|query| query.get("state")) - .and_then(|code| code.as_str()) - .ok_or_else(|| ShieldError::Validation("Missing state.".to_owned()))?; - - if csrf.is_none_or(|csrf| csrf != state) { - return Err(ShieldError::Validation("Invalid state.".to_owned())); - } - - let authorization_code = request - .query - .as_ref() - .and_then(|query| query.get("code")) - .and_then(|code| code.as_str()) - .ok_or_else(|| ShieldError::Validation("Missing authorization code.".to_owned()))?; - - let provider = match request.provider_id { - Some(provider_id) => self.oauth_provider_by_id_or_slug(&provider_id).await?, - None => return Err(ProviderError::ProviderMissing.into()), - }; - - let client = provider.oauth_client().await?; - - let mut token_request = client - .exchange_code(AuthorizationCode::new(authorization_code.to_owned())) - .map_err(|err| { - ShieldError::Configuration(ConfigurationError::Missing(err.to_string())) - })?; - - if let Some(pkce_verifier) = pkce_verifier { - token_request = token_request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier)); - } else if provider.pkce_code_challenge != OauthProviderPkceCodeChallenge::None { - return Err(ShieldError::Validation("Missing PKCE verifier.".to_owned())); - } - - if let Some(token_url_params) = provider.token_url_params { - let params = parse(token_url_params.trim_start_matches('?').as_bytes()); - - for (name, value) in params { - token_request = - token_request.add_extra_param(name.into_owned(), value.into_owned()); - } - } - - let async_http_client = async_http_client()?; - - let token_response = token_request - .request_async(&async_http_client) - .await - .map_err(|err| ShieldError::Request(err.to_string()))?; - - // TODO: user info - let identifier = ""; - let email = Some(""); - let name = Some(""); - - let (connection, user) = match self - .storage - .oauth_connection_by_identifier(&provider.id, identifier) - .await? - { - Some(connection) => { - let connection = self - .update_oauth_connection(connection.id, token_response) - .await?; - - let user = self.update_user(&connection.user_id, name).await?; - - (connection, user) - } - None => { - let user = self.create_user(email, name).await?; - - let connection = self - .create_oauth_connection( - provider.id.clone(), - user.id(), - identifier.to_owned(), - token_response, - ) - .await?; - - (connection, user) - } - }; - - session.renew().await?; - - { - let session_data = session.data(); - let mut session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.authentication = Some(Authentication { - method_id: self.id(), - provider_id: Some(provider.id), - user_id: user.id(), - }); - - session_data.set_method( - OAUTH_METHOD_ID, - OauthSession { - csrf: None, - pkce_verifier: None, - oauth_connection_id: Some(connection.id), - }, - )?; + provider_id: Option<&str>, + ) -> Result, ShieldError> { + if let Some(provider_id) = provider_id { + self.oauth_provider_by_id_or_slug(provider_id).await + } else { + Ok(None) } - - Ok(Response::Redirect( - request - .redirect_url - .unwrap_or(options.sign_in_redirect.clone()), - )) - } - - async fn sign_out( - &self, - request: SignOutRequest, - _session: Session, - _options: &ShieldOptions, - ) -> Result, ShieldError> { - let _provider = match request.provider_id { - Some(provider_id) => self.oauth_provider_by_id_or_slug(&provider_id).await?, - None => return Err(ProviderError::ProviderMissing.into()), - }; - - // TODO: OAuth token revocation. - - Ok(None) } } -type ParsedTokenResponse = ( - String, - String, - Option, - Option>, - Option>, -); - -fn parse_token_response( - token_response: BasicTokenResponse, -) -> Result { - Ok(( - token_response.token_type().as_ref().to_string(), - token_response.access_token().secret().clone(), - token_response - .refresh_token() - .map(|refresh_token| refresh_token.secret().clone()), - match token_response.expires_in() { - Some(expires_in) => Some( - (Utc::now() - + Duration::from_std(expires_in) - .map_err(|err| ShieldError::Validation(err.to_string()))?) - .into(), - ), - None => None, - }, - token_response - .scopes() - .map(|scopes| scopes.iter().map(|scope| scope.to_string()).collect()), - )) -} +erased_method!(OauthMethod, ); diff --git a/packages/methods/shield-oauth/src/options.rs b/packages/methods/shield-oauth/src/options.rs new file mode 100644 index 0000000..3bc43e3 --- /dev/null +++ b/packages/methods/shield-oauth/src/options.rs @@ -0,0 +1,14 @@ +use bon::Builder; + +#[derive(Builder, Clone, Debug)] +#[builder(on(String, into), state_mod(vis = "pub(crate)"))] +pub struct OauthOptions { + #[builder(default = "/")] + pub sign_in_redirect: String, +} + +impl Default for OauthOptions { + fn default() -> Self { + Self::builder().build() + } +} diff --git a/packages/methods/shield-oauth/src/provider.rs b/packages/methods/shield-oauth/src/provider.rs index 71b5280..8e6311e 100644 --- a/packages/methods/shield-oauth/src/provider.rs +++ b/packages/methods/shield-oauth/src/provider.rs @@ -6,7 +6,8 @@ use oauth2::{ BasicTokenIntrospectionResponse, BasicTokenResponse, }, }; -use shield::{ConfigurationError, Form, Provider}; +use secrecy::{ExposeSecret, SecretString}; +use shield::{ConfigurationError, Provider}; use crate::method::OAUTH_METHOD_ID; @@ -43,7 +44,7 @@ pub struct OauthProvider { pub slug: Option, pub visibility: OauthProviderVisibility, pub client_id: String, - pub client_secret: Option, + pub client_secret: Option, pub scopes: Option>, pub redirect_url: Option, pub authorization_url: Option, @@ -63,7 +64,8 @@ impl OauthProvider { let mut client = BasicClient::new(ClientId::new(self.client_id.clone())); if let Some(client_secret) = &self.client_secret { - client = client.set_client_secret(ClientSecret::new(client_secret.clone())); + client = client + .set_client_secret(ClientSecret::new(client_secret.expose_secret().to_owned())); } if let Some(redirect_url) = &self.redirect_url { @@ -129,12 +131,4 @@ impl Provider for OauthProvider { fn name(&self) -> String { self.name.clone() } - - fn icon_url(&self) -> Option { - self.icon_url.clone() - } - - fn form(&self) -> Option { - None - } } diff --git a/packages/methods/shield-oidc/Cargo.toml b/packages/methods/shield-oidc/Cargo.toml index 97b199d..f9c3295 100644 --- a/packages/methods/shield-oidc/Cargo.toml +++ b/packages/methods/shield-oidc/Cargo.toml @@ -26,6 +26,7 @@ oauth2 = { version = "5.0.0", default-features = false, features = [ openidconnect = { version = "4.0.0", default-features = false, features = [ "reqwest", ] } +secrecy.workspace = true serde.workspace = true shield.workspace = true tracing.workspace = true diff --git a/packages/methods/shield-oidc/src/actions.rs b/packages/methods/shield-oidc/src/actions.rs new file mode 100644 index 0000000..47d587d --- /dev/null +++ b/packages/methods/shield-oidc/src/actions.rs @@ -0,0 +1,7 @@ +mod sign_in; +mod sign_in_callback; +mod sign_out; + +pub use sign_in::*; +pub use sign_in_callback::*; +pub use sign_out::*; diff --git a/packages/methods/shield-oidc/src/actions/sign_in.rs b/packages/methods/shield-oidc/src/actions/sign_in.rs new file mode 100644 index 0000000..d3b1602 --- /dev/null +++ b/packages/methods/shield-oidc/src/actions/sign_in.rs @@ -0,0 +1,99 @@ +use async_trait::async_trait; +use openidconnect::{ + CsrfToken, Nonce, PkceCodeChallenge, Scope, core::CoreAuthenticationFlow, + url::form_urlencoded::parse, +}; +use shield::{ + Action, Form, Request, Response, SIGN_IN_ACTION_ID, Session, SessionError, ShieldError, + erased_action, +}; + +use crate::{ + method::OIDC_METHOD_ID, + provider::{OidcProvider, OidcProviderPkceCodeChallenge}, + session::OidcSession, +}; + +pub struct OidcSignInAction; + +#[async_trait] +impl Action for OidcSignInAction { + fn id(&self) -> String { + SIGN_IN_ACTION_ID.to_owned() + } + + fn render(&self, _provider: OidcProvider) -> Form { + Form { + inputs: vec![], + attributes: None, + } + } + + async fn call( + &self, + provider: OidcProvider, + session: Session, + _request: Request, + ) -> Result { + let client = provider.oidc_client().await?; + + let mut authorization_request = client.authorize_url( + CoreAuthenticationFlow::AuthorizationCode, + CsrfToken::new_random, + Nonce::new_random, + ); + + let pkce_code_challenge = match provider.pkce_code_challenge { + OidcProviderPkceCodeChallenge::None => None, + OidcProviderPkceCodeChallenge::Plain => Some(PkceCodeChallenge::new_random_plain()), + OidcProviderPkceCodeChallenge::S256 => Some(PkceCodeChallenge::new_random_sha256()), + }; + + if let Some((pkce_code_challenge, _)) = &pkce_code_challenge { + authorization_request = + authorization_request.set_pkce_challenge(pkce_code_challenge.clone()); + } + + if let Some(scopes) = provider.scopes { + authorization_request = + authorization_request.add_scopes(scopes.into_iter().map(Scope::new)); + } + + if let Some(authorization_url_params) = provider.authorization_url_params { + let params = parse(authorization_url_params.trim_start_matches('?').as_bytes()); + + for (name, value) in params { + authorization_request = + authorization_request.add_extra_param(name.into_owned(), value.into_owned()); + } + } + + let (auth_url, csrf_token, nonce) = authorization_request.url(); + + { + // TODO: Add a generic type for session data to actions, so the action caller can be read/write the session. + + let session_data = session.data(); + let mut session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.authentication = None; + + session_data.set_method( + OIDC_METHOD_ID, + OidcSession { + csrf: Some(csrf_token.secret().clone()), + nonce: Some(nonce.secret().clone()), + pkce_verifier: pkce_code_challenge + .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()), + oidc_connection_id: None, + }, + )?; + } + + Ok(Response::Redirect(auth_url.to_string())) + } +} + +erased_action!(OidcSignInAction); diff --git a/packages/methods/shield-oidc/src/actions/sign_in_callback.rs b/packages/methods/shield-oidc/src/actions/sign_in_callback.rs new file mode 100644 index 0000000..2f6fff6 --- /dev/null +++ b/packages/methods/shield-oidc/src/actions/sign_in_callback.rs @@ -0,0 +1,342 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use chrono::{DateTime, Duration, FixedOffset, Utc}; +use openidconnect::{ + AuthorizationCode, EmptyAdditionalClaims, Nonce, OAuth2TokenResponse, PkceCodeVerifier, + TokenResponse, UserInfoClaims, + core::{CoreGenderClaim, CoreTokenResponse}, + url::form_urlencoded::parse, +}; +use secrecy::SecretString; +use shield::{ + Action, Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Form, Request, + Response, SIGN_IN_CALLBACK_ACTION_ID, Session, SessionError, ShieldError, UpdateUser, User, + erased_action, +}; +use tracing::debug; + +use crate::{ + claims::Claims, + client::async_http_client, + connection::{CreateOidcConnection, OidcConnection, UpdateOidcConnection}, + method::OIDC_METHOD_ID, + options::OidcOptions, + provider::{OidcProvider, OidcProviderPkceCodeChallenge}, + session::OidcSession, + storage::OidcStorage, +}; + +pub struct OidcSignInCallbackAction { + options: OidcOptions, + storage: Arc>, +} + +impl OidcSignInCallbackAction { + pub fn new(options: OidcOptions, storage: Arc>) -> Self { + Self { options, storage } + } + + // TODO: Consider if there is a better location for the functions below. + + async fn create_user(&self, claims: &Claims) -> Result { + if let Some(email) = claims.email() { + match self.storage.user_by_email(email).await? { + Some(_) => Err(ShieldError::Validation( + "\ + Email address `{email}` is already used by another account. \ + To link a new provider, sign in to with your exising account first. \ + If this is not your account, please contact support for assistence.\ + " + .to_owned(), + )), + None => Ok(self + .storage + .create_user( + CreateUser { + name: claims + .name() + .and_then(|name| name.get(None).map(|name| name.to_string())), + }, + CreateEmailAddress { + email: email.to_string(), + is_primary: true, + // TODO: from claim? + is_verified: false, + // TODO: generate if not verified + verification_token: None, + verification_token_expired_at: None, + verified_at: None, + }, + ) + .await?), + } + } else { + Err(ShieldError::Validation( + "Missing email address in OpenID Connect claims.".to_owned(), + )) + } + } + + async fn update_user(&self, user_id: &str, claims: &Claims) -> Result { + self.storage + .update_user(UpdateUser { + id: user_id.to_owned(), + name: claims + .name() + .and_then(|name| name.get(None).map(|name| name.to_string())) + .map(Some), + }) + .await + .map_err(ShieldError::Storage) + } + + async fn create_oidc_connection( + &self, + provider_id: String, + user_id: String, + identifier: String, + token_response: CoreTokenResponse, + ) -> Result { + let (token_type, access_token, refresh_token, id_token, expired_at, scopes) = + parse_token_response(token_response)?; + + self.storage + .create_oidc_connection(CreateOidcConnection { + identifier, + token_type, + access_token, + refresh_token, + id_token, + expired_at, + scopes, + provider_id, + user_id, + }) + .await + .map_err(ShieldError::Storage) + } + + async fn update_oidc_connection( + &self, + connection_id: String, + token_response: CoreTokenResponse, + ) -> Result { + let (token_type, access_token, refresh_token, id_token, expired_at, scopes) = + parse_token_response(token_response)?; + + self.storage + .update_oidc_connection(UpdateOidcConnection { + id: connection_id, + token_type: Some(token_type), + access_token: Some(access_token), + refresh_token: refresh_token.map(Some), + id_token: id_token.map(Some), + expired_at: expired_at.map(Some), + scopes: scopes.map(Some), + }) + .await + .map_err(ShieldError::Storage) + } +} + +#[async_trait] +impl Action for OidcSignInCallbackAction { + fn id(&self) -> String { + SIGN_IN_CALLBACK_ACTION_ID.to_owned() + } + + fn render(&self, _provider: OidcProvider) -> Form { + Form { + inputs: vec![], + attributes: None, + } + } + + async fn call( + &self, + provider: OidcProvider, + session: Session, + request: Request, + ) -> Result { + let OidcSession { + csrf, + nonce, + pkce_verifier, + .. + } = { + let session_data = session.data(); + let session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.method(OIDC_METHOD_ID)? + }; + + let state = request + .query + .get("state") + .and_then(|code| code.as_str()) + .ok_or_else(|| ShieldError::Validation("Missing state.".to_owned()))?; + + if csrf.is_none_or(|csrf| csrf != state) { + return Err(ShieldError::Validation("Invalid state.".to_owned())); + } + + let authorization_code = request + .query + .get("code") + .and_then(|code| code.as_str()) + .ok_or_else(|| ShieldError::Validation("Missing authorization code.".to_owned()))?; + + let client = provider.oidc_client().await?; + + let mut token_request = client + .exchange_code(AuthorizationCode::new(authorization_code.to_owned())) + .map_err(|err| { + ShieldError::Configuration(ConfigurationError::Missing(err.to_string())) + })?; + + if let Some(pkce_verifier) = pkce_verifier { + token_request = token_request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier)); + } else if provider.pkce_code_challenge != OidcProviderPkceCodeChallenge::None { + return Err(ShieldError::Validation("Missing PKCE verifier.".to_owned())); + } + + if let Some(token_url_params) = provider.token_url_params { + let params = parse(token_url_params.trim_start_matches('?').as_bytes()); + + for (name, value) in params { + token_request = + token_request.add_extra_param(name.into_owned(), value.into_owned()); + } + } + + let async_http_client = async_http_client()?; + + let token_response = token_request + .request_async(&async_http_client) + .await + .map_err(|err| ShieldError::Request(err.to_string()))?; + + let claims = if let Some(id_token) = token_response.id_token() { + let claims = id_token + .claims( + &client.id_token_verifier(), + &Nonce::new( + nonce + .ok_or_else(|| ShieldError::Validation("Missing nonce.".to_owned()))?, + ), + ) + .map_err(|err| ShieldError::Validation(err.to_string()))?; + + Claims::from(claims.clone()) + } else { + let claims: UserInfoClaims = client + .user_info(token_response.access_token().to_owned(), None) + .map_err(|err| ConfigurationError::Missing(err.to_string()))? + .request_async(&async_http_client) + .await + .map_err(|err| ShieldError::Request(err.to_string()))?; + + Claims::from(claims) + }; + + debug!("{:?}\n{:?}", claims.subject(), claims); + + let (connection, user) = match self + .storage + .oidc_connection_by_identifier(&provider.id, claims.subject()) + .await? + { + Some(connection) => { + let connection = self + .update_oidc_connection(connection.id, token_response) + .await?; + + let user = self.update_user(&connection.user_id, &claims).await?; + + (connection, user) + } + None => { + let user = self.create_user(&claims).await?; + + let connection = self + .create_oidc_connection( + provider.id.clone(), + user.id(), + claims.subject().to_string(), + token_response, + ) + .await?; + + (connection, user) + } + }; + + session.renew().await?; + + { + let session_data = session.data(); + let mut session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.authentication = Some(Authentication { + method_id: self.id(), + provider_id: Some(provider.id), + user_id: user.id(), + }); + + session_data.set_method( + OIDC_METHOD_ID, + OidcSession { + csrf: None, + nonce: None, + pkce_verifier: None, + oidc_connection_id: Some(connection.id), + }, + )?; + } + + Ok(Response::Redirect(self.options.sign_in_redirect.clone())) + } +} + +erased_action!(OidcSignInCallbackAction, ); + +type ParsedTokenResponse = ( + String, + SecretString, + Option, + Option, + Option>, + Option>, +); + +fn parse_token_response( + token_response: CoreTokenResponse, +) -> Result { + Ok(( + token_response.token_type().as_ref().to_string(), + token_response.access_token().secret().as_str().into(), + token_response + .refresh_token() + .map(|refresh_token| refresh_token.secret().as_str().into()), + token_response + .id_token() + .map(|id_token| id_token.to_string().into()), + match token_response.expires_in() { + Some(expires_in) => Some( + (Utc::now() + + Duration::from_std(expires_in) + .map_err(|err| ShieldError::Validation(err.to_string()))?) + .into(), + ), + None => None, + }, + token_response + .scopes() + .map(|scopes| scopes.iter().map(|scope| scope.to_string()).collect()), + )) +} diff --git a/packages/methods/shield-oidc/src/actions/sign_out.rs b/packages/methods/shield-oidc/src/actions/sign_out.rs new file mode 100644 index 0000000..3c1adde --- /dev/null +++ b/packages/methods/shield-oidc/src/actions/sign_out.rs @@ -0,0 +1,84 @@ +use async_trait::async_trait; +use shield::{ + Action, Form, Request, Response, SIGN_OUT_ACTION_ID, Session, ShieldError, erased_action, +}; + +use crate::provider::OidcProvider; + +pub struct OidcSignOutAction; + +#[async_trait] +impl Action for OidcSignOutAction { + fn id(&self) -> String { + SIGN_OUT_ACTION_ID.to_owned() + } + + fn render(&self, _provider: OidcProvider) -> Form { + Form { + inputs: vec![], + attributes: None, + } + } + + async fn call( + &self, + _provider: OidcProvider, + _session: Session, + _request: Request, + ) -> Result { + // TODO: See [`OidcProvider::oidc_client`]. + + // let provider = match request.provider_id { + // Some(provider_id) => self.oidc_provider_by_id_or_slug(&provider_id).await?, + // None => return Err(ProviderError::ProviderMissing.into()), + // }; + + // let connection_id = { + // let session_data = session.data(); + // let session_data = session_data + // .lock() + // .map_err(|err| SessionError::Lock(err.to_string()))?; + + // session_data.oidc_connection_id.clone() + // }; + + // if let Some(connection_id) = connection_id { + // if let Some(connection) = self.storage.oidc_connection_by_id(&connection_id).await? { + // debug!("revoking access token {:?}", connection.access_token); + + // let token = AccessToken::new(connection.access_token); + + // let client = subprovider.oidc_client().await?; + + // let revocation_request = match client.revoke_token(token.into()) { + // Ok(revocation_request) => Some(revocation_request), + // Err(openidconnect::ConfigurationError::MissingUrl("revocation")) => None, + // Err(err) => return Err(ConfigurationError::Invalid(err.to_string()).into()), + // }; + + // if let Some(revocation_request) = revocation_request { + // let mut revocation_request = revocation_request; + + // if let Some(revocation_url_params) = subprovider.revocation_url_params { + // let params = + // parse(revocation_url_params.trim_start_matches('?').as_bytes()); + + // for (name, value) in params { + // revocation_request = revocation_request + // .add_extra_param(name.into_owned(), value.into_owned()); + // } + // } + + // revocation_request + // .request_async(async_http_client) + // .await + // .map_err(|err| ShieldError::Request(err.to_string()))?; + // } + // } + // } + + Ok(Response::Default) + } +} + +erased_action!(OidcSignOutAction); diff --git a/packages/methods/shield-oidc/src/connection.rs b/packages/methods/shield-oidc/src/connection.rs index 271aa20..c8727a5 100644 --- a/packages/methods/shield-oidc/src/connection.rs +++ b/packages/methods/shield-oidc/src/connection.rs @@ -1,13 +1,14 @@ use chrono::{DateTime, FixedOffset}; +use secrecy::SecretString; #[derive(Clone, Debug)] pub struct OidcConnection { pub id: String, pub identifier: String, pub token_type: String, - pub access_token: String, - pub refresh_token: Option, - pub id_token: Option, + pub access_token: SecretString, + pub refresh_token: Option, + pub id_token: Option, pub expired_at: Option>, pub scopes: Option>, pub provider_id: String, @@ -18,9 +19,9 @@ pub struct OidcConnection { pub struct CreateOidcConnection { pub identifier: String, pub token_type: String, - pub access_token: String, - pub refresh_token: Option, - pub id_token: Option, + pub access_token: SecretString, + pub refresh_token: Option, + pub id_token: Option, pub expired_at: Option>, pub scopes: Option>, pub provider_id: String, @@ -31,9 +32,9 @@ pub struct CreateOidcConnection { pub struct UpdateOidcConnection { pub id: String, pub token_type: Option, - pub access_token: Option, - pub refresh_token: Option>, - pub id_token: Option>, + pub access_token: Option, + pub refresh_token: Option>, + pub id_token: Option>, pub expired_at: Option>>, pub scopes: Option>>, } diff --git a/packages/methods/shield-oidc/src/lib.rs b/packages/methods/shield-oidc/src/lib.rs index b3ce963..98cc550 100644 --- a/packages/methods/shield-oidc/src/lib.rs +++ b/packages/methods/shield-oidc/src/lib.rs @@ -1,9 +1,11 @@ +mod actions; mod builders; mod claims; mod client; mod connection; mod metadata; mod method; +mod options; mod provider; mod session; mod storage; @@ -11,5 +13,6 @@ mod storage; pub use builders::*; pub use connection::*; pub use method::*; +pub use options::*; pub use provider::*; pub use storage::*; diff --git a/packages/methods/shield-oidc/src/method.rs b/packages/methods/shield-oidc/src/method.rs index da33d23..eb3da79 100644 --- a/packages/methods/shield-oidc/src/method.rs +++ b/packages/methods/shield-oidc/src/method.rs @@ -1,39 +1,37 @@ +use std::sync::Arc; + use async_trait::async_trait; -use chrono::{DateTime, Duration, FixedOffset, Utc}; -use openidconnect::{ - AuthorizationCode, CsrfToken, EmptyAdditionalClaims, Nonce, OAuth2TokenResponse, - PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, UserInfoClaims, - core::{CoreAuthenticationFlow, CoreGenderClaim, CoreTokenResponse}, - url::form_urlencoded::parse, -}; -use shield::{ - Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Method, Provider, - ProviderError, Response, Session, SessionError, ShieldError, ShieldOptions, - SignInCallbackRequest, SignInRequest, SignOutRequest, UpdateUser, User, -}; -use tracing::debug; +use shield::{Action, Method, ShieldError, User, erased_method}; use crate::{ - CreateOidcConnection, OidcConnection, OidcProviderPkceCodeChallenge, UpdateOidcConnection, - claims::Claims, client::async_http_client, provider::OidcProvider, session::OidcSession, + actions::{OidcSignInAction, OidcSignInCallbackAction, OidcSignOutAction}, + options::OidcOptions, + provider::OidcProvider, storage::OidcStorage, }; pub const OIDC_METHOD_ID: &str = "oidc"; pub struct OidcMethod { + options: OidcOptions, providers: Vec, - storage: Box>, + storage: Arc>, } impl OidcMethod { pub fn new + 'static>(storage: S) -> Self { Self { + options: OidcOptions::default(), providers: vec![], - storage: Box::new(storage), + storage: Arc::new(storage), } } + pub fn with_options(mut self, options: OidcOptions) -> Self { + self.options = options; + self + } + pub fn with_providers>(mut self, providers: I) -> Self { self.providers = providers.into_iter().collect(); self @@ -42,13 +40,13 @@ impl OidcMethod { async fn oidc_provider_by_id_or_slug( &self, provider_id: &str, - ) -> Result { + ) -> Result, ShieldError> { if let Some(provider) = self .providers .iter() .find(|provider| provider.id == provider_id) { - return Ok(provider.clone()); + return Ok(Some(provider.clone())); } if let Some(provider) = self @@ -56,462 +54,49 @@ impl OidcMethod { .oidc_provider_by_id_or_slug(provider_id) .await? { - return Ok(provider); + return Ok(Some(provider)); } - Err(ProviderError::ProviderNotFound(provider_id.to_owned()).into()) - } - - async fn create_user(&self, claims: &Claims) -> Result { - if let Some(email) = claims.email() { - match self.storage.user_by_email(email).await? { - Some(_) => Err(ShieldError::Validation( - "\ - Email address `{email}` is already used by another account. \ - To link a new provider, sign in to with your exising account first. \ - If this is not your account, please contact support for assistence.\ - " - .to_owned(), - )), - None => Ok(self - .storage - .create_user( - CreateUser { - name: claims - .name() - .and_then(|name| name.get(None).map(|name| name.to_string())), - }, - CreateEmailAddress { - email: email.to_string(), - is_primary: true, - // TODO: from claim? - is_verified: false, - // TODO: generate if not verified - verification_token: None, - verification_token_expired_at: None, - verified_at: None, - }, - ) - .await?), - } - } else { - Err(ShieldError::Validation( - "Missing email address in OpenID Connect claims.".to_owned(), - )) - } - } - - async fn update_user(&self, user_id: &str, claims: &Claims) -> Result { - self.storage - .update_user(UpdateUser { - id: user_id.to_owned(), - name: claims - .name() - .and_then(|name| name.get(None).map(|name| name.to_string())) - .map(Some), - }) - .await - .map_err(ShieldError::Storage) - } - - async fn create_oidc_connection( - &self, - provider_id: String, - user_id: String, - identifier: String, - token_response: CoreTokenResponse, - ) -> Result { - let (token_type, access_token, refresh_token, id_token, expired_at, scopes) = - parse_token_response(token_response)?; - - self.storage - .create_oidc_connection(CreateOidcConnection { - identifier, - token_type, - access_token, - refresh_token, - id_token, - expired_at, - scopes, - provider_id, - user_id, - }) - .await - .map_err(ShieldError::Storage) - } - - async fn update_oidc_connection( - &self, - connection_id: String, - token_response: CoreTokenResponse, - ) -> Result { - let (token_type, access_token, refresh_token, id_token, expired_at, scopes) = - parse_token_response(token_response)?; - - self.storage - .update_oidc_connection(UpdateOidcConnection { - id: connection_id, - token_type: Some(token_type), - access_token: Some(access_token), - refresh_token: refresh_token.map(Some), - id_token: id_token.map(Some), - expired_at: expired_at.map(Some), - scopes: scopes.map(Some), - }) - .await - .map_err(ShieldError::Storage) + Ok(None) } } #[async_trait] -impl Method for OidcMethod { +impl Method for OidcMethod { fn id(&self) -> String { OIDC_METHOD_ID.to_owned() } - async fn providers(&self) -> Result>, ShieldError> { - let providers = self + fn actions(&self) -> Vec>> { + vec![ + Box::new(OidcSignInAction), + Box::new(OidcSignInCallbackAction::new( + self.options.clone(), + self.storage.clone(), + )), + Box::new(OidcSignOutAction), + ] + } + + async fn providers(&self) -> Result, ShieldError> { + Ok(self .providers .iter() .cloned() - .chain(self.storage.oidc_providers().await?); - - Ok(providers - .map(|provider| Box::new(provider) as Box) + .chain(self.storage.oidc_providers().await?) .collect()) } async fn provider_by_id( &self, - provider_id: &str, - ) -> Result>, ShieldError> { - self.oidc_provider_by_id_or_slug(provider_id) - .await - .map(|provider| Some(Box::new(provider) as Box)) - } - - async fn sign_in( - &self, - request: SignInRequest, - session: Session, - _options: &ShieldOptions, - ) -> Result { - let provider = match request.provider_id { - Some(provider_id) => self.oidc_provider_by_id_or_slug(&provider_id).await?, - None => return Err(ProviderError::ProviderMissing.into()), - }; - - let client = provider.oidc_client().await?; - - let mut authorization_request = client.authorize_url( - CoreAuthenticationFlow::AuthorizationCode, - CsrfToken::new_random, - Nonce::new_random, - ); - - let pkce_code_challenge = match provider.pkce_code_challenge { - OidcProviderPkceCodeChallenge::None => None, - OidcProviderPkceCodeChallenge::Plain => Some(PkceCodeChallenge::new_random_plain()), - OidcProviderPkceCodeChallenge::S256 => Some(PkceCodeChallenge::new_random_sha256()), - }; - - if let Some((pkce_code_challenge, _)) = &pkce_code_challenge { - authorization_request = - authorization_request.set_pkce_challenge(pkce_code_challenge.clone()); - } - - if let Some(scopes) = provider.scopes { - authorization_request = - authorization_request.add_scopes(scopes.into_iter().map(Scope::new)); - } - - if let Some(authorization_url_params) = provider.authorization_url_params { - let params = parse(authorization_url_params.trim_start_matches('?').as_bytes()); - - for (name, value) in params { - authorization_request = - authorization_request.add_extra_param(name.into_owned(), value.into_owned()); - } - } - - let (auth_url, csrf_token, nonce) = authorization_request.url(); - - { - let session_data = session.data(); - let mut session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.authentication = None; - - session_data.set_method( - OIDC_METHOD_ID, - OidcSession { - csrf: Some(csrf_token.secret().clone()), - nonce: Some(nonce.secret().clone()), - pkce_verifier: pkce_code_challenge - .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()), - oidc_connection_id: None, - }, - )?; - } - - Ok(Response::Redirect(auth_url.to_string())) - } - - async fn sign_in_callback( - &self, - request: SignInCallbackRequest, - session: Session, - options: &ShieldOptions, - ) -> Result { - let OidcSession { - csrf, - nonce, - pkce_verifier, - .. - } = { - let session_data = session.data(); - let session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.method(OIDC_METHOD_ID)? - }; - - let state = request - .query - .as_ref() - .and_then(|query| query.get("state")) - .and_then(|code| code.as_str()) - .ok_or_else(|| ShieldError::Validation("Missing state.".to_owned()))?; - - if csrf.is_none_or(|csrf| csrf != state) { - return Err(ShieldError::Validation("Invalid state.".to_owned())); - } - - let authorization_code = request - .query - .as_ref() - .and_then(|query| query.get("code")) - .and_then(|code| code.as_str()) - .ok_or_else(|| ShieldError::Validation("Missing authorization code.".to_owned()))?; - - let provider = match request.provider_id { - Some(provider_id) => self.oidc_provider_by_id_or_slug(&provider_id).await?, - None => return Err(ProviderError::ProviderMissing.into()), - }; - - let client = provider.oidc_client().await?; - - let mut token_request = client - .exchange_code(AuthorizationCode::new(authorization_code.to_owned())) - .map_err(|err| { - ShieldError::Configuration(ConfigurationError::Missing(err.to_string())) - })?; - - if let Some(pkce_verifier) = pkce_verifier { - token_request = token_request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier)); - } else if provider.pkce_code_challenge != OidcProviderPkceCodeChallenge::None { - return Err(ShieldError::Validation("Missing PKCE verifier.".to_owned())); - } - - if let Some(token_url_params) = provider.token_url_params { - let params = parse(token_url_params.trim_start_matches('?').as_bytes()); - - for (name, value) in params { - token_request = - token_request.add_extra_param(name.into_owned(), value.into_owned()); - } - } - - let async_http_client = async_http_client()?; - - let token_response = token_request - .request_async(&async_http_client) - .await - .map_err(|err| ShieldError::Request(err.to_string()))?; - - let claims = if let Some(id_token) = token_response.id_token() { - let claims = id_token - .claims( - &client.id_token_verifier(), - &Nonce::new( - nonce - .ok_or_else(|| ShieldError::Validation("Missing nonce.".to_owned()))?, - ), - ) - .map_err(|err| ShieldError::Validation(err.to_string()))?; - - Claims::from(claims.clone()) + provider_id: Option<&str>, + ) -> Result, ShieldError> { + if let Some(provider_id) = provider_id { + self.oidc_provider_by_id_or_slug(provider_id).await } else { - let claims: UserInfoClaims = client - .user_info(token_response.access_token().to_owned(), None) - .map_err(|err| ConfigurationError::Missing(err.to_string()))? - .request_async(&async_http_client) - .await - .map_err(|err| ShieldError::Request(err.to_string()))?; - - Claims::from(claims) - }; - - debug!("{:?}\n{:?}", claims.subject(), claims); - - let (connection, user) = match self - .storage - .oidc_connection_by_identifier(&provider.id, claims.subject()) - .await? - { - Some(connection) => { - let connection = self - .update_oidc_connection(connection.id, token_response) - .await?; - - let user = self.update_user(&connection.user_id, &claims).await?; - - (connection, user) - } - None => { - let user = self.create_user(&claims).await?; - - let connection = self - .create_oidc_connection( - provider.id.clone(), - user.id(), - claims.subject().to_string(), - token_response, - ) - .await?; - - (connection, user) - } - }; - - session.renew().await?; - - { - let session_data = session.data(); - let mut session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.authentication = Some(Authentication { - method_id: self.id(), - provider_id: Some(provider.id), - user_id: user.id(), - }); - - session_data.set_method( - OIDC_METHOD_ID, - OidcSession { - csrf: None, - nonce: None, - pkce_verifier: None, - oidc_connection_id: Some(connection.id), - }, - )?; + Ok(None) } - - Ok(Response::Redirect( - request - .redirect_url - .unwrap_or(options.sign_in_redirect.clone()), - )) - } - - async fn sign_out( - &self, - _request: SignOutRequest, - _session: Session, - _options: &ShieldOptions, - ) -> Result, ShieldError> { - // TODO: See [`OidcProvider::oidc_client`]. - - // let provider = match request.provider_id { - // Some(provider_id) => self.oidc_provider_by_id_or_slug(&provider_id).await?, - // None => return Err(ProviderError::ProviderMissing.into()), - // }; - - // let connection_id = { - // let session_data = session.data(); - // let session_data = session_data - // .lock() - // .map_err(|err| SessionError::Lock(err.to_string()))?; - - // session_data.oidc_connection_id.clone() - // }; - - // if let Some(connection_id) = connection_id { - // if let Some(connection) = self.storage.oidc_connection_by_id(&connection_id).await? { - // debug!("revoking access token {:?}", connection.access_token); - - // let token = AccessToken::new(connection.access_token); - - // let client = subprovider.oidc_client().await?; - - // let revocation_request = match client.revoke_token(token.into()) { - // Ok(revocation_request) => Some(revocation_request), - // Err(openidconnect::ConfigurationError::MissingUrl("revocation")) => None, - // Err(err) => return Err(ConfigurationError::Invalid(err.to_string()).into()), - // }; - - // if let Some(revocation_request) = revocation_request { - // let mut revocation_request = revocation_request; - - // if let Some(revocation_url_params) = subprovider.revocation_url_params { - // let params = - // parse(revocation_url_params.trim_start_matches('?').as_bytes()); - - // for (name, value) in params { - // revocation_request = revocation_request - // .add_extra_param(name.into_owned(), value.into_owned()); - // } - // } - - // revocation_request - // .request_async(async_http_client) - // .await - // .map_err(|err| ShieldError::Request(err.to_string()))?; - // } - // } - // } - - Ok(None) } } -type ParsedTokenResponse = ( - String, - String, - Option, - Option, - Option>, - Option>, -); - -fn parse_token_response( - token_response: CoreTokenResponse, -) -> Result { - Ok(( - token_response.token_type().as_ref().to_string(), - token_response.access_token().secret().clone(), - token_response - .refresh_token() - .map(|refresh_token| refresh_token.secret().clone()), - token_response - .id_token() - .map(|id_token| id_token.to_string()), - match token_response.expires_in() { - Some(expires_in) => Some( - (Utc::now() - + Duration::from_std(expires_in) - .map_err(|err| ShieldError::Validation(err.to_string()))?) - .into(), - ), - None => None, - }, - token_response - .scopes() - .map(|scopes| scopes.iter().map(|scope| scope.to_string()).collect()), - )) -} +erased_method!(OidcMethod, ); diff --git a/packages/methods/shield-oidc/src/options.rs b/packages/methods/shield-oidc/src/options.rs new file mode 100644 index 0000000..336e268 --- /dev/null +++ b/packages/methods/shield-oidc/src/options.rs @@ -0,0 +1,14 @@ +use bon::Builder; + +#[derive(Builder, Clone, Debug)] +#[builder(on(String, into), state_mod(vis = "pub(crate)"))] +pub struct OidcOptions { + #[builder(default = "/")] + pub sign_in_redirect: String, +} + +impl Default for OidcOptions { + fn default() -> Self { + Self::builder().build() + } +} diff --git a/packages/methods/shield-oidc/src/provider.rs b/packages/methods/shield-oidc/src/provider.rs index 7b51d0e..52e6755 100644 --- a/packages/methods/shield-oidc/src/provider.rs +++ b/packages/methods/shield-oidc/src/provider.rs @@ -10,6 +10,7 @@ use openidconnect::{ CoreTokenResponse, }, }; +use secrecy::{ExposeSecret, SecretString}; use shield::{ConfigurationError, Provider}; use crate::{ @@ -50,9 +51,13 @@ pub enum OidcProviderPkceCodeChallenge { Plain, S256, } - +#[expect(clippy::duplicated_attributes)] #[derive(Builder, Clone, Debug)] -#[builder(on(String, into), state_mod(vis = "pub(crate)"))] +#[builder( + on(String, into), + on(SecretString, into), + state_mod(vis = "pub(crate)") +)] pub struct OidcProvider { pub id: String, pub name: String, @@ -61,7 +66,7 @@ pub struct OidcProvider { #[builder(default = OidcProviderVisibility::Public)] pub visibility: OidcProviderVisibility, pub client_id: String, - pub client_secret: Option, + pub client_secret: Option, pub scopes: Option>, pub redirect_url: Option, pub discovery_url: Option, @@ -177,7 +182,9 @@ impl OidcProvider { let mut client = CoreClient::from_provider_metadata( provider_metadata, ClientId::new(self.client_id.clone()), - self.client_secret.clone().map(ClientSecret::new), + self.client_secret + .clone() + .map(|client_secret| ClientSecret::new(client_secret.expose_secret().to_owned())), ); // TODO: Upstream: _option version of these (and other) functions which set the type to EndpointMaybeSet. @@ -217,12 +224,4 @@ impl Provider for OidcProvider { fn name(&self) -> String { self.name.clone() } - - fn icon_url(&self) -> Option { - self.icon_url.clone() - } - - fn form(&self) -> Option { - None - } } diff --git a/packages/storage/shield-memory/Cargo.toml b/packages/storage/shield-memory/Cargo.toml index 1227608..58b69f1 100644 --- a/packages/storage/shield-memory/Cargo.toml +++ b/packages/storage/shield-memory/Cargo.toml @@ -13,13 +13,13 @@ default = [] all-methods = [ # "method-credentials", # "method-email", - # "method-oauth", + "method-oauth", "method-oidc", # "method-webauthn", ] # method-credentials = ["dep:shield-credentials"] # method-email = ["dep:shield-email"] -# method-oauth = ["dep:shield-oauth"] +method-oauth = ["dep:shield-oauth"] method-oidc = ["dep:shield-oidc"] [dependencies] @@ -28,7 +28,7 @@ serde.workspace = true shield.workspace = true # shield-credentials = { workspace = true, optional = true } # shield-email = { workspace = true, optional = true } -# shield-oauth = { workspace = true, optional = true } +shield-oauth = { workspace = true, optional = true } shield-oidc = { workspace = true, optional = true } # shield-webauthn = { workspace = true, optional = true } uuid = { workspace = true, features = ["v4"] } diff --git a/packages/storage/shield-memory/src/providers.rs b/packages/storage/shield-memory/src/providers.rs index 6ba5800..6408e6b 100644 --- a/packages/storage/shield-memory/src/providers.rs +++ b/packages/storage/shield-memory/src/providers.rs @@ -1,2 +1,4 @@ +#[cfg(feature = "method-oauth")] +pub mod oauth; #[cfg(feature = "method-oidc")] pub mod oidc; diff --git a/packages/storage/shield-memory/src/providers/oauth.rs b/packages/storage/shield-memory/src/providers/oauth.rs new file mode 100644 index 0000000..9bde888 --- /dev/null +++ b/packages/storage/shield-memory/src/providers/oauth.rs @@ -0,0 +1,129 @@ +use std::sync::{Arc, Mutex}; + +use async_trait::async_trait; +use shield::StorageError; +use shield_oauth::{ + CreateOauthConnection, OauthConnection, OauthProvider, OauthStorage, UpdateOauthConnection, +}; +use uuid::Uuid; + +use crate::{storage::MemoryStorage, user::User}; + +#[derive(Clone, Debug, Default)] +pub struct OauthMemoryStorage { + connections: Arc>>, +} + +#[async_trait] +impl OauthStorage for MemoryStorage { + async fn oauth_providers(&self) -> Result, StorageError> { + Ok(vec![]) + } + + async fn oauth_provider_by_id_or_slug( + &self, + _provider_id: &str, + ) -> Result, StorageError> { + Ok(None) + } + + async fn oauth_connection_by_id( + &self, + connection_id: &str, + ) -> Result, StorageError> { + Ok(self + .oauth + .connections + .lock() + .map_err(|err| StorageError::Engine(err.to_string()))? + .iter() + .find(|connection| connection.id == connection_id) + .cloned()) + } + + async fn oauth_connection_by_identifier( + &self, + provider_id: &str, + identifier: &str, + ) -> Result, StorageError> { + Ok(self + .oauth + .connections + .lock() + .map_err(|err| StorageError::Engine(err.to_string()))? + .iter() + .find(|connection| { + connection.provider_id == provider_id && connection.identifier == identifier + }) + .cloned()) + } + + async fn create_oauth_connection( + &self, + connection: CreateOauthConnection, + ) -> Result { + let connection = OauthConnection { + id: Uuid::new_v4().to_string(), + identifier: connection.identifier, + token_type: connection.token_type, + access_token: connection.access_token, + refresh_token: connection.refresh_token, + expired_at: connection.expired_at, + scopes: connection.scopes, + provider_id: connection.provider_id, + user_id: connection.user_id, + }; + + self.oauth + .connections + .lock() + .map_err(|err| StorageError::Engine(err.to_string()))? + .push(connection.clone()); + + Ok(connection) + } + + async fn update_oauth_connection( + &self, + connection: UpdateOauthConnection, + ) -> Result { + let mut connections = self + .oauth + .connections + .lock() + .map_err(|err| StorageError::Engine(err.to_string()))?; + + let connection_mut = connections + .iter_mut() + .find(|c| c.id == connection.id) + .ok_or_else(|| StorageError::NotFound("User".to_owned(), connection.id.clone()))?; + + if let Some(token_type) = connection.token_type { + connection_mut.token_type = token_type; + } + if let Some(access_token) = connection.access_token { + connection_mut.access_token = access_token; + } + if let Some(refresh_token) = connection.refresh_token { + connection_mut.refresh_token = refresh_token; + } + if let Some(expired_at) = connection.expired_at { + connection_mut.expired_at = expired_at; + } + if let Some(scopes) = connection.scopes { + connection_mut.scopes = scopes; + } + + Ok(connection_mut.clone()) + } + + async fn delete_oauth_connection(&self, connection_id: &str) -> Result<(), StorageError> { + self.oauth + .connections + .lock() + .map_err(|err| StorageError::Engine(err.to_string()))? + .retain(|connection| connection.id != connection_id); + + Ok(()) + } +} diff --git a/packages/storage/shield-memory/src/storage.rs b/packages/storage/shield-memory/src/storage.rs index cf2333e..c7d2bc9 100644 --- a/packages/storage/shield-memory/src/storage.rs +++ b/packages/storage/shield-memory/src/storage.rs @@ -13,6 +13,8 @@ pub const MEMORY_STORAGE_ID: &str = "memory"; #[derive(Clone, Debug, Default)] pub struct MemoryStorage { pub(crate) users: Arc>>, + #[cfg(feature = "method-oauth")] + pub(crate) oauth: crate::providers::oauth::OauthMemoryStorage, #[cfg(feature = "method-oidc")] pub(crate) oidc: crate::providers::oidc::OidcMemoryStorage, } diff --git a/packages/storage/shield-sea-orm/Cargo.toml b/packages/storage/shield-sea-orm/Cargo.toml index aa315d3..44c7144 100644 --- a/packages/storage/shield-sea-orm/Cargo.toml +++ b/packages/storage/shield-sea-orm/Cargo.toml @@ -32,6 +32,7 @@ async-trait.workspace = true chrono.workspace = true sea-orm.workspace = true sea-orm-migration.workspace = true +secrecy.workspace = true serde = { workspace = true, features = ["derive"] } serde_json.workspace = true shield.workspace = true diff --git a/packages/storage/shield-sea-orm/src/lib.rs b/packages/storage/shield-sea-orm/src/lib.rs index 2d0c2ba..62323d4 100644 --- a/packages/storage/shield-sea-orm/src/lib.rs +++ b/packages/storage/shield-sea-orm/src/lib.rs @@ -1,7 +1,7 @@ pub mod base; pub mod entities; +mod methods; pub mod migrations; -mod providers; mod storage; mod user; diff --git a/packages/storage/shield-sea-orm/src/providers.rs b/packages/storage/shield-sea-orm/src/methods.rs similarity index 100% rename from packages/storage/shield-sea-orm/src/providers.rs rename to packages/storage/shield-sea-orm/src/methods.rs diff --git a/packages/storage/shield-sea-orm/src/providers/oauth.rs b/packages/storage/shield-sea-orm/src/methods/oauth.rs similarity index 92% rename from packages/storage/shield-sea-orm/src/providers/oauth.rs rename to packages/storage/shield-sea-orm/src/methods/oauth.rs index 1c9852c..14309a6 100644 --- a/packages/storage/shield-sea-orm/src/providers/oauth.rs +++ b/packages/storage/shield-sea-orm/src/methods/oauth.rs @@ -1,5 +1,6 @@ use async_trait::async_trait; use sea_orm::{ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, QueryFilter}; +use secrecy::ExposeSecret; use shield::StorageError; use shield_oauth::{ CreateOauthConnection, OauthConnection, OauthProvider, OauthProviderPkceCodeChallenge, @@ -76,8 +77,12 @@ impl OauthStorage for SeaOrmStorage { let active_model = oauth_provider_connection::ActiveModel { identifier: ActiveValue::Set(connection.identifier), token_type: ActiveValue::Set(connection.token_type), - access_token: ActiveValue::Set(connection.access_token), - refresh_token: ActiveValue::Set(connection.refresh_token), + access_token: ActiveValue::Set(connection.access_token.expose_secret().to_owned()), + refresh_token: ActiveValue::Set( + connection + .refresh_token + .map(|refresh_token| refresh_token.expose_secret().to_owned()), + ), expired_at: ActiveValue::Set(connection.expired_at), scopes: ActiveValue::Set(connection.scopes.map(|scopes| scopes.join(","))), provider_id: ActiveValue::Set(Self::parse_uuid(&connection.provider_id)?), @@ -109,10 +114,12 @@ impl OauthStorage for SeaOrmStorage { active_model.token_type = ActiveValue::Set(token_type); } if let Some(access_token) = connection.access_token { - active_model.access_token = ActiveValue::Set(access_token); + active_model.access_token = ActiveValue::Set(access_token.expose_secret().to_owned()); } if let Some(refresh_token) = connection.refresh_token { - active_model.refresh_token = ActiveValue::Set(refresh_token); + active_model.refresh_token = ActiveValue::Set( + refresh_token.map(|refresh_token| refresh_token.expose_secret().to_owned()), + ); } if let Some(expired_at) = connection.expired_at { active_model.expired_at = ActiveValue::Set(expired_at); @@ -173,7 +180,7 @@ impl TryFrom for OauthProvider { icon_url: value.icon_url, visibility: value.visibility.into(), client_id: value.client_id, - client_secret: value.client_secret, + client_secret: value.client_secret.map(Into::into), scopes: value .scopes .map(|scopes| scopes.split(',').map(|s| s.to_string()).collect()), @@ -197,8 +204,8 @@ impl From for OauthConnection { id: value.id.to_string(), identifier: value.identifier, token_type: value.token_type, - access_token: value.access_token, - refresh_token: value.refresh_token, + access_token: value.access_token.into(), + refresh_token: value.refresh_token.map(Into::into), expired_at: value.expired_at, scopes: value .scopes diff --git a/packages/storage/shield-sea-orm/src/providers/oidc.rs b/packages/storage/shield-sea-orm/src/methods/oidc.rs similarity index 89% rename from packages/storage/shield-sea-orm/src/providers/oidc.rs rename to packages/storage/shield-sea-orm/src/methods/oidc.rs index 8ab2a85..8b323cd 100644 --- a/packages/storage/shield-sea-orm/src/providers/oidc.rs +++ b/packages/storage/shield-sea-orm/src/methods/oidc.rs @@ -1,5 +1,6 @@ use async_trait::async_trait; use sea_orm::{ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, QueryFilter}; +use secrecy::ExposeSecret; use shield::StorageError; use shield_oidc::{ CreateOidcConnection, OidcConnection, OidcProvider, OidcProviderPkceCodeChallenge, @@ -74,9 +75,17 @@ impl OidcStorage for SeaOrmStorage { let active_model = oidc_provider_connection::ActiveModel { identifier: ActiveValue::Set(connection.identifier), token_type: ActiveValue::Set(connection.token_type), - access_token: ActiveValue::Set(connection.access_token), - refresh_token: ActiveValue::Set(connection.refresh_token), - id_token: ActiveValue::Set(connection.id_token), + access_token: ActiveValue::Set(connection.access_token.expose_secret().to_owned()), + refresh_token: ActiveValue::Set( + connection + .refresh_token + .map(|refresh_token| refresh_token.expose_secret().to_owned()), + ), + id_token: ActiveValue::Set( + connection + .id_token + .map(|id_token| id_token.expose_secret().to_owned()), + ), expired_at: ActiveValue::Set(connection.expired_at), scopes: ActiveValue::Set(connection.scopes.map(|scopes| scopes.join(","))), provider_id: ActiveValue::Set(Self::parse_uuid(&connection.provider_id)?), @@ -108,13 +117,16 @@ impl OidcStorage for SeaOrmStorage { active_model.token_type = ActiveValue::Set(token_type); } if let Some(access_token) = connection.access_token { - active_model.access_token = ActiveValue::Set(access_token); + active_model.access_token = ActiveValue::Set(access_token.expose_secret().to_owned()); } if let Some(refresh_token) = connection.refresh_token { - active_model.refresh_token = ActiveValue::Set(refresh_token); + active_model.refresh_token = ActiveValue::Set( + refresh_token.map(|refresh_token| refresh_token.expose_secret().to_owned()), + ); } if let Some(id_token) = connection.id_token { - active_model.id_token = ActiveValue::Set(id_token); + active_model.id_token = + ActiveValue::Set(id_token.map(|id_token| id_token.expose_secret().to_owned())); } if let Some(expired_at) = connection.expired_at { active_model.expired_at = ActiveValue::Set(expired_at); @@ -175,7 +187,7 @@ impl TryFrom for OidcProvider { icon_url: value.icon_url, visibility: value.visibility.into(), client_id: value.client_id, - client_secret: value.client_secret, + client_secret: value.client_secret.map(Into::into), scopes: value .scopes .map(|scopes| scopes.split(',').map(|s| s.to_string()).collect()), @@ -208,9 +220,9 @@ impl From for OidcConnection { id: value.id.to_string(), identifier: value.identifier, token_type: value.token_type, - access_token: value.access_token, - refresh_token: value.refresh_token, - id_token: value.id_token, + access_token: value.access_token.into(), + refresh_token: value.refresh_token.map(Into::into), + id_token: value.id_token.map(Into::into), expired_at: value.expired_at, scopes: value .scopes