More validation

master
D4VID 1 year ago
parent c48b921c78
commit 8a270a7486

@ -14,8 +14,12 @@ public class OAuthController(
JwtService jwt, JwtService jwt,
IDataProtectionProvider dataProtectionProvider IDataProtectionProvider dataProtectionProvider
) : ControllerBase { ) : ControllerBase {
private readonly Dictionary<string, string> _clients = new() { private record Client(string ClientId, string ClientSecret, string ClientOrigin, string ClientScope);
{"lmao", "yeet"},
private record CodeObject(string ClientId, string RedirectUri, DateTime Expiry);
private readonly Dictionary<string, Client> _clients = new() {
{"lmao", new Client("lmao", "yeet", "http://localhost:5126", "41")},
}; };
@ -28,26 +32,35 @@ public class OAuthController(
string? client_id, string? client_id,
string? state string? state
) { ) {
// Check if the required fields are present
if (string.IsNullOrEmpty(response_type) || string.IsNullOrEmpty(client_id) || string.IsNullOrEmpty(state)) { if (string.IsNullOrEmpty(response_type) || string.IsNullOrEmpty(client_id) || string.IsNullOrEmpty(state)) {
return Redirect($"{redirect_uri}?error=invalid_request"); return Redirect($"{redirect_uri}?error=invalid_request");
} }
// The only supported option is "code"
if (response_type != "code") { if (response_type != "code") {
return Redirect($"{redirect_uri}?error=unsupported_response_type&state={state}"); return Redirect($"{redirect_uri}?error=unsupported_response_type&state={state}");
} }
if (!_clients.ContainsKey(client_id)) { // Check if the client exists
if (!_clients.TryGetValue(client_id, out Client? client)) {
logger.LogInformation("Unknown client id"); logger.LogInformation("Unknown client id");
return Redirect($"{redirect_uri}?error=unauthorized_client&state={state}"); return Redirect($"{redirect_uri}?error=unauthorized_client&state={state}");
} }
// Check if the origin matches the pre-registered one
string origin = GetOrigin(redirect_uri);
if (origin != client.ClientOrigin) {
return Redirect($"{redirect_uri}?error=unauthorized_client&state={state}");
}
// Turn the client info with an expiration to a opaque value using the data protection api
IDataProtector protector = dataProtectionProvider.CreateProtector("oauth"); IDataProtector protector = dataProtectionProvider.CreateProtector("oauth");
CodeObject codeObject = new CodeObject( CodeObject codeObject = new CodeObject(
ClientId: client_id, ClientId: client_id,
RedirectUri: redirect_uri, RedirectUri: redirect_uri,
Expiry: DateTime.UtcNow.AddMinutes(5) Expiry: DateTime.UtcNow.AddMinutes(5)
); );
string code = protector.Protect(JsonSerializer.Serialize(codeObject)); string code = protector.Protect(JsonSerializer.Serialize(codeObject));
return Redirect($"{redirect_uri}?code={code}&state={state}"); return Redirect($"{redirect_uri}?code={code}&state={state}");
@ -64,6 +77,7 @@ public class OAuthController(
[HttpPost("token")] [HttpPost("token")]
[Consumes("application/x-www-form-urlencoded")] [Consumes("application/x-www-form-urlencoded")]
public ActionResult GenerateToken([FromForm] GenerateTokenRequest request) { public ActionResult GenerateToken([FromForm] GenerateTokenRequest request) {
// Check if all the required fields are present
if (string.IsNullOrEmpty(request.grant_type) || if (string.IsNullOrEmpty(request.grant_type) ||
string.IsNullOrEmpty(request.code) || string.IsNullOrEmpty(request.code) ||
string.IsNullOrEmpty(request.redirect_uri) || string.IsNullOrEmpty(request.redirect_uri) ||
@ -72,20 +86,24 @@ public class OAuthController(
return BadRequest(new {error = "invalid_request"}); return BadRequest(new {error = "invalid_request"});
} }
// The only supported option is "authorization_code"
if (request.grant_type != "authorization_code") { if (request.grant_type != "authorization_code") {
return BadRequest(new {error = "unsupported_grant_type"}); return BadRequest(new {error = "unsupported_grant_type"});
} }
if (!_clients.TryGetValue(request.client_id, out string? clientSecret)) { // Check if the client exists
if (!_clients.TryGetValue(request.client_id, out Client? client)) {
logger.LogInformation("Unknown client id"); logger.LogInformation("Unknown client id");
return BadRequest(new {error = "unauthorized_client"}); return BadRequest(new {error = "unauthorized_client"});
} }
if (request.client_secret != clientSecret) { // Check if the client secret matches
if (request.client_secret != client.ClientSecret) {
logger.LogInformation("Invalid client secret"); logger.LogInformation("Invalid client secret");
return BadRequest(new {error = "unauthorized_client"}); return BadRequest(new {error = "unauthorized_client"});
} }
// Retrieve client information set by the preceding authorization step
IDataProtector protector = dataProtectionProvider.CreateProtector("oauth"); IDataProtector protector = dataProtectionProvider.CreateProtector("oauth");
CodeObject? codeObject; CodeObject? codeObject;
try { try {
@ -98,22 +116,33 @@ public class OAuthController(
return BadRequest(new {error = "invalid_request"}); return BadRequest(new {error = "invalid_request"});
} }
// Check if the values are consistent
if (codeObject.ClientId != request.client_id || codeObject.RedirectUri != request.redirect_uri) { if (codeObject.ClientId != request.client_id || codeObject.RedirectUri != request.redirect_uri) {
return BadRequest(new {error = "invalid_request"}); return BadRequest(new {error = "invalid_request"});
} }
// Check the token's expiration
if (DateTime.UtcNow > codeObject.Expiry) { if (DateTime.UtcNow > codeObject.Expiry) {
logger.LogInformation("Expired token"); logger.LogInformation("Expired token");
return BadRequest(new {error = "invalid_grant"}); return BadRequest(new {error = "invalid_grant"});
} }
string token = jwt.GenerateToken(); // Generate the auth token for the application server
string token = jwt.GenerateToken("1", client.ClientId, "External", client.ClientScope);
// Add http headers to prevent caching
Response.Headers.Append("Cache-Control", "no-store"); Response.Headers.Append("Cache-Control", "no-store");
Response.Headers.Append("Pragma", "no-cache"); Response.Headers.Append("Pragma", "no-cache");
return Ok(new {access_token = token, token_type = "bearer"}); return Ok(new {
access_token = token,
token_type = "bearer",
scope = client.ClientScope
});
} }
private record CodeObject(string ClientId, string RedirectUri, DateTime Expiry); private static string GetOrigin(string url) {
Uri uri = new Uri(url);
return $"{uri.Scheme}://{uri.Authority}";
}
} }

@ -31,14 +31,15 @@ public class JwtService {
return rsaKey; return rsaKey;
} }
public string GenerateToken() { public string GenerateToken(string userId, string clientId, string role, string scope) {
var handler = new JsonWebTokenHandler(); var handler = new JsonWebTokenHandler();
var key = new RsaSecurityKey(_rsaKey); var key = new RsaSecurityKey(_rsaKey);
var token = handler.CreateToken(new SecurityTokenDescriptor { var token = handler.CreateToken(new SecurityTokenDescriptor {
Subject = new ClaimsIdentity(new[] { Subject = new ClaimsIdentity(new[] {
new Claim(JwtRegisteredClaimNames.Sub, "user1"), new Claim(JwtRegisteredClaimNames.Sub, userId),
new Claim("role", "External"), new Claim("client", clientId),
new Claim("scope", "scope:1") new Claim("role", role),
new Claim("scope", scope)
}), }),
Expires = DateTime.UtcNow.AddDays(10), Expires = DateTime.UtcNow.AddDays(10),
SigningCredentials = new SigningCredentials(key, SecurityAlgorithms.RsaSha256) SigningCredentials = new SigningCredentials(key, SecurityAlgorithms.RsaSha256)

Loading…
Cancel
Save