Skip to main content

mas_jose/jwt/
signed.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use base64ct::{Base64UrlUnpadded, Encoding};
8use rand::thread_rng;
9use serde::{Serialize, de::DeserializeOwned};
10use signature::{RandomizedSigner, SignatureEncoding, Verifier, rand_core::CryptoRngCore};
11use thiserror::Error;
12
13use super::{header::JsonWebSignatureHeader, raw::RawJwt};
14use crate::{constraints::ConstraintSet, jwk::PublicJsonWebKeySet};
15
16#[derive(Clone, PartialEq, Eq)]
17pub struct Jwt<'a, T> {
18    raw: RawJwt<'a>,
19    header: JsonWebSignatureHeader,
20    payload: T,
21    signature: Vec<u8>,
22}
23
24impl<T> std::fmt::Display for Jwt<'_, T> {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        write!(f, "{}", self.raw)
27    }
28}
29
30impl<T> std::fmt::Debug for Jwt<'_, T>
31where
32    T: std::fmt::Debug,
33{
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("Jwt")
36            .field("raw", &"...")
37            .field("header", &self.header)
38            .field("payload", &self.payload)
39            .field("signature", &"...")
40            .finish()
41    }
42}
43
44#[derive(Debug, Error)]
45pub enum JwtDecodeError {
46    #[error(transparent)]
47    RawDecode {
48        #[from]
49        inner: super::raw::DecodeError,
50    },
51
52    #[error("failed to decode JWT header")]
53    DecodeHeader {
54        #[source]
55        inner: base64ct::Error,
56    },
57
58    #[error("failed to deserialize JWT header")]
59    DeserializeHeader {
60        #[source]
61        inner: serde_json::Error,
62    },
63
64    #[error("failed to decode JWT payload")]
65    DecodePayload {
66        #[source]
67        inner: base64ct::Error,
68    },
69
70    #[error("failed to deserialize JWT payload")]
71    DeserializePayload {
72        #[source]
73        inner: serde_json::Error,
74    },
75
76    #[error("failed to decode JWT signature")]
77    DecodeSignature {
78        #[source]
79        inner: base64ct::Error,
80    },
81}
82
83impl JwtDecodeError {
84    fn decode_header(inner: base64ct::Error) -> Self {
85        Self::DecodeHeader { inner }
86    }
87
88    fn deserialize_header(inner: serde_json::Error) -> Self {
89        Self::DeserializeHeader { inner }
90    }
91
92    fn decode_payload(inner: base64ct::Error) -> Self {
93        Self::DecodePayload { inner }
94    }
95
96    fn deserialize_payload(inner: serde_json::Error) -> Self {
97        Self::DeserializePayload { inner }
98    }
99
100    fn decode_signature(inner: base64ct::Error) -> Self {
101        Self::DecodeSignature { inner }
102    }
103}
104
105impl<'a, T> TryFrom<RawJwt<'a>> for Jwt<'a, T>
106where
107    T: DeserializeOwned,
108{
109    type Error = JwtDecodeError;
110    fn try_from(raw: RawJwt<'a>) -> Result<Self, Self::Error> {
111        let header_reader =
112            base64ct::Decoder::<'_, Base64UrlUnpadded>::new(raw.header().as_bytes())
113                .map_err(JwtDecodeError::decode_header)?;
114        let header =
115            serde_json::from_reader(header_reader).map_err(JwtDecodeError::deserialize_header)?;
116
117        let payload_reader =
118            base64ct::Decoder::<'_, Base64UrlUnpadded>::new(raw.payload().as_bytes())
119                .map_err(JwtDecodeError::decode_payload)?;
120        let payload =
121            serde_json::from_reader(payload_reader).map_err(JwtDecodeError::deserialize_payload)?;
122
123        let signature = Base64UrlUnpadded::decode_vec(raw.signature())
124            .map_err(JwtDecodeError::decode_signature)?;
125
126        Ok(Self {
127            raw,
128            header,
129            payload,
130            signature,
131        })
132    }
133}
134
135impl<'a, T> TryFrom<&'a str> for Jwt<'a, T>
136where
137    T: DeserializeOwned,
138{
139    type Error = JwtDecodeError;
140    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
141        let raw = RawJwt::try_from(value)?;
142        Self::try_from(raw)
143    }
144}
145
146impl<T> TryFrom<String> for Jwt<'static, T>
147where
148    T: DeserializeOwned,
149{
150    type Error = JwtDecodeError;
151    fn try_from(value: String) -> Result<Self, Self::Error> {
152        let raw = RawJwt::try_from(value)?;
153        Self::try_from(raw)
154    }
155}
156
157#[derive(Debug, Error)]
158pub enum JwtVerificationError {
159    #[error("failed to parse signature")]
160    ParseSignature,
161
162    #[error("signature verification failed")]
163    Verify {
164        #[source]
165        inner: signature::Error,
166    },
167}
168
169impl JwtVerificationError {
170    fn parse_signature<E>(_inner: E) -> Self {
171        Self::ParseSignature
172    }
173
174    fn verify(inner: signature::Error) -> Self {
175        Self::Verify { inner }
176    }
177}
178
179#[derive(Debug, Error, Default)]
180#[error("none of the keys worked")]
181pub struct NoKeyWorked {
182    _inner: (),
183}
184
185impl<'a, T> Jwt<'a, T> {
186    /// Get the JWT header
187    pub fn header(&self) -> &JsonWebSignatureHeader {
188        &self.header
189    }
190
191    /// Get the JWT payload
192    pub fn payload(&self) -> &T {
193        &self.payload
194    }
195
196    pub fn into_owned(self) -> Jwt<'static, T> {
197        Jwt {
198            raw: self.raw.into_owned(),
199            header: self.header,
200            payload: self.payload,
201            signature: self.signature,
202        }
203    }
204
205    /// Verify the signature of this JWT using the given key.
206    ///
207    /// # Errors
208    ///
209    /// Returns an error if the signature is invalid.
210    pub fn verify<K, S>(&self, key: &K) -> Result<(), JwtVerificationError>
211    where
212        K: Verifier<S>,
213        S: SignatureEncoding,
214    {
215        let signature =
216            S::try_from(&self.signature).map_err(JwtVerificationError::parse_signature)?;
217
218        key.verify(self.raw.signed_part().as_bytes(), &signature)
219            .map_err(JwtVerificationError::verify)
220    }
221
222    /// Verify the signature of this JWT using the given symmetric key.
223    ///
224    /// # Errors
225    ///
226    /// Returns an error if the signature is invalid or if the algorithm is not
227    /// supported.
228    pub fn verify_with_shared_secret(&self, secret: Vec<u8>) -> Result<(), NoKeyWorked> {
229        let verifier = crate::jwa::SymmetricKey::new_for_alg(secret, self.header().alg())
230            .map_err(|_| NoKeyWorked::default())?;
231
232        self.verify(&verifier).map_err(|_| NoKeyWorked::default())?;
233
234        Ok(())
235    }
236
237    /// Verify the signature of this JWT using the given JWKS.
238    ///
239    /// # Errors
240    ///
241    /// Returns an error if the signature is invalid, if no key matches the
242    /// constraints, or if the algorithm is not supported.
243    pub fn verify_with_jwks(&self, jwks: &PublicJsonWebKeySet) -> Result<(), NoKeyWorked> {
244        let constraints = ConstraintSet::from(self.header());
245        let candidates = constraints.filter(&**jwks);
246
247        for candidate in candidates {
248            let Ok(key) = crate::jwa::AsymmetricVerifyingKey::from_jwk_and_alg(
249                candidate.params(),
250                self.header().alg(),
251            ) else {
252                continue;
253            };
254
255            if self.verify(&key).is_ok() {
256                return Ok(());
257            }
258        }
259
260        Err(NoKeyWorked::default())
261    }
262
263    /// Get the raw JWT string as a borrowed [`str`]
264    pub fn as_str(&'a self) -> &'a str {
265        &self.raw
266    }
267
268    /// Get the raw JWT string as an owned [`String`]
269    pub fn into_string(self) -> String {
270        self.raw.into()
271    }
272
273    /// Split the JWT into its parts (header and payload).
274    pub fn into_parts(self) -> (JsonWebSignatureHeader, T) {
275        (self.header, self.payload)
276    }
277}
278
279#[derive(Debug, Error)]
280pub enum JwtSignatureError {
281    #[error("failed to serialize header")]
282    EncodeHeader {
283        #[source]
284        inner: serde_json::Error,
285    },
286
287    #[error("failed to serialize payload")]
288    EncodePayload {
289        #[source]
290        inner: serde_json::Error,
291    },
292
293    #[error("failed to sign")]
294    Signature {
295        #[from]
296        inner: signature::Error,
297    },
298}
299
300impl JwtSignatureError {
301    fn encode_header(inner: serde_json::Error) -> Self {
302        Self::EncodeHeader { inner }
303    }
304
305    fn encode_payload(inner: serde_json::Error) -> Self {
306        Self::EncodePayload { inner }
307    }
308}
309
310impl<T> Jwt<'static, T> {
311    /// Sign the given payload with the given key.
312    ///
313    /// # Errors
314    ///
315    /// Returns an error if the payload could not be serialized or if the key
316    /// could not sign the payload.
317    pub fn sign<K, S>(
318        header: JsonWebSignatureHeader,
319        payload: T,
320        key: &K,
321    ) -> Result<Self, JwtSignatureError>
322    where
323        K: RandomizedSigner<S>,
324        S: SignatureEncoding,
325        T: Serialize,
326    {
327        #[expect(clippy::disallowed_methods)]
328        Self::sign_with_rng(&mut thread_rng(), header, payload, key)
329    }
330
331    /// Sign the given payload with the given key using the given RNG.
332    ///
333    /// # Errors
334    ///
335    /// Returns an error if the payload could not be serialized or if the key
336    /// could not sign the payload.
337    pub fn sign_with_rng<R, K, S>(
338        rng: &mut R,
339        header: JsonWebSignatureHeader,
340        payload: T,
341        key: &K,
342    ) -> Result<Self, JwtSignatureError>
343    where
344        R: CryptoRngCore,
345        K: RandomizedSigner<S>,
346        S: SignatureEncoding,
347        T: Serialize,
348    {
349        let header_ = serde_json::to_vec(&header).map_err(JwtSignatureError::encode_header)?;
350        let header_ = Base64UrlUnpadded::encode_string(&header_);
351
352        let payload_ = serde_json::to_vec(&payload).map_err(JwtSignatureError::encode_payload)?;
353        let payload_ = Base64UrlUnpadded::encode_string(&payload_);
354
355        let mut inner = format!("{header_}.{payload_}");
356
357        let first_dot = header_.len();
358        let second_dot = inner.len();
359
360        let signature = key.try_sign_with_rng(rng, inner.as_bytes())?.to_vec();
361        let signature_ = Base64UrlUnpadded::encode_string(&signature);
362        inner.reserve_exact(1 + signature_.len());
363        inner.push('.');
364        inner.push_str(&signature_);
365
366        let raw = RawJwt::new(inner, first_dot, second_dot);
367
368        Ok(Self {
369            raw,
370            header,
371            payload,
372            signature,
373        })
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    #![allow(clippy::disallowed_methods)]
380    use mas_iana::jose::JsonWebSignatureAlg;
381    use rand::thread_rng;
382
383    use super::*;
384
385    #[test]
386    fn test_jwt_decode() {
387        let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
388        let jwt: Jwt<'_, serde_json::Value> = Jwt::try_from(jwt).unwrap();
389        assert_eq!(jwt.raw.header(), "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9");
390        assert_eq!(
391            jwt.raw.payload(),
392            "eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ"
393        );
394        assert_eq!(
395            jwt.raw.signature(),
396            "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
397        );
398        assert_eq!(
399            jwt.raw.signed_part(),
400            "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ"
401        );
402    }
403
404    #[test]
405    fn test_jwt_sign_and_verify() {
406        let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Es256);
407        let payload = serde_json::json!({"hello": "world"});
408
409        let key = ecdsa::SigningKey::<p256::NistP256>::random(&mut thread_rng());
410        let signed = Jwt::sign::<_, ecdsa::Signature<_>>(header, payload, &key).unwrap();
411        signed
412            .verify::<_, ecdsa::Signature<_>>(key.verifying_key())
413            .unwrap();
414    }
415}