grove/Transport/
gRPCTransport.rs1use std::sync::Arc;
7
8use async_trait::async_trait;
9use bytes::Bytes;
10use tokio::sync::RwLock;
11use tonic::transport::{Channel, Endpoint};
12use tracing::{debug, info, instrument, warn};
13
14use crate::Transport::TransportStrategy;
15use crate::Transport::TransportType;
16use crate::Transport::TransportStats;
17use crate::Transport::TransportConfig;
18
19#[derive(Clone, Debug)]
21pub struct GrpcTransport {
22 endpoint:String,
24 channel:Arc<RwLock<Option<Channel>>>,
26 config:TransportConfig,
28 connected:Arc<RwLock<bool>>,
30 stats:Arc<RwLock<TransportStats>>,
32}
33
34impl GrpcTransport {
35 pub fn new(address:&str) -> anyhow::Result<Self> {
50 Ok(Self {
51 endpoint:address.to_string(),
52 channel:Arc::new(RwLock::new(None)),
53 config:TransportConfig::default(),
54 connected:Arc::new(RwLock::new(false)),
55 stats:Arc::new(RwLock::new(TransportStats::default())),
56 })
57 }
58
59 pub fn with_config(address:&str, config:TransportConfig) -> anyhow::Result<Self> {
61 Ok(Self {
62 endpoint:address.to_string(),
63 channel:Arc::new(RwLock::new(None)),
64 config,
65 connected:Arc::new(RwLock::new(false)),
66 stats:Arc::new(RwLock::new(TransportStats::default())),
67 })
68 }
69
70 pub fn endpoint(&self) -> &str { &self.endpoint }
72
73 pub async fn channel(&self) -> anyhow::Result<Channel> {
75 let channel = self.channel.read().await;
76 channel
77 .as_ref()
78 .cloned()
79 .ok_or_else(|| anyhow::anyhow!("gRPC channel not connected"))
80 }
81
82 pub async fn stats(&self) -> TransportStats { self.stats.read().await.clone() }
84
85 fn build_endpoint(&self) -> anyhow::Result<Endpoint> {
87 let endpoint = Endpoint::from_shared(self.endpoint.clone())?
88 .timeout(self.config.connection_timeout)
89 .connect_timeout(self.config.connection_timeout)
90 .tcp_keepalive(Some(self.config.keepalive_interval));
91
92 Ok(endpoint)
93 }
94}
95
96#[async_trait]
97impl TransportStrategy for GrpcTransport {
98 type Error = GrpcTransportError;
99
100 #[instrument(skip(self))]
101 async fn connect(&self) -> Result<(), Self::Error> {
102 info!("Connecting to gRPC endpoint: {}", self.endpoint);
103
104 let endpoint = self
105 .build_endpoint()
106 .map_err(|e| GrpcTransportError::ConnectionFailed(e.to_string()))?;
107
108 let channel = endpoint
109 .connect()
110 .await
111 .map_err(|e| GrpcTransportError::ConnectionFailed(e.to_string()))?;
112
113 *self.channel.write().await = Some(channel);
114 *self.connected.write().await = true;
115
116 info!("gRPC connection established: {}", self.endpoint);
117 debug!("Connected to gRPC endpoint: {}", self.endpoint);
118
119 Ok(())
120 }
121
122 #[instrument(skip(self, request))]
123 async fn send(&self, request:&[u8]) -> Result<Vec<u8>, Self::Error> {
124 let start = std::time::Instant::now();
125
126 if !self.is_connected() {
127 return Err(GrpcTransportError::NotConnected);
128 }
129
130 debug!("Sending gRPC request ({} bytes)", request.len());
131
132 let response = vec![]; let latency_us = start.elapsed().as_micros() as u64;
137
138 let mut stats = self.stats.write().await;
140 stats.record_sent(request.len() as u64, latency_us);
141 stats.record_received(response.len() as u64);
142
143 debug!("gRPC request completed in {}µs", latency_us);
144
145 Ok(response)
146 }
147
148 #[instrument(skip(self, data))]
149 async fn send_no_response(&self, data:&[u8]) -> Result<(), Self::Error> {
150 if !self.is_connected() {
151 return Err(GrpcTransportError::NotConnected);
152 }
153
154 debug!("Sending gRPC request without response ({} bytes)", data.len());
155
156 let mut stats = self.stats.write().await;
159 stats.record_sent(data.len() as u64, 0);
160
161 Ok(())
162 }
163
164 #[instrument(skip(self))]
165 async fn close(&self) -> Result<(), Self::Error> {
166 info!("Closing gRPC connection: {}", self.endpoint);
167
168 *self.channel.write().await = None;
169 *self.connected.write().await = false;
170
171 info!("gRPC connection closed: {}", self.endpoint);
172
173 Ok(())
174 }
175
176 fn is_connected(&self) -> bool { self.connected.blocking_read().to_owned() }
177
178 fn transport_type(&self) -> TransportType {
179 TransportType::gRPC
180 }
181}
182
183#[derive(Debug, thiserror::Error)]
185pub enum GrpcTransportError {
186#[error("Connection failed: {0}")]
188ConnectionFailed(String),
189
190#[error("Send failed: {0}")]
192SendFailed(String),
193
194#[error("Receive failed: {0}")]
196ReceiveFailed(String),
197
198#[error("Not connected")]
200NotConnected,
201
202#[error("Timeout")]
204Timeout,
205
206#[error("gRPC error: {0}")]
208GrpcError(String),
209}
210
211impl From<tonic::transport::Error> for GrpcTransportError {
212 fn from(err:tonic::transport::Error) -> Self { GrpcTransportError::ConnectionFailed(err.to_string()) }
213}
214
215impl From<tonic::Status> for GrpcTransportError {
216 fn from(status:tonic::Status) -> Self { GrpcTransportError::GrpcError(status.to_string()) }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_grpc_transport_creation() {
225 let result = GrpcTransport::new("127.0.0.1:50050");
226 assert!(result.is_ok());
227 let transport = result.unwrap();
228 assert_eq!(transport.endpoint(), "127.0.0.1:50050");
229 }
230
231 #[test]
232 fn test_grpc_transport_with_config() {
233 let config = TransportConfig::default().with_max_retries(5);
234 let result = GrpcTransport::with_config("127.0.0.1:50050", config);
235 assert!(result.is_ok());
236 }
237
238 #[tokio::test]
239 async fn test_grpc_transport_not_connected() {
240 let transport = GrpcTransport::new("127.0.0.1:50050").unwrap();
241 assert!(!transport.is_connected());
242 }
243
244 #[tokio::test]
245 async fn test_grpc_transport_stats() {
246 let transport = GrpcTransport::new("127.0.0.1:50050").unwrap();
247 let stats = transport.stats().await;
248 assert_eq!(stats.messages_sent, 0);
249 assert_eq!(stats.messages_received, 0);
250 }
251}