1use std::{
7 collections::HashMap,
8 marker::PhantomData,
9 sync::Arc,
10 time::{Duration, SystemTime},
11};
12
13use log::{debug, error, info, trace, warn};
14use ring::{
15 aead::{self, AES_256_GCM, LessSafeKey, NONCE_LEN, UnboundKey},
16 hmac,
17 rand::{SecureRandom, SystemRandom},
18};
19use serde::{Deserialize, Serialize};
20use tokio::sync::RwLock;
21use bincode::serde::{decode_from_slice, encode_to_vec};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct SecurityConfig {
26 pub encryption_algorithm:String,
27 pub key_rotation_interval_hours:u64,
28 pub hmac_algorithm:String,
29 pub nonce_size_bytes:usize,
30 pub auth_tag_size_bytes:usize,
31 pub max_message_size_bytes:usize,
32}
33
34impl Default for SecurityConfig {
35 fn default() -> Self {
36 Self {
37 encryption_algorithm:"AES-256-GCM".to_string(),
38 key_rotation_interval_hours:24,
40 hmac_algorithm:"HMAC-SHA256".to_string(),
41 nonce_size_bytes:NONCE_LEN,
42 auth_tag_size_bytes:AES_256_GCM.tag_len(),
43 max_message_size_bytes:10 * 1024 * 1024,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51struct EncryptionKey {
52 key:LessSafeKey,
53 created_at:SystemTime,
54 key_id:String,
55 usage_count:usize,
56}
57
58impl EncryptionKey {
59 fn new(key_bytes:&[u8]) -> Result<Self, String> {
60 let unbound_key =
61 UnboundKey::new(&AES_256_GCM, key_bytes).map_err(|e| format!("Failed to create unbound key: {}", e))?;
62
63 Ok(Self {
64 key:LessSafeKey::new(unbound_key),
65 created_at:SystemTime::now(),
66 key_id:Self::generate_key_id(),
67 usage_count:0,
68 })
69 }
70
71 fn generate_key_id() -> String {
72 let rng = SystemRandom::new();
73 let mut id_bytes = [0u8; 8];
74 rng.fill(&mut id_bytes).unwrap();
75 hex::encode(id_bytes)
76 }
77
78 fn is_expired(&self, rotation_interval:Duration) -> bool {
79 self.created_at.elapsed().unwrap_or_default() > rotation_interval
80 }
81
82 fn increment_usage(&mut self) { self.usage_count += 1; }
83}
84
85pub struct SecureMessageChannel {
87 pub config:SecurityConfig,
88 pub current_key:Arc<RwLock<EncryptionKey>>,
89 pub previous_keys:Arc<RwLock<HashMap<String, EncryptionKey>>>,
90 pub hmac_key:Arc<RwLock<Vec<u8>>>,
91 pub rng:SystemRandom,
92 pub key_rotation_task:Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
93}
94
95impl SecureMessageChannel {
96 pub fn new(config:SecurityConfig) -> Result<Self, String> {
98 let rng = SystemRandom::new();
99
100 let mut encryption_key_bytes = vec![0u8; 32];
102 rng.fill(&mut encryption_key_bytes)
103 .map_err(|e| format!("Failed to generate encryption key: {}", e))?;
104
105 let encryption_key = EncryptionKey::new(&encryption_key_bytes)?;
106
107 let mut hmac_key = vec![0u8; 32];
109 rng.fill(&mut hmac_key)
110 .map_err(|e| format!("Failed to generate HMAC key: {}", e))?;
111
112 let channel = Self {
113 config,
114 current_key:Arc::new(RwLock::new(encryption_key)),
115 previous_keys:Arc::new(RwLock::new(HashMap::new())),
116 hmac_key:Arc::new(RwLock::new(hmac_key)),
117 rng,
118 key_rotation_task:Arc::new(RwLock::new(None)),
119 };
120
121 info!(
122 "[SecureMessageChannel] Created secure channel with {} encryption",
123 channel.config.encryption_algorithm
124 );
125
126 Ok(channel)
127 }
128
129 pub async fn start(&self) -> Result<(), String> {
131 self.start_key_rotation().await;
133
134 info!("[SecureMessageChannel] Secure channel started");
135 Ok(())
136 }
137
138 pub async fn stop(&self) -> Result<(), String> {
140 {
142 let mut rotation_task = self.key_rotation_task.write().await;
143 if let Some(task) = rotation_task.take() {
144 task.abort();
145 }
146 }
147
148 {
150 let mut current_key = self.current_key.write().await;
151 *current_key = EncryptionKey::new(&[0u8; 32]).unwrap();
153 }
154
155 {
156 let mut previous_keys = self.previous_keys.write().await;
157 previous_keys.clear();
158 }
159
160 {
161 let mut hmac_key = self.hmac_key.write().await;
162 hmac_key.fill(0);
164 }
165
166 info!("[SecureMessageChannel] Secure channel stopped");
167 Ok(())
168 }
169
170 pub async fn encrypt_message<T:Serialize>(&self, message:&T) -> Result<EncryptedMessage, String> {
172 let serialized_data = encode_to_vec(message, bincode::config::standard())
174 .map_err(|e| format!("Failed to serialize message: {}", e))?;
175
176 if serialized_data.len() > self.config.max_message_size_bytes {
178 return Err(format!("Message too large: {} bytes", serialized_data.len()));
179 }
180
181 let mut current_key = self.current_key.write().await;
183 current_key.increment_usage();
184
185 let mut nonce = vec![0u8; self.config.nonce_size_bytes];
187 self.rng
188 .fill(&mut nonce)
189 .map_err(|e| format!("Failed to generate nonce: {}", e))?;
190
191 let mut in_out = serialized_data.clone();
193 let nonce_slice:&[u8] = &nonce;
194 let nonce_array:[u8; NONCE_LEN] = nonce_slice.try_into().map_err(|_| "Invalid nonce length".to_string())?;
195
196 let aead_nonce = aead::Nonce::assume_unique_for_key(nonce_array);
197
198 current_key
199 .key
200 .seal_in_place_append_tag(aead_nonce, aead::Aad::empty(), &mut in_out)
201 .map_err(|e| format!("Encryption failed: {}", e))?;
202
203 let hmac_key = self.hmac_key.read().await;
205 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &hmac_key);
206 let hmac_tag = hmac::sign(&hmac_key, &in_out);
207
208 let encrypted_message = EncryptedMessage {
209 key_id:current_key.key_id.clone(),
210 nonce:nonce.to_vec(),
211 ciphertext:in_out,
212 hmac_tag:hmac_tag.as_ref().to_vec(),
213 timestamp:SystemTime::now()
214 .duration_since(SystemTime::UNIX_EPOCH)
215 .unwrap_or_default()
216 .as_millis() as u64,
217 };
218
219 trace!(
220 "[SecureMessageChannel] Message encrypted (size: {} bytes)",
221 encrypted_message.ciphertext.len()
222 );
223
224 Ok(encrypted_message)
225 }
226
227 pub async fn decrypt_message<T:for<'de> Deserialize<'de>>(&self, encrypted:&EncryptedMessage) -> Result<T, String> {
229 let hmac_key = self.hmac_key.read().await;
231 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &hmac_key);
232
233 hmac::verify(&hmac_key, &encrypted.ciphertext, &encrypted.hmac_tag)
234 .map_err(|_| "HMAC verification failed".to_string())?;
235
236 let encryption_key = self.get_encryption_key(&encrypted.key_id).await?;
238
239 let mut in_out = encrypted.ciphertext.clone();
241 let nonce_slice:&[u8] = &encrypted.nonce;
242 let nonce_array:[u8; NONCE_LEN] = nonce_slice.try_into().map_err(|_| "Invalid nonce length".to_string())?;
243
244 let nonce = aead::Nonce::assume_unique_for_key(nonce_array);
245
246 encryption_key
247 .key
248 .open_in_place(nonce, aead::Aad::empty(), &mut in_out)
249 .map_err(|e| format!("Decryption failed: {}", e))?;
250
251 let plaintext_len = in_out.len() - AES_256_GCM.tag_len();
253 in_out.truncate(plaintext_len);
254
255 let (message, _) = decode_from_slice(&in_out, bincode::config::standard())
257 .map_err(|e| format!("Failed to deserialize message: {}", e))?;
258
259 trace!("[SecureMessageChannel] Message decrypted successfully");
260
261 Ok(message)
262 }
263
264 pub async fn rotate_keys(&self) -> Result<(), String> {
266 info!("[SecureMessageChannel] Rotating encryption keys");
267
268 let mut new_key_bytes = vec![0u8; 32];
270 self.rng
271 .fill(&mut new_key_bytes)
272 .map_err(|e| format!("Failed to generate new encryption key: {}", e))?;
273
274 let new_key = EncryptionKey::new(&new_key_bytes)?;
275
276 {
278 let mut current_key = self.current_key.write().await;
279 let mut previous_keys = self.previous_keys.write().await;
280
281 previous_keys.insert(current_key.key_id.clone(), current_key.clone());
282 *current_key = new_key;
283 }
284
285 self.cleanup_old_keys().await;
287
288 debug!("[SecureMessageChannel] Key rotation completed");
289 Ok(())
290 }
291
292 async fn get_encryption_key(&self, key_id:&str) -> Result<EncryptionKey, String> {
294 let current_key = self.current_key.read().await;
296 if current_key.key_id == key_id {
297 return Ok(current_key.clone());
298 }
299
300 let previous_keys = self.previous_keys.read().await;
302 if let Some(key) = previous_keys.get(key_id) {
303 return Ok(key.clone());
304 }
305
306 Err(format!("Encryption key not found: {}", key_id))
307 }
308
309 async fn start_key_rotation(&self) {
311 let channel = Arc::new(self.clone());
312
313 let rotation_interval = Duration::from_secs(self.config.key_rotation_interval_hours * 3600);
314
315 let task = tokio::spawn(async move {
316 let mut interval = tokio::time::interval(rotation_interval);
317
318 loop {
319 interval.tick().await;
320
321 if let Err(e) = channel.rotate_keys().await {
322 error!("[SecureMessageChannel] Automatic key rotation failed: {}", e);
323 }
324 }
325 });
326
327 {
328 let mut rotation_task = self.key_rotation_task.write().await;
329 *rotation_task = Some(task);
330 }
331 }
332
333 async fn cleanup_old_keys(&self) {
335 let rotation_interval = Duration::from_secs(self.config.key_rotation_interval_hours * 3600);
336 let max_age = rotation_interval * 2;
338
339 let mut previous_keys = self.previous_keys.write().await;
340
341 previous_keys.retain(|_, key| !key.is_expired(max_age));
342
343 debug!("[SecureMessageChannel] Cleaned up {} old keys", previous_keys.len());
344 }
345
346 pub async fn get_stats(&self) -> SecurityStats {
348 let current_key = self.current_key.read().await;
349 let previous_keys = self.previous_keys.read().await;
350
351 SecurityStats {
352 current_key_id:current_key.key_id.clone(),
353 current_key_age_seconds:current_key.created_at.elapsed().unwrap_or_default().as_secs(),
354 current_key_usage_count:current_key.usage_count,
355 previous_keys_count:previous_keys.len(),
356 config:self.config.clone(),
357 }
358 }
359
360 pub async fn validate_message_integrity(&self, encrypted:&EncryptedMessage) -> Result<bool, String> {
362 let message_time = SystemTime::UNIX_EPOCH + Duration::from_millis(encrypted.timestamp);
364 let current_time = SystemTime::now();
365
366 if current_time.duration_since(message_time).unwrap_or_default() > Duration::from_secs(300) {
367 return Ok(false);
369 }
370
371 let hmac_key = self.hmac_key.read().await;
373 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &hmac_key);
374
375 match hmac::verify(&hmac_key, &encrypted.ciphertext, &encrypted.hmac_tag) {
376 Ok(_) => Ok(true),
377 Err(_) => Ok(false),
378 }
379 }
380
381 pub fn default_channel() -> Result<Self, String> { Self::new(SecurityConfig::default()) }
383
384 pub fn high_security_channel() -> Result<Self, String> {
386 Self::new(SecurityConfig {
387 key_rotation_interval_hours:1,
389 max_message_size_bytes:1 * 1024 * 1024,
391 ..Default::default()
392 })
393 }
394}
395
396impl Clone for SecureMessageChannel {
397 fn clone(&self) -> Self {
398 Self {
399 config:self.config.clone(),
400 current_key:self.current_key.clone(),
401 previous_keys:self.previous_keys.clone(),
402 hmac_key:self.hmac_key.clone(),
403 rng:SystemRandom::new(),
404 key_rotation_task:Arc::new(RwLock::new(None)),
405 }
406 }
407}
408
409#[derive(Debug, Clone, Serialize, Deserialize)]
411pub struct EncryptedMessage {
412 pub key_id:String,
413 pub nonce:Vec<u8>,
414 pub ciphertext:Vec<u8>,
415 pub hmac_tag:Vec<u8>,
416 pub timestamp:u64,
417}
418
419#[derive(Debug, Clone, Serialize, Deserialize)]
421pub struct SecurityStats {
422 pub current_key_id:String,
423 pub current_key_age_seconds:u64,
424 pub current_key_usage_count:usize,
425 pub previous_keys_count:usize,
426 pub config:SecurityConfig,
427}
428
429impl SecureMessageChannel {
431 pub fn generate_secure_key(key_size_bytes:usize) -> Result<Vec<u8>, String> {
433 let rng = SystemRandom::new();
434 let mut key = vec![0u8; key_size_bytes];
435
436 rng.fill(&mut key)
437 .map_err(|e| format!("Failed to generate secure key: {}", e))?;
438
439 Ok(key)
440 }
441
442 pub fn calculate_encryption_overhead(message_size:usize) -> usize {
444 NONCE_LEN + AES_256_GCM.tag_len() + 16
446 }
447
448 pub fn estimate_encrypted_size(original_size:usize) -> usize {
450 original_size + Self::calculate_encryption_overhead(original_size)
451 }
452
453 pub async fn create_secure_message<T:Serialize>(
455 &self,
456 message:&T,
457 additional_headers:HashMap<String, String>,
458 ) -> Result<SecureMessage<T>, String> {
459 let encrypted = self.encrypt_message(message).await?;
460
461 Ok(SecureMessage::<T> {
462 encrypted,
463 headers:additional_headers,
464 version:"1.0".to_string(),
465 _marker:PhantomData,
466 })
467 }
468}
469
470#[derive(Debug, Clone, Serialize, Deserialize)]
472pub struct SecureMessage<T> {
473 pub encrypted:EncryptedMessage,
474 pub headers:HashMap<String, String>,
475 pub version:String,
476 #[serde(skip)]
477 _marker:PhantomData<T>,
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483
484 #[tokio::test]
485 async fn test_secure_channel_creation() {
486 let channel = SecureMessageChannel::default_channel().unwrap();
487 assert_eq!(channel.config.encryption_algorithm, "AES-256-GCM");
488 }
489
490 #[tokio::test]
491 async fn test_message_encryption_decryption() {
492 let channel = SecureMessageChannel::default_channel().unwrap();
493 channel.start().await.unwrap();
494
495 let test_message = "Hello, secure world!";
496 let encrypted = channel.encrypt_message(&test_message).await.unwrap();
497
498 assert!(!encrypted.ciphertext.is_empty());
499 assert!(!encrypted.hmac_tag.is_empty());
500 assert!(!encrypted.nonce.is_empty());
501
502 let decrypted:String = channel.decrypt_message(&encrypted).await.unwrap();
503 assert_eq!(decrypted, test_message);
504
505 channel.stop().await.unwrap();
506 }
507
508 #[tokio::test]
509 async fn test_message_validation() {
510 let channel = SecureMessageChannel::default_channel().unwrap();
511 channel.start().await.unwrap();
512
513 let test_message = "Test validation";
514 let encrypted = channel.encrypt_message(&test_message).await.unwrap();
515
516 let is_valid = channel.validate_message_integrity(&encrypted).await.unwrap();
517 assert!(is_valid);
518
519 channel.stop().await.unwrap();
520 }
521
522 #[tokio::test]
523 async fn test_key_rotation() {
524 let channel = SecureMessageChannel::default_channel().unwrap();
525 channel.start().await.unwrap();
526
527 let stats_before = channel.get_stats().await;
528
529 channel.rotate_keys().await.unwrap();
531
532 let stats_after = channel.get_stats().await;
533 assert_ne!(stats_before.current_key_id, stats_after.current_key_id);
534 assert_eq!(stats_after.previous_keys_count, 1);
535
536 channel.stop().await.unwrap();
537 }
538
539 #[test]
540 fn test_secure_key_generation() {
541 let key = SecureMessageChannel::generate_secure_key(32).unwrap();
542 assert_eq!(key.len(), 32);
543 }
544
545 #[test]
546 fn test_encryption_overhead_calculation() {
547 let overhead = SecureMessageChannel::calculate_encryption_overhead(100);
548 assert!(overhead > 0);
549
550 let estimated_size = SecureMessageChannel::estimate_encrypted_size(100);
551 assert!(estimated_size > 100);
552 }
553}