Add authenticated endpoint.
This commit is contained in:
29
api_server/api_server/jwks_schema.py
Normal file
29
api_server/api_server/jwks_schema.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Literal, Annotated, Union
|
||||
|
||||
|
||||
class ResponseJwks(BaseModel):
|
||||
keys: list[ResponseJwk]
|
||||
|
||||
|
||||
class ResponseRsaJwk(BaseModel):
|
||||
kty: Literal["RSA"]
|
||||
d: str | None = Field(default=None)
|
||||
q: str | None = Field(default=None)
|
||||
qi: str | None = Field(default=None)
|
||||
dq: str | None = Field(default=None)
|
||||
e: str | None = Field(default=None)
|
||||
key_ops: list[str] | None = Field(default=None)
|
||||
dp: str | None = Field(default=None)
|
||||
n: str | None = Field(default=None)
|
||||
p: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ResponseEcJwk(BaseModel):
|
||||
kty: Literal["EC"]
|
||||
|
||||
|
||||
ResponseJwk = Annotated[
|
||||
Union[ResponseRsaJwk, ResponseEcJwk], Field(discriminator="kty")
|
||||
]
|
||||
40
api_server/api_server/local_generate_long_lived_token.py
Normal file
40
api_server/api_server/local_generate_long_lived_token.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import jwt
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
import time
|
||||
|
||||
|
||||
def main():
|
||||
gateway_address = get_terraform_output("gateway_address")
|
||||
client_id = get_terraform_output("client_id")
|
||||
jwt_private_key = get_terraform_output("jwt_private_key")
|
||||
private_key = serialization.load_pem_private_key(
|
||||
jwt_private_key.encode("utf-8"), password=None, backend=default_backend()
|
||||
)
|
||||
encoded = jwt.encode(
|
||||
{
|
||||
"iss": "issuer of the token",
|
||||
"sub": "Alice",
|
||||
"aud": client_id,
|
||||
"iat": int(time.time()),
|
||||
"exp": int(time.time()) + 30 * 60 * 60 * 24,
|
||||
},
|
||||
private_key,
|
||||
algorithm="RS256",
|
||||
)
|
||||
print(encoded)
|
||||
|
||||
|
||||
def get_terraform_output(name: str) -> str:
|
||||
terraform_folder = Path(__file__).parent / "../../terraform"
|
||||
result = subprocess.run(
|
||||
["terraform", f"-chdir={terraform_folder}", "output", "-raw", name],
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
return result.stdout.decode("utf-8")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,4 +1,13 @@
|
||||
from __future__ import annotations
|
||||
from fastapi import FastAPI
|
||||
import jwt.algorithms
|
||||
import jwt
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from functools import cache
|
||||
import os
|
||||
from api_server.jwks_schema import ResponseJwks, ResponseRsaJwk
|
||||
import time
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -6,3 +15,47 @@ app = FastAPI()
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Hello World"}
|
||||
|
||||
|
||||
@app.get("/some_protected_endpoint")
|
||||
async def protected_endpoint():
|
||||
# TODO: Read X-Apigateway-Api-Userinfo header
|
||||
return {"message": "You reached the protected endpoint"}
|
||||
|
||||
|
||||
@app.get("/get_short_lived_token")
|
||||
async def get_short_lived_token():
|
||||
gateway_address = os.environ["JWT_GATEWAY_ADDRESS"]
|
||||
client_id = os.environ["JWT_CLIENT_ID"]
|
||||
jwt_private_key = os.environ["JWT_PRIVATE_KEY"]
|
||||
private_key = serialization.load_pem_private_key(
|
||||
jwt_private_key.encode("utf-8"), password=None, backend=default_backend()
|
||||
)
|
||||
encoded = jwt.encode(
|
||||
{
|
||||
"iss": "issuer of the token",
|
||||
"sub": "Alice",
|
||||
"aud": client_id,
|
||||
"iat": int(time.time()),
|
||||
"exp": int(time.time()) + 5 * 60,
|
||||
},
|
||||
private_key,
|
||||
algorithm="RS256",
|
||||
)
|
||||
|
||||
return {"message": "This JWT will only last 5 minutes.", "token": encoded}
|
||||
|
||||
|
||||
@app.get("/.well-known/jwks.json")
|
||||
async def jwks() -> ResponseJwks:
|
||||
return _generate_jwks()
|
||||
|
||||
|
||||
@cache
|
||||
def _generate_jwks() -> ResponseJwks:
|
||||
private_key_pem = os.environ["JWT_PRIVATE_KEY"]
|
||||
private_key = serialization.load_pem_private_key(
|
||||
private_key_pem.encode("utf-8"), password=None, backend=default_backend()
|
||||
)
|
||||
jwk_json = jwt.algorithms.RSAAlgorithm.to_jwk(private_key)
|
||||
return ResponseJwks(keys=[ResponseRsaJwk.model_validate_json(jwk_json)])
|
||||
|
||||
Reference in New Issue
Block a user