Add comments

pull/5/head
Trinity Pointard 6 years ago
parent acce8261c9
commit 0a403f4825

@ -107,8 +107,6 @@ pub struct CsrfFairingBuilder {
}
impl CsrfFairingBuilder {
/// Create a new builder with default values.
pub fn new() -> Self {
CsrfFairingBuilder {
@ -191,10 +189,10 @@ impl CsrfFairingBuilder {
}
/// Set the secret key used to generate secure cryptographic tokens. If not set, rocket_csrf
/// will attempt to get the secret used by Rocket for it's own private cookies via the
/// will attempt to get the secret used by Rocket for it's own private cookies via the
/// ROCKET_SECRET_KEY environment variable, or will generate a new one at each restart.
/// Having the secret key set (via this or Rocket environment variable) allow tokens to keep
/// their validity in case of an application restart.
/// their validity in case of an application restart.
///
/// # Example
///
@ -213,7 +211,7 @@ impl CsrfFairingBuilder {
self.secret = Some(secret);
self
}
/// Set if this should modify response to insert tokens automatically in all forms. If true,
/// this will insert tokens in all forms it encounter, if false, you will have to add them via
/// [`CsrfToken`], which you may obtain via request guards.
@ -244,6 +242,7 @@ impl CsrfFairingBuilder {
/// Get the fairing from the builder.
pub fn finalize(self) -> Result<CsrfFairing, ()> {
let secret = self.secret.unwrap_or_else(|| {
//use provided secret if one is
env::vars()
.filter(|(key, _)| key == "ROCKET_SECRET_KEY")
.next()
@ -260,26 +259,26 @@ impl CsrfFairingBuilder {
} else {
None
}
})
})//else get secret environment variable
.unwrap_or_else(|| {
eprintln!("[rocket_csrf] No secret key was found, you should consider set one to allow application restart");
eprintln!("[rocket_csrf] No secret key was found, you should consider set one to keep token validity across application restart");
thread_rng().gen()
})
}) //if environment variable is not set, generate a random secret and print a warning
});
let default_target = Path::from(&self.default_target.0);
let mut hashmap = HashMap::new();
hashmap.insert("uri", "");
if default_target.map(hashmap).is_none() {
return Err(()); //invalid default url
}
return Err(());
} //verify if this path is valid as default path, i.e. it have at most one dynamic part which is <uri>
Ok(CsrfFairing {
duration: self.duration,
default_target: (default_target, self.default_target.1),
exceptions: self
.exceptions
.iter()
.map(|(a, b, m)| (Path::from(&a), Path::from(&b), *m))
.map(|(a, b, m)| (Path::from(&a), Path::from(&b), *m))//TODO verify if source and target are compatible
.collect(),
secret: secret,
auto_insert: self.auto_insert,
@ -322,13 +321,13 @@ impl Fairing for CsrfFairing {
}
fn on_attach(&self, rocket: Rocket) -> Result<Rocket, Rocket> {
Ok(rocket.manage((AesGcmCsrfProtection::from_key(self.secret), self.duration)))
Ok(rocket.manage((AesGcmCsrfProtection::from_key(self.secret), self.duration))) //add the Csrf engine to Rocket's managed state
}
fn on_request(&self, request: &mut Request, data: &Data) {
match request.method() {
Get | Head | Connect | Trace | Options => {
let _ = request.guard::<CsrfToken>();; //force regeneration of csrf cookies
Get | Head | Connect | Options => {
let _ = request.guard::<CsrfToken>(); //force regeneration of csrf cookies
return;
}
_ => {}
@ -343,7 +342,7 @@ impl Fairing for CsrfFairing {
.cookies()
.get(csrf::CSRF_COOKIE_NAME)
.and_then(|cookie| BASE64.decode(cookie.value().as_bytes()).ok())
.and_then(|cookie| csrf_engine.parse_cookie(&cookie).ok());
.and_then(|cookie| csrf_engine.parse_cookie(&cookie).ok()); //get and parse Csrf cookie
let _ = request.guard::<CsrfToken>(); //force regeneration of csrf cookies
@ -351,16 +350,18 @@ impl Fairing for CsrfFairing {
.filter(|(key, _)| key == &csrf::CSRF_FORM_FIELD)
.filter_map(|(_, token)| BASE64URL_NOPAD.decode(&token.as_bytes()).ok())
.filter_map(|token| csrf_engine.parse_token(&token).ok())
.next();
.next(); //get and parse Csrf token
if let Some(token) = token {
if let Some(cookie) = cookie {
if csrf_engine.verify_token_pair(&token, &cookie) {
return;
return; //if we got both token and cookie, and they match each other, we do nothing
}
}
}
//Request reaching here are violating Csrf protection
for (src, dst, method) in self.exceptions.iter() {
if let Some(param) = src.extract(&request.uri().to_string()) {
if let Some(destination) = dst.map(param) {
@ -370,6 +371,9 @@ impl Fairing for CsrfFairing {
}
}
}
//if request matched no exception, reroute it to default target
let uri = request.uri().to_string();
let uri = Uri::percent_encode(&uri);
let mut param: HashMap<&str, &str> = HashMap::new();
@ -383,7 +387,8 @@ impl Fairing for CsrfFairing {
if !ct.is_html() {
return;
}
}
} //if content type is not html, we do nothing
let uri = request.uri().to_string();
if self
.auto_insert_disable_prefix
@ -393,33 +398,34 @@ impl Fairing for CsrfFairing {
.is_some()
{
return;
}
//content type is html and we are not on static ressources, so we may need to modify this answer
} //if request is on an ignored prefix, ignore it
let token = match request.guard::<CsrfToken>() {
Outcome::Success(t) => t,
_ => return,
};
}; //if we can't get a token, leave request unchanged, we can't do anything anyway
let body = response.take_body();
let body = response.take_body(); //take request body from Rocket
if body.is_none() {
return;
}
} //if there was no body, leave it that way
let body = body.unwrap();
if let Sized(body_reader, len) = body {
if len <= self.auto_insert_max_size {
//if this is a small enought body, process the full body
let mut res = Vec::with_capacity(len as usize);
CsrfProxy::from(body_reader, token)
.read_to_end(&mut res)
.unwrap();
response.set_sized_body(std::io::Cursor::new(res));
} else {
//if body is of known but long size, change it to a stream to preserve memory, by encapsulating it into our "proxy" struct
let body = body_reader;
response.set_streamed_body(Box::new(CsrfProxy::from(body, token)));
}
} else {
//if body is of unknown size, encapsulate it into our "proxy" struct
let body = body.into_inner();
response.set_streamed_body(Box::new(CsrfProxy::from(body, token)));
}
@ -429,21 +435,20 @@ impl Fairing for CsrfFairing {
enum ParseState {
Reset, //default state
PartialFormMatch(u8), //when parsing "<form"
CloseFormTag, //searching for '>'
SearchInput, //like default state, but inside a form
PartialInputMatch(u8, usize), //when parsing "<input"
PartialFormEndMatch(u8, usize), //when parsing "/form" ('<' done by PartialInputMarch)
SearchMethod(usize), //like default state, but inside input tag
PartialFormEndMatch(u8, usize), //when parsing "</form" ('<' is actally done via PartialInputMarch)
SearchMethod(usize), //when inside the first <input>, search for begining of a param
PartialNameMatch(u8, usize), //when parsing "name="_method""
CloseInputTag, //only if insert after
CloseInputTag, //only if insert after, search for '>' of a "<input name=\"_method\">"
}
struct CsrfProxy<'a> {
underlying: Box<Read + 'a>,
token: Vec<u8>,
buf: Vec<(Vec<u8>, usize)>,
state: ParseState,
insert_tag: Option<usize>,
underlying: Box<Read + 'a>, //the underlying Reader from which we get data
token: Vec<u8>, //a full input tag loaded with a valid token
buf: Vec<(Vec<u8>, usize)>, //a stack of buffers, with a position in case a buffer was not fully transmited
state: ParseState, //state of the parser
insert_tag: Option<usize>, //if we have to insert tag here, and how fare are we in the tag (in case of very short read()s)
}
impl<'a> CsrfProxy<'a> {
@ -468,28 +473,35 @@ impl<'a> CsrfProxy<'a> {
impl<'a> Read for CsrfProxy<'a> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
if let Some(pos) = self.insert_tag {
//if we should insert a tag
let size = buf.len();
let copy_size = std::cmp::min(size, self.token.len() - pos);
buf[0..copy_size].copy_from_slice(&self.token[pos..copy_size + pos]);
let copy_size = std::cmp::min(size, self.token.len() - pos); //get max copy length
buf[0..copy_size].copy_from_slice(&self.token[pos..copy_size + pos]); //copy that mutch
if copy_size == self.token.len() - pos {
//if we copied the full tag, say we don't need to set it again
self.insert_tag = None;
} else {
//if we didn't copy the full tag, save where we were
self.insert_tag = Some(pos + copy_size);
}
return Ok(copy_size);
return Ok(copy_size); //return the lenght of the copied data
}
let len = if let Some((vec, pos)) = self.buf.pop() {
//if there is a buffer to add here
let size = buf.len();
if vec.len() - pos <= size {
//if the part left of the buffer is smaller than buf
buf[0..vec.len() - pos].copy_from_slice(&vec[pos..]);
vec.len()
} else {
//else if the part left of the buffer is bigger than buf
buf.copy_from_slice(&vec[pos..pos + size]);
self.buf.push((vec, pos + size));
size
}
} //send the size of what was read as if it was a normal read on underlying struct
} else {
//if there is no buffer to add, read from underlying struct
let res = self.underlying.read(buf);
if res.is_err() {
return res;
@ -501,45 +513,48 @@ impl<'a> Read for CsrfProxy<'a> {
};
for i in 0..len {
//for each byte
use ParseState::*;
self.state = match self.state {
Reset => if buf[i] as char == '<' {
//if we are in default state and we begin to match any tag
PartialFormMatch(0)
} else {
//if we don't match a tag
Reset
},
PartialFormMatch(count) => match (buf[i] as char, count) {
//progressively match "form"
('f', 0) | ('F', 0) => PartialFormMatch(1),
('o', 1) | ('O', 1) => PartialFormMatch(2),
('r', 2) | ('R', 2) => PartialFormMatch(3),
('m', 3) | ('M', 3) => CloseFormTag,
_ => Reset,
},
CloseFormTag => if buf[i] as char == '>' {
SearchInput
} else {
CloseFormTag
('m', 3) | ('M', 3) => SearchInput, //when we success, go to next state
_ => Reset, //if this don't match, go back to defailt state
},
SearchInput => if buf[i] as char == '<' {
//begin to match any tag
PartialInputMatch(0, i)
} else {
SearchInput
},
PartialInputMatch(count, pos) => match (buf[i] as char, count) {
//progressively match "input"
('i', 0) | ('I', 0) => PartialInputMatch(1, pos),
('n', 1) | ('N', 1) => PartialInputMatch(2, pos),
('p', 2) | ('P', 2) => PartialInputMatch(3, pos),
('u', 3) | ('U', 3) => PartialInputMatch(4, pos),
('t', 4) | ('T', 4) => SearchMethod(pos),
('/', 0) => PartialFormEndMatch(1, pos),
_ => SearchInput,
('t', 4) | ('T', 4) => SearchMethod(pos), //when we success, go to next state
('/', 0) => PartialFormEndMatch(1, pos), //if first char is '/', it may mean we are matching end of form, go to that state
_ => SearchInput, //not a input tag, go back to SearchInput
},
PartialFormEndMatch(count, pos) => match (buf[i] as char, count) {
//progressively match "/form"
('/', 0) => PartialFormEndMatch(1, pos), //unreachable, here only for comprehension
('f', 1) | ('F', 1) => PartialFormEndMatch(2, pos),
('o', 2) | ('O', 2) => PartialFormEndMatch(3, pos),
('r', 3) | ('R', 3) => PartialFormEndMatch(4, pos),
('m', 4) | ('M', 4) => {
//if we match end of form, save "</form>" and anything after to a buffer, and insert our token
self.insert_tag = Some(0);
self.buf.push((buf[pos..].to_vec(), 0));
self.state = Reset;
@ -548,8 +563,10 @@ impl<'a> Read for CsrfProxy<'a> {
_ => SearchInput,
},
SearchMethod(pos) => match buf[i] as char {
' ' => PartialNameMatch(0, pos),
//try to match params
' ' => PartialNameMatch(0, pos), //space, next char is a new param
'>' => {
//end of this <input> tag, it's not Rocket special one, so insert before, saving what comes next to buffer
self.insert_tag = Some(0);
self.buf.push((buf[pos..].to_vec(), 0));
self.state = Reset;
@ -558,12 +575,13 @@ impl<'a> Read for CsrfProxy<'a> {
_ => SearchMethod(pos),
},
PartialNameMatch(count, pos) => match (buf[i] as char, count) {
//progressively match "name='_method'", which must be first to work
('n', 0) | ('N', 0) => PartialNameMatch(1, pos),
('a', 1) | ('A', 1) => PartialNameMatch(2, pos),
('m', 2) | ('M', 2) => PartialNameMatch(3, pos),
('e', 3) | ('E', 3) => PartialNameMatch(4, pos),
('=', 4) => PartialNameMatch(5, pos),
('"', 5) => PartialNameMatch(6, pos),
('"', 5) | ('\'', 5) => PartialNameMatch(6, pos),
('_', 6) => PartialNameMatch(7, pos),
('m', 7) | ('M', 7) => PartialNameMatch(8, pos),
('e', 8) | ('E', 8) => PartialNameMatch(9, pos),
@ -571,10 +589,11 @@ impl<'a> Read for CsrfProxy<'a> {
('h', 10) | ('H', 10) => PartialNameMatch(11, pos),
('o', 11) | ('O', 11) => PartialNameMatch(12, pos),
('d', 12) | ('D', 12) => PartialNameMatch(13, pos),
('"', 13) => CloseInputTag,
_ => SearchMethod(pos),
('"', 13) => CloseInputTag, //we matched, wait for end of this <input> and insert just after
_ => SearchMethod(pos), //we did not match, search next param
},
CloseInputTag => if buf[i] as char == '>' {
//search for '>' at the end of an "<input name='_method'>", and insert token after
self.insert_tag = Some(0);
self.buf.push((buf[i + 1..].to_vec(), 0));
self.state = Reset;
@ -605,7 +624,7 @@ impl Serialize for CsrfToken {
where
S: Serializer,
{
serializer.serialize_str(&self.value)
serializer.serialize_str(&self.value) //simply serialise to the underlying String
}
}
@ -632,11 +651,12 @@ impl<'a, 'r> FromRequest<'a, 'r> for CsrfToken {
} else {
None
}
});
}); //when request guard is called, parse cookie to get it's encrypted secret (if there is a cookie)
match csrf_engine.generate_token_pair(token_value.as_ref(), *duration) {
Ok((token, cookie)) => {
cookies.add(Cookie::new(csrf::CSRF_COOKIE_NAME, cookie.b64_string()));
let mut c = Cookie::new(csrf::CSRF_COOKIE_NAME, cookie.b64_string());
cookies.add(c); //todo add a timeout and a same_site policy to the cookie
Outcome::Success(CsrfToken {
value: BASE64URL_NOPAD.encode(token.value()),
})
@ -655,6 +675,7 @@ struct Path {
impl Path {
fn from(path: &str) -> Self {
let (path, query) = if let Some(pos) = path.find('?') {
//cut the path at pos begining of query parameters
let (path, query) = path.split_at(pos);
let query = &query[1..];
(path, Some(query))
@ -663,14 +684,14 @@ impl Path {
};
Path {
path: path
.split('/')
.filter(|seg| seg != &"")
.split('/')//split path at each '/'
.filter(|seg| seg != &"")//remove empty segments
.map(|seg| {
if seg.get(..1) == Some("<") && seg.get(seg.len() - 1..) == Some(">") {
if seg.get(..1) == Some("<") && seg.get(seg.len() - 1..) == Some(">") {//if the segment start with '<' and end with '>', it is dynamic
PathPart::Dynamic(seg[1..seg.len() - 1].to_owned())
} else {
} else {//else it's static
PathPart::Static(seg.to_owned())
}
}//TODO add support for <..path> to match more than one segment
})
.collect(),
param: query.map(|query| {
@ -679,6 +700,7 @@ impl Path {
(
k.to_owned(),
if v.get(..1) == Some("<") && v.get(v.len() - 1..) == Some(">") {
//do the same kind of parsing as above, but on query params
PathPart::Dynamic(v[1..v.len() - 1].to_owned())
} else {
PathPart::Static(v.to_owned())
@ -691,6 +713,7 @@ impl Path {
}
fn extract<'a>(&self, uri: &'a str) -> Option<HashMap<&str, &'a str>> {
//try to match a str against a path, give back a hashmap of correponding parts if it matched
let mut res: HashMap<&str, &'a str> = HashMap::new();
let (path, query) = if let Some(pos) = uri.find('?') {
let (path, query) = uri.split_at(pos);
@ -707,17 +730,21 @@ impl Path {
if let Some(reference) = reference.next() {
match reference {
PathPart::Static(refe) => if refe != &v {
//static, but not the same, fail to parse
return None;
},
PathPart::Dynamic(key) => {
//dynamic, store to hashmap
res.insert(key, v);
}
};
} else {
//not the same lenght, fail to parse
return None;
}
}
None => if reference.next().is_some() {
//not the same lenght, fail to parse
return None;
} else {
break;
@ -730,17 +757,21 @@ impl Path {
for (k, v) in param {
match v {
PathPart::Static(val) => if val != hm.get::<str>(k)? {
//static but not the same, fail to parse
return None;
},
PathPart::Dynamic(key) => {
//dynamic, store to hashmap
res.insert(key, hm.get::<str>(k)?);
}
}
}
} else {
//param in query, but not in reference, fail to parse
return None;
}
} else if self.param.is_some() {
//param in reference, but not in query, fail to parse
return None;
}
@ -748,8 +779,10 @@ impl Path {
}
fn map(&self, param: HashMap<&str, &str>) -> Option<String> {
//Generate a path from a reference and a hashmap
let mut res = String::new();
for seg in self.path.iter() {
//TODO add a / if no elements in self.path
res.push('/');
match seg {
PathPart::Static(val) => res.push_str(val),
@ -757,6 +790,7 @@ impl Path {
}
}
if let Some(ref keymap) = self.param {
//if there is some query part
res.push('?');
for (k, v) in keymap {
res.push_str(k);
@ -768,7 +802,7 @@ impl Path {
res.push('&');
}
}
Some(res.trim_right_matches('&').to_owned())
Some(res.trim_right_matches('&').to_owned()) //trim the last '&' which was added if there is a query part
}
}
@ -779,10 +813,12 @@ enum PathPart {
}
fn parse_args<'a>(args: &'a str) -> impl Iterator<Item = (&'a str, &'a str)> {
//transform a group of argument into an iterator of key and value
args.split('&').filter_map(|kv| parse_keyvalue(&kv))
}
fn parse_keyvalue<'a>(kv: &'a str) -> Option<(&'a str, &'a str)> {
//convert a single key-value pair into a key and a value
if let Some(pos) = kv.find('=') {
let (key, value) = kv.split_at(pos + 1);
Some((&key[0..pos], value))

Loading…
Cancel
Save