1#![warn(missing_debug_implementations, missing_docs, unreachable_pub)]
5
6use crate::filter::AsyncFilter;
7use futures_util::future::Either;
8use pin_project_lite::pin_project;
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11use std::{
12 future,
13 pin::Pin,
14 task::{Context, Poll},
15};
16use tracing::error;
17
18mod delay;
19mod latency;
20mod rotating_histogram;
21mod select;
22
23use delay::Delay;
24use latency::Latency;
25use rotating_histogram::RotatingHistogram;
26use select::Select;
27
28type Histo = Arc<Mutex<RotatingHistogram>>;
29type Service<S, P> = select::Select<
30 SelectPolicy<P>,
31 Latency<Histo, S>,
32 Delay<DelayPolicy, AsyncFilter<Latency<Histo, S>, PolicyPredicate<P>>>,
33>;
34
35#[derive(Debug)]
39pub struct Hedge<S, P>(Service<S, P>);
40
41pin_project! {
42 #[derive(Debug)]
46 pub struct Future<S, Request>
47 where
48 S: tower_service::Service<Request>,
49 {
50 #[pin]
51 inner: S::Future,
52 }
53}
54
55pub trait Policy<Request> {
58 fn clone_request(&self, req: &Request) -> Option<Request>;
60
61 fn can_retry(&self, req: &Request) -> bool;
63}
64
65#[doc(hidden)]
68#[derive(Clone, Debug)]
69pub struct PolicyPredicate<P>(P);
70
71#[doc(hidden)]
72#[derive(Debug)]
73pub struct DelayPolicy {
74 histo: Histo,
75 latency_percentile: f32,
76}
77
78#[doc(hidden)]
79#[derive(Debug)]
80pub struct SelectPolicy<P> {
81 policy: P,
82 histo: Histo,
83 min_data_points: u64,
84}
85
86impl<S, P> Hedge<S, P> {
87 pub fn new<Request>(
89 service: S,
90 policy: P,
91 min_data_points: u64,
92 latency_percentile: f32,
93 period: Duration,
94 ) -> Hedge<S, P>
95 where
96 S: tower_service::Service<Request> + Clone,
97 S::Error: Into<crate::BoxError>,
98 P: Policy<Request> + Clone,
99 {
100 let histo = Arc::new(Mutex::new(RotatingHistogram::new(period)));
101 Self::new_with_histo(service, policy, min_data_points, latency_percentile, histo)
102 }
103
104 pub fn new_with_mock_latencies<Request>(
107 service: S,
108 policy: P,
109 min_data_points: u64,
110 latency_percentile: f32,
111 period: Duration,
112 latencies_ms: &[u64],
113 ) -> Hedge<S, P>
114 where
115 S: tower_service::Service<Request> + Clone,
116 S::Error: Into<crate::BoxError>,
117 P: Policy<Request> + Clone,
118 {
119 let histo = Arc::new(Mutex::new(RotatingHistogram::new(period)));
120 {
121 let mut locked = histo.lock().unwrap();
122 for latency in latencies_ms.iter() {
123 locked.read().record(*latency).unwrap();
124 }
125 }
126 Self::new_with_histo(service, policy, min_data_points, latency_percentile, histo)
127 }
128
129 fn new_with_histo<Request>(
130 service: S,
131 policy: P,
132 min_data_points: u64,
133 latency_percentile: f32,
134 histo: Histo,
135 ) -> Hedge<S, P>
136 where
137 S: tower_service::Service<Request> + Clone,
138 S::Error: Into<crate::BoxError>,
139 P: Policy<Request> + Clone,
140 {
141 let recorded_a = Latency::new(histo.clone(), service.clone());
144 let recorded_b = Latency::new(histo.clone(), service);
145
146 let filtered = AsyncFilter::new(recorded_b, PolicyPredicate(policy.clone()));
148
149 let delay_policy = DelayPolicy {
152 histo: histo.clone(),
153 latency_percentile,
154 };
155 let delayed = Delay::new(delay_policy, filtered);
156
157 let select_policy = SelectPolicy {
160 policy,
161 histo,
162 min_data_points,
163 };
164 Hedge(Select::new(select_policy, recorded_a, delayed))
165 }
166}
167
168impl<S, P, Request> tower_service::Service<Request> for Hedge<S, P>
169where
170 S: tower_service::Service<Request> + Clone,
171 S::Error: Into<crate::BoxError>,
172 P: Policy<Request> + Clone,
173{
174 type Response = S::Response;
175 type Error = crate::BoxError;
176 type Future = Future<Service<S, P>, Request>;
177
178 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
179 self.0.poll_ready(cx)
180 }
181
182 fn call(&mut self, request: Request) -> Self::Future {
183 Future {
184 inner: self.0.call(request),
185 }
186 }
187}
188
189impl<S, Request> std::future::Future for Future<S, Request>
190where
191 S: tower_service::Service<Request>,
192 S::Error: Into<crate::BoxError>,
193{
194 type Output = Result<S::Response, crate::BoxError>;
195
196 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
197 self.project().inner.poll(cx).map_err(Into::into)
198 }
199}
200
201const NANOS_PER_MILLI: u32 = 1_000_000;
203const MILLIS_PER_SEC: u64 = 1_000;
204fn millis(duration: Duration) -> u64 {
205 let millis = (duration.subsec_nanos() + NANOS_PER_MILLI - 1) / NANOS_PER_MILLI;
207 duration
208 .as_secs()
209 .saturating_mul(MILLIS_PER_SEC)
210 .saturating_add(u64::from(millis))
211}
212
213impl latency::Record for Histo {
214 fn record(&mut self, latency: Duration) {
215 let mut locked = self.lock().unwrap();
216 locked.write().record(millis(latency)).unwrap_or_else(|e| {
217 error!("Failed to write to hedge histogram: {:?}", e);
218 })
219 }
220}
221
222impl<P, Request> crate::filter::AsyncPredicate<Request> for PolicyPredicate<P>
223where
224 P: Policy<Request>,
225{
226 type Future = Either<
227 future::Ready<Result<Request, crate::BoxError>>,
228 future::Pending<Result<Request, crate::BoxError>>,
229 >;
230 type Request = Request;
231
232 fn check(&mut self, request: Request) -> Self::Future {
233 if self.0.can_retry(&request) {
234 Either::Left(future::ready(Ok(request)))
235 } else {
236 Either::Right(future::pending())
241 }
242 }
243}
244
245impl<Request> delay::Policy<Request> for DelayPolicy {
246 fn delay(&self, _req: &Request) -> Duration {
247 let mut locked = self.histo.lock().unwrap();
248 let millis = locked
249 .read()
250 .value_at_quantile(self.latency_percentile.into());
251 Duration::from_millis(millis)
252 }
253}
254
255impl<P, Request> select::Policy<Request> for SelectPolicy<P>
256where
257 P: Policy<Request>,
258{
259 fn clone_request(&self, req: &Request) -> Option<Request> {
260 self.policy.clone_request(req).filter(|_| {
261 let mut locked = self.histo.lock().unwrap();
262 locked.read().len() >= self.min_data_points
265 })
266 }
267}