grove/WASM/
FunctionExport.rs

1//! Function Export Module
2//!
3//! Handles exporting host functions to WASM modules.
4//! Provides registration and management of functions that WASM can call.
5
6use std::{collections::HashMap, sync::Arc};
7
8use anyhow::{Context, Result};
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11use tracing::{debug, info, instrument, warn};
12use wasmtime::{Caller, Engine, Extern, Func, FuncType, Linker, Store, StoreLimits, Trap, Val, ValType};
13
14use crate::WASM::HostBridge::{FunctionSignature, HostBridgeImpl as HostBridge, HostBridgeImpl, HostFunctionCallback, ParamType, ReturnType};
15
16/// Host function registry for WASM exports
17pub struct HostFunctionRegistry {
18	/// Registered host functions
19	functions:Arc<RwLock<HashMap<String, RegisteredHostFunction>>>,
20	/// Associated host bridge
21	bridge:Arc<HostBridge>,
22}
23
24/// Registered host function with metadata
25#[derive(Debug, Clone)]
26struct RegisteredHostFunction {
27	/// Function name
28	name:String,
29	/// Function signature
30	signature:FunctionSignature,
31	/// Synchronous callback
32	callback:Option<HostFunctionCallback>,
33	/// Registration timestamp
34	registered_at:u64,
35	/// Call statistics
36	stats:FunctionStats,
37}
38
39/// Function statistics
40#[derive(Debug, Clone, Default)]
41pub struct FunctionStats {
42	/// Number of times called
43	pub call_count:u64,
44	/// Total execution time in nanoseconds
45	pub total_execution_ns:u64,
46	/// Last call timestamp
47	pub last_call_at:Option<u64>,
48	/// Number of errors
49	pub error_count:u64,
50}
51
52/// Export configuration for WASM functions
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ExportConfig {
55	/// Enable function export by default
56	pub auto_export:bool,
57	/// Enable timing statistics
58	pub enable_stats:bool,
59	/// Maximum number of functions that can be exported
60	pub max_functions:usize,
61	/// Function name prefix for exports
62	pub name_prefix:Option<String>,
63}
64
65impl Default for ExportConfig {
66	fn default() -> Self {
67		Self {
68			auto_export:true,
69			enable_stats:true,
70			max_functions:1000,
71			name_prefix:Some("host_".to_string()),
72		}
73	}
74}
75
76/// Function export for WASM
77pub struct FunctionExportImpl {
78	registry:Arc<HostFunctionRegistry>,
79	config:ExportConfig,
80}
81
82impl FunctionExportImpl {
83	/// Create a new function export manager
84	pub fn new(bridge:Arc<HostBridge>) -> Self {
85		Self {
86			registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
87			config:ExportConfig::default(),
88		}
89	}
90
91	/// Create with custom configuration
92	pub fn with_config(bridge:Arc<HostBridge>, config:ExportConfig) -> Self {
93		Self {
94			registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
95			config,
96		}
97	}
98
99	/// Register a host function for export to WASM
100	#[instrument(skip(self, callback))]
101	pub async fn register_function(
102		&self,
103		name:&str,
104		signature:FunctionSignature,
105		callback:HostFunctionCallback,
106	) -> Result<()> {
107		info!("Registering host function for export: {}", name);
108
109		let functions = self.registry.functions.read().await;
110
111		// Check max function limit
112		if functions.len() >= self.config.max_functions {
113			return Err(anyhow::anyhow!(
114				"Maximum number of exported functions reached: {}",
115				self.config.max_functions
116			));
117		}
118
119		drop(functions);
120
121		let mut functions = self.registry.functions.write().await;
122
123		let registered_at = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
124
125		functions.insert(
126			name.to_string(),
127			RegisteredHostFunction {
128				name:name.to_string(),
129				signature,
130				callback:Some(callback),
131				registered_at,
132				stats:FunctionStats::default(),
133			},
134		);
135
136		debug!("Host function registered for WASM export: {}", name);
137		Ok(())
138	}
139
140	/// Register multiple host functions
141	#[instrument(skip(self, callbacks))]
142	pub async fn register_functions(
143		&self,
144		signatures:Vec<FunctionSignature>,
145		callbacks:Vec<HostFunctionCallback>,
146	) -> Result<()> {
147		if signatures.len() != callbacks.len() {
148			return Err(anyhow::anyhow!("Number of signatures must match number of callbacks"));
149		}
150
151		for (sig, callback) in signatures.into_iter().zip(callbacks) {
152			let name = sig.name.clone();
153			self.register_function(&name, sig, callback).await?;
154		}
155
156		Ok(())
157	}
158
159	/// Export all registered functions to a WASMtime linker
160	#[instrument(skip(self, linker))]
161	pub async fn export_to_linker<T>(&self, linker:&mut Linker<T>) -> Result<()>
162	where
163		T: Send + 'static, {
164		info!(
165			"Exporting {} host functions to linker",
166			self.registry.functions.read().await.len()
167		);
168
169		let functions = self.registry.functions.read().await;
170
171		for (name, func) in functions.iter() {
172			self.export_single_function(linker, name, func)?;
173		}
174
175		info!("All host functions exported to linker");
176		Ok(())
177	}
178
179	/// Export a single function to the linker
180	fn export_single_function<T>(&self, linker:&mut Linker<T>, name:&str, func:&RegisteredHostFunction) -> Result<()>
181	where
182		T: Send + 'static, {
183		debug!("Exporting function: {}", name);
184
185		let callback = func
186			.callback
187			.ok_or_else(|| anyhow::anyhow!("No callback available for function: {}", name))?;
188
189		let func_name = if let Some(prefix) = &self.config.name_prefix {
190			format!("{}{}", prefix, name)
191		} else {
192			name.to_string()
193		};
194		
195		let func_name_for_debug = func_name.clone();
196		let func_name_inner = func_name.clone();
197
198		// Create a wrapper function that handles stats and error handling
199		let wrapped_callback =
200			move |mut _caller:Caller<'_, T>, args:&[wasmtime::Val]| -> Result<Vec<wasmtime::Val>, wasmtime::Trap> {
201				let start = std::time::Instant::now();
202
203				// Convert args to bytes
204				let args_bytes:Result<Vec<bytes::Bytes>, _> = args
205					.iter()
206					.map(|arg| {
207						match arg {
208							wasmtime::Val::I32(i) => {
209								serde_json::to_vec(i)
210									.map(bytes::Bytes::from)
211									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
212							},
213							wasmtime::Val::I64(i) => {
214								serde_json::to_vec(i)
215									.map(bytes::Bytes::from)
216									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
217							},
218							wasmtime::Val::F32(f) => {
219								serde_json::to_vec(f)
220									.map(bytes::Bytes::from)
221									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
222							},
223							wasmtime::Val::F64(f) => {
224								serde_json::to_vec(f)
225									.map(bytes::Bytes::from)
226									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
227							},
228							_ => Err(anyhow::anyhow!("Unsupported argument type")),
229						}
230					})
231					.collect();
232
233				let args_bytes = args_bytes.map_err(|_| {
234					warn!("Error converting arguments for function '{}'", func_name_inner);
235					wasmtime::Trap::StackOverflow
236				})?;
237
238				// Call the callback
239				let result = callback(args_bytes);
240
241				match result {
242					Ok(response_bytes) => {
243						// Deserialize response
244						let result_val:serde_json::Value = serde_json::from_slice(&response_bytes)
245							.map_err(|_| {
246								warn!("Error deserializing response for function '{}'", func_name_inner);
247								wasmtime::Trap::StackOverflow
248							})?;
249
250						let ret_val = match result_val {
251							serde_json::Value::Number(n) => {
252								if let Some(i) = n.as_i64() {
253									wasmtime::Val::I32(i as i32)
254								} else if let Some(f) = n.as_f64() {
255									wasmtime::Val::I64(f as i64)
256								} else {
257									warn!("Invalid number format for function '{}'", func_name_inner);
258									return Err(wasmtime::Trap::StackOverflow);
259								}
260							},
261							_ => {
262								warn!("Unsupported response type for function '{}'", func_name_inner);
263								return Err(wasmtime::Trap::StackOverflow);
264							},
265						};
266
267						Ok(vec![ret_val])
268					},
269					Err(e) => {
270						// Error handling
271						debug!("Host function '{}' returned error: {}", func_name_inner, e);
272						Err(wasmtime::Trap::StackOverflow)
273					},
274				}
275			};
276
277		// Define the function signature for WASMtime
278		let _wasmparser_signature = wasmparser::FuncType::new([wasmparser::ValType::I32], [wasmparser::ValType::I32]);
279
280		// Register host function with the linker using simple i32->i32 signature
281		// In Wasmtime 20, func_wrap expects parameters to be inferred from the closure signature
282		let func_name_for_logging = func_name.clone();
283		linker.func_wrap(
284			"_host", // Module name for host functions
285			&func_name,
286			move |mut caller:wasmtime::Caller<'_, T>, input_param:i32| -> i32 {
287				// Track function call for metrics
288				let start = std::time::Instant::now();
289
290				// Convert input parameter to bytes for callback
291				let args_bytes = match serde_json::to_vec(&input_param)
292					.map(bytes::Bytes::from) {
293					Ok(b) => b,
294					Err(e) => {
295						warn!("Serialization error for function '{}': {}", func_name_for_logging, e);
296						return -1i32;
297					}
298				};
299
300				// Call the registered callback
301				let result = callback(vec![args_bytes]);
302
303				match result {
304					Ok(response_bytes) => {
305						// Deserialize response
306						let result_val:serde_json::Value = match serde_json::from_slice(&response_bytes) {
307							Ok(v) => v,
308							Err(_) => {
309								warn!("Error deserializing response for function '{}'", func_name_for_logging);
310								return -1i32;
311							}
312						};
313
314						// Extract result value
315						let ret_val = match result_val {
316							serde_json::Value::Number(n) => {
317								if let Some(i) = n.as_i64() {
318									i as i32
319								} else if let Some(f) = n.as_f64() {
320									f as i32
321								} else {
322									warn!("Invalid number format for function '{}'", func_name_for_logging);
323									-1i32
324								}
325							},
326							serde_json::Value::Bool(b) => {
327								if b { 1i32 } else { 0i32 }
328							},
329							_ => {
330								warn!("Unsupported response type for function '{}', expected number or bool", func_name_for_logging);
331								-1i32
332							},
333						};
334
335						// Log successful call
336						let duration = start.elapsed();
337						debug!("[FunctionExport] Host function '{}' executed successfully in {}µs", func_name_for_logging, duration.as_micros());
338
339						ret_val
340					},
341					Err(e) => {
342						// Error handling - return error code to WASM caller
343						debug!("[FunctionExport] Host function '{}' returned error: {}", func_name_for_logging, e);
344						// Return -1 to indicate error to WASM caller
345						-1i32
346					},
347				}
348			},
349		).map_err(|e| {
350			warn!("[FunctionExport] Failed to wrap host function '{}': {}", func_name_for_debug, e);
351			e
352		})?;
353
354		debug!("[FunctionExport] Host function '{}' registered successfully", func_name_for_debug);
355
356		Ok(())
357	}
358
359	/// Convert our signature to WASMtime signature type
360	fn wasmtime_signature_from_signature(&self, _sig:&FunctionSignature) -> Result<wasmparser::FuncType> {
361		// This is a placeholder - actual implementation depends on the exact types
362		// In production, this would map ParamType and ReturnType to WASMtime types
363		Ok(wasmparser::FuncType::new([], []))
364	}
365
366	/// Get all registered function names
367	pub async fn get_function_names(&self) -> Vec<String> {
368		self.registry.functions.read().await.keys().cloned().collect()
369	}
370
371	/// Get function statistics
372	pub async fn get_function_stats(&self, name:&str) -> Option<FunctionStats> {
373		self.registry.functions.read().await.get(name).map(|f| f.stats.clone())
374	}
375
376	/// Unregister a function
377	#[instrument(skip(self))]
378	pub async fn unregister_function(&self, name:&str) -> Result<bool> {
379		let mut functions = self.registry.functions.write().await;
380		let removed = functions.remove(name).is_some();
381
382		if removed {
383			info!("Unregistered host function: {}", name);
384		} else {
385			warn!("Attempted to unregister non-existent function: {}", name);
386		}
387
388		Ok(removed)
389	}
390
391	/// Clear all registered functions
392	pub async fn clear(&self) {
393		info!("Clearing all registered host functions");
394		self.registry.functions.write().await.clear();
395	}
396}
397
398#[cfg(test)]
399mod tests {
400	use super::*;
401
402	#[tokio::test]
403	async fn test_function_export_creation() {
404		let bridge = Arc::new(HostBridgeImpl::new());
405		let export = FunctionExportImpl::new(bridge);
406
407		assert_eq!(export.get_function_names().await.len(), 0);
408	}
409
410	#[tokio::test]
411	async fn test_register_function() {
412		let bridge = Arc::new(HostBridgeImpl::new());
413		let export = FunctionExportImpl::new(bridge);
414
415		let signature = FunctionSignature {
416			name:"echo".to_string(),
417			param_types:vec![ParamType::I32],
418			return_type:Some(ReturnType::I32),
419			is_async:false,
420		};
421
422		let callback = |args:Vec<bytes::Bytes>| Ok(args.get(0).cloned().unwrap_or(bytes::Bytes::new()));
423
424		let result: anyhow::Result<()> = export.register_function("echo", signature, callback).await;
425		assert!(result.is_ok());
426		assert_eq!(export.get_function_names().await.len(), 1);
427	}
428
429	#[tokio::test]
430	async fn test_unregister_function() {
431		let bridge = Arc::new(HostBridgeImpl::new());
432		let export = FunctionExportImpl::new(bridge);
433
434		let signature = FunctionSignature {
435			name:"test".to_string(),
436			param_types:vec![ParamType::I32],
437			return_type:Some(ReturnType::I32),
438			is_async:false,
439		};
440
441		let callback = |_:Vec<bytes::Bytes>| Ok(bytes::Bytes::new());
442		let _: anyhow::Result<()> = export.register_function("test", signature, callback).await;
443
444		let result: bool = export.unregister_function("test").await.unwrap();
445		assert!(result);
446		assert_eq!(export.get_function_names().await.len(), 0);
447	}
448
449	#[test]
450	fn test_export_config_default() {
451		let config = ExportConfig::default();
452		assert_eq!(config.auto_export, true);
453		assert_eq!(config.max_functions, 1000);
454	}
455
456	#[test]
457	fn test_function_stats_default() {
458		let stats = FunctionStats::default();
459		assert_eq!(stats.call_count, 0);
460		assert_eq!(stats.error_count, 0);
461	}
462}