1use std::{env, fs::File, io::BufReader, path::PathBuf, time::Duration};
39
40use log::{debug, error, info, warn};
41use tonic::transport::{Channel, Endpoint};
42#[cfg(feature = "mtls")]
43#[cfg(feature = "mtls")]
44use rustls::ClientConfig;
45#[cfg(feature = "mtls")]
46use rustls::RootCertStore;
47
48pub const DEFAULT_MOUNTAIN_ADDRESS:&str = "[::1]:50051";
55
56pub const DEFAULT_CONNECTION_TIMEOUT_SECS:u64 = 5;
58
59pub const DEFAULT_REQUEST_TIMEOUT_SECS:u64 = 30;
61
62#[cfg(feature = "mtls")]
67#[derive(Debug, Clone)]
68pub struct TlsConfig {
69 pub ca_cert_path:Option<PathBuf>,
72
73 pub client_cert_path:Option<PathBuf>,
75
76 pub client_key_path:Option<PathBuf>,
78
79 pub server_name:Option<String>,
81
82 pub verify_certs:bool,
84}
85
86#[cfg(feature = "mtls")]
87impl Default for TlsConfig {
88 fn default() -> Self {
89 Self {
90 ca_cert_path:None,
91 client_cert_path:None,
92 client_key_path:None,
93 server_name:None,
94 verify_certs:true,
95 }
96 }
97}
98
99#[cfg(feature = "mtls")]
100impl TlsConfig {
101 pub fn server_auth(ca_cert_path:PathBuf) -> Self {
109 Self {
110 ca_cert_path:Some(ca_cert_path),
111 client_cert_path:None,
112 client_key_path:None,
113 server_name:Some("localhost".to_string()),
114 verify_certs:true,
115 }
116 }
117
118 pub fn mtls(ca_cert_path:PathBuf, client_cert_path:PathBuf, client_key_path:PathBuf) -> Self {
128 Self {
129 ca_cert_path:Some(ca_cert_path),
130 client_cert_path:Some(client_cert_path),
131 client_key_path:Some(client_key_path),
132 server_name:Some("localhost".to_string()),
133 verify_certs:true,
134 }
135 }
136}
137
138#[cfg(feature = "mtls")]
149pub fn create_tls_client_config(tls_config:&TlsConfig) -> Result<ClientConfig, Box<dyn std::error::Error>> {
150 info!("Creating TLS client configuration");
151
152 let mut root_store = RootCertStore::empty();
154
155 if let Some(ca_path) = &tls_config.ca_cert_path {
156 debug!("Loading CA certificate from {:?}", ca_path);
158 let ca_file = File::open(ca_path).map_err(|e| format!("Failed to open CA certificate file: {}", e))?;
159 let mut reader = BufReader::new(ca_file);
160
161 let certs:Result<Vec<_>, _> = rustls_pemfile::certs(&mut reader).collect();
162 let certs = certs.map_err(|e| format!("Failed to parse CA certificate: {}", e))?;
163
164 if certs.is_empty() {
165 return Err("No CA certificates found in file".into());
166 }
167
168 for cert in certs {
169 root_store
170 .add(cert)
171 .map_err(|e| format!("Failed to add CA certificate to root store: {}", e))?;
172 }
173
174 info!("Loaded CA certificate from {:?}", ca_path);
175 } else {
176 debug!("Loading system root certificates");
178 let native_certs = rustls_native_certs::load_native_certs()
179 .map_err(|e| format!("Failed to load system root certificates: {}", e))?;
180
181 if native_certs.is_empty() {
182 warn!("No system root certificates found");
183 }
184
185 for cert in native_certs {
186 root_store
187 .add(cert)
188 .map_err(|e| format!("Failed to add system certificate to root store: {}", e))?;
189 }
190
191 info!("Loaded {} system root certificates", root_store.len());
192 }
193
194 let client_certs = if tls_config.client_cert_path.is_some() && tls_config.client_key_path.is_some() {
196 let cert_path = tls_config.client_cert_path.as_ref().unwrap();
197 let key_path = tls_config.client_key_path.as_ref().unwrap();
198
199 debug!("Loading client certificate from {:?}", cert_path);
200 let cert_file = File::open(cert_path).map_err(|e| format!("Failed to open client certificate file: {}", e))?;
201 let mut cert_reader = BufReader::new(cert_file);
202
203 let certs:Result<Vec<_>, _> = rustls_pemfile::certs(&mut cert_reader).collect();
204 let certs = certs.map_err(|e| format!("Failed to parse client certificate: {}", e))?;
205
206 if certs.is_empty() {
207 return Err("No client certificates found in file".into());
208 }
209
210 debug!("Loading client private key from {:?}", key_path);
211 let key_file = File::open(key_path).map_err(|e| format!("Failed to open private key file: {}", e))?;
212 let mut key_reader = BufReader::new(key_file);
213
214 let key = rustls_pemfile::private_key(&mut key_reader)
215 .map_err(|e| format!("Failed to parse private key: {}", e))?
216 .ok_or("No private key found in file")?;
217
218 Some((certs, key))
219 } else {
220 None
221 };
222
223 let mut config = match client_certs {
225 Some((certs, key)) => {
226 let client_config = ClientConfig::builder()
228 .with_root_certificates(root_store)
229 .with_client_auth_cert(certs, key)
230 .map_err(|e| format!("Failed to configure client authentication: {}", e))?;
231
232 info!("Configured mTLS with client certificate");
233
234 client_config
235 },
236 None => {
237 let client_config = ClientConfig::builder().with_root_certificates(root_store).with_no_client_auth();
240
241 info!("Configured TLS with server authentication only");
242
243 client_config
244 },
245 };
246
247 config.alpn_protocols = vec![b"h2".to_vec()];
249
250 if !tls_config.verify_certs {
255 warn!("Certificate verification disabled - this is NOT secure for production!");
256 }
259
260 info!("TLS client configuration created successfully");
261
262 Ok(config)
263}
264
265#[derive(Debug, Clone)]
267pub struct MountainClientConfig {
268 pub address:String,
270
271 pub connection_timeout_secs:u64,
273
274 pub request_timeout_secs:u64,
276
277 #[cfg(feature = "mtls")]
279 pub tls_config:Option<TlsConfig>,
280}
281
282impl Default for MountainClientConfig {
283 fn default() -> Self {
284 Self {
285 address:DEFAULT_MOUNTAIN_ADDRESS.to_string(),
286 connection_timeout_secs:DEFAULT_CONNECTION_TIMEOUT_SECS,
287 request_timeout_secs:DEFAULT_REQUEST_TIMEOUT_SECS,
288 #[cfg(feature = "mtls")]
289 tls_config:None,
290 }
291 }
292}
293
294impl MountainClientConfig {
295 pub fn new(address:impl Into<String>) -> Self { Self { address:address.into(), ..Default::default() } }
303
304 pub fn from_env() -> Self {
324 let address = env::var("MOUNTAIN_ADDRESS").unwrap_or_else(|_| DEFAULT_MOUNTAIN_ADDRESS.to_string());
325
326 let connection_timeout_secs = env::var("MOUNTAIN_CONNECTION_TIMEOUT_SECS")
327 .ok()
328 .and_then(|s| s.parse().ok())
329 .unwrap_or(DEFAULT_CONNECTION_TIMEOUT_SECS);
330
331 let request_timeout_secs = env::var("MOUNTAIN_REQUEST_TIMEOUT_SECS")
332 .ok()
333 .and_then(|s| s.parse().ok())
334 .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS);
335
336 #[cfg(feature = "mtls")]
337 let tls_config = if env::var("MOUNTAIN_TLS_ENABLED")
338 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
339 .unwrap_or(false)
340 {
341 Some(TlsConfig {
342 ca_cert_path:env::var("MOUNTAIN_CA_CERT").ok().map(PathBuf::from),
343 client_cert_path:env::var("MOUNTAIN_CLIENT_CERT").ok().map(PathBuf::from),
344 client_key_path:env::var("MOUNTAIN_CLIENT_KEY").ok().map(PathBuf::from),
345 server_name:env::var("MOUNTAIN_SERVER_NAME").ok(),
346 verify_certs:env::var("MOUNTAIN_VERIFY_CERTS")
347 .map(|v| v != "0" && !v.eq_ignore_ascii_case("false"))
348 .unwrap_or(true),
349 })
350 } else {
351 None
352 };
353
354 #[cfg(not(feature = "mtls"))]
355 let tls_config = None;
356
357 Self {
358 address,
359 connection_timeout_secs,
360 request_timeout_secs,
361 #[cfg(feature = "mtls")]
362 tls_config,
363 }
364 }
365
366 pub fn with_connection_timeout(mut self, timeout_secs:u64) -> Self {
374 self.connection_timeout_secs = timeout_secs;
375 self
376 }
377
378 pub fn with_request_timeout(mut self, timeout_secs:u64) -> Self {
386 self.request_timeout_secs = timeout_secs;
387 self
388 }
389
390 #[cfg(feature = "mtls")]
398 pub fn with_tls(mut self, tls_config:TlsConfig) -> Self {
399 self.tls_config = Some(tls_config);
400 self
401 }
402}
403
404#[derive(Debug, Clone)]
410pub struct MountainClient {
411 channel:Channel,
413
414 config:MountainClientConfig,
416}
417
418impl MountainClient {
419 pub async fn connect(config:MountainClientConfig) -> Result<Self, Box<dyn std::error::Error>> {
430 info!("Connecting to Mountain at {}", config.address);
431
432 let endpoint = Endpoint::from_shared(config.address.clone())?
433 .connect_timeout(Duration::from_secs(config.connection_timeout_secs));
434
435 #[cfg(feature = "mtls")]
437 if let Some(tls_config) = &config.tls_config {
438 info!("TLS configuration provided, configuring secure connection");
439
440 let _client_config = create_tls_client_config(tls_config).map_err(|e| {
441 error!("Failed to create TLS client configuration: {}", e);
442 format!("TLS configuration error: {}", e)
443 })?;
444
445 let domain_name = tls_config.server_name.clone().unwrap_or_else(|| "localhost".to_string());
447 info!("Setting server name for SNI: {}", domain_name);
448
449 let tls = tonic::transport::ClientTlsConfig::new().domain_name(domain_name.clone());
451 let channel = endpoint
452 .tcp_keepalive(Some(Duration::from_secs(60)))
453 .tls_config(tls)?
454 .connect()
455 .await
456 .map_err(|e| format!("Failed to connect with TLS: {}", e))?;
457
458 info!("Successfully connected to Mountain at {} with TLS", config.address);
459 return Ok(Self { channel, config });
460 }
461
462 debug!("Using unencrypted connection");
464 let channel = endpoint.connect().await?;
465 info!("Successfully connected to Mountain at {}", config.address);
466
467 Ok(Self { channel, config })
468 }
469
470 pub fn channel(&self) -> &Channel { &self.channel }
475
476 pub fn config(&self) -> &MountainClientConfig { &self.config }
481
482 pub async fn health_check(&self) -> Result<bool, Box<dyn std::error::Error>> {
489 debug!("Checking Mountain health");
490
491 match tokio::time::timeout(Duration::from_secs(self.config.request_timeout_secs), async {
493 Ok::<(), Box<dyn std::error::Error>>(())
496 })
497 .await
498 {
499 Ok(Ok(())) => {
500 debug!("Mountain health check: healthy");
501 Ok(true)
502 },
503 Ok(Err(e)) => {
504 warn!("Mountain health check: disconnected - {}", e);
505 Ok(false)
506 },
507 Err(_) => {
508 warn!("Mountain health check: timeout");
509 Ok(false)
510 },
511 }
512 }
513
514 pub async fn get_status(&self) -> Result<String, Box<dyn std::error::Error>> {
522 debug!("Getting Mountain status");
523
524 Ok("connected".to_string())
527 }
528
529 pub async fn get_config(&self, key:&str) -> Result<Option<String>, Box<dyn std::error::Error>> {
540 debug!("Getting Mountain config: {}", key);
541
542 Ok(None)
545 }
546
547 pub async fn set_config(&self, key:&str, value:&str) -> Result<(), Box<dyn std::error::Error>> {
559 debug!("Setting Mountain config: {} = {}", key, value);
560
561 Ok(())
564 }
565}
566
567pub async fn connect_to_mountain() -> Result<MountainClient, Box<dyn std::error::Error>> {
572 MountainClient::connect(MountainClientConfig::default()).await
573}
574
575pub async fn connect_to_mountain_at(address:impl Into<String>) -> Result<MountainClient, Box<dyn std::error::Error>> {
583 MountainClient::connect(MountainClientConfig::new(address)).await
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn test_default_config() {
592 let config = MountainClientConfig::default();
593 assert_eq!(config.address, DEFAULT_MOUNTAIN_ADDRESS);
594 assert_eq!(config.connection_timeout_secs, DEFAULT_CONNECTION_TIMEOUT_SECS);
595 assert_eq!(config.request_timeout_secs, DEFAULT_REQUEST_TIMEOUT_SECS);
596 }
597
598 #[test]
599 fn test_config_builder() {
600 let config = MountainClientConfig::new("[::1]:50060")
601 .with_connection_timeout(10)
602 .with_request_timeout(60);
603
604 assert_eq!(config.address, "[::1]:50060");
605 assert_eq!(config.connection_timeout_secs, 10);
606 assert_eq!(config.request_timeout_secs, 60);
607 }
608
609 #[cfg(feature = "mtls")]
610 #[test]
611 fn test_tls_config_server_auth() {
612 let tls = TlsConfig::server_auth(std::path::PathBuf::from("/path/to/ca.pem"));
613 assert_eq!(tls.server_name, Some("localhost".to_string()));
614 assert!(tls.client_cert_path.is_none());
615 assert!(tls.client_key_path.is_none());
616 assert!(tls.ca_cert_path.is_some());
617 assert!(tls.verify_certs);
618 }
619
620 #[cfg(feature = "mtls")]
621 #[test]
622 fn test_tls_config_mtls() {
623 let tls = TlsConfig::mtls(
624 std::path::PathBuf::from("/path/to/ca.pem"),
625 std::path::PathBuf::from("/path/to/cert.pem"),
626 std::path::PathBuf::from("/path/to/key.pem"),
627 );
628 assert!(tls.client_cert_path.is_some());
629 assert!(tls.client_key_path.is_some());
630 assert!(tls.ca_cert_path.is_some());
631 assert!(tls.verify_certs);
632 assert_eq!(tls.server_name, Some("localhost".to_string()));
633 }
634
635 #[cfg(feature = "mtls")]
636 #[test]
637 fn test_tls_config_default() {
638 let tls = TlsConfig::default();
639 assert!(tls.ca_cert_path.is_none());
640 assert!(tls.client_cert_path.is_none());
641 assert!(tls.client_key_path.is_none());
642 assert!(tls.server_name.is_none());
643 assert!(tls.verify_certs);
644 }
645
646 #[test]
647 fn test_from_env_default() {
648 unsafe { env::remove_var("MOUNTAIN_ADDRESS"); }
650 unsafe { env::remove_var("MOUNTAIN_CONNECTION_TIMEOUT_SECS"); }
651 unsafe { env::remove_var("MOUNTAIN_REQUEST_TIMEOUT_SECS"); }
652 unsafe { env::remove_var("MOUNTAIN_TLS_ENABLED"); }
653
654 let config = MountainClientConfig::from_env();
655 assert_eq!(config.address, DEFAULT_MOUNTAIN_ADDRESS);
656 assert_eq!(config.connection_timeout_secs, DEFAULT_CONNECTION_TIMEOUT_SECS);
657 assert_eq!(config.request_timeout_secs, DEFAULT_REQUEST_TIMEOUT_SECS);
658 }
659
660 #[test]
661 fn test_from_env_custom() {
662 unsafe { env::set_var("MOUNTAIN_ADDRESS", "[::1]:50060"); }
663 unsafe { env::set_var("MOUNTAIN_CONNECTION_TIMEOUT_SECS", "10"); }
664 unsafe { env::set_var("MOUNTAIN_REQUEST_TIMEOUT_SECS", "60"); }
665
666 let config = MountainClientConfig::from_env();
667 assert_eq!(config.address, "[::1]:50060");
668 assert_eq!(config.connection_timeout_secs, 10);
669 assert_eq!(config.request_timeout_secs, 60);
670
671 unsafe { env::remove_var("MOUNTAIN_ADDRESS"); }
673 unsafe { env::remove_var("MOUNTAIN_CONNECTION_TIMEOUT_SECS"); }
674 unsafe { env::remove_var("MOUNTAIN_REQUEST_TIMEOUT_SECS"); }
675 }
676
677 #[cfg(feature = "mtls")]
678 #[test]
679 fn test_from_env_tls() {
680 unsafe { env::set_var("MOUNTAIN_TLS_ENABLED", "1"); }
681 unsafe { env::set_var("MOUNTAIN_CA_CERT", "/path/to/ca.pem"); }
682 unsafe { env::set_var("MOUNTAIN_SERVER_NAME", "mymountain.com"); }
683
684 let config = MountainClientConfig::from_env();
685 assert!(config.tls_config.is_some());
686 let tls = config.tls_config.unwrap();
687 assert_eq!(tls.ca_cert_path, Some(std::path::PathBuf::from("/path/to/ca.pem")));
688 assert_eq!(tls.server_name, Some("mymountain.com".to_string()));
689 assert!(tls.verify_certs);
690
691 unsafe { env::remove_var("MOUNTAIN_TLS_ENABLED"); }
693 unsafe { env::remove_var("MOUNTAIN_CA_CERT"); }
694 unsafe { env::remove_var("MOUNTAIN_SERVER_NAME"); }
695 }
696
697 #[cfg(feature = "mtls")]
698 #[test]
699 fn test_from_env_mtls() {
700 unsafe { env::set_var("MOUNTAIN_TLS_ENABLED", "true"); }
701 unsafe { env::set_var("MOUNTAIN_CA_CERT", "/path/to/ca.pem"); }
702 unsafe { env::set_var("MOUNTAIN_CLIENT_CERT", "/path/to/cert.pem"); }
703 unsafe { env::set_var("MOUNTAIN_CLIENT_KEY", "/path/to/key.pem"); }
704
705 let config = MountainClientConfig::from_env();
706 assert!(config.tls_config.is_some());
707 let tls = config.tls_config.unwrap();
708 assert_eq!(tls.ca_cert_path, Some(std::path::PathBuf::from("/path/to/ca.pem")));
709 assert_eq!(tls.client_cert_path, Some(std::path::PathBuf::from("/path/to/cert.pem")));
710 assert_eq!(tls.client_key_path, Some(std::path::PathBuf::from("/path/to/key.pem")));
711 assert!(tls.verify_certs);
712
713 unsafe { env::remove_var("MOUNTAIN_TLS_ENABLED"); }
715 unsafe { env::remove_var("MOUNTAIN_CA_CERT"); }
716 unsafe { env::remove_var("MOUNTAIN_CLIENT_CERT"); }
717 unsafe { env::remove_var("MOUNTAIN_CLIENT_KEY"); }
718 }
719}