Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/aws-proxy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
LOCALSTACK_AUTH_TOKEN: ${{ secrets.LOCALSTACK_AUTH_TOKEN }}
run: |
set -e
cd aws-proxy
docker pull localstack/localstack-pro &
docker pull public.ecr.aws/lambda/python:3.8 &

Expand All @@ -49,7 +50,6 @@ jobs:
# build and install extension
localstack extensions init
(
cd aws-proxy
make install
. .venv/bin/activate
pip install --upgrade --pre localstack localstack-ext
Expand Down
38 changes: 16 additions & 22 deletions aws-proxy/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,39 @@ VENV_RUN = . $(VENV_ACTIVATE)
TEST_PATH ?= tests
PIP_CMD ?= pip

usage: ## Show this help
usage: ## Show this help
@grep -Fh "##" $(MAKEFILE_LIST) | grep -Fv fgrep | sed -e 's/:.*##\s*/##/g' | awk -F'##' '{ printf "%-25s %s\n", $$1, $$2 }'

venv: $(VENV_ACTIVATE)

$(VENV_ACTIVATE): setup.py setup.cfg
install: ## Install dependencies
test -d .venv || $(VENV_BIN) .venv
$(VENV_RUN); pip install --upgrade pip setuptools plux wheel
$(VENV_RUN); pip install --upgrade black isort pyproject-flake8 flake8-black flake8-isort
$(VENV_RUN); pip install -e .
$(VENV_RUN); pip install -e .[test]
touch $(VENV_DIR)/bin/activate

clean:
clean: ## Clean up
rm -rf .venv/
rm -rf build/
rm -rf .eggs/
rm -rf *.egg-info/

lint:
$(VENV_RUN); python -m pflake8 --show-source

format:
$(VENV_RUN); python -m isort .; python -m black .
format: ## Run ruff to format the whole codebase
($(VENV_RUN); python -m ruff format .; python -m ruff check --output-format=full --fix .)

install: venv
$(VENV_RUN); $(PIP_CMD) install -e ".[test]"
lint: ## Run code linter to check code style
($(VENV_RUN); python -m ruff check --output-format=full . && python -m ruff format --check .)

test: venv
test: ## Run tests
$(VENV_RUN); python -m pytest $(PYTEST_ARGS) $(TEST_PATH)

dist: venv
$(VENV_RUN); python setup.py sdist bdist_wheel
entrypoints: ## Generate plugin entrypoints for Python package
$(VENV_RUN); python -m plux entrypoints

build: ## Build the extension
mkdir -p build
cp -r setup.py setup.cfg README.md aws_proxy build/
(cd build && python setup.py sdist)
build: entrypoints ## Build the extension
$(VENV_RUN); python -m build --no-isolation . --outdir build
@# make sure that the entrypoints are contained in the dist folder and are non-empty
@test -s localstack_extension_aws_proxy.egg-info/entry_points.txt || (echo "Entrypoints were not correctly created! Aborting!" && exit 1)

enable: $(wildcard ./build/dist/localstack_extension_aws_proxy-*.tar.gz) ## Enable the extension in LocalStack
enable: $(wildcard ./build/localstack_extension_aws_proxy-*.tar.gz) ## Enable the extension in LocalStack
$(VENV_RUN); \
pip uninstall --yes localstack-extension-aws-proxy; \
localstack extensions -v install file://$?
Expand Down
63 changes: 48 additions & 15 deletions aws-proxy/aws_proxy/client/auth_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from localstack import config as localstack_config
from localstack.aws.spec import load_service
from localstack.config import external_service_url
from localstack.constants import AWS_REGION_US_EAST_1, DOCKER_IMAGE_NAME_PRO, LOCALHOST_HOSTNAME
from localstack.constants import (
AWS_REGION_US_EAST_1,
DOCKER_IMAGE_NAME_PRO,
LOCALHOST_HOSTNAME,
)
from localstack.http import Request
from localstack.pro.core.bootstrap.licensingv2 import (
ENV_LOCALSTACK_API_KEY,
Expand All @@ -25,7 +29,10 @@
from localstack.utils.bootstrap import setup_logging
from localstack.utils.collections import select_attributes
from localstack.utils.container_utils.container_client import PortMappings
from localstack.utils.docker_utils import DOCKER_CLIENT, reserve_available_container_port
from localstack.utils.docker_utils import (
DOCKER_CLIENT,
reserve_available_container_port,
)
from localstack.utils.files import new_tmp_file, save_file
from localstack.utils.functions import run_safe
from localstack.utils.net import get_docker_host_from_container, get_free_tcp_port
Expand All @@ -39,8 +46,6 @@
from aws_proxy.shared.constants import HEADER_HOST_ORIGINAL
from aws_proxy.shared.models import AddProxyRequest, ProxyConfig

from .http2_server import run_server

LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
if localstack_config.DEBUG:
Expand All @@ -66,9 +71,14 @@ def __init__(self, config: ProxyConfig, port: int = None):
super().__init__(port=port)

def do_run(self):
# note: keep import here, to avoid runtime errors
from .http2_server import run_server

self.register_in_instance()
bind_host = self.config.get("bind_host") or DEFAULT_BIND_HOST
proxy = run_server(port=self.port, bind_addresses=[bind_host], handler=self.proxy_request)
proxy = run_server(
port=self.port, bind_addresses=[bind_host], handler=self.proxy_request
)
proxy.join()

def proxy_request(self, request: Request, data: bytes) -> Response:
Expand Down Expand Up @@ -109,7 +119,9 @@ def proxy_request(self, request: Request, data: bytes) -> Response:
# adjust request dict and fix certain edge cases in the request
self._adjust_request_dict(service_name, request_dict)

headers_truncated = {k: truncate(to_str(v)) for k, v in dict(aws_request.headers).items()}
headers_truncated = {
k: truncate(to_str(v)) for k, v in dict(aws_request.headers).items()
}
LOG.debug(
"Sending request for service %s to AWS: %s %s - %s - %s",
service_name,
Expand Down Expand Up @@ -138,7 +150,9 @@ def proxy_request(self, request: Request, data: bytes) -> Response:
return response
except Exception as e:
if LOG.isEnabledFor(logging.DEBUG):
LOG.exception("Error when making request to AWS service %s: %s", service_name, e)
LOG.exception(
"Error when making request to AWS service %s: %s", service_name, e
)
return requests_response("", status_code=400)

def register_in_instance(self):
Expand Down Expand Up @@ -224,7 +238,10 @@ def _adjust_request_dict(self, service_name: str, request_dict: Dict):
body_str = run_safe(lambda: to_str(req_body)) or ""

# TODO: this custom fix should not be required - investigate and remove!
if "<CreateBucketConfiguration" in body_str and "LocationConstraint" not in body_str:
if (
"<CreateBucketConfiguration" in body_str
and "LocationConstraint" not in body_str
):
region = request_dict["context"]["client_region"]
if region == AWS_REGION_US_EAST_1:
request_dict["body"] = ""
Expand All @@ -238,15 +255,19 @@ def _adjust_request_dict(self, service_name: str, request_dict: Dict):
account_id = self._query_account_id_from_aws()
if "QueueUrl" in req_body:
queue_name = req_body["QueueUrl"].split("/")[-1]
req_body["QueueUrl"] = f"https://queue.amazonaws.com/{account_id}/{queue_name}"
req_body["QueueUrl"] = (
f"https://queue.amazonaws.com/{account_id}/{queue_name}"
)
if "QueueOwnerAWSAccountId" in req_body:
req_body["QueueOwnerAWSAccountId"] = account_id
if service_name == "sqs" and request_dict.get("url"):
req_json = run_safe(lambda: json.loads(body_str)) or {}
account_id = self._query_account_id_from_aws()
queue_name = req_json.get("QueueName")
if account_id and queue_name:
request_dict["url"] = f"https://queue.amazonaws.com/{account_id}/{queue_name}"
request_dict["url"] = (
f"https://queue.amazonaws.com/{account_id}/{queue_name}"
)
req_json["QueueOwnerAWSAccountId"] = account_id
request_dict["body"] = to_bytes(json.dumps(req_json))

Expand All @@ -256,7 +277,9 @@ def _fix_headers(self, request: Request, service_name: str):
host = request.headers.get("Host") or ""
regex = r"^(https?://)?([0-9.]+|localhost)(:[0-9]+)?"
if re.match(regex, host):
request.headers["Host"] = re.sub(regex, rf"\1s3.{LOCALHOST_HOSTNAME}", host)
request.headers["Host"] = re.sub(
regex, rf"\1s3.{LOCALHOST_HOSTNAME}", host
)
request.headers.pop("Content-Length", None)
request.headers.pop("x-localstack-request-url", None)
request.headers.pop("X-Forwarded-For", None)
Expand Down Expand Up @@ -311,7 +334,9 @@ def start_aws_auth_proxy_in_container(
# should consider building pre-baked images for the extension in the future. Also,
# the new packaged CLI binary can help us gain more stability over time...

logging.getLogger("localstack.utils.container_utils.docker_cmd_client").setLevel(logging.INFO)
logging.getLogger("localstack.utils.container_utils.docker_cmd_client").setLevel(
logging.INFO
)
logging.getLogger("localstack.utils.docker_utils").setLevel(logging.INFO)
logging.getLogger("localstack.utils.run").setLevel(logging.INFO)

Expand All @@ -328,13 +353,18 @@ def start_aws_auth_proxy_in_container(
image_name = DOCKER_IMAGE_NAME_PRO
# add host mapping for localstack.cloud to localhost to prevent the health check from failing
additional_flags = (
repl_config.PROXY_DOCKER_FLAGS + " --add-host=localhost.localstack.cloud:host-gateway"
repl_config.PROXY_DOCKER_FLAGS
+ " --add-host=localhost.localstack.cloud:host-gateway"
)
DOCKER_CLIENT.create_container(
image_name,
name=container_name,
entrypoint="",
command=["bash", "-c", f"touch {CONTAINER_LOG_FILE}; tail -f {CONTAINER_LOG_FILE}"],
command=[
"bash",
"-c",
f"touch {CONTAINER_LOG_FILE}; tail -f {CONTAINER_LOG_FILE}",
],
ports=ports,
additional_flags=additional_flags,
)
Expand Down Expand Up @@ -388,7 +418,10 @@ def start_aws_auth_proxy_in_container(
command = f"{venv_activate}; localstack aws proxy -c {CONTAINER_CONFIG_FILE} -p {port} --host 0.0.0.0 > {CONTAINER_LOG_FILE} 2>&1"
if use_docker_sdk_command:
DOCKER_CLIENT.exec_in_container(
container_name, command=["bash", "-c", command], env_vars=env_vars, interactive=True
container_name,
command=["bash", "-c", command],
env_vars=env_vars,
interactive=True,
)
else:
env_vars_list = []
Expand Down
12 changes: 9 additions & 3 deletions aws-proxy/aws_proxy/client/http2_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def _do_create():
def _encode_headers(headers):
if RETURN_CASE_SENSITIVE_HEADERS:
return [(key.encode(), value.encode()) for key, value in headers.items()]
return [(key.lower().encode(), value.encode()) for key, value in headers.items()]
return [
(key.lower().encode(), value.encode()) for key, value in headers.items()
]

quart_asgi._encode_headers = quart_asgi.encode_headers = _encode_headers
quart_app.encode_headers = quart_utils.encode_headers = _encode_headers
Expand All @@ -116,7 +118,9 @@ def build_and_validate_headers(headers):
for name, value in headers:
if name[0] == b":"[0]:
raise ValueError("Pseudo headers are not valid")
header_name = bytes(name) if RETURN_CASE_SENSITIVE_HEADERS else bytes(name).lower()
header_name = (
bytes(name) if RETURN_CASE_SENSITIVE_HEADERS else bytes(name).lower()
)
validated_headers.append((header_name.strip(), bytes(value).strip()))
return validated_headers

Expand Down Expand Up @@ -212,7 +216,9 @@ async def index(path=None):
response.headers.pop("Content-Length", None)
result.headers.pop("Server", None)
result.headers.pop("Date", None)
headers = {k: str(v).replace("\n", r"\n") for k, v in result.headers.items()}
headers = {
k: str(v).replace("\n", r"\n") for k, v in result.headers.items()
}
response.headers.update(headers)
# set multi-value headers
multi_value_headers = getattr(result, "multi_value_headers", {})
Expand Down
48 changes: 38 additions & 10 deletions aws-proxy/aws_proxy/server/aws_request_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ class AwsProxyHandler(Handler):
# maps port numbers to proxy instances
PROXY_INSTANCES: Dict[int, ProxyInstance] = {}

def __call__(self, chain: HandlerChain, context: RequestContext, response: Response):
def __call__(
self, chain: HandlerChain, context: RequestContext, response: Response
):
proxy = self.select_proxy(context)
if not proxy:
return
Expand Down Expand Up @@ -63,7 +65,9 @@ def select_proxy(self, context: RequestContext) -> Optional[ProxyInstance]:
proxy = self.PROXY_INSTANCES[port]
proxy_config = proxy.get("config") or {}
services = proxy_config.get("services") or {}
service_name = self._get_canonical_service_name(context.service.service_name)
service_name = self._get_canonical_service_name(
context.service.service_name
)
service_config = services.get(service_name)
if not service_config:
continue
Expand Down Expand Up @@ -100,7 +104,9 @@ def _request_matches_resource(
self, context: RequestContext, resource_name_pattern: str
) -> bool:
try:
service_name = self._get_canonical_service_name(context.service.service_name)
service_name = self._get_canonical_service_name(
context.service.service_name
)
if service_name == "s3":
bucket_name = context.service_request.get("Bucket") or ""
s3_bucket_arn = arns.s3_bucket_arn(bucket_name)
Expand All @@ -113,7 +119,9 @@ def _request_matches_resource(
queue_name,
queue_url,
sqs_queue_arn(
queue_name, account_id=context.account_id, region_name=context.region
queue_name,
account_id=context.account_id,
region_name=context.region,
),
)
for candidate in candidates:
Expand All @@ -133,12 +141,16 @@ def _request_matches_resource(
) from e
return True

def forward_request(self, context: RequestContext, proxy: ProxyInstance) -> requests.Response:
def forward_request(
self, context: RequestContext, proxy: ProxyInstance
) -> requests.Response:
"""Forward the given request to the proxy instance, and return the response."""
port = proxy["port"]
request = context.request
target_host = get_addressable_container_host(default_local_hostname=LOCALHOST)
url = f"http://{target_host}:{port}{request.path}?{to_str(request.query_string)}"
url = (
f"http://{target_host}:{port}{request.path}?{to_str(request.query_string)}"
)

# inject Auth header, to ensure we're passing the right region to the proxy (e.g., for Cognito InitiateAuth)
self._extract_region_from_domain(context)
Expand All @@ -156,10 +168,20 @@ def forward_request(self, context: RequestContext, proxy: ProxyInstance) -> requ
data = request.form
elif request.data:
data = request.data
LOG.debug("Forward request: %s %s - %s - %s", request.method, url, dict(headers), data)
LOG.debug(
"Forward request: %s %s - %s - %s",
request.method,
url,
dict(headers),
data,
)
# construct response
result = requests.request(
method=request.method, url=url, data=data, headers=dict(headers), stream=True
method=request.method,
url=url,
data=data,
headers=dict(headers),
stream=True,
)
# TODO: ugly hack for now, simply attaching an additional attribute for raw response content
result.raw_content = result.raw.read()
Expand All @@ -173,7 +195,10 @@ def forward_request(self, context: RequestContext, proxy: ProxyInstance) -> requ
)
except requests.exceptions.ConnectionError:
# remove unreachable proxy
LOG.info("Removing unreachable AWS forward proxy due to connection issue: %s", url)
LOG.info(
"Removing unreachable AWS forward proxy due to connection issue: %s",
url,
)
self.PROXY_INSTANCES.pop(port, None)
return result

Expand All @@ -186,7 +211,10 @@ def _is_read_request(self, context: RequestContext) -> bool:
if operation_name.lower().startswith(("describe", "get", "list", "query")):
return True
# service-specific rules
if context.service.service_name == "cognito-idp" and operation_name == "InitiateAuth":
if (
context.service.service_name == "cognito-idp"
and operation_name == "InitiateAuth"
):
return True
if context.service.service_name == "dynamodb" and operation_name in {
"Scan",
Expand Down
Loading