1use std::{collections::HashMap, sync::Arc};
7
8use anyhow::{Context, Result};
9use bytes::Bytes;
10use serde::{Deserialize, Serialize};
11use tokio::sync::RwLock;
12use tracing::{debug, instrument, warn};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct APICallRequest {
17 pub extension_id:String,
19 pub api_method:String,
21 pub arguments:Vec<serde_json::Value>,
23 pub correlation_id:Option<String>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct APICallResponse {
30 pub success:bool,
32 pub data:Option<serde_json::Value>,
34 pub error:Option<String>,
36 pub correlation_id:Option<String>,
38}
39
40pub struct APICall {
42 extension_id:String,
44 api_method:String,
46 arguments:Vec<serde_json::Value>,
48 timestamp:u64,
50}
51
52type APIMethodHandler = fn(&str, Vec<serde_json::Value>) -> Result<serde_json::Value>;
54
55type AsyncAPIMethodHandler =
57 fn(&str, Vec<serde_json::Value>) -> Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + Unpin>;
58
59#[derive(Clone)]
61pub struct APIMethodInfo {
62 name:String,
64 description:String,
66 parameters:Option<serde_json::Value>,
68 returns:Option<serde_json::Value>,
70 is_async:bool,
72 call_count:u64,
74 total_time_us:u64,
76}
77
78pub struct APIBridgeImpl {
80 api_methods:Arc<RwLock<HashMap<String, APIMethodInfo>>>,
82 stats:Arc<RwLock<APIStats>>,
84 contexts:Arc<RwLock<HashMap<String, APIContext>>>,
86}
87
88#[derive(Debug, Clone, Default, Serialize, Deserialize)]
90pub struct APIStats {
91 pub total_calls:u64,
93 pub successful_calls:u64,
95 pub failed_calls:u64,
97 pub avg_latency_us:u64,
99 pub active_contexts:usize,
101}
102
103#[derive(Debug, Clone)]
105pub struct APIContext {
106 pub extension_id:String,
108 pub context_id:String,
110 pub workspace_folder:Option<String>,
112 pub active_editor:Option<String>,
114 pub selections:Vec<Selection>,
116 pub created_at:u64,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct Selection {
123 pub start_line:u32,
125 pub start_character:u32,
127 pub end_line:u32,
129 pub end_character:u32,
131}
132
133impl Default for Selection {
134 fn default() -> Self { Self { start_line:0, start_character:0, end_line:0, end_character:0 } }
135}
136
137impl APIBridgeImpl {
138 pub fn new() -> Self {
140 let bridge = Self {
141 api_methods:Arc::new(RwLock::new(HashMap::new())),
142 stats:Arc::new(RwLock::new(APIStats::default())),
143 contexts:Arc::new(RwLock::new(HashMap::new())),
144 };
145
146 bridge.register_builtin_methods();
147
148 bridge
149 }
150
151 fn register_builtin_methods(&self) {
153 debug!("Registered built-in VS Code API methods");
162 }
163
164 pub async fn register_method(
166 &self,
167 name:&str,
168 description:&str,
169 parameters:Option<serde_json::Value>,
170 returns:Option<serde_json::Value>,
171 is_async:bool,
172 ) -> Result<()> {
173 let mut methods = self.api_methods.write().await;
174
175 if methods.contains_key(name) {
176 warn!("API method already registered: {}", name);
177 }
178
179 methods.insert(
180 name.to_string(),
181 APIMethodInfo {
182 name:name.to_string(),
183 description:description.to_string(),
184 parameters,
185 returns,
186 is_async,
187 call_count:0,
188 total_time_us:0,
189 },
190 );
191
192 debug!("Registered API method: {}", name);
193
194 Ok(())
195 }
196
197 #[instrument(skip(self))]
199 pub async fn create_context(&self, extension_id:&str) -> Result<APIContext> {
200 let context_id = format!("{}-{}", extension_id, uuid::Uuid::new_v4());
201
202 let context = APIContext {
203 extension_id:extension_id.to_string(),
204 context_id:context_id.clone(),
205 workspace_folder:None,
206 active_editor:None,
207 selections:Vec::new(),
208 created_at:std::time::SystemTime::now()
209 .duration_since(std::time::UNIX_EPOCH)
210 .map(|d| d.as_secs())
211 .unwrap_or(0),
212 };
213
214 let mut contexts = self.contexts.write().await;
215 contexts.insert(context_id.clone(), context.clone());
216
217 let mut stats = self.stats.write().await;
219 stats.active_contexts = contexts.len();
220
221 debug!("Created API context for extension: {}", extension_id);
222
223 Ok(context)
224 }
225
226 pub async fn get_context(&self, context_id:&str) -> Option<APIContext> {
228 self.contexts.read().await.get(context_id).cloned()
229 }
230
231 pub async fn update_context(&self, context:APIContext) -> Result<()> {
233 let mut contexts = self.contexts.write().await;
234 contexts.insert(context.context_id.clone(), context);
235 Ok(())
236 }
237
238 pub async fn remove_context(&self, context_id:&str) -> Result<bool> {
240 let mut contexts = self.contexts.write().await;
241 let removed = contexts.remove(context_id).is_some();
242
243 if removed {
244 let mut stats = self.stats.write().await;
245 stats.active_contexts = contexts.len();
246 }
247
248 Ok(removed)
249 }
250
251 #[instrument(skip(self, request))]
253 pub async fn handle_call(&self, request:APICallRequest) -> Result<APICallResponse> {
254 let start = std::time::Instant::now();
255
256 debug!(
257 "Handling API call: {} from extension {}",
258 request.api_method, request.extension_id
259 );
260
261 let exists = {
263 let methods = self.api_methods.read().await;
264 methods.contains_key(&request.api_method)
265 };
266
267 if !exists {
268 return Ok(APICallResponse {
269 success:false,
270 data:None,
271 error:Some(format!("API method not found: {}", request.api_method)),
272 correlation_id:request.correlation_id,
273 });
274 }
275
276 let result = self
279 .execute_method(&request.extension_id, &request.api_method, &request.arguments)
280 .await;
281
282 let elapsed_us = start.elapsed().as_micros() as u64;
283
284 let mut stats = self.stats.write().await;
286 stats.total_calls += 1;
287 stats.total_calls += 1;
288 if exists {
289 stats.successful_calls += 1;
290 stats.avg_latency_us =
292 (stats.avg_latency_us * (stats.successful_calls - 1) + elapsed_us) / stats.successful_calls;
293 }
294
295 {
297 let mut methods = self.api_methods.write().await;
298 if let Some(method) = methods.get_mut(&request.api_method) {
299 method.call_count += 1;
300 method.total_time_us += elapsed_us;
301 }
302 }
303
304 debug!("API call {} completed in {}µs", request.api_method, elapsed_us);
305
306 match result {
307 Ok(data) => {
308 Ok(
309 APICallResponse {
310 success:true,
311 data:Some(data),
312 error:None,
313 correlation_id:request.correlation_id,
314 },
315 )
316 },
317 Err(e) => {
318 Ok(APICallResponse {
319 success:false,
320 data:None,
321 error:Some(e.to_string()),
322 correlation_id:request.correlation_id,
323 })
324 },
325 }
326 }
327
328 async fn execute_method(
330 &self,
331 _extension_id:&str,
332 _method_name:&str,
333 _arguments:&[serde_json::Value],
334 ) -> Result<serde_json::Value> {
335 Ok(serde_json::Value::Null)
344 }
345
346 pub async fn stats(&self) -> APIStats { self.stats.read().await.clone() }
348
349 pub async fn get_methods(&self) -> Vec<APIMethodInfo> { self.api_methods.read().await.values().cloned().collect() }
351
352 pub async fn unregister_method(&self, name:&str) -> Result<bool> {
354 let mut methods = self.api_methods.write().await;
355 let removed = methods.remove(name).is_some();
356
357 if removed {
358 debug!("Unregistered API method: {}", name);
359 }
360
361 Ok(removed)
362 }
363}
364
365impl Default for APIBridgeImpl {
366 fn default() -> Self { Self::new() }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[tokio::test]
374 async fn test_api_bridge_creation() {
375 let bridge = APIBridgeImpl::new();
376 let stats = bridge.stats().await;
377 assert_eq!(stats.total_calls, 0);
378 assert_eq!(stats.successful_calls, 0);
379 }
380
381 #[tokio::test]
382 async fn test_context_creation() {
383 let bridge = APIBridgeImpl::new();
384 let context = bridge.create_context("test.ext").await.unwrap();
385 assert_eq!(context.extension_id, "test.ext");
386 assert!(!context.context_id.is_empty());
387 }
388
389 #[tokio::test]
390 async fn test_method_registration() {
391 let bridge = APIBridgeImpl::new();
392 let result: Result<()> = bridge.register_method("test.method", "Test method", None, None, false).await;
393 assert!(result.is_ok());
394
395 let methods: Vec<APIMethodInfo> = bridge.get_methods().await;
396 assert!(methods.iter().any(|m| m.name == "test.method"));
397 }
398
399 #[tokio::test]
400 async fn test_api_call_request() {
401 let request = APICallRequest {
402 extension_id:"test.ext".to_string(),
403 api_method:"test.method".to_string(),
404 arguments:vec![serde_json::json!("arg1")],
405 correlation_id:Some("test-id".to_string()),
406 };
407
408 assert_eq!(request.extension_id, "test.ext");
409 assert_eq!(request.api_method, "test.method");
410 assert_eq!(request.arguments.len(), 1);
411 }
412
413 #[test]
414 fn test_selection_default() {
415 let selection = Selection::default();
416 assert_eq!(selection.start_line, 0);
417 assert_eq!(selection.end_line, 0);
418 }
419}