grove/WASM/
HostBridge.rs

1//! Host Bridge
2//!
3//! Provides bidirectional communication between the host (Grove) and WASM
4//! modules. Handles function calls, data transfer, and marshalling between the
5//! two environments.
6
7use std::{collections::HashMap, sync::Arc};
8
9use anyhow::{Context, Result};
10use bytes::Bytes;
11use serde::{Serialize, de::DeserializeOwned};
12use tokio::sync::{RwLock, mpsc, oneshot};
13use tracing::{debug, error, instrument, warn};
14use wasmtime::{Caller, Extern, Func, Linker, Store};
15
16/// Host bridge error types
17#[derive(Debug, thiserror::Error)]
18pub enum BridgeError {
19/// Function not found error
20#[error("Function not found: {0}")]
21FunctionNotFound(String),
22
23/// Invalid function signature error
24#[error("Invalid function signature: {0}")]
25InvalidSignature(String),
26
27/// Serialization failed error
28#[error("Serialization failed: {0}")]
29SerializationError(String),
30
31/// Deserialization failed error
32#[error("Deserialization failed: {0}")]
33DeserializationError(String),
34
35/// Host function error
36#[error("Host function error: {0}")]
37HostFunctionError(String),
38
39/// Communication timeout error
40#[error("Communication timeout")]
41Timeout,
42
43/// Bridge closed error
44#[error("Bridge closed")]
45BridgeClosed,
46}
47
48/// Type-safe result for operations
49pub type BridgeResult<T> = Result<T, BridgeError>;
50
51/// Function signature information
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53pub struct FunctionSignature {
54	/// Function name
55	pub name:String,
56	/// Parameter types
57	pub param_types:Vec<ParamType>,
58	/// Return type
59	pub return_type:Option<ReturnType>,
60	/// Whether this is an async function
61	pub is_async:bool,
62}
63
64/// Parameter types for WASM functions
65#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
66pub enum ParamType {
67/// 32-bit signed integer parameter
68I32,
69/// 64-bit signed integer parameter
70I64,
71/// 32-bit floating point parameter
72F32,
73/// 64-bit floating point parameter
74F64,
75/// Pointer to memory
76Ptr,
77/// Length parameter following a pointer
78Len,
79}
80
81/// Return types for WASM functions
82#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
83pub enum ReturnType {
84/// 32-bit signed integer return type
85I32,
86/// 64-bit signed integer return type
87I64,
88/// 32-bit floating point return type
89F32,
90/// 64-bit floating point return type
91F64,
92/// No return value (void)
93Void,
94}
95
96/// Message sent from WASM to host
97#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
98pub struct HostMessage {
99	/// Message ID for correlation
100	pub message_id:String,
101	/// Function name to call
102	pub function:String,
103	/// Serialized arguments
104	pub args:Vec<Bytes>,
105	/// Callback token for async responses
106	pub callback_token:Option<u64>,
107}
108
109/// Response sent from host to WASM
110#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
111pub struct HostResponse {
112	/// Correlating message ID
113	pub message_id:String,
114	/// Success flag
115	pub success:bool,
116	/// Response data
117	pub data:Option<Bytes>,
118	/// Error message if failed
119	pub error:Option<String>,
120}
121
122/// Callback for async function responses
123#[derive(Clone)]
124pub struct AsyncCallback {
125/// Sender for transmitting the response
126sender:Arc<tokio::sync::Mutex<Option<tokio::sync::oneshot::Sender<HostResponse>>>>,
127/// Message ID for correlation
128message_id:String,
129}
130
131impl std::fmt::Debug for AsyncCallback {
132	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133		f.debug_struct("AsyncCallback")
134			.field("message_id", &self.message_id)
135			.finish()
136	}
137}
138
139impl AsyncCallback {
140	/// Send response through the callback
141	pub async fn send(self, response:HostResponse) -> Result<()> {
142		let mut sender_opt = self.sender.lock().await;
143		if let Some(sender) = sender_opt.take() {
144			sender.send(response).map_err(|_| BridgeError::BridgeClosed)?;
145			Ok(())
146		} else {
147			Err(BridgeError::BridgeClosed.into())
148		}
149	}
150}
151
152/// Message from host to WASM
153#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
154pub struct WASMMessage {
155	/// Target function in WASM
156	pub function:String,
157	/// Arguments
158	pub args:Vec<Bytes>,
159}
160
161/// Host function callback type
162pub type HostFunctionCallback = fn(Vec<Bytes>) -> Result<Bytes>;
163
164/// Async host function callback type
165pub type AsyncHostFunctionCallback =
166	fn(Vec<Bytes>) -> Box<dyn std::future::Future<Output = Result<Bytes>> + Send + Unpin>;
167
168/// Host function definition
169#[derive(Debug)]
170pub struct HostFunction {
171	/// Function name
172	pub name:String,
173	/// Function signature
174	pub signature:FunctionSignature,
175	/// Synchronous callback - not serializable (skip serde derive)
176	#[allow(dead_code)]
177	pub callback:Option<HostFunctionCallback>,
178	/// Async callback - not serializable (skip serde derive)
179	#[allow(dead_code)]
180	pub async_callback:Option<AsyncHostFunctionCallback>,
181}
182
183/// Host Bridge for WASM communication
184#[derive(Debug)]
185pub struct HostBridgeImpl {
186	/// Registry of host functions exported to WASM
187	host_functions:Arc<RwLock<HashMap<String, HostFunction>>>,
188	/// Channel for receiving messages from WASM
189	wasm_to_host_rx:mpsc::UnboundedReceiver<WASMMessage>,
190	/// Channel for sending messages to WASM
191	host_to_wasm_tx:mpsc::UnboundedSender<WASMMessage>,
192	/// Active async callbacks
193	async_callbacks:Arc<RwLock<HashMap<u64, AsyncCallback>>>,
194	/// Next callback token
195	next_callback_token:Arc<std::sync::atomic::AtomicU64>,
196}
197
198impl HostBridgeImpl {
199	/// Create a new host bridge
200	pub fn new() -> Self {
201		let (wasm_to_host_tx, wasm_to_host_rx) = mpsc::unbounded_channel();
202		let (host_to_wasm_tx, host_to_wasm_rx) = mpsc::unbounded_channel();
203
204		// In a real implementation, we'd need to wire these up properly
205		// For now, we drop the receiver to avoid unused warnings
206		drop(host_to_wasm_rx);
207
208		Self {
209			host_functions:Arc::new(RwLock::new(HashMap::new())),
210			wasm_to_host_rx,
211			host_to_wasm_tx,
212			async_callbacks:Arc::new(RwLock::new(HashMap::new())),
213			next_callback_token:Arc::new(std::sync::atomic::AtomicU64::new(0)),
214		}
215	}
216
217	/// Register a host function to be exported to WASM
218	#[instrument(skip(self, callback))]
219	pub async fn register_host_function(
220		&self,
221		name:&str,
222		signature:FunctionSignature,
223		callback:HostFunctionCallback,
224	) -> BridgeResult<()> {
225		debug!("Registering host function: {}", name);
226
227		let mut functions = self.host_functions.write().await;
228
229		if functions.contains_key(name) {
230			warn!("Host function already registered: {}", name);
231		}
232
233		functions.insert(
234			name.to_string(),
235			HostFunction { name:name.to_string(), signature, callback:Some(callback), async_callback:None },
236		);
237
238		debug!("Host function registered successfully: {}", name);
239		Ok(())
240	}
241
242	/// Register an async host function
243	#[instrument(skip(self, callback))]
244	pub async fn register_async_host_function(
245		&self,
246		name:&str,
247		signature:FunctionSignature,
248		callback:AsyncHostFunctionCallback,
249	) -> BridgeResult<()> {
250		debug!("Registering async host function: {}", name);
251
252		let mut functions = self.host_functions.write().await;
253
254		functions.insert(
255			name.to_string(),
256			HostFunction { name:name.to_string(), signature, callback:None, async_callback:Some(callback) },
257		);
258
259		debug!("Async host function registered successfully: {}", name);
260		Ok(())
261	}
262
263	/// Call a host function from WASM
264	#[instrument(skip(self, args))]
265	pub async fn call_host_function(&self, function_name:&str, args:Vec<Bytes>) -> BridgeResult<Bytes> {
266		debug!("Calling host function: {}", function_name);
267
268		let functions = self.host_functions.read().await;
269		let func = functions
270			.get(function_name)
271			.ok_or_else(|| BridgeError::FunctionNotFound(function_name.to_string()))?;
272
273		if let Some(callback) = func.callback {
274			// Synchronous call
275			let result =
276				callback(args).map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
277			debug!("Host function call completed: {}", function_name);
278			Ok(result)
279		} else if let Some(async_callback) = func.async_callback {
280			// Async call
281			let future = async_callback(args);
282			let result = future
283				.await
284				.map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
285			debug!("Async host function call completed: {}", function_name);
286			Ok(result)
287		} else {
288			Err(BridgeError::FunctionNotFound(format!(
289				"No callback for function: {}",
290				function_name
291			)))
292		}
293	}
294
295	/// Send a message to WASM
296	#[instrument(skip(self, message))]
297	pub async fn send_to_wasm(&self, message:WASMMessage) -> BridgeResult<()> {
298		let function_name = message.function.clone();
299		self.host_to_wasm_tx.send(message).map_err(|_| BridgeError::BridgeClosed)?;
300		debug!("Message sent to WASM: {}", function_name);
301		Ok(())
302	}
303
304	/// Receive a message from WASM (blocking)
305	pub async fn receive_from_wasm(&mut self) -> Option<WASMMessage> { self.wasm_to_host_rx.recv().await }
306
307	/// Create async callback
308	#[instrument(skip(self))]
309	pub async fn create_async_callback(&self, message_id:String) -> (AsyncCallback, u64) {
310		let token = self.next_callback_token.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
311		let (tx, _rx) = oneshot::channel();
312
313		// Create callback with Arc-wrapped sender
314		let callback = AsyncCallback {
315			sender:Arc::new(tokio::sync::Mutex::new(Some(tx))),
316			message_id:message_id.clone(),
317		};
318
319		self.async_callbacks.write().await.insert(token, callback.clone());
320
321		(callback, token)
322	}
323
324	/// Get callback by token
325	#[instrument(skip(self))]
326	pub async fn get_callback(&self, token:u64) -> Option<AsyncCallback> {
327		self.async_callbacks.write().await.remove(&token)
328	}
329
330	/// Get all registered host functions
331	pub async fn get_host_functions(&self) -> Vec<String> { self.host_functions.read().await.keys().cloned().collect() }
332
333	/// Unregister a host function
334	#[instrument(skip(self))]
335	pub async fn unregister_host_function(&self, name:&str) -> bool {
336		let mut functions = self.host_functions.write().await;
337		let removed = functions.remove(name).is_some();
338		if removed {
339			debug!("Host function unregistered: {}", name);
340		}
341		removed
342	}
343
344	/// Clear all registered functions
345	pub async fn clear(&self) {
346		debug!("Clearing all registered host functions");
347		self.host_functions.write().await.clear();
348		self.async_callbacks.write().await.clear();
349	}
350}
351
352impl Default for HostBridgeImpl {
353	fn default() -> Self { Self::new() }
354}
355
356/// Utility function to serialize data to Bytes
357pub fn serialize_to_bytes<T:Serialize>(data:&T) -> Result<Bytes> {
358	serde_json::to_vec(data)
359		.map(Bytes::from)
360		.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
361}
362
363/// Utility function to deserialize Bytes to data
364pub fn deserialize_from_bytes<T:DeserializeOwned>(bytes:&Bytes) -> Result<T> {
365	serde_json::from_slice(bytes).map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))
366}
367
368/// Marshal arguments for WASM function call
369pub fn marshal_args(args:Vec<Bytes>) -> Result<Vec<wasmtime::Val>> {
370	args.iter()
371		.map(|bytes| {
372			let value:serde_json::Value = serde_json::from_slice(bytes)?;
373			match value {
374				serde_json::Value::Number(n) => {
375					if let Some(i) = n.as_i64() {
376						Ok(wasmtime::Val::I32(i as i32))
377					} else if let Some(f) = n.as_f64() {
378						Ok(wasmtime::Val::F64(f.to_bits()))
379					} else {
380						Err(anyhow::anyhow!("Invalid number value"))
381					}
382				},
383				_ => Err(anyhow::anyhow!("Unsupported argument type")),
384			}
385		})
386		.collect()
387}
388
389/// Unmarshal return values from WASM function call
390pub fn unmarshal_return(val:wasmtime::Val) -> Result<Bytes> {
391	match val {
392		wasmtime::Val::I32(i) => {
393			let json = serde_json::to_string(&i)?;
394			Ok(Bytes::from(json))
395		},
396		wasmtime::Val::I64(i) => {
397			let json = serde_json::to_string(&i)?;
398			Ok(Bytes::from(json))
399		},
400		wasmtime::Val::F32(f) => {
401			let json = serde_json::to_string(&f)?;
402			Ok(Bytes::from(json))
403		},
404		wasmtime::Val::F64(f) => {
405			let json = serde_json::to_string(&f)?;
406			Ok(Bytes::from(json))
407		},
408		_ => Err(anyhow::anyhow!("Unsupported return type")),
409	}
410}
411
412#[cfg(test)]
413mod tests {
414	use super::*;
415
416	#[test]
417	fn test_function_signature_creation() {
418		let signature = FunctionSignature {
419			name:"test_func".to_string(),
420			param_types:vec![ParamType::I32, ParamType::Ptr],
421			return_type:Some(ReturnType::I32),
422			is_async:false,
423		};
424
425		assert_eq!(signature.name, "test_func");
426		assert_eq!(signature.param_types.len(), 2);
427	}
428
429	#[tokio::test]
430	async fn test_host_bridge_creation() {
431		let bridge = HostBridgeImpl::new();
432		assert_eq!(bridge.get_host_functions().await.len(), 0);
433	}
434
435	#[tokio::test]
436	async fn test_register_host_function() {
437		let bridge = HostBridgeImpl::new();
438
439		let signature = FunctionSignature {
440			name:"echo".to_string(),
441			param_types:vec![ParamType::I32],
442			return_type:Some(ReturnType::I32),
443			is_async:false,
444		};
445
446		let result = bridge
447			.register_host_function("echo", signature, |args| Ok(args[0].clone()))
448			.await;
449
450		assert!(result.is_ok());
451		assert_eq!(bridge.get_host_functions().await.len(), 1);
452	}
453
454	#[test]
455	fn test_serialize_deserialize() {
456		let data = vec![1, 2, 3, 4, 5];
457		let bytes = serialize_to_bytes(&data).unwrap();
458		let recovered:Vec<i32> = deserialize_from_bytes(&bytes).unwrap();
459		assert_eq!(data, recovered);
460	}
461
462	#[test]
463	fn test_marshal_unmarshal() {
464		let args = vec![serialize_to_bytes(&42i32).unwrap(), serialize_to_bytes(&3.14f64).unwrap()];
465
466		// Test that marshaling works (we don't assert on exact type conversion)
467		let marshaled = marshal_args(args);
468		assert!(marshaled.is_ok());
469	}
470}