1use std::collections::HashSet;
8
9use mas_iana::jose::{JsonWebKeyType, JsonWebKeyUse, JsonWebSignatureAlg};
10
11use crate::jwt::JsonWebSignatureHeader;
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub enum Constraint<'a> {
15 Alg {
16 constraint_alg: &'a JsonWebSignatureAlg,
17 },
18
19 Algs {
20 constraint_algs: &'a [JsonWebSignatureAlg],
21 },
22
23 Kid {
24 constraint_kid: &'a str,
25 },
26
27 Use {
28 constraint_use: &'a JsonWebKeyUse,
29 },
30
31 Kty {
32 constraint_kty: &'a JsonWebKeyType,
33 },
34}
35
36impl<'a> Constraint<'a> {
37 #[must_use]
38 pub fn alg(constraint_alg: &'a JsonWebSignatureAlg) -> Self {
39 Constraint::Alg { constraint_alg }
40 }
41
42 #[must_use]
43 pub fn algs(constraint_algs: &'a [JsonWebSignatureAlg]) -> Self {
44 Constraint::Algs { constraint_algs }
45 }
46
47 #[must_use]
48 pub fn kid(constraint_kid: &'a str) -> Self {
49 Constraint::Kid { constraint_kid }
50 }
51
52 #[must_use]
53 pub fn use_(constraint_use: &'a JsonWebKeyUse) -> Self {
54 Constraint::Use { constraint_use }
55 }
56
57 #[must_use]
58 pub fn kty(constraint_kty: &'a JsonWebKeyType) -> Self {
59 Constraint::Kty { constraint_kty }
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum ConstraintDecision {
65 Positive,
66 Neutral,
67 Negative,
68}
69
70pub trait Constrainable {
71 fn alg(&self) -> Option<&JsonWebSignatureAlg> {
72 None
73 }
74
75 fn algs(&self) -> &[JsonWebSignatureAlg] {
77 &[]
78 }
79
80 fn kid(&self) -> Option<&str> {
82 None
83 }
84
85 fn use_(&self) -> Option<&JsonWebKeyUse> {
87 None
88 }
89
90 fn kty(&self) -> JsonWebKeyType;
92}
93
94impl Constraint<'_> {
95 fn decide<T: Constrainable>(&self, constrainable: &T) -> ConstraintDecision {
96 match self {
97 Constraint::Alg { constraint_alg } => {
98 if let Some(alg) = constrainable.alg() {
100 if alg == *constraint_alg {
101 ConstraintDecision::Positive
102 } else {
103 ConstraintDecision::Negative
104 }
105 } else if constrainable.algs().contains(constraint_alg) {
108 ConstraintDecision::Neutral
109 } else {
110 ConstraintDecision::Negative
111 }
112 }
113 Constraint::Algs { constraint_algs } => {
114 if let Some(alg) = constrainable.alg() {
115 if constraint_algs.contains(alg) {
116 ConstraintDecision::Positive
117 } else {
118 ConstraintDecision::Negative
119 }
120 } else if constrainable
121 .algs()
122 .iter()
123 .any(|alg| constraint_algs.contains(alg))
124 {
125 ConstraintDecision::Neutral
126 } else {
127 ConstraintDecision::Negative
128 }
129 }
130 Constraint::Kid { constraint_kid } => {
131 if let Some(kid) = constrainable.kid() {
132 if kid == *constraint_kid {
133 ConstraintDecision::Positive
134 } else {
135 ConstraintDecision::Negative
136 }
137 } else {
138 ConstraintDecision::Neutral
139 }
140 }
141 Constraint::Use { constraint_use } => {
142 if let Some(use_) = constrainable.use_() {
143 if use_ == *constraint_use {
144 ConstraintDecision::Positive
145 } else {
146 ConstraintDecision::Negative
147 }
148 } else {
149 ConstraintDecision::Neutral
150 }
151 }
152 Constraint::Kty { constraint_kty } => {
153 if **constraint_kty == constrainable.kty() {
154 ConstraintDecision::Positive
155 } else {
156 ConstraintDecision::Negative
157 }
158 }
159 }
160 }
161}
162
163#[derive(Default)]
164pub struct ConstraintSet<'a> {
165 constraints: HashSet<Constraint<'a>>,
166}
167
168impl<'a> FromIterator<Constraint<'a>> for ConstraintSet<'a> {
169 fn from_iter<T: IntoIterator<Item = Constraint<'a>>>(iter: T) -> Self {
170 Self {
171 constraints: HashSet::from_iter(iter),
172 }
173 }
174}
175
176impl<'a> ConstraintSet<'a> {
177 pub fn new(constraints: impl IntoIterator<Item = Constraint<'a>>) -> Self {
178 constraints.into_iter().collect()
179 }
180
181 pub fn filter<'b, T: Constrainable, I: IntoIterator<Item = &'b T>>(
182 &self,
183 constrainables: I,
184 ) -> Vec<&'b T> {
185 let mut selected = Vec::new();
186
187 'outer: for constrainable in constrainables {
188 let mut score = 0;
189
190 for constraint in &self.constraints {
191 match constraint.decide(constrainable) {
192 ConstraintDecision::Positive => score += 1,
193 ConstraintDecision::Neutral => {}
194 ConstraintDecision::Negative => continue 'outer,
196 }
197 }
198
199 selected.push((score, constrainable));
200 }
201
202 selected.sort_by_key(|(score, _)| *score);
203
204 selected
205 .into_iter()
206 .map(|(_score, constrainable)| constrainable)
207 .collect()
208 }
209
210 #[must_use]
211 pub fn alg(mut self, constraint_alg: &'a JsonWebSignatureAlg) -> Self {
212 self.constraints.insert(Constraint::alg(constraint_alg));
213 self
214 }
215
216 #[must_use]
217 pub fn algs(mut self, constraint_algs: &'a [JsonWebSignatureAlg]) -> Self {
218 self.constraints.insert(Constraint::algs(constraint_algs));
219 self
220 }
221
222 #[must_use]
223 pub fn kid(mut self, constraint_kid: &'a str) -> Self {
224 self.constraints.insert(Constraint::kid(constraint_kid));
225 self
226 }
227
228 #[must_use]
229 pub fn use_(mut self, constraint_use: &'a JsonWebKeyUse) -> Self {
230 self.constraints.insert(Constraint::use_(constraint_use));
231 self
232 }
233
234 #[must_use]
235 pub fn kty(mut self, constraint_kty: &'a JsonWebKeyType) -> Self {
236 self.constraints.insert(Constraint::kty(constraint_kty));
237 self
238 }
239}
240
241impl<'a> From<&'a JsonWebSignatureHeader> for ConstraintSet<'a> {
242 fn from(header: &'a JsonWebSignatureHeader) -> Self {
243 let mut constraints = Self::default().alg(header.alg());
244
245 if let Some(kid) = header.kid() {
246 constraints = constraints.kid(kid);
247 }
248
249 constraints
250 }
251}