Files
SlowAPI/slow/slow.py
2026-01-20 22:19:57 +03:30

330 lines
10 KiB
Python

import asyncio
import json
import re
import urllib.parse
from json.decoder import JSONDecodeError
from pathlib import Path
from typing import Any, Awaitable, Callable, Optional
from .responses import HTTPResponse, JSONResponse, Response
PR = re.compile(r"\<([a-zA-Z_][a-zA-Z0-9_]*)\>")
class CORS:
Origins: set[str]
Methods: set[str]
Disabled: bool
def __init__(self):
self.Disabled = False
self.Origins: set[str] = {}
self.Methods: set[str] = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"}
class Headers:
def __init__(self):
self._d: dict[str, str] = {}
def get(self, key: str, default: Optional[Any] = None) -> str | Any:
return self._d.get(key.lower(), default)
def set(self, key: str, value: str) -> None:
self._d[key.lower()] = value
def __str__(self):
return str(self._d)
def __contains__(self, key):
return self._d.__contains__(key)
class Request:
def __init__(self, method: str, path: str, headers: Headers, body: bytes):
self.method = method
self.path = path
self.headers = headers
self.body = body
def __str__(self):
return str(
{
"method": self.method,
"path": self.path,
"headers": self.headers,
"body": self.body,
}
)
def json(self) -> dict | None:
try:
return json.loads(self.body)
except JSONDecodeError:
return None
async def _default_404_route(request: Request):
return "HTTP/1.1 404 Not Found\r\nContent-Type: text/html\r\n\r\n404 Not Found contact admin".encode(
encoding="utf-8"
)
async def _default_405_route(request: Request):
return "HTTP/1.1 405 Method Not Allowed\r\nContent-Type: text/html\r\n\r\n405 Method Not Allowed".encode(
encoding="utf-8"
)
class App:
def __init__(self):
self.routes: dict[
re.Pattern[str], dict[str, Callable[[Request, ...], Awaitable[Response]]]
] = {}
self.error_routes: dict[int, Callable[[Request, ...], Awaitable[Response]]] = {
404: _default_404_route,
405: _default_405_route,
}
self.CORS = CORS()
def _pattern_to_regex(self, temp) -> re.Pattern[str]:
re_temp = temp
iter = PR.finditer(temp)
for m in iter:
name = m[1]
re_temp = re.sub(m[0], r"(?P<" + name + r">[a-zA-Z0-9\-._~:%&=]+)", re_temp)
return re.compile(re_temp)
def _serve(
self,
path: str,
method: str,
func: Callable[[Request, ...], Awaitable[Response]],
):
if method not in ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"]:
raise RuntimeError(f'Invalid method "{method}".')
pat = self._pattern_to_regex(path)
if pat not in self.routes:
self.routes[pat] = {}
if method in self.routes[pat]:
raise RuntimeWarning(f'Path "{path}" already exists.')
self.routes[pat][method] = func
def GET(self, path: str):
"""Decorator to register a GET HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[Response]]):
self._serve(path, "GET", func)
async def wrapper(*args, **kwargs):
res: Response = await func(*args, **kwargs)
res.no_header = True
return res
self._serve(path, "HEAD", wrapper)
return func
return decorator
def POST(self, path: str):
"""Decorator to register a POST HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[Response]]):
self._serve(path, "POST", func)
return func
return decorator
def PUT(self, path: str):
"""Decorator to register a PUT HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[Response]]):
self._serve(path, "PUT", func)
return func
def DELETE(self, path: str):
"""Decorator to register a DELETE HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[Response]]):
self._serve(path, "DELETE", func)
return func
return decorator
def PATCH(self, path: str):
"""Decorator to register a PATCH HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[Response]]):
self._serve(path, "PATCH", func)
return func
return decorator
def HEAD(self, path: str):
"""Decorator to register a HEAD HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[Response]]):
self._serve(path, "HEAD", func)
return func
return decorator
def error(self, code):
"""Decorator to register an error route."""
def decorator(func):
self.error_routes[code] = func
return func
return decorator
def resolve(self, path, method) -> tuple[Callable[..., Awaitable[Response]], dict]:
for pattern, route in self.routes.items():
if m := pattern.fullmatch(path):
if method not in route:
return self.error_routes[405], {}
return route[method], {
k: urllib.parse.unquote(v) for k, v in m.groupdict().items()
}
return self.error_routes[404], {}
def methods(self, path) -> set[str]:
for pattern, route in self.routes.items():
if pattern.fullmatch(path):
return set(route.keys())
return set()
async def handle_client(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
):
"""Handle an incoming connection (HTTP or WebSocket)."""
try:
# Read the initial HTTP request line
request_line = await reader.readline()
if not request_line:
return
# Parse request line
parts = request_line.decode(encoding="utf-8").strip().split()
if len(parts) < 3:
return
method, path, protocol = parts[0], parts[1], parts[2]
assert protocol == "HTTP/1.1"
headers: Headers = Headers()
while True:
line = await reader.readline()
if line == b"\r\n" or line == b"\n" or not line: # End of headers
break
line = line.decode("utf-8").strip()
if ":" in line:
key, value = line.split(":", 1)
headers.set(key.strip(), value.strip())
content_length = int(headers.get("Content-Length", 0))
body = await reader.read(content_length) if content_length else b""
if (
method == "OPTIONS"
): # FIXME Should handle responding with available methods as well
if "origin" in headers and headers.get("origin") in self.CORS.Origins:
origin = headers.get("origin")
head = [
"Content-Type: text/plain",
"Content-Length: 0",
f"Access-Control-Allow-Origin: {origin}",
f"Access-Control-Allow-Methods: {','.join(self.CORS.Methods)}",
"Access-Control-Allow-Headers: Content-Type,Authorization", # CORS
"Vary: Origin",
]
writer.write(Response(200, head, b"").render(self))
await writer.drain()
else:
writer.write(Response(403, ["Vary: Origin"], b"").render(self))
await writer.drain()
else:
route, kwargs = self.resolve(path, method)
response: Response = await route(
request=Request(
method=method, path=path, headers=headers, body=body
),
**kwargs,
)
response.headers.append("Vary: Origin")
if (
"origin" in headers
and self.CORS.Disabled
and headers.get("origin") in self.CORS.Origins
): # CORS
response.headers.append(
f"Access-Control-Allow-Origin: {headers.get('origin')}"
)
response.headers.append(
f"Access-Control-Allow-Methods: {','.join(self.CORS.Methods & self.methods(path))}"
)
response.headers.append(
"Access-Control-Allow-Headers: Content-Type,Authorization\r\n"
)
writer.write(response.render(self))
await writer.drain()
except Exception as e:
raise e
print(f"Internal Server Error: {e}")
finally:
writer.close()
await writer.wait_closed()
async def run(self, host="127.0.0.1", port=8000):
"""Start the async server."""
server = await asyncio.start_server(self.handle_client, host, port)
print(f"Serving on http://{host}:{port}")
async with server:
await server.serve_forever()
_value_pattern = re.compile(r"\{\{\s*([a-zA-Z_][a-zA-Z_0-9]*)\s*\}\}")
def render(
file: str | Path, variables: dict[str, Any] = {}
) -> Response: # TODO Move to another module
if isinstance(file, str):
file = Path(file)
content: str = file.read_text(encoding="utf-8")
for m in _value_pattern.findall(content):
if m in variables:
content = re.sub(r"\{\{\s*" + m + r"\s*\}\}", variables[m], content)
return HTTPResponse(content, content_type="text/html; charset=utf-8")
def redirect(location: str): # TODO Move to another module
return Response(307, ["Location: {location}"], b"")
def JSONAPI(func): # TODO Move to another module
async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs)
if not isinstance(result, dict):
if (
isinstance(result, tuple)
and len(result) == 2
and isinstance(result[1], dict)
and isinstance(result[0], int)
):
return JSONResponse(result[1], result[0])
raise RuntimeError("Return value of JSONAPI route is not a dictionary")
return JSONResponse(result)
return wrapper