grove/Host/
ExtensionManager.rs

1//! Extension Manager Module
2//!
3//! Handles extension discovery, loading, and management.
4//! Provides query and monitoring capabilities for extensions.
5
6use crate::Host::HostConfig;
7use crate::WASM::Runtime::WASMRuntime;
8use anyhow::{Context, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, info, instrument, warn};
15
16/// Extension manager for handling extension lifecycle
17pub struct ExtensionManagerImpl {
18    /// WASM runtime for executing extensions
19    wasm_runtime: Arc<WASMRuntime>,
20    /// Host configuration
21    config: HostConfig,
22    /// Loaded extensions
23    extensions: Arc<RwLock<HashMap<String, ExtensionInfo>>>,
24    /// Extension statistics
25    stats: Arc<RwLock<ExtensionStats>>,
26}
27
28/// Extension information
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ExtensionInfo {
31    /// Extension ID (e.g., "publisher.extension-name")
32    pub id: String,
33    /// Extension display name
34    pub display_name: String,
35    /// Extension description
36    pub description: String,
37    /// Extension version
38    pub version: String,
39    /// Publisher name
40    pub publisher: String,
41    /// Path to extension directory
42    pub path: PathBuf,
43    /// Entry point file
44    pub entry_point: PathBuf,
45    /// Activation events
46    pub activation_events: Vec<String>,
47    /// Type of extension (wasm, native, etc.)
48    pub extension_type: ExtensionType,
49    /// Extension state
50    pub state: ExtensionState,
51    /// Extension capabilities
52    pub capabilities: Vec<String>,
53    /// Dependencies
54    pub dependencies: Vec<String>,
55    /// Extension manifest (JSON)
56    pub manifest: serde_json::Value,
57    /// Load timestamp
58    pub loaded_at: u64,
59    /// Activation timestamp
60    pub activated_at: Option<u64>,
61}
62
63/// Extension type
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
65pub enum ExtensionType {
66    /// WebAssembly extension
67    WASM,
68    /// Native Rust extension
69    Native,
70    /// JavaScript/TypeScript extension (via Cocoon compatibility)
71    JavaScript,
72    /// Unknown type
73    Unknown,
74}
75
76/// Extension state
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
78pub enum ExtensionState {
79    /// Extension is loaded but not activated
80    Loaded,
81    /// Extension is activated and running
82    Activated,
83    /// Extension is deactivated
84    Deactivated,
85    /// Extension encountered an error
86    Error,
87}
88
89/// Extension statistics
90#[derive(Debug, Clone, Default, Serialize, Deserialize)]
91pub struct ExtensionStats {
92    /// Total number of extensions loaded
93    pub total_loaded: usize,
94    /// Total number of extensions activated
95    pub total_activated: usize,
96    /// Total number of extensions deactivated
97    pub total_deactivated: usize,
98    /// Total activation time in milliseconds
99    pub total_activation_time_ms: u64,
100    /// Number of errors encountered
101    pub errors: u64,
102}
103
104impl ExtensionManagerImpl {
105    /// Create a new extension manager
106    pub fn new(wasm_runtime: Arc<WASMRuntime>, config: HostConfig) -> Self {
107        Self {
108            wasm_runtime,
109            config,
110            extensions: Arc::new(RwLock::new(HashMap::new())),
111            stats: Arc::new(RwLock::new(ExtensionStats::default())),
112        }
113    }
114
115    /// Load an extension from a path
116    #[instrument(skip(self, path))]
117    pub async fn load_extension(&self, path: &PathBuf) -> Result<String> {
118        info!("Loading extension from: {:?}", path);
119
120        // Validate path
121        if !path.exists() {
122            return Err(anyhow::anyhow!("Extension path does not exist: {:?}", path));
123        }
124
125        // Parse manifest
126        let manifest = self.parse_manifest(path)?;
127        let extension_id = self.extract_extension_id(&manifest)?;
128
129        // Check if extension is already loaded
130        let extensions = self.extensions.read().await;
131        if extensions.contains_key(&extension_id) {
132            warn!("Extension already loaded: {}", extension_id);
133            return Ok(extension_id);
134        }
135        drop(extensions);
136
137        // Determine extension type
138        let extension_type = self.determine_extension_type(path, &manifest)?;
139
140        // Create extension info
141        let extension_info = ExtensionInfo {
142            id: extension_id.clone(),
143            display_name: manifest
144                .get("displayName")
145                .and_then(|v| v.as_str())
146                .unwrap_or("")
147                .to_string(),
148            description: manifest
149                .get("description")
150                .and_then(|v| v.as_str())
151                .unwrap_or("")
152                .to_string(),
153            version: manifest
154                .get("version")
155                .and_then(|v| v.as_str())
156                .unwrap_or("0.0.0")
157                .to_string(),
158            publisher: manifest
159                .get("publisher")
160                .and_then(|v| v.as_str())
161                .unwrap_or("")
162                .to_string(),
163            path: path.clone(),
164            entry_point: path.join(
165                manifest
166                    .get("main")
167                    .and_then(|v| v.as_str())
168                    .unwrap_or("dist/extension.js"),
169            ),
170            activation_events: self.extract_activation_events(&manifest),
171            extension_type,
172            state: ExtensionState::Loaded,
173            capabilities: self.extract_capabilities(&manifest),
174            dependencies: self.extract_dependencies(&manifest),
175            manifest,
176            loaded_at: std::time::SystemTime::now()
177                .duration_since(std::time::UNIX_EPOCH)
178                .map(|d| d.as_secs())
179                .unwrap_or(0),
180            activated_at: None,
181        };
182
183        // Register extension
184        let mut extensions = self.extensions.write().await;
185        extensions.insert(extension_id.clone(), extension_info);
186
187        // Update statistics
188        let mut stats = self.stats.write().await;
189        stats.total_loaded += 1;
190
191        info!("Extension loaded successfully: {}", extension_id);
192
193        Ok(extension_id)
194    }
195
196    /// Unload an extension
197    #[instrument(skip(self, extension_id))]
198    pub async fn unload_extension(&self, extension_id: &str) -> Result<()> {
199        info!("Unloading extension: {}", extension_id);
200
201        let mut extensions = self.extensions.write().await;
202        extensions.remove(extension_id);
203
204        info!("Extension unloaded: {}", extension_id);
205
206        Ok(())
207    }
208
209    /// Get an extension by ID
210    pub async fn get_extension(&self, extension_id: &str) -> Option<ExtensionInfo> {
211        self.extensions.read().await.get(extension_id).cloned()
212    }
213
214    /// List all loaded extensions
215    pub async fn list_extensions(&self) -> Vec<String> {
216        self.extensions
217            .read()
218            .await
219            .keys()
220            .cloned()
221            .collect()
222    }
223
224    /// List extensions in a specific state
225    pub async fn list_extensions_by_state(&self, state: ExtensionState) -> Vec<ExtensionInfo> {
226        self.extensions
227            .read()
228            .await
229            .values()
230            .filter(|ext| ext.state == state)
231            .cloned()
232            .collect()
233    }
234
235    /// Update extension state
236    #[instrument(skip(self, extension_id))]
237    pub async fn update_state(&self, extension_id: &str, state: ExtensionState) -> Result<()> {
238        let mut extensions = self.extensions.write().await;
239        if let Some(info) = extensions.get_mut(extension_id) {
240            info.state = state;
241            if state == ExtensionState::Activated {
242                info.activated_at = Some(
243                    std::time::SystemTime::now()
244                        .duration_since(std::time::UNIX_EPOCH)
245                        .map(|d| d.as_secs())
246                        .unwrap_or(0),
247                );
248                
249                let mut stats = self.stats.write().await;
250                stats.total_activated += 1;
251            } else if state == ExtensionState::Deactivated {
252                let mut stats = self.stats.write().await;
253                stats.total_deactivated += 1;
254            }
255            Ok(())
256        } else {
257            Err(anyhow::anyhow!("Extension not found: {}", extension_id))
258        }
259    }
260
261    /// Get extension manager statistics
262    pub async fn stats(&self) -> ExtensionStats {
263        self.stats.read().await.clone()
264    }
265
266    /// Discover extensions in configured paths
267    #[instrument(skip(self))]
268    pub async fn discover_extensions(&self) -> Result<Vec<PathBuf>> {
269        info!("Discovering extensions in configured paths");
270
271        let mut extensions = Vec::new();
272
273        for discovery_path in &self.config.discovery_paths {
274            match self.discover_in_path(discovery_path).await {
275                Ok(mut found) => extensions.append(&mut found),
276                Err(e) => {
277                    warn!("Failed to discover extensions in {}: {}", discovery_path, e);
278                }
279            }
280        }
281
282        info!("Discovered {} extensions", extensions.len());
283
284        Ok(extensions)
285    }
286
287    /// Discover extensions in a specific path
288    async fn discover_in_path(&self, path: &str) -> Result<Vec<PathBuf>> {
289        let path = PathBuf::from(shellexpand::tilde(path).as_ref());
290
291        if !path.exists() {
292            return Ok(Vec::new());
293        }
294
295        let mut extensions = Vec::new();
296
297        // Read directory entries
298        let mut entries = tokio::fs::read_dir(&path)
299            .await
300            .context(format!("Failed to read directory: {:?}", path))?;
301
302        while let Some(entry) = entries.next_entry().await? {
303            let entry_path = entry.path();
304            
305            // Skip if not a directory
306            if !entry_path.is_dir() {
307                continue;
308            }
309
310            // Check for package.json or manifest.json
311            let manifest_path = entry_path.join("package.json");
312            let alt_manifest_path = entry_path.join("manifest.json");
313
314            if manifest_path.exists() || alt_manifest_path.exists() {
315                extensions.push(entry_path.clone());
316                debug!("Discovered extension: {:?}", entry_path);
317            }
318        }
319
320        Ok(extensions)
321    }
322
323    /// Parse extension manifest
324    fn parse_manifest(&self, path: &Path) -> Result<serde_json::Value> {
325        let manifest_path = path.join("package.json");
326        let alt_manifest_path = path.join("manifest.json");
327
328        let manifest_content = if manifest_path.exists() {
329            tokio::runtime::Runtime::new()
330                .unwrap()
331                .block_on(tokio::fs::read_to_string(&manifest_path))
332                .context("Failed to read package.json")?
333        } else if alt_manifest_path.exists() {
334            tokio::runtime::Runtime::new()
335                .unwrap()
336                .block_on(tokio::fs::read_to_string(&alt_manifest_path))
337                .context("Failed to read manifest.json")?
338        } else {
339            return Err(anyhow::anyhow!("No manifest found in extension path"));
340        };
341
342        let manifest: serde_json::Value =
343            serde_json::from_str(&manifest_content).context("Failed to parse manifest")?;
344
345        Ok(manifest)
346    }
347
348    /// Extract extension ID from manifest
349    fn extract_extension_id(&self, manifest: &serde_json::Value) -> Result<String> {
350        let publisher = manifest
351            .get("publisher")
352            .and_then(|v| v.as_str())
353            .ok_or_else(|| anyhow::anyhow!("Missing publisher in manifest"))?;
354
355        let name = manifest
356            .get("name")
357            .and_then(|v| v.as_str())
358            .ok_or_else(|| anyhow::anyhow!("Missing name in manifest"))?;
359
360        Ok(format!("{}.{}", publisher, name))
361    }
362
363    /// Determine extension type
364    fn determine_extension_type(
365        &self,
366        path: &Path,
367        manifest: &serde_json::Value,
368    ) -> Result<ExtensionType> {
369        // Check for WASM file
370        let wasm_path = path.join("extension.wasm");
371        if wasm_path.exists() {
372            return Ok(ExtensionType::WASM);
373        }
374
375        // Check for Rust project
376        let cargo_path = path.join("Cargo.toml");
377        if cargo_path.exists() {
378            return Ok(ExtensionType::Native);
379        }
380
381        // Check for JavaScript/TypeScript
382        let main = manifest.get("main").and_then(|v| v.as_str());
383        if let Some(main) = main {
384            let main_path = path.join(main);
385            if main_path.exists() && (main.ends_with(".js") || main.ends_with(".ts")) {
386                return Ok(ExtensionType::JavaScript);
387            }
388        }
389
390        Ok(ExtensionType::Unknown)
391    }
392
393    /// Extract activation events from manifest
394    fn extract_activation_events(&self, manifest: &serde_json::Value) -> Vec<String> {
395        manifest
396            .get("activationEvents")
397            .and_then(|v| v.as_array())
398            .map(|arr| {
399                arr.iter()
400                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
401                    .collect()
402            })
403            .unwrap_or_default()
404    }
405
406    /// Extract capabilities from manifest
407    fn extract_capabilities(&self, manifest: &serde_json::Value) -> Vec<String> {
408        manifest
409            .get("capabilities")
410            .and_then(|v| v.as_object())
411            .map(|obj| obj.keys().cloned().collect())
412            .unwrap_or_default()
413    }
414
415    /// Extract dependencies from manifest
416    fn extract_dependencies(&self, manifest: &serde_json::Value) -> Vec<String> {
417        manifest
418            .get("extensionDependencies")
419            .and_then(|v| v.as_array())
420            .map(|arr| {
421                arr.iter()
422                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
423                    .collect()
424            })
425            .unwrap_or_default()
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn test_extension_type() {
435        assert_eq!(ExtensionType::WASM, ExtensionType::WASM);
436        assert_eq!(ExtensionType::Native, ExtensionType::Native);
437        assert_eq!(ExtensionType::JavaScript, ExtensionType::JavaScript);
438    }
439
440    #[test]
441    fn test_extension_state() {
442        assert_eq!(ExtensionState::Loaded, ExtensionState::Loaded);
443        assert_eq!(ExtensionState::Activated, ExtensionState::Activated);
444        assert_eq!(ExtensionState::Deactivated, ExtensionState::Deactivated);
445        assert_eq!(ExtensionState::Error, ExtensionState::Error);
446    }
447
448    #[tokio::test]
449    async fn test_extension_manager_creation() {
450        let wasm_runtime = Arc::new(
451            tokio::runtime::Runtime::new()
452                .unwrap()
453                .block_on(crate::WASM::Runtime::WASMRuntime::new(
454                    crate::WASM::Runtime::WASMConfig::default(),
455                ))
456                .unwrap()
457        );
458        let config = HostConfig::default();
459        let manager = ExtensionManagerImpl::new(wasm_runtime, config);
460        
461        assert_eq!(manager.list_extensions().await.len(), 0);
462    }
463
464    #[test]
465    fn test_extension_stats_default() {
466        let stats = ExtensionStats::default();
467        assert_eq!(stats.total_loaded, 0);
468        assert_eq!(stats.total_activated, 0);
469    }
470}