1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
//! A buffer for reading data from the network.
//!
//! The `InputBuffer` is a buffer of bytes similar to a first-in, first-out queue.
//! It is filled by reading from a stream supporting `Read` and is then
//! accessible as a cursor for reading bytes.
#![deny(missing_debug_implementations)]
extern crate bytes;

use std::error;
use std::fmt;
use std::io::{Cursor, Read, Result as IoResult};

use bytes::{Buf, BufMut};

/// A FIFO buffer for reading packets from network.
#[derive(Debug)]
pub struct InputBuffer(Cursor<Vec<u8>>);

/// The recommended minimum read size.
pub const MIN_READ: usize = 4096;

impl InputBuffer {
    /// Create a new empty input buffer.
    pub fn new() -> Self {
        Self::with_capacity(MIN_READ)
    }

    /// Create a new empty input buffer.
    pub fn with_capacity(capacity: usize) -> Self {
        Self::from_partially_read(Vec::with_capacity(capacity))
    }

    /// Create a input buffer filled with previously read data.
    pub fn from_partially_read(part: Vec<u8>) -> Self {
        InputBuffer(Cursor::new(part))
    }

    /// Get the data as a cursor.
    pub fn as_cursor(&self) -> &Cursor<Vec<u8>> {
        &self.0
    }

    /// Get the data as a mutable cursor.
    pub fn as_cursor_mut(&mut self) -> &mut Cursor<Vec<u8>> {
        &mut self.0
    }

    /// Remove the already consumed portion of the data.
    pub fn remove_garbage(&mut self) {
        let pos = self.0.position() as usize;
        self.0.get_mut().drain(0..pos).count();
        self.0.set_position(0);
    }

    /// Get the rest of the buffer and destroy the buffer.
    pub fn into_vec(mut self) -> Vec<u8> {
        self.remove_garbage();
        self.0.into_inner()
    }

    /// Read next portion of data from the given input stream.
    pub fn read_from<S: Read>(&mut self, stream: &mut S) -> IoResult<usize> {
        self.prepare().read_from(stream)
    }

    /// Prepare reading.
    pub fn prepare<'t>(&'t mut self) -> DoRead<'t> {
        self.prepare_reserve(MIN_READ)
    }

    /// Prepare reading with the given reserve.
    pub fn prepare_reserve<'t>(&'t mut self, reserve: usize) -> DoRead<'t> {
        // Space that we have right now.
        let free_space = self.total_len() - self.filled_len();
        // Space that we could have after garbage collect.
        let total_space = free_space + self.consumed_len();
        // If garbage collect would help, schedule it.
        let remove_garbage = free_space < reserve && total_space >= reserve;

        DoRead {
            buf: self,
            remove_garbage,
            reserve,
        }
    }
}

impl InputBuffer {
    /// Get the total buffer length.
    fn total_len(&self) -> usize {
        self.0.get_ref().capacity()
    }

    /// Get the filled buffer length.
    fn filled_len(&self) -> usize {
        self.0.get_ref().len()
    }

    /// Get the consumed data length.
    fn consumed_len(&self) -> usize {
        self.0.position() as usize
    }
}

impl Buf for InputBuffer {
    fn remaining(&self) -> usize {
        Buf::remaining(self.as_cursor())
    }
    fn bytes(&self) -> &[u8] {
        Buf::bytes(self.as_cursor())
    }
    fn advance(&mut self, size: usize) {
        Buf::advance(self.as_cursor_mut(), size)
    }
}

/// The reference to the buffer used for reading.
#[derive(Debug)]
pub struct DoRead<'t> {
    buf: &'t mut InputBuffer,
    remove_garbage: bool,
    reserve: usize,
}

impl<'t> DoRead<'t> {
    /// Enforce the size limit.
    pub fn with_limit(mut self, limit: usize) -> Result<Self, SizeLimit> {
        // Total size we shall have after reserve.
        let total_len = self.buf.filled_len() + self.reserve;
        // Size we could free if we collect garbage.
        let consumed_len = self.buf.consumed_len();
        // Shall we fit if we remove data already consumed?
        if total_len - consumed_len <= limit {
            // Shall we not fit if we don't remove data already consumed?
            if total_len > limit {
                self.remove_garbage = true;
            }
            Ok(self)
        } else {
            Err(SizeLimit)
        }
    }

    /// Read next portion of data from the given input stream.
    pub fn read_from<S: Read>(self, stream: &mut S) -> IoResult<usize> {
        if self.remove_garbage {
            self.buf.remove_garbage();
        }

        let v: &mut Vec<u8> = self.buf.0.get_mut();

        v.reserve(self.reserve);

        assert!(v.capacity() > v.len());
        let size = unsafe {
            // TODO: This can be replaced by std::mem::MaybeUninit::first_ptr_mut() once
            // it is stabilized.
            let data = &mut v.bytes_mut()[..self.reserve];
            // We first have to initialize the data or otherwise casting to a byte slice
            // below is UB. See also code of std::io::copy(), tokio::AsyncRead::poll_read_buf()
            // and others.
            //
            // Read::read() might read uninitialized data otherwise, and generally creating
            // references to uninitialized data is UB.
            for x in data.iter_mut() {
                *x.as_mut_ptr() = 0;
            }
            // Now it's safe to cast it to a byte slice
            let data = std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len());
            let size = stream.read(data)?;
            v.advance_mut(size);
            size
        };
        Ok(size)
    }
}

/// Size limit error.
#[derive(Debug, Clone, Copy)]
pub struct SizeLimit;

impl fmt::Display for SizeLimit {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "SizeLimit")
    }
}

impl error::Error for SizeLimit {
    fn description(&self) -> &'static str {
        "Size limit exceeded"
    }
}

#[cfg(test)]
mod tests {

    use super::InputBuffer;
    use bytes::Buf;
    use std::io::Cursor;

    #[test]
    fn simple_reading() {
        let mut inp = Cursor::new(b"Hello World!".to_vec());
        let mut buf = InputBuffer::new();
        let size = buf.read_from(&mut inp).unwrap();
        assert_eq!(size, 12);
        assert_eq!(buf.bytes(), b"Hello World!");
    }

    #[test]
    fn partial_reading() {
        let mut inp = Cursor::new(b"Hello World!".to_vec());
        let mut buf = InputBuffer::with_capacity(4);
        let size = buf.prepare_reserve(4).read_from(&mut inp).unwrap();
        assert_eq!(size, 4);
        assert_eq!(buf.bytes(), b"Hell");
        buf.advance(2);
        assert_eq!(buf.bytes(), b"ll");
        let size = buf.prepare_reserve(1).read_from(&mut inp).unwrap();
        assert_eq!(size, 1);
        assert_eq!(buf.bytes(), b"llo");
        let size = buf.prepare_reserve(4).read_from(&mut inp).unwrap();
        assert_eq!(size, 4);
        assert_eq!(buf.bytes(), b"llo Wor");
        let size = buf.prepare_reserve(16).read_from(&mut inp).unwrap();
        assert_eq!(size, 3);
        assert_eq!(buf.bytes(), b"llo World!");
    }

    #[test]
    fn limiting() {
        let mut inp = Cursor::new(b"Hello World!".to_vec());
        let mut buf = InputBuffer::with_capacity(4);
        let size = buf
            .prepare_reserve(4)
            .with_limit(5)
            .unwrap()
            .read_from(&mut inp)
            .unwrap();
        assert_eq!(size, 4);
        assert_eq!(buf.bytes(), b"Hell");
        buf.advance(2);
        assert_eq!(buf.bytes(), b"ll");
        {
            let e = buf.prepare_reserve(4).with_limit(5);
            assert!(e.is_err());
        }
        buf.advance(1);
        let size = buf
            .prepare_reserve(4)
            .with_limit(5)
            .unwrap()
            .read_from(&mut inp)
            .unwrap();
        assert_eq!(size, 4);
        assert_eq!(buf.bytes(), b"lo Wo");
    }
}