1use std::{collections::HashMap, sync::Arc};
8
9use anyhow::{Context, Result};
10use bytes::Bytes;
11use serde::{Serialize, de::DeserializeOwned};
12use tokio::sync::{RwLock, mpsc, oneshot};
13use tracing::{debug, error, instrument, warn};
14use wasmtime::{Caller, Extern, Func, Linker, Store};
15
16#[derive(Debug, thiserror::Error)]
18pub enum BridgeError {
19#[error("Function not found: {0}")]
21FunctionNotFound(String),
22
23#[error("Invalid function signature: {0}")]
25InvalidSignature(String),
26
27#[error("Serialization failed: {0}")]
29SerializationError(String),
30
31#[error("Deserialization failed: {0}")]
33DeserializationError(String),
34
35#[error("Host function error: {0}")]
37HostFunctionError(String),
38
39#[error("Communication timeout")]
41Timeout,
42
43#[error("Bridge closed")]
45BridgeClosed,
46}
47
48pub type BridgeResult<T> = Result<T, BridgeError>;
50
51#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53pub struct FunctionSignature {
54 pub name:String,
56 pub param_types:Vec<ParamType>,
58 pub return_type:Option<ReturnType>,
60 pub is_async:bool,
62}
63
64#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
66pub enum ParamType {
67I32,
69I64,
71F32,
73F64,
75Ptr,
77Len,
79}
80
81#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
83pub enum ReturnType {
84I32,
86I64,
88F32,
90F64,
92Void,
94}
95
96#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
98pub struct HostMessage {
99 pub message_id:String,
101 pub function:String,
103 pub args:Vec<Bytes>,
105 pub callback_token:Option<u64>,
107}
108
109#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
111pub struct HostResponse {
112 pub message_id:String,
114 pub success:bool,
116 pub data:Option<Bytes>,
118 pub error:Option<String>,
120}
121
122#[derive(Clone)]
124pub struct AsyncCallback {
125sender:Arc<tokio::sync::Mutex<Option<tokio::sync::oneshot::Sender<HostResponse>>>>,
127message_id:String,
129}
130
131impl std::fmt::Debug for AsyncCallback {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 f.debug_struct("AsyncCallback")
134 .field("message_id", &self.message_id)
135 .finish()
136 }
137}
138
139impl AsyncCallback {
140 pub async fn send(self, response:HostResponse) -> Result<()> {
142 let mut sender_opt = self.sender.lock().await;
143 if let Some(sender) = sender_opt.take() {
144 sender.send(response).map_err(|_| BridgeError::BridgeClosed)?;
145 Ok(())
146 } else {
147 Err(BridgeError::BridgeClosed.into())
148 }
149 }
150}
151
152#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
154pub struct WASMMessage {
155 pub function:String,
157 pub args:Vec<Bytes>,
159}
160
161pub type HostFunctionCallback = fn(Vec<Bytes>) -> Result<Bytes>;
163
164pub type AsyncHostFunctionCallback =
166 fn(Vec<Bytes>) -> Box<dyn std::future::Future<Output = Result<Bytes>> + Send + Unpin>;
167
168#[derive(Debug)]
170pub struct HostFunction {
171 pub name:String,
173 pub signature:FunctionSignature,
175 #[allow(dead_code)]
177 pub callback:Option<HostFunctionCallback>,
178 #[allow(dead_code)]
180 pub async_callback:Option<AsyncHostFunctionCallback>,
181}
182
183#[derive(Debug)]
185pub struct HostBridgeImpl {
186 host_functions:Arc<RwLock<HashMap<String, HostFunction>>>,
188 wasm_to_host_rx:mpsc::UnboundedReceiver<WASMMessage>,
190 host_to_wasm_tx:mpsc::UnboundedSender<WASMMessage>,
192 async_callbacks:Arc<RwLock<HashMap<u64, AsyncCallback>>>,
194 next_callback_token:Arc<std::sync::atomic::AtomicU64>,
196}
197
198impl HostBridgeImpl {
199 pub fn new() -> Self {
201 let (wasm_to_host_tx, wasm_to_host_rx) = mpsc::unbounded_channel();
202 let (host_to_wasm_tx, host_to_wasm_rx) = mpsc::unbounded_channel();
203
204 drop(host_to_wasm_rx);
207
208 Self {
209 host_functions:Arc::new(RwLock::new(HashMap::new())),
210 wasm_to_host_rx,
211 host_to_wasm_tx,
212 async_callbacks:Arc::new(RwLock::new(HashMap::new())),
213 next_callback_token:Arc::new(std::sync::atomic::AtomicU64::new(0)),
214 }
215 }
216
217 #[instrument(skip(self, callback))]
219 pub async fn register_host_function(
220 &self,
221 name:&str,
222 signature:FunctionSignature,
223 callback:HostFunctionCallback,
224 ) -> BridgeResult<()> {
225 debug!("Registering host function: {}", name);
226
227 let mut functions = self.host_functions.write().await;
228
229 if functions.contains_key(name) {
230 warn!("Host function already registered: {}", name);
231 }
232
233 functions.insert(
234 name.to_string(),
235 HostFunction { name:name.to_string(), signature, callback:Some(callback), async_callback:None },
236 );
237
238 debug!("Host function registered successfully: {}", name);
239 Ok(())
240 }
241
242 #[instrument(skip(self, callback))]
244 pub async fn register_async_host_function(
245 &self,
246 name:&str,
247 signature:FunctionSignature,
248 callback:AsyncHostFunctionCallback,
249 ) -> BridgeResult<()> {
250 debug!("Registering async host function: {}", name);
251
252 let mut functions = self.host_functions.write().await;
253
254 functions.insert(
255 name.to_string(),
256 HostFunction { name:name.to_string(), signature, callback:None, async_callback:Some(callback) },
257 );
258
259 debug!("Async host function registered successfully: {}", name);
260 Ok(())
261 }
262
263 #[instrument(skip(self, args))]
265 pub async fn call_host_function(&self, function_name:&str, args:Vec<Bytes>) -> BridgeResult<Bytes> {
266 debug!("Calling host function: {}", function_name);
267
268 let functions = self.host_functions.read().await;
269 let func = functions
270 .get(function_name)
271 .ok_or_else(|| BridgeError::FunctionNotFound(function_name.to_string()))?;
272
273 if let Some(callback) = func.callback {
274 let result =
276 callback(args).map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
277 debug!("Host function call completed: {}", function_name);
278 Ok(result)
279 } else if let Some(async_callback) = func.async_callback {
280 let future = async_callback(args);
282 let result = future
283 .await
284 .map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
285 debug!("Async host function call completed: {}", function_name);
286 Ok(result)
287 } else {
288 Err(BridgeError::FunctionNotFound(format!(
289 "No callback for function: {}",
290 function_name
291 )))
292 }
293 }
294
295 #[instrument(skip(self, message))]
297 pub async fn send_to_wasm(&self, message:WASMMessage) -> BridgeResult<()> {
298 let function_name = message.function.clone();
299 self.host_to_wasm_tx.send(message).map_err(|_| BridgeError::BridgeClosed)?;
300 debug!("Message sent to WASM: {}", function_name);
301 Ok(())
302 }
303
304 pub async fn receive_from_wasm(&mut self) -> Option<WASMMessage> { self.wasm_to_host_rx.recv().await }
306
307 #[instrument(skip(self))]
309 pub async fn create_async_callback(&self, message_id:String) -> (AsyncCallback, u64) {
310 let token = self.next_callback_token.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
311 let (tx, _rx) = oneshot::channel();
312
313 let callback = AsyncCallback {
315 sender:Arc::new(tokio::sync::Mutex::new(Some(tx))),
316 message_id:message_id.clone(),
317 };
318
319 self.async_callbacks.write().await.insert(token, callback.clone());
320
321 (callback, token)
322 }
323
324 #[instrument(skip(self))]
326 pub async fn get_callback(&self, token:u64) -> Option<AsyncCallback> {
327 self.async_callbacks.write().await.remove(&token)
328 }
329
330 pub async fn get_host_functions(&self) -> Vec<String> { self.host_functions.read().await.keys().cloned().collect() }
332
333 #[instrument(skip(self))]
335 pub async fn unregister_host_function(&self, name:&str) -> bool {
336 let mut functions = self.host_functions.write().await;
337 let removed = functions.remove(name).is_some();
338 if removed {
339 debug!("Host function unregistered: {}", name);
340 }
341 removed
342 }
343
344 pub async fn clear(&self) {
346 debug!("Clearing all registered host functions");
347 self.host_functions.write().await.clear();
348 self.async_callbacks.write().await.clear();
349 }
350}
351
352impl Default for HostBridgeImpl {
353 fn default() -> Self { Self::new() }
354}
355
356pub fn serialize_to_bytes<T:Serialize>(data:&T) -> Result<Bytes> {
358 serde_json::to_vec(data)
359 .map(Bytes::from)
360 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
361}
362
363pub fn deserialize_from_bytes<T:DeserializeOwned>(bytes:&Bytes) -> Result<T> {
365 serde_json::from_slice(bytes).map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))
366}
367
368pub fn marshal_args(args:Vec<Bytes>) -> Result<Vec<wasmtime::Val>> {
370 args.iter()
371 .map(|bytes| {
372 let value:serde_json::Value = serde_json::from_slice(bytes)?;
373 match value {
374 serde_json::Value::Number(n) => {
375 if let Some(i) = n.as_i64() {
376 Ok(wasmtime::Val::I32(i as i32))
377 } else if let Some(f) = n.as_f64() {
378 Ok(wasmtime::Val::F64(f.to_bits()))
379 } else {
380 Err(anyhow::anyhow!("Invalid number value"))
381 }
382 },
383 _ => Err(anyhow::anyhow!("Unsupported argument type")),
384 }
385 })
386 .collect()
387}
388
389pub fn unmarshal_return(val:wasmtime::Val) -> Result<Bytes> {
391 match val {
392 wasmtime::Val::I32(i) => {
393 let json = serde_json::to_string(&i)?;
394 Ok(Bytes::from(json))
395 },
396 wasmtime::Val::I64(i) => {
397 let json = serde_json::to_string(&i)?;
398 Ok(Bytes::from(json))
399 },
400 wasmtime::Val::F32(f) => {
401 let json = serde_json::to_string(&f)?;
402 Ok(Bytes::from(json))
403 },
404 wasmtime::Val::F64(f) => {
405 let json = serde_json::to_string(&f)?;
406 Ok(Bytes::from(json))
407 },
408 _ => Err(anyhow::anyhow!("Unsupported return type")),
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_function_signature_creation() {
418 let signature = FunctionSignature {
419 name:"test_func".to_string(),
420 param_types:vec![ParamType::I32, ParamType::Ptr],
421 return_type:Some(ReturnType::I32),
422 is_async:false,
423 };
424
425 assert_eq!(signature.name, "test_func");
426 assert_eq!(signature.param_types.len(), 2);
427 }
428
429 #[tokio::test]
430 async fn test_host_bridge_creation() {
431 let bridge = HostBridgeImpl::new();
432 assert_eq!(bridge.get_host_functions().await.len(), 0);
433 }
434
435 #[tokio::test]
436 async fn test_register_host_function() {
437 let bridge = HostBridgeImpl::new();
438
439 let signature = FunctionSignature {
440 name:"echo".to_string(),
441 param_types:vec![ParamType::I32],
442 return_type:Some(ReturnType::I32),
443 is_async:false,
444 };
445
446 let result = bridge
447 .register_host_function("echo", signature, |args| Ok(args[0].clone()))
448 .await;
449
450 assert!(result.is_ok());
451 assert_eq!(bridge.get_host_functions().await.len(), 1);
452 }
453
454 #[test]
455 fn test_serialize_deserialize() {
456 let data = vec![1, 2, 3, 4, 5];
457 let bytes = serialize_to_bytes(&data).unwrap();
458 let recovered:Vec<i32> = deserialize_from_bytes(&bytes).unwrap();
459 assert_eq!(data, recovered);
460 }
461
462 #[test]
463 fn test_marshal_unmarshal() {
464 let args = vec![serialize_to_bytes(&42i32).unwrap(), serialize_to_bytes(&3.14f64).unwrap()];
465
466 let marshaled = marshal_args(args);
468 assert!(marshaled.is_ok());
469 }
470}