diff --git a/rslib/src/sync/http.rs b/rslib/src/sync/http.rs new file mode 100644 index 000000000..5778f7181 --- /dev/null +++ b/rslib/src/sync/http.rs @@ -0,0 +1,88 @@ +use super::{Chunk, Graves, SanityCheckCounts, UnchunkedChanges}; +use crate::prelude::*; +use serde::{Deserialize, Serialize}; +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub enum SyncRequest { + HostKey(HostKeyIn), + Meta(MetaIn), + Start(StartIn), + ApplyGraves(ApplyGravesIn), + ApplyChanges(ApplyChangesIn), + Chunk, + ApplyChunk(ApplyChunkIn), + #[serde(rename = "sanityCheck2")] + SanityCheck(SanityCheckIn), + Finish, + Abort, +} + +impl SyncRequest { + /// Return method name and payload bytes. + pub(crate) fn to_method_and_json(&self) -> Result<(&'static str, Vec)> { + use serde_json::to_vec; + Ok(match self { + SyncRequest::HostKey(v) => ("hostKey", to_vec(&v)?), + SyncRequest::Meta(v) => ("meta", to_vec(&v)?), + SyncRequest::Start(v) => ("start", to_vec(&v)?), + SyncRequest::ApplyGraves(v) => ("applyGraves", to_vec(&v)?), + SyncRequest::ApplyChanges(v) => ("applyChanges", to_vec(&v)?), + SyncRequest::Chunk => ("chunk", b"{}".to_vec()), + SyncRequest::ApplyChunk(v) => ("applyChunk", to_vec(&v)?), + SyncRequest::SanityCheck(v) => ("sanityCheck2", to_vec(&v)?), + SyncRequest::Finish => ("finish", b"{}".to_vec()), + SyncRequest::Abort => ("abort", b"{}".to_vec()), + }) + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct HostKeyIn { + #[serde(rename = "u")] + pub username: String, + #[serde(rename = "p")] + pub password: String, +} +#[derive(Serialize, Deserialize, Debug)] +pub struct HostKeyOut { + pub key: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct MetaIn { + #[serde(rename = "v")] + pub sync_version: u8, + #[serde(rename = "cv")] + pub client_version: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct StartIn { + #[serde(rename = "minUsn")] + pub client_usn: Usn, + #[serde(rename = "offset", default)] + pub minutes_west: Option, + #[serde(rename = "lnewer")] + pub local_is_newer: bool, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ApplyGravesIn { + pub chunk: Graves, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ApplyChangesIn { + pub changes: UnchunkedChanges, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ApplyChunkIn { + pub chunk: Chunk, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SanityCheckIn { + pub client: SanityCheckCounts, + pub full: bool, +} diff --git a/rslib/src/sync/http_client.rs b/rslib/src/sync/http_client.rs index f72d68011..a202adbaf 100644 --- a/rslib/src/sync/http_client.rs +++ b/rslib/src/sync/http_client.rs @@ -2,11 +2,29 @@ // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html use super::server::SyncServer; -use super::*; +use super::{ + Chunk, FullSyncProgress, Graves, SanityCheckCounts, SanityCheckOut, SyncMeta, UnchunkedChanges, +}; +use crate::prelude::*; +use crate::{err::SyncErrorKind, notes::guid, version::sync_client_version}; use async_trait::async_trait; use bytes::Bytes; +use flate2::write::GzEncoder; +use flate2::Compression; use futures::Stream; +use futures::StreamExt; use reqwest::Body; +use reqwest::{multipart, Client, Response}; +use serde::de::DeserializeOwned; + +use super::http::{ + ApplyChangesIn, ApplyChunkIn, ApplyGravesIn, HostKeyIn, HostKeyOut, MetaIn, SanityCheckIn, + StartIn, SyncRequest, +}; +use std::io::prelude::*; +use std::path::Path; +use std::time::Duration; +use tempfile::NamedTempFile; // fixme: 100mb limit @@ -22,61 +40,6 @@ pub struct HTTPSyncClient { full_sync_progress_fn: Option, } -#[derive(Serialize)] -struct HostKeyIn<'a> { - #[serde(rename = "u")] - username: &'a str, - #[serde(rename = "p")] - password: &'a str, -} -#[derive(Deserialize)] -struct HostKeyOut { - key: String, -} - -#[derive(Serialize)] -struct MetaIn<'a> { - #[serde(rename = "v")] - sync_version: u8, - #[serde(rename = "cv")] - client_version: &'a str, -} - -#[derive(Serialize, Deserialize, Debug)] -struct StartIn { - #[serde(rename = "minUsn")] - local_usn: Usn, - #[serde(rename = "offset")] - minutes_west: Option, - // only used to modify behaviour of changes() - #[serde(rename = "lnewer")] - local_is_newer: bool, - // used by 2.0 clients - #[serde(skip_serializing_if = "Option::is_none")] - local_graves: Option, -} - -#[derive(Serialize, Deserialize, Debug)] -struct ApplyGravesIn { - chunk: Graves, -} - -#[derive(Serialize, Deserialize, Debug)] -struct ApplyChangesIn { - changes: UnchunkedChanges, -} - -#[derive(Serialize, Deserialize, Debug)] -struct ApplyChunkIn { - chunk: Chunk, -} - -#[derive(Serialize, Deserialize, Debug)] -struct SanityCheckIn { - client: SanityCheckCounts, - full: bool, -} - pub struct Timeouts { pub connect_secs: u64, pub request_secs: u64, @@ -99,158 +62,64 @@ impl Timeouts { } } } -#[derive(Serialize)] -struct Empty {} - -impl HTTPSyncClient { - pub fn new(hkey: Option, host_number: u32) -> HTTPSyncClient { - let timeouts = Timeouts::new(); - let client = Client::builder() - .connect_timeout(Duration::from_secs(timeouts.connect_secs)) - .timeout(Duration::from_secs(timeouts.request_secs)) - .io_timeout(Duration::from_secs(timeouts.io_secs)) - .build() - .unwrap(); - let skey = guid(); - let endpoint = sync_endpoint(host_number); - HTTPSyncClient { - hkey, - skey, - client, - endpoint, - full_sync_progress_fn: None, - } - } - - pub fn set_full_sync_progress_fn(&mut self, func: Option) { - self.full_sync_progress_fn = func; - } - - async fn json_request(&self, method: &str, json: &T, timeout_long: bool) -> Result - where - T: serde::Serialize, - { - let req_json = serde_json::to_vec(json)?; - - let mut gz = GzEncoder::new(Vec::new(), Compression::fast()); - gz.write_all(&req_json)?; - let part = multipart::Part::bytes(gz.finish()?); - - self.request(method, part, timeout_long).await - } - - async fn json_request_deserialized(&self, method: &str, json: &T) -> Result - where - T: Serialize, - T2: DeserializeOwned, - { - self.json_request(method, json, false) - .await? - .json() - .await - .map_err(Into::into) - } - - async fn request( - &self, - method: &str, - data_part: multipart::Part, - timeout_long: bool, - ) -> Result { - let data_part = data_part.file_name("data"); - - let mut form = multipart::Form::new() - .part("data", data_part) - .text("c", "1"); - if let Some(hkey) = &self.hkey { - form = form.text("k", hkey.clone()).text("s", self.skey.clone()); - } - - let url = format!("{}{}", self.endpoint, method); - let mut req = self.client.post(&url).multipart(form); - - if timeout_long { - req = req.timeout(Duration::from_secs(60 * 60)); - } - - req.send().await?.error_for_status().map_err(Into::into) - } - - pub(crate) async fn login(&mut self, username: &str, password: &str) -> Result<()> { - let resp: HostKeyOut = self - .json_request_deserialized("hostKey", &HostKeyIn { username, password }) - .await?; - self.hkey = Some(resp.key); - - Ok(()) - } - - pub(crate) fn hkey(&self) -> &str { - self.hkey.as_ref().unwrap() - } -} #[async_trait(?Send)] impl SyncServer for HTTPSyncClient { async fn meta(&self) -> Result { - let meta_in = MetaIn { + let input = SyncRequest::Meta(MetaIn { sync_version: SYNC_VERSION, - client_version: sync_client_version(), - }; - self.json_request_deserialized("meta", &meta_in).await + client_version: sync_client_version().to_string(), + }); + self.json_request(&input).await } async fn start( &mut self, - local_usn: Usn, + client_usn: Usn, minutes_west: Option, local_is_newer: bool, ) -> Result { - let input = StartIn { - local_usn, + let input = SyncRequest::Start(StartIn { + client_usn, minutes_west, local_is_newer, - local_graves: None, - }; - self.json_request_deserialized("start", &input).await + }); + self.json_request(&input).await } async fn apply_graves(&mut self, chunk: Graves) -> Result<()> { - let input = ApplyGravesIn { chunk }; - let resp = self.json_request("applyGraves", &input, false).await?; - resp.error_for_status()?; - Ok(()) + let input = SyncRequest::ApplyGraves(ApplyGravesIn { chunk }); + self.json_request(&input).await } async fn apply_changes(&mut self, changes: UnchunkedChanges) -> Result { - let input = ApplyChangesIn { changes }; - self.json_request_deserialized("applyChanges", &input).await + let input = SyncRequest::ApplyChanges(ApplyChangesIn { changes }); + self.json_request(&input).await } async fn chunk(&mut self) -> Result { - self.json_request_deserialized("chunk", &Empty {}).await + let input = SyncRequest::Chunk; + self.json_request(&input).await } async fn apply_chunk(&mut self, chunk: Chunk) -> Result<()> { - let input = ApplyChunkIn { chunk }; - let resp = self.json_request("applyChunk", &input, false).await?; - resp.error_for_status()?; - Ok(()) + let input = SyncRequest::ApplyChunk(ApplyChunkIn { chunk }); + self.json_request(&input).await } async fn sanity_check(&mut self, client: SanityCheckCounts) -> Result { - let input = SanityCheckIn { client, full: true }; - self.json_request_deserialized("sanityCheck2", &input).await + let input = SyncRequest::SanityCheck(SanityCheckIn { client, full: true }); + self.json_request(&input).await } async fn finish(&mut self) -> Result { - Ok(self.json_request_deserialized("finish", &Empty {}).await?) + let input = SyncRequest::Finish; + self.json_request(&input).await } async fn abort(&mut self) -> Result<()> { - let resp = self.json_request("abort", &Empty {}, false).await?; - resp.error_for_status()?; - Ok(()) + let input = SyncRequest::Abort; + self.json_request(&input).await } async fn full_upload(mut self: Box, col_path: &Path, _can_consume: bool) -> Result<()> { @@ -301,13 +170,101 @@ impl SyncServer for HTTPSyncClient { } impl HTTPSyncClient { + pub fn new(hkey: Option, host_number: u32) -> HTTPSyncClient { + let timeouts = Timeouts::new(); + let client = Client::builder() + .connect_timeout(Duration::from_secs(timeouts.connect_secs)) + .timeout(Duration::from_secs(timeouts.request_secs)) + .io_timeout(Duration::from_secs(timeouts.io_secs)) + .build() + .unwrap(); + let skey = guid(); + let endpoint = sync_endpoint(host_number); + HTTPSyncClient { + hkey, + skey, + client, + endpoint, + full_sync_progress_fn: None, + } + } + + pub fn set_full_sync_progress_fn(&mut self, func: Option) { + self.full_sync_progress_fn = func; + } + + async fn json_request(&self, req: &SyncRequest) -> Result + where + T: DeserializeOwned, + { + let (method, req_json) = req.to_method_and_json()?; + self.request_bytes(method, &req_json, false) + .await? + .json() + .await + .map_err(Into::into) + } + + async fn request_bytes( + &self, + method: &str, + req: &[u8], + timeout_long: bool, + ) -> Result { + let mut gz = GzEncoder::new(Vec::new(), Compression::fast()); + gz.write_all(req)?; + let part = multipart::Part::bytes(gz.finish()?); + let resp = self.request(method, part, timeout_long).await?; + resp.error_for_status().map_err(Into::into) + } + + async fn request( + &self, + method: &str, + data_part: multipart::Part, + timeout_long: bool, + ) -> Result { + let data_part = data_part.file_name("data"); + + let mut form = multipart::Form::new() + .part("data", data_part) + .text("c", "1"); + if let Some(hkey) = &self.hkey { + form = form.text("k", hkey.clone()).text("s", self.skey.clone()); + } + + let url = format!("{}{}", self.endpoint, method); + let mut req = self.client.post(&url).multipart(form); + + if timeout_long { + req = req.timeout(Duration::from_secs(60 * 60)); + } + + req.send().await?.error_for_status().map_err(Into::into) + } + + pub(crate) async fn login>(&mut self, username: S, password: S) -> Result<()> { + let input = SyncRequest::HostKey(HostKeyIn { + username: username.into(), + password: password.into(), + }); + let output: HostKeyOut = self.json_request(&input).await?; + self.hkey = Some(output.key); + + Ok(()) + } + + pub(crate) fn hkey(&self) -> &str { + self.hkey.as_ref().unwrap() + } + async fn download_inner( &self, ) -> Result<( usize, impl Stream>, )> { - let resp: reqwest::Response = self.json_request("download", &Empty {}, true).await?; + let resp: reqwest::Response = self.request_bytes("download", b"{}", true).await?; let len = resp.content_length().unwrap_or_default(); Ok((len as usize, resp.bytes_stream())) } @@ -386,7 +343,7 @@ fn sync_endpoint(host_number: u32) -> String { #[cfg(test)] mod test { use super::*; - use crate::err::SyncErrorKind; + use crate::{err::SyncErrorKind, sync::SanityCheckDueCounts}; use tokio::runtime::Runtime; async fn http_client_inner(username: String, password: String) -> Result<()> { diff --git a/rslib/src/sync/mod.rs b/rslib/src/sync/mod.rs index 6f861a489..5749db3d6 100644 --- a/rslib/src/sync/mod.rs +++ b/rslib/src/sync/mod.rs @@ -1,6 +1,7 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +pub mod http; mod http_client; mod server; @@ -10,30 +11,23 @@ use crate::{ deckconf::DeckConfSchema11, decks::DeckSchema11, err::SyncErrorKind, - notes::{guid, Note}, + notes::Note, notetype::{NoteType, NoteTypeSchema11}, prelude::*, revlog::RevlogEntry, serde::{default_on_invalid, deserialize_int_from_number}, storage::open_and_check_sqlite_file, tags::{join_tags, split_tags}, - version::sync_client_version, }; -use flate2::write::GzEncoder; -use flate2::Compression; -use futures::StreamExt; pub use http_client::FullSyncProgressFn; use http_client::HTTPSyncClient; pub use http_client::Timeouts; use itertools::Itertools; -use reqwest::{multipart, Client, Response}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde::{Deserialize, Serialize}; use serde_json::Value; use serde_tuple::Serialize_tuple; -use server::SyncServer; -use std::io::prelude::*; -use std::{collections::HashMap, path::Path, time::Duration}; -use tempfile::NamedTempFile; +pub(crate) use server::SyncServer; +use std::collections::HashMap; #[derive(Default, Debug, Clone, Copy)] pub struct NormalSyncProgress { @@ -1194,6 +1188,8 @@ impl From for sync_status_out::Required { #[cfg(test)] mod test { + use std::path::Path; + use async_trait::async_trait; use lazy_static::lazy_static;