1use std::{collections::HashMap, sync::Arc};
7
8use anyhow::{Context, Result};
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11use tracing::{debug, info, instrument, warn};
12use wasmtime::{Caller, Engine, Extern, Func, FuncType, Linker, Store, StoreLimits, Trap, Val, ValType};
13
14use crate::WASM::HostBridge::{FunctionSignature, HostBridgeImpl as HostBridge, HostBridgeImpl, HostFunctionCallback, ParamType, ReturnType};
15
16pub struct HostFunctionRegistry {
18 functions:Arc<RwLock<HashMap<String, RegisteredHostFunction>>>,
20 bridge:Arc<HostBridge>,
22}
23
24#[derive(Debug, Clone)]
26struct RegisteredHostFunction {
27 name:String,
29 signature:FunctionSignature,
31 callback:Option<HostFunctionCallback>,
33 registered_at:u64,
35 stats:FunctionStats,
37}
38
39#[derive(Debug, Clone, Default)]
41pub struct FunctionStats {
42 pub call_count:u64,
44 pub total_execution_ns:u64,
46 pub last_call_at:Option<u64>,
48 pub error_count:u64,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ExportConfig {
55 pub auto_export:bool,
57 pub enable_stats:bool,
59 pub max_functions:usize,
61 pub name_prefix:Option<String>,
63}
64
65impl Default for ExportConfig {
66 fn default() -> Self {
67 Self {
68 auto_export:true,
69 enable_stats:true,
70 max_functions:1000,
71 name_prefix:Some("host_".to_string()),
72 }
73 }
74}
75
76pub struct FunctionExportImpl {
78 registry:Arc<HostFunctionRegistry>,
79 config:ExportConfig,
80}
81
82impl FunctionExportImpl {
83 pub fn new(bridge:Arc<HostBridge>) -> Self {
85 Self {
86 registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
87 config:ExportConfig::default(),
88 }
89 }
90
91 pub fn with_config(bridge:Arc<HostBridge>, config:ExportConfig) -> Self {
93 Self {
94 registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
95 config,
96 }
97 }
98
99 #[instrument(skip(self, callback))]
101 pub async fn register_function(
102 &self,
103 name:&str,
104 signature:FunctionSignature,
105 callback:HostFunctionCallback,
106 ) -> Result<()> {
107 info!("Registering host function for export: {}", name);
108
109 let functions = self.registry.functions.read().await;
110
111 if functions.len() >= self.config.max_functions {
113 return Err(anyhow::anyhow!(
114 "Maximum number of exported functions reached: {}",
115 self.config.max_functions
116 ));
117 }
118
119 drop(functions);
120
121 let mut functions = self.registry.functions.write().await;
122
123 let registered_at = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
124
125 functions.insert(
126 name.to_string(),
127 RegisteredHostFunction {
128 name:name.to_string(),
129 signature,
130 callback:Some(callback),
131 registered_at,
132 stats:FunctionStats::default(),
133 },
134 );
135
136 debug!("Host function registered for WASM export: {}", name);
137 Ok(())
138 }
139
140 #[instrument(skip(self, callbacks))]
142 pub async fn register_functions(
143 &self,
144 signatures:Vec<FunctionSignature>,
145 callbacks:Vec<HostFunctionCallback>,
146 ) -> Result<()> {
147 if signatures.len() != callbacks.len() {
148 return Err(anyhow::anyhow!("Number of signatures must match number of callbacks"));
149 }
150
151 for (sig, callback) in signatures.into_iter().zip(callbacks) {
152 let name = sig.name.clone();
153 self.register_function(&name, sig, callback).await?;
154 }
155
156 Ok(())
157 }
158
159 #[instrument(skip(self, linker))]
161 pub async fn export_to_linker<T>(&self, linker:&mut Linker<T>) -> Result<()>
162 where
163 T: Send + 'static, {
164 info!(
165 "Exporting {} host functions to linker",
166 self.registry.functions.read().await.len()
167 );
168
169 let functions = self.registry.functions.read().await;
170
171 for (name, func) in functions.iter() {
172 self.export_single_function(linker, name, func)?;
173 }
174
175 info!("All host functions exported to linker");
176 Ok(())
177 }
178
179 fn export_single_function<T>(&self, linker:&mut Linker<T>, name:&str, func:&RegisteredHostFunction) -> Result<()>
181 where
182 T: Send + 'static, {
183 debug!("Exporting function: {}", name);
184
185 let callback = func
186 .callback
187 .ok_or_else(|| anyhow::anyhow!("No callback available for function: {}", name))?;
188
189 let func_name = if let Some(prefix) = &self.config.name_prefix {
190 format!("{}{}", prefix, name)
191 } else {
192 name.to_string()
193 };
194
195 let func_name_for_debug = func_name.clone();
196 let func_name_inner = func_name.clone();
197
198 let wrapped_callback =
200 move |mut _caller:Caller<'_, T>, args:&[wasmtime::Val]| -> Result<Vec<wasmtime::Val>, wasmtime::Trap> {
201 let start = std::time::Instant::now();
202
203 let args_bytes:Result<Vec<bytes::Bytes>, _> = args
205 .iter()
206 .map(|arg| {
207 match arg {
208 wasmtime::Val::I32(i) => {
209 serde_json::to_vec(i)
210 .map(bytes::Bytes::from)
211 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
212 },
213 wasmtime::Val::I64(i) => {
214 serde_json::to_vec(i)
215 .map(bytes::Bytes::from)
216 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
217 },
218 wasmtime::Val::F32(f) => {
219 serde_json::to_vec(f)
220 .map(bytes::Bytes::from)
221 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
222 },
223 wasmtime::Val::F64(f) => {
224 serde_json::to_vec(f)
225 .map(bytes::Bytes::from)
226 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
227 },
228 _ => Err(anyhow::anyhow!("Unsupported argument type")),
229 }
230 })
231 .collect();
232
233 let args_bytes = args_bytes.map_err(|_| {
234 warn!("Error converting arguments for function '{}'", func_name_inner);
235 wasmtime::Trap::StackOverflow
236 })?;
237
238 let result = callback(args_bytes);
240
241 match result {
242 Ok(response_bytes) => {
243 let result_val:serde_json::Value = serde_json::from_slice(&response_bytes)
245 .map_err(|_| {
246 warn!("Error deserializing response for function '{}'", func_name_inner);
247 wasmtime::Trap::StackOverflow
248 })?;
249
250 let ret_val = match result_val {
251 serde_json::Value::Number(n) => {
252 if let Some(i) = n.as_i64() {
253 wasmtime::Val::I32(i as i32)
254 } else if let Some(f) = n.as_f64() {
255 wasmtime::Val::I64(f as i64)
256 } else {
257 warn!("Invalid number format for function '{}'", func_name_inner);
258 return Err(wasmtime::Trap::StackOverflow);
259 }
260 },
261 _ => {
262 warn!("Unsupported response type for function '{}'", func_name_inner);
263 return Err(wasmtime::Trap::StackOverflow);
264 },
265 };
266
267 Ok(vec![ret_val])
268 },
269 Err(e) => {
270 debug!("Host function '{}' returned error: {}", func_name_inner, e);
272 Err(wasmtime::Trap::StackOverflow)
273 },
274 }
275 };
276
277 let _wasmparser_signature = wasmparser::FuncType::new([wasmparser::ValType::I32], [wasmparser::ValType::I32]);
279
280 let func_name_for_logging = func_name.clone();
283 linker.func_wrap(
284 "_host", &func_name,
286 move |mut caller:wasmtime::Caller<'_, T>, input_param:i32| -> i32 {
287 let start = std::time::Instant::now();
289
290 let args_bytes = match serde_json::to_vec(&input_param)
292 .map(bytes::Bytes::from) {
293 Ok(b) => b,
294 Err(e) => {
295 warn!("Serialization error for function '{}': {}", func_name_for_logging, e);
296 return -1i32;
297 }
298 };
299
300 let result = callback(vec![args_bytes]);
302
303 match result {
304 Ok(response_bytes) => {
305 let result_val:serde_json::Value = match serde_json::from_slice(&response_bytes) {
307 Ok(v) => v,
308 Err(_) => {
309 warn!("Error deserializing response for function '{}'", func_name_for_logging);
310 return -1i32;
311 }
312 };
313
314 let ret_val = match result_val {
316 serde_json::Value::Number(n) => {
317 if let Some(i) = n.as_i64() {
318 i as i32
319 } else if let Some(f) = n.as_f64() {
320 f as i32
321 } else {
322 warn!("Invalid number format for function '{}'", func_name_for_logging);
323 -1i32
324 }
325 },
326 serde_json::Value::Bool(b) => {
327 if b { 1i32 } else { 0i32 }
328 },
329 _ => {
330 warn!("Unsupported response type for function '{}', expected number or bool", func_name_for_logging);
331 -1i32
332 },
333 };
334
335 let duration = start.elapsed();
337 debug!("[FunctionExport] Host function '{}' executed successfully in {}µs", func_name_for_logging, duration.as_micros());
338
339 ret_val
340 },
341 Err(e) => {
342 debug!("[FunctionExport] Host function '{}' returned error: {}", func_name_for_logging, e);
344 -1i32
346 },
347 }
348 },
349 ).map_err(|e| {
350 warn!("[FunctionExport] Failed to wrap host function '{}': {}", func_name_for_debug, e);
351 e
352 })?;
353
354 debug!("[FunctionExport] Host function '{}' registered successfully", func_name_for_debug);
355
356 Ok(())
357 }
358
359 fn wasmtime_signature_from_signature(&self, _sig:&FunctionSignature) -> Result<wasmparser::FuncType> {
361 Ok(wasmparser::FuncType::new([], []))
364 }
365
366 pub async fn get_function_names(&self) -> Vec<String> {
368 self.registry.functions.read().await.keys().cloned().collect()
369 }
370
371 pub async fn get_function_stats(&self, name:&str) -> Option<FunctionStats> {
373 self.registry.functions.read().await.get(name).map(|f| f.stats.clone())
374 }
375
376 #[instrument(skip(self))]
378 pub async fn unregister_function(&self, name:&str) -> Result<bool> {
379 let mut functions = self.registry.functions.write().await;
380 let removed = functions.remove(name).is_some();
381
382 if removed {
383 info!("Unregistered host function: {}", name);
384 } else {
385 warn!("Attempted to unregister non-existent function: {}", name);
386 }
387
388 Ok(removed)
389 }
390
391 pub async fn clear(&self) {
393 info!("Clearing all registered host functions");
394 self.registry.functions.write().await.clear();
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 #[tokio::test]
403 async fn test_function_export_creation() {
404 let bridge = Arc::new(HostBridgeImpl::new());
405 let export = FunctionExportImpl::new(bridge);
406
407 assert_eq!(export.get_function_names().await.len(), 0);
408 }
409
410 #[tokio::test]
411 async fn test_register_function() {
412 let bridge = Arc::new(HostBridgeImpl::new());
413 let export = FunctionExportImpl::new(bridge);
414
415 let signature = FunctionSignature {
416 name:"echo".to_string(),
417 param_types:vec![ParamType::I32],
418 return_type:Some(ReturnType::I32),
419 is_async:false,
420 };
421
422 let callback = |args:Vec<bytes::Bytes>| Ok(args.get(0).cloned().unwrap_or(bytes::Bytes::new()));
423
424 let result: anyhow::Result<()> = export.register_function("echo", signature, callback).await;
425 assert!(result.is_ok());
426 assert_eq!(export.get_function_names().await.len(), 1);
427 }
428
429 #[tokio::test]
430 async fn test_unregister_function() {
431 let bridge = Arc::new(HostBridgeImpl::new());
432 let export = FunctionExportImpl::new(bridge);
433
434 let signature = FunctionSignature {
435 name:"test".to_string(),
436 param_types:vec![ParamType::I32],
437 return_type:Some(ReturnType::I32),
438 is_async:false,
439 };
440
441 let callback = |_:Vec<bytes::Bytes>| Ok(bytes::Bytes::new());
442 let _: anyhow::Result<()> = export.register_function("test", signature, callback).await;
443
444 let result: bool = export.unregister_function("test").await.unwrap();
445 assert!(result);
446 assert_eq!(export.get_function_names().await.len(), 0);
447 }
448
449 #[test]
450 fn test_export_config_default() {
451 let config = ExportConfig::default();
452 assert_eq!(config.auto_export, true);
453 assert_eq!(config.max_functions, 1000);
454 }
455
456 #[test]
457 fn test_function_stats_default() {
458 let stats = FunctionStats::default();
459 assert_eq!(stats.call_count, 0);
460 assert_eq!(stats.error_count, 0);
461 }
462}