record/
stream.rs

1use std::io::{Read, Write};
2
3use bertie::tls13utils::*;
4use tracing::{debug, error, trace};
5
6use crate::{
7    debug::{info_record, Hex},
8    AppError,
9};
10
11#[derive(Debug)]
12pub struct RecordStream<Stream>
13where
14    Stream: Read + Write,
15{
16    stream: Stream,
17    buffer: Vec<u8>,
18}
19
20impl<Stream> RecordStream<Stream>
21where
22    Stream: Read + Write,
23{
24    pub fn new(stream: Stream) -> Self {
25        Self {
26            stream,
27            buffer: Vec::new(),
28        }
29    }
30}
31
32impl<Stream> RecordStream<Stream>
33where
34    Stream: Read + Write,
35{
36    #[tracing::instrument(skip(self))]
37    pub fn read_record(&mut self) -> Result<Bytes, AppError> {
38        // Buffer to read chunks into.
39        let mut tmp = [0u8; 4096];
40
41        // ```TLS
42        // struct {
43        //     ContentType type;
44        //     ProtocolVersion legacy_record_version;
45        //     uint16 length;
46        //     opaque fragment[TLSPlaintext.length];
47        // } TLSPlaintext;
48        // ```
49        loop {
50            debug!("Search for TLS record in stream buffer.");
51            trace!(buffer = %Hex(&self.buffer), "Buffered data");
52
53            if self.buffer.len() >= 5 {
54                let length = self.buffer[3] as usize * 256 + self.buffer[4] as usize;
55
56                // // TODO: Who does this?
57                // // The length (in bytes) of the following TLSPlaintext.fragment. The length MUST NOT
58                // // exceed 2^14 bytes. An endpoint that receives a record that exceeds this length
59                // // MUST terminate the connection with a "record_overflow" alert.
60                // if length > 16384 {
61                //     // TODO: Correct error?
62                //     panic!("payload has length {}", length);
63                //     return Err(PAYLOAD_TOO_LONG.into());
64                // }
65
66                if self.buffer.len() >= 5 + length {
67                    let record = {
68                        let record = &self.buffer[..5 + length];
69                        info_record(record);
70                        Bytes::from(record)
71                    };
72
73                    self.buffer = self.buffer.split_off(5 + length);
74
75                    if !self.buffer.is_empty() {
76                        debug!("There is still data in the stream buffer.");
77                        trace!(
78                            left = %Hex(&self.buffer),
79                            "There is still data in the stream buffer (content)."
80                        );
81                    }
82
83                    return Ok(record);
84                }
85            }
86
87            debug!(
88                buffer=%Hex(&self.buffer),"No complete TLS record found in stream buffer."
89            );
90            match self.stream.read(&mut tmp) {
91                Ok(l) => match l {
92                    0 => {
93                        error!("Connection closed.");
94                        // TODO: Correct error?
95                        return Err(INSUFFICIENT_DATA.into());
96                    }
97                    amt => {
98                        eprintln!("Read {}", amt);
99                        let data = &tmp[..amt];
100
101                        debug!(amt, "Read data into stream buffer.");
102                        trace!(data=%Hex(data), "Read data into stream buffer (content).");
103
104                        self.buffer.extend_from_slice(data);
105                    }
106                },
107                Err(e) => {
108                    error!("Reading from stream failed with {}", e);
109                    return Err(e.into());
110                }
111            }
112        }
113    }
114
115    #[tracing::instrument(skip(self, record))]
116    pub fn write_record(&mut self, record: Bytes) -> Result<(), AppError> {
117        let data = record.declassify();
118        self.stream.write_all(&data)?;
119
120        debug!(amt = data.len(), "Wrote data.");
121        trace!(data=%Hex(&data), "Wrote data (content).");
122        info_record(&data);
123
124        Ok(())
125    }
126}