Files
SlowAPI/slow/slow.py
2026-01-20 21:01:24 +03:30

335 lines
11 KiB
Python

import asyncio
import http.client
import json
import re
import urllib.parse
from json.decoder import JSONDecodeError
from pathlib import Path
from typing import Any, Awaitable, Callable, Optional
PR = re.compile(r"\<([a-zA-Z_][a-zA-Z0-9_]*)\>")
class CORS:
Origins: set[str]
Methods: set[str]
Disabled: bool
def __init__(self):
self.Disabled = False
self.Origins = []
self.Methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"]
class Headers:
def __init__(self):
self._d: dict[str, str] = {}
def get(self, key: str, default: Optional[Any] = None) -> str | Any:
return self._d.get(key.lower(), default)
def set(self, key: str, value: str) -> None:
self._d[key.lower()] = value
def __str__(self):
return str(self._d)
def __contains__(self, key):
return self._d.__contains__(key)
class Request:
def __init__(
self, method: str, path: str, headers: Headers, body: bytes, app: "App"
):
self.method = method
self.path = path
self.headers = headers
self.body = body
self.app = app
def __str__(self):
return str(
{
"method": self.method,
"path": self.path,
"headers": self.headers,
"body": self.body,
}
)
def json(self) -> dict | None:
try:
return json.loads(self.body)
except JSONDecodeError:
return None
async def _default_404_route(request: Request):
return "HTTP/1.1 404 Not Found\r\nContent-Type: text/html\r\n\r\n404 Not Found contact admin".encode(
encoding="utf-8"
)
async def _default_405_route(request: Request):
return "HTTP/1.1 405 Method Not Allowed\r\nContent-Type: text/html\r\n\r\n405 Method Not Allowed".encode(
encoding="utf-8"
)
class App:
def __init__(self):
self.routes: dict[
re.Pattern[str], dict[str, Callable[[Request, ...], Awaitable[bytes]]]
] = {}
self.error_routes: dict[int, Callable[[Request, ...], Awaitable[bytes]]] = {
404: _default_404_route,
405: _default_405_route,
}
self.CORS = CORS()
def _pattern_to_regex(self, temp) -> re.Pattern[str]:
re_temp = temp
iter = PR.finditer(temp)
for m in iter:
name = m[1]
re_temp = re.sub(m[0], r"(?P<" + name + r">[a-zA-Z0-9\-._~:%&=]+)", re_temp)
return re.compile(re_temp)
def _serve(
self, path: str, method: str, func: Callable[[Request, ...], Awaitable[bytes]]
):
if method not in ["GET", "POST", "PUT", "DELETE"]:
raise RuntimeError(f'Invalid method "{method}".')
pat = self._pattern_to_regex(path)
if pat not in self.routes:
self.routes[pat] = {}
if method in self.routes[pat]:
raise RuntimeWarning(f'Path "{path}" already exists.')
self.routes[pat][method] = func
def GET(self, path: str):
"""Decorator to register a GET HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[bytes]]):
self._serve(path, "GET", func)
return func
return decorator
def POST(self, path: str):
"""Decorator to register a POST HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[bytes]]):
self._serve(path, "POST", func)
return func
return decorator
def PUT(self, path: str):
"""Decorator to register a PUT HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[bytes]]):
self._serve(path, "PUT", func)
return func
def DELETE(self, path: str):
"""Decorator to register a DELETE HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[bytes]]):
self._serve(path, "DELETE", func)
return func
return decorator
def PATCH(self, path: str):
"""Decorator to register a PATCH HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[bytes]]):
self._serve(path, "PATCH", func)
return func
return decorator
def HEAD(self, path: str):
"""Decorator to register a HEAD HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[bytes]]):
self._serve(path, "HEAD", func)
return func
return decorator
def error(self, code):
"""Decorator to register an error route."""
def decorator(func):
self.error_routes[code] = func
return func
return decorator
def resolve(self, path, method) -> tuple[Callable[..., Awaitable[bytes]], dict]:
for pattern, route in self.routes.items():
if m := pattern.fullmatch(path):
if method not in route:
return self.error_routes[405], {}
return route[method], {
k: urllib.parse.unquote(v) for k, v in m.groupdict().items()
}
return self.error_routes[404], {}
async def handle_client(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
):
"""Handle an incoming connection (HTTP or WebSocket)."""
try:
# Read the initial HTTP request line
request_line = await reader.readline()
if not request_line:
return
# Parse request line
parts = request_line.decode(encoding="utf-8").strip().split()
if len(parts) < 3:
return
method, path, protocol = parts[0], parts[1], parts[2]
assert protocol == "HTTP/1.1"
headers: Headers = Headers()
while True:
line = await reader.readline()
if line == b"\r\n" or line == b"\n" or not line: # End of headers
break
line = line.decode("utf-8").strip()
if ":" in line:
key, value = line.split(":", 1)
headers.set(key.strip(), value.strip())
content_length = int(headers.get("Content-Length", 0))
body = await reader.read(content_length) if content_length else b""
if method == "OPTIONS":
if "origin" in headers and headers.get("origin") in self.CORS.Origins:
origin = headers.get("origin")
response = "HTTP/1.1 200 OK\r\n"
response += "Content-Type: text/plain\r\n"
response += "Content-Length: 0\r\n"
response += f"Access-Control-Allow-Origin: {origin}\r\n"
response += f"Access-Control-Allow-Methods: {','.join(self.CORS.Methods)}\r\n"
response += "Access-Control-Allow-Headers: Content-Type,Authorization\r\n" # CORS
response += "Vary: Origin\r\n"
response += "\r\n"
writer.write(response.encode(encoding="utf-8"))
await writer.drain()
else:
response = "HTTP/1.1 403 Forbidden\r\n"
response += "Content-Length: 0\r\n"
response += "Vary: Origin\r\n"
response += "\r\n"
writer.write(response.encode(encoding="utf-8"))
await writer.drain()
else:
route, kwargs = self.resolve(path, method)
response = await route(
request=Request(
method=method, path=path, headers=headers, body=body, app=self
),
**kwargs,
)
writer.write(response)
await writer.drain()
except Exception as e:
raise e
print(f"Internal Server Error: {e}")
finally:
writer.close()
await writer.wait_closed()
async def run(self, host="127.0.0.1", port=8000):
"""Start the async server."""
server = await asyncio.start_server(self.handle_client, host, port)
print(f"Serving on http://{host}:{port}")
async with server:
await server.serve_forever()
def HTTPResponse(
request: Request,
content: str | bytes,
status=200,
content_type="text/plain; charset=utf-8",
headers=[],
) -> bytes:
content_bytes = content
if isinstance(content, str):
content_bytes = content.encode(encoding="utf-8")
head: str = f"HTTP/1.1 {status} {http.client.responses.get(status, 'Unkown Status Code')}\r\nContent-Type: {content_type}\r\nContent-Length: {len(content_bytes)}\r\n"
if (
"origin" in request.headers
and not request.app.CORS.Disabled
and request.headers.get("origin") in request.app.CORS.Origins
):
head += (
f"Access-Control-Allow-Origin: {request.headers.get('origin')}\r\n" # CORS
)
head += f"Access-Control-Allow-Methods: {','.join(request.app.CORS.Methods)}\r\n" # CORS
head += "Access-Control-Allow-Headers: Content-Type,Authorization\r\n" # CORS
head += "Vary: Origin\r\n"
head += "\r\n".join(headers) + ("\r\n" if len(headers) > 0 else "")
head += "\r\n"
return head.encode(encoding="utf-8") + content_bytes
_value_pattern = re.compile(r"\{\{\s*([a-zA-Z_][a-zA-Z_0-9]*)\s*\}\}")
def render(request: Request, file: str | Path, variables: dict[str, Any] = {}) -> bytes:
if isinstance(file, str):
file = Path(file)
content: str = file.read_text(encoding="utf-8")
for m in _value_pattern.findall(content):
if m in variables:
content = re.sub(r"\{\{\s*" + m + r"\s*\}\}", variables[m], content)
return HTTPResponse(request, content, content_type="text/html; charset=utf-8")
def redirect(location: str):
return f"HTTP/1.1 307 Temporary Redirect\r\nContent-Length: 0\r\nLocation: {location}".encode(
encoding="utf-8"
)
def JSONResponse(request: Request, d: dict, status=200) -> bytes:
return HTTPResponse(
request, json.dumps(d), status=status, content_type="text/json; charset=utf-8"
)
def JSONAPI(func):
async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs)
if not isinstance(result, dict):
if (
isinstance(result, tuple)
and len(result) == 2
and isinstance(result[1], dict)
and isinstance(result[0], int)
):
return JSONResponse(kwargs["request"], result[1], result[0])
raise RuntimeError("Return value of JSONAPI route is not a dictionary")
return JSONResponse(kwargs["request"], result)
return wrapper