1use std::{collections::HashMap, path::PathBuf, sync::Arc};
7
8use async_trait::async_trait;
9use bytes::Bytes;
10use serde::{Deserialize, Serialize};
11use tokio::sync::RwLock;
12use tracing::{debug, info, instrument, warn};
13
14use crate::{
15 Transport::TransportStrategy,
16 Transport::TransportType,
17 Transport::TransportStats,
18 Transport::TransportConfig,
19 WASM::{
20 HostBridge::HostBridgeImpl,
21 MemoryManager::{MemoryLimits, MemoryManagerImpl},
22 Runtime::{WASMConfig, WASMRuntime},
23 WASMStats,
24 },
25};
26
27#[derive(Clone, Debug)]
29pub struct WASMTransportImpl {
30 runtime:Arc<WASMRuntime>,
32 memory_manager:Arc<RwLock<MemoryManagerImpl>>,
34 bridge:Arc<HostBridgeImpl>,
36 modules:Arc<RwLock<HashMap<String, WASMModuleInfo>>>,
38 config:TransportConfig,
40 connected:Arc<RwLock<bool>>,
42 stats:Arc<RwLock<TransportStats>>,
44}
45
46#[derive(Debug, Clone)]
48pub struct WASMModuleInfo {
49 pub id:String,
51 pub name:Option<String>,
53 pub path:Option<PathBuf>,
55 pub loaded_at:u64,
57 pub function_stats:HashMap<String, FunctionCallStats>,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct FunctionCallStats {
64 pub call_count:u64,
66 pub total_time_us:u64,
68 pub last_call_at:Option<u64>,
70 pub error_count:u64,
72}
73
74impl FunctionCallStats {
75 pub fn record_call(&mut self, time_us:u64) {
77 self.call_count += 1;
78 self.total_time_us += time_us;
79 self.last_call_at = Some(
80 std::time::SystemTime::now()
81 .duration_since(std::time::UNIX_EPOCH)
82 .map(|d| d.as_secs())
83 .unwrap_or(0),
84 );
85 }
86
87 pub fn record_error(&mut self) { self.error_count += 1; }
89}
90
91impl Default for FunctionCallStats {
92 fn default() -> Self { Self { call_count:0, total_time_us:0, last_call_at:None, error_count:0 } }
93}
94
95impl WASMTransportImpl {
96 pub fn new(enable_wasi:bool, memory_limit_mb:u64, max_execution_time_ms:u64) -> anyhow::Result<Self> {
98 let config = WASMConfig::new(memory_limit_mb, max_execution_time_ms, enable_wasi);
99
100 let runtime_result = tokio::runtime::Runtime::new()
103 .map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
104 .block_on(WASMRuntime::new(config.clone()))
105 .map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
106 let runtime = Arc::new(runtime_result);
107
108 let memory_limits = MemoryLimits::new(memory_limit_mb, (memory_limit_mb as f64 * 0.75) as u64, 100);
109 let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
110 let bridge = Arc::new(HostBridgeImpl::new());
111
112 Ok(Self {
113 runtime,
114 memory_manager,
115 bridge,
116 modules:Arc::new(RwLock::new(HashMap::new())),
117 config:TransportConfig::default(),
118 connected:Arc::new(RwLock::new(true)), stats:Arc::new(RwLock::new(TransportStats::default())),
120 })
121 }
122
123 pub fn with_config(wasm_config:WASMConfig, transport_config:TransportConfig) -> anyhow::Result<Self> {
125 let runtime_result = tokio::runtime::Runtime::new()
126 .map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
127 .block_on(WASMRuntime::new(wasm_config.clone()))
128 .map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
129 let runtime = Arc::new(runtime_result);
130
131 let memory_limits = MemoryLimits::new(
132 wasm_config.memory_limit_mb,
133 (wasm_config.memory_limit_mb as f64 * 0.75) as u64,
134 100,
135 );
136 let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
137 let bridge = Arc::new(HostBridgeImpl::new());
138
139 Ok(Self {
140 runtime,
141 memory_manager,
142 bridge,
143 modules:Arc::new(RwLock::new(HashMap::new())),
144 config:transport_config,
145 connected:Arc::new(RwLock::new(true)),
146 stats:Arc::new(RwLock::new(TransportStats::default())),
147 })
148 }
149
150 pub fn runtime(&self) -> &Arc<WASMRuntime> { &self.runtime }
152
153 pub fn memory_manager(&self) -> &Arc<RwLock<MemoryManagerImpl>> { &self.memory_manager }
155
156 pub fn bridge(&self) -> &Arc<HostBridgeImpl> { &self.bridge }
158
159 pub async fn get_modules(&self) -> HashMap<String, WASMModuleInfo> { self.modules.read().await.clone() }
161
162 pub async fn get_wasm_stats(&self) -> WASMStats {
164 let memory_manager = self.memory_manager.read().await;
165 let managers = self.modules.read().await;
166
167 WASMStats {
168 modules_loaded:managers.len(),
169 active_instances:managers.len(), total_memory_mb:memory_manager.current_usage_mb() as u64,
171 total_execution_time_ms:0, function_calls:self.stats.read().await.messages_sent,
173 }
174 }
175
176 #[instrument(skip(self, module_id, function_name, args))]
178 pub async fn call_wasm_function(
179 &self,
180 module_id:&str,
181 function_name:&str,
182 args:Vec<Bytes>,
183 ) -> anyhow::Result<Bytes> {
184 let start = std::time::Instant::now();
185
186 debug!(
187 "Calling WASM function: {}::{} with {} arguments",
188 module_id,
189 function_name,
190 args.len()
191 );
192
193 let modules = self.modules.read().await;
194 let module = modules
195 .get(module_id)
196 .ok_or_else(|| anyhow::anyhow!("Module not found: {}", module_id))?;
197
198 let response = Bytes::new();
201
202 let mut modules_mut = self.modules.write().await;
204 if let Some(module) = modules_mut.get_mut(module_id) {
205 let stats = module.function_stats.entry(function_name.to_string()).or_default();
206 stats.record_call(start.elapsed().as_micros() as u64);
207 }
208
209 drop(modules_mut);
210
211 let mut stats = self.stats.write().await;
213 stats.record_sent(args.iter().map(|b| b.len() as u64).sum(), start.elapsed().as_micros() as u64);
214 stats.record_received(response.len() as u64);
215
216 Ok(response)
217 }
218}
219
220#[async_trait]
221impl TransportStrategy for WASMTransportImpl {
222 type Error = WASMTransportError;
223
224 #[instrument(skip(self))]
225 async fn connect(&self) -> Result<(), Self::Error> {
226 info!("WASM transport connecting");
227
228 *self.connected.write().await = true;
230
231 info!("WASM transport connected");
232
233 Ok(())
234 }
235
236 #[instrument(skip(self, request))]
237 async fn send(&self, request:&[u8]) -> Result<Vec<u8>, Self::Error> {
238 let start = std::time::Instant::now();
239
240 if !self.is_connected() {
241 return Err(WASMTransportError::NotConnected);
242 }
243
244 debug!("Sending WASM transport request ({} bytes)", request.len());
245
246 let request_str =
249 std::str::from_utf8(request).map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?;
250
251 let parts:Vec<&str> = request_str.splitn(3, ':').collect();
252 if parts.len() < 3 {
253 return Err(WASMTransportError::InvalidRequest("Invalid request format".to_string()));
254 }
255
256 let module_id = parts[0];
257 let function_name = parts[1];
258 let args_base64 = parts[2];
259
260 let args = vec![Bytes::from(
262 base64::decode(args_base64).map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?,
263 )];
264
265 let response = self
267 .call_wasm_function(module_id, function_name, args)
268 .await
269 .map_err(|e| WASMTransportError::FunctionCallFailed(e.to_string()))?;
270
271 let response_vec = response.to_vec();
273
274 let latency_us = start.elapsed().as_micros() as u64;
275
276 debug!("WASM transport request completed in {}µs", latency_us);
277
278 Ok(response_vec)
279 }
280
281 #[instrument(skip(self, data))]
282 async fn send_no_response(&self, data:&[u8]) -> Result<(), Self::Error> {
283 if !self.is_connected() {
284 return Err(WASMTransportError::NotConnected);
285 }
286
287 debug!("Sending WASM transport request without response ({} bytes)", data.len());
288
289 self.send(data).await?;
291 Ok(())
292 }
293
294 #[instrument(skip(self))]
295 async fn close(&self) -> Result<(), Self::Error> {
296 info!("Closing WASM transport");
297
298 *self.connected.write().await = false;
299
300 info!("WASM transport closed");
301
302 Ok(())
303 }
304
305 fn is_connected(&self) -> bool { self.connected.blocking_read().to_owned() }
306
307 fn transport_type(&self) -> TransportType {
308 TransportType::WASM
309 }
310}
311
312#[derive(Debug, thiserror::Error)]
314pub enum WASMTransportError {
315#[error("Module not found: {0}")]
317ModuleNotFound(String),
318
319#[error("Function not found: {0}")]
321FunctionNotFound(String),
322
323#[error("Function call failed: {0}")]
325FunctionCallFailed(String),
326
327#[error("Memory error: {0}")]
329MemoryError(String),
330
331#[error("Runtime error: {0}")]
333RuntimeError(String),
334
335#[error("Invalid request: {0}")]
337InvalidRequest(String),
338
339#[error("Not connected")]
341NotConnected,
342
343#[error("Compilation failed: {0}")]
345CompilationFailed(String),
346
347#[error("Timeout")]
349Timeout,
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::Transport::Strategy::TransportStrategy;
356
357 #[test]
358 fn test_wasm_transport_creation() {
359 let result = WASMTransportImpl::new(true, 512, 30000);
360 assert!(result.is_ok());
361 let transport = result.unwrap();
362 assert!(transport.is_connected());
364 }
365
366 #[test]
367 fn test_function_call_stats() {
368 let mut stats = FunctionCallStats::default();
369 stats.record_call(100);
370 assert_eq!(stats.call_count, 1);
371 assert_eq!(stats.total_time_us, 100);
372 assert!(stats.last_call_at.is_some());
373 }
374
375 #[tokio::test]
376 async fn test_wasm_transport_not_connected_after_close() {
377 let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
378 let _: anyhow::Result<()> = transport.close().await.map_err(|e| anyhow::anyhow!(e.to_string()));
379 assert!(!transport.is_connected());
380 }
381
382 #[tokio::test]
383 async fn test_get_wasm_stats() {
384 let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
385 let stats = transport.get_wasm_stats().await;
386 assert_eq!(stats.modules_loaded, 0);
387 assert_eq!(stats.active_instances, 0);
388 }
389}