grove/Transport/
WASMTransport.rs

1//! WASM Transport Implementation
2//!
3//! Provides direct communication with WASM modules.
4//! Handles calls to and from WebAssembly instances.
5
6use std::{collections::HashMap, path::PathBuf, sync::Arc};
7
8use async_trait::async_trait;
9use bytes::Bytes;
10use serde::{Deserialize, Serialize};
11use tokio::sync::RwLock;
12use tracing::{debug, info, instrument, warn};
13
14use crate::{
15	Transport::TransportStrategy,
16	Transport::TransportType,
17	Transport::TransportStats,
18	Transport::TransportConfig,
19	WASM::{
20		HostBridge::HostBridgeImpl,
21		MemoryManager::{MemoryLimits, MemoryManagerImpl},
22		Runtime::{WASMConfig, WASMRuntime},
23		WASMStats,
24	},
25};
26
27/// WASM transport for direct module communication
28#[derive(Clone, Debug)]
29pub struct WASMTransportImpl {
30	/// WASM runtime
31	runtime:Arc<WASMRuntime>,
32	/// Memory manager
33	memory_manager:Arc<RwLock<MemoryManagerImpl>>,
34	/// Host bridge for communication
35	bridge:Arc<HostBridgeImpl>,
36	/// Loaded modules
37	modules:Arc<RwLock<HashMap<String, WASMModuleInfo>>>,
38	/// Transport configuration
39	config:TransportConfig,
40	/// Connection state
41	connected:Arc<RwLock<bool>>,
42	/// Transport statistics
43	stats:Arc<RwLock<TransportStats>>,
44}
45
46/// Information about a loaded WASM module
47#[derive(Debug, Clone)]
48pub struct WASMModuleInfo {
49	/// Module ID
50	pub id:String,
51	/// Module name (if available)
52	pub name:Option<String>,
53	/// Path to module file
54	pub path:Option<PathBuf>,
55	/// Module loaded timestamp
56	pub loaded_at:u64,
57	/// Function statistics
58	pub function_stats:HashMap<String, FunctionCallStats>,
59}
60
61/// Statistics for function calls
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct FunctionCallStats {
64	/// Number of calls
65	pub call_count:u64,
66	/// Total execution time in microseconds
67	pub total_time_us:u64,
68	/// Last call timestamp
69	pub last_call_at:Option<u64>,
70	/// Number of errors
71	pub error_count:u64,
72}
73
74impl FunctionCallStats {
75	/// Record a successful function call
76	pub fn record_call(&mut self, time_us:u64) {
77		self.call_count += 1;
78		self.total_time_us += time_us;
79		self.last_call_at = Some(
80			std::time::SystemTime::now()
81				.duration_since(std::time::UNIX_EPOCH)
82				.map(|d| d.as_secs())
83				.unwrap_or(0),
84		);
85	}
86
87	/// Record a failed function call
88	pub fn record_error(&mut self) { self.error_count += 1; }
89}
90
91impl Default for FunctionCallStats {
92	fn default() -> Self { Self { call_count:0, total_time_us:0, last_call_at:None, error_count:0 } }
93}
94
95impl WASMTransportImpl {
96	/// Create a new WASM transport with default configuration
97	pub fn new(enable_wasi:bool, memory_limit_mb:u64, max_execution_time_ms:u64) -> anyhow::Result<Self> {
98		let config = WASMConfig::new(memory_limit_mb, max_execution_time_ms, enable_wasi);
99
100		// Create runtime - this would normally be async, but for now we do it
101		// synchronously In production, this would need to be properly awaited
102		let runtime_result = tokio::runtime::Runtime::new()
103			.map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
104			.block_on(WASMRuntime::new(config.clone()))
105			.map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
106		let runtime = Arc::new(runtime_result);
107
108		let memory_limits = MemoryLimits::new(memory_limit_mb, (memory_limit_mb as f64 * 0.75) as u64, 100);
109		let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
110		let bridge = Arc::new(HostBridgeImpl::new());
111
112		Ok(Self {
113			runtime,
114			memory_manager,
115			bridge,
116			modules:Arc::new(RwLock::new(HashMap::new())),
117			config:TransportConfig::default(),
118			connected:Arc::new(RwLock::new(true)), // WASM transport is always "connected" locally
119			stats:Arc::new(RwLock::new(TransportStats::default())),
120		})
121	}
122
123	/// Create a new WASM transport with custom configuration
124	pub fn with_config(wasm_config:WASMConfig, transport_config:TransportConfig) -> anyhow::Result<Self> {
125		let runtime_result = tokio::runtime::Runtime::new()
126			.map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
127			.block_on(WASMRuntime::new(wasm_config.clone()))
128			.map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
129		let runtime = Arc::new(runtime_result);
130
131		let memory_limits = MemoryLimits::new(
132			wasm_config.memory_limit_mb,
133			(wasm_config.memory_limit_mb as f64 * 0.75) as u64,
134			100,
135		);
136		let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
137		let bridge = Arc::new(HostBridgeImpl::new());
138
139		Ok(Self {
140			runtime,
141			memory_manager,
142			bridge,
143			modules:Arc::new(RwLock::new(HashMap::new())),
144			config:transport_config,
145			connected:Arc::new(RwLock::new(true)),
146			stats:Arc::new(RwLock::new(TransportStats::default())),
147		})
148	}
149
150	/// Get a reference to the WASM runtime
151	pub fn runtime(&self) -> &Arc<WASMRuntime> { &self.runtime }
152
153	/// Get a reference to the memory manager
154	pub fn memory_manager(&self) -> &Arc<RwLock<MemoryManagerImpl>> { &self.memory_manager }
155
156	/// Get a reference to the host bridge
157	pub fn bridge(&self) -> &Arc<HostBridgeImpl> { &self.bridge }
158
159	/// Get all loaded modules
160	pub async fn get_modules(&self) -> HashMap<String, WASMModuleInfo> { self.modules.read().await.clone() }
161
162	/// Get WASM runtime statistics
163	pub async fn get_wasm_stats(&self) -> WASMStats {
164		let memory_manager = self.memory_manager.read().await;
165		let managers = self.modules.read().await;
166
167		WASMStats {
168			modules_loaded:managers.len(),
169			active_instances:managers.len(), // In real implementation, track instances
170			total_memory_mb:memory_manager.current_usage_mb() as u64,
171			total_execution_time_ms:0, // Track from actual calls
172			function_calls:self.stats.read().await.messages_sent,
173		}
174	}
175
176	/// Call a function in a WASM module
177	#[instrument(skip(self, module_id, function_name, args))]
178	pub async fn call_wasm_function(
179		&self,
180		module_id:&str,
181		function_name:&str,
182		args:Vec<Bytes>,
183	) -> anyhow::Result<Bytes> {
184		let start = std::time::Instant::now();
185
186		debug!(
187			"Calling WASM function: {}::{} with {} arguments",
188			module_id,
189			function_name,
190			args.len()
191		);
192
193		let modules = self.modules.read().await;
194		let module = modules
195			.get(module_id)
196			.ok_or_else(|| anyhow::anyhow!("Module not found: {}", module_id))?;
197
198		// In a real implementation, this would call the actual WASM function
199		// For now, we return a mock response
200		let response = Bytes::new();
201
202		// Update statistics
203		let mut modules_mut = self.modules.write().await;
204		if let Some(module) = modules_mut.get_mut(module_id) {
205			let stats = module.function_stats.entry(function_name.to_string()).or_default();
206			stats.record_call(start.elapsed().as_micros() as u64);
207		}
208
209		drop(modules_mut);
210
211		// Update transport statistics
212		let mut stats = self.stats.write().await;
213		stats.record_sent(args.iter().map(|b| b.len() as u64).sum(), start.elapsed().as_micros() as u64);
214		stats.record_received(response.len() as u64);
215
216		Ok(response)
217	}
218}
219
220#[async_trait]
221impl TransportStrategy for WASMTransportImpl {
222	type Error = WASMTransportError;
223
224	#[instrument(skip(self))]
225	async fn connect(&self) -> Result<(), Self::Error> {
226		info!("WASM transport connecting");
227
228		// WASM transport is always "connected" locally
229		*self.connected.write().await = true;
230
231		info!("WASM transport connected");
232
233		Ok(())
234	}
235
236	#[instrument(skip(self, request))]
237	async fn send(&self, request:&[u8]) -> Result<Vec<u8>, Self::Error> {
238		let start = std::time::Instant::now();
239
240		if !self.is_connected() {
241			return Err(WASMTransportError::NotConnected);
242		}
243
244		debug!("Sending WASM transport request ({} bytes)", request.len());
245
246		// Parse request - it should contain module ID and function name
247		// For simplicity, we use a minimal format: module_id:function_name:base64_args
248		let request_str =
249			std::str::from_utf8(request).map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?;
250
251		let parts:Vec<&str> = request_str.splitn(3, ':').collect();
252		if parts.len() < 3 {
253			return Err(WASMTransportError::InvalidRequest("Invalid request format".to_string()));
254		}
255
256		let module_id = parts[0];
257		let function_name = parts[1];
258		let args_base64 = parts[2];
259
260		// Decode arguments from base64
261		let args = vec![Bytes::from(
262			base64::decode(args_base64).map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?,
263		)];
264
265		// Call the WASM function
266		let response = self
267			.call_wasm_function(module_id, function_name, args)
268			.await
269			.map_err(|e| WASMTransportError::FunctionCallFailed(e.to_string()))?;
270
271		// Convert response to Vec<u8>
272		let response_vec = response.to_vec();
273
274		let latency_us = start.elapsed().as_micros() as u64;
275
276		debug!("WASM transport request completed in {}µs", latency_us);
277
278		Ok(response_vec)
279	}
280
281	#[instrument(skip(self, data))]
282	async fn send_no_response(&self, data:&[u8]) -> Result<(), Self::Error> {
283		if !self.is_connected() {
284			return Err(WASMTransportError::NotConnected);
285		}
286
287		debug!("Sending WASM transport request without response ({} bytes)", data.len());
288
289		// For fire-and-forget calls, we still execute but ignore the response
290		self.send(data).await?;
291		Ok(())
292	}
293
294	#[instrument(skip(self))]
295	async fn close(&self) -> Result<(), Self::Error> {
296		info!("Closing WASM transport");
297
298		*self.connected.write().await = false;
299
300		info!("WASM transport closed");
301
302		Ok(())
303	}
304
305	fn is_connected(&self) -> bool { self.connected.blocking_read().to_owned() }
306
307	fn transport_type(&self) -> TransportType {
308		TransportType::WASM
309	}
310}
311
312/// WASM transport errors
313#[derive(Debug, thiserror::Error)]
314pub enum WASMTransportError {
315/// Module not found error
316#[error("Module not found: {0}")]
317ModuleNotFound(String),
318
319/// Function not found error
320#[error("Function not found: {0}")]
321FunctionNotFound(String),
322
323/// Function call failed error
324#[error("Function call failed: {0}")]
325FunctionCallFailed(String),
326
327/// Memory error
328#[error("Memory error: {0}")]
329MemoryError(String),
330
331/// Runtime error
332#[error("Runtime error: {0}")]
333RuntimeError(String),
334
335/// Invalid request error
336#[error("Invalid request: {0}")]
337InvalidRequest(String),
338
339/// Not connected error
340#[error("Not connected")]
341NotConnected,
342
343/// Compilation failed error
344#[error("Compilation failed: {0}")]
345CompilationFailed(String),
346
347/// Timeout error
348#[error("Timeout")]
349Timeout,
350}
351
352#[cfg(test)]
353mod tests {
354	use super::*;
355	use crate::Transport::Strategy::TransportStrategy;
356
357	#[test]
358	fn test_wasm_transport_creation() {
359		let result = WASMTransportImpl::new(true, 512, 30000);
360		assert!(result.is_ok());
361		let transport = result.unwrap();
362		// WASM transport should always be connected
363		assert!(transport.is_connected());
364	}
365
366	#[test]
367	fn test_function_call_stats() {
368		let mut stats = FunctionCallStats::default();
369		stats.record_call(100);
370		assert_eq!(stats.call_count, 1);
371		assert_eq!(stats.total_time_us, 100);
372		assert!(stats.last_call_at.is_some());
373	}
374
375	#[tokio::test]
376	async fn test_wasm_transport_not_connected_after_close() {
377		let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
378		let _: anyhow::Result<()> = transport.close().await.map_err(|e| anyhow::anyhow!(e.to_string()));
379		assert!(!transport.is_connected());
380	}
381
382	#[tokio::test]
383	async fn test_get_wasm_stats() {
384		let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
385		let stats = transport.get_wasm_stats().await;
386		assert_eq!(stats.modules_loaded, 0);
387		assert_eq!(stats.active_instances, 0);
388	}
389}