Anki/rslib/src/sync/http_client/io_monitor.rs
Damien Elmes 3707e54ffa Rework syncing code, and replace local sync server (#2329)
This PR replaces the existing Python-driven sync server with a new one in Rust.
The new server supports both collection and media syncing, and is compatible
with both the new protocol mentioned below, and older clients. A setting has
been added to the preferences screen to point Anki to a local server, and a
similar setting is likely to come to AnkiMobile soon.

Documentation is available here: <https://docs.ankiweb.net/sync-server.html>

In addition to the new server and refactoring, this PR also makes changes to the
sync protocol. The existing sync protocol places payloads and metadata inside a
multipart POST body, which causes a few headaches:

- Legacy clients build the request in a non-deterministic order, meaning the
entire request needs to be scanned to extract the metadata.
- Reqwest's multipart API directly writes the multipart body, without exposing
the resulting stream to us, making it harder to track the progress of the
transfer. We've been relying on a patched version of reqwest for timeouts,
which is a pain to keep up to date.

To address these issues, the metadata is now sent in a HTTP header, with the
data payload sent directly in the body. Instead of the slower gzip, we now
use zstd. The old timeout handling code has been replaced with a new implementation
that wraps the request and response body streams to track progress, allowing us
to drop the git dependencies for reqwest, hyper-timeout and tokio-io-timeout.

The main other change to the protocol is that one-way syncs no longer need to
downgrade the collection to schema 11 prior to sending.
2023-01-18 12:43:46 +10:00

292 lines
9 KiB
Rust

// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{
io::{Cursor, ErrorKind},
sync::{Arc, Mutex},
time::Duration,
};
use bytes::Bytes;
use futures::{Stream, StreamExt, TryStreamExt};
use reqwest::{
header::{CONTENT_TYPE, LOCATION},
Body, RequestBuilder, Response, StatusCode,
};
use tokio::{
io::AsyncReadExt,
select,
time::{interval, Instant},
};
use tokio_util::io::{ReaderStream, StreamReader};
use crate::{
error::Result,
sync::{
error::{HttpError, HttpResult, HttpSnafu, OrHttpErr},
request::header_and_stream::{decode_zstd_body_stream, encode_zstd_body_stream},
response::ORIGINAL_SIZE,
},
};
/// Serves two purposes:
/// - allows us to monitor data sending/receiving and abort if
/// the transfer stalls
/// - allows us to monitor amount of data moving, to provide progress
/// reporting
#[derive(Clone)]
pub struct IoMonitor(pub Arc<Mutex<IoMonitorInner>>);
impl IoMonitor {
pub fn new() -> Self {
Self(Arc::new(Mutex::new(IoMonitorInner {
last_activity: Instant::now(),
bytes_sent: 0,
total_bytes_to_send: 0,
bytes_received: 0,
total_bytes_to_receive: 0,
})))
}
pub fn wrap_stream<S, E>(
&self,
sending: bool,
total_bytes: u32,
stream: S,
) -> impl Stream<Item = HttpResult<Bytes>> + Send + Sync + 'static
where
S: Stream<Item = Result<Bytes, E>> + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
{
let inner = self.0.clone();
{
let mut inner = inner.lock().unwrap();
inner.last_activity = Instant::now();
if sending {
inner.total_bytes_to_send += total_bytes
} else {
inner.total_bytes_to_receive += total_bytes
}
}
stream.map(move |res| match res {
Ok(bytes) => {
let mut inner = inner.lock().unwrap();
inner.last_activity = Instant::now();
if sending {
inner.bytes_sent += bytes.len() as u32;
} else {
inner.bytes_received += bytes.len() as u32;
}
Ok(bytes)
}
err => err.or_http_err(StatusCode::SEE_OTHER, "stream failure"),
})
}
/// Returns if no I/O activity observed for `stall_time`.
pub async fn timeout(&self, stall_time: Duration) {
let poll_interval = Duration::from_millis(if cfg!(test) { 10 } else { 1000 });
let mut interval = interval(poll_interval);
loop {
let now = interval.tick().await;
let last_activity = self.0.lock().unwrap().last_activity;
if now.duration_since(last_activity) > stall_time {
return;
}
}
}
/// Takes care of encoding provided request data and setting content type to binary, and returns
/// the decompressed response body.
pub async fn zstd_request_with_timeout(
&self,
request: RequestBuilder,
request_body: Vec<u8>,
stall_duration: Duration,
) -> HttpResult<Vec<u8>> {
let request_total = request_body.len() as u32;
let request_body_stream = encode_zstd_body_stream(self.wrap_stream(
true,
request_total,
ReaderStream::new(Cursor::new(request_body)),
));
let response_body_stream = async move {
let resp = request
.header(CONTENT_TYPE, "application/octet-stream")
.body(Body::wrap_stream(request_body_stream))
.send()
.await?
.error_for_status()?;
map_redirect_to_error(&resp)?;
let response_total = resp
.headers()
.get(&ORIGINAL_SIZE)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok())
.or_bad_request("missing original size")?;
let response_stream = self.wrap_stream(
false,
response_total,
decode_zstd_body_stream(resp.bytes_stream()),
);
let mut reader =
StreamReader::new(response_stream.map_err(|e| {
std::io::Error::new(ErrorKind::ConnectionAborted, format!("{e}"))
}));
let mut buf = Vec::with_capacity(response_total as usize);
reader
.read_to_end(&mut buf)
.await
.or_http_err(StatusCode::SEE_OTHER, "reading stream")?;
Ok::<_, HttpError>(buf)
};
select! {
// happy path
data = response_body_stream => Ok(data?),
// timeout
_ = self.timeout(stall_duration) => {
HttpSnafu {
code: StatusCode::REQUEST_TIMEOUT,
context: "timeout monitor",
source: None,
}.fail()
}
}
}
}
/// Reqwest can't retry a redirected request as the body has been consumed, so we need
/// to bubble it up to the sync driver to retry.
fn map_redirect_to_error(resp: &Response) -> HttpResult<()> {
if resp.status() == StatusCode::PERMANENT_REDIRECT {
let location = resp
.headers()
.get(LOCATION)
.or_bad_request("missing location header")?;
let location = String::from_utf8(location.as_bytes().to_vec())
.or_bad_request("location was not in utf8")?;
None.or_permanent_redirect(location)?;
}
Ok(())
}
#[derive(Debug)]
pub struct IoMonitorInner {
last_activity: Instant,
pub bytes_sent: u32,
pub total_bytes_to_send: u32,
pub bytes_received: u32,
pub total_bytes_to_receive: u32,
}
impl IoMonitor {}
#[cfg(test)]
mod test {
use async_stream::stream;
use futures::{pin_mut, StreamExt};
use tokio::{select, time::sleep};
use wiremock::{
matchers::{method, path},
Mock, MockServer, ResponseTemplate,
};
use super::*;
use crate::sync::error::HttpError;
/// Longer delays on Windows
fn millis(millis: u64) -> Duration {
Duration::from_millis(millis * if cfg!(windows) { 10 } else { 1 })
}
#[tokio::test]
async fn can_fail_before_any_bytes() {
let monitor = IoMonitor::new();
let stream = monitor.wrap_stream(
true,
0,
stream! {
sleep(millis(2000)).await;
yield Ok::<_, HttpError>(Bytes::from("1"))
},
);
pin_mut!(stream);
select! {
_ = stream.next() => panic!("expected failure"),
_ = monitor.timeout(millis(100)) => ()
};
}
#[tokio::test]
async fn fails_when_data_stops_moving() {
let monitor = IoMonitor::new();
let stream = monitor.wrap_stream(
true,
0,
stream! {
for _ in 0..10 {
sleep(millis(10)).await;
yield Ok::<_, HttpError>(Bytes::from("1"))
}
sleep(millis(50)).await;
yield Ok::<_, HttpError>(Bytes::from("1"))
},
);
pin_mut!(stream);
for _ in 0..10 {
select! {
_ = stream.next() => (),
_ = monitor.timeout(millis(20)) => panic!("expected success")
};
}
select! {
_ = stream.next() => panic!("expected timeout"),
_ = monitor.timeout(millis(20)) => ()
};
}
#[tokio::test]
async fn connect_timeout_works() {
let monitor = IoMonitor::new();
let req = monitor.zstd_request_with_timeout(
reqwest::Client::new().post("http://0.0.0.1"),
vec![],
millis(50),
);
req.await.unwrap_err();
}
#[tokio::test]
async fn http_success() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/"))
.respond_with(ResponseTemplate::new(200).insert_header(ORIGINAL_SIZE.clone(), "0"))
.mount(&mock_server)
.await;
let monitor = IoMonitor::new();
let req = monitor.zstd_request_with_timeout(
reqwest::Client::new().post(mock_server.uri()),
vec![],
millis(10),
);
req.await.unwrap();
}
#[tokio::test]
async fn delay_before_reply_fails() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/"))
.respond_with(ResponseTemplate::new(200).set_delay(millis(50)))
.mount(&mock_server)
.await;
let monitor = IoMonitor::new();
let req = monitor.zstd_request_with_timeout(
reqwest::Client::new().post(mock_server.uri()),
vec![],
millis(10),
);
req.await.unwrap_err();
}
}