Add authenticated endpoint.

This commit is contained in:
Tom Alexander
2024-10-15 23:03:52 -04:00
parent f9d3c551f0
commit 9ea4952327
14 changed files with 501 additions and 29 deletions

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from pydantic import BaseModel, Field
from typing import Literal, Annotated, Union
class ResponseJwks(BaseModel):
keys: list[ResponseJwk]
class ResponseRsaJwk(BaseModel):
kty: Literal["RSA"]
d: str | None = Field(default=None)
q: str | None = Field(default=None)
qi: str | None = Field(default=None)
dq: str | None = Field(default=None)
e: str | None = Field(default=None)
key_ops: list[str] | None = Field(default=None)
dp: str | None = Field(default=None)
n: str | None = Field(default=None)
p: str | None = Field(default=None)
class ResponseEcJwk(BaseModel):
kty: Literal["EC"]
ResponseJwk = Annotated[
Union[ResponseRsaJwk, ResponseEcJwk], Field(discriminator="kty")
]

View File

@@ -0,0 +1,40 @@
import jwt
import subprocess
from pathlib import Path
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
import time
def main():
gateway_address = get_terraform_output("gateway_address")
client_id = get_terraform_output("client_id")
jwt_private_key = get_terraform_output("jwt_private_key")
private_key = serialization.load_pem_private_key(
jwt_private_key.encode("utf-8"), password=None, backend=default_backend()
)
encoded = jwt.encode(
{
"iss": "issuer of the token",
"sub": "Alice",
"aud": client_id,
"iat": int(time.time()),
"exp": int(time.time()) + 30 * 60 * 60 * 24,
},
private_key,
algorithm="RS256",
)
print(encoded)
def get_terraform_output(name: str) -> str:
terraform_folder = Path(__file__).parent / "../../terraform"
result = subprocess.run(
["terraform", f"-chdir={terraform_folder}", "output", "-raw", name],
stdout=subprocess.PIPE,
)
return result.stdout.decode("utf-8")
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,13 @@
from __future__ import annotations
from fastapi import FastAPI
import jwt.algorithms
import jwt
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from functools import cache
import os
from api_server.jwks_schema import ResponseJwks, ResponseRsaJwk
import time
app = FastAPI()
@@ -6,3 +15,47 @@ app = FastAPI()
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get("/some_protected_endpoint")
async def protected_endpoint():
# TODO: Read X-Apigateway-Api-Userinfo header
return {"message": "You reached the protected endpoint"}
@app.get("/get_short_lived_token")
async def get_short_lived_token():
gateway_address = os.environ["JWT_GATEWAY_ADDRESS"]
client_id = os.environ["JWT_CLIENT_ID"]
jwt_private_key = os.environ["JWT_PRIVATE_KEY"]
private_key = serialization.load_pem_private_key(
jwt_private_key.encode("utf-8"), password=None, backend=default_backend()
)
encoded = jwt.encode(
{
"iss": "issuer of the token",
"sub": "Alice",
"aud": client_id,
"iat": int(time.time()),
"exp": int(time.time()) + 5 * 60,
},
private_key,
algorithm="RS256",
)
return {"message": "This JWT will only last 5 minutes.", "token": encoded}
@app.get("/.well-known/jwks.json")
async def jwks() -> ResponseJwks:
return _generate_jwks()
@cache
def _generate_jwks() -> ResponseJwks:
private_key_pem = os.environ["JWT_PRIVATE_KEY"]
private_key = serialization.load_pem_private_key(
private_key_pem.encode("utf-8"), password=None, backend=default_backend()
)
jwk_json = jwt.algorithms.RSAAlgorithm.to_jwk(private_key)
return ResponseJwks(keys=[ResponseRsaJwk.model_validate_json(jwk_json)])