use std::{num::NonZeroU32, time::Duration};
use camino::Utf8PathBuf;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use super::ConfigurationSection;
use crate::schema;
#[allow(clippy::unnecessary_wraps)]
fn default_connection_string() -> Option<String> {
Some("postgresql://".to_owned())
}
fn default_max_connections() -> NonZeroU32 {
NonZeroU32::new(10).unwrap()
}
fn default_connect_timeout() -> Duration {
Duration::from_secs(30)
}
#[allow(clippy::unnecessary_wraps)]
fn default_idle_timeout() -> Option<Duration> {
Some(Duration::from_secs(10 * 60))
}
#[allow(clippy::unnecessary_wraps)]
fn default_max_lifetime() -> Option<Duration> {
Some(Duration::from_secs(30 * 60))
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
uri: default_connection_string(),
host: None,
port: None,
socket: None,
username: None,
password: None,
database: None,
ssl_mode: None,
ssl_ca: None,
ssl_ca_file: None,
ssl_certificate: None,
ssl_certificate_file: None,
ssl_key: None,
ssl_key_file: None,
max_connections: default_max_connections(),
min_connections: Default::default(),
connect_timeout: default_connect_timeout(),
idle_timeout: default_idle_timeout(),
max_lifetime: default_max_lifetime(),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "kebab-case")]
pub enum PgSslMode {
Disable,
Allow,
Prefer,
Require,
VerifyCa,
VerifyFull,
}
#[serde_as]
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct DatabaseConfig {
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(url, default = "default_connection_string")]
pub uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option::<schema::Hostname>")]
pub host: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(range(min = 1, max = 65535))]
pub port: Option<u16>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub socket: Option<Utf8PathBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub username: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub password: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub database: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ssl_mode: Option<PgSslMode>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ssl_ca: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub ssl_ca_file: Option<Utf8PathBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ssl_certificate: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub ssl_certificate_file: Option<Utf8PathBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ssl_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub ssl_key_file: Option<Utf8PathBuf>,
#[serde(default = "default_max_connections")]
pub max_connections: NonZeroU32,
#[serde(default)]
pub min_connections: u32,
#[schemars(with = "u64")]
#[serde(default = "default_connect_timeout")]
#[serde_as(as = "serde_with::DurationSeconds<u64>")]
pub connect_timeout: Duration,
#[schemars(with = "Option<u64>")]
#[serde(
default = "default_idle_timeout",
skip_serializing_if = "Option::is_none"
)]
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
pub idle_timeout: Option<Duration>,
#[schemars(with = "u64")]
#[serde(
default = "default_max_lifetime",
skip_serializing_if = "Option::is_none"
)]
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
pub max_lifetime: Option<Duration>,
}
impl ConfigurationSection for DatabaseConfig {
const PATH: Option<&'static str> = Some("database");
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::error::Error> {
let metadata = figment.find_metadata(Self::PATH.unwrap());
let annotate = |mut error: figment::Error| {
error.metadata = metadata.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![Self::PATH.unwrap().to_owned()];
Err(error)
};
let has_split_options = self.host.is_some()
|| self.port.is_some()
|| self.socket.is_some()
|| self.username.is_some()
|| self.password.is_some()
|| self.database.is_some();
if self.uri.is_some() && has_split_options {
return annotate(figment::error::Error::from(
"uri must not be specified if host, port, socket, username, password, or database are specified".to_owned(),
));
}
if self.ssl_ca.is_some() && self.ssl_ca_file.is_some() {
return annotate(figment::error::Error::from(
"ssl_ca must not be specified if ssl_ca_file is specified".to_owned(),
));
}
if self.ssl_certificate.is_some() && self.ssl_certificate_file.is_some() {
return annotate(figment::error::Error::from(
"ssl_certificate must not be specified if ssl_certificate_file is specified"
.to_owned(),
));
}
if self.ssl_key.is_some() && self.ssl_key_file.is_some() {
return annotate(figment::error::Error::from(
"ssl_key must not be specified if ssl_key_file is specified".to_owned(),
));
}
if (self.ssl_key.is_some() || self.ssl_key_file.is_some())
^ (self.ssl_certificate.is_some() || self.ssl_certificate_file.is_some())
{
return annotate(figment::error::Error::from(
"both a ssl_certificate and a ssl_key must be set at the same time or none of them"
.to_owned(),
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use figment::{
providers::{Format, Yaml},
Figment, Jail,
};
use super::*;
#[test]
fn load_config() {
Jail::expect_with(|jail| {
jail.create_file(
"config.yaml",
r"
database:
uri: postgresql://user:password@host/database
",
)?;
let config = Figment::new()
.merge(Yaml::file("config.yaml"))
.extract_inner::<DatabaseConfig>("database")?;
assert_eq!(
config.uri.as_deref(),
Some("postgresql://user:password@host/database")
);
Ok(())
});
}
}