Mountain/IPC/Enhanced/
SecureMessageChannel.rs

1//! # Secure Message Channel
2//!
3//! Advanced security enhancements for IPC messages including AES-256-GCM
4//! encryption, HMAC authentication, and secure key management.
5
6use std::{
7	collections::HashMap,
8	marker::PhantomData,
9	sync::Arc,
10	time::{Duration, SystemTime},
11};
12
13use log::{debug, error, info, trace, warn};
14use ring::{
15	aead::{self, AES_256_GCM, LessSafeKey, NONCE_LEN, UnboundKey},
16	hmac,
17	rand::{SecureRandom, SystemRandom},
18};
19use serde::{Deserialize, Serialize};
20use tokio::sync::RwLock;
21use bincode::serde::{decode_from_slice, encode_to_vec};
22
23/// Security configuration
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct SecurityConfig {
26	pub encryption_algorithm:String,
27	pub key_rotation_interval_hours:u64,
28	pub hmac_algorithm:String,
29	pub nonce_size_bytes:usize,
30	pub auth_tag_size_bytes:usize,
31	pub max_message_size_bytes:usize,
32}
33
34impl Default for SecurityConfig {
35	fn default() -> Self {
36		Self {
37			encryption_algorithm:"AES-256-GCM".to_string(),
38			// Rotate encryption keys every 24 hours for forward secrecy.
39			key_rotation_interval_hours:24,
40			hmac_algorithm:"HMAC-SHA256".to_string(),
41			nonce_size_bytes:NONCE_LEN,
42			auth_tag_size_bytes:AES_256_GCM.tag_len(),
43			// Maximum message size: 10MB to prevent memory exhaustion attacks.
44			max_message_size_bytes:10 * 1024 * 1024,
45		}
46	}
47}
48
49/// Encryption key with metadata
50#[derive(Debug, Clone)]
51struct EncryptionKey {
52	key:LessSafeKey,
53	created_at:SystemTime,
54	key_id:String,
55	usage_count:usize,
56}
57
58impl EncryptionKey {
59	fn new(key_bytes:&[u8]) -> Result<Self, String> {
60		let unbound_key =
61			UnboundKey::new(&AES_256_GCM, key_bytes).map_err(|e| format!("Failed to create unbound key: {}", e))?;
62
63		Ok(Self {
64			key:LessSafeKey::new(unbound_key),
65			created_at:SystemTime::now(),
66			key_id:Self::generate_key_id(),
67			usage_count:0,
68		})
69	}
70
71	fn generate_key_id() -> String {
72		let rng = SystemRandom::new();
73		let mut id_bytes = [0u8; 8];
74		rng.fill(&mut id_bytes).unwrap();
75		hex::encode(id_bytes)
76	}
77
78	fn is_expired(&self, rotation_interval:Duration) -> bool {
79		self.created_at.elapsed().unwrap_or_default() > rotation_interval
80	}
81
82	fn increment_usage(&mut self) { self.usage_count += 1; }
83}
84
85/// Secure message channel with encryption and authentication
86pub struct SecureMessageChannel {
87	pub config:SecurityConfig,
88	pub current_key:Arc<RwLock<EncryptionKey>>,
89	pub previous_keys:Arc<RwLock<HashMap<String, EncryptionKey>>>,
90	pub hmac_key:Arc<RwLock<Vec<u8>>>,
91	pub rng:SystemRandom,
92	pub key_rotation_task:Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
93}
94
95impl SecureMessageChannel {
96	/// Create a new secure message channel
97	pub fn new(config:SecurityConfig) -> Result<Self, String> {
98		let rng = SystemRandom::new();
99
100		// Generate encryption key
101		let mut encryption_key_bytes = vec![0u8; 32];
102		rng.fill(&mut encryption_key_bytes)
103			.map_err(|e| format!("Failed to generate encryption key: {}", e))?;
104
105		let encryption_key = EncryptionKey::new(&encryption_key_bytes)?;
106
107		// Generate HMAC key
108		let mut hmac_key = vec![0u8; 32];
109		rng.fill(&mut hmac_key)
110			.map_err(|e| format!("Failed to generate HMAC key: {}", e))?;
111
112		let channel = Self {
113			config,
114			current_key:Arc::new(RwLock::new(encryption_key)),
115			previous_keys:Arc::new(RwLock::new(HashMap::new())),
116			hmac_key:Arc::new(RwLock::new(hmac_key)),
117			rng,
118			key_rotation_task:Arc::new(RwLock::new(None)),
119		};
120
121		info!(
122			"[SecureMessageChannel] Created secure channel with {} encryption",
123			channel.config.encryption_algorithm
124		);
125
126		Ok(channel)
127	}
128
129	/// Start the secure channel with automatic key rotation
130	pub async fn start(&self) -> Result<(), String> {
131		// Start key rotation task
132		self.start_key_rotation().await;
133
134		info!("[SecureMessageChannel] Secure channel started");
135		Ok(())
136	}
137
138	/// Stop the secure channel
139	pub async fn stop(&self) -> Result<(), String> {
140		// Stop key rotation task
141		{
142			let mut rotation_task = self.key_rotation_task.write().await;
143			if let Some(task) = rotation_task.take() {
144				task.abort();
145			}
146		}
147
148		// Clear all cryptographic keys from memory.
149		{
150			let mut current_key = self.current_key.write().await;
151			// Replace with a zeroized key to overwrite sensitive material.
152			*current_key = EncryptionKey::new(&[0u8; 32]).unwrap();
153		}
154
155		{
156			let mut previous_keys = self.previous_keys.write().await;
157			previous_keys.clear();
158		}
159
160		{
161			let mut hmac_key = self.hmac_key.write().await;
162			// Zero out the HMAC key material to prevent leakage.
163			hmac_key.fill(0);
164		}
165
166		info!("[SecureMessageChannel] Secure channel stopped");
167		Ok(())
168	}
169
170	/// Encrypt and authenticate a message
171	pub async fn encrypt_message<T:Serialize>(&self, message:&T) -> Result<EncryptedMessage, String> {
172		// Serialize message
173		let serialized_data = encode_to_vec(message, bincode::config::standard())
174			.map_err(|e| format!("Failed to serialize message: {}", e))?;
175
176		// Check message size
177		if serialized_data.len() > self.config.max_message_size_bytes {
178			return Err(format!("Message too large: {} bytes", serialized_data.len()));
179		}
180
181		// Get current encryption key
182		let mut current_key = self.current_key.write().await;
183		current_key.increment_usage();
184
185		// Generate nonce
186		let mut nonce = vec![0u8; self.config.nonce_size_bytes];
187		self.rng
188			.fill(&mut nonce)
189			.map_err(|e| format!("Failed to generate nonce: {}", e))?;
190
191		// Encrypt message
192		let mut in_out = serialized_data.clone();
193		let nonce_slice:&[u8] = &nonce;
194		let nonce_array:[u8; NONCE_LEN] = nonce_slice.try_into().map_err(|_| "Invalid nonce length".to_string())?;
195
196		let aead_nonce = aead::Nonce::assume_unique_for_key(nonce_array);
197
198		current_key
199			.key
200			.seal_in_place_append_tag(aead_nonce, aead::Aad::empty(), &mut in_out)
201			.map_err(|e| format!("Encryption failed: {}", e))?;
202
203		// Create HMAC
204		let hmac_key = self.hmac_key.read().await;
205		let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &hmac_key);
206		let hmac_tag = hmac::sign(&hmac_key, &in_out);
207
208		let encrypted_message = EncryptedMessage {
209			key_id:current_key.key_id.clone(),
210			nonce:nonce.to_vec(),
211			ciphertext:in_out,
212			hmac_tag:hmac_tag.as_ref().to_vec(),
213			timestamp:SystemTime::now()
214				.duration_since(SystemTime::UNIX_EPOCH)
215				.unwrap_or_default()
216				.as_millis() as u64,
217		};
218
219		trace!(
220			"[SecureMessageChannel] Message encrypted (size: {} bytes)",
221			encrypted_message.ciphertext.len()
222		);
223
224		Ok(encrypted_message)
225	}
226
227	/// Decrypt and verify a message
228	pub async fn decrypt_message<T:for<'de> Deserialize<'de>>(&self, encrypted:&EncryptedMessage) -> Result<T, String> {
229		// Verify HMAC
230		let hmac_key = self.hmac_key.read().await;
231		let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &hmac_key);
232
233		hmac::verify(&hmac_key, &encrypted.ciphertext, &encrypted.hmac_tag)
234			.map_err(|_| "HMAC verification failed".to_string())?;
235
236		// Get encryption key
237		let encryption_key = self.get_encryption_key(&encrypted.key_id).await?;
238
239		// Decrypt message
240		let mut in_out = encrypted.ciphertext.clone();
241		let nonce_slice:&[u8] = &encrypted.nonce;
242		let nonce_array:[u8; NONCE_LEN] = nonce_slice.try_into().map_err(|_| "Invalid nonce length".to_string())?;
243
244		let nonce = aead::Nonce::assume_unique_for_key(nonce_array);
245
246		encryption_key
247			.key
248			.open_in_place(nonce, aead::Aad::empty(), &mut in_out)
249			.map_err(|e| format!("Decryption failed: {}", e))?;
250
251		// Remove authentication tag
252		let plaintext_len = in_out.len() - AES_256_GCM.tag_len();
253		in_out.truncate(plaintext_len);
254
255		// Deserialize message
256		let (message, _) = decode_from_slice(&in_out, bincode::config::standard())
257			.map_err(|e| format!("Failed to deserialize message: {}", e))?;
258
259		trace!("[SecureMessageChannel] Message decrypted successfully");
260
261		Ok(message)
262	}
263
264	/// Rotate encryption keys
265	pub async fn rotate_keys(&self) -> Result<(), String> {
266		info!("[SecureMessageChannel] Rotating encryption keys");
267
268		// Generate new encryption key
269		let mut new_key_bytes = vec![0u8; 32];
270		self.rng
271			.fill(&mut new_key_bytes)
272			.map_err(|e| format!("Failed to generate new encryption key: {}", e))?;
273
274		let new_key = EncryptionKey::new(&new_key_bytes)?;
275
276		// Move current key to previous keys
277		{
278			let mut current_key = self.current_key.write().await;
279			let mut previous_keys = self.previous_keys.write().await;
280
281			previous_keys.insert(current_key.key_id.clone(), current_key.clone());
282			*current_key = new_key;
283		}
284
285		// Clean up old keys
286		self.cleanup_old_keys().await;
287
288		debug!("[SecureMessageChannel] Key rotation completed");
289		Ok(())
290	}
291
292	/// Get encryption key by ID
293	async fn get_encryption_key(&self, key_id:&str) -> Result<EncryptionKey, String> {
294		// Check current key first
295		let current_key = self.current_key.read().await;
296		if current_key.key_id == key_id {
297			return Ok(current_key.clone());
298		}
299
300		// Check previous keys
301		let previous_keys = self.previous_keys.read().await;
302		if let Some(key) = previous_keys.get(key_id) {
303			return Ok(key.clone());
304		}
305
306		Err(format!("Encryption key not found: {}", key_id))
307	}
308
309	/// Start automatic key rotation
310	async fn start_key_rotation(&self) {
311		let channel = Arc::new(self.clone());
312
313		let rotation_interval = Duration::from_secs(self.config.key_rotation_interval_hours * 3600);
314
315		let task = tokio::spawn(async move {
316			let mut interval = tokio::time::interval(rotation_interval);
317
318			loop {
319				interval.tick().await;
320
321				if let Err(e) = channel.rotate_keys().await {
322					error!("[SecureMessageChannel] Automatic key rotation failed: {}", e);
323				}
324			}
325		});
326
327		{
328			let mut rotation_task = self.key_rotation_task.write().await;
329			*rotation_task = Some(task);
330		}
331	}
332
333	/// Cleanup old keys
334	async fn cleanup_old_keys(&self) {
335		let rotation_interval = Duration::from_secs(self.config.key_rotation_interval_hours * 3600);
336		// Keep previous keys for 2 rotation cycles to support key rollover.
337		let max_age = rotation_interval * 2;
338
339		let mut previous_keys = self.previous_keys.write().await;
340
341		previous_keys.retain(|_, key| !key.is_expired(max_age));
342
343		debug!("[SecureMessageChannel] Cleaned up {} old keys", previous_keys.len());
344	}
345
346	/// Get security statistics
347	pub async fn get_stats(&self) -> SecurityStats {
348		let current_key = self.current_key.read().await;
349		let previous_keys = self.previous_keys.read().await;
350
351		SecurityStats {
352			current_key_id:current_key.key_id.clone(),
353			current_key_age_seconds:current_key.created_at.elapsed().unwrap_or_default().as_secs(),
354			current_key_usage_count:current_key.usage_count,
355			previous_keys_count:previous_keys.len(),
356			config:self.config.clone(),
357		}
358	}
359
360	/// Validate message integrity
361	pub async fn validate_message_integrity(&self, encrypted:&EncryptedMessage) -> Result<bool, String> {
362		// Check timestamp (prevent replay attacks)
363		let message_time = SystemTime::UNIX_EPOCH + Duration::from_millis(encrypted.timestamp);
364		let current_time = SystemTime::now();
365
366		if current_time.duration_since(message_time).unwrap_or_default() > Duration::from_secs(300) {
367			// Message is older than 5 minutes
368			return Ok(false);
369		}
370
371		// Verify HMAC
372		let hmac_key = self.hmac_key.read().await;
373		let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &hmac_key);
374
375		match hmac::verify(&hmac_key, &encrypted.ciphertext, &encrypted.hmac_tag) {
376			Ok(_) => Ok(true),
377			Err(_) => Ok(false),
378		}
379	}
380
381	/// Create a secure channel with default configuration
382	pub fn default_channel() -> Result<Self, String> { Self::new(SecurityConfig::default()) }
383
384	/// Create a high-security channel
385	pub fn high_security_channel() -> Result<Self, String> {
386		Self::new(SecurityConfig {
387			// Rotate keys hourly for maximum security.
388			key_rotation_interval_hours:1,
389			// Smaller message size limit: 1MB for stricter controls.
390			max_message_size_bytes:1 * 1024 * 1024,
391			..Default::default()
392		})
393	}
394}
395
396impl Clone for SecureMessageChannel {
397	fn clone(&self) -> Self {
398		Self {
399			config:self.config.clone(),
400			current_key:self.current_key.clone(),
401			previous_keys:self.previous_keys.clone(),
402			hmac_key:self.hmac_key.clone(),
403			rng:SystemRandom::new(),
404			key_rotation_task:Arc::new(RwLock::new(None)),
405		}
406	}
407}
408
409/// Encrypted message structure
410#[derive(Debug, Clone, Serialize, Deserialize)]
411pub struct EncryptedMessage {
412	pub key_id:String,
413	pub nonce:Vec<u8>,
414	pub ciphertext:Vec<u8>,
415	pub hmac_tag:Vec<u8>,
416	pub timestamp:u64,
417}
418
419/// Security statistics
420#[derive(Debug, Clone, Serialize, Deserialize)]
421pub struct SecurityStats {
422	pub current_key_id:String,
423	pub current_key_age_seconds:u64,
424	pub current_key_usage_count:usize,
425	pub previous_keys_count:usize,
426	pub config:SecurityConfig,
427}
428
429/// Utility functions for secure messaging
430impl SecureMessageChannel {
431	/// Generate a secure random key
432	pub fn generate_secure_key(key_size_bytes:usize) -> Result<Vec<u8>, String> {
433		let rng = SystemRandom::new();
434		let mut key = vec![0u8; key_size_bytes];
435
436		rng.fill(&mut key)
437			.map_err(|e| format!("Failed to generate secure key: {}", e))?;
438
439		Ok(key)
440	}
441
442	/// Calculate message overhead for encryption
443	pub fn calculate_encryption_overhead(message_size:usize) -> usize {
444		// Nonce + HMAC tag + encryption overhead + additional padding.
445		NONCE_LEN + AES_256_GCM.tag_len() + 16
446	}
447
448	/// Estimate encrypted message size
449	pub fn estimate_encrypted_size(original_size:usize) -> usize {
450		original_size + Self::calculate_encryption_overhead(original_size)
451	}
452
453	/// Create message with secure headers
454	pub async fn create_secure_message<T:Serialize>(
455		&self,
456		message:&T,
457		additional_headers:HashMap<String, String>,
458	) -> Result<SecureMessage<T>, String> {
459		let encrypted = self.encrypt_message(message).await?;
460
461		Ok(SecureMessage::<T> {
462			encrypted,
463			headers:additional_headers,
464			version:"1.0".to_string(),
465			_marker:PhantomData,
466		})
467	}
468}
469
470/// Secure message with headers
471#[derive(Debug, Clone, Serialize, Deserialize)]
472pub struct SecureMessage<T> {
473	pub encrypted:EncryptedMessage,
474	pub headers:HashMap<String, String>,
475	pub version:String,
476	#[serde(skip)]
477	_marker:PhantomData<T>,
478}
479
480#[cfg(test)]
481mod tests {
482	use super::*;
483
484	#[tokio::test]
485	async fn test_secure_channel_creation() {
486		let channel = SecureMessageChannel::default_channel().unwrap();
487		assert_eq!(channel.config.encryption_algorithm, "AES-256-GCM");
488	}
489
490	#[tokio::test]
491	async fn test_message_encryption_decryption() {
492		let channel = SecureMessageChannel::default_channel().unwrap();
493		channel.start().await.unwrap();
494
495		let test_message = "Hello, secure world!";
496		let encrypted = channel.encrypt_message(&test_message).await.unwrap();
497
498		assert!(!encrypted.ciphertext.is_empty());
499		assert!(!encrypted.hmac_tag.is_empty());
500		assert!(!encrypted.nonce.is_empty());
501
502		let decrypted:String = channel.decrypt_message(&encrypted).await.unwrap();
503		assert_eq!(decrypted, test_message);
504
505		channel.stop().await.unwrap();
506	}
507
508	#[tokio::test]
509	async fn test_message_validation() {
510		let channel = SecureMessageChannel::default_channel().unwrap();
511		channel.start().await.unwrap();
512
513		let test_message = "Test validation";
514		let encrypted = channel.encrypt_message(&test_message).await.unwrap();
515
516		let is_valid = channel.validate_message_integrity(&encrypted).await.unwrap();
517		assert!(is_valid);
518
519		channel.stop().await.unwrap();
520	}
521
522	#[tokio::test]
523	async fn test_key_rotation() {
524		let channel = SecureMessageChannel::default_channel().unwrap();
525		channel.start().await.unwrap();
526
527		let stats_before = channel.get_stats().await;
528
529		// Rotate keys
530		channel.rotate_keys().await.unwrap();
531
532		let stats_after = channel.get_stats().await;
533		assert_ne!(stats_before.current_key_id, stats_after.current_key_id);
534		assert_eq!(stats_after.previous_keys_count, 1);
535
536		channel.stop().await.unwrap();
537	}
538
539	#[test]
540	fn test_secure_key_generation() {
541		let key = SecureMessageChannel::generate_secure_key(32).unwrap();
542		assert_eq!(key.len(), 32);
543	}
544
545	#[test]
546	fn test_encryption_overhead_calculation() {
547		let overhead = SecureMessageChannel::calculate_encryption_overhead(100);
548		assert!(overhead > 0);
549
550		let estimated_size = SecureMessageChannel::estimate_encrypted_size(100);
551		assert!(estimated_size > 100);
552	}
553}