optee_utee/
extension.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{Error, ErrorKind, Result, Uuid};
19#[cfg(not(feature = "std"))]
20use alloc::{borrow::ToOwned, vec::Vec};
21use optee_utee_sys as raw;
22
23pub struct LoadablePlugin {
24    uuid: Uuid,
25}
26
27pub struct LoadablePluginCommand<'a> {
28    plugin: &'a LoadablePlugin,
29    cmd_id: u32,
30    sub_cmd_id: u32,
31    buffer: Vec<u8>,
32}
33
34impl LoadablePlugin {
35    pub fn new(uuid: &Uuid) -> Self {
36        Self {
37            uuid: uuid.to_owned(),
38        }
39    }
40    /// Invoke plugin with given request data, use when you want to post something into REE.
41    /// ``` rust,no_run
42    /// # use optee_utee::{LoadablePlugin, Uuid};
43    /// # fn main() -> optee_utee::Result<()> {
44    /// # let uuid = Uuid::parse_str("").unwrap();
45    /// # let command_id = 0;
46    /// # let subcommand_id = 0;
47    /// # let request_data = [0_u8; 0];
48    /// let plugin = LoadablePlugin::new(&uuid);
49    /// let result = plugin.invoke(command_id, subcommand_id, &request_data)?;
50    /// # Ok(())
51    /// # }
52    /// ```
53    /// Caution: the size of the shared buffer is set to the len of data, you could get a
54    ///          ShortBuffer error if Plugin return more data than shared buffer, in that case,
55    ///          use invoke_with_capacity and set the capacity manually.
56    pub fn invoke(&self, command_id: u32, subcommand_id: u32, data: &[u8]) -> Result<Vec<u8>> {
57        self.invoke_with_capacity(command_id, subcommand_id, data.len())
58            .chain_write_body(data)
59            .call()
60    }
61    /// Construct a command with shared buffer up to capacity size, write the buffer and call it
62    /// manually, use when you need to control details of the invoking process.
63    /// ```no_run
64    /// # use optee_utee::{Uuid, LoadablePlugin};
65    /// # fn main() -> optee_utee::Result<()> {
66    /// # let plugin = LoadablePlugin::new(&Uuid::parse_str("").unwrap());
67    /// # let request_data = [0_u8; 0];
68    /// # let command_id = 0;
69    /// # let sub_command_id = 0;
70    /// # let capacity = 0;
71    /// let mut cmd = plugin.invoke_with_capacity(command_id, sub_command_id, capacity);
72    /// cmd.write_body(&request_data);
73    /// let result = cmd.call()?;
74    /// # Ok(())
75    /// # }
76    /// ```
77    /// You can also imply a wrapper for performance, for example, imply a std::io::Write so
78    /// serde_json can write to the buffer directly.
79    /// ```no_run
80    /// # use optee_utee::{LoadablePluginCommand, Uuid, LoadablePlugin, trace_println};
81    /// # use optee_utee::ErrorKind;
82    /// # fn main() -> optee_utee::Result<()> {
83    /// # let command_id = 0;
84    /// # let subcommand_id = 0;
85    /// # let capacity = 0;
86    /// # let plugin = LoadablePlugin::new(&Uuid::parse_str("").unwrap());
87    /// struct Wrapper<'a, 'b>(&'b mut LoadablePluginCommand<'a>);
88    /// impl<'a, 'b> std::io::Write for Wrapper<'a, 'b> {
89    ///     fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
90    ///         self.0.write_body(buf);
91    ///         Ok(buf.len())
92    ///     }
93    ///     fn flush(&mut self) -> std::io::Result<()> {
94    ///         Ok(())
95    ///     }
96    /// }
97    /// // serialize data into command directly
98    /// let request_data = serde_json::json!({
99    ///     "age": 100,
100    ///     "name": "name"
101    /// });
102    /// let mut cmd = plugin.invoke_with_capacity(command_id, subcommand_id, capacity);
103    /// serde_json::to_writer(Wrapper(&mut cmd), &request_data).map_err(|err| {
104    ///     trace_println!("serde error: {:?}", err);
105    ///     ErrorKind::Unknown
106    /// })?;
107    /// let result = cmd.call()?;
108    ///
109    /// # Ok(())
110    /// # }
111    /// ```
112    /// Notice: the shared buffer could grow to fit the request data automatically.
113    pub fn invoke_with_capacity(
114        &self,
115        command_id: u32,
116        subcommand_id: u32,
117        capacity: usize,
118    ) -> LoadablePluginCommand<'_> {
119        LoadablePluginCommand::new_with_capacity(self, command_id, subcommand_id, capacity)
120    }
121}
122
123impl<'a> LoadablePluginCommand<'a> {
124    // use this to write request body if needed
125    pub fn write_body(&mut self, data: &[u8]) {
126        self.buffer.extend_from_slice(data);
127    }
128    // same with write_body, but chainable
129    pub fn chain_write_body(mut self, data: &[u8]) -> Self {
130        self.write_body(data);
131        self
132    }
133    // invoke the command, and get result from it
134    pub fn call(self) -> Result<Vec<u8>> {
135        let mut outlen: usize = 0;
136        let mut buffer = self.buffer;
137        buffer.resize(buffer.capacity(), 0); // resize to capacity first
138        match unsafe {
139            raw::tee_invoke_supp_plugin(
140                self.plugin.uuid.as_raw_ptr(),
141                self.cmd_id,
142                self.sub_cmd_id,
143                // convert the pointer manually, as in some platform c_char is i8
144                buffer.as_mut_slice().as_mut_ptr() as *mut _,
145                buffer.len(),
146                &mut outlen as *mut usize,
147            )
148        } {
149            raw::TEE_SUCCESS => {
150                if outlen > buffer.len() {
151                    return Err(ErrorKind::ShortBuffer.into());
152                }
153                buffer.resize(outlen, 0);
154                Ok(buffer)
155            }
156            code => Err(Error::from_raw_error(code)),
157        }
158    }
159}
160
161impl<'a> LoadablePluginCommand<'a> {
162    fn new_with_capacity(
163        plugin: &'a LoadablePlugin,
164        cmd_id: u32,
165        sub_cmd_id: u32,
166        capacity: usize,
167    ) -> Self {
168        Self {
169            plugin,
170            cmd_id,
171            sub_cmd_id,
172            buffer: Vec::with_capacity(capacity),
173        }
174    }
175}
176
177#[cfg(test)]
178pub mod test_loadable_plugin {
179    extern crate std;
180    use super::*;
181    use alloc::string::ToString;
182    use optee_utee_sys::{mock_api, mock_utils::SERIAL_TEST_LOCK};
183    use rand::distr::Alphanumeric;
184
185    fn generate_random_bytes(len: usize) -> Vec<u8> {
186        use rand::RngExt;
187
188        rand::rng().sample_iter(&Alphanumeric).take(len).collect()
189    }
190
191    fn generate_test_pairs(
192        request_size: usize,
193        response_size: usize,
194    ) -> (u32, u32, Vec<u8>, Vec<u8>) {
195        let cmd: u32 = rand::random();
196        let sub_cmd: u32 = rand::random();
197        let random_request: Vec<u8> = generate_random_bytes(request_size);
198        let random_response: Vec<u8> = generate_random_bytes(response_size);
199        (cmd, sub_cmd, random_request, random_response)
200    }
201
202    fn random_uuid() -> Uuid {
203        Uuid::new_raw(
204            rand::random(),
205            rand::random(),
206            rand::random(),
207            rand::random(),
208        )
209    }
210
211    fn expect_success_request(
212        ctx: &mock_api::extension::__tee_invoke_supp_plugin::Context,
213        exp_uuid: &Uuid,
214        exp_cmd: u32,
215        exp_sub_cmd: u32,
216        exp_request: &[u8],
217        exp_response: &[u8],
218    ) {
219        let exp_request = exp_request.to_vec();
220        let exp_response = exp_response.to_vec();
221        let exp_uuid = exp_uuid.to_string();
222        ctx.expect()
223            .return_once_st(move |uuid, cmd, sub_cmd, buf, len, outlen| {
224                let request_uuid = Uuid::from(unsafe { *uuid }).to_string();
225                debug_assert_eq!(exp_uuid, request_uuid);
226                debug_assert_eq!(cmd, exp_cmd);
227                debug_assert_eq!(sub_cmd, exp_sub_cmd);
228                debug_assert_eq!(
229                    unsafe { core::slice::from_raw_parts(buf as *mut u8, exp_request.len()) },
230                    exp_request.as_slice()
231                );
232                debug_assert!(len >= exp_response.len());
233                let buffer: &mut [u8] =
234                    unsafe { core::slice::from_raw_parts_mut(buf as *mut u8, len) };
235                buffer[0..exp_response.len()].copy_from_slice(&exp_response);
236                unsafe { *outlen = exp_response.len() };
237                raw::TEE_SUCCESS
238            });
239    }
240
241    #[test]
242    fn test_invoke() {
243        let _lock = SERIAL_TEST_LOCK.lock().expect("should get the lock");
244
245        let uuid: Uuid = random_uuid();
246        let plugin = LoadablePlugin::new(&uuid);
247        const REQUEST_LEN: usize = 32;
248        let run_test = |request_size: usize, response_size: usize| {
249            let (cmd, sub_cmd, request, exp_response) =
250                generate_test_pairs(request_size, response_size);
251            let fn1 = mock_api::extension::tee_invoke_supp_plugin_context();
252            expect_success_request(&fn1, &uuid, cmd, sub_cmd, &request, &exp_response);
253            let response = plugin.invoke(cmd, sub_cmd, &request).expect("should be ok");
254            std::println!("*TA*: response is {:?}", response);
255            debug_assert_eq!(response, exp_response);
256        };
257
258        // test calling with output size less than input
259        run_test(REQUEST_LEN, REQUEST_LEN / 2);
260        // test calling with output size equals to input
261        run_test(REQUEST_LEN, REQUEST_LEN);
262        // test calling with output size greater than input.
263        // Mark: Without explicitly setting the response size, this function
264        // must not be called with a response size larger than the request size.
265        {
266            let (cmd, sub_cmd, request, exp_response) =
267                generate_test_pairs(REQUEST_LEN, 2 * REQUEST_LEN);
268            let fn1 = mock_api::extension::tee_invoke_supp_plugin_context();
269            fn1.expect().return_once_st(move |_, _, _, _, _, outlen| {
270                unsafe { *outlen = exp_response.len() };
271                raw::TEE_SUCCESS
272            });
273            let err = plugin
274                .invoke(cmd, sub_cmd, &request)
275                .expect_err("should be err");
276            debug_assert_eq!(err.kind(), ErrorKind::ShortBuffer);
277        }
278    }
279
280    // This test is equivalent to test_invoke, with the added verification that
281    // capacity permits the response size to be larger than the request.
282    #[test]
283    fn test_invoke_with_capacity() {
284        let _lock = SERIAL_TEST_LOCK.lock().expect("should get the lock");
285        let uuid: Uuid = random_uuid();
286        let plugin = LoadablePlugin::new(&uuid);
287        const RESPONSE_LEN: usize = 32;
288
289        let run_test = |request_size: usize, response_size: usize| {
290            let (cmd, sub_cmd, request, exp_response) =
291                generate_test_pairs(request_size, response_size);
292            let fn1 = mock_api::extension::tee_invoke_supp_plugin_context();
293            expect_success_request(&fn1, &uuid, cmd, sub_cmd, &request, &exp_response);
294
295            let response = plugin
296                .invoke_with_capacity(cmd, sub_cmd, exp_response.len())
297                .chain_write_body(&request)
298                .call()
299                .unwrap();
300            std::println!("*TA*: response is {:?}", response);
301            debug_assert_eq!(response, exp_response);
302        };
303
304        // test calling with output size less than input
305        run_test(2 * RESPONSE_LEN, RESPONSE_LEN);
306        // test calling with output size equals to input
307        run_test(RESPONSE_LEN, RESPONSE_LEN);
308        // test calling with output size greater than input
309        run_test(RESPONSE_LEN / 2, RESPONSE_LEN);
310    }
311
312    #[test]
313    fn test_invoke_with_writer() {
314        let _lock = SERIAL_TEST_LOCK.lock().expect("should get the lock");
315        let uuid: Uuid = random_uuid();
316        let plugin = LoadablePlugin::new(&uuid);
317        let fn1 = mock_api::extension::tee_invoke_supp_plugin_context();
318        // impl a writer for Command
319        struct Wrapper<'a, 'b>(&'b mut LoadablePluginCommand<'a>);
320        impl<'a, 'b> std::io::Write for Wrapper<'a, 'b> {
321            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
322                self.0.write_body(buf);
323                Ok(buf.len())
324            }
325            fn flush(&mut self) -> std::io::Result<()> {
326                Ok(())
327            }
328        }
329        // serialize data into command directly
330        let test_data = serde_json::json!({
331            "code": 100,
332            "message": "error"
333        });
334        let mut exp_request = serde_json::to_vec(&test_data).unwrap();
335        let buffer_len = exp_request.len() * 2;
336        let (cmd, sub_cmd, _, exp_response) = generate_test_pairs(0, buffer_len);
337        let mut plugin_cmd = plugin.invoke_with_capacity(cmd, sub_cmd, buffer_len);
338        exp_request.resize(exp_response.len(), 0);
339
340        expect_success_request(&fn1, &uuid, cmd, sub_cmd, &exp_request, &exp_response);
341        serde_json::to_writer(Wrapper(&mut plugin_cmd), &test_data).unwrap();
342        let response = plugin_cmd.call().unwrap();
343        std::println!("*TA*: response is {:?}", response);
344        debug_assert_eq!(response, exp_response);
345    }
346
347    #[test]
348    fn test_invoke_with_no_data() {
349        let _lock = SERIAL_TEST_LOCK.lock().expect("should get the lock");
350
351        let uuid: Uuid = random_uuid();
352        let plugin = LoadablePlugin::new(&uuid);
353        let fn1 = mock_api::extension::tee_invoke_supp_plugin_context();
354        const OUTPUT_LEN: usize = 50;
355        let (cmd, sub_cmd, request, exp_response) = generate_test_pairs(0, OUTPUT_LEN);
356        expect_success_request(&fn1, &uuid, cmd, sub_cmd, &request, &exp_response);
357
358        let response = plugin
359            .invoke_with_capacity(cmd, sub_cmd, OUTPUT_LEN)
360            .call()
361            .unwrap();
362        std::println!("*TA*: response is {:?}", response);
363        debug_assert_eq!(response, exp_response);
364    }
365}